moduvex_runtime/sync/
mutex.rs1use std::cell::UnsafeCell;
14use std::collections::VecDeque;
15use std::future::Future;
16use std::ops::{Deref, DerefMut};
17use std::pin::Pin;
18use std::sync::{Arc, Mutex as StdMutex};
19use std::task::{Context, Poll, Waker};
20
21struct Inner<T> {
24 locked: bool,
26 waiters: VecDeque<Waker>,
28 value: UnsafeCell<T>,
35}
36
37unsafe impl<T: Send> Send for Inner<T> {}
41unsafe impl<T: Send> Sync for Inner<T> {}
42
43pub struct Mutex<T> {
50 inner: Arc<StdMutex<Inner<T>>>,
51}
52
53impl<T> Mutex<T> {
54 pub fn new(value: T) -> Self {
56 Self {
57 inner: Arc::new(StdMutex::new(Inner {
58 locked: false,
59 waiters: VecDeque::new(),
60 value: UnsafeCell::new(value),
61 })),
62 }
63 }
64
65 pub fn lock(&self) -> LockFuture<'_, T> {
70 LockFuture {
71 inner: &self.inner,
72 registered_waker: None,
73 }
74 }
75}
76
77pub struct LockFuture<'a, T> {
85 inner: &'a Arc<StdMutex<Inner<T>>>,
86 registered_waker: Option<Waker>,
89}
90
91impl<T> Future for LockFuture<'_, T> {
92 type Output = MutexGuard<T>;
93
94 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
95 let mut g = self.inner.lock().unwrap();
96 if !g.locked {
97 g.locked = true;
98 self.registered_waker = None; let value_ptr = g.value.get();
100 Poll::Ready(MutexGuard {
101 inner: Arc::clone(self.inner),
102 value_ptr,
103 })
104 } else {
105 let new_waker = cx.waker().clone();
106 if let Some(ref existing) = self.registered_waker {
107 if !existing.will_wake(&new_waker) {
109 for w in &mut g.waiters {
111 if w.will_wake(existing) {
112 *w = new_waker.clone();
113 break;
114 }
115 }
116 self.registered_waker = Some(new_waker);
117 }
118 } else {
119 g.waiters.push_back(new_waker.clone());
121 self.registered_waker = Some(new_waker);
122 }
123 Poll::Pending
124 }
125 }
126}
127
128impl<T> Drop for LockFuture<'_, T> {
129 fn drop(&mut self) {
130 if let Some(ref waker) = self.registered_waker {
131 if let Ok(mut g) = self.inner.lock() {
133 if let Some(pos) = g.waiters.iter().position(|w| w.will_wake(waker)) {
135 g.waiters.remove(pos);
136 }
137 }
138 }
139 }
140}
141
142pub struct MutexGuard<T> {
146 inner: Arc<StdMutex<Inner<T>>>,
147 value_ptr: *mut T,
152}
153
154unsafe impl<T: Send> Send for MutexGuard<T> {}
158unsafe impl<T: Send> Sync for MutexGuard<T> {}
159
160impl<T> Deref for MutexGuard<T> {
161 type Target = T;
162
163 fn deref(&self) -> &T {
164 unsafe { &*self.value_ptr }
168 }
169}
170
171impl<T> DerefMut for MutexGuard<T> {
172 fn deref_mut(&mut self) -> &mut T {
173 unsafe { &mut *self.value_ptr }
176 }
177}
178
179impl<T> Drop for MutexGuard<T> {
180 fn drop(&mut self) {
181 let mut g = self.inner.lock().unwrap();
182 g.locked = false;
184 if let Some(w) = g.waiters.pop_front() {
185 drop(g); w.wake();
187 }
188 }
189}
190
191#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::executor::{block_on, block_on_with_spawn, spawn};
197 use std::sync::Arc as StdArc;
198
199 #[test]
200 fn lock_and_mutate() {
201 block_on(async {
202 let m = Mutex::new(0u32);
203 {
204 let mut g = m.lock().await;
205 *g += 1;
206 }
207 {
208 let g = m.lock().await;
209 assert_eq!(*g, 1);
210 }
211 });
212 }
213
214 #[test]
215 fn sequential_locks_in_single_task() {
216 block_on(async {
217 let m = Mutex::new(Vec::<u32>::new());
218 for i in 0..5 {
219 m.lock().await.push(i);
220 }
221 let g = m.lock().await;
222 assert_eq!(*g, vec![0, 1, 2, 3, 4]);
223 });
224 }
225
226 #[test]
227 fn concurrent_lock_via_spawn() {
228 let counter = StdArc::new(Mutex::new(0u32));
229 let c1 = counter.clone();
230 let c2 = counter.clone();
231
232 block_on_with_spawn(async move {
233 let jh1 = spawn(async move {
234 let mut g = c1.lock().await;
235 *g += 1;
236 });
237 let jh2 = spawn(async move {
238 let mut g = c2.lock().await;
239 *g += 1;
240 });
241 jh1.await.unwrap();
242 jh2.await.unwrap();
243 });
244
245 let final_val = block_on(async { *counter.lock().await });
247 assert_eq!(final_val, 2);
248 }
249
250 #[test]
251 fn guard_drops_release_lock() {
252 block_on(async {
253 let m = Mutex::new(42u32);
254 let g = m.lock().await;
255 assert_eq!(*g, 42);
256 drop(g);
257 let g2 = m.lock().await;
259 assert_eq!(*g2, 42);
260 });
261 }
262
263 #[test]
266 fn mutex_stress_100_concurrent_increments() {
267 let counter = StdArc::new(Mutex::new(0u64));
268 let c = counter.clone();
269 block_on_with_spawn(async move {
270 let mut handles = Vec::new();
271 for _ in 0..100 {
272 let cc = c.clone();
273 handles.push(spawn(async move {
274 let mut g = cc.lock().await;
275 *g += 1;
276 }));
277 }
278 for h in handles {
279 h.await.unwrap();
280 }
281 });
282 let final_val = block_on(async { *counter.lock().await });
283 assert_eq!(final_val, 100);
284 }
285
286 #[test]
287 fn mutex_fifo_all_entries_recorded() {
288 let order = StdArc::new(Mutex::new(Vec::<u32>::new()));
290 let o = order.clone();
291 block_on_with_spawn(async move {
292 let mut handles = Vec::new();
293 for i in 0u32..5 {
294 let oo = o.clone();
295 handles.push(spawn(async move {
296 let mut g = oo.lock().await;
297 g.push(i);
298 }));
299 }
300 for h in handles {
301 h.await.unwrap();
302 }
303 });
304 let v = block_on(async { order.lock().await.len() });
305 assert_eq!(v, 5);
306 }
307
308 #[test]
309 fn mutex_guard_deref() {
310 block_on(async {
311 let m = Mutex::new(vec![1u32, 2, 3]);
312 let g = m.lock().await;
313 assert_eq!(g.len(), 3);
314 assert_eq!((*g)[1], 2);
315 });
316 }
317
318 #[test]
319 fn mutex_guard_deref_mut() {
320 block_on(async {
321 let m = Mutex::new(0u32);
322 let mut g = m.lock().await;
323 *g = 99;
324 drop(g);
325 assert_eq!(*m.lock().await, 99);
326 });
327 }
328
329 #[test]
330 fn mutex_reentrant_after_abort_no_deadlock() {
331 block_on_with_spawn(async {
332 let m = StdArc::new(Mutex::new(0u32));
333 let m2 = m.clone();
334 let guard = m.lock().await;
336 let jh = spawn(async move {
338 let _ = m2.lock().await;
340 });
341 jh.abort();
343 drop(guard); *m.lock().await += 1;
346 assert_eq!(*m.lock().await, 1);
347 });
348 }
349
350 #[test]
351 fn mutex_initial_value_preserved() {
352 block_on(async {
353 let m = Mutex::new(String::from("initial"));
354 let g = m.lock().await;
355 assert_eq!(*g, "initial");
356 });
357 }
358
359 #[test]
360 fn mutex_multiple_sequential_mutations() {
361 block_on(async {
362 let m = Mutex::new(0u32);
363 for i in 1..=10u32 {
364 *m.lock().await = i;
365 }
366 assert_eq!(*m.lock().await, 10);
367 });
368 }
369
370 #[test]
371 fn mutex_string_value() {
372 block_on(async {
373 let m = Mutex::new(String::new());
374 for i in 0..5 {
375 m.lock().await.push_str(&i.to_string());
376 }
377 assert_eq!(*m.lock().await, "01234");
378 });
379 }
380
381 #[test]
382 fn mutex_vec_value_append() {
383 block_on(async {
384 let m = Mutex::new(Vec::<u32>::new());
385 for i in 0..5u32 {
386 m.lock().await.push(i);
387 }
388 let g = m.lock().await;
389 assert_eq!(*g, vec![0, 1, 2, 3, 4]);
390 });
391 }
392
393 #[test]
394 fn mutex_concurrent_10_tasks() {
395 let counter = StdArc::new(Mutex::new(0u32));
396 let c = counter.clone();
397 block_on_with_spawn(async move {
398 let mut handles = Vec::new();
399 for _ in 0..10 {
400 let cc = c.clone();
401 handles.push(spawn(async move {
402 *cc.lock().await += 1;
403 }));
404 }
405 for h in handles {
406 h.await.unwrap();
407 }
408 });
409 let v = block_on(async { *counter.lock().await });
410 assert_eq!(v, 10);
411 }
412
413 #[test]
414 fn mutex_new_value_is_accessible() {
415 block_on(async {
416 let m = Mutex::new(42u64);
417 assert_eq!(*m.lock().await, 42);
418 });
419 }
420
421 #[test]
422 fn mutex_lock_after_multiple_releases() {
423 block_on(async {
424 let m = Mutex::new(0u32);
425 for _ in 0..5 {
426 let mut g = m.lock().await;
427 *g += 1;
428 drop(g);
429 }
430 assert_eq!(*m.lock().await, 5);
431 });
432 }
433
434 #[test]
435 fn mutex_guard_cannot_alias() {
436 let m = StdArc::new(Mutex::new(0u32));
438 let m2 = m.clone();
439 block_on_with_spawn(async move {
440 let g = m.lock().await;
441 let jh = spawn(async move {
442 *m2.lock().await += 1;
444 });
445 drop(g);
447 jh.await.unwrap();
448 assert_eq!(*m.lock().await, 1);
449 });
450 }
451
452 #[test]
453 fn mutex_hashmap_value() {
454 block_on(async {
455 use std::collections::HashMap;
456 let m = Mutex::new(HashMap::<String, u32>::new());
457 m.lock().await.insert("a".to_string(), 1);
458 m.lock().await.insert("b".to_string(), 2);
459 let g = m.lock().await;
460 assert_eq!(g.len(), 2);
461 assert_eq!(g.get("a"), Some(&1));
462 });
463 }
464
465 #[test]
466 fn mutex_u64_max_value() {
467 block_on(async {
468 let m = Mutex::new(u64::MAX);
469 assert_eq!(*m.lock().await, u64::MAX);
470 });
471 }
472
473 #[test]
474 fn mutex_wraps_arc() {
475 block_on(async {
476 let inner = StdArc::new(0u32);
477 let m = Mutex::new(inner.clone());
478 let g = m.lock().await;
479 assert_eq!(StdArc::strong_count(&*g), 2); });
481 }
482
483 #[test]
484 fn mutex_lock_and_immediately_drop() {
485 block_on(async {
486 let m = Mutex::new(42u32);
487 drop(m.lock().await); assert_eq!(*m.lock().await, 42);
490 });
491 }
492
493 #[test]
494 fn mutex_20_concurrent_tasks() {
495 let counter = StdArc::new(Mutex::new(0u32));
496 let c = counter.clone();
497 block_on_with_spawn(async move {
498 let handles: Vec<_> = (0..20)
499 .map(|_| {
500 let cc = c.clone();
501 spawn(async move { *cc.lock().await += 1 })
502 })
503 .collect();
504 for h in handles {
505 h.await.unwrap();
506 }
507 });
508 let v = block_on(async { *counter.lock().await });
509 assert_eq!(v, 20);
510 }
511}