use std::future::Future;
use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll, Waker};
#[derive(Debug)]
struct NotifyState {
waiters: Vec<Waker>,
}
#[derive(Debug)]
pub struct AsyncNotify {
state: Mutex<NotifyState>,
}
impl AsyncNotify {
#[inline]
pub fn new() -> Self {
Self {
state: Mutex::new(NotifyState {
waiters: Vec::new(),
}),
}
}
pub fn notify_waiters(&self) {
let mut guard = self.state.lock().expect("async notify mutex poisoned");
for waker in guard.waiters.drain(..) {
waker.wake();
}
}
#[inline]
pub fn notified(&self) -> Notified<'_> {
Notified {
notify: self,
armed: false,
}
}
}
impl Default for AsyncNotify {
fn default() -> Self {
Self::new()
}
}
pub struct Notified<'a> {
notify: &'a AsyncNotify,
armed: bool,
}
impl Future for Notified<'_> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().get_mut();
if this.armed {
this.armed = false;
return Poll::Ready(());
}
let mut guard = this
.notify
.state
.lock()
.expect("async notify mutex poisoned");
guard.waiters.push(cx.waker().clone());
this.armed = true;
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_creates_empty_notify() {
let n = AsyncNotify::default();
n.notify_waiters();
}
#[tokio::test]
async fn notified_completes_after_notify_waiters() {
let n = AsyncNotify::new();
let n2 = std::sync::Arc::new(n);
let n3 = n2.clone();
let handle = tokio::spawn(async move {
n3.notified().await;
});
tokio::task::yield_now().await;
n2.notify_waiters();
handle.await.expect("notified task should complete");
}
}