use crate::sync::{Condvar, Mutex};
use std::fmt;
use std::sync::Arc;
pub struct WaitGroup {
inner: Arc<Inner>,
}
struct Inner {
cvar: Condvar,
count: Mutex<usize>,
}
impl Default for WaitGroup {
fn default() -> Self {
Self {
inner: Arc::new(Inner {
cvar: Condvar::new(),
count: Mutex::new(1),
}),
}
}
}
impl WaitGroup {
pub fn new() -> Self {
Self::default()
}
pub fn wait(self) {
if *self.inner.count.lock().unwrap() == 1 {
return;
}
let inner = self.inner.clone();
drop(self);
let mut count = inner.count.lock().unwrap();
while *count > 0 {
count = inner.cvar.wait(count).unwrap();
}
}
}
impl Drop for WaitGroup {
fn drop(&mut self) {
let mut count = self.inner.count.lock().unwrap();
*count -= 1;
if *count == 0 {
self.inner.cvar.notify_all();
}
}
}
impl Clone for WaitGroup {
fn clone(&self) -> Self {
let mut count = self.inner.count.lock().unwrap();
*count += 1;
Self {
inner: self.inner.clone(),
}
}
}
impl fmt::Debug for WaitGroup {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let count: &usize = &self.inner.count.lock().unwrap();
f.debug_struct("WaitGroup").field("count", count).finish()
}
}
#[cfg(test)]
mod tests {
use super::WaitGroup;
use crate::sync::mpsc;
use std::time::Duration;
const THREADS: usize = 10;
#[test]
fn wait() {
let wg = WaitGroup::new();
let (tx, rx) = mpsc::channel();
for _ in 0..THREADS {
let wg = wg.clone();
let tx = tx.clone();
go!(move || {
wg.wait();
tx.send(()).unwrap();
});
}
crate::coroutine::sleep(Duration::from_millis(100));
assert!(rx.try_recv().is_err());
wg.wait();
for _ in 0..THREADS {
rx.recv().unwrap();
}
}
#[test]
fn wait_and_drop() {
let wg = WaitGroup::new();
let wg2 = WaitGroup::new();
let (tx, rx) = mpsc::channel();
for _ in 0..THREADS {
let wg = wg.clone();
let wg2 = wg2.clone();
let tx = tx.clone();
go!(move || {
wg2.wait();
tx.send(()).unwrap();
drop(wg);
});
}
assert!(rx.try_recv().is_err());
drop(wg2);
wg.wait();
for _ in 0..THREADS {
rx.try_recv().unwrap();
}
}
}