1use std::{
2 fmt::{Debug, Display, Formatter},
3 ops::{Deref, DerefMut},
4 sync::LockResult,
5};
6
7use crate::{
8 level::{Level, LevelGuard},
9 map_guard,
10};
11
12#[derive(Debug, Default)]
28pub struct Mutex<T> {
29 inner: std::sync::Mutex<T>,
30 level: Level,
31}
32
33impl<T> Mutex<T> {
34 pub fn new(t: T) -> Self {
37 Self::with_level(t, 0)
38 }
39
40 pub fn with_level(t: T, level: u32) -> Self {
44 Mutex {
45 inner: std::sync::Mutex::new(t),
46 level: Level::new(level),
47 }
48 }
49
50 pub fn lock(&self) -> LockResult<MutexGuard<T>> {
52 let level = self.level.lock();
53 map_guard(self.inner.lock(), |guard| MutexGuard {
54 inner: guard,
55 _level: level,
56 })
57 }
58
59 pub fn get_mut(&mut self) -> LockResult<&mut T> {
61 self.inner.get_mut()
63 }
64
65 pub fn into_inner(self) -> LockResult<T> {
67 self.inner.into_inner()
69 }
70}
71
72impl<T> From<T> for Mutex<T> {
73 fn from(value: T) -> Self {
76 Mutex::new(value)
77 }
78}
79
80pub struct MutexGuard<'a, T> {
81 inner: std::sync::MutexGuard<'a, T>,
82 _level: LevelGuard,
83}
84
85impl<T: Debug> Debug for MutexGuard<'_, T> {
86 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
87 Debug::fmt(&self.inner, f)
88 }
89}
90
91impl<T: Display> Display for MutexGuard<'_, T> {
92 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
93 Display::fmt(&self.inner, f)
94 }
95}
96
97impl<T> Deref for MutexGuard<'_, T> {
98 type Target = T;
99
100 fn deref(&self) -> &T {
101 self.inner.deref()
102 }
103}
104
105impl<T> DerefMut for MutexGuard<'_, T> {
106 fn deref_mut(&mut self) -> &mut Self::Target {
107 self.inner.deref_mut()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use std::{hint::black_box, sync::Arc, thread};
114
115 use super::*;
116
117 #[test]
118 fn acquire_resource() {
119 let mutex = Mutex::new(42);
120 let guard = mutex.lock().unwrap();
121
122 assert_eq!(42, *guard)
123 }
124
125 #[test]
126 fn allow_mutation() {
127 let mutex = Mutex::new(42);
128 let mut guard = mutex.lock().unwrap();
129
130 *guard = 43;
131
132 assert_eq!(43, *guard)
133 }
134
135 #[test]
136 fn multithreaded() {
137 let mutex = Arc::new(Mutex::new(()));
138 let thread = thread::spawn({
139 let mutex = mutex.clone();
140 move || {
141 black_box(mutex.lock().unwrap());
142 }
143 });
144 black_box(mutex.lock().unwrap());
145 thread.join().unwrap();
146 }
147
148 #[test]
149 #[should_panic(
150 expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
151 )]
152 #[cfg(debug_assertions)]
153 fn self_deadlock() {
154 let mutex = Mutex::new(());
156 let _guard = mutex.lock().unwrap();
157 let _guard = mutex.lock().unwrap();
158 }
159
160 #[test]
161 #[should_panic(
162 expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
163 )]
164 #[cfg(debug_assertions)]
165 fn poisoned_lock() {
166 let mutex = Mutex::new(());
167 std::panic::catch_unwind(|| {
168 let _guard = mutex.lock();
169 panic!("lock is poisoned now");
170 })
171 .unwrap_err();
172
173 let _guard_a = mutex.lock().unwrap_err().into_inner();
174 let _guard_b = mutex.lock();
175 }
176
177 #[test]
178 #[cfg(debug_assertions)]
179 fn correct_level_locked() {
180 let mutex = Mutex::with_level((), 1);
181 let _guard_a = mutex.lock().unwrap();
182 assert_eq!(_guard_a._level.level, 1);
183
184 let mutex = Mutex::new(());
185 let _guard_a = mutex.lock().unwrap();
186 assert_eq!(_guard_a._level.level, 0);
187 }
188
189 #[test]
190 #[cfg(debug_assertions)]
191 fn created_by_default_impl_should_be_level_0() {
192 let mutex = Mutex::<()>::default();
193 assert_eq!(mutex.level.level, 0);
194 }
195
196 #[test]
197 #[cfg(debug_assertions)]
198 fn mutex_created_by_from_impl_should_be_level_0() {
199 let mutex: Mutex<u8> = 42.into();
200 assert_eq!(mutex.level.level, 0);
201 }
202}