#[cfg(feature = "bilock")]
use futures_core::future::Future;
use futures_core::task::{Context, Poll, Waker};
use core::cell::UnsafeCell;
use core::fmt;
use core::ops::{Deref, DerefMut};
use core::pin::Pin;
use core::sync::atomic::AtomicUsize;
use core::sync::atomic::Ordering::SeqCst;
use alloc::boxed::Box;
use alloc::sync::Arc;
#[derive(Debug)]
pub struct BiLock<T> {
arc: Arc<Inner<T>>,
}
#[derive(Debug)]
struct Inner<T> {
state: AtomicUsize,
value: Option<UnsafeCell<T>>,
}
unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}
impl<T> BiLock<T> {
pub fn new(t: T) -> (BiLock<T>, BiLock<T>) {
let arc = Arc::new(Inner {
state: AtomicUsize::new(0),
value: Some(UnsafeCell::new(t)),
});
(BiLock { arc: arc.clone() }, BiLock { arc })
}
pub fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<BiLockGuard<'_, T>> {
let mut waker = None;
loop {
match self.arc.state.swap(1, SeqCst) {
0 => return Poll::Ready(BiLockGuard { bilock: self }),
1 => {}
n => unsafe {
let mut prev = Box::from_raw(n as *mut Waker);
*prev = cx.waker().clone();
waker = Some(prev);
}
}
let me: Box<Waker> = waker.take().unwrap_or_else(||Box::new(cx.waker().clone()));
let me = Box::into_raw(me) as usize;
match self.arc.state.compare_exchange(1, me, SeqCst, SeqCst) {
Ok(_) => return Poll::Pending,
Err(0) => unsafe {
waker = Some(Box::from_raw(me as *mut Waker));
},
Err(n) => panic!("invalid state: {}", n),
}
}
}
#[cfg(feature = "bilock")]
pub fn lock(&self) -> BiLockAcquire<'_, T> {
BiLockAcquire {
bilock: self,
}
}
pub fn reunite(self, other: Self) -> Result<T, ReuniteError<T>>
where
T: Unpin,
{
if Arc::ptr_eq(&self.arc, &other.arc) {
drop(other);
let inner = Arc::try_unwrap(self.arc)
.ok()
.expect("futures: try_unwrap failed in BiLock<T>::reunite");
Ok(unsafe { inner.into_value() })
} else {
Err(ReuniteError(self, other))
}
}
fn unlock(&self) {
match self.arc.state.swap(0, SeqCst) {
0 => panic!("invalid unlocked state"),
1 => {}
n => unsafe {
Box::from_raw(n as *mut Waker).wake();
}
}
}
}
impl<T: Unpin> Inner<T> {
unsafe fn into_value(mut self) -> T {
self.value.take().unwrap().into_inner()
}
}
impl<T> Drop for Inner<T> {
fn drop(&mut self) {
assert_eq!(self.state.load(SeqCst), 0);
}
}
pub struct ReuniteError<T>(pub BiLock<T>, pub BiLock<T>);
impl<T> fmt::Debug for ReuniteError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ReuniteError")
.field(&"...")
.finish()
}
}
impl<T> fmt::Display for ReuniteError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "tried to reunite two BiLocks that don't form a pair")
}
}
#[cfg(feature = "std")]
impl<T: core::any::Any> std::error::Error for ReuniteError<T> {}
#[derive(Debug)]
pub struct BiLockGuard<'a, T> {
bilock: &'a BiLock<T>,
}
impl<T> Deref for BiLockGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.bilock.arc.value.as_ref().unwrap().get() }
}
}
impl<T: Unpin> DerefMut for BiLockGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.bilock.arc.value.as_ref().unwrap().get() }
}
}
impl<T> BiLockGuard<'_, T> {
pub fn as_pin_mut(&mut self) -> Pin<&mut T> {
unsafe { Pin::new_unchecked(&mut *self.bilock.arc.value.as_ref().unwrap().get()) }
}
}
impl<T> Drop for BiLockGuard<'_, T> {
fn drop(&mut self) {
self.bilock.unlock();
}
}
#[cfg(feature = "bilock")]
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
pub struct BiLockAcquire<'a, T> {
bilock: &'a BiLock<T>,
}
#[cfg(feature = "bilock")]
impl<T> Unpin for BiLockAcquire<'_, T> {}
#[cfg(feature = "bilock")]
impl<'a, T> Future for BiLockAcquire<'a, T> {
type Output = BiLockGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.bilock.poll_lock(cx)
}
}