peace_lock/
mutex.rs

1#[cfg(feature = "owning_ref")]
2use owning_ref::StableAddress;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5#[cfg(any(debug_assertions, feature = "check"))]
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::{
8    cell::UnsafeCell,
9    ops::{Deref, DerefMut},
10    panic::{RefUnwindSafe, UnwindSafe},
11};
12
13/// A mutual exclusive lock
14#[derive(Debug)]
15pub struct Mutex<T: ?Sized> {
16    #[cfg(any(debug_assertions, feature = "check"))]
17    state: AtomicBool,
18    value: UnsafeCell<T>,
19}
20
21impl<T> RefUnwindSafe for Mutex<T> where T: ?Sized {}
22impl<T> UnwindSafe for Mutex<T> where T: ?Sized {}
23unsafe impl<T> Send for Mutex<T> where T: ?Sized + Send {}
24unsafe impl<T> Sync for Mutex<T> where T: ?Sized + Send {}
25
26impl<T> From<T> for Mutex<T> {
27    fn from(val: T) -> Self {
28        Self::new(val)
29    }
30}
31
32impl<T> Default for Mutex<T>
33where
34    T: ?Sized + Default,
35{
36    fn default() -> Self {
37        Self::new(T::default())
38    }
39}
40
41impl<T> Mutex<T> {
42    /// Create a new `Mutex`.
43    #[inline]
44    pub fn new(val: T) -> Self {
45        Self {
46            #[cfg(any(debug_assertions, feature = "check"))]
47            state: AtomicBool::new(false),
48            value: UnsafeCell::new(val),
49        }
50    }
51
52    /// Consume the `Mutex`, returning the inner value.
53    #[inline]
54    pub fn into_inner(self) -> T {
55        self.value.into_inner()
56    }
57}
58
59impl<T> Mutex<T>
60where
61    T: ?Sized,
62{
63    /// Get a mutable reference of the inner value T. This is safe because we
64    /// have the mutable reference of the lock.
65    #[inline]
66    pub fn get_mut(&mut self) -> &mut T {
67        self.value.get_mut()
68    }
69
70    /// Try lock the `Mutex`, returns the mutex guard. Returns None if the
71    /// `Mutex` is write locked.
72    #[inline]
73    pub fn try_lock<'a>(&'a self) -> Option<MutexGuard<'a, T>> {
74        self.lock_exclusive().then(|| MutexGuard { lock: self })
75    }
76
77    /// Lock the `Mutex`, returns the mutex guard.
78    ///
79    /// # Panics
80    ///
81    /// If the `Mutex` is already locked, this will panic if the `check` feature
82    /// is turned on.
83    #[inline]
84    pub fn lock<'a>(&'a self) -> MutexGuard<'a, T> {
85        if !self.lock_exclusive() {
86            #[cfg(any(debug_assertions, feature = "check"))]
87            panic!("The lock is already write locked")
88        }
89
90        MutexGuard { lock: self }
91    }
92
93    #[inline]
94    fn lock_exclusive(&self) -> bool {
95        #[cfg(any(debug_assertions, feature = "check"))]
96        {
97            self.state
98                .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
99                .is_ok()
100        }
101
102        #[cfg(not(any(debug_assertions, feature = "check")))]
103        true
104    }
105
106    #[inline]
107    fn unlock_exclusive(&self) -> bool {
108        #[cfg(any(debug_assertions, feature = "check"))]
109        {
110            self.state
111                .compare_exchange(true, false, Ordering::Acquire, Ordering::Relaxed)
112                .is_ok()
113        }
114
115        #[cfg(not(any(debug_assertions, feature = "check")))]
116        true
117    }
118}
119
120pub struct MutexGuard<'a, T>
121where
122    T: ?Sized,
123{
124    lock: &'a Mutex<T>,
125}
126
127impl<'a, T> Deref for MutexGuard<'a, T>
128where
129    T: ?Sized,
130{
131    type Target = T;
132
133    #[inline]
134    fn deref(&self) -> &T {
135        unsafe { &*self.lock.value.get() }
136    }
137}
138
139impl<'a, T> DerefMut for MutexGuard<'a, T>
140where
141    T: ?Sized,
142{
143    #[inline]
144    fn deref_mut(&mut self) -> &mut T {
145        unsafe { &mut *self.lock.value.get() }
146    }
147}
148
149impl<'a, T> Drop for MutexGuard<'a, T>
150where
151    T: ?Sized,
152{
153    #[inline]
154    fn drop(&mut self) {
155        self.lock.unlock_exclusive();
156    }
157}
158
159#[cfg(feature = "owning_ref")]
160unsafe impl<'a, T: 'a> StableAddress for MutexGuard<'a, T> where T: ?Sized {}
161
162#[cfg(feature = "serde")]
163impl<T> Serialize for Mutex<T>
164where
165    T: Serialize + ?Sized,
166{
167    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168    where
169        S: Serializer,
170    {
171        self.lock().serialize(serializer)
172    }
173}
174
175#[cfg(feature = "serde")]
176impl<'de, T> Deserialize<'de> for Mutex<T>
177where
178    T: Deserialize<'de> + ?Sized,
179{
180    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181    where
182        D: Deserializer<'de>,
183    {
184        Deserialize::deserialize(deserializer).map(Mutex::new)
185    }
186}