use core::cell::UnsafeCell;
use core::future::Future;
use core::ops::{Deref, DerefMut};
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::{Context, Poll};
use atomic_waker::AtomicWaker;
pub struct LocalLock<T> {
held: AtomicBool,
waker: AtomicWaker,
value: UnsafeCell<T>,
}
unsafe impl<T: Send> Send for LocalLock<T> {}
unsafe impl<T: Send> Sync for LocalLock<T> {}
impl<T> LocalLock<T> {
pub const fn new(value: T) -> Self {
Self {
held: AtomicBool::new(false),
waker: AtomicWaker::new(),
value: UnsafeCell::new(value),
}
}
pub fn into_inner(self) -> T {
self.value.into_inner()
}
pub fn try_lock(&self) -> Option<LocalLockGuard<'_, T>> {
self.held
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.ok()
.map(|_| LocalLockGuard { lock: self })
}
pub fn lock(&self) -> LockFuture<'_, T> {
LockFuture { lock: self }
}
fn unlock(&self) {
self.held.store(false, Ordering::Release);
self.waker.wake();
}
}
pub struct LocalLockGuard<'a, T> {
lock: &'a LocalLock<T>,
}
impl<T> Deref for LocalLockGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.value.get() }
}
}
impl<T> DerefMut for LocalLockGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.value.get() }
}
}
impl<T> Drop for LocalLockGuard<'_, T> {
fn drop(&mut self) {
self.lock.unlock();
}
}
pub struct LockFuture<'a, T> {
lock: &'a LocalLock<T>,
}
impl<'a, T> Future for LockFuture<'a, T> {
type Output = LocalLockGuard<'a, T>;
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(g) = self.lock.try_lock() {
return Poll::Ready(g);
}
self.lock.waker.register(cx.waker());
if let Some(g) = self.lock.try_lock() {
return Poll::Ready(g);
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::pin::pin;
use core::task::{RawWaker, RawWakerVTable, Waker};
fn noop_waker() -> Waker {
const VT: RawWakerVTable = RawWakerVTable::new(|_| RAW, |_| {}, |_| {}, |_| {});
const RAW: RawWaker = RawWaker::new(core::ptr::null(), &VT);
unsafe { Waker::from_raw(RAW) }
}
fn run<F: Future>(mut fut: F) -> F::Output {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut fut = unsafe { core::pin::Pin::new_unchecked(&mut fut) };
loop {
if let Poll::Ready(v) = fut.as_mut().poll(&mut cx) {
return v;
}
}
}
#[test]
fn lock_and_unlock() {
let lock = LocalLock::new(42i32);
{
let mut g = run(lock.lock());
*g += 1;
}
let g = run(lock.lock());
assert_eq!(*g, 43);
}
#[test]
fn try_lock_returns_none_when_held() {
let lock = LocalLock::new(0u8);
let g = lock.try_lock().unwrap();
assert!(lock.try_lock().is_none());
drop(g);
assert!(lock.try_lock().is_some());
}
#[test]
fn lock_future_becomes_ready_after_drop() {
let lock = LocalLock::new(0u8);
let g = lock.try_lock().unwrap();
let mut fut = pin!(lock.lock());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
assert!(matches!(fut.as_mut().poll(&mut cx), Poll::Pending));
drop(g);
match fut.as_mut().poll(&mut cx) {
Poll::Ready(_) => {}
Poll::Pending => panic!("expected ready after unlock"),
}
}
}