use core::pin::pin;
use core::sync::atomic::Ordering::*;
use core::{
cell::UnsafeCell,
marker::PhantomData,
mem,
ops::{Deref, DerefMut},
pin::Pin,
ptr,
sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize},
task::{Poll, Waker},
};
use alloc::boxed::Box;
use super::q::Queue;
use crate::Guard;
use crate::sync::{Waiter, Waiters};
pub struct Mutex<T> {
locked: AtomicBool,
waiters: Waiters,
value: UnsafeCell<T>,
}
unsafe impl<T: Send> Send for Mutex<T> {}
unsafe impl<T: Send> Sync for Mutex<T> {}
impl<'a, T> Deref for Guard<'a, Mutex<T>> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &self.0.value.get().as_ref().unwrap() }
}
}
impl<'a, T> DerefMut for Guard<'a, Mutex<T>> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.0.value.get().as_mut().unwrap() }
}
}
impl<T> Mutex<T> {
pub const fn new(value: T) -> Self {
Self {
locked: AtomicBool::new(false),
waiters: Queue::new(),
value: UnsafeCell::new(value),
}
}
pub fn try_lock(&self) -> Option<Guard<'_, Self>> {
self.locked
.compare_exchange(false, true, Acquire, Relaxed)
.map(|_| Guard(self))
.ok()
}
pub fn lock(
&self,
waiter: Waiter,
) -> Lock<'_, Self, impl Future<Output = Guard<'_, Self>> + Send>
where
T: Send,
{
Lock(Box::pin(Key {
lock: &self,
node: waiter.into(),
queued: false,
marker: PhantomData,
}))
}
pub fn enqueue(&self, waiter: Waiter) {
unsafe {
self.waiters.enqueue(waiter);
}
}
}
pub struct Key<'a, T> {
lock: &'a Mutex<T>,
node: Option<Waiter>,
queued: bool,
marker: PhantomData<&'a mut T>,
}
impl<'a, T> Future for Key<'a, T> {
type Output = Guard<'a, Mutex<T>>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
let Self {
lock, node, queued, ..
} = &mut *self;
if lock.try_lock().is_some() {
return Poll::Ready(Guard(lock));
}
let node = node.take().expect("lock() requires a waiter");
node.assign_waker(cx.waker().clone());
if node.signaled() {
loop {
let locked = lock.locked.load(Relaxed);
if locked {
break;
}
let Some(poll) = lock
.locked
.compare_exchange_weak(false, true, Acquire, Relaxed)
.map(|_| Guard(*lock))
.map(Poll::Ready)
.ok()
else {
continue;
};
return poll;
}
}
if !*queued {
unsafe {
lock.enqueue(node);
*queued = true;
}
}
Poll::Pending
}
}
pub type LockKey<'a, T> = Lock<'a, Mutex<T>, Key<'a, T>>;
pub struct Lock<'a, T: 'a, F>(Pin<Box<F>>)
where
F: Future<Output = Guard<'a, T>>;
impl<'a, T: 'a, F: Future<Output = Guard<'a, T>>> Future for Lock<'a, T, F> {
type Output = F::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
pin!(&mut self.0).poll(cx)
}
}