lock_hierarchy/
rwlock.rs

1use std::{
2    fmt::{Debug, Display, Formatter},
3    ops::{Deref, DerefMut},
4    sync::LockResult,
5};
6
7use crate::{
8    level::{Level, LevelGuard},
9    map_guard,
10};
11
12/// Wrapper around a [`std::sync::RwLock`] which uses a thread local variable in order to check for
13/// lock hierarchy violations in debug builds.
14///
15/// See the [crate level documentation](crate) for more general information.
16///
17/// ```
18/// use lock_hierarchy::RwLock;
19///
20/// let mutex_a = RwLock::new(()); // Level 0
21/// let mutex_b = RwLock::with_level((), 0); // also level 0
22/// // Fine, first mutex in thread
23/// let _guard_a = mutex_a.read().unwrap();
24/// // Would panic, lock hierarchy violation
25/// // let _guard_b = mutex_b.read().unwrap();
26/// ```
27#[derive(Debug, Default)]
28pub struct RwLock<T> {
29    inner: std::sync::RwLock<T>,
30    level: Level,
31}
32
33impl<T> RwLock<T> {
34    /// Creates a lock with level 0. Use this constructor if you want to get an error in debug builds
35    /// every time you acquire another lock while holding this one.
36    pub fn new(t: T) -> Self {
37        Self::with_level(t, 0)
38    }
39
40    /// Creates a lock and assigns it a level in the lock hierarchy. Higher levels must be acquired
41    /// first if locks are to be held simultaneously. This way we can ensure locks are always
42    /// acquired in the same order. This prevents deadlocks.
43    pub fn with_level(t: T, level: u32) -> Self {
44        RwLock {
45            inner: std::sync::RwLock::new(t),
46            level: Level::new(level),
47        }
48    }
49
50    /// See [std::sync::RwLock::read]
51    pub fn read(&self) -> LockResult<RwLockReadGuard<T>> {
52        let level = self.level.lock();
53        map_guard(self.inner.read(), |guard| RwLockReadGuard {
54            inner: guard,
55            _level: level,
56        })
57    }
58
59    /// See [std::sync::RwLock::write]
60    pub fn write(&self) -> LockResult<RwLockWriteGuard<T>> {
61        let level = self.level.lock();
62        map_guard(self.inner.write(), |guard| RwLockWriteGuard {
63            inner: guard,
64            _level: level,
65        })
66    }
67
68    /// See [std::sync::RwLock::get_mut]
69    pub fn get_mut(&mut self) -> LockResult<&mut T> {
70        // No need to check hierarchy, this does not lock
71        self.inner.get_mut()
72    }
73
74    /// See [std::sync::RwLock::into_inner]
75    pub fn into_inner(self) -> LockResult<T> {
76        // No need to check hierarchy, this does not lock
77        self.inner.into_inner()
78    }
79}
80
81impl<T> From<T> for RwLock<T> {
82    /// Creates a new mutex in an unlocked state ready for use.
83    /// This is equivalent to [`RwLock::new`].
84    fn from(value: T) -> Self {
85        RwLock::new(value)
86    }
87}
88
89pub struct RwLockReadGuard<'a, T> {
90    inner: std::sync::RwLockReadGuard<'a, T>,
91    _level: LevelGuard,
92}
93
94impl<T: Debug> Debug for RwLockReadGuard<'_, T> {
95    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
96        Debug::fmt(&self.inner, f)
97    }
98}
99
100impl<T: Display> Display for RwLockReadGuard<'_, T> {
101    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102        Display::fmt(&self.inner, f)
103    }
104}
105
106impl<T> Deref for RwLockReadGuard<'_, T> {
107    type Target = T;
108
109    fn deref(&self) -> &T {
110        self.inner.deref()
111    }
112}
113
114pub struct RwLockWriteGuard<'a, T> {
115    inner: std::sync::RwLockWriteGuard<'a, T>,
116    _level: LevelGuard,
117}
118
119impl<T: Debug> Debug for RwLockWriteGuard<'_, T> {
120    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
121        Debug::fmt(&self.inner, f)
122    }
123}
124
125impl<T: Display> Display for RwLockWriteGuard<'_, T> {
126    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
127        Display::fmt(&self.inner, f)
128    }
129}
130
131impl<T> Deref for RwLockWriteGuard<'_, T> {
132    type Target = T;
133
134    fn deref(&self) -> &T {
135        self.inner.deref()
136    }
137}
138
139impl<T> DerefMut for RwLockWriteGuard<'_, T> {
140    fn deref_mut(&mut self) -> &mut Self::Target {
141        self.inner.deref_mut()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use std::{hint::black_box, sync::Arc, thread};
148
149    use super::*;
150
151    #[test]
152    fn acquire_resource() {
153        let mutex = RwLock::new(42);
154        let guard = mutex.read().unwrap();
155        assert_eq!(42, *guard);
156        drop(guard);
157
158        let guard = mutex.write().unwrap();
159        assert_eq!(42, *guard);
160        drop(guard);
161    }
162
163    #[test]
164    fn allow_mutation() {
165        let mutex = RwLock::new(42);
166        let mut guard = mutex.write().unwrap();
167
168        *guard = 43;
169
170        assert_eq!(43, *guard)
171    }
172
173    #[test]
174    fn multithreaded() {
175        let mutex = Arc::new(RwLock::new(()));
176        let thread = thread::spawn({
177            let mutex = mutex.clone();
178            move || {
179                black_box(mutex.read().unwrap());
180                black_box(mutex.write().unwrap());
181            }
182        });
183        black_box(mutex.read().unwrap());
184        black_box(mutex.write().unwrap());
185        thread.join().unwrap();
186    }
187
188    #[cfg(debug_assertions)]
189    fn poisoned_lock() -> RwLock<()> {
190        let mutex = RwLock::new(());
191        std::panic::catch_unwind(|| {
192            let _guard = mutex.write();
193            panic!("lock is poisoned now");
194        })
195        .unwrap_err();
196        mutex
197    }
198
199    #[test]
200    #[should_panic(
201        expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
202    )]
203    #[cfg(debug_assertions)]
204    fn poisoned_read_lock() {
205        let mutex = poisoned_lock();
206
207        let _guard_a = mutex.read().unwrap_err().into_inner();
208        let _guard_b = mutex.read();
209    }
210
211    #[test]
212    #[should_panic(
213        expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
214    )]
215    #[cfg(debug_assertions)]
216    fn poisoned_write_lock() {
217        let mutex = poisoned_lock();
218
219        let _guard_a = mutex.write().unwrap_err().into_inner();
220        let _guard_b = mutex.write();
221    }
222
223    #[test]
224    #[should_panic(
225        expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
226    )]
227    #[cfg(debug_assertions)]
228    fn self_deadlock_write() {
229        // This ensures that the level is locked in RwLock::write before locking the std lock which might otherwise cause a deadlock
230        let mutex = RwLock::new(());
231        let _guard = mutex.read().unwrap();
232        let _guard = mutex.write().unwrap();
233    }
234
235    #[test]
236    #[should_panic(
237        expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
238    )]
239    #[cfg(debug_assertions)]
240    fn self_deadlock_read() {
241        // This ensures that the level is locked in RwLock::read before locking the std lock which might otherwise cause an unchecked deadlock
242        let mutex = RwLock::new(());
243        let _guard = mutex.read().unwrap();
244        let _guard = mutex.read().unwrap();
245    }
246
247    #[test]
248    #[cfg(debug_assertions)]
249    fn correct_level_locked() {
250        let mutex = RwLock::with_level((), 1);
251        let guard = mutex.read().unwrap();
252        assert_eq!(guard._level.level, 1);
253        drop(guard);
254        let guard = mutex.write().unwrap();
255        assert_eq!(guard._level.level, 1);
256        drop(guard);
257
258        let mutex = RwLock::new(());
259        let guard = mutex.read().unwrap();
260        assert_eq!(guard._level.level, 0);
261        drop(guard);
262        let guard = mutex.write().unwrap();
263        assert_eq!(guard._level.level, 0);
264        drop(guard);
265    }
266
267    #[test]
268    #[cfg(debug_assertions)]
269    fn created_by_default_impl_should_be_level_0() {
270        let mutex = RwLock::<()>::default();
271        assert_eq!(mutex.level.level, 0);
272    }
273
274    #[test]
275    #[cfg(debug_assertions)]
276    fn mutex_created_by_from_impl_should_be_level_0() {
277        let mutex: RwLock<u8> = 42.into();
278        assert_eq!(mutex.level.level, 0);
279    }
280}