1use std::cell::UnsafeCell;
2use std::marker::PhantomData;
3use std::ops::{Deref, DerefMut};
4use std::pin::Pin;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::{Arc, Mutex as StdMutex};
7use std::{fmt, mem};
8
9use slab::Slab;
10
11use futures_core::future::{FusedFuture, Future};
12use futures_core::task::{Context, Poll, Waker};
13
14pub struct Mutex<T: ?Sized> {
23 state: AtomicUsize,
24 waiters: StdMutex<Slab<Waiter>>,
25 value: UnsafeCell<T>,
26}
27
28impl<T: ?Sized> fmt::Debug for Mutex<T> {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 let state = self.state.load(Ordering::SeqCst);
31 f.debug_struct("Mutex")
32 .field("is_locked", &((state & IS_LOCKED) != 0))
33 .field("has_waiters", &((state & HAS_WAITERS) != 0))
34 .finish()
35 }
36}
37
38impl<T> From<T> for Mutex<T> {
39 fn from(t: T) -> Self {
40 Self::new(t)
41 }
42}
43
44impl<T: Default> Default for Mutex<T> {
45 fn default() -> Self {
46 Self::new(Default::default())
47 }
48}
49
50enum Waiter {
51 Waiting(Waker),
52 Woken,
53}
54
55impl Waiter {
56 fn register(&mut self, waker: &Waker) {
57 match self {
58 Self::Waiting(w) if waker.will_wake(w) => {}
59 _ => *self = Self::Waiting(waker.clone()),
60 }
61 }
62
63 fn wake(&mut self) {
64 match mem::replace(self, Self::Woken) {
65 Self::Waiting(waker) => waker.wake(),
66 Self::Woken => {}
67 }
68 }
69}
70
71const IS_LOCKED: usize = 1 << 0;
72const HAS_WAITERS: usize = 1 << 1;
73
74impl<T> Mutex<T> {
75 pub const fn new(t: T) -> Self {
77 Self {
78 state: AtomicUsize::new(0),
79 waiters: StdMutex::new(Slab::new()),
80 value: UnsafeCell::new(t),
81 }
82 }
83
84 pub fn into_inner(self) -> T {
95 self.value.into_inner()
96 }
97}
98
99impl<T: ?Sized> Mutex<T> {
100 pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
104 let old_state = self.state.fetch_or(IS_LOCKED, Ordering::Acquire);
105 if (old_state & IS_LOCKED) == 0 {
106 Some(MutexGuard { mutex: self })
107 } else {
108 None
109 }
110 }
111
112 pub fn try_lock_owned(self: &Arc<Self>) -> Option<OwnedMutexGuard<T>> {
116 let old_state = self.state.fetch_or(IS_LOCKED, Ordering::Acquire);
117 if (old_state & IS_LOCKED) == 0 {
118 Some(OwnedMutexGuard { mutex: self.clone() })
119 } else {
120 None
121 }
122 }
123
124 pub fn lock(&self) -> MutexLockFuture<'_, T> {
129 MutexLockFuture { mutex: Some(self), wait_key: WAIT_KEY_NONE }
130 }
131
132 pub fn lock_owned(self: Arc<Self>) -> OwnedMutexLockFuture<T> {
137 OwnedMutexLockFuture { mutex: Some(self), wait_key: WAIT_KEY_NONE }
138 }
139
140 pub fn get_mut(&mut self) -> &mut T {
157 unsafe { &mut *self.value.get() }
160 }
161
162 fn remove_waker(&self, wait_key: usize, wake_another: bool) {
163 if wait_key != WAIT_KEY_NONE {
164 let mut waiters = self.waiters.lock().unwrap();
165 match waiters.remove(wait_key) {
166 Waiter::Waiting(_) => {}
167 Waiter::Woken => {
168 if wake_another {
172 if let Some((_i, waiter)) = waiters.iter_mut().next() {
173 waiter.wake();
174 }
175 }
176 }
177 }
178 if waiters.is_empty() {
179 self.state.fetch_and(!HAS_WAITERS, Ordering::Relaxed); }
181 }
182 }
183
184 fn unlock(&self) {
187 let old_state = self.state.fetch_and(!IS_LOCKED, Ordering::AcqRel);
188 if (old_state & HAS_WAITERS) != 0 {
189 let mut waiters = self.waiters.lock().unwrap();
190 if let Some((_i, waiter)) = waiters.iter_mut().next() {
191 waiter.wake();
192 }
193 }
194 }
195}
196
197const WAIT_KEY_NONE: usize = usize::MAX;
199
200pub struct OwnedMutexLockFuture<T: ?Sized> {
202 mutex: Option<Arc<Mutex<T>>>,
204 wait_key: usize,
205}
206
207impl<T: ?Sized> fmt::Debug for OwnedMutexLockFuture<T> {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.debug_struct("OwnedMutexLockFuture")
210 .field("was_acquired", &self.mutex.is_none())
211 .field("mutex", &self.mutex)
212 .field(
213 "wait_key",
214 &(if self.wait_key == WAIT_KEY_NONE { None } else { Some(self.wait_key) }),
215 )
216 .finish()
217 }
218}
219
220impl<T: ?Sized> FusedFuture for OwnedMutexLockFuture<T> {
221 fn is_terminated(&self) -> bool {
222 self.mutex.is_none()
223 }
224}
225
226impl<T: ?Sized> Future for OwnedMutexLockFuture<T> {
227 type Output = OwnedMutexGuard<T>;
228
229 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
230 let this = self.get_mut();
231
232 let mutex = this.mutex.as_ref().expect("polled OwnedMutexLockFuture after completion");
233
234 if let Some(lock) = mutex.try_lock_owned() {
235 mutex.remove_waker(this.wait_key, false);
236 this.mutex = None;
237 return Poll::Ready(lock);
238 }
239
240 {
241 let mut waiters = mutex.waiters.lock().unwrap();
242 if this.wait_key == WAIT_KEY_NONE {
243 this.wait_key = waiters.insert(Waiter::Waiting(cx.waker().clone()));
244 if waiters.len() == 1 {
245 mutex.state.fetch_or(HAS_WAITERS, Ordering::Relaxed); }
247 } else {
248 waiters[this.wait_key].register(cx.waker());
249 }
250 }
251
252 if let Some(lock) = mutex.try_lock_owned() {
255 mutex.remove_waker(this.wait_key, false);
256 this.mutex = None;
257 return Poll::Ready(lock);
258 }
259
260 Poll::Pending
261 }
262}
263
264impl<T: ?Sized> Drop for OwnedMutexLockFuture<T> {
265 fn drop(&mut self) {
266 if let Some(mutex) = self.mutex.as_ref() {
267 mutex.remove_waker(self.wait_key, true);
272 }
273 }
274}
275
276#[clippy::has_significant_drop]
280pub struct OwnedMutexGuard<T: ?Sized> {
281 mutex: Arc<Mutex<T>>,
282}
283
284impl<T: ?Sized + fmt::Debug> fmt::Debug for OwnedMutexGuard<T> {
285 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286 f.debug_struct("OwnedMutexGuard")
287 .field("value", &&**self)
288 .field("mutex", &self.mutex)
289 .finish()
290 }
291}
292
293impl<T: ?Sized> Drop for OwnedMutexGuard<T> {
294 fn drop(&mut self) {
295 self.mutex.unlock()
296 }
297}
298
299impl<T: ?Sized> Deref for OwnedMutexGuard<T> {
300 type Target = T;
301 fn deref(&self) -> &T {
302 unsafe { &*self.mutex.value.get() }
303 }
304}
305
306impl<T: ?Sized> DerefMut for OwnedMutexGuard<T> {
307 fn deref_mut(&mut self) -> &mut T {
308 unsafe { &mut *self.mutex.value.get() }
309 }
310}
311
312pub struct MutexLockFuture<'a, T: ?Sized> {
314 mutex: Option<&'a Mutex<T>>,
316 wait_key: usize,
317}
318
319impl<T: ?Sized> fmt::Debug for MutexLockFuture<'_, T> {
320 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321 f.debug_struct("MutexLockFuture")
322 .field("was_acquired", &self.mutex.is_none())
323 .field("mutex", &self.mutex)
324 .field(
325 "wait_key",
326 &(if self.wait_key == WAIT_KEY_NONE { None } else { Some(self.wait_key) }),
327 )
328 .finish()
329 }
330}
331
332impl<T: ?Sized> FusedFuture for MutexLockFuture<'_, T> {
333 fn is_terminated(&self) -> bool {
334 self.mutex.is_none()
335 }
336}
337
338impl<'a, T: ?Sized> Future for MutexLockFuture<'a, T> {
339 type Output = MutexGuard<'a, T>;
340
341 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
342 let mutex = self.mutex.expect("polled MutexLockFuture after completion");
343
344 if let Some(lock) = mutex.try_lock() {
345 mutex.remove_waker(self.wait_key, false);
346 self.mutex = None;
347 return Poll::Ready(lock);
348 }
349
350 {
351 let mut waiters = mutex.waiters.lock().unwrap();
352 if self.wait_key == WAIT_KEY_NONE {
353 self.wait_key = waiters.insert(Waiter::Waiting(cx.waker().clone()));
354 if waiters.len() == 1 {
355 mutex.state.fetch_or(HAS_WAITERS, Ordering::Relaxed); }
357 } else {
358 waiters[self.wait_key].register(cx.waker());
359 }
360 }
361
362 if let Some(lock) = mutex.try_lock() {
365 mutex.remove_waker(self.wait_key, false);
366 self.mutex = None;
367 return Poll::Ready(lock);
368 }
369
370 Poll::Pending
371 }
372}
373
374impl<T: ?Sized> Drop for MutexLockFuture<'_, T> {
375 fn drop(&mut self) {
376 if let Some(mutex) = self.mutex {
377 mutex.remove_waker(self.wait_key, true);
382 }
383 }
384}
385
386#[clippy::has_significant_drop]
390pub struct MutexGuard<'a, T: ?Sized> {
391 mutex: &'a Mutex<T>,
392}
393
394impl<'a, T: ?Sized> MutexGuard<'a, T> {
395 #[inline]
411 pub fn map<U: ?Sized, F>(this: Self, f: F) -> MappedMutexGuard<'a, T, U>
412 where
413 F: FnOnce(&mut T) -> &mut U,
414 {
415 let mutex = this.mutex;
416 let value = f(unsafe { &mut *this.mutex.value.get() });
417 mem::forget(this);
420 MappedMutexGuard { mutex, value, _marker: PhantomData }
421 }
422}
423
424impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
425 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426 f.debug_struct("MutexGuard").field("value", &&**self).field("mutex", &self.mutex).finish()
427 }
428}
429
430impl<T: ?Sized> Drop for MutexGuard<'_, T> {
431 fn drop(&mut self) {
432 self.mutex.unlock()
433 }
434}
435
436impl<T: ?Sized> Deref for MutexGuard<'_, T> {
437 type Target = T;
438 fn deref(&self) -> &T {
439 unsafe { &*self.mutex.value.get() }
440 }
441}
442
443impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
444 fn deref_mut(&mut self) -> &mut T {
445 unsafe { &mut *self.mutex.value.get() }
446 }
447}
448
449#[clippy::has_significant_drop]
452pub struct MappedMutexGuard<'a, T: ?Sized, U: ?Sized> {
453 mutex: &'a Mutex<T>,
454 value: *mut U,
455 _marker: PhantomData<&'a mut U>,
456}
457
458impl<'a, T: ?Sized, U: ?Sized> MappedMutexGuard<'a, T, U> {
459 #[inline]
476 pub fn map<V: ?Sized, F>(this: Self, f: F) -> MappedMutexGuard<'a, T, V>
477 where
478 F: FnOnce(&mut U) -> &mut V,
479 {
480 let mutex = this.mutex;
481 let value = f(unsafe { &mut *this.value });
482 mem::forget(this);
485 MappedMutexGuard { mutex, value, _marker: PhantomData }
486 }
487}
488
489impl<T: ?Sized, U: ?Sized + fmt::Debug> fmt::Debug for MappedMutexGuard<'_, T, U> {
490 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
491 f.debug_struct("MappedMutexGuard")
492 .field("value", &&**self)
493 .field("mutex", &self.mutex)
494 .finish()
495 }
496}
497
498impl<T: ?Sized, U: ?Sized> Drop for MappedMutexGuard<'_, T, U> {
499 fn drop(&mut self) {
500 self.mutex.unlock()
501 }
502}
503
504impl<T: ?Sized, U: ?Sized> Deref for MappedMutexGuard<'_, T, U> {
505 type Target = U;
506 fn deref(&self) -> &U {
507 unsafe { &*self.value }
508 }
509}
510
511impl<T: ?Sized, U: ?Sized> DerefMut for MappedMutexGuard<'_, T, U> {
512 fn deref_mut(&mut self) -> &mut U {
513 unsafe { &mut *self.value }
514 }
515}
516
517unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
520unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
521
522unsafe impl<T: ?Sized + Send> Send for MutexLockFuture<'_, T> {}
525
526unsafe impl<T: ?Sized> Sync for MutexLockFuture<'_, T> {}
528
529unsafe impl<T: ?Sized + Send> Send for OwnedMutexLockFuture<T> {}
532
533unsafe impl<T: ?Sized> Sync for OwnedMutexLockFuture<T> {}
535
536unsafe impl<T: ?Sized + Send> Send for MutexGuard<'_, T> {}
539unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
540
541unsafe impl<T: ?Sized + Send> Send for OwnedMutexGuard<T> {}
542unsafe impl<T: ?Sized + Sync> Sync for OwnedMutexGuard<T> {}
543
544unsafe impl<T: ?Sized + Send, U: ?Sized + Send> Send for MappedMutexGuard<'_, T, U> {}
545unsafe impl<T: ?Sized + Sync, U: ?Sized + Sync> Sync for MappedMutexGuard<'_, T, U> {}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use std::format;
551
552 #[test]
553 fn test_mutex_guard_debug_not_recurse() {
554 let mutex = Mutex::new(42);
555 let guard = mutex.try_lock().unwrap();
556 let _ = format!("{guard:?}");
557 let guard = MutexGuard::map(guard, |n| n);
558 let _ = format!("{guard:?}");
559 }
560}