lock_hierarchy/
mutex.rs

1use std::{
2    fmt::{Debug, Display, Formatter},
3    ops::{Deref, DerefMut},
4    sync::LockResult,
5};
6
7use crate::{
8    level::{Level, LevelGuard},
9    map_guard,
10};
11
12/// Wrapper around a [`std::sync::Mutex`] which uses a thread local variable in order to check for
13/// lock hierarchy violations in debug builds.
14///
15/// See the [crate level documentation](crate) for more general information.
16///
17/// ```
18/// use lock_hierarchy::Mutex;
19///
20/// let mutex_a = Mutex::new(()); // Level 0
21/// let mutex_b = Mutex::with_level((), 0); // also level 0
22/// // Fine, first mutex in thread
23/// let _guard_a = mutex_a.lock().unwrap();
24/// // Would panic, lock hierarchy violation
25/// // let _guard_b = mutex_b.lock().unwrap();
26/// ```
27#[derive(Debug, Default)]
28pub struct Mutex<T> {
29    inner: std::sync::Mutex<T>,
30    level: Level,
31}
32
33impl<T> Mutex<T> {
34    /// Creates lock with level 0. Use this constructor if you want to get an error in debug builds
35    /// every time you acquire another lock while holding this one.
36    pub fn new(t: T) -> Self {
37        Self::with_level(t, 0)
38    }
39
40    /// Creates a lock and assigns it a level in the lock hierarchy. Higher levels must be acquired
41    /// first if locks are to be held simultaneously. This way we can ensure locks are always
42    /// acquired in the same order. This prevents deadlocks.
43    pub fn with_level(t: T, level: u32) -> Self {
44        Mutex {
45            inner: std::sync::Mutex::new(t),
46            level: Level::new(level),
47        }
48    }
49
50    /// See [std::sync::Mutex::lock]
51    pub fn lock(&self) -> LockResult<MutexGuard<T>> {
52        let level = self.level.lock();
53        map_guard(self.inner.lock(), |guard| MutexGuard {
54            inner: guard,
55            _level: level,
56        })
57    }
58
59    /// See [std::sync::Mutex::get_mut]
60    pub fn get_mut(&mut self) -> LockResult<&mut T> {
61        // No need to check hierarchy, this does not lock
62        self.inner.get_mut()
63    }
64
65    /// See [std::sync::Mutex::into_inner]
66    pub fn into_inner(self) -> LockResult<T> {
67        // No need to check hierarchy, this does not lock
68        self.inner.into_inner()
69    }
70}
71
72impl<T> From<T> for Mutex<T> {
73    /// Creates a new mutex in an unlocked state ready for use.
74    /// This is equivalent to [`Mutex::new`].
75    fn from(value: T) -> Self {
76        Mutex::new(value)
77    }
78}
79
80pub struct MutexGuard<'a, T> {
81    inner: std::sync::MutexGuard<'a, T>,
82    _level: LevelGuard,
83}
84
85impl<T: Debug> Debug for MutexGuard<'_, T> {
86    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
87        Debug::fmt(&self.inner, f)
88    }
89}
90
91impl<T: Display> Display for MutexGuard<'_, T> {
92    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
93        Display::fmt(&self.inner, f)
94    }
95}
96
97impl<T> Deref for MutexGuard<'_, T> {
98    type Target = T;
99
100    fn deref(&self) -> &T {
101        self.inner.deref()
102    }
103}
104
105impl<T> DerefMut for MutexGuard<'_, T> {
106    fn deref_mut(&mut self) -> &mut Self::Target {
107        self.inner.deref_mut()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use std::{hint::black_box, sync::Arc, thread};
114
115    use super::*;
116
117    #[test]
118    fn acquire_resource() {
119        let mutex = Mutex::new(42);
120        let guard = mutex.lock().unwrap();
121
122        assert_eq!(42, *guard)
123    }
124
125    #[test]
126    fn allow_mutation() {
127        let mutex = Mutex::new(42);
128        let mut guard = mutex.lock().unwrap();
129
130        *guard = 43;
131
132        assert_eq!(43, *guard)
133    }
134
135    #[test]
136    fn multithreaded() {
137        let mutex = Arc::new(Mutex::new(()));
138        let thread = thread::spawn({
139            let mutex = mutex.clone();
140            move || {
141                black_box(mutex.lock().unwrap());
142            }
143        });
144        black_box(mutex.lock().unwrap());
145        thread.join().unwrap();
146    }
147
148    #[test]
149    #[should_panic(
150        expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
151    )]
152    #[cfg(debug_assertions)]
153    fn self_deadlock() {
154        // This ensures that the level is locked in Mutex::lock before locking the std lock which might otherwise cause an unchecked deadlock
155        let mutex = Mutex::new(());
156        let _guard = mutex.lock().unwrap();
157        let _guard = mutex.lock().unwrap();
158    }
159
160    #[test]
161    #[should_panic(
162        expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
163    )]
164    #[cfg(debug_assertions)]
165    fn poisoned_lock() {
166        let mutex = Mutex::new(());
167        std::panic::catch_unwind(|| {
168            let _guard = mutex.lock();
169            panic!("lock is poisoned now");
170        })
171        .unwrap_err();
172
173        let _guard_a = mutex.lock().unwrap_err().into_inner();
174        let _guard_b = mutex.lock();
175    }
176
177    #[test]
178    #[cfg(debug_assertions)]
179    fn correct_level_locked() {
180        let mutex = Mutex::with_level((), 1);
181        let _guard_a = mutex.lock().unwrap();
182        assert_eq!(_guard_a._level.level, 1);
183
184        let mutex = Mutex::new(());
185        let _guard_a = mutex.lock().unwrap();
186        assert_eq!(_guard_a._level.level, 0);
187    }
188
189    #[test]
190    #[cfg(debug_assertions)]
191    fn created_by_default_impl_should_be_level_0() {
192        let mutex = Mutex::<()>::default();
193        assert_eq!(mutex.level.level, 0);
194    }
195
196    #[test]
197    #[cfg(debug_assertions)]
198    fn mutex_created_by_from_impl_should_be_level_0() {
199        let mutex: Mutex<u8> = 42.into();
200        assert_eq!(mutex.level.level, 0);
201    }
202}