use core::{
cell::Cell,
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll, Waker},
};
use std::collections::VecDeque;
use std::fmt;
use fairly_unsafe_cell::*;
use crate::{extend_lifetime, extend_lifetime_mut};
pub struct Mutex<T> {
value: FairlyUnsafeCell<T>,
currently_used: Cell<bool>,
parked: FairlyUnsafeCell<VecDeque<Waker>>, }
impl<T> Mutex<T> {
pub fn new(value: T) -> Self {
Mutex {
value: FairlyUnsafeCell::new(value),
currently_used: Cell::new(false),
parked: FairlyUnsafeCell::new(VecDeque::new()),
}
}
pub fn into_inner(self) -> T {
self.value.into_inner()
}
pub async fn read(&self) -> ReadGuard<T> {
ReadFuture(self).await
}
pub fn try_read(&self) -> Option<ReadGuard<T>> {
if self.currently_used.get() {
return None;
}
self.currently_used.set(true);
Some(ReadGuard { mutex: self })
}
pub async fn write(&self) -> WriteGuard<T> {
WriteFuture(self).await
}
pub fn try_write(&self) -> Option<WriteGuard<T>> {
if self.currently_used.get() {
return None;
}
self.currently_used.set(true);
Some(WriteGuard { mutex: self })
}
pub async fn set(&self, to: T) {
let mut guard = self.write().await;
*guard = to;
}
pub async fn replace(&self, mut to: T) -> T {
let mut guard = self.write().await;
core::mem::swap(guard.deref_mut(), &mut to);
to
}
pub async fn update(&self, with: impl FnOnce(&T) -> T) {
let mut guard = self.write().await;
*guard = with(&guard);
}
pub async fn fallible_update<E>(&self, with: impl FnOnce(&T) -> Result<T, E>) -> Result<(), E> {
let mut guard = self.write().await;
*guard = with(&guard)?;
Ok(())
}
pub async fn mutate(&self, with: impl FnOnce(&mut T)) {
let mut guard = self.write().await;
with(&mut guard)
}
pub async fn fallible_mutate<E>(
&self,
with: impl FnOnce(&mut T) -> Result<(), E>,
) -> Result<(), E> {
let mut guard = self.write().await;
with(&mut guard)
}
fn wake_next(&self) {
let mut r = unsafe { self.parked.borrow_mut() };
if let Some(waker) = r.deref_mut().pop_front() {
waker.wake()
}
}
fn park(&self, cx: &mut Context<'_>) {
let mut r = unsafe { self.parked.borrow_mut() };
r.deref_mut().push_back(cx.waker().clone());
}
}
impl<T> AsMut<T> for Mutex<T> {
fn as_mut(&mut self) -> &mut T {
self.value.get_mut()
}
}
impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = f.debug_tuple("Mutex");
match self.try_read() {
Some(guard) => {
d.field(&&*guard);
}
None => {
d.field(&format_args!("<locked>"));
}
}
d.finish()
}
}
#[derive(Debug)]
pub struct ReadGuard<'mutex, T> {
mutex: &'mutex Mutex<T>,
}
impl<T> Drop for ReadGuard<'_, T> {
fn drop(&mut self) {
self.mutex.currently_used.set(false);
self.mutex.wake_next();
}
}
impl<T> Deref for ReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
let borrowed = unsafe { self.mutex.value.borrow() }; unsafe { extend_lifetime(borrowed.deref()) }
}
}
#[derive(Debug)]
pub struct WriteGuard<'mutex, T> {
mutex: &'mutex Mutex<T>,
}
impl<T> Drop for WriteGuard<'_, T> {
fn drop(&mut self) {
self.mutex.currently_used.set(false);
self.mutex.wake_next();
}
}
impl<T> Deref for WriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
let borrowed = unsafe { self.mutex.value.borrow() }; unsafe { extend_lifetime(borrowed.deref()) }
}
}
impl<T> DerefMut for WriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
let mut borrowed = unsafe { self.mutex.value.borrow_mut() }; unsafe { extend_lifetime_mut(borrowed.deref_mut()) }
}
}
struct ReadFuture<'mutex, T>(&'mutex Mutex<T>);
impl<'mutex, T> Future for ReadFuture<'mutex, T> {
type Output = ReadGuard<'mutex, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.try_read() {
Some(guard) => Poll::Ready(guard),
None => {
self.0.park(cx);
Poll::Pending
}
}
}
}
struct WriteFuture<'mutex, T>(&'mutex Mutex<T>);
impl<'mutex, T> Future for WriteFuture<'mutex, T> {
type Output = WriteGuard<'mutex, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.try_write() {
Some(guard) => Poll::Ready(guard),
None => {
self.0.park(cx);
Poll::Pending
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::time::Duration;
use smol::{block_on, Timer};
#[test]
fn test_mutex_basic() {
let m = Mutex::new(0);
let set1 = async {
{
let mut handle = m.write().await;
Timer::after(Duration::from_millis(50)).await; assert_eq!(0, *handle.deref());
*handle.deref_mut() = 1;
}
};
let set2 = async {
Timer::after(Duration::from_millis(10)).await;
let mut handle = m.write().await;
assert_eq!(1, *handle.deref());
*handle.deref_mut() = 2;
};
block_on(futures::future::join(set1, set2));
}
}