Skip to main content

radiate_core/domain/sync/
group.rs

1use std::sync::{
2    Arc, Condvar, Mutex,
3    atomic::{AtomicUsize, Ordering},
4};
5
6#[derive(Clone)]
7pub struct WaitGroup {
8    inner: Arc<Inner>,
9    total_count: Arc<AtomicUsize>,
10}
11
12struct Inner {
13    counter: AtomicUsize,
14    lock: Mutex<()>,
15    cvar: Condvar,
16}
17
18impl Default for WaitGroup {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl WaitGroup {
25    pub fn new() -> Self {
26        Self {
27            inner: Arc::new(Inner {
28                counter: AtomicUsize::new(0),
29                lock: Mutex::new(()),
30                cvar: Condvar::new(),
31            }),
32            total_count: Arc::new(AtomicUsize::new(0)),
33        }
34    }
35
36    pub fn get_count(&self) -> usize {
37        self.total_count.load(Ordering::Acquire)
38    }
39
40    pub fn guard(&self) -> WaitGuard {
41        self.inner.counter.fetch_add(1, Ordering::AcqRel);
42        self.total_count.fetch_add(1, Ordering::AcqRel);
43
44        WaitGuard { wg: self.clone() }
45    }
46
47    /// Waits until the counter reaches zero.
48    pub fn wait(&self) -> usize {
49        if self.inner.counter.load(Ordering::Acquire) == 0 {
50            return 0;
51        }
52
53        let lock = self.inner.lock.lock().unwrap();
54        let _unused = self
55            .inner
56            .cvar
57            .wait_while(lock, |_| self.inner.counter.load(Ordering::Acquire) != 0);
58
59        self.get_count()
60    }
61}
62
63pub struct WaitGuard {
64    wg: WaitGroup,
65}
66
67impl Drop for WaitGuard {
68    fn drop(&mut self) {
69        if self.wg.inner.counter.fetch_sub(1, Ordering::AcqRel) == 1 {
70            let _guard = self.wg.inner.lock.lock().unwrap();
71            self.wg.inner.cvar.notify_all();
72        }
73    }
74}