use std::sync::{Arc, Condvar, Mutex};
#[derive(Clone)]
pub struct WaitGroup {
counter: Arc<(Mutex<i64>, Condvar)>,
}
impl WaitGroup {
pub fn new() -> Self {
WaitGroup {
counter: Arc::new((Mutex::new(0), Condvar::new())),
}
}
pub fn add(&self, delta: i64) {
let (lock, cvar) = &*self.counter;
let mut count = lock.lock().unwrap();
*count += delta;
if *count < 0 {
panic!("negative WaitGroup counter");
}
if *count == 0 {
cvar.notify_all();
}
}
pub fn done(&self) {
self.add(-1);
}
pub fn wait(&self) {
let (lock, cvar) = &*self.counter;
let mut count = lock.lock().unwrap();
while *count > 0 {
count = cvar.wait(count).unwrap();
}
}
}
impl Default for WaitGroup {
fn default() -> Self {
WaitGroup::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicI64, Ordering};
use std::thread;
#[test]
fn it_works() {
let count = Arc::new(AtomicI64::new(0));
let wg = WaitGroup::default();
let n = 10;
for _ in 0..n {
let wg = wg.clone();
wg.add(1);
let count = count.clone();
count.fetch_add(1, Ordering::Relaxed);
thread::spawn(move || {
count.fetch_sub(1, Ordering::Relaxed);
wg.done();
});
}
wg.wait();
assert_eq!(count.load(Ordering::Relaxed), 0);
}
}