use alloc::boxed::Box;
use core::fmt::{self, Debug, Display, Formatter};
use core::marker::PhantomData;
use core::ptr::NonNull;
use core::sync::atomic::Ordering::{AcqRel, Acquire};
use crate::cfg::atomic::{fence, AtomicPtr, UnsyncLoad};
use crate::cfg::cell::{Cell, CellNullMut, UnsafeCell, UnsafeCellWith};
use crate::lock::{Lock, Wait};
#[derive(Debug)]
struct MutexNodeInner<L> {
prev: Cell<*mut Self>,
lock: L,
}
impl<L: Lock> MutexNodeInner<L> {
#[cfg(not(all(loom, test)))]
const fn locked() -> Self {
let prev = Cell::NULL_MUT;
let lock = Lock::LOCKED;
Self { prev, lock }
}
#[cfg(all(loom, test))]
#[cfg(not(tarpaulin_include))]
fn locked() -> Self {
let prev = Cell::null_mut();
let lock = Lock::locked();
Self { prev, lock }
}
#[cfg(not(all(loom, test)))]
const fn unlocked() -> Self {
let prev = Cell::NULL_MUT;
let lock = Lock::UNLOCKED;
Self { prev, lock }
}
#[cfg(all(loom, test))]
#[cfg(not(tarpaulin_include))]
fn unlocked() -> Self {
let prev = Cell::null_mut();
let lock = Lock::unlocked();
Self { prev, lock }
}
#[cfg(not(all(loom, test)))]
fn lock(&mut self) {
self.lock = Lock::LOCKED;
}
#[cfg(all(loom, test))]
#[cfg(not(tarpaulin_include))]
fn lock(&mut self) {
self.lock = Lock::locked();
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct MutexNode<L> {
inner: NonNull<MutexNodeInner<L>>,
}
unsafe impl<L> Send for MutexNode<L> {}
unsafe impl<L> Sync for MutexNode<L> {}
impl<L> MutexNode<L> {
const unsafe fn from_ptr(ptr: *mut MutexNodeInner<L>) -> Self {
Self { inner: unsafe { NonNull::new_unchecked(ptr) } }
}
unsafe fn set(this: &mut Self, ptr: *mut MutexNodeInner<L>) {
this.inner = unsafe { NonNull::new_unchecked(ptr) };
}
}
impl<L: Lock> MutexNode<L> {
pub fn new() -> Self {
Self::locked()
}
fn locked() -> Self {
let node = MutexNodeInner::locked();
let ptr = Box::into_raw(Box::new(node));
let inner = unsafe { NonNull::new_unchecked(ptr) };
Self { inner }
}
fn unlocked() -> *mut MutexNodeInner<L> {
let node = MutexNodeInner::unlocked();
Box::into_raw(Box::new(node))
}
}
impl<L> Drop for MutexNode<L> {
fn drop(&mut self) {
let inner = self.inner.as_ptr();
drop(unsafe { Box::from_raw(inner) });
}
}
#[cfg(not(tarpaulin_include))]
impl<L: Lock> Default for MutexNode<L> {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
pub struct Mutex<T: ?Sized, L, W> {
tail: AtomicPtr<MutexNodeInner<L>>,
wait: PhantomData<W>,
data: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send, L, W> Send for Mutex<T, L, W> {}
unsafe impl<T: ?Sized + Send, L, W> Sync for Mutex<T, L, W> {}
impl<T, L: Lock, W> Mutex<T, L, W> {
pub fn new(value: T) -> Self {
let initial = MutexNode::unlocked();
let tail = AtomicPtr::new(initial);
let data = UnsafeCell::new(value);
Self { tail, data, wait: PhantomData }
}
}
impl<T: ?Sized, L: Lock, W: Wait> Mutex<T, L, W> {
pub fn lock_with(&self, mut node: MutexNode<L>) -> MutexGuard<'_, T, L, W> {
unsafe { node.inner.as_mut() }.lock();
let prev = self.tail.swap(node.inner.as_ptr(), AcqRel);
unsafe { node.inner.as_ref() }.prev.set(prev);
unsafe { &*prev }.lock.lock_wait_relaxed::<W>();
fence(Acquire);
MutexGuard::new(self, node)
}
}
impl<T: ?Sized, L, W> Drop for Mutex<T, L, W> {
fn drop(&mut self) {
let tail = self.tail.load_unsynced();
drop(unsafe { MutexNode::from_ptr(tail) });
}
}
impl<T: ?Sized, L, W> Mutex<T, L, W> {
#[cfg(not(all(loom, test)))]
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.data.get() }
}
}
impl<T: ?Sized + Debug, L: Lock, W: Wait> Debug for Mutex<T, L, W> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let node = MutexNode::new();
let mut d = f.debug_struct("Mutex");
self.lock_with(node).with(|data| d.field("data", &data));
d.finish()
}
}
pub struct MutexGuard<'a, T: ?Sized, L: Lock, W> {
lock: &'a Mutex<T, L, W>,
head: MutexNode<L>,
}
unsafe impl<T: ?Sized + Send, L: Lock, W> Send for MutexGuard<'_, T, L, W> {}
unsafe impl<T: ?Sized + Sync, L: Lock, W> Sync for MutexGuard<'_, T, L, W> {}
impl<'a, T: ?Sized, L: Lock, W> MutexGuard<'a, T, L, W> {
const fn new(lock: &'a Mutex<T, L, W>, head: MutexNode<L>) -> Self {
Self { lock, head }
}
fn with<F, Ret>(&self, f: F) -> Ret
where
F: FnOnce(&T) -> Ret,
{
unsafe { self.lock.data.with_unchecked(f) }
}
#[must_use]
pub fn into_node(mut self) -> MutexNode<L> {
unsafe { self.unlock() }
let inner = self.head.inner;
core::mem::forget(self);
MutexNode { inner }
}
unsafe fn unlock(&mut self) {
let inner = unsafe { self.head.inner.as_ref() };
let prev = inner.prev.get();
inner.lock.notify();
unsafe { MutexNode::set(&mut self.head, prev) }
}
}
impl<'a, T: ?Sized, L: Lock, W> Drop for MutexGuard<'a, T, L, W> {
#[inline]
fn drop(&mut self) {
unsafe { self.unlock() }
}
}
impl<'a, T: ?Sized + Debug, L: Lock, W> Debug for MutexGuard<'a, T, L, W> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.with(|data| data.fmt(f))
}
}
impl<'a, T: ?Sized + Display, L: Lock, W> Display for MutexGuard<'a, T, L, W> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.with(|data| data.fmt(f))
}
}
#[cfg(not(all(loom, test)))]
impl<'a, T: ?Sized, L: Lock, W> core::ops::Deref for MutexGuard<'a, T, L, W> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
#[cfg(not(all(loom, test)))]
impl<'a, T: ?Sized, L: Lock, W> core::ops::DerefMut for MutexGuard<'a, T, L, W> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
#[cfg(all(loom, test))]
#[cfg(not(tarpaulin_include))]
unsafe impl<T: ?Sized, L: Lock, W> crate::loom::Guard for MutexGuard<'_, T, L, W> {
type Target = T;
fn get(&self) -> &loom::cell::UnsafeCell<Self::Target> {
&self.lock.data
}
}