use std::{
future::Future,
pin::Pin,
sync::{
Arc,
atomic::{AtomicBool, Ordering}
},
task::{Context, Poll},
time::Duration
};
use tokio::{task, time};
use wakerizer::{Waiter, Wakers};
#[derive(Default)]
struct Trigger {
state: Arc<AtomicBool>,
wakers: Wakers
}
impl Trigger {
fn trigger(&self) {
self.state.store(true, Ordering::SeqCst);
self.wakers.wake_all();
}
fn waiter(&self) -> TriggerWaiter {
TriggerWaiter {
state: Arc::clone(&self.state),
waiter: self.wakers.waiter()
}
}
}
struct TriggerWaiter {
state: Arc<AtomicBool>,
waiter: Waiter
}
impl Future for TriggerWaiter {
type Output = ();
fn poll(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>
) -> Poll<Self::Output> {
if self.state.load(Ordering::SeqCst) {
Poll::Ready(())
} else {
self.waiter.prime(ctx);
Poll::Pending
}
}
}
#[tokio::test]
async fn all_nowait() {
let button = Trigger::default();
let waiter = button.waiter();
let jh1 = task::spawn(async {
waiter.await;
});
let waiter = button.waiter();
let jh2 = task::spawn(async {
waiter.await;
});
let waiter = button.waiter();
let jh3 = task::spawn(async {
waiter.await;
});
button.trigger();
jh1.await.unwrap();
jh2.await.unwrap();
jh3.await.unwrap();
}
#[tokio::test]
async fn all_wait() {
let button = Trigger::default();
let waiter = button.waiter();
let jh1 = task::spawn(async {
waiter.await;
});
let waiter = button.waiter();
let jh2 = task::spawn(async {
waiter.await;
});
let waiter = button.waiter();
let jh3 = task::spawn(async {
waiter.await;
});
time::sleep(Duration::from_millis(100)).await;
button.trigger();
jh1.await.unwrap();
jh2.await.unwrap();
jh3.await.unwrap();
}