maybe_fut/api/sync/
mutex.rs

1mod guard;
2
3use std::sync::{PoisonError, TryLockError};
4
5pub use self::guard::MutexGuard;
6use crate::maybe_fut_constructor_sync;
7
8/// A mutual exclusion primitive useful for protecting shared data
9///
10/// This mutex will block threads waiting for the lock to become available.
11/// The mutex can be created via a [`Mutex::new`] constructor.
12/// Each mutex has a type parameter `<T>` which represents the data that it is protecting.
13///
14/// The data can only be accessed through the RAII guards returned from [`Mutex::lock`] and [`Mutex::try_lock`],
15/// which guarantees that the data is only ever accessed when the mutex is locked.
16#[derive(Debug, Unwrap)]
17#[unwrap_types(
18    std(std::sync::Mutex),
19    tokio(tokio::sync::Mutex),
20    tokio_gated("tokio-sync")
21)]
22pub struct Mutex<T>(MutexInner<T>);
23
24/// Inner wrapper for [`Mutex`].
25#[derive(Debug)]
26enum MutexInner<T> {
27    /// Std mutex
28    Std(std::sync::Mutex<T>),
29    /// Tokio mutex
30    #[cfg(tokio_sync)]
31    #[cfg_attr(docsrs, doc(cfg(feature = "tokio-sync")))]
32    Tokio(tokio::sync::Mutex<T>),
33}
34
35impl<T> From<std::sync::Mutex<T>> for Mutex<T> {
36    fn from(mutex: std::sync::Mutex<T>) -> Self {
37        Mutex(MutexInner::Std(mutex))
38    }
39}
40
41#[cfg(tokio_sync)]
42#[cfg_attr(docsrs, doc(cfg(feature = "tokio-sync")))]
43impl<T> From<tokio::sync::Mutex<T>> for Mutex<T> {
44    fn from(mutex: tokio::sync::Mutex<T>) -> Self {
45        Mutex(MutexInner::Tokio(mutex))
46    }
47}
48
49impl<T> Mutex<T>
50where
51    T: Sized,
52{
53    maybe_fut_constructor_sync!(
54        /// Creates a new lock in an unlocked state ready for use.
55        new(t: T) -> Self,
56        std::sync::Mutex::new,
57        tokio::sync::Mutex::new,
58        tokio_sync
59    );
60
61    /// Clear the poisoned state from a mutex.
62    ///
63    /// If the mutex is poisoned, it will remain poisoned until this function is called.
64    /// This allows recovering from a poisoned state and marking that it has recovered.
65    /// For example, if the value is overwritten by a known-good value, then the mutex can be marked as un-poisoned.
66    ///
67    /// If the inner type is a [`tokio::sync::Mutex`], this function is a no-op.
68    pub fn clear_poison(&self) {
69        #[allow(irrefutable_let_patterns)]
70        if let MutexInner::Std(mutex) = &self.0 {
71            mutex.clear_poison();
72        }
73    }
74
75    /// Returns `true` if the mutex is poisoned.
76    ///
77    /// If the inner type is a [`tokio::sync::Mutex`], this function will always return `false`
78    pub fn is_poisoned(&self) -> bool {
79        match &self.0 {
80            MutexInner::Std(mutex) => mutex.is_poisoned(),
81            #[cfg(tokio_sync)]
82            MutexInner::Tokio(_) => false, // Tokio mutexes are not poisoned
83        }
84    }
85
86    /// Acquires a mutex, blocking the current thread until it is able to do so.
87    ///
88    /// This function will block the local thread until it is available to acquire the mutex.
89    /// Upon returning, the thread is the only thread with the lock held. An RAII guard is returned to allow scoped
90    /// unlock of the lock. When the guard goes out of scope, the mutex will be unlocked.
91    pub async fn lock(
92        &self,
93    ) -> Result<MutexGuard<'_, T>, PoisonError<std::sync::MutexGuard<'_, T>>> {
94        match &self.0 {
95            MutexInner::Std(mutex) => {
96                let guard = mutex.lock()?;
97                Ok(MutexGuard::from(guard))
98            }
99            #[cfg(tokio_sync)]
100            MutexInner::Tokio(mutex) => {
101                let guard = mutex.lock().await;
102                Ok(MutexGuard::from(guard))
103            }
104        }
105    }
106
107    /// Attempts to acquire this lock.
108    ///
109    /// If the lock could not be acquired at this time, then [`TryLockError`] is returned.
110    /// Otherwise, an RAII guard is returned.
111    /// The lock will be unlocked when the guard is dropped.
112    pub async fn try_lock(
113        &self,
114    ) -> Result<MutexGuard<'_, T>, TryLockError<std::sync::MutexGuard<'_, T>>> {
115        match &self.0 {
116            MutexInner::Std(mutex) => {
117                let guard = mutex.try_lock()?;
118                Ok(MutexGuard::from(guard))
119            }
120            #[cfg(tokio_sync)]
121            MutexInner::Tokio(mutex) => {
122                let guard = mutex.try_lock().map_err(|_| TryLockError::WouldBlock)?;
123                Ok(MutexGuard::from(guard))
124            }
125        }
126    }
127}
128
129impl<T> From<T> for Mutex<T> {
130    fn from(t: T) -> Self {
131        Mutex::new(t)
132    }
133}
134
135impl<T> Default for Mutex<T>
136where
137    T: Default,
138{
139    fn default() -> Self {
140        Mutex::new(T::default())
141    }
142}
143
144#[cfg(test)]
145mod test {
146
147    use super::*;
148    use crate::SyncRuntime;
149
150    #[test]
151    fn test_mutex_default_sync() {
152        let mutex: Mutex<i32> = Mutex::default();
153        assert!(matches!(mutex.0, MutexInner::Std(_)));
154    }
155
156    #[cfg(tokio_sync)]
157    #[tokio::test]
158    async fn test_mutex_default_tokio_sync() {
159        let mutex: Mutex<i32> = Mutex::default();
160        assert!(matches!(mutex.0, MutexInner::Tokio(_)));
161    }
162
163    #[test]
164    fn test_mutex_from_sync() {
165        let std_mutex = std::sync::Mutex::new(42);
166        let mutex: Mutex<i32> = Mutex::from(std_mutex);
167        assert!(matches!(mutex.0, MutexInner::Std(_)));
168    }
169
170    #[cfg(tokio_sync)]
171    #[tokio::test]
172    async fn test_mutex_from_tokio() {
173        let tokio_mutex = tokio::sync::Mutex::new(42);
174        let mutex: Mutex<i32> = Mutex::from(tokio_mutex);
175        assert!(matches!(mutex.0, MutexInner::Tokio(_)));
176    }
177
178    #[test]
179    fn test_mutex_new_sync() {
180        let mutex = Mutex::new(42);
181        assert!(matches!(mutex.0, MutexInner::Std(_)));
182    }
183
184    #[cfg(tokio_sync)]
185    #[tokio::test]
186    async fn test_mutex_new_tokio_sync() {
187        let mutex = Mutex::new(42);
188        assert!(matches!(mutex.0, MutexInner::Tokio(_)));
189    }
190
191    #[test]
192    fn test_should_lock_sync_mutex() {
193        let mutex = Mutex::new(42);
194        let guard = SyncRuntime::block_on(mutex.lock());
195        assert_eq!(*guard.unwrap(), 42);
196
197        // write
198        let mut guard = SyncRuntime::block_on(mutex.lock()).unwrap();
199        *guard = 43;
200        assert_eq!(*guard, 43);
201        // read
202        drop(guard);
203        let guard = SyncRuntime::block_on(mutex.lock()).unwrap();
204        assert_eq!(*guard, 43);
205    }
206
207    #[cfg(tokio_sync)]
208    #[tokio::test]
209    async fn test_should_lock_tokio_mutex() {
210        let mutex = Mutex::new(42);
211        let guard = mutex.lock().await;
212        assert_eq!(*guard.unwrap(), 42);
213
214        // write
215        let mut guard = mutex.lock().await.unwrap();
216        *guard = 43;
217        assert_eq!(*guard, 43);
218        // read
219        drop(guard);
220        let guard = mutex.lock().await.unwrap();
221        assert_eq!(*guard, 43);
222    }
223
224    #[test]
225    fn test_should_try_lock_sync_mutex() {
226        let mutex = Mutex::new(42);
227        let guard = SyncRuntime::block_on(mutex.try_lock());
228        assert_eq!(*guard.unwrap(), 42);
229
230        // write
231        let mut guard = SyncRuntime::block_on(mutex.try_lock()).unwrap();
232        *guard = 43;
233        assert_eq!(*guard, 43);
234        // read
235        drop(guard);
236        let guard = SyncRuntime::block_on(mutex.try_lock()).unwrap();
237        assert_eq!(*guard, 43);
238    }
239
240    #[cfg(tokio_sync)]
241    #[tokio::test]
242    async fn test_should_try_lock_tokio_mutex() {
243        let mutex = Mutex::new(42);
244        let guard = mutex.try_lock().await;
245        assert_eq!(*guard.unwrap(), 42);
246
247        // write
248        let mut guard = mutex.try_lock().await.unwrap();
249        *guard = 43;
250        assert_eq!(*guard, 43);
251        // read
252        drop(guard);
253        let guard = mutex.try_lock().await.unwrap();
254        assert_eq!(*guard, 43);
255    }
256
257    #[test]
258    fn test_mutex_poisoned_sync() {
259        let mutex = Mutex::new(42);
260        let _guard = SyncRuntime::block_on(mutex.lock()).unwrap();
261        mutex.clear_poison();
262        assert!(!mutex.is_poisoned());
263    }
264}