use std::{
cell::Cell,
hint::spin_loop,
ops::Deref,
sync::atomic::{
AtomicBool, AtomicUsize,
Ordering::{Acquire, Relaxed, Release},
},
};
pub(crate) struct Mutex<T> {
inner: T,
spin_lock: AtomicBool,
reentrant_count: Cell<usize>,
owner: AtomicUsize,
}
impl<T> Mutex<T> {
pub(super) fn new(inner: T) -> Self {
Self {
inner,
spin_lock: AtomicBool::new(false),
reentrant_count: Cell::new(0),
owner: AtomicUsize::new(0),
}
}
pub(super) fn lock(&self) -> MutexGuard<'_, T> {
let this_thread = current_thread_unique_ptr();
if self.owner.load(Relaxed) == this_thread {
self.reentrant_count.set(
self.reentrant_count
.get()
.checked_add(1)
.expect("reentrant_count overflow"),
);
} else {
while self
.spin_lock
.compare_exchange_weak(false, true, Acquire, Relaxed)
.is_err()
{
spin_loop();
}
self.owner.store(this_thread, Relaxed);
debug_assert_eq!(self.reentrant_count.get(), 0);
self.reentrant_count.set(1);
}
MutexGuard { lock: self }
}
}
pub(super) struct MutexGuard<'a, T> {
lock: &'a Mutex<T>,
}
impl<T> Drop for MutexGuard<'_, T> {
#[inline]
fn drop(&mut self) {
debug_assert!(self.lock.spin_lock.load(Relaxed));
debug_assert!(self.lock.reentrant_count.get() > 0);
debug_assert_eq!(self.lock.owner.load(Relaxed), current_thread_unique_ptr());
self.lock.reentrant_count.set(
self.lock
.reentrant_count
.get()
.checked_sub(1)
.expect("reentrant_count underflow"),
);
if self.lock.reentrant_count.get() == 0 {
self.lock.owner.store(0, Relaxed);
self.lock.spin_lock.store(false, Release);
}
}
}
impl<T> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
&self.lock.inner
}
}
pub(crate) fn current_thread_unique_ptr() -> usize {
thread_local! { static X: u8 = const { 0 } }
X.with(|x| <*const _>::addr(x))
}
impl<'s, T: 's> super::Reentrant<'s, T> for Mutex<T> {
fn create(data: T) -> Self {
Mutex::new(data)
}
fn reentrant_lock(&'s self) -> impl Deref<Target = T> + 's {
self.lock()
}
}