use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::Notify;
#[derive(Debug, Clone)]
pub(crate) struct WaitGroup {
count: Arc<AtomicUsize>,
notify_on_zero: Arc<Notify>,
}
impl WaitGroup {
pub fn new() -> Self {
Self {
count: Arc::new(AtomicUsize::new(0)),
notify_on_zero: Arc::new(Notify::new()),
}
}
pub fn add(&self, delta: usize) {
if delta == 0 {
return;
}
let old_count = self.count.fetch_add(delta, Ordering::Relaxed);
if old_count == 0 {
tracing::trace!(delta, new_count = delta, "WaitGroup count increased from zero");
}
}
pub fn done(&self) {
let old_count = self.count.fetch_sub(1, Ordering::AcqRel);
if old_count == 0 {
self.count.fetch_add(1, Ordering::Relaxed); panic!("WaitGroup::done() called when count was already zero!");
} else if old_count == 1 {
self.notify_on_zero.notify_waiters();
tracing::trace!("WaitGroup count reached zero, notifying waiters");
}
}
pub async fn wait(&self) {
if self.count.load(Ordering::Acquire) == 0 {
tracing::trace!("WaitGroup::wait() called when count is already zero");
return;
}
loop {
self.notify_on_zero.notified().await;
if self.count.load(Ordering::Acquire) == 0 {
tracing::trace!("WaitGroup::wait() released after notification");
return;
}
tracing::trace!("WaitGroup::wait() woke, but count is non-zero; re-waiting");
}
}
#[allow(dead_code)]
pub fn get_count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn test_waitgroup_add_done_wait() {
let wg = WaitGroup::new();
assert_eq!(wg.get_count(), 0);
wg.add(2); assert_eq!(wg.get_count(), 2);
let wg_clone1 = wg.clone();
let task1 = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
wg_clone1.done();
"Task 1 Done"
});
let notify = Arc::new(Notify::new());
let notify_clone = notify.clone();
let wg_clone2 = wg.clone();
let task2 = tokio::spawn(async move {
notify_clone.notified().await;
wg_clone2.done();
"Task 2 Done"
});
let wg_clone_wait = wg.clone();
let mut wait_task = tokio::spawn(async move {
wg_clone_wait.wait().await;
"Wait Finished"
});
tokio::time::sleep(Duration::from_millis(15)).await;
assert_eq!(wg.get_count(), 1);
assert!(
timeout(Duration::from_millis(5), &mut wait_task).await.is_err(),
"Wait task should not have finished yet"
);
notify.notify_one();
let task1_res = task1.await.unwrap();
let task2_res = task2.await.unwrap();
let wait_res = timeout(Duration::from_millis(50), wait_task).await;
assert_eq!(task1_res, "Task 1 Done");
assert_eq!(task2_res, "Task 2 Done");
assert_eq!(wg.get_count(), 0);
assert!(wait_res.is_ok(), "Wait task should finish after task 2 completes");
assert_eq!(wait_res.unwrap().unwrap(), "Wait Finished");
}
#[tokio::test]
async fn test_waitgroup_wait_on_zero() {
let wg = WaitGroup::new();
let start = tokio::time::Instant::now();
wg.wait().await; assert!(start.elapsed() < Duration::from_millis(10));
assert_eq!(wg.get_count(), 0);
}
#[tokio::test]
async fn test_waitgroup_add_after_wait_starts() {
let wg = WaitGroup::new();
wg.add(1);
let wg_clone_wait = wg.clone();
let mut wait_task = tokio::spawn(async move {
wg_clone_wait.wait().await; });
tokio::time::sleep(Duration::from_millis(10)).await;
wg.add(1); assert_eq!(wg.get_count(), 2);
wg.done(); assert_eq!(wg.get_count(), 1);
assert!(
timeout(Duration::from_millis(10), &mut wait_task).await.is_err(),
"Wait task should still be blocked after one done()"
);
wg.done(); assert_eq!(wg.get_count(), 0);
assert!(
timeout(Duration::from_millis(50), wait_task).await.is_ok(),
"Wait task should complete after second done()"
);
}
#[tokio::test]
#[should_panic]
async fn test_waitgroup_done_panic_on_zero() {
let wg = WaitGroup::new();
wg.done(); }
}