use core::fmt::{self, Debug, Display, Formatter};
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::ptr;
use core::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
use crate::cfg::atomic::{fence, AtomicPtr, AtomicPtrNull};
use crate::cfg::cell::{UnsafeCell, UnsafeCellOptionWith, UnsafeCellWith};
use crate::lock::{Lock, Wait};
use crate::relax::Relax;
#[cfg(feature = "thread_local")]
mod thread_local;
#[cfg(feature = "thread_local")]
pub use thread_local::LocalMutexNode;
#[cfg(all(feature = "thread_local", feature = "barging"))]
pub use thread_local::Key;
#[derive(Debug)]
pub struct MutexNodeInit<L> {
next: AtomicPtr<Self>,
lock: L,
}
impl<L> MutexNodeInit<L> {
const fn as_ptr(&self) -> *mut Self {
(self as *const Self).cast_mut()
}
fn wait_next_relaxed<R: Relax>(&self) -> *mut Self {
let mut relax = R::new();
loop {
let ptr = self.next.load(Relaxed);
let true = ptr.is_null() else { return ptr };
relax.relax();
}
}
}
impl<L: Lock> MutexNodeInit<L> {
#[cfg(not(all(loom, test)))]
const fn locked() -> Self {
let next = AtomicPtr::NULL_MUT;
let lock = Lock::LOCKED;
Self { next, lock }
}
#[cfg(all(loom, test))]
#[cfg(not(tarpaulin_include))]
fn locked() -> Self {
let next = AtomicPtr::null_mut();
let lock = Lock::locked();
Self { next, lock }
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct MutexNode<L> {
inner: MaybeUninit<MutexNodeInit<L>>,
}
impl<L> MutexNode<L> {
pub const fn new() -> Self {
Self { inner: MaybeUninit::uninit() }
}
}
impl<L: Lock> MutexNode<L> {
fn initialize(&mut self) -> &MutexNodeInit<L> {
self.inner.write(MutexNodeInit::locked())
}
}
#[cfg(not(tarpaulin_include))]
impl<L> Default for MutexNode<L> {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
pub struct Mutex<T: ?Sized, L, W> {
tail: AtomicPtr<MutexNodeInit<L>>,
wait: PhantomData<W>,
data: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send, L, W> Send for Mutex<T, L, W> {}
unsafe impl<T: ?Sized + Send, L, W> Sync for Mutex<T, L, W> {}
impl<T, L, W> Mutex<T, L, W> {
#[cfg(not(all(loom, test)))]
pub const fn new(value: T) -> Self {
let tail = AtomicPtr::NULL_MUT;
let data = UnsafeCell::new(value);
Self { tail, data, wait: PhantomData }
}
#[cfg(all(loom, test))]
#[cfg(not(tarpaulin_include))]
pub fn new(value: T) -> Self {
let tail = AtomicPtr::null_mut();
let data = UnsafeCell::new(value);
Self { tail, data, wait: PhantomData }
}
pub fn into_inner(self) -> T {
self.data.into_inner()
}
}
impl<T: ?Sized, L: Lock, W: Wait> Mutex<T, L, W> {
unsafe fn try_lock_with<'a>(&'a self, n: &'a mut MutexNode<L>) -> OptionGuard<'a, T, L, W> {
let node = n.initialize();
self.tail
.compare_exchange(ptr::null_mut(), node.as_ptr(), AcqRel, Relaxed)
.map(|_| MutexGuard::new(self, node))
.ok()
}
unsafe fn lock_with<'a>(&'a self, n: &'a mut MutexNode<L>) -> MutexGuard<'a, T, L, W> {
let node = n.initialize();
let pred = self.tail.swap(node.as_ptr(), AcqRel);
if !pred.is_null() {
unsafe { &(*pred).next }.store(node.as_ptr(), Release);
node.lock.wait_lock_relaxed::<W>();
fence(Acquire);
}
MutexGuard::new(self, node)
}
fn unlock_with(&self, head: &MutexNodeInit<L>) {
let mut next = head.next.load(Relaxed);
if next.is_null() {
let false = self.try_unlock_release(head.as_ptr()) else { return };
next = head.wait_next_relaxed::<W::UnlockRelax>();
}
fence(Acquire);
unsafe { &(*next).lock }.notify_release();
}
}
impl<T: ?Sized, L, W> Mutex<T, L, W> {
pub fn is_locked(&self) -> bool {
!self.tail.load(Relaxed).is_null()
}
#[cfg(not(all(loom, test)))]
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.data.get() }
}
fn try_unlock_release(&self, node: *mut MutexNodeInit<L>) -> bool {
self.tail.compare_exchange(node, ptr::null_mut(), Release, Relaxed).is_ok()
}
}
impl<T: ?Sized, L: Lock, W: Wait> Mutex<T, L, W> {
pub fn try_lock_with_then<F, Ret>(&self, node: &mut MutexNode<L>, f: F) -> Ret
where
F: FnOnce(Option<&mut T>) -> Ret,
{
unsafe { self.try_lock_with(node) }.as_deref_mut_with_mut(f)
}
pub fn lock_with_then<F, Ret>(&self, node: &mut MutexNode<L>, f: F) -> Ret
where
F: FnOnce(&mut T) -> Ret,
{
unsafe { self.lock_with(node) }.with_mut(f)
}
}
impl<T: ?Sized + Debug, L: Lock, W: Wait> Debug for Mutex<T, L, W> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut d = f.debug_struct("Mutex");
let mut node = MutexNode::new();
self.try_lock_with_then(&mut node, |data| match data {
Some(data) => d.field("data", &data),
None => d.field("data", &format_args!("<locked>")),
});
d.finish()
}
}
type OptionGuard<'a, T, L, W> = Option<MutexGuard<'a, T, L, W>>;
#[must_use = "if unused the Mutex will immediately unlock"]
struct MutexGuard<'a, T: ?Sized, L: Lock, W: Wait> {
lock: &'a Mutex<T, L, W>,
head: &'a MutexNodeInit<L>,
}
unsafe impl<T: ?Sized + Sync, L: Lock, W: Wait> Sync for MutexGuard<'_, T, L, W> {}
impl<'a, T: ?Sized, L: Lock, W: Wait> MutexGuard<'a, T, L, W> {
const fn new(lock: &'a Mutex<T, L, W>, head: &'a MutexNodeInit<L>) -> Self {
Self { lock, head }
}
#[cfg(not(tarpaulin_include))]
fn with<F, Ret>(&self, f: F) -> Ret
where
F: FnOnce(&T) -> Ret,
{
unsafe { self.lock.data.with_unchecked(f) }
}
fn with_mut<F, Ret>(&mut self, f: F) -> Ret
where
F: FnOnce(&mut T) -> Ret,
{
unsafe { self.lock.data.with_mut_unchecked(f) }
}
}
trait AsDerefMutWithMut {
type Target: ?Sized;
fn as_deref_mut_with_mut<F, Ret>(&mut self, f: F) -> Ret
where
F: FnOnce(Option<&mut Self::Target>) -> Ret;
}
impl<T: ?Sized, L: Lock, W: Wait> AsDerefMutWithMut for OptionGuard<'_, T, L, W> {
type Target = T;
fn as_deref_mut_with_mut<F, Ret>(&mut self, f: F) -> Ret
where
F: FnOnce(Option<&mut Self::Target>) -> Ret,
{
let data = self.as_ref().map(|guard| &guard.lock.data);
unsafe { data.as_deref_with_mut_unchecked(f) }
}
}
#[cfg(not(tarpaulin_include))]
impl<T: ?Sized + Debug, L: Lock, W: Wait> Debug for MutexGuard<'_, T, L, W> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.with(|data| data.fmt(f))
}
}
#[cfg(not(tarpaulin_include))]
impl<T: ?Sized + Display, L: Lock, W: Wait> Display for MutexGuard<'_, T, L, W> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.with(|data| data.fmt(f))
}
}
#[cfg(not(all(loom, test)))]
#[cfg(not(tarpaulin_include))]
impl<T: ?Sized, L: Lock, W: Wait> core::ops::Deref for MutexGuard<'_, T, L, W> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
#[cfg(not(all(loom, test)))]
#[cfg(not(tarpaulin_include))]
impl<T: ?Sized, L: Lock, W: Wait> core::ops::DerefMut for MutexGuard<'_, T, L, W> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
impl<T: ?Sized, L: Lock, W: Wait> Drop for MutexGuard<'_, T, L, W> {
#[inline]
fn drop(&mut self) {
self.lock.unlock_with(self.head);
}
}