use crate::{
blocking::{DefaultMutex, ScopedRawMutex},
loom::cell::{MutPtr, UnsafeCell},
util::fmt,
wait_queue::{self, WaitQueue},
};
use core::{
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
use pin_project::pin_project;
#[cfg(test)]
mod tests;
pub struct Mutex<T: ?Sized, L: ScopedRawMutex = DefaultMutex> {
wait: WaitQueue<L>,
data: UnsafeCell<T>,
}
#[must_use = "if unused, the `Mutex` will immediately unlock"]
pub struct MutexGuard<'a, T: ?Sized, L: ScopedRawMutex = DefaultMutex> {
data: MutPtr<T>,
_wake: WakeOnDrop<'a, T, L>,
}
#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
#[pin_project]
#[derive(Debug)]
pub struct Lock<'a, T: ?Sized, L: ScopedRawMutex = DefaultMutex> {
#[pin]
wait: wait_queue::Wait<'a, L>,
mutex: &'a Mutex<T, L>,
}
struct WakeOnDrop<'a, T: ?Sized, L: ScopedRawMutex>(&'a Mutex<T, L>);
impl<T> Mutex<T> {
loom_const_fn! {
#[must_use]
pub fn new(data: T) -> Self {
Self::new_with_raw_mutex(data, DefaultMutex::new())
}
}
}
impl<T, L: ScopedRawMutex> Mutex<T, L> {
loom_const_fn! {
pub fn new_with_raw_mutex(data: T, lock: L) -> Self {
Self {
wait: WaitQueue::<L>::new_woken(lock),
data: UnsafeCell::new(data),
}
}
}
}
impl<T, L: ScopedRawMutex> Mutex<T, L> {
#[inline]
#[must_use]
pub fn into_inner(self) -> T {
self.data.into_inner()
}
}
impl<T: ?Sized, L: ScopedRawMutex> Mutex<T, L> {
pub fn lock(&self) -> Lock<'_, T, L> {
Lock {
wait: self.wait.wait(),
mutex: self,
}
}
pub fn try_lock(&self) -> Option<MutexGuard<'_, T, L>> {
match self.wait.try_wait() {
Poll::Pending => None,
Poll::Ready(Ok(_)) => Some(unsafe {
self.guard()
}),
Poll::Ready(Err(_)) => unsafe {
unreachable_unchecked!("`Mutex` never calls `WaitQueue::close`")
},
}
}
pub fn get_mut(&mut self) -> &mut T {
unsafe {
self.data.with_mut(|data| &mut *data)
}
}
unsafe fn guard(&self) -> MutexGuard<'_, T, L> {
MutexGuard {
_wake: WakeOnDrop(self),
data: self.data.get_mut(),
}
}
}
impl<T: Default> Default for Mutex<T> {
fn default() -> Self {
Self::new(Default::default())
}
}
impl<T, L> fmt::Debug for Mutex<T, L>
where
T: ?Sized + fmt::Debug,
L: ScopedRawMutex + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { data: _, wait } = self;
f.debug_struct("Mutex")
.field("data", &fmt::opt(&self.try_lock()).or_else("<locked>"))
.field("wait", wait)
.finish()
}
}
unsafe impl<T, L: ScopedRawMutex> Send for Mutex<T, L>
where
T: ?Sized + Send,
L: Send,
{
}
unsafe impl<T, L: ScopedRawMutex> Sync for Mutex<T, L>
where
T: ?Sized + Send,
L: Sync,
{
}
impl<'a, T, L> Future for Lock<'a, T, L>
where
T: ?Sized,
L: ScopedRawMutex,
{
type Output = MutexGuard<'a, T, L>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.wait.poll(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(_)) => unsafe {
unreachable_unchecked!("`Mutex` never calls `WaitQueue::close`")
},
Poll::Pending => return Poll::Pending,
}
let guard = unsafe {
this.mutex.guard()
};
Poll::Ready(guard)
}
}
impl<T, L> Deref for MutexGuard<'_, T, L>
where
T: ?Sized,
L: ScopedRawMutex,
{
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe {
&*self.data.deref()
}
}
}
impl<T, L> DerefMut for MutexGuard<'_, T, L>
where
T: ?Sized,
L: ScopedRawMutex,
{
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
self.data.deref()
}
}
}
impl<T, L> fmt::Debug for MutexGuard<'_, T, L>
where
T: ?Sized + fmt::Debug,
L: ScopedRawMutex,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
unsafe impl<T, L> Send for MutexGuard<'_, T, L>
where
T: ?Sized + Send,
L: ScopedRawMutex + Sync,
{
}
unsafe impl<T, L> Sync for MutexGuard<'_, T, L>
where
T: ?Sized + Send + Sync,
L: ScopedRawMutex + Sync,
{
}
impl<T: ?Sized, L: ScopedRawMutex> Drop for WakeOnDrop<'_, T, L> {
fn drop(&mut self) {
self.0.wait.wake()
}
}
feature! {
#![feature = "alloc"]
use alloc::sync::Arc;
#[must_use = "if unused, the Mutex will immediately unlock"]
pub struct OwnedMutexGuard<T: ?Sized, L: ScopedRawMutex> {
data: MutPtr<T>,
_wake: WakeArcOnDrop<T, L>,
}
impl<T: ?Sized, L: ScopedRawMutex> Mutex<T, L> {
pub async fn lock_owned(self: Arc<Self>) -> OwnedMutexGuard<T, L> {
self.wait.wait().await.unwrap();
unsafe {
self.owned_guard()
}
}
pub fn try_lock_owned(self: Arc<Self>) -> Result<OwnedMutexGuard<T, L>, Arc<Self>> {
match self.wait.try_wait() {
Poll::Pending => Err(self),
Poll::Ready(Ok(_)) => Ok(unsafe {
self.owned_guard()
}),
Poll::Ready(Err(_)) => unsafe {
unreachable_unchecked!("`Mutex` never calls `WaitQueue::close`")
},
}
}
unsafe fn owned_guard(self: Arc<Self>) -> OwnedMutexGuard<T, L> {
let data = self.data.get_mut();
OwnedMutexGuard {
_wake: WakeArcOnDrop(self),
data,
}
}
}
struct WakeArcOnDrop<T: ?Sized, L: ScopedRawMutex>(Arc<Mutex<T, L>>);
impl<T, L> Deref for OwnedMutexGuard<T, L>
where
T: ?Sized,
L: ScopedRawMutex,
{
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe {
&*self.data.deref()
}
}
}
impl<T, L> DerefMut for OwnedMutexGuard<T, L>
where
T: ?Sized,
L: ScopedRawMutex,
{
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
self.data.deref()
}
}
}
impl<T, L> fmt::Debug for OwnedMutexGuard<T, L>
where
T: ?Sized + fmt::Debug,
L: ScopedRawMutex,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
unsafe impl<T, L> Send for OwnedMutexGuard<T, L>
where
T: ?Sized + Send,
L: ScopedRawMutex + Sync,
{
}
unsafe impl<T, L> Sync for OwnedMutexGuard<T, L>
where
T: ?Sized + Send + Sync,
L: ScopedRawMutex + Sync,
{
}
impl<T, L> Drop for WakeArcOnDrop<T, L>
where
T: ?Sized,
L: ScopedRawMutex,
{
fn drop(&mut self) {
self.0.wait.wake()
}
}
}