use std::cell::Cell;
use std::cell::UnsafeCell;
use std::error::Error;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Display;
use std::future::Future;
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::process;
use std::ptr::NonNull;
use std::task;
use std::task::Poll;
pub struct WaitList<I, O> {
borrowed: Cell<bool>,
inner: UnsafeCell<Inner<I, O>>,
}
struct Inner<I, O> {
head: Option<NonNull<UnsafeCell<Waiter<I, O>>>>,
tail: Option<NonNull<UnsafeCell<Waiter<I, O>>>>,
}
impl<I, O> Inner<I, O> {
unsafe fn enqueue(&mut self, waiter: &UnsafeCell<Waiter<I, O>>) {
unsafe {
(*waiter.get()).prev = self.tail;
}
if let Some(prev) = self.tail {
let prev = unsafe { &mut *prev.as_ref().get() };
debug_assert_eq!(prev.next, None);
prev.next = Some(NonNull::from(waiter));
}
self.tail = Some(NonNull::from(waiter));
self.head.get_or_insert(NonNull::from(waiter));
}
unsafe fn dequeue(&mut self, waiter: &UnsafeCell<Waiter<I, O>>) {
let next = unsafe { (*waiter.get()).next };
let prev = unsafe { (*waiter.get()).prev };
let prev_next_pointer = match prev {
Some(prev) => unsafe { &mut (*prev.as_ref().get()).next },
None => &mut self.head,
};
debug_assert_eq!(*prev_next_pointer, Some(NonNull::from(waiter)));
*prev_next_pointer = next;
let next_prev_pointer = match next {
Some(next) => unsafe { &mut (*next.as_ref().get()).prev },
None => &mut self.tail,
};
debug_assert_eq!(*next_prev_pointer, Some(NonNull::from(waiter)));
*next_prev_pointer = prev;
}
}
struct Waiter<I, O> {
next: Option<NonNull<UnsafeCell<Waiter<I, O>>>>,
prev: Option<NonNull<UnsafeCell<Waiter<I, O>>>>,
state: State<I, O>,
waker: Option<task::Waker>,
}
union State<I, O> {
input: ManuallyDrop<I>,
output: ManuallyDrop<O>,
}
impl<I, O> Drop for Waiter<I, O> {
fn drop(&mut self) {
unsafe {
if self.waker.is_some() {
ManuallyDrop::drop(&mut self.state.input);
} else {
ManuallyDrop::drop(&mut self.state.output);
}
}
}
}
impl<I, O> WaitList<I, O> {
#[must_use]
pub const fn new() -> Self {
Self {
borrowed: Cell::new(false),
inner: UnsafeCell::new(Inner {
head: None,
tail: None,
}),
}
}
#[must_use]
pub fn try_borrow(&self) -> Option<Borrowed<'_, I, O>> {
if self.borrowed.replace(true) {
return None;
}
Some(Borrowed { list: self })
}
#[must_use]
pub fn borrow(&self) -> Borrowed<'_, I, O> {
self.try_borrow()
.expect("attempted to borrow `WaitList` while it is already borrowed")
}
pub async fn wait(&self, input: I) -> O {
let waiter = UnsafeCell::new(Waiter {
next: None,
prev: None,
state: State {
input: ManuallyDrop::new(input),
},
waker: Some(CloneWaker.await),
});
unsafe {
WaitInner::new(self, &waiter).await;
}
let mut waiter = ManuallyDrop::new(waiter.into_inner());
debug_assert!(waiter.waker.is_none());
unsafe { ManuallyDrop::take(&mut waiter.state.output) }
}
}
impl<I, O> Debug for WaitList<I, O> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("WaitList")
}
}
#[derive(Debug)]
pub struct Borrowed<'wait_list, I, O> {
list: &'wait_list WaitList<I, O>,
}
impl<'wait_list, I, O> Borrowed<'wait_list, I, O> {
fn inner(&self) -> &Inner<I, O> {
unsafe { &*self.list.inner.get() }
}
fn inner_mut(&mut self) -> &mut Inner<I, O> {
unsafe { &mut *self.list.inner.get() }
}
fn head(&self) -> Option<&UnsafeCell<Waiter<I, O>>> {
Some(unsafe { self.inner().head?.as_ref() })
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner().head.is_none()
}
#[must_use]
pub fn head_input(&self) -> Option<&I> {
Some(unsafe { &(*self.head()?.get()).state.input })
}
#[must_use]
pub fn head_input_mut(&mut self) -> Option<&mut I> {
Some(unsafe { &mut (*self.head()?.get()).state.input })
}
pub fn wake_one(&mut self, output: O) -> Result<I, WakeOneError<O>> {
let inner = self.inner_mut();
let head = match inner.head {
Some(head) => head,
None => return Err(WakeOneError { output }),
};
let (waker, input) = {
let head = unsafe { &mut *head.as_ref().get() };
let waker = match head.waker.take() {
Some(waker) => waker,
None => unreachable!(),
};
let input = unsafe { ManuallyDrop::take(&mut head.state.input) };
head.state.output = ManuallyDrop::new(output);
(waker, input)
};
unsafe {
inner.dequeue(head.as_ref());
}
waker.wake();
Ok(input)
}
}
impl<I, O> Drop for Borrowed<'_, I, O> {
fn drop(&mut self) {
debug_assert!(self.list.borrowed.get());
self.list.borrowed.set(false);
}
}
struct WaitInner<'list, 'waiter, I, O> {
list: &'list WaitList<I, O>,
waiter: &'waiter UnsafeCell<Waiter<I, O>>,
}
impl<'list, 'waiter, I, O> WaitInner<'list, 'waiter, I, O> {
unsafe fn new(list: &'list WaitList<I, O>, waiter: &'waiter UnsafeCell<Waiter<I, O>>) -> Self {
unsafe {
list.borrow().inner_mut().enqueue(waiter);
}
Self { list, waiter }
}
}
impl<I, O> Future for WaitInner<'_, '_, I, O> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let _guard = self.list.borrow();
let old_waker = unsafe { &mut (*self.waiter.get()).waker };
match old_waker {
Some(same_waker) if same_waker.will_wake(cx.waker()) => {}
Some(_) => *old_waker = Some(cx.waker().clone()),
None => return Poll::Ready(()),
}
Poll::Pending
}
}
impl<I, O> Drop for WaitInner<'_, '_, I, O> {
fn drop(&mut self) {
let mut list = match self.list.try_borrow() {
Some(guard) => guard,
None => process::abort(),
};
unsafe {
if (*self.waiter.get()).waker.is_some() {
list.inner_mut().dequeue(self.waiter);
}
}
}
}
#[non_exhaustive]
#[derive(Debug)]
pub struct WakeOneError<O> {
pub output: O,
}
impl<O> Display for WakeOneError<O> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("no tasks were waiting")
}
}
impl<O: Debug> Error for WakeOneError<O> {}
struct CloneWaker;
impl Future for CloneWaker {
type Output = task::Waker;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
Poll::Ready(cx.waker().clone())
}
}
#[cfg(test)]
mod tests {
use std::future::Future;
use std::task::Poll;
use super::WaitList;
use crate::utils::noop_cx;
#[test]
fn cancel() {
let cx = &mut noop_cx();
let list = WaitList::<u32, ()>::new();
let mut future = Box::pin(list.wait(5));
for _ in 0..10 {
assert_eq!(future.as_mut().poll(cx), Poll::Pending);
}
drop(future);
}
#[test]
fn drop_in_middle() {
let cx = &mut noop_cx();
let list = WaitList::<u32, ()>::new();
let mut f1 = Box::pin(list.wait(1));
let mut f2 = Box::pin(list.wait(2));
let mut f3 = Box::pin(list.wait(3));
assert_eq!(f1.as_mut().poll(cx), Poll::Pending);
assert_eq!(f2.as_mut().poll(cx), Poll::Pending);
assert_eq!(f3.as_mut().poll(cx), Poll::Pending);
drop(f2);
drop(f3);
drop(f1);
assert!(list.borrow().is_empty());
}
}