1use std::{
2 fmt::{Debug, Display, Formatter},
3 ops::{Deref, DerefMut},
4 sync::LockResult,
5};
6
7use crate::{
8 level::{Level, LevelGuard},
9 map_guard,
10};
11
12#[derive(Debug, Default)]
28pub struct RwLock<T> {
29 inner: std::sync::RwLock<T>,
30 level: Level,
31}
32
33impl<T> RwLock<T> {
34 pub fn new(t: T) -> Self {
37 Self::with_level(t, 0)
38 }
39
40 pub fn with_level(t: T, level: u32) -> Self {
44 RwLock {
45 inner: std::sync::RwLock::new(t),
46 level: Level::new(level),
47 }
48 }
49
50 pub fn read(&self) -> LockResult<RwLockReadGuard<T>> {
52 let level = self.level.lock();
53 map_guard(self.inner.read(), |guard| RwLockReadGuard {
54 inner: guard,
55 _level: level,
56 })
57 }
58
59 pub fn write(&self) -> LockResult<RwLockWriteGuard<T>> {
61 let level = self.level.lock();
62 map_guard(self.inner.write(), |guard| RwLockWriteGuard {
63 inner: guard,
64 _level: level,
65 })
66 }
67
68 pub fn get_mut(&mut self) -> LockResult<&mut T> {
70 self.inner.get_mut()
72 }
73
74 pub fn into_inner(self) -> LockResult<T> {
76 self.inner.into_inner()
78 }
79}
80
81impl<T> From<T> for RwLock<T> {
82 fn from(value: T) -> Self {
85 RwLock::new(value)
86 }
87}
88
89pub struct RwLockReadGuard<'a, T> {
90 inner: std::sync::RwLockReadGuard<'a, T>,
91 _level: LevelGuard,
92}
93
94impl<T: Debug> Debug for RwLockReadGuard<'_, T> {
95 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
96 Debug::fmt(&self.inner, f)
97 }
98}
99
100impl<T: Display> Display for RwLockReadGuard<'_, T> {
101 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102 Display::fmt(&self.inner, f)
103 }
104}
105
106impl<T> Deref for RwLockReadGuard<'_, T> {
107 type Target = T;
108
109 fn deref(&self) -> &T {
110 self.inner.deref()
111 }
112}
113
114pub struct RwLockWriteGuard<'a, T> {
115 inner: std::sync::RwLockWriteGuard<'a, T>,
116 _level: LevelGuard,
117}
118
119impl<T: Debug> Debug for RwLockWriteGuard<'_, T> {
120 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
121 Debug::fmt(&self.inner, f)
122 }
123}
124
125impl<T: Display> Display for RwLockWriteGuard<'_, T> {
126 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
127 Display::fmt(&self.inner, f)
128 }
129}
130
131impl<T> Deref for RwLockWriteGuard<'_, T> {
132 type Target = T;
133
134 fn deref(&self) -> &T {
135 self.inner.deref()
136 }
137}
138
139impl<T> DerefMut for RwLockWriteGuard<'_, T> {
140 fn deref_mut(&mut self) -> &mut Self::Target {
141 self.inner.deref_mut()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use std::{hint::black_box, sync::Arc, thread};
148
149 use super::*;
150
151 #[test]
152 fn acquire_resource() {
153 let mutex = RwLock::new(42);
154 let guard = mutex.read().unwrap();
155 assert_eq!(42, *guard);
156 drop(guard);
157
158 let guard = mutex.write().unwrap();
159 assert_eq!(42, *guard);
160 drop(guard);
161 }
162
163 #[test]
164 fn allow_mutation() {
165 let mutex = RwLock::new(42);
166 let mut guard = mutex.write().unwrap();
167
168 *guard = 43;
169
170 assert_eq!(43, *guard)
171 }
172
173 #[test]
174 fn multithreaded() {
175 let mutex = Arc::new(RwLock::new(()));
176 let thread = thread::spawn({
177 let mutex = mutex.clone();
178 move || {
179 black_box(mutex.read().unwrap());
180 black_box(mutex.write().unwrap());
181 }
182 });
183 black_box(mutex.read().unwrap());
184 black_box(mutex.write().unwrap());
185 thread.join().unwrap();
186 }
187
188 #[cfg(debug_assertions)]
189 fn poisoned_lock() -> RwLock<()> {
190 let mutex = RwLock::new(());
191 std::panic::catch_unwind(|| {
192 let _guard = mutex.write();
193 panic!("lock is poisoned now");
194 })
195 .unwrap_err();
196 mutex
197 }
198
199 #[test]
200 #[should_panic(
201 expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
202 )]
203 #[cfg(debug_assertions)]
204 fn poisoned_read_lock() {
205 let mutex = poisoned_lock();
206
207 let _guard_a = mutex.read().unwrap_err().into_inner();
208 let _guard_b = mutex.read();
209 }
210
211 #[test]
212 #[should_panic(
213 expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
214 )]
215 #[cfg(debug_assertions)]
216 fn poisoned_write_lock() {
217 let mutex = poisoned_lock();
218
219 let _guard_a = mutex.write().unwrap_err().into_inner();
220 let _guard_b = mutex.write();
221 }
222
223 #[test]
224 #[should_panic(
225 expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
226 )]
227 #[cfg(debug_assertions)]
228 fn self_deadlock_write() {
229 let mutex = RwLock::new(());
231 let _guard = mutex.read().unwrap();
232 let _guard = mutex.write().unwrap();
233 }
234
235 #[test]
236 #[should_panic(
237 expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
238 )]
239 #[cfg(debug_assertions)]
240 fn self_deadlock_read() {
241 let mutex = RwLock::new(());
243 let _guard = mutex.read().unwrap();
244 let _guard = mutex.read().unwrap();
245 }
246
247 #[test]
248 #[cfg(debug_assertions)]
249 fn correct_level_locked() {
250 let mutex = RwLock::with_level((), 1);
251 let guard = mutex.read().unwrap();
252 assert_eq!(guard._level.level, 1);
253 drop(guard);
254 let guard = mutex.write().unwrap();
255 assert_eq!(guard._level.level, 1);
256 drop(guard);
257
258 let mutex = RwLock::new(());
259 let guard = mutex.read().unwrap();
260 assert_eq!(guard._level.level, 0);
261 drop(guard);
262 let guard = mutex.write().unwrap();
263 assert_eq!(guard._level.level, 0);
264 drop(guard);
265 }
266
267 #[test]
268 #[cfg(debug_assertions)]
269 fn created_by_default_impl_should_be_level_0() {
270 let mutex = RwLock::<()>::default();
271 assert_eq!(mutex.level.level, 0);
272 }
273
274 #[test]
275 #[cfg(debug_assertions)]
276 fn mutex_created_by_from_impl_should_be_level_0() {
277 let mutex: RwLock<u8> = 42.into();
278 assert_eq!(mutex.level.level, 0);
279 }
280}