use crate::{
loom::{
cell::UnsafeCell,
sync::atomic::{
AtomicUsize,
Ordering::{self, *},
},
},
util::{fmt, CachePadded},
Closed,
};
use core::{
future::Future,
ops,
pin::Pin,
task::{self, Context, Poll, Waker},
};
pub struct WaitCell {
state: CachePadded<AtomicUsize>,
waker: UnsafeCell<Option<Waker>>,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum PollWaitError {
Closed,
Busy,
}
#[derive(Debug)]
#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
pub struct Wait<'a> {
cell: &'a WaitCell,
presubscribe: Poll<Result<(), super::Closed>>,
}
#[derive(Debug)]
#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
pub struct Subscribe<'a> {
cell: &'a WaitCell,
}
#[derive(Eq, PartialEq, Copy, Clone)]
struct State(usize);
impl WaitCell {
loom_const_fn! {
#[must_use]
pub fn new() -> Self {
Self {
state: CachePadded::new(AtomicUsize::new(State::WAITING.0)),
waker: UnsafeCell::new(None),
}
}
}
}
impl Default for WaitCell {
fn default() -> Self {
Self::new()
}
}
impl WaitCell {
pub fn poll_wait(&self, cx: &mut Context<'_>) -> Poll<Result<(), PollWaitError>> {
enter_test_debug_span!("WaitCell::poll_wait", cell = ?fmt::ptr(self));
match test_dbg!(self.compare_exchange(State::WAITING, State::REGISTERING, Acquire)) {
Err(actual) if test_dbg!(actual.contains(State::CLOSED)) => {
return Poll::Ready(Err(PollWaitError::Closed));
}
Err(actual) if test_dbg!(actual.contains(State::WOKEN)) => {
self.fetch_and(!State::WOKEN, Release);
return Poll::Ready(Ok(()));
}
Err(actual) if test_dbg!(actual.contains(State::WAKING)) => {
return Poll::Ready(Ok(()));
}
Err(_) => return Poll::Ready(Err(PollWaitError::Busy)),
Ok(_) => {}
}
let waker = cx.waker();
trace!(wait_cell = ?fmt::ptr(self), ?waker, "registering waker");
let prev_waker = self.waker.with_mut(|old_waker| unsafe {
match &mut *old_waker {
Some(old_waker) if waker.will_wake(old_waker) => None,
old => old.replace(waker.clone()),
}
});
if let Some(prev_waker) = prev_waker {
test_debug!("Replaced an old waker in cell, waking");
prev_waker.wake();
}
if let Err(actual) =
test_dbg!(self.compare_exchange(State::REGISTERING, State::WAITING, AcqRel))
{
test_trace!(state = ?actual, "was notified");
let waker = self.waker.with_mut(|waker| unsafe { (*waker).take() });
let state = test_dbg!(self.fetch_and(State::CLOSED, AcqRel));
debug_assert!(
state == actual || state == actual | State::CLOSED,
"state changed unexpectedly while parking!"
);
if let Some(waker) = waker {
waker.wake();
}
if state.contains(State::CLOSED) {
return Poll::Ready(Err(PollWaitError::Closed));
}
return Poll::Ready(Ok(()));
}
Poll::Pending
}
pub fn wait(&self) -> Wait<'_> {
Wait {
cell: self,
presubscribe: Poll::Pending,
}
}
pub fn subscribe(&self) -> Subscribe<'_> {
Subscribe { cell: self }
}
pub fn wake(&self) -> bool {
enter_test_debug_span!("WaitCell::wake", cell = ?fmt::ptr(self));
if let Some(waker) = self.take_waker(false) {
waker.wake();
true
} else {
false
}
}
pub fn close(&self) -> bool {
enter_test_debug_span!("WaitCell::close", cell = ?fmt::ptr(self));
if let Some(waker) = self.take_waker(true) {
waker.wake();
true
} else {
false
}
}
pub async fn wait_for<F: FnMut() -> bool>(&self, mut f: F) -> Result<(), Closed> {
loop {
let wait = self.subscribe().await;
if f() {
return Ok(());
}
wait.await?;
}
}
pub async fn wait_for_value<T, F: FnMut() -> Option<T>>(&self, mut f: F) -> Result<T, Closed> {
loop {
let wait = self.subscribe().await;
if let Some(t) = f() {
return Ok(t);
}
wait.await?;
}
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.current_state() == State::CLOSED
}
pub(crate) fn take_waker(&self, close: bool) -> Option<Waker> {
trace!(wait_cell = ?fmt::ptr(self), ?close, "notifying");
let state = {
let mut bits = State::WAKING | State::WOKEN;
if close {
bits.0 |= State::CLOSED.0;
}
test_dbg!(self.fetch_or(bits, AcqRel))
};
if !test_dbg!(state.contains(State::WAKING | State::REGISTERING | State::CLOSED)) {
let waker = self.waker.with_mut(|thread| unsafe { (*thread).take() });
self.fetch_and(!State::WAKING, Release);
if let Some(waker) = test_dbg!(waker) {
trace!(wait_cell = ?fmt::ptr(self), ?close, ?waker, "notified");
return Some(waker);
}
}
None
}
}
impl WaitCell {
#[inline(always)]
fn compare_exchange(
&self,
State(curr): State,
State(new): State,
success: Ordering,
) -> Result<State, State> {
self.state
.compare_exchange(curr, new, success, Acquire)
.map(State)
.map_err(State)
}
#[inline(always)]
fn fetch_and(&self, State(state): State, order: Ordering) -> State {
State(self.state.fetch_and(state, order))
}
#[inline(always)]
fn fetch_or(&self, State(state): State, order: Ordering) -> State {
State(self.state.fetch_or(state, order))
}
#[inline(always)]
fn current_state(&self) -> State {
State(self.state.load(Acquire))
}
}
unsafe impl Send for WaitCell {}
unsafe impl Sync for WaitCell {}
impl fmt::Debug for WaitCell {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WaitCell")
.field("state", &self.current_state())
.field("waker", &fmt::display(".."))
.finish()
}
}
impl Drop for WaitCell {
fn drop(&mut self) {
self.close();
}
}
impl Future for Wait<'_> {
type Output = Result<(), Closed>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
enter_test_debug_span!("Wait::poll");
if test_dbg!(self.presubscribe.is_ready()) {
return self.presubscribe;
}
match task::ready!(test_dbg!(self.cell.poll_wait(cx))) {
Ok(()) => Poll::Ready(Ok(())),
Err(PollWaitError::Closed) => Poll::Ready(Err(Closed(()))),
Err(PollWaitError::Busy) => {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
impl<'cell> Future for Subscribe<'cell> {
type Output = Wait<'cell>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
enter_test_debug_span!("Subscribe::poll");
let presubscribe = match test_dbg!(self.cell.poll_wait(cx)) {
Poll::Ready(Err(PollWaitError::Busy)) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Poll::Ready(Err(PollWaitError::Closed)) => Poll::Ready(Err(Closed(()))),
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
};
Poll::Ready(Wait {
cell: self.cell,
presubscribe,
})
}
}
impl State {
const WAITING: Self = Self(0b0000);
const REGISTERING: Self = Self(0b0001);
const WAKING: Self = Self(0b0010);
const WOKEN: Self = Self(0b0100);
const CLOSED: Self = Self(0b1000);
fn contains(self, Self(state): Self) -> bool {
self.0 & state > 0
}
}
impl ops::BitOr for State {
type Output = Self;
fn bitor(self, Self(rhs): Self) -> Self::Output {
Self(self.0 | rhs)
}
}
impl ops::BitAnd for State {
type Output = Self;
fn bitand(self, Self(rhs): Self) -> Self::Output {
Self(self.0 & rhs)
}
}
impl ops::Not for State {
type Output = Self;
fn not(self) -> Self::Output {
Self(!self.0)
}
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut has_states = false;
fmt_bits!(self, f, has_states, REGISTERING, WAKING, CLOSED, WOKEN);
if !has_states {
if *self == Self::WAITING {
return f.write_str("WAITING");
}
f.debug_tuple("UnknownState")
.field(&format_args!("{:#b}", self.0))
.finish()?;
}
Ok(())
}
}
#[cfg(all(feature = "alloc", not(loom), test))]
mod tests {
use super::*;
use alloc::sync::Arc;
use tokio_test::{assert_pending, assert_ready, assert_ready_ok, task};
#[test]
fn wait_smoke() {
let _trace = crate::util::test::trace_init();
let wait = Arc::new(WaitCell::new());
let mut task = task::spawn({
let wait = wait.clone();
async move { wait.wait().await }
});
assert_pending!(task.poll());
assert!(wait.wake());
assert!(task.is_woken());
assert_ready_ok!(task.poll());
}
#[test]
fn wait_spurious_poll() {
let _trace = crate::util::test::trace_init();
let cell = Arc::new(WaitCell::new());
let mut task = task::spawn({
let cell = cell.clone();
async move { cell.wait().await }
});
assert_pending!(task.poll(), "first poll should be pending");
assert_pending!(task.poll(), "second poll should be pending");
cell.wake();
assert_ready_ok!(task.poll(), "should have been woken");
}
#[test]
fn subscribe() {
let _trace = crate::util::test::trace_init();
futures::executor::block_on(async {
let cell = WaitCell::new();
let wait = cell.subscribe().await;
cell.wake();
wait.await.unwrap();
})
}
#[test]
fn wake_before_subscribe() {
let _trace = crate::util::test::trace_init();
let cell = Arc::new(WaitCell::new());
cell.wake();
let mut task = task::spawn({
let cell = cell.clone();
async move {
let wait = cell.subscribe().await;
wait.await.unwrap();
}
});
assert_ready!(task.poll(), "woken task should complete");
let mut task = task::spawn({
let cell = cell.clone();
async move {
let wait = cell.subscribe().await;
wait.await.unwrap();
}
});
assert_pending!(task.poll(), "wait cell hasn't been woken yet");
cell.wake();
assert!(task.is_woken());
assert_ready!(task.poll());
}
#[test]
fn wake_debounce() {
let _trace = crate::util::test::trace_init();
let cell = Arc::new(WaitCell::new());
let mut task = task::spawn({
let cell = cell.clone();
async move {
cell.wait().await.unwrap();
}
});
assert_pending!(task.poll());
cell.wake();
cell.wake();
assert!(task.is_woken());
assert_ready!(task.poll());
let mut task = task::spawn({
let cell = cell.clone();
async move {
cell.wait().await.unwrap();
}
});
assert_pending!(task.poll());
assert!(!task.is_woken());
cell.wake();
assert!(task.is_woken());
assert_ready!(task.poll());
}
#[test]
fn subscribe_doesnt_self_wake() {
let _trace = crate::util::test::trace_init();
let cell = Arc::new(WaitCell::new());
let mut task = task::spawn({
let cell = cell.clone();
async move {
let wait = cell.subscribe().await;
wait.await.unwrap();
let wait = cell.subscribe().await;
wait.await.unwrap();
}
});
assert_pending!(task.poll());
assert!(!task.is_woken());
cell.wake();
assert!(task.is_woken());
assert_pending!(task.poll());
assert!(!task.is_woken());
assert_pending!(task.poll());
cell.wake();
assert!(task.is_woken());
assert_ready!(task.poll());
}
}
#[cfg(all(loom, test))]
mod loom {
use super::*;
use crate::loom::{future, sync::Arc, thread};
#[test]
fn basic() {
crate::loom::model(|| {
let wait = Arc::new(WaitCell::new());
let waker = wait.clone();
let closer = wait.clone();
thread::spawn(move || {
tracing::info!("waking");
waker.wake();
tracing::info!("woken");
});
thread::spawn(move || {
tracing::info!("closing");
closer.close();
tracing::info!("closed");
});
tracing::info!("waiting");
let _ = future::block_on(wait.wait());
tracing::info!("wait'd");
});
}
#[test]
fn subscribe() {
crate::loom::model(|| {
future::block_on(async move {
let cell = Arc::new(WaitCell::new());
let wait = cell.subscribe().await;
thread::spawn({
let waker = cell.clone();
move || {
tracing::info!("waking");
waker.wake();
tracing::info!("woken");
}
});
tracing::info!("waiting");
wait.await.expect("wait should be woken, not closed");
tracing::info!("wait'd");
});
});
}
}