use crate::{
blocking::DefaultMutex,
loom::cell::{MutPtr, UnsafeCell},
util::fmt,
};
use core::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
pub use mutex_traits::{RawMutex, ScopedRawMutex};
pub struct Mutex<T, Lock = DefaultMutex> {
lock: Lock,
data: UnsafeCell<T>,
}
#[must_use = "if unused, the `Mutex` will immediately unlock"]
pub struct MutexGuard<'a, T, Lock: RawMutex> {
ptr: MutPtr<T>,
lock: &'a Lock,
_marker: PhantomData<Lock::GuardMarker>,
}
impl<T> Mutex<T> {
loom_const_fn! {
#[must_use]
pub fn new(data: T) -> Self {
Self {
lock: DefaultMutex::new(),
data: UnsafeCell::new(data),
}
}
}
}
impl<T, Lock> Mutex<T, Lock> {
loom_const_fn! {
#[must_use]
pub fn new_with_raw_mutex(data: T, lock: Lock) -> Self {
Self {
lock,
data: UnsafeCell::new(data),
}
}
}
#[inline]
#[must_use]
pub fn into_inner(self) -> T {
self.data.into_inner()
}
pub fn get_mut(&mut self) -> &mut T {
unsafe {
self.data.with_mut(|data| &mut *data)
}
}
}
impl<T, Lock: ScopedRawMutex> Mutex<T, Lock> {
#[track_caller]
pub fn with_lock<U>(&self, f: impl FnOnce(&mut T) -> U) -> U {
self.lock.with_lock(|| {
self.data.with_mut(|data| unsafe {
f(&mut *data)
})
})
}
#[track_caller]
pub fn try_with_lock<U>(&self, f: impl FnOnce(&mut T) -> U) -> Option<U> {
self.lock.try_with_lock(|| {
self.data.with_mut(|data| unsafe {
f(&mut *data)
})
})
}
}
impl<T, Lock> Mutex<T, Lock>
where
Lock: RawMutex,
{
fn guard(&self) -> MutexGuard<'_, T, Lock> {
MutexGuard {
ptr: self.data.get_mut(),
lock: &self.lock,
_marker: PhantomData,
}
}
#[must_use]
#[cfg_attr(test, track_caller)]
pub fn try_lock(&self) -> Option<MutexGuard<'_, T, Lock>> {
if self.lock.try_lock() {
Some(self.guard())
} else {
None
}
}
#[cfg_attr(test, track_caller)]
pub fn lock(&self) -> MutexGuard<'_, T, Lock> {
self.lock.lock();
self.guard()
}
pub unsafe fn force_unlock(&self) {
self.lock.unlock()
}
}
impl<T: Default, Lock: Default> Default for Mutex<T, Lock> {
fn default() -> Self {
Self {
lock: Default::default(),
data: UnsafeCell::new(Default::default()),
}
}
}
impl<T, Lock> fmt::Debug for Mutex<T, Lock>
where
T: fmt::Debug,
Lock: ScopedRawMutex,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.try_with_lock(|data| {
f.debug_struct("Mutex")
.field("data", data)
.field("lock", &format_args!("{}", core::any::type_name::<Lock>()))
.finish()
})
.unwrap_or_else(|| {
f.debug_struct("Mutex")
.field("data", &format_args!("<locked>"))
.field("lock", &format_args!("{}", core::any::type_name::<Lock>()))
.finish()
})
}
}
unsafe impl<T: Send, Lock: Send> Send for Mutex<T, Lock> {}
unsafe impl<T: Send, Lock: Sync> Sync for Mutex<T, Lock> {}
impl<T, Lock: RawMutex> Deref for MutexGuard<'_, T, Lock> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe {
&*self.ptr.deref()
}
}
}
impl<T, Lock: RawMutex> DerefMut for MutexGuard<'_, T, Lock> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
self.ptr.deref()
}
}
}
impl<T, Lock, R: ?Sized> AsRef<R> for MutexGuard<'_, T, Lock>
where
T: AsRef<R>,
Lock: RawMutex,
{
#[inline]
fn as_ref(&self) -> &R {
self.deref().as_ref()
}
}
impl<T, Lock, R: ?Sized> AsMut<R> for MutexGuard<'_, T, Lock>
where
T: AsMut<R>,
Lock: RawMutex,
{
#[inline]
fn as_mut(&mut self) -> &mut R {
self.deref_mut().as_mut()
}
}
impl<T, Lock> Drop for MutexGuard<'_, T, Lock>
where
Lock: RawMutex,
{
#[inline]
#[cfg_attr(test, track_caller)]
fn drop(&mut self) {
unsafe { self.lock.unlock() }
}
}
impl<T, Lock> fmt::Debug for MutexGuard<'_, T, Lock>
where
T: fmt::Debug,
Lock: RawMutex,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
impl<T, Lock> fmt::Display for MutexGuard<'_, T, Lock>
where
T: fmt::Display,
Lock: RawMutex,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
unsafe impl<T, Lock> Send for MutexGuard<'_, T, Lock>
where
T: Send,
Lock: RawMutex + Sync,
Lock::GuardMarker: Send,
{
}
#[cfg(test)]
mod tests {
use crate::loom::{self, thread};
use crate::spin::Spinlock;
use std::prelude::v1::*;
use std::sync::Arc;
use super::*;
#[test]
fn multithreaded() {
loom::model(|| {
let mutex = Arc::new(Mutex::new_with_raw_mutex(String::new(), Spinlock::new()));
let mutex2 = mutex.clone();
let t1 = thread::spawn(move || {
tracing::info!("t1: locking...");
let mut lock = mutex2.lock();
tracing::info!("t1: locked");
lock.push_str("bbbbb");
tracing::info!("t1: dropping...");
});
{
tracing::info!("t2: locking...");
let mut lock = mutex.lock();
tracing::info!("t2: locked");
lock.push_str("bbbbb");
tracing::info!("t2: dropping...");
}
t1.join().unwrap();
});
}
#[test]
fn try_lock() {
loom::model(|| {
let mutex = Mutex::new_with_raw_mutex(42, Spinlock::new());
let a = mutex.try_lock();
assert_eq!(a.as_ref().map(|r| **r), Some(42));
let b = mutex.try_lock();
assert!(b.is_none());
::core::mem::drop(a);
let c = mutex.try_lock();
assert_eq!(c.as_ref().map(|r| **r), Some(42));
});
}
}