#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
#![cfg_attr(feature = "fmt", warn(missing_debug_implementations))]
pub mod raw_impls;
use core::cell::UnsafeCell;
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
use core::panic::AssertUnwindSafe;
pub use mutex_traits::{ConstInit, RawMutex, ScopedRawMutex};
pub struct BlockingMutex<R, T: ?Sized> {
raw: R,
data: UnsafeCell<T>,
}
#[must_use]
pub struct MutexGuard<'mutex, R: RawMutex, T: ?Sized> {
lock: &'mutex BlockingMutex<R, T>,
_marker: PhantomData<R::GuardMarker>,
}
unsafe impl<R: ScopedRawMutex + Send, T: ?Sized + Send> Send for BlockingMutex<R, T> {}
unsafe impl<R: ScopedRawMutex + Sync, T: ?Sized + Send> Sync for BlockingMutex<R, T> {}
#[cfg(feature = "std")]
#[inline(always)]
fn catch_unwind<F: FnOnce() -> R + std::panic::UnwindSafe, R>(
f: F,
) -> Result<R, Box<dyn std::any::Any + Send>> {
std::panic::catch_unwind(f)
}
#[cfg(not(feature = "std"))]
#[inline(always)]
fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> Result<R, core::convert::Infallible> {
Ok(f())
}
impl<R: ConstInit, T> BlockingMutex<R, T> {
#[inline]
pub const fn new(val: T) -> BlockingMutex<R, T> {
BlockingMutex {
raw: R::INIT,
data: UnsafeCell::new(val),
}
}
}
impl<R: ScopedRawMutex, T: ?Sized> BlockingMutex<R, T> {
pub fn with_lock<U>(&self, f: impl FnOnce(&mut T) -> U) -> U {
let res = self.raw.with_lock(|| {
let ptr = self.data.get();
let inner = unsafe { &mut *ptr };
catch_unwind(AssertUnwindSafe(|| f(inner)))
});
match res {
Ok(g) => g,
#[cfg(feature = "std")]
Err(b) => std::panic::resume_unwind(b),
}
}
#[must_use]
pub fn try_with_lock<U>(&self, f: impl FnOnce(&mut T) -> U) -> Option<U> {
let res = self.raw.try_with_lock(|| {
let ptr = self.data.get();
let inner = unsafe { &mut *ptr };
catch_unwind(AssertUnwindSafe(|| f(inner)))
});
match res {
None => None,
Some(Ok(g)) => Some(g),
#[cfg(feature = "std")]
Some(Err(b)) => std::panic::resume_unwind(b),
}
}
}
impl<R: RawMutex, T: ?Sized> BlockingMutex<R, T> {
pub fn lock(&self) -> MutexGuard<'_, R, T> {
self.raw.lock();
MutexGuard {
lock: self,
_marker: PhantomData,
}
}
pub fn try_lock(&self) -> Option<MutexGuard<'_, R, T>> {
if self.raw.try_lock() {
Some(MutexGuard {
lock: self,
_marker: PhantomData,
})
} else {
None
}
}
}
impl<R, T> BlockingMutex<R, T> {
#[inline]
pub const fn const_new(raw_mutex: R, val: T) -> BlockingMutex<R, T> {
BlockingMutex {
raw: raw_mutex,
data: UnsafeCell::new(val),
}
}
#[inline]
pub fn into_inner(self) -> T {
self.data.into_inner()
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.data.get() }
}
pub unsafe fn get_unchecked(&self) -> *mut T {
self.data.get()
}
}
#[cfg(feature = "fmt")]
impl<R, T> core::fmt::Debug for BlockingMutex<R, T>
where
R: ScopedRawMutex + core::fmt::Debug,
T: ?Sized + core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut s = f.debug_struct("BlockingMutex");
s.field("raw", &self.raw);
self.try_with_lock(|data| s.field("data", &data).finish())
.unwrap_or_else(|| s.field("data", &format_args!("<locked>")).finish())
}
}
impl<R: RawMutex, T: ?Sized> Drop for MutexGuard<'_, R, T> {
fn drop(&mut self) {
debug_assert!(
self.lock.raw.is_locked(),
"tried to unlock a `Mutex` that was not locked! this is almost \
certainly a bug in the `RawMutex` implementation (`{}`)",
core::any::type_name::<R>(),
);
unsafe {
self.lock.raw.unlock();
}
}
}
impl<R: RawMutex, T: ?Sized> Deref for MutexGuard<'_, R, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
debug_assert!(
self.lock.raw.is_locked(),
"tried to dereference a `MutexGuard` that was not locked! this is \
almost certainly a bug in the `RawMutex` implementation (`{}`)",
core::any::type_name::<R>(),
);
unsafe {
&*self.lock.data.get()
}
}
}
impl<R: RawMutex, T: ?Sized> DerefMut for MutexGuard<'_, R, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
debug_assert!(
self.lock.raw.is_locked(),
"tried to mutably dereference a `MutexGuard` that was not locked! \
this is almost certainly a bug in the `RawMutex` implementation \
(`{}`)",
core::any::type_name::<R>(),
);
unsafe {
&mut *self.lock.data.get()
}
}
}
unsafe impl<R, T> Send for MutexGuard<'_, R, T>
where
T: ?Sized + Send,
R: RawMutex,
R::GuardMarker: Send,
{
}
unsafe impl<R, T> Sync for MutexGuard<'_, R, T>
where
T: ?Sized + Sync,
R: RawMutex,
{
}
#[cfg(feature = "fmt")]
impl<R, T> core::fmt::Debug for MutexGuard<'_, R, T>
where
T: ?Sized + core::fmt::Debug,
R: RawMutex,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Debug::fmt(&self.data, f)
}
}
#[cfg(feature = "fmt")]
impl<R, T> core::fmt::Display for MutexGuard<'_, R, T>
where
T: ?Sized + core::fmt::Display,
R: RawMutex,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(&self.data, f)
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod test {
use core::sync::atomic::{AtomicBool, Ordering};
use crate::{raw_impls::cs::CriticalSectionRawMutex, BlockingMutex};
#[test]
fn unlocks_on_unwind() {
static MUTEX: BlockingMutex<CriticalSectionRawMutex, u64> = BlockingMutex::new(0);
let res = std::thread::spawn(|| {
MUTEX.with_lock(|num| {
let old = *num;
*num += 1;
old
})
})
.join()
.unwrap();
assert_eq!(0, res);
std::thread::spawn(|| {
MUTEX.with_lock(|_num| {
panic!();
})
})
.join()
.unwrap_err();
let res = std::thread::spawn(|| {
MUTEX.with_lock(|num| {
let old = *num;
*num += 1;
old
})
})
.join()
.unwrap();
assert_eq!(1, res);
}
#[test]
fn try_unlocks_on_unwind() {
static MUTEX: BlockingMutex<CriticalSectionRawMutex, u64> = BlockingMutex::new(0);
let res = std::thread::spawn(|| {
MUTEX.try_with_lock(|num| {
let old = *num;
*num += 1;
old
})
})
.join()
.unwrap();
assert_eq!(Some(0), res);
static TRIED: AtomicBool = AtomicBool::new(false);
std::thread::spawn(|| {
MUTEX.try_with_lock(|_num| {
TRIED.store(true, Ordering::Relaxed);
panic!();
})
})
.join()
.unwrap_err();
assert_eq!(true, TRIED.load(Ordering::Relaxed));
let res = std::thread::spawn(|| {
MUTEX.try_with_lock(|num| {
let old = *num;
*num += 1;
old
})
})
.join()
.unwrap();
assert_eq!(Some(1), res);
}
}