use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::atomic::{AtomicU8, Ordering};
use crate::misc::{PhantomMutex, PhantomMutexGuard};
use crate::result::{LockResult, PoisonError, TryLockError, TryLockResult};
pub struct Mutex<T: ?Sized> {
lock: AtomicU8,
_phantom: PhantomMutex<T>,
data: UnsafeCell<T>,
}
impl<T> Mutex<T> {
pub const fn new(t: T) -> Self {
Mutex {
lock: AtomicU8::new(INIT),
data: UnsafeCell::new(t),
_phantom: PhantomMutex {},
}
}
pub fn into_inner(self) -> LockResult<T> {
let is_err = self.is_poisoned();
let data = self.data.into_inner();
if is_err {
Err(PoisonError::new(data))
} else {
Ok(data)
}
}
}
impl<T: ?Sized> Mutex<T> {
pub fn lock(&self) -> LockResult<MutexGuard<T>> {
loop {
match self.do_try_lock() {
s if is_locked(s) => std::thread::yield_now(),
s if is_poisoned(s) => return Err(PoisonError::new(MutexGuard::new(self))),
_ => return Ok(MutexGuard::new(self)),
}
}
}
pub fn try_lock(&self) -> TryLockResult<MutexGuard<T>> {
match self.do_try_lock() {
s if is_locked(s) => Err(TryLockError::WouldBlock),
s if is_poisoned(s) => Err(TryLockError::Poisoned(PoisonError::new(MutexGuard::new(
self,
)))),
_ => Ok(MutexGuard::new(self)),
}
}
fn do_try_lock(&self) -> LockStatus {
let mut expected = INIT;
loop {
let desired = acquire_lock(expected);
match self
.lock
.compare_and_swap(expected, desired, Ordering::Acquire)
{
s if s == expected => return s, s if is_locked(s) => return s, s => expected = s, }
}
}
pub fn is_poisoned(&self) -> bool {
let status = self.lock.load(Ordering::Relaxed);
return is_poisoned(status);
}
pub fn get_mut(&mut self) -> LockResult<&mut T> {
let data = unsafe { &mut *self.data.get() };
if self.is_poisoned() {
Err(PoisonError::new(data))
} else {
Ok(data)
}
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.try_lock() {
Ok(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(),
Err(TryLockError::Poisoned(err)) => f
.debug_struct("Mutex")
.field("data", &&**err.get_ref())
.finish(),
Err(TryLockError::WouldBlock) => {
struct LockedPlaceholder;
impl fmt::Debug for LockedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
f.debug_struct("Mutex")
.field("data", &LockedPlaceholder)
.finish()
}
}
}
}
impl<T> From<T> for Mutex<T> {
fn from(t: T) -> Self {
Mutex::new(t)
}
}
impl<T: ?Sized + Default> Default for Mutex<T> {
fn default() -> Self {
Mutex::new(T::default())
}
}
#[must_use = "if unused the Mutex will immediately unlock"]
pub struct MutexGuard<'a, T: ?Sized + 'a> {
mutex: &'a Mutex<T>,
_phantom: PhantomMutexGuard<'a, T>, }
impl<'a, T: ?Sized> MutexGuard<'a, T> {
fn new(mutex: &'a Mutex<T>) -> Self {
Self {
mutex,
_phantom: Default::default(),
}
}
}
impl<T: ?Sized> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
let old_status = self.mutex.lock.load(Ordering::Relaxed);
debug_assert!(is_locked(old_status));
let mut new_status = release_lock(old_status);
if std::thread::panicking() {
new_status = set_poison_flag(new_status);
}
self.mutex.lock.store(new_status, Ordering::Release);
}
}
impl<T: ?Sized> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.mutex.data.get() }
}
}
impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.mutex.data.get() }
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
}
impl<T: ?Sized> UnwindSafe for Mutex<T> {}
impl<T: ?Sized> RefUnwindSafe for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
type LockStatus = u8;
const INIT: LockStatus = 0;
const LOCK_FLAG: LockStatus = 0x01;
const POISON_FLAG: LockStatus = 0x02;
const NOT_USED_MASK: LockStatus = 0xfc;
#[inline]
#[must_use]
fn is_locked(s: LockStatus) -> bool {
debug_assert_eq!(0, s & NOT_USED_MASK);
(s & LOCK_FLAG) != 0
}
#[inline]
#[must_use]
fn acquire_lock(s: LockStatus) -> LockStatus {
debug_assert_eq!(false, is_locked(s));
s | LOCK_FLAG
}
#[inline]
#[must_use]
fn release_lock(s: LockStatus) -> LockStatus {
debug_assert_eq!(true, is_locked(s));
s & !(LOCK_FLAG)
}
#[inline]
#[must_use]
fn is_poisoned(s: LockStatus) -> bool {
debug_assert_eq!(0, s & NOT_USED_MASK);
(s & POISON_FLAG) != 0
}
#[inline]
#[must_use]
fn set_poison_flag(s: LockStatus) -> LockStatus {
debug_assert_eq!(0, s & NOT_USED_MASK);
s | POISON_FLAG
}