Skip to main content

moduvex_runtime/sync/
mutex.rs

1//! Async mutex — cooperative mutual exclusion for async tasks.
2//!
3//! Unlike `std::sync::Mutex`, locking suspends the calling task (yields back
4//! to the executor) instead of blocking the OS thread. This is critical inside
5//! async contexts where blocking would starve other tasks sharing the thread.
6//!
7//! # Design
8//! - Inner value protected by a `std::sync::Mutex` for the critical section
9//!   of updating waker queues and the locked flag.
10//! - A `VecDeque<Waker>` wait queue ensures FIFO fairness across contenders.
11//! - `MutexGuard` drops the lock and wakes the next waiter on `Drop`.
12
13use std::cell::UnsafeCell;
14use std::collections::VecDeque;
15use std::future::Future;
16use std::ops::{Deref, DerefMut};
17use std::pin::Pin;
18use std::sync::{Arc, Mutex as StdMutex};
19use std::task::{Context, Poll, Waker};
20
21// ── Inner state ───────────────────────────────────────────────────────────────
22
23struct Inner<T> {
24    /// Whether the async lock is currently held by a `MutexGuard`.
25    locked: bool,
26    /// Tasks waiting to acquire the lock, in arrival order (FIFO).
27    waiters: VecDeque<Waker>,
28    /// The protected value.
29    ///
30    /// `UnsafeCell` allows mutation through a shared `Arc<Inner<T>>`.
31    /// Safe because access is serialised: only the current `MutexGuard`
32    /// holder may dereference this pointer, and there is at most one guard
33    /// alive at a time (enforced by `locked`).
34    value: UnsafeCell<T>,
35}
36
37// SAFETY: `Mutex<T>` must be `Send + Sync` when `T: Send` so it can be shared
38// across async tasks. The `UnsafeCell` is safe because mutation is serialised
39// by the `locked` flag inside the `StdMutex<Inner>`.
40unsafe impl<T: Send> Send for Inner<T> {}
41unsafe impl<T: Send> Sync for Inner<T> {}
42
43// ── Mutex ─────────────────────────────────────────────────────────────────────
44
45/// Async-aware mutual exclusion primitive.
46///
47/// Wraps a value of type `T`; concurrent tasks suspend (not block) while
48/// waiting for the lock.
49pub struct Mutex<T> {
50    inner: Arc<StdMutex<Inner<T>>>,
51}
52
53impl<T> Mutex<T> {
54    /// Create a new `Mutex` wrapping `value`.
55    pub fn new(value: T) -> Self {
56        Self {
57            inner: Arc::new(StdMutex::new(Inner {
58                locked: false,
59                waiters: VecDeque::new(),
60                value: UnsafeCell::new(value),
61            })),
62        }
63    }
64
65    /// Acquire the lock asynchronously, returning a `MutexGuard<T>`.
66    ///
67    /// The returned future suspends if the lock is already held and resumes
68    /// once the previous holder's `MutexGuard` is dropped.
69    pub fn lock(&self) -> LockFuture<'_, T> {
70        LockFuture { inner: &self.inner }
71    }
72}
73
74// ── LockFuture ────────────────────────────────────────────────────────────────
75
76/// Future returned by [`Mutex::lock`].
77pub struct LockFuture<'a, T> {
78    inner: &'a Arc<StdMutex<Inner<T>>>,
79}
80
81impl<T> Future for LockFuture<'_, T> {
82    type Output = MutexGuard<T>;
83
84    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85        let mut g = self.inner.lock().unwrap();
86        if !g.locked {
87            g.locked = true;
88            // Cache the pointer to the protected value at lock acquisition time.
89            let value_ptr = g.value.get();
90            Poll::Ready(MutexGuard {
91                inner: Arc::clone(self.inner),
92                value_ptr,
93            })
94        } else {
95            g.waiters.push_back(cx.waker().clone());
96            Poll::Pending
97        }
98    }
99}
100
101// ── MutexGuard ────────────────────────────────────────────────────────────────
102
103/// RAII guard that releases the async lock on drop and wakes the next waiter.
104pub struct MutexGuard<T> {
105    inner: Arc<StdMutex<Inner<T>>>,
106    /// Cached raw pointer to the protected value. Avoids acquiring the
107    /// StdMutex on every deref. Valid for the lifetime of this guard because:
108    /// - The Arc keeps the Inner allocation alive.
109    /// - The async `locked` flag prevents concurrent mutation.
110    value_ptr: *mut T,
111}
112
113// SAFETY: MutexGuard<T> is Send+Sync when T: Send because:
114// - The async lock serialises all access to the value.
115// - The raw pointer comes from UnsafeCell inside an Arc (heap-stable).
116unsafe impl<T: Send> Send for MutexGuard<T> {}
117unsafe impl<T: Send> Sync for MutexGuard<T> {}
118
119impl<T> Deref for MutexGuard<T> {
120    type Target = T;
121
122    fn deref(&self) -> &T {
123        // SAFETY: we hold the async lock (`locked == true`), so no other
124        // `MutexGuard` exists concurrently. The Arc keeps memory alive.
125        // `value_ptr` was obtained at lock acquisition time.
126        unsafe { &*self.value_ptr }
127    }
128}
129
130impl<T> DerefMut for MutexGuard<T> {
131    fn deref_mut(&mut self) -> &mut T {
132        // SAFETY: we hold the async lock exclusively; `&mut self` ensures
133        // no aliased mutable references exist via this guard.
134        unsafe { &mut *self.value_ptr }
135    }
136}
137
138impl<T> Drop for MutexGuard<T> {
139    fn drop(&mut self) {
140        let mut g = self.inner.lock().unwrap();
141        // Release the lock and wake the next waiter, if any.
142        g.locked = false;
143        if let Some(w) = g.waiters.pop_front() {
144            drop(g); // release inner mutex before waking
145            w.wake();
146        }
147    }
148}
149
150// ── Tests ─────────────────────────────────────────────────────────────────────
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::executor::{block_on, block_on_with_spawn, spawn};
156    use std::sync::Arc as StdArc;
157
158    #[test]
159    fn lock_and_mutate() {
160        block_on(async {
161            let m = Mutex::new(0u32);
162            {
163                let mut g = m.lock().await;
164                *g += 1;
165            }
166            {
167                let g = m.lock().await;
168                assert_eq!(*g, 1);
169            }
170        });
171    }
172
173    #[test]
174    fn sequential_locks_in_single_task() {
175        block_on(async {
176            let m = Mutex::new(Vec::<u32>::new());
177            for i in 0..5 {
178                m.lock().await.push(i);
179            }
180            let g = m.lock().await;
181            assert_eq!(*g, vec![0, 1, 2, 3, 4]);
182        });
183    }
184
185    #[test]
186    fn concurrent_lock_via_spawn() {
187        let counter = StdArc::new(Mutex::new(0u32));
188        let c1 = counter.clone();
189        let c2 = counter.clone();
190
191        block_on_with_spawn(async move {
192            let jh1 = spawn(async move {
193                let mut g = c1.lock().await;
194                *g += 1;
195            });
196            let jh2 = spawn(async move {
197                let mut g = c2.lock().await;
198                *g += 1;
199            });
200            jh1.await.unwrap();
201            jh2.await.unwrap();
202        });
203
204        // Run a fresh block_on to read the result.
205        let final_val = block_on(async { *counter.lock().await });
206        assert_eq!(final_val, 2);
207    }
208
209    #[test]
210    fn guard_drops_release_lock() {
211        block_on(async {
212            let m = Mutex::new(42u32);
213            let g = m.lock().await;
214            assert_eq!(*g, 42);
215            drop(g);
216            // After drop we must be able to lock again immediately.
217            let g2 = m.lock().await;
218            assert_eq!(*g2, 42);
219        });
220    }
221}