1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4pub struct WaitGroup {
6 pub total: Arc<AtomicU64>,
7 pub recv: Arc<flume::Receiver<u64>>,
8 pub send: Arc<flume::Sender<u64>>,
9}
10
11impl Clone for WaitGroup {
12 fn clone(&self) -> Self {
13 self.add(1);
14 Self {
15 total: self.total.clone(),
16 recv: self.recv.clone(),
17 send: self.send.clone(),
18 }
19 }
20}
21
22impl WaitGroup {
23 pub fn new() -> Self {
24 let (s, r) = flume::unbounded();
25 Self {
26 total: Arc::new(AtomicU64::new(0)),
27 recv: Arc::new(r),
28 send: Arc::new(s),
29 }
30 }
31
32 pub fn add(&self, v: u64) {
33 let current = self.total.fetch_or(0, Ordering::SeqCst);
34 self.total.store(current + v, Ordering::SeqCst);
35 }
36
37 #[cfg(feature = "async")]
38 pub async fn wait_async(&self) {
39 let mut total;
40 let mut current = 0;
41 loop {
42 match self.recv.recv_async().await {
43 Ok(v) => {
44 current += v;
45 total = self.total.fetch_or(0, Ordering::SeqCst);
46 if current >= total {
47 break;
48 }
49 }
50 Err(_) => {
51 break;
52 }
53 }
54 }
55 }
56
57 pub fn wait(&self) {
58 let mut total;
59 let mut current = 0;
60 loop {
61 match self.recv.recv() {
62 Ok(v) => {
63 current += v;
64 total = self.total.fetch_or(0, Ordering::SeqCst);
65 if current >= total {
66 break;
67 }
68 }
69 Err(_) => {
70 break;
71 }
72 }
73 }
74 }
75}
76
77impl Drop for WaitGroup {
78 fn drop(&mut self) {
79 let _ = self.send.send(1);
80 }
81}