use crate::{
collection::{Deque, backward_deque_idx},
misc::_unlikely_unreachable,
sync::{AtomicUsize, SyncMutex, sync_mutex::SyncMutexGuard},
};
use core::{
cell::UnsafeCell,
fmt::{Debug, Formatter},
mem,
ops::{Deref, DerefMut},
pin::Pin,
sync::atomic::Ordering,
task::{Context, Poll, Waker},
};
const HAS_WAITERS: usize = 0b10;
const IS_LOCKED: usize = 0b1;
pub struct AsyncMutex<T> {
state: AtomicUsize,
value: UnsafeCell<T>,
waiters: SyncMutex<Waiters>,
}
impl<T> AsyncMutex<T> {
#[inline]
pub const fn new(t: T) -> Self {
Self {
state: AtomicUsize::new(0),
value: UnsafeCell::new(t),
waiters: SyncMutex::new(Waiters {
added: 0,
deque: Deque::new(),
last_added: 0,
waiting_count: 0,
}),
}
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.value.get() }
}
#[inline]
pub fn into_inner(self) -> T {
self.value.into_inner()
}
#[inline]
pub const fn lock(&self) -> AsyncMutexGuardFuture<'_, T> {
AsyncMutexGuardFuture { idx_opt: None, mutex_opt: Some(self) }
}
#[inline]
pub fn try_lock(&self) -> Option<AsyncMutexGuard<'_, T>> {
let prev = self.state.fetch_or(IS_LOCKED, Ordering::Acquire);
if is_locked(prev) { None } else { Some(AsyncMutexGuard { mutex: self }) }
}
}
#[expect(clippy::missing_fields_in_debug, reason = "best effort")]
impl<T> Debug for AsyncMutex<T> {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
let state = self.state.load(Ordering::SeqCst);
f.debug_struct("Mutex")
.field("is_locked", &is_locked(state))
.field("has_waiters", &has_waiters(state))
.finish()
}
}
unsafe impl<T: Send> Send for AsyncMutex<T> {}
unsafe impl<T: Send> Sync for AsyncMutex<T> {}
#[clippy::has_significant_drop]
pub struct AsyncMutexGuard<'any, T> {
mutex: &'any AsyncMutex<T>,
}
impl<T> Debug for AsyncMutexGuard<'_, T>
where
T: Debug,
{
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("AsyncMutexGuard").field("mutex", &self.mutex).field("value", &&**self).finish()
}
}
impl<T> Deref for AsyncMutexGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
unsafe { &*self.mutex.value.get() }
}
}
impl<T> DerefMut for AsyncMutexGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.mutex.value.get() }
}
}
impl<T> Drop for AsyncMutexGuard<'_, T> {
#[inline]
fn drop(&mut self) {
let prev = self.mutex.state.fetch_and(!IS_LOCKED, Ordering::AcqRel);
if has_waiters(prev) {
wake(&self.mutex.state, self.mutex.waiters.lock());
}
}
}
unsafe impl<T: Send> Send for AsyncMutexGuard<'_, T> {}
unsafe impl<T: Sync> Sync for AsyncMutexGuard<'_, T> {}
#[derive(Debug)]
pub struct AsyncMutexGuardFuture<'mutex, T> {
idx_opt: Option<usize>,
mutex_opt: Option<&'mutex AsyncMutex<T>>,
}
impl<T> Drop for AsyncMutexGuardFuture<'_, T> {
#[inline]
fn drop(&mut self) {
let (Some(idx), Some(mutex)) = (self.idx_opt, self.mutex_opt) else {
return;
};
let mut guard = mutex.waiters.lock();
if matches!(remove_waker(idx, &mutex.state, &mut guard), Some(Waiter::Woken)) {
wake(&mutex.state, guard);
}
}
}
impl<'any, T> Future for AsyncMutexGuardFuture<'any, T> {
type Output = AsyncMutexGuard<'any, T>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Some(mutex) = self.mutex_opt else { _unlikely_unreachable() };
if let Some(mutex_guard) = mutex.try_lock() {
if let Some(idx) = self.idx_opt {
drop(remove_waker(idx, &mutex.state, &mut mutex_guard.mutex.waiters.lock()));
}
self.mutex_opt = None;
return Poll::Ready(mutex_guard);
}
let AsyncMutex { state, waiters, value: _ } = mutex;
let mut waiters_guard = waiters.lock();
if let Some(idx) = self.idx_opt {
let Waiters { added: _, deque, last_added, waiting_count } = &mut *waiters_guard;
let actual_idx = backward_deque_idx(idx, *last_added);
if let Some(elem) = deque.get_mut(actual_idx) {
elem.register(waiting_count, cx.waker());
if *waiting_count > 0 {
let _ = state.fetch_or(HAS_WAITERS, Ordering::Relaxed);
}
}
} else {
waiters_guard.last_added = waiters_guard.added;
self.idx_opt = Some(waiters_guard.last_added);
waiters_guard.added = waiters_guard.added.wrapping_add(1);
let _rslt = waiters_guard.deque.push_front(Waiter::Waiting(cx.waker().clone()));
waiters_guard.waiting_count = waiters_guard.waiting_count.wrapping_add(1);
if waiters_guard.waiting_count == 1 {
let _ = state.fetch_or(HAS_WAITERS, Ordering::Relaxed);
}
}
if let Some(mutex_guard) = mutex.try_lock() {
if let Some(idx) = self.idx_opt {
drop(remove_waker(idx, &mutex.state, &mut waiters_guard));
}
drop(waiters_guard);
self.mutex_opt = None;
return Poll::Ready(mutex_guard);
}
Poll::Pending
}
}
unsafe impl<T: Send> Send for AsyncMutexGuardFuture<'_, T> {}
unsafe impl<T: Send> Sync for AsyncMutexGuardFuture<'_, T> {}
#[derive(Debug)]
enum Waiter {
Removed,
Waiting(Waker),
Woken,
}
impl Waiter {
#[inline]
fn register(&mut self, waiting_count: &mut usize, waker: &Waker) {
match self {
Self::Removed | Self::Woken => {
*waiting_count = waiting_count.wrapping_add(1);
*self = Self::Waiting(waker.clone());
}
Self::Waiting(elem) => {
elem.clone_from(waker);
}
}
}
}
#[derive(Debug)]
struct Waiters {
added: usize,
deque: Deque<Waiter>,
last_added: usize,
waiting_count: usize,
}
#[inline]
const fn is_locked(state: usize) -> bool {
(state & IS_LOCKED) != 0
}
#[inline]
const fn has_waiters(state: usize) -> bool {
(state & HAS_WAITERS) != 0
}
#[inline]
fn remove_waker(idx: usize, state: &AtomicUsize, waiters: &mut Waiters) -> Option<Waiter> {
let actual_idx = backward_deque_idx(idx, waiters.last_added);
let waiter = waiters.deque.get_mut(actual_idx)?;
let prev = mem::replace(waiter, Waiter::Removed);
if matches!(&prev, Waiter::Waiting(_)) {
waiters.waiting_count = waiters.waiting_count.wrapping_sub(1);
if waiters.waiting_count == 0 {
let _ = state.fetch_and(!HAS_WAITERS, Ordering::Relaxed);
}
}
Some(prev)
}
#[inline]
fn wake(state: &AtomicUsize, mut waiters: SyncMutexGuard<'_, Waiters>) {
let waker_opt = loop {
let Some(waiter) = waiters.deque.last_mut() else {
let _ = state.fetch_and(!HAS_WAITERS, Ordering::Relaxed);
break None;
};
let prev = mem::replace(waiter, Waiter::Woken);
match prev {
Waiter::Removed => {
let _elem = waiters.deque.pop_back();
}
Waiter::Waiting(waker) => {
waiters.waiting_count = waiters.waiting_count.wrapping_sub(1);
if waiters.waiting_count == 0 {
let _ = state.fetch_and(!HAS_WAITERS, Ordering::Relaxed);
}
break Some(waker);
}
Waiter::Woken => break None,
}
};
drop(waiters);
if let Some(waker) = waker_opt {
waker.wake();
}
}
#[cfg(test)]
mod tests {
use crate::{
executor::Runtime,
misc::PollOnce,
sync::{
Arc, AsyncMutex,
async_mutex::{has_waiters, is_locked},
},
};
use core::sync::atomic::Ordering;
#[test]
fn competition() {
let (tx, rx) = std::sync::mpsc::channel();
let mutex = Arc::new(AsyncMutex::new(0));
let num_threads = 1000;
let runtime = Runtime::new();
let tx = Arc::new(tx);
for _ in 0..num_threads {
let tx = tx.clone();
let mutex = mutex.clone();
let _fut = runtime
.spawn_threaded(async move {
let mut guard = mutex.lock().await;
*guard += 1;
tx.send(()).unwrap();
})
.unwrap();
}
runtime.block_on(async {
for _ in 0..num_threads {
rx.recv().unwrap();
}
assert_eq!(num_threads, *mutex.lock().await);
});
std::thread::sleep(std::time::Duration::from_millis(500));
check_mutex(&mutex);
}
#[test]
fn sequential() {
Runtime::new().block_on(async {
let mutex = AsyncMutex::new(());
for _ in 0..10 {
let _guard = mutex.lock().await;
}
check_mutex(&mutex);
});
}
#[test]
fn wakes_waiter() {
Runtime::new().block_on(async {
let mutex = AsyncMutex::new(());
{
let lock0 = mutex.lock().await;
let mut lock1_fut = mutex.lock();
assert!(PollOnce::new(&mut lock1_fut).await.is_none());
drop(lock0);
assert!(PollOnce::new(&mut lock1_fut).await.is_some());
}
check_mutex(&mutex);
});
}
fn check_mutex<T>(mutex: &AsyncMutex<T>) {
let state = mutex.state.load(Ordering::Relaxed);
let waiters = mutex.waiters.lock();
assert!(!has_waiters(state));
assert!(!is_locked(state));
assert_eq!(waiters.waiting_count, 0);
}
}