use core::{cell::Cell, fmt, ops, panic::Location};
use super::rank::LockRank;
pub use LockState as RankData;
pub struct Mutex<T> {
inner: parking_lot::Mutex<T>,
rank: LockRank,
}
pub struct MutexGuard<'a, T> {
inner: parking_lot::MutexGuard<'a, T>,
saved: LockStateGuard,
}
std::thread_local! {
static LOCK_STATE: Cell<LockState> = const { Cell::new(LockState::INITIAL) };
}
#[derive(Debug, Copy, Clone)]
pub struct LockState {
last_acquired: Option<(LockRank, &'static Location<'static>)>,
depth: u32,
}
impl LockState {
const INITIAL: LockState = LockState {
last_acquired: None,
depth: 0,
};
}
struct LockStateGuard(LockState);
impl Drop for LockStateGuard {
fn drop(&mut self) {
release(self.0)
}
}
fn acquire(new_rank: LockRank, location: &'static Location<'static>) -> LockState {
let state = LOCK_STATE.get();
if let Some((ref last_rank, ref last_location)) = state.last_acquired {
assert!(
last_rank.followers.contains(new_rank.bit),
"Attempt to acquire nested mutexes in wrong order:\n\
last locked {:<35} at {}\n\
now locking {:<35} at {}\n\
Locking {} after locking {} is not permitted.",
last_rank.bit.member_name(),
last_location,
new_rank.bit.member_name(),
location,
new_rank.bit.member_name(),
last_rank.bit.member_name(),
);
}
LOCK_STATE.set(LockState {
last_acquired: Some((new_rank, location)),
depth: state.depth + 1,
});
state
}
fn release(saved: LockState) {
let prior = LOCK_STATE.replace(saved);
assert_eq!(
prior.depth,
saved.depth + 1,
"Lock not released in stacking order"
);
}
impl<T> Mutex<T> {
pub fn new(rank: LockRank, value: T) -> Mutex<T> {
Mutex {
inner: parking_lot::Mutex::new(value),
rank,
}
}
#[track_caller]
pub fn lock(&self) -> MutexGuard<'_, T> {
let saved = acquire(self.rank, Location::caller());
MutexGuard {
inner: self.inner.lock(),
saved: LockStateGuard(saved),
}
}
}
impl<'a, T> ops::Deref for MutexGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> ops::DerefMut for MutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(f)
}
}
pub struct RwLock<T> {
inner: parking_lot::RwLock<T>,
rank: LockRank,
}
pub struct RwLockReadGuard<'a, T> {
inner: parking_lot::RwLockReadGuard<'a, T>,
saved: LockStateGuard,
}
pub struct RwLockWriteGuard<'a, T> {
inner: parking_lot::RwLockWriteGuard<'a, T>,
saved: LockStateGuard,
}
impl<T> RwLock<T> {
pub fn new(rank: LockRank, value: T) -> RwLock<T> {
RwLock {
inner: parking_lot::RwLock::new(value),
rank,
}
}
#[track_caller]
pub fn read(&self) -> RwLockReadGuard<'_, T> {
let saved = acquire(self.rank, Location::caller());
RwLockReadGuard {
inner: self.inner.read(),
saved: LockStateGuard(saved),
}
}
#[track_caller]
pub fn write(&self) -> RwLockWriteGuard<'_, T> {
let saved = acquire(self.rank, Location::caller());
RwLockWriteGuard {
inner: self.inner.write(),
saved: LockStateGuard(saved),
}
}
pub unsafe fn force_unlock_read(&self, data: RankData) {
release(data);
unsafe { self.inner.force_unlock_read() };
}
}
impl<'a, T> RwLockReadGuard<'a, T> {
pub fn forget(this: Self) -> RankData {
core::mem::forget(this.inner);
this.saved.0
}
}
impl<'a, T> RwLockWriteGuard<'a, T> {
pub fn downgrade(this: Self) -> RwLockReadGuard<'a, T> {
RwLockReadGuard {
inner: parking_lot::RwLockWriteGuard::downgrade(this.inner),
saved: this.saved,
}
}
}
impl<T: fmt::Debug> fmt::Debug for RwLock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(f)
}
}
impl<'a, T> ops::Deref for RwLockReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> ops::Deref for RwLockWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> ops::DerefMut for RwLockWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
#[test]
fn permitted() {
use super::rank;
let lock1 = Mutex::new(rank::PAWN, ());
let lock2 = Mutex::new(rank::ROOK, ());
let _guard1 = lock1.lock();
let _guard2 = lock2.lock();
}
#[test]
#[should_panic(expected = "Locking pawn after locking rook")]
fn forbidden_unrelated() {
use super::rank;
let lock1 = Mutex::new(rank::ROOK, ());
let lock2 = Mutex::new(rank::PAWN, ());
let _guard1 = lock1.lock();
let _guard2 = lock2.lock();
}
#[test]
#[should_panic(expected = "Locking knight after locking pawn")]
fn forbidden_skip() {
use super::rank;
let lock1 = Mutex::new(rank::PAWN, ());
let lock2 = Mutex::new(rank::KNIGHT, ());
let _guard1 = lock1.lock();
let _guard2 = lock2.lock();
}
#[test]
fn stack_like() {
use super::rank;
let lock1 = Mutex::new(rank::PAWN, ());
let lock2 = Mutex::new(rank::ROOK, ());
let lock3 = Mutex::new(rank::BISHOP, ());
let guard1 = lock1.lock();
let guard2 = lock2.lock();
drop(guard2);
let guard3 = lock3.lock();
drop(guard3);
drop(guard1);
}
#[test]
#[should_panic(expected = "Lock not released in stacking order")]
fn non_stack_like() {
use super::rank;
let lock1 = Mutex::new(rank::PAWN, ());
let lock2 = Mutex::new(rank::ROOK, ());
let guard1 = lock1.lock();
let guard2 = lock2.lock();
core::mem::forget(guard2);
drop(guard1);
}