use crate::current;
use crate::future::batch_semaphore::{BatchSemaphore, Fairness};
use crate::runtime::task::TaskId;
use std::cell::RefCell;
use std::fmt::{Debug, Display};
use std::ops::{Deref, DerefMut};
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::{LockResult, PoisonError, TryLockError, TryLockResult};
use tracing::trace;
pub struct Mutex<T: ?Sized> {
state: RefCell<MutexState>,
semaphore: BatchSemaphore,
inner: std::sync::Mutex<T>,
}
pub struct MutexGuard<'a, T: ?Sized> {
inner: Option<std::sync::MutexGuard<'a, T>>,
mutex: &'a Mutex<T>,
}
#[derive(Debug)]
struct MutexState {
holder: Option<TaskId>,
}
impl<T> Mutex<T> {
pub const fn new(value: T) -> Self {
let state = MutexState { holder: None };
Self {
state: RefCell::new(state),
semaphore: BatchSemaphore::const_new(1, Fairness::Unfair),
inner: std::sync::Mutex::new(value),
}
}
}
impl<T: ?Sized> Mutex<T> {
pub fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
let me = current::me();
let mut state = self.state.borrow_mut();
trace!(holder=?state.holder, semaphore=?self.semaphore, "waiting to acquire mutex {:p}", self);
drop(state);
if !self.semaphore.is_closed() {
state = self.state.borrow_mut();
assert!(
match &state.holder {
Some(holder) => *holder != me,
None => true,
},
"deadlock! task {me:?} tried to acquire a Mutex it already holds"
);
drop(state);
self.semaphore.acquire_blocking(1).unwrap();
}
state = self.state.borrow_mut();
assert!(state.holder.is_none());
state.holder = Some(me);
drop(state);
trace!(semaphore=?self.semaphore, "acquired mutex {:p}", self);
let result = match self.inner.try_lock() {
Ok(guard) => Ok(MutexGuard {
inner: Some(guard),
mutex: self,
}),
Err(TryLockError::Poisoned(guard)) => Err(PoisonError::new(MutexGuard {
inner: Some(guard.into_inner()),
mutex: self,
})),
Err(TryLockError::WouldBlock) => unreachable!("mutex state out of sync"),
};
result
}
pub fn try_lock(&self) -> TryLockResult<MutexGuard<T>> {
let me = current::me();
let mut state = self.state.borrow_mut();
trace!(holder=?state.holder, semaphore=?self.semaphore, "trying to acquire mutex {:p}", self);
drop(state);
self.semaphore.try_acquire(1).map_err(|_| TryLockError::WouldBlock)?;
state = self.state.borrow_mut();
state.holder = Some(me);
drop(state);
trace!(semaphore=?self.semaphore, "acquired mutex {:p}", self);
let result = match self.inner.try_lock() {
Ok(guard) => Ok(MutexGuard {
inner: Some(guard),
mutex: self,
}),
Err(TryLockError::Poisoned(guard)) => Err(TryLockError::Poisoned(PoisonError::new(MutexGuard {
inner: Some(guard.into_inner()),
mutex: self,
}))),
Err(TryLockError::WouldBlock) => unreachable!("mutex state out of sync"),
};
result
}
#[inline]
pub fn get_mut(&mut self) -> LockResult<&mut T> {
self.inner.get_mut()
}
pub fn into_inner(self) -> LockResult<T>
where
T: Sized,
{
let state = self.state.borrow();
assert!(state.holder.is_none());
self.semaphore.try_acquire(1).unwrap();
self.inner.into_inner()
}
}
unsafe impl<T: Send + ?Sized> Send for Mutex<T> {}
unsafe impl<T: Send + ?Sized> Sync for Mutex<T> {}
impl<T: ?Sized> UnwindSafe for Mutex<T> {}
impl<T: ?Sized> RefUnwindSafe for Mutex<T> {}
impl<T: Default> Default for Mutex<T> {
fn default() -> Self {
Self::new(Default::default())
}
}
impl<T: ?Sized + Debug> Debug for Mutex<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.inner, f)
}
}
impl<'a, T: ?Sized> MutexGuard<'a, T> {
pub(super) fn unlock(self) -> &'a Mutex<T> {
self.mutex
}
}
impl<T: ?Sized> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
self.inner = None;
let mut state = self.mutex.state.borrow_mut();
trace!(semaphore=?self.mutex.semaphore, "releasing mutex {:p}", self.mutex);
state.holder = None;
drop(state);
self.mutex.semaphore.release(1);
}
}
impl<T: ?Sized> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.as_ref().unwrap()
}
}
impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.as_mut().unwrap()
}
}
impl<T: Debug + ?Sized> Debug for MutexGuard<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.inner.as_ref().unwrap(), f)
}
}
impl<T: Display + ?Sized> Display for MutexGuard<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
(**self).fmt(f)
}
}
impl<T> crate::annotations::WithName for &Mutex<T> {
fn with_name_and_kind(self, name: Option<&str>, kind: Option<&str>) -> Self {
(&self.semaphore).with_name_and_kind(name, kind.or(Some("shuttle::sync::Mutex")));
self
}
}
impl<T> crate::annotations::WithName for Mutex<T> {
fn with_name_and_kind(self, name: Option<&str>, kind: Option<&str>) -> Self {
(&self).with_name_and_kind(name, kind);
self
}
}