use std::{
fmt::{Debug, Display, Formatter},
ops::{Deref, DerefMut},
sync::LockResult,
};
use crate::{
level::{Level, LevelGuard},
map_guard,
};
#[derive(Debug, Default)]
pub struct RwLock<T> {
inner: std::sync::RwLock<T>,
level: Level,
}
impl<T> RwLock<T> {
pub fn new(t: T) -> Self {
Self::with_level(t, 0)
}
pub fn with_level(t: T, level: u32) -> Self {
RwLock {
inner: std::sync::RwLock::new(t),
level: Level::new(level),
}
}
pub fn read(&self) -> LockResult<RwLockReadGuard<T>> {
let level = self.level.lock();
map_guard(self.inner.read(), |guard| RwLockReadGuard {
inner: guard,
_level: level,
})
}
pub fn write(&self) -> LockResult<RwLockWriteGuard<T>> {
let level = self.level.lock();
map_guard(self.inner.write(), |guard| RwLockWriteGuard {
inner: guard,
_level: level,
})
}
pub fn get_mut(&mut self) -> LockResult<&mut T> {
self.inner.get_mut()
}
pub fn into_inner(self) -> LockResult<T> {
self.inner.into_inner()
}
}
impl<T> From<T> for RwLock<T> {
fn from(value: T) -> Self {
RwLock::new(value)
}
}
pub struct RwLockReadGuard<'a, T> {
inner: std::sync::RwLockReadGuard<'a, T>,
_level: LevelGuard,
}
impl<T: Debug> Debug for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.inner, f)
}
}
impl<T: Display> Display for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.inner, f)
}
}
impl<T> Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.inner.deref()
}
}
pub struct RwLockWriteGuard<'a, T> {
inner: std::sync::RwLockWriteGuard<'a, T>,
_level: LevelGuard,
}
impl<T: Debug> Debug for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.inner, f)
}
}
impl<T: Display> Display for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.inner, f)
}
}
impl<T> Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.inner.deref()
}
}
impl<T> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
#[cfg(test)]
mod tests {
use std::{hint::black_box, sync::Arc, thread};
use super::*;
#[test]
fn acquire_resource() {
let mutex = RwLock::new(42);
let guard = mutex.read().unwrap();
assert_eq!(42, *guard);
drop(guard);
let guard = mutex.write().unwrap();
assert_eq!(42, *guard);
drop(guard);
}
#[test]
fn allow_mutation() {
let mutex = RwLock::new(42);
let mut guard = mutex.write().unwrap();
*guard = 43;
assert_eq!(43, *guard)
}
#[test]
fn multithreaded() {
let mutex = Arc::new(RwLock::new(()));
let thread = thread::spawn({
let mutex = mutex.clone();
move || {
black_box(mutex.read().unwrap());
black_box(mutex.write().unwrap());
}
});
black_box(mutex.read().unwrap());
black_box(mutex.write().unwrap());
thread.join().unwrap();
}
#[cfg(debug_assertions)]
fn poisoned_lock() -> RwLock<()> {
let mutex = RwLock::new(());
std::panic::catch_unwind(|| {
let _guard = mutex.write();
panic!("lock is poisoned now");
})
.unwrap_err();
mutex
}
#[test]
#[should_panic(
expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
)]
#[cfg(debug_assertions)]
fn poisoned_read_lock() {
let mutex = poisoned_lock();
let _guard_a = mutex.read().unwrap_err().into_inner();
let _guard_b = mutex.read();
}
#[test]
#[should_panic(
expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
)]
#[cfg(debug_assertions)]
fn poisoned_write_lock() {
let mutex = poisoned_lock();
let _guard_a = mutex.write().unwrap_err().into_inner();
let _guard_b = mutex.write();
}
#[test]
#[should_panic(
expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
)]
#[cfg(debug_assertions)]
fn self_deadlock_write() {
let mutex = RwLock::new(());
let _guard = mutex.read().unwrap();
let _guard = mutex.write().unwrap();
}
#[test]
#[should_panic(
expected = "Tried to acquire lock with level 0 while a lock with level 0 is acquired. This is a violation of lock hierarchies which could lead to deadlocks."
)]
#[cfg(debug_assertions)]
fn self_deadlock_read() {
let mutex = RwLock::new(());
let _guard = mutex.read().unwrap();
let _guard = mutex.read().unwrap();
}
#[test]
#[cfg(debug_assertions)]
fn correct_level_locked() {
let mutex = RwLock::with_level((), 1);
let guard = mutex.read().unwrap();
assert_eq!(guard._level.level, 1);
drop(guard);
let guard = mutex.write().unwrap();
assert_eq!(guard._level.level, 1);
drop(guard);
let mutex = RwLock::new(());
let guard = mutex.read().unwrap();
assert_eq!(guard._level.level, 0);
drop(guard);
let guard = mutex.write().unwrap();
assert_eq!(guard._level.level, 0);
drop(guard);
}
#[test]
#[cfg(debug_assertions)]
fn created_by_default_impl_should_be_level_0() {
let mutex = RwLock::<()>::default();
assert_eq!(mutex.level.level, 0);
}
#[test]
#[cfg(debug_assertions)]
fn mutex_created_by_from_impl_should_be_level_0() {
let mutex: RwLock<u8> = 42.into();
assert_eq!(mutex.level.level, 0);
}
}