moduvex_runtime/sync/mutex.rs
1//! Async mutex — cooperative mutual exclusion for async tasks.
2//!
3//! Unlike `std::sync::Mutex`, locking suspends the calling task (yields back
4//! to the executor) instead of blocking the OS thread. This is critical inside
5//! async contexts where blocking would starve other tasks sharing the thread.
6//!
7//! # Design
8//! - Inner value protected by a `std::sync::Mutex` for the critical section
9//! of updating waker queues and the locked flag.
10//! - A `VecDeque<Waker>` wait queue ensures FIFO fairness across contenders.
11//! - `MutexGuard` drops the lock and wakes the next waiter on `Drop`.
12
13use std::cell::UnsafeCell;
14use std::collections::VecDeque;
15use std::future::Future;
16use std::ops::{Deref, DerefMut};
17use std::pin::Pin;
18use std::sync::{Arc, Mutex as StdMutex};
19use std::task::{Context, Poll, Waker};
20
21// ── Inner state ───────────────────────────────────────────────────────────────
22
23struct Inner<T> {
24 /// Whether the async lock is currently held by a `MutexGuard`.
25 locked: bool,
26 /// Tasks waiting to acquire the lock, in arrival order (FIFO).
27 waiters: VecDeque<Waker>,
28 /// The protected value.
29 ///
30 /// `UnsafeCell` allows mutation through a shared `Arc<Inner<T>>`.
31 /// Safe because access is serialised: only the current `MutexGuard`
32 /// holder may dereference this pointer, and there is at most one guard
33 /// alive at a time (enforced by `locked`).
34 value: UnsafeCell<T>,
35}
36
37// SAFETY: `Mutex<T>` must be `Send + Sync` when `T: Send` so it can be shared
38// across async tasks. The `UnsafeCell` is safe because mutation is serialised
39// by the `locked` flag inside the `StdMutex<Inner>`.
40unsafe impl<T: Send> Send for Inner<T> {}
41unsafe impl<T: Send> Sync for Inner<T> {}
42
43// ── Mutex ─────────────────────────────────────────────────────────────────────
44
45/// Async-aware mutual exclusion primitive.
46///
47/// Wraps a value of type `T`; concurrent tasks suspend (not block) while
48/// waiting for the lock.
49pub struct Mutex<T> {
50 inner: Arc<StdMutex<Inner<T>>>,
51}
52
53impl<T> Mutex<T> {
54 /// Create a new `Mutex` wrapping `value`.
55 pub fn new(value: T) -> Self {
56 Self {
57 inner: Arc::new(StdMutex::new(Inner {
58 locked: false,
59 waiters: VecDeque::new(),
60 value: UnsafeCell::new(value),
61 })),
62 }
63 }
64
65 /// Acquire the lock asynchronously, returning a `MutexGuard<T>`.
66 ///
67 /// The returned future suspends if the lock is already held and resumes
68 /// once the previous holder's `MutexGuard` is dropped.
69 pub fn lock(&self) -> LockFuture<'_, T> {
70 LockFuture { inner: &self.inner }
71 }
72}
73
74// ── LockFuture ────────────────────────────────────────────────────────────────
75
76/// Future returned by [`Mutex::lock`].
77pub struct LockFuture<'a, T> {
78 inner: &'a Arc<StdMutex<Inner<T>>>,
79}
80
81impl<T> Future for LockFuture<'_, T> {
82 type Output = MutexGuard<T>;
83
84 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85 let mut g = self.inner.lock().unwrap();
86 if !g.locked {
87 g.locked = true;
88 // Cache the pointer to the protected value at lock acquisition time.
89 let value_ptr = g.value.get();
90 Poll::Ready(MutexGuard {
91 inner: Arc::clone(self.inner),
92 value_ptr,
93 })
94 } else {
95 g.waiters.push_back(cx.waker().clone());
96 Poll::Pending
97 }
98 }
99}
100
101// ── MutexGuard ────────────────────────────────────────────────────────────────
102
103/// RAII guard that releases the async lock on drop and wakes the next waiter.
104pub struct MutexGuard<T> {
105 inner: Arc<StdMutex<Inner<T>>>,
106 /// Cached raw pointer to the protected value. Avoids acquiring the
107 /// StdMutex on every deref. Valid for the lifetime of this guard because:
108 /// - The Arc keeps the Inner allocation alive.
109 /// - The async `locked` flag prevents concurrent mutation.
110 value_ptr: *mut T,
111}
112
113// SAFETY: MutexGuard<T> is Send+Sync when T: Send because:
114// - The async lock serialises all access to the value.
115// - The raw pointer comes from UnsafeCell inside an Arc (heap-stable).
116unsafe impl<T: Send> Send for MutexGuard<T> {}
117unsafe impl<T: Send> Sync for MutexGuard<T> {}
118
119impl<T> Deref for MutexGuard<T> {
120 type Target = T;
121
122 fn deref(&self) -> &T {
123 // SAFETY: we hold the async lock (`locked == true`), so no other
124 // `MutexGuard` exists concurrently. The Arc keeps memory alive.
125 // `value_ptr` was obtained at lock acquisition time.
126 unsafe { &*self.value_ptr }
127 }
128}
129
130impl<T> DerefMut for MutexGuard<T> {
131 fn deref_mut(&mut self) -> &mut T {
132 // SAFETY: we hold the async lock exclusively; `&mut self` ensures
133 // no aliased mutable references exist via this guard.
134 unsafe { &mut *self.value_ptr }
135 }
136}
137
138impl<T> Drop for MutexGuard<T> {
139 fn drop(&mut self) {
140 let mut g = self.inner.lock().unwrap();
141 // Release the lock and wake the next waiter, if any.
142 g.locked = false;
143 if let Some(w) = g.waiters.pop_front() {
144 drop(g); // release inner mutex before waking
145 w.wake();
146 }
147 }
148}
149
150// ── Tests ─────────────────────────────────────────────────────────────────────
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::executor::{block_on, block_on_with_spawn, spawn};
156 use std::sync::Arc as StdArc;
157
158 #[test]
159 fn lock_and_mutate() {
160 block_on(async {
161 let m = Mutex::new(0u32);
162 {
163 let mut g = m.lock().await;
164 *g += 1;
165 }
166 {
167 let g = m.lock().await;
168 assert_eq!(*g, 1);
169 }
170 });
171 }
172
173 #[test]
174 fn sequential_locks_in_single_task() {
175 block_on(async {
176 let m = Mutex::new(Vec::<u32>::new());
177 for i in 0..5 {
178 m.lock().await.push(i);
179 }
180 let g = m.lock().await;
181 assert_eq!(*g, vec![0, 1, 2, 3, 4]);
182 });
183 }
184
185 #[test]
186 fn concurrent_lock_via_spawn() {
187 let counter = StdArc::new(Mutex::new(0u32));
188 let c1 = counter.clone();
189 let c2 = counter.clone();
190
191 block_on_with_spawn(async move {
192 let jh1 = spawn(async move {
193 let mut g = c1.lock().await;
194 *g += 1;
195 });
196 let jh2 = spawn(async move {
197 let mut g = c2.lock().await;
198 *g += 1;
199 });
200 jh1.await.unwrap();
201 jh2.await.unwrap();
202 });
203
204 // Run a fresh block_on to read the result.
205 let final_val = block_on(async { *counter.lock().await });
206 assert_eq!(final_val, 2);
207 }
208
209 #[test]
210 fn guard_drops_release_lock() {
211 block_on(async {
212 let m = Mutex::new(42u32);
213 let g = m.lock().await;
214 assert_eq!(*g, 42);
215 drop(g);
216 // After drop we must be able to lock again immediately.
217 let g2 = m.lock().await;
218 assert_eq!(*g2, 42);
219 });
220 }
221}