use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::oneshot::{Receiver, Sender, channel, error::TryRecvError};
#[derive(Debug)]
pub struct OnceTrigger(Sender<()>);
impl OnceTrigger {
pub fn trigger(self) -> bool {
self.0.send(()).is_ok()
}
pub async fn dropped(&mut self) {
self.0.closed().await
}
pub fn is_dropped(&self) -> bool {
self.0.is_closed()
}
pub fn poll_dropped(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.0.poll_closed(cx)
}
}
#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum Triggered {
#[default]
Pending,
Triggered,
Dropped,
}
#[derive(Debug)]
pub struct OnceWaiter {
recv: Receiver<()>,
triggered: Triggered,
}
impl OnceWaiter {
pub fn triggered(&mut self) -> Triggered {
match self.triggered {
Triggered::Pending => {
let triggered = match self.recv.try_recv() {
Ok(_) => Triggered::Triggered,
Err(TryRecvError::Closed) => Triggered::Dropped,
_ => Triggered::Pending,
};
self.triggered = triggered;
triggered
}
triggered => triggered,
}
}
pub fn has_been_triggered(mut self) -> Triggered {
self.triggered()
}
pub fn blocking_wait(self) -> bool {
if self.triggered != Triggered::Pending {
return self.triggered == Triggered::Triggered;
}
self.recv.blocking_recv().is_ok()
}
}
impl Future for OnceWaiter {
type Output = bool;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.triggered != Triggered::Pending {
return Poll::Ready(self.triggered == Triggered::Triggered);
}
match Pin::new(&mut self.recv).poll(cx) {
Poll::Ready(Ok(_)) => {
self.triggered = Triggered::Triggered;
Poll::Ready(true)
}
Poll::Ready(Err(_)) => {
self.triggered = Triggered::Dropped;
Poll::Ready(false)
}
Poll::Pending => Poll::Pending,
}
}
}
pub fn once_event() -> (OnceTrigger, OnceWaiter) {
let triggered = Default::default();
let (send, recv) = channel();
(OnceTrigger(send), OnceWaiter { recv, triggered })
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "multi_thread")]
async fn async_wait() {
let (trigger, waiter) = once_event();
tokio::spawn(async move {
assert!(trigger.trigger());
});
assert!(waiter.await);
let (trigger, waiter) = once_event();
drop(waiter);
assert!(!trigger.trigger());
let (trigger, waiter) = once_event();
drop(trigger);
assert!(!waiter.await);
let (trigger, mut waiter) = once_event();
tokio::spawn(async move {
assert!(trigger.trigger());
});
while waiter.triggered() == Triggered::Pending {}
assert_eq!(waiter.triggered(), Triggered::Triggered);
assert!(waiter.await);
let (trigger, mut waiter) = once_event();
drop(trigger);
while waiter.triggered() == Triggered::Pending {}
assert_eq!(waiter.triggered(), Triggered::Dropped);
assert!(!waiter.await);
let (trigger, mut waiter) = once_event();
assert_eq!(waiter.triggered(), Triggered::Pending);
tokio::spawn(async move {
assert!(trigger.trigger());
});
assert!((&mut waiter).await);
assert_eq!(waiter.triggered(), Triggered::Triggered);
assert_eq!(waiter.has_been_triggered(), Triggered::Triggered);
let (trigger, mut waiter) = once_event();
assert_eq!(waiter.triggered(), Triggered::Pending);
drop(trigger);
assert!(!(&mut waiter).await);
assert_eq!(waiter.triggered(), Triggered::Dropped);
assert_eq!(waiter.has_been_triggered(), Triggered::Dropped);
}
#[test]
fn blocking_wait() {
use std::thread;
let (trigger, waiter) = once_event();
thread::spawn(move || {
assert!(trigger.trigger());
});
assert!(waiter.blocking_wait());
let (trigger, waiter) = once_event();
drop(waiter);
assert!(!trigger.trigger());
let (trigger, waiter) = once_event();
drop(trigger);
assert!(!waiter.blocking_wait());
let (trigger, mut waiter) = once_event();
thread::spawn(move || {
assert!(trigger.trigger());
});
while waiter.triggered() == Triggered::Pending {}
assert_eq!(waiter.triggered(), Triggered::Triggered);
assert!(waiter.blocking_wait());
let (trigger, mut waiter) = once_event();
drop(trigger);
while waiter.triggered() == Triggered::Pending {}
assert_eq!(waiter.triggered(), Triggered::Dropped);
assert!(!waiter.blocking_wait());
}
#[test]
fn triggered() {
let (trigger, mut waiter) = once_event();
assert_eq!(waiter.triggered(), Triggered::Pending);
assert_eq!(waiter.triggered(), Triggered::Pending);
assert!(trigger.trigger());
assert_eq!(waiter.triggered(), Triggered::Triggered);
assert_eq!(waiter.triggered(), Triggered::Triggered);
let (trigger, mut waiter) = once_event();
drop(trigger);
assert_eq!(waiter.triggered(), Triggered::Dropped);
assert_eq!(waiter.triggered(), Triggered::Dropped);
}
#[test]
fn has_been_triggered() {
let (trigger, waiter) = once_event();
assert!(!trigger.is_dropped());
assert_eq!(waiter.has_been_triggered(), Triggered::Pending);
assert!(trigger.is_dropped());
assert!(!trigger.trigger());
let (trigger, waiter) = once_event();
assert!(trigger.trigger());
assert_eq!(waiter.has_been_triggered(), Triggered::Triggered);
let (trigger, waiter) = once_event();
drop(trigger);
assert_eq!(waiter.has_been_triggered(), Triggered::Dropped);
}
#[test]
fn is_dropped() {
let (trigger, waiter) = once_event();
assert!(!trigger.is_dropped());
drop(waiter);
assert!(trigger.is_dropped());
assert!(!trigger.trigger());
}
#[tokio::test(flavor = "multi_thread")]
async fn dropped() {
let (mut trigger, waiter) = once_event();
assert!(!trigger.is_dropped());
tokio::spawn(async move {
drop(waiter);
});
trigger.dropped().await;
assert!(trigger.is_dropped());
assert!(!trigger.trigger());
}
#[tokio::test(flavor = "multi_thread")]
async fn poll_dropped() {
use std::future::poll_fn;
let (mut trigger, waiter) = once_event();
assert!(!trigger.is_dropped());
tokio::spawn(async move {
drop(waiter);
});
poll_fn(|cx| trigger.poll_dropped(cx)).await;
assert!(trigger.is_dropped());
assert!(!trigger.trigger());
}
#[tokio::test(flavor = "multi_thread")]
async fn select_waiter() {
use std::time::Duration;
use tokio::time::{interval as _interval, sleep};
let mut ticks = 0;
let mut interval = _interval(Duration::from_millis(200));
let (trigger, mut waiter) = once_event();
tokio::spawn(async move {
sleep(Duration::from_millis(500)).await;
trigger.trigger();
});
loop {
tokio::select! {
_ = interval.tick() => ticks += 1,
_ = &mut waiter => break,
}
assert_eq!(waiter.triggered(), Triggered::Pending);
}
assert_eq!(ticks, 3);
assert_eq!(waiter.triggered(), Triggered::Triggered);
assert_eq!(waiter.has_been_triggered(), Triggered::Triggered);
let mut ticks = 0;
let mut interval = _interval(Duration::from_millis(200));
let (trigger, mut waiter) = once_event();
tokio::spawn(async move {
sleep(Duration::from_millis(500)).await;
drop(trigger);
});
loop {
tokio::select! {
_ = interval.tick() => ticks += 1,
_ = &mut waiter => break,
}
assert_eq!(waiter.triggered(), Triggered::Pending);
}
assert_eq!(ticks, 3);
assert_eq!(waiter.triggered(), Triggered::Dropped);
assert_eq!(waiter.has_been_triggered(), Triggered::Dropped);
}
#[tokio::test(flavor = "multi_thread")]
async fn select_dropped() {
use std::time::Duration;
use tokio::time::sleep;
use tokio_util::time::FutureExt;
let timeout = Duration::from_millis(100);
let (mut trigger, waiter) = once_event();
tokio::spawn(async move {
tokio::select! {
_ = trigger.dropped() => (),
_ = sleep(Duration::from_millis(500)) => {
trigger.trigger();
}
}
});
assert!(waiter.timeout(timeout).await.is_err());
let (mut trigger, waiter) = once_event();
tokio::spawn(async move {
tokio::select! {
_ = trigger.dropped() => (),
_ = sleep(Duration::from_millis(5)) => {
drop(trigger);
}
}
});
assert_eq!(waiter.timeout(timeout).await, Ok(false));
let (mut trigger, waiter) = once_event();
tokio::spawn(async move {
tokio::select! {
_ = trigger.dropped() => (),
_ = sleep(Duration::from_millis(5)) => {
trigger.trigger();
}
}
});
assert_eq!(waiter.timeout(timeout).await, Ok(true));
}
}