use core::{
cell::{Cell, UnsafeCell},
pin::Pin,
ptr::NonNull,
task::{Context, Poll, Waker},
};
pub struct WaiterQueue {
front: Cell<Option<NonNull<WaiterNode>>>,
back: Cell<Option<NonNull<WaiterNode>>>,
waiter_count: Cell<usize>,
}
struct WaiterNode {
lifecycle: Cell<WaiterLifecycle>,
state: UnsafeCell<WaiterNodeState>,
next: Cell<Option<NonNull<Self>>>,
previous: Cell<Option<NonNull<Self>>>,
}
enum WaiterNodeState {
Pending,
Polled { waker: Waker },
Notified,
Shutdown,
}
impl WaiterNode {
pub fn new() -> Self {
Self {
lifecycle: Cell::new(WaiterLifecycle::Unregistered),
state: UnsafeCell::new(WaiterNodeState::Pending),
next: Cell::new(None),
previous: Cell::new(None),
}
}
fn with_state<R>(&self, f: impl FnOnce(&mut WaiterNodeState) -> R) -> R {
f(unsafe { &mut *self.state.get() })
}
fn notify(&self) -> Option<Waker> {
self.with_state(|state| {
match state {
WaiterNodeState::Pending => {
*state = WaiterNodeState::Notified;
None
}
WaiterNodeState::Polled { .. } => {
let WaiterNodeState::Polled { waker } =
core::mem::replace(&mut *state, WaiterNodeState::Notified)
else {
unreachable!();
};
Some(waker)
}
WaiterNodeState::Notified { .. } => unreachable!(),
WaiterNodeState::Shutdown => unreachable!(),
}
})
}
}
impl WaiterQueue {
pub fn new() -> Self {
Self {
front: Cell::new(None),
back: Cell::new(None),
waiter_count: Cell::new(0),
}
}
pub fn waiter_count(&self) -> usize {
self.waiter_count.get()
}
pub unsafe fn wait(&self) -> Waiter<'_> {
Waiter::new(&self)
}
#[inline]
pub async fn wait_for<T>(&self, mut condition: impl FnMut() -> Option<T>) -> T {
loop {
if let Some(ready) = condition() {
return ready;
}
WaitUntil {
waiter: unsafe { self.wait() },
}
.await;
}
}
#[inline]
pub async fn wait_until(&self, mut condition: impl FnMut() -> bool) {
loop {
if condition() {
return;
}
WaitUntil {
waiter: unsafe { self.wait() },
}
.await;
}
}
pub fn notify(&self, count: usize) -> usize {
let mut notified_count = 0;
while let Some(front_ptr) = self.front.get() {
if notified_count >= count {
break;
}
notified_count += 1;
self.waiter_count.set(self.waiter_count.get() - 1);
let next_ptr = unsafe { front_ptr.as_ref() }.next.take();
self.front.set(next_ptr);
if let Some(new_front_ptr) = self.front.get() {
unsafe { new_front_ptr.as_ref() }.previous.set(None);
} else {
debug_assert_eq!(Some(front_ptr), self.back.get());
debug_assert!(unsafe { front_ptr.as_ref() }.previous.get().is_none());
self.back.set(None);
}
let maybe_waker = unsafe { front_ptr.as_ref() }.notify();
if let Some(waker) = maybe_waker {
waker.wake();
}
}
notified_count
}
pub fn notify_all(&self) -> usize {
self.notify(usize::MAX)
}
fn remove_waiter(&self, node: NonNull<WaiterNode>) -> bool {
let prev = unsafe { node.as_ref() }.previous.get();
let next = unsafe { node.as_ref() }.next.get();
if prev.is_none() && next.is_none() && self.front.get() != Some(node) {
return false;
}
self.waiter_count.set(self.waiter_count.get() - 1);
unsafe { node.as_ref() }.next.set(None);
unsafe { node.as_ref() }.previous.set(None);
if Some(node) == self.back.get() {
self.back.set(prev);
debug_assert!(next.is_none());
}
if Some(node) == self.front.get() {
self.front.set(next);
if let Some(next) = next {
unsafe { next.as_ref() }.previous.set(None);
} else {
debug_assert!(self.back.get().is_none());
}
} else if let Some(prev) = prev {
unsafe { prev.as_ref() }.next.set(next);
if let Some(next) = next {
unsafe { next.as_ref() }.previous.set(Some(prev));
}
}
true
}
fn add_waiter(&self, new_node: NonNull<WaiterNode>) {
self.waiter_count.set(self.waiter_count.get() + 1);
debug_assert!(unsafe { new_node.as_ref() }.next.get().is_none());
debug_assert!(unsafe { new_node.as_ref() }.previous.get().is_none());
let prev_back = self.back.replace(Some(new_node));
if let Some(prev_back) = prev_back {
unsafe { new_node.as_ref() }.previous.set(Some(prev_back));
unsafe { prev_back.as_ref() }.next.set(Some(new_node));
} else {
self.front.set(Some(new_node));
debug_assert!(unsafe { new_node.as_ref() }.next.get().is_none());
debug_assert!(unsafe { new_node.as_ref() }.previous.get().is_none());
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum WaiterLifecycle {
Unregistered,
Registered,
Releasing,
}
pub struct Waiter<'a> {
waiter_queue: &'a WaiterQueue,
waiter_node: UnsafeCell<WaiterNode>,
}
impl<'a> Waiter<'a> {
pub fn new(waiter_queue: &'a WaiterQueue) -> Self {
Self {
waiter_queue,
waiter_node: UnsafeCell::new(WaiterNode::new()),
}
}
fn lifecycle(&self) -> WaiterLifecycle {
unsafe { &*self.waiter_node.get() }.lifecycle.get()
}
fn set_lifecycle(&self, new_value: WaiterLifecycle) {
unsafe { &*self.waiter_node.get() }.lifecycle.set(new_value);
}
fn register(self: Pin<&Self>) {
if self.lifecycle() != WaiterLifecycle::Unregistered {
return;
}
let waiter_node_ptr = NonNull::from(unsafe { &*self.waiter_node.get() });
self.set_lifecycle(WaiterLifecycle::Registered);
self.waiter_queue.add_waiter(waiter_node_ptr);
}
pub fn cancel(&self) -> bool {
match self.lifecycle() {
WaiterLifecycle::Registered => {
self.set_lifecycle(WaiterLifecycle::Releasing);
let waiter_node = unsafe { &*self.waiter_node.get() };
let state = unsafe { &mut *waiter_node.state.get() };
match core::mem::replace(state, WaiterNodeState::Shutdown) {
WaiterNodeState::Notified => true,
WaiterNodeState::Shutdown => false,
_ => {
self.waiter_queue.remove_waiter(NonNull::from(waiter_node));
false
}
}
}
_ => false,
}
}
pub fn poll_notification(self: Pin<&'_ Self>, context: &'_ mut Context<'_>) -> Poll<()> {
self.register();
let waiter_node = unsafe { &*self.waiter_node.get() };
let is_notified = waiter_node.with_state(|state| {
match state {
WaiterNodeState::Pending => {
*state = WaiterNodeState::Polled {
waker: context.waker().clone(),
};
}
WaiterNodeState::Polled { waker } => {
let new_waker = context.waker();
if !waker.will_wake(new_waker) {
*state = WaiterNodeState::Polled {
waker: new_waker.clone(),
}
} else {
*state = WaiterNodeState::Polled {
waker: waker.clone(),
}
}
}
WaiterNodeState::Notified => {
*state = WaiterNodeState::Shutdown;
return true;
}
WaiterNodeState::Shutdown => {
panic!("waitq::local waiter polled after shutdown");
}
}
false
});
if is_notified {
debug_assert_eq!(self.lifecycle(), WaiterLifecycle::Registered);
self.set_lifecycle(WaiterLifecycle::Releasing);
return Poll::Ready(());
}
Poll::Pending
}
}
pub struct WaitUntil<'a> {
waiter: Waiter<'a>,
}
impl<'a> WaitUntil<'a> {
#[inline]
fn waiter(self: Pin<&'_ Self>) -> Pin<&'_ Waiter<'a>> {
unsafe { self.map_unchecked(|s| &s.waiter) }
}
}
impl core::future::Future for WaitUntil<'_> {
type Output = ();
#[inline]
fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
self.as_ref().waiter().poll_notification(context)
}
}
impl Drop for WaitUntil<'_> {
fn drop(&mut self) {
if self.waiter.cancel() {
self.waiter.waiter_queue.notify(1);
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_add_waiter() {
let waiter_queue = WaiterQueue::new();
let a = WaiterNode::new();
let b = WaiterNode::new();
let c = WaiterNode::new();
waiter_queue.add_waiter(NonNull::from(&a));
waiter_queue.add_waiter(NonNull::from(&b));
waiter_queue.add_waiter(NonNull::from(&c));
assert!(waiter_queue.remove_waiter(NonNull::from(&b)));
assert!(waiter_queue.remove_waiter(NonNull::from(&a)));
assert!(waiter_queue.remove_waiter(NonNull::from(&c)));
assert!(!waiter_queue.remove_waiter(NonNull::from(&a)));
assert!(!waiter_queue.remove_waiter(NonNull::from(&b)));
assert!(!waiter_queue.remove_waiter(NonNull::from(&c)));
}
#[test]
fn test_register_waiter() {
let waiter_queue = WaiterQueue::new();
let a = core::pin::pin!(Waiter::new(&waiter_queue));
let b = core::pin::pin!(Waiter::new(&waiter_queue));
let c = core::pin::pin!(Waiter::new(&waiter_queue));
a.as_ref().register();
b.as_ref().register();
c.as_ref().register();
assert_eq!(b.cancel(), false);
assert_eq!(a.cancel(), false);
assert_eq!(c.cancel(), false);
}
#[test]
fn test_notify() {
let waiter_queue = WaiterQueue::new();
let a = core::pin::pin!(Waiter::new(&waiter_queue));
let b = core::pin::pin!(Waiter::new(&waiter_queue));
let c = core::pin::pin!(Waiter::new(&waiter_queue));
a.as_ref().register();
b.as_ref().register();
c.as_ref().register();
waiter_queue.notify(2);
assert_eq!(b.cancel(), true);
assert_eq!(a.cancel(), true);
assert_eq!(c.cancel(), false);
}
}