Skip to main content

moduvex_runtime/sync/
mutex.rs

1//! Async mutex — cooperative mutual exclusion for async tasks.
2//!
3//! Unlike `std::sync::Mutex`, locking suspends the calling task (yields back
4//! to the executor) instead of blocking the OS thread. This is critical inside
5//! async contexts where blocking would starve other tasks sharing the thread.
6//!
7//! # Design
8//! - Inner value protected by a `std::sync::Mutex` for the critical section
9//!   of updating waker queues and the locked flag.
10//! - A `VecDeque<Waker>` wait queue ensures FIFO fairness across contenders.
11//! - `MutexGuard` drops the lock and wakes the next waiter on `Drop`.
12
13use std::cell::UnsafeCell;
14use std::collections::VecDeque;
15use std::future::Future;
16use std::ops::{Deref, DerefMut};
17use std::pin::Pin;
18use std::sync::{Arc, Mutex as StdMutex};
19use std::task::{Context, Poll, Waker};
20
21// ── Inner state ───────────────────────────────────────────────────────────────
22
23struct Inner<T> {
24    /// Whether the async lock is currently held by a `MutexGuard`.
25    locked: bool,
26    /// Tasks waiting to acquire the lock, in arrival order (FIFO).
27    waiters: VecDeque<Waker>,
28    /// The protected value.
29    ///
30    /// `UnsafeCell` allows mutation through a shared `Arc<Inner<T>>`.
31    /// Safe because access is serialised: only the current `MutexGuard`
32    /// holder may dereference this pointer, and there is at most one guard
33    /// alive at a time (enforced by `locked`).
34    value: UnsafeCell<T>,
35}
36
37// SAFETY: `Mutex<T>` must be `Send + Sync` when `T: Send` so it can be shared
38// across async tasks. The `UnsafeCell` is safe because mutation is serialised
39// by the `locked` flag inside the `StdMutex<Inner>`.
40unsafe impl<T: Send> Send for Inner<T> {}
41unsafe impl<T: Send> Sync for Inner<T> {}
42
43// ── Mutex ─────────────────────────────────────────────────────────────────────
44
45/// Async-aware mutual exclusion primitive.
46///
47/// Wraps a value of type `T`; concurrent tasks suspend (not block) while
48/// waiting for the lock.
49pub struct Mutex<T> {
50    inner: Arc<StdMutex<Inner<T>>>,
51}
52
53impl<T> Mutex<T> {
54    /// Create a new `Mutex` wrapping `value`.
55    pub fn new(value: T) -> Self {
56        Self {
57            inner: Arc::new(StdMutex::new(Inner {
58                locked: false,
59                waiters: VecDeque::new(),
60                value: UnsafeCell::new(value),
61            })),
62        }
63    }
64
65    /// Acquire the lock asynchronously, returning a `MutexGuard<T>`.
66    ///
67    /// The returned future suspends if the lock is already held and resumes
68    /// once the previous holder's `MutexGuard` is dropped.
69    pub fn lock(&self) -> LockFuture<'_, T> {
70        LockFuture {
71            inner: &self.inner,
72            registered_waker: None,
73        }
74    }
75}
76
77// ── LockFuture ────────────────────────────────────────────────────────────────
78
79/// Future returned by [`Mutex::lock`].
80///
81/// Stores its registered waker so it can remove itself from the queue on
82/// cancellation (drop before completion). This prevents MutexGuard::drop from
83/// waking an already-dropped task.
84pub struct LockFuture<'a, T> {
85    inner: &'a Arc<StdMutex<Inner<T>>>,
86    /// The waker we pushed into `waiters`, stored so Drop can remove it.
87    /// `None` if we have not yet registered (or have already been resolved).
88    registered_waker: Option<Waker>,
89}
90
91impl<T> Future for LockFuture<'_, T> {
92    type Output = MutexGuard<T>;
93
94    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
95        let mut g = self.inner.lock().unwrap();
96        if !g.locked {
97            g.locked = true;
98            self.registered_waker = None; // lock acquired; no longer in waiter queue
99            let value_ptr = g.value.get();
100            Poll::Ready(MutexGuard {
101                inner: Arc::clone(self.inner),
102                value_ptr,
103            })
104        } else {
105            let new_waker = cx.waker().clone();
106            if let Some(ref existing) = self.registered_waker {
107                // Already registered: update in place if waker changed.
108                if !existing.will_wake(&new_waker) {
109                    // Replace our stale waker in the queue with the new one.
110                    for w in &mut g.waiters {
111                        if w.will_wake(existing) {
112                            *w = new_waker.clone();
113                            break;
114                        }
115                    }
116                    self.registered_waker = Some(new_waker);
117                }
118            } else {
119                // First time blocked — push waker and remember it for cleanup.
120                g.waiters.push_back(new_waker.clone());
121                self.registered_waker = Some(new_waker);
122            }
123            Poll::Pending
124        }
125    }
126}
127
128impl<T> Drop for LockFuture<'_, T> {
129    fn drop(&mut self) {
130        if let Some(ref waker) = self.registered_waker {
131            // Remove our waker so MutexGuard::drop doesn't wake a dead task.
132            if let Ok(mut g) = self.inner.lock() {
133                // Remove the first waker in the queue that matches ours.
134                if let Some(pos) = g.waiters.iter().position(|w| w.will_wake(waker)) {
135                    g.waiters.remove(pos);
136                }
137            }
138        }
139    }
140}
141
142// ── MutexGuard ────────────────────────────────────────────────────────────────
143
144/// RAII guard that releases the async lock on drop and wakes the next waiter.
145pub struct MutexGuard<T> {
146    inner: Arc<StdMutex<Inner<T>>>,
147    /// Cached raw pointer to the protected value. Avoids acquiring the
148    /// StdMutex on every deref. Valid for the lifetime of this guard because:
149    /// - The Arc keeps the Inner allocation alive.
150    /// - The async `locked` flag prevents concurrent mutation.
151    value_ptr: *mut T,
152}
153
154// SAFETY: MutexGuard<T> is Send+Sync when T: Send because:
155// - The async lock serialises all access to the value.
156// - The raw pointer comes from UnsafeCell inside an Arc (heap-stable).
157unsafe impl<T: Send> Send for MutexGuard<T> {}
158unsafe impl<T: Send> Sync for MutexGuard<T> {}
159
160impl<T> Deref for MutexGuard<T> {
161    type Target = T;
162
163    fn deref(&self) -> &T {
164        // SAFETY: we hold the async lock (`locked == true`), so no other
165        // `MutexGuard` exists concurrently. The Arc keeps memory alive.
166        // `value_ptr` was obtained at lock acquisition time.
167        unsafe { &*self.value_ptr }
168    }
169}
170
171impl<T> DerefMut for MutexGuard<T> {
172    fn deref_mut(&mut self) -> &mut T {
173        // SAFETY: we hold the async lock exclusively; `&mut self` ensures
174        // no aliased mutable references exist via this guard.
175        unsafe { &mut *self.value_ptr }
176    }
177}
178
179impl<T> Drop for MutexGuard<T> {
180    fn drop(&mut self) {
181        let mut g = self.inner.lock().unwrap();
182        // Release the lock and wake the next waiter, if any.
183        g.locked = false;
184        if let Some(w) = g.waiters.pop_front() {
185            drop(g); // release inner mutex before waking
186            w.wake();
187        }
188    }
189}
190
191// ── Tests ─────────────────────────────────────────────────────────────────────
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::executor::{block_on, block_on_with_spawn, spawn};
197    use std::sync::Arc as StdArc;
198
199    #[test]
200    fn lock_and_mutate() {
201        block_on(async {
202            let m = Mutex::new(0u32);
203            {
204                let mut g = m.lock().await;
205                *g += 1;
206            }
207            {
208                let g = m.lock().await;
209                assert_eq!(*g, 1);
210            }
211        });
212    }
213
214    #[test]
215    fn sequential_locks_in_single_task() {
216        block_on(async {
217            let m = Mutex::new(Vec::<u32>::new());
218            for i in 0..5 {
219                m.lock().await.push(i);
220            }
221            let g = m.lock().await;
222            assert_eq!(*g, vec![0, 1, 2, 3, 4]);
223        });
224    }
225
226    #[test]
227    fn concurrent_lock_via_spawn() {
228        let counter = StdArc::new(Mutex::new(0u32));
229        let c1 = counter.clone();
230        let c2 = counter.clone();
231
232        block_on_with_spawn(async move {
233            let jh1 = spawn(async move {
234                let mut g = c1.lock().await;
235                *g += 1;
236            });
237            let jh2 = spawn(async move {
238                let mut g = c2.lock().await;
239                *g += 1;
240            });
241            jh1.await.unwrap();
242            jh2.await.unwrap();
243        });
244
245        // Run a fresh block_on to read the result.
246        let final_val = block_on(async { *counter.lock().await });
247        assert_eq!(final_val, 2);
248    }
249
250    #[test]
251    fn guard_drops_release_lock() {
252        block_on(async {
253            let m = Mutex::new(42u32);
254            let g = m.lock().await;
255            assert_eq!(*g, 42);
256            drop(g);
257            // After drop we must be able to lock again immediately.
258            let g2 = m.lock().await;
259            assert_eq!(*g2, 42);
260        });
261    }
262
263    // ── Additional mutex tests ─────────────────────────────────────────────
264
265    #[test]
266    fn mutex_stress_100_concurrent_increments() {
267        let counter = StdArc::new(Mutex::new(0u64));
268        let c = counter.clone();
269        block_on_with_spawn(async move {
270            let mut handles = Vec::new();
271            for _ in 0..100 {
272                let cc = c.clone();
273                handles.push(spawn(async move {
274                    let mut g = cc.lock().await;
275                    *g += 1;
276                }));
277            }
278            for h in handles {
279                h.await.unwrap();
280            }
281        });
282        let final_val = block_on(async { *counter.lock().await });
283        assert_eq!(final_val, 100);
284    }
285
286    #[test]
287    fn mutex_fifo_all_entries_recorded() {
288        // All lockers queue; each pushes a known value.
289        let order = StdArc::new(Mutex::new(Vec::<u32>::new()));
290        let o = order.clone();
291        block_on_with_spawn(async move {
292            let mut handles = Vec::new();
293            for i in 0u32..5 {
294                let oo = o.clone();
295                handles.push(spawn(async move {
296                    let mut g = oo.lock().await;
297                    g.push(i);
298                }));
299            }
300            for h in handles {
301                h.await.unwrap();
302            }
303        });
304        let v = block_on(async { order.lock().await.len() });
305        assert_eq!(v, 5);
306    }
307
308    #[test]
309    fn mutex_guard_deref() {
310        block_on(async {
311            let m = Mutex::new(vec![1u32, 2, 3]);
312            let g = m.lock().await;
313            assert_eq!(g.len(), 3);
314            assert_eq!((*g)[1], 2);
315        });
316    }
317
318    #[test]
319    fn mutex_guard_deref_mut() {
320        block_on(async {
321            let m = Mutex::new(0u32);
322            let mut g = m.lock().await;
323            *g = 99;
324            drop(g);
325            assert_eq!(*m.lock().await, 99);
326        });
327    }
328
329    #[test]
330    fn mutex_reentrant_after_abort_no_deadlock() {
331        block_on_with_spawn(async {
332            let m = StdArc::new(Mutex::new(0u32));
333            let m2 = m.clone();
334            // Hold the lock in one task
335            let guard = m.lock().await;
336            // Spawn a task that will block trying to acquire
337            let jh = spawn(async move {
338                // This will be Pending because guard holds the lock
339                let _ = m2.lock().await;
340            });
341            // Abort the waiting task
342            jh.abort();
343            drop(guard); // release lock — should not deadlock
344            // Verify we can still acquire the lock
345            *m.lock().await += 1;
346            assert_eq!(*m.lock().await, 1);
347        });
348    }
349
350    #[test]
351    fn mutex_initial_value_preserved() {
352        block_on(async {
353            let m = Mutex::new(String::from("initial"));
354            let g = m.lock().await;
355            assert_eq!(*g, "initial");
356        });
357    }
358
359    #[test]
360    fn mutex_multiple_sequential_mutations() {
361        block_on(async {
362            let m = Mutex::new(0u32);
363            for i in 1..=10u32 {
364                *m.lock().await = i;
365            }
366            assert_eq!(*m.lock().await, 10);
367        });
368    }
369
370    #[test]
371    fn mutex_string_value() {
372        block_on(async {
373            let m = Mutex::new(String::new());
374            for i in 0..5 {
375                m.lock().await.push_str(&i.to_string());
376            }
377            assert_eq!(*m.lock().await, "01234");
378        });
379    }
380
381    #[test]
382    fn mutex_vec_value_append() {
383        block_on(async {
384            let m = Mutex::new(Vec::<u32>::new());
385            for i in 0..5u32 {
386                m.lock().await.push(i);
387            }
388            let g = m.lock().await;
389            assert_eq!(*g, vec![0, 1, 2, 3, 4]);
390        });
391    }
392
393    #[test]
394    fn mutex_concurrent_10_tasks() {
395        let counter = StdArc::new(Mutex::new(0u32));
396        let c = counter.clone();
397        block_on_with_spawn(async move {
398            let mut handles = Vec::new();
399            for _ in 0..10 {
400                let cc = c.clone();
401                handles.push(spawn(async move {
402                    *cc.lock().await += 1;
403                }));
404            }
405            for h in handles {
406                h.await.unwrap();
407            }
408        });
409        let v = block_on(async { *counter.lock().await });
410        assert_eq!(v, 10);
411    }
412
413    #[test]
414    fn mutex_new_value_is_accessible() {
415        block_on(async {
416            let m = Mutex::new(42u64);
417            assert_eq!(*m.lock().await, 42);
418        });
419    }
420
421    #[test]
422    fn mutex_lock_after_multiple_releases() {
423        block_on(async {
424            let m = Mutex::new(0u32);
425            for _ in 0..5 {
426                let mut g = m.lock().await;
427                *g += 1;
428                drop(g);
429            }
430            assert_eq!(*m.lock().await, 5);
431        });
432    }
433
434    #[test]
435    fn mutex_guard_cannot_alias() {
436        // Taking a second lock while guard is held blocks (we verify by spawning)
437        let m = StdArc::new(Mutex::new(0u32));
438        let m2 = m.clone();
439        block_on_with_spawn(async move {
440            let g = m.lock().await;
441            let jh = spawn(async move {
442                // This should block until g is dropped
443                *m2.lock().await += 1;
444            });
445            // Release g after spawning
446            drop(g);
447            jh.await.unwrap();
448            assert_eq!(*m.lock().await, 1);
449        });
450    }
451
452    #[test]
453    fn mutex_hashmap_value() {
454        block_on(async {
455            use std::collections::HashMap;
456            let m = Mutex::new(HashMap::<String, u32>::new());
457            m.lock().await.insert("a".to_string(), 1);
458            m.lock().await.insert("b".to_string(), 2);
459            let g = m.lock().await;
460            assert_eq!(g.len(), 2);
461            assert_eq!(g.get("a"), Some(&1));
462        });
463    }
464
465    #[test]
466    fn mutex_u64_max_value() {
467        block_on(async {
468            let m = Mutex::new(u64::MAX);
469            assert_eq!(*m.lock().await, u64::MAX);
470        });
471    }
472
473    #[test]
474    fn mutex_wraps_arc() {
475        block_on(async {
476            let inner = StdArc::new(0u32);
477            let m = Mutex::new(inner.clone());
478            let g = m.lock().await;
479            assert_eq!(StdArc::strong_count(&*g), 2); // inner + guard's ref
480        });
481    }
482
483    #[test]
484    fn mutex_lock_and_immediately_drop() {
485        block_on(async {
486            let m = Mutex::new(42u32);
487            drop(m.lock().await); // lock and release immediately
488            // Verify we can lock again
489            assert_eq!(*m.lock().await, 42);
490        });
491    }
492
493    #[test]
494    fn mutex_20_concurrent_tasks() {
495        let counter = StdArc::new(Mutex::new(0u32));
496        let c = counter.clone();
497        block_on_with_spawn(async move {
498            let handles: Vec<_> = (0..20)
499                .map(|_| {
500                    let cc = c.clone();
501                    spawn(async move { *cc.lock().await += 1 })
502                })
503                .collect();
504            for h in handles {
505                h.await.unwrap();
506            }
507        });
508        let v = block_on(async { *counter.lock().await });
509        assert_eq!(v, 20);
510    }
511}