use std::{
fmt::{Debug, Display, Formatter},
ops::{Deref, DerefMut},
sync::LockResult,
};
use crate::{
level::{Level, LevelGuard},
map_guard,
};
#[derive(Debug, Default)]
pub struct Mutex<T> {
inner: std::sync::Mutex<T>,
level: Level,
}
impl<T> Mutex<T> {
pub fn new(t: T) -> Self {
Self::with_level(t, 0)
}
pub fn with_level(t: T, level: u32) -> Self {
Mutex {
inner: std::sync::Mutex::new(t),
level: Level::new(level),
}
}
pub fn lock(&self) -> LockResult<MutexGuard<T>> {
let level = self.level.lock();
map_guard(self.inner.lock(), |guard| MutexGuard {
inner: guard,
_level: level,
})
}
pub fn get_mut(&mut self) -> LockResult<&mut T> {
self.inner.get_mut()
}
pub fn into_inner(self) -> LockResult<T> {
self.inner.into_inner()
}
}
impl<T> From<T> for Mutex<T> {
fn from(value: T) -> Self {
Mutex::new(value)
}
}
pub struct MutexGuard<'a, T> {
inner: std::sync::MutexGuard<'a, T>,
_level: LevelGuard,
}
impl<T: Debug> Debug for MutexGuard<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.inner, f)
}
}
impl<T: Display> Display for MutexGuard<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.inner, f)
}
}
impl<T> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.inner.deref()
}
}
impl<T> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
#[cfg(test)]
mod tests {
use std::{hint::black_box, sync::Arc, thread};
use super::*;
#[test]
fn acquire_resource() {
let mutex = Mutex::new(42);
let guard = mutex.lock().unwrap();
assert_eq!(42, *guard)
}
#[test]
fn allow_mutation() {
let mutex = Mutex::new(42);
let mut guard = mutex.lock().unwrap();
*guard = 43;
assert_eq!(43, *guard)
}
#[test]
fn multithreaded() {
let mutex = Arc::new(Mutex::new(()));
let thread = thread::spawn({
let mutex = mutex.clone();
move || {
black_box(mutex.lock().unwrap());
}
});
black_box(mutex.lock().unwrap());
thread.join().unwrap();
}
#[test]
#[should_panic(
expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
)]
#[cfg(debug_assertions)]
fn self_deadlock() {
let mutex = Mutex::new(());
let _guard = mutex.lock().unwrap();
let _guard = mutex.lock().unwrap();
}
#[test]
#[should_panic(
expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
)]
#[cfg(debug_assertions)]
fn poisoned_lock() {
let mutex = Mutex::new(());
std::panic::catch_unwind(|| {
let _guard = mutex.lock();
panic!("lock is poisoned now");
})
.unwrap_err();
let _guard_a = mutex.lock().unwrap_err().into_inner();
let _guard_b = mutex.lock();
}
#[test]
#[cfg(debug_assertions)]
fn correct_level_locked() {
let mutex = Mutex::with_level((), 1);
let _guard_a = mutex.lock().unwrap();
assert_eq!(_guard_a._level.level, 1);
let mutex = Mutex::new(());
let _guard_a = mutex.lock().unwrap();
assert_eq!(_guard_a._level.level, 0);
}
#[test]
#[cfg(debug_assertions)]
fn created_by_default_impl_should_be_level_0() {
let mutex = Mutex::<()>::default();
assert_eq!(mutex.level.level, 0);
}
#[test]
#[cfg(debug_assertions)]
fn mutex_created_by_from_impl_should_be_level_0() {
let mutex: Mutex<u8> = 42.into();
assert_eq!(mutex.level.level, 0);
}
}