use event_listener::{Event, EventListener};
use core::{
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
task::{Context, Poll},
};
#[cfg(feature = "triomphe")]
use triomphe::Arc;
#[cfg(all(any(feature = "std", feature = "alloc"), not(feature = "triomphe")))]
use std::sync::Arc;
#[derive(Debug)]
struct AsyncInner {
counter: AtomicUsize,
event: Event,
}
#[cfg_attr(docsrs, doc(cfg(feature = "future")))]
pub struct WaitGroup {
inner: Arc<AsyncInner>,
}
impl Default for WaitGroup {
fn default() -> Self {
Self {
inner: Arc::new(AsyncInner {
counter: AtomicUsize::new(0),
event: Event::new(),
}),
}
}
}
impl From<usize> for WaitGroup {
fn from(count: usize) -> Self {
Self {
inner: Arc::new(AsyncInner {
counter: AtomicUsize::new(count),
event: Event::new(),
}),
}
}
}
impl Clone for WaitGroup {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl core::fmt::Debug for WaitGroup {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WaitGroup")
.field("counter", &self.inner.counter)
.finish()
}
}
impl core::ops::AddAssign<usize> for WaitGroup {
fn add_assign(&mut self, rhs: usize) {
self.add(rhs);
}
}
impl WaitGroup {
pub fn new() -> Self {
Self::default()
}
pub fn add(&self, num: usize) -> Self {
self
.inner
.counter
.fetch_update(Ordering::Release, Ordering::Relaxed, |prev| {
prev.checked_add(num)
})
.expect("WaitGroup counter overflow");
Self {
inner: self.inner.clone(),
}
}
pub fn done(&self) -> usize {
match self
.inner
.counter
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
{
Ok(old) => {
let remaining = old - 1;
if remaining == 0 {
self.inner.event.notify(usize::MAX);
}
remaining
}
Err(_) => 0,
}
}
pub fn remaining(&self) -> usize {
self.inner.counter.load(Ordering::Acquire)
}
pub fn wait(&self) -> WaitGroupFuture<'_> {
WaitGroupFuture {
inner: self,
notified: self.inner.event.listen(),
_pin: core::marker::PhantomPinned,
}
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "std", not(target_family = "wasm")))))]
pub fn wait_blocking(&self) {
use event_listener::Listener;
while self.inner.counter.load(Ordering::Acquire) != 0 {
let ln = self.inner.event.listen();
if self.inner.counter.load(Ordering::Acquire) == 0 {
return;
}
ln.wait();
}
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WaitGroupFuture<'a> {
inner: &'a WaitGroup,
#[pin]
notified: EventListener,
#[pin]
_pin: core::marker::PhantomPinned,
}
}
impl core::future::Future for WaitGroupFuture<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.inner.inner.counter.load(Ordering::Acquire) == 0 {
return Poll::Ready(());
}
let mut this = self.project();
match this.notified.as_mut().poll(cx) {
Poll::Pending => {
if this.inner.inner.counter.load(Ordering::Acquire) == 0 {
Poll::Ready(())
} else {
Poll::Pending
}
}
Poll::Ready(_) => {
if this.inner.inner.counter.load(Ordering::Acquire) == 0 {
Poll::Ready(())
} else {
*this.notified = this.inner.inner.event.listen();
match this.notified.as_mut().poll(cx) {
Poll::Pending => {
if this.inner.inner.counter.load(Ordering::Acquire) == 0 {
Poll::Ready(())
} else {
Poll::Pending
}
}
Poll::Ready(_) => {
if this.inner.inner.counter.load(Ordering::Acquire) == 0 {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
}
}
}
}