use parking_lot::{Condvar, Mutex};
use std::{
cell::UnsafeCell,
ops::{Deref, DerefMut},
thread::ThreadId,
};
pub struct ReentrantSafeMutex<T: ?Sized> {
thread: Mutex<Option<ThreadId>>,
condvar: Condvar,
value: UnsafeCell<T>,
}
pub struct ReentrantSafeMutexGuard<'a, T> {
mutex: &'a ReentrantSafeMutex<T>,
}
impl<T> ReentrantSafeMutex<T> {
pub fn new(value: T) -> Self {
Self {
thread: Mutex::new(None),
condvar: Condvar::new(),
value: UnsafeCell::new(value),
}
}
pub fn lock(&self) -> ReentrantSafeMutexGuard<'_, T> {
let current_thread_id = std::thread::current().id();
let mut thread = self.thread.lock();
while let Some(id) = *thread {
assert!(id != current_thread_id, "Reentrant locking attempt");
self.condvar.wait(&mut thread);
}
debug_assert!((*thread).is_none());
*thread = Some(current_thread_id);
ReentrantSafeMutexGuard { mutex: self }
}
}
impl<'a, T> Drop for ReentrantSafeMutexGuard<'a, T> {
fn drop(&mut self) {
let mut thread = self.mutex.thread.lock();
*thread = None;
self.mutex.condvar.notify_one();
}
}
impl<'a, T> Deref for ReentrantSafeMutexGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.mutex.value.get() }
}
}
impl<'a, T> DerefMut for ReentrantSafeMutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.mutex.value.get() }
}
}
unsafe impl<T: ?Sized + Send> Send for ReentrantSafeMutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for ReentrantSafeMutex<T> {}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
#[test]
#[should_panic(expected = "Reentrant locking attempt")]
fn panics_on_reentrant_use() {
let m = ReentrantSafeMutex::new(1);
let g1 = m.lock();
let g2 = m.lock();
drop(g1);
drop(g2);
}
#[test]
fn does_not_panic_on_sequential_use() {
let m = ReentrantSafeMutex::new(1);
let g1 = m.lock();
drop(g1);
let g2 = m.lock();
drop(g2);
}
#[test]
fn works_as_a_mutex() {
let m = Arc::new(ReentrantSafeMutex::new(0));
let handles = (0..10)
.map(|i| {
let m = m.clone();
let inc = (i % 2) * 2 - 1;
std::thread::spawn(move || {
for _j in 0..1000 {
*m.lock() += inc;
}
})
})
.collect::<Vec<_>>();
for handle in handles {
handle.join().unwrap();
}
assert!(*m.lock() == 0);
}
}