use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
use parking_lot::Mutex;
#[derive(Default, Debug)]
struct WaitGroupInner {
waker: Mutex<Option<Waker>>,
token_count: AtomicUsize,
is_waiting: AtomicBool,
}
#[derive(Default)]
pub struct WaitGroup {
inner: Arc<WaitGroupInner>,
}
impl WaitGroup {
pub fn token(&self) -> WaitToken {
self.inner.token_count.fetch_add(1, Ordering::Relaxed);
WaitToken {
inner: Arc::clone(&self.inner),
}
}
pub async fn wait(&self) {
let was_waiting = self.inner.is_waiting.swap(true, Ordering::Relaxed);
assert!(!was_waiting);
WaitGroupFuture { inner: &self.inner }.await
}
}
struct WaitGroupFuture<'a> {
inner: &'a Arc<WaitGroupInner>,
}
impl Future for WaitGroupFuture<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.inner.token_count.load(Ordering::Acquire) == 0 {
return Poll::Ready(());
}
let mut waker_lock = self.inner.waker.lock();
if self.inner.token_count.load(Ordering::Acquire) == 0 {
return Poll::Ready(());
}
let waker = cx.waker().clone();
*waker_lock = Some(waker);
Poll::Pending
}
}
impl Drop for WaitGroupFuture<'_> {
fn drop(&mut self) {
self.inner.is_waiting.store(false, Ordering::Relaxed);
}
}
#[derive(Debug)]
pub struct WaitToken {
inner: Arc<WaitGroupInner>,
}
impl Clone for WaitToken {
fn clone(&self) -> Self {
self.inner.token_count.fetch_add(1, Ordering::Relaxed);
Self {
inner: self.inner.clone(),
}
}
}
impl Drop for WaitToken {
fn drop(&mut self) {
if self.inner.token_count.fetch_sub(1, Ordering::Release) == 1 {
if let Some(w) = self.inner.waker.lock().take() {
w.wake();
}
}
}
}