use super::rank::LockRank;
use std::{cell::Cell, panic::Location};
pub struct Mutex<T> {
inner: parking_lot::Mutex<T>,
rank: LockRank,
}
pub struct MutexGuard<'a, T> {
inner: parking_lot::MutexGuard<'a, T>,
saved: LockState,
}
thread_local! {
static LOCK_STATE: Cell<LockState> = const { Cell::new(LockState::INITIAL) };
}
#[derive(Debug, Copy, Clone)]
struct LockState {
last_acquired: Option<(LockRank, &'static Location<'static>)>,
depth: u32,
}
impl LockState {
const INITIAL: LockState = LockState {
last_acquired: None,
depth: 0,
};
}
fn acquire(new_rank: LockRank, location: &'static Location<'static>) -> LockState {
let state = LOCK_STATE.get();
if let Some((ref last_rank, ref last_location)) = state.last_acquired {
assert!(
last_rank.followers.contains(new_rank.bit),
"Attempt to acquire nested mutexes in wrong order:\n\
last locked {:<35} at {}\n\
now locking {:<35} at {}\n\
Locking {} after locking {} is not permitted.",
last_rank.bit.name(),
last_location,
new_rank.bit.name(),
location,
new_rank.bit.name(),
last_rank.bit.name(),
);
}
LOCK_STATE.set(LockState {
last_acquired: Some((new_rank, location)),
depth: state.depth + 1,
});
state
}
fn release(saved: LockState) {
let prior = LOCK_STATE.replace(saved);
assert_eq!(
prior.depth,
saved.depth + 1,
"Lock not released in stacking order"
);
}
impl<T> Mutex<T> {
pub fn new(rank: LockRank, value: T) -> Mutex<T> {
Mutex {
inner: parking_lot::Mutex::new(value),
rank,
}
}
#[track_caller]
pub fn lock(&self) -> MutexGuard<T> {
let saved = acquire(self.rank, Location::caller());
MutexGuard {
inner: self.inner.lock(),
saved,
}
}
}
impl<'a, T> Drop for MutexGuard<'a, T> {
fn drop(&mut self) {
release(self.saved);
}
}
impl<'a, T> std::ops::Deref for MutexGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> std::ops::DerefMut for MutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.inner.fmt(f)
}
}
pub struct RwLock<T> {
inner: parking_lot::RwLock<T>,
rank: LockRank,
}
pub struct RwLockReadGuard<'a, T> {
inner: parking_lot::RwLockReadGuard<'a, T>,
saved: LockState,
}
pub struct RwLockWriteGuard<'a, T> {
inner: parking_lot::RwLockWriteGuard<'a, T>,
saved: LockState,
}
impl<T> RwLock<T> {
pub fn new(rank: LockRank, value: T) -> RwLock<T> {
RwLock {
inner: parking_lot::RwLock::new(value),
rank,
}
}
#[track_caller]
pub fn read(&self) -> RwLockReadGuard<T> {
let saved = acquire(self.rank, Location::caller());
RwLockReadGuard {
inner: self.inner.read(),
saved,
}
}
#[track_caller]
pub fn write(&self) -> RwLockWriteGuard<T> {
let saved = acquire(self.rank, Location::caller());
RwLockWriteGuard {
inner: self.inner.write(),
saved,
}
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for RwLock<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.inner.fmt(f)
}
}
impl<'a, T> Drop for RwLockReadGuard<'a, T> {
fn drop(&mut self) {
release(self.saved);
}
}
impl<'a, T> Drop for RwLockWriteGuard<'a, T> {
fn drop(&mut self) {
release(self.saved);
}
}
impl<'a, T> std::ops::Deref for RwLockReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> std::ops::Deref for RwLockWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> std::ops::DerefMut for RwLockWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
#[test]
fn permitted() {
use super::rank;
let lock1 = Mutex::new(rank::PAWN, ());
let lock2 = Mutex::new(rank::ROOK, ());
let _guard1 = lock1.lock();
let _guard2 = lock2.lock();
}
#[test]
#[should_panic(expected = "Locking pawn after locking rook")]
fn forbidden_unrelated() {
use super::rank;
let lock1 = Mutex::new(rank::ROOK, ());
let lock2 = Mutex::new(rank::PAWN, ());
let _guard1 = lock1.lock();
let _guard2 = lock2.lock();
}
#[test]
#[should_panic(expected = "Locking knight after locking pawn")]
fn forbidden_skip() {
use super::rank;
let lock1 = Mutex::new(rank::PAWN, ());
let lock2 = Mutex::new(rank::KNIGHT, ());
let _guard1 = lock1.lock();
let _guard2 = lock2.lock();
}
#[test]
fn stack_like() {
use super::rank;
let lock1 = Mutex::new(rank::PAWN, ());
let lock2 = Mutex::new(rank::ROOK, ());
let lock3 = Mutex::new(rank::BISHOP, ());
let guard1 = lock1.lock();
let guard2 = lock2.lock();
drop(guard2);
let guard3 = lock3.lock();
drop(guard3);
drop(guard1);
}
#[test]
#[should_panic(expected = "Lock not released in stacking order")]
fn non_stack_like() {
use super::rank;
let lock1 = Mutex::new(rank::PAWN, ());
let lock2 = Mutex::new(rank::ROOK, ());
let guard1 = lock1.lock();
let guard2 = lock2.lock();
std::mem::forget(guard2);
drop(guard1);
}