radiate_core/domain/sync/
group.rs1use std::sync::{
2 Arc, Condvar, Mutex,
3 atomic::{AtomicUsize, Ordering},
4};
5
6#[derive(Clone)]
7pub struct WaitGroup {
8 inner: Arc<Inner>,
9 total_count: Arc<AtomicUsize>,
10}
11
12struct Inner {
13 counter: AtomicUsize,
14 lock: Mutex<()>,
15 cvar: Condvar,
16}
17
18impl WaitGroup {
19 pub fn new() -> Self {
20 Self {
21 inner: Arc::new(Inner {
22 counter: AtomicUsize::new(0),
23 lock: Mutex::new(()),
24 cvar: Condvar::new(),
25 }),
26 total_count: Arc::new(AtomicUsize::new(0)),
27 }
28 }
29
30 pub fn get_count(&self) -> usize {
31 self.total_count.load(Ordering::Acquire)
32 }
33
34 pub fn guard(&self) -> WaitGuard {
35 self.inner.counter.fetch_add(1, Ordering::AcqRel);
36 self.total_count.fetch_add(1, Ordering::AcqRel);
37
38 WaitGuard { wg: self.clone() }
39 }
40
41 pub fn wait(&self) -> usize {
43 if self.inner.counter.load(Ordering::Acquire) == 0 {
44 return 0;
45 }
46
47 let lock = self.inner.lock.lock().unwrap();
48 let _unused = self
49 .inner
50 .cvar
51 .wait_while(lock, |_| self.inner.counter.load(Ordering::Acquire) != 0);
52
53 self.get_count()
54 }
55}
56
57pub struct WaitGuard {
58 wg: WaitGroup,
59}
60
61impl Drop for WaitGuard {
62 fn drop(&mut self) {
63 if self.wg.inner.counter.fetch_sub(1, Ordering::AcqRel) == 1 {
64 let _guard = self.wg.inner.lock.lock().unwrap();
65 self.wg.inner.cvar.notify_all();
66 }
67 }
68}