1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4pub struct WaitGroup {
7 pub total: Arc<AtomicU64>,
8 pub recv: Arc<flume::Receiver<u64>>,
9 pub send: Arc<flume::Sender<u64>>,
10}
11
12impl Clone for WaitGroup {
13 fn clone(&self) -> Self {
14 self.add(1);
15 Self {
16 total: self.total.clone(),
17 recv: self.recv.clone(),
18 send: self.send.clone(),
19 }
20 }
21}
22
23impl WaitGroup {
24 pub fn new() -> Self {
25 let (s, r) = flume::unbounded();
26 Self {
27 total: Arc::new(AtomicU64::new(0)),
28 recv: Arc::new(r),
29 send: Arc::new(s),
30 }
31 }
32
33 pub fn add(&self, v: u64) {
34 let current = self.total.fetch_or(0, Ordering::SeqCst);
35 self.total.store(current + v, Ordering::SeqCst);
36 }
37
38 #[cfg(feature = "async")]
39 pub async fn wait_async(&self) {
40 let mut total;
41 let mut current = 0;
42 loop {
43 match self.recv.recv_async().await {
44 Ok(v) => {
45 current += v;
46 total = self.total.fetch_or(0, Ordering::SeqCst);
47 if current >= total {
48 break;
49 }
50 }
51 Err(_) => {
52 break;
53 }
54 }
55 }
56 }
57
58 pub fn wait(&self) {
59 let mut total;
60 let mut current = 0;
61 loop {
62 match self.recv.recv() {
63 Ok(v) => {
64 current += v;
65 total = self.total.fetch_or(0, Ordering::SeqCst);
66 if current >= total {
67 break;
68 }
69 }
70 Err(_) => {
71 break;
72 }
73 }
74 }
75 }
76}
77
78impl Drop for WaitGroup {
79 fn drop(&mut self) {
80 let _ = self.send.send(1);
81 }
82}