Skip to main content

radiate_core/domain/sync/
group.rs

1use std::sync::{
2    Arc, Condvar, Mutex,
3    atomic::{AtomicUsize, Ordering},
4};
5
6#[derive(Clone)]
7pub struct WaitGroup {
8    inner: Arc<Inner>,
9    total_count: Arc<AtomicUsize>,
10}
11
12struct Inner {
13    counter: AtomicUsize,
14    lock: Mutex<()>,
15    cvar: Condvar,
16}
17
18impl WaitGroup {
19    pub fn new() -> Self {
20        Self {
21            inner: Arc::new(Inner {
22                counter: AtomicUsize::new(0),
23                lock: Mutex::new(()),
24                cvar: Condvar::new(),
25            }),
26            total_count: Arc::new(AtomicUsize::new(0)),
27        }
28    }
29
30    pub fn get_count(&self) -> usize {
31        self.total_count.load(Ordering::Acquire)
32    }
33
34    pub fn guard(&self) -> WaitGuard {
35        self.inner.counter.fetch_add(1, Ordering::AcqRel);
36        self.total_count.fetch_add(1, Ordering::AcqRel);
37
38        WaitGuard { wg: self.clone() }
39    }
40
41    /// Waits until the counter reaches zero.
42    pub fn wait(&self) -> usize {
43        if self.inner.counter.load(Ordering::Acquire) == 0 {
44            return 0;
45        }
46
47        let lock = self.inner.lock.lock().unwrap();
48        let _unused = self
49            .inner
50            .cvar
51            .wait_while(lock, |_| self.inner.counter.load(Ordering::Acquire) != 0);
52
53        self.get_count()
54    }
55}
56
57pub struct WaitGuard {
58    wg: WaitGroup,
59}
60
61impl Drop for WaitGuard {
62    fn drop(&mut self) {
63        if self.wg.inner.counter.fetch_sub(1, Ordering::AcqRel) == 1 {
64            let _guard = self.wg.inner.lock.lock().unwrap();
65            self.wg.inner.cvar.notify_all();
66        }
67    }
68}