use std::future::Future;
use std::num::NonZeroUsize;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll, ready};
use event_listener::{Event, EventListener, listener};
use futures_core::Stream;
use pin_project_lite::pin_project;
#[derive(Debug, Default)]
pub struct Notify {
count: AtomicUsize,
event: Event,
}
impl Notify {
pub const fn new() -> Self {
Self {
count: AtomicUsize::new(0),
event: Event::new(),
}
}
#[inline]
pub fn notify(&self) {
self.notify_n(NonZeroUsize::new(1).unwrap())
}
#[inline]
pub fn notify_n(&self, n: NonZeroUsize) {
let n = n.get();
self.count.fetch_add(n, Ordering::Release);
self.event.notify(n);
}
#[inline]
pub fn notify_waiters(&self, n: NonZeroUsize) {
self.event.notify(n.get());
}
#[inline]
pub async fn notified(&self) {
loop {
if self.fast_path() {
return;
}
listener!(self.event => listener);
if self.fast_path() {
return;
}
listener.await;
}
}
fn fast_path(&self) -> bool {
#[allow(deprecated)]
self.count
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |c| c.checked_sub(1))
.is_ok()
}
}
pin_project! {
pub struct NotifyStream<T: Deref<Target=Notify>> {
#[pin]
notify: T,
listener: Option<EventListener>,
}
}
impl<T: Deref<Target = Notify>> NotifyStream<T> {
pub const fn new(notify: T) -> Self {
Self {
notify,
listener: None,
}
}
pub fn into_inner(self) -> T {
self.notify
}
}
impl<T: Deref<Target = Notify>> AsRef<Notify> for NotifyStream<T> {
fn as_ref(&self) -> &Notify {
self.notify.deref()
}
}
impl<T: Deref<Target = Notify>> Stream for NotifyStream<T> {
type Item = ();
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let notify = this.notify.deref();
loop {
if notify.fast_path() {
*this.listener = None;
return Poll::Ready(Some(()));
}
match this.listener.as_mut() {
None => {
let listener = notify.event.listen();
*this.listener = Some(listener);
}
Some(listener) => {
ready!(Pin::new(listener).poll(cx));
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use futures_util::{FutureExt, StreamExt, select};
use super::*;
#[test]
fn test() {
async_global_executor::block_on(async {
let notify = Arc::new(Notify::new());
let notify2 = notify.clone();
async_global_executor::spawn(async move {
notify2.notify();
println!("sent notification");
})
.detach();
println!("received notification");
notify.notified().await;
})
}
#[test]
fn test_multi_notify() {
async_global_executor::block_on(async {
let notify = Arc::new(Notify::new());
let notify2 = notify.clone();
notify.notify();
notify.notify();
select! {
_ = notify2.notified().fuse() => {}
default => unreachable!("there should be notified")
}
select! {
_ = notify2.notified().fuse() => {}
default => unreachable!("there should be notified")
}
select! {
_ = notify2.notified().fuse() => unreachable!("there should not be notified"),
default => {}
}
notify.notify();
select! {
_ = notify2.notified().fuse() => {}
default => unreachable!("there should be notified")
}
})
}
#[test]
fn test_notify_n() {
async_global_executor::block_on(async {
let notify = Arc::new(Notify::new());
let notify2 = notify.clone();
notify.notify_n(3.try_into().unwrap());
for _ in 0..3 {
select! {
_ = notify2.notified().fuse() => {}
default => unreachable!("there should be notified")
}
}
select! {
_ = notify2.notified().fuse() => unreachable!("there should not be notified"),
default => {}
}
})
}
#[test]
fn test_notify_waiters() {
async_global_executor::block_on(async {
let notify = Arc::new(Notify::new());
let notify2 = notify.clone();
let notify3 = notify.clone();
let t1 = async_global_executor::spawn(async move {
notify2.notified().await;
});
let t2 = async_global_executor::spawn(async move {
notify3.notified().await;
});
async_global_executor::spawn(async {}).await;
notify.notify();
notify.notify_waiters(NonZeroUsize::new(2).unwrap());
notify.notify();
t1.await;
t2.await;
})
}
#[test]
fn stream() {
async_global_executor::block_on(async {
let notify = Arc::new(Notify::new());
let mut notify_stream = NotifyStream::new(notify.clone());
async_global_executor::spawn(async move {
notify.notify();
println!("sent notification");
})
.detach();
notify_stream.next().await.unwrap();
})
}
}