use futures_channel::oneshot;
use futures_task::{Context, Poll};
use std::{
cell::UnsafeCell,
clone::Clone,
collections::VecDeque,
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
sync
};
use super::{FutState, TryLockError};
#[cfg(feature = "tokio")] use tokio::task;
#[derive(Debug)]
pub struct MutexGuard<T: ?Sized> {
mutex: Mutex<T>
}
impl<T: ?Sized> Drop for MutexGuard<T> {
fn drop(&mut self) {
self.mutex.unlock();
}
}
impl<T: ?Sized> Deref for MutexGuard<T> {
type Target = T;
fn deref(&self) -> &T {
unsafe {&*self.mutex.inner.data.get()}
}
}
impl<T: ?Sized> DerefMut for MutexGuard<T> {
fn deref_mut(&mut self) -> &mut T {
unsafe {&mut *self.mutex.inner.data.get()}
}
}
pub struct MutexFut<T: ?Sized> {
state: FutState,
mutex: Mutex<T>,
}
impl<T: ?Sized> MutexFut<T> {
fn new(state: FutState, mutex: Mutex<T>) -> Self {
MutexFut{state, mutex}
}
}
impl<T: ?Sized> Drop for MutexFut<T> {
fn drop(&mut self) {
match self.state {
FutState::New => {
},
FutState::Pending(ref mut rx) => {
rx.close();
match rx.try_recv() {
Ok(Some(())) => {
self.mutex.unlock()
},
Ok(None) => {
},
Err(oneshot::Canceled) => {
}
}
},
FutState::Acquired => {
}
}
}
}
impl<T: ?Sized> Future for MutexFut<T> {
type Output = MutexGuard<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let (result, new_state) = match self.state {
FutState::New => {
let mut mtx_data = self.mutex.inner.mutex.lock()
.expect("sync::Mutex::lock");
if mtx_data.owned {
let (tx, mut rx) = oneshot::channel::<()>();
mtx_data.waiters.push_back(tx);
assert!(Pin::new(&mut rx).poll(cx).is_pending());
(Poll::Pending, FutState::Pending(rx))
} else {
mtx_data.owned = true;
let guard = MutexGuard{mutex: self.mutex.clone()};
(Poll::Ready(guard), FutState::Acquired)
}
},
FutState::Pending(ref mut rx) => {
match Pin::new(rx).poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(_) => {
let state = FutState::Acquired;
let result = Poll::Ready(
MutexGuard{mutex: self.mutex.clone()}
);
(result, state)
} }
},
FutState::Acquired => panic!("Double-poll of ready Future")
};
self.state = new_state;
result
}
}
#[derive(Debug, Default)]
struct MutexData {
owned: bool,
waiters: VecDeque<oneshot::Sender<()>>,
}
#[derive(Debug, Default)]
struct Inner<T: ?Sized> {
mutex: sync::Mutex<MutexData>,
data: UnsafeCell<T>,
}
#[derive(Debug)]
pub struct MutexWeak<T: ?Sized> {
inner: sync::Weak<Inner<T>>,
}
impl<T: ?Sized> MutexWeak<T> {
pub fn upgrade(&self) -> Option<Mutex<T>> {
if let Some(inner) = self.inner.upgrade() {
return Some(Mutex{inner})
}
None
}
}
impl<T: ?Sized> Clone for MutexWeak<T> {
fn clone(&self) -> MutexWeak<T> {
MutexWeak {inner: self.inner.clone()}
}
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<T: ?Sized + Send> Send for MutexWeak<T> {}
unsafe impl<T: ?Sized + Send> Sync for MutexWeak<T> {}
#[derive(Debug, Default)]
pub struct Mutex<T: ?Sized> {
inner: sync::Arc<Inner<T>>,
}
impl<T: ?Sized> Clone for Mutex<T> {
fn clone(&self) -> Mutex<T> {
Mutex { inner: self.inner.clone()}
}
}
impl<T> Mutex<T> {
pub fn new(t: T) -> Mutex<T> {
let mutex_data = MutexData {
owned: false,
waiters: VecDeque::new(),
};
let inner = Inner {
mutex: sync::Mutex::new(mutex_data),
data: UnsafeCell::new(t)
}; Mutex { inner: sync::Arc::new(inner)}
}
pub fn try_unwrap(self) -> Result<T, Mutex<T>> {
match sync::Arc::try_unwrap(self.inner) {
Ok(inner) => Ok({
#[allow(unused_unsafe)]
unsafe { inner.data.into_inner() }
}),
Err(arc) => Err(Mutex {inner: arc})
}
}
}
impl<T: ?Sized> Mutex<T> {
pub fn downgrade(this: &Mutex<T>) -> MutexWeak<T> {
MutexWeak {inner: sync::Arc::<Inner<T>>::downgrade(&this.inner)}
}
pub fn get_mut(&mut self) -> Option<&mut T> {
if let Some(inner) = sync::Arc::get_mut(&mut self.inner) {
let lock_data = inner.mutex.get_mut().unwrap();
let data = unsafe { inner.data.get().as_mut() }.unwrap();
debug_assert!(!lock_data.owned);
Some(data)
} else {
None
}
}
pub fn lock(&self) -> MutexFut<T> {
MutexFut::new(FutState::New, self.clone())
}
pub fn try_lock(&self) -> Result<MutexGuard<T>, TryLockError> {
let mut mtx_data = self.inner.mutex.lock().expect("sync::Mutex::lock");
if mtx_data.owned {
Err(TryLockError)
} else {
mtx_data.owned = true;
Ok(MutexGuard{mutex: self.clone()})
}
}
fn unlock(&self) {
let mut mtx_data = self.inner.mutex.lock().expect("sync::Mutex::lock");
assert!(mtx_data.owned);
while let Some(tx) = mtx_data.waiters.pop_front() {
if tx.send(()).is_ok() {
return;
}
}
mtx_data.owned = false;
}
pub fn ptr_eq(this: &Mutex<T>, other: &Mutex<T>) -> bool {
sync::Arc::ptr_eq(&this.inner, &other.inner)
}
}
impl<T: 'static + ?Sized> Mutex<T> {
#[cfg(any(feature = "tokio", all(docsrs, rustdoc)))]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub fn with<B, F, R>(&self, f: F)
-> impl Future<Output = R>
where F: FnOnce(MutexGuard<T>) -> B + Send + 'static,
B: Future<Output = R> + Send + 'static,
R: Send + 'static,
T: Send
{
let jh = tokio::spawn({
let fut = self.lock();
async move { f(fut.await).await }
});
async move { jh.await.unwrap() }
}
#[cfg(any(feature = "tokio", all(docsrs, rustdoc)))]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub fn with_local<B, F, R>(&self, f: F)
-> impl Future<Output = R>
where F: FnOnce(MutexGuard<T>) -> B + 'static,
B: Future<Output = R> + 'static + Unpin,
R: 'static
{
let local = task::LocalSet::new();
let jh = local.spawn_local({
let fut = self.lock();
async move { f(fut.await).await }
});
async move {
local.await;
jh.await.unwrap()
}
}
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
#[cfg(test)]
mod t {
use super::*;
#[test]
fn debug() {
let m = Mutex::<u32>::new(0);
format!("{:?}", &m);
}
#[test]
fn test_default() {
let m = Mutex::default();
let value: u32 = m.try_unwrap().unwrap();
let expected = u32::default();
assert_eq!(expected, value);
}
}