reifydb_core/util/
wait_group.rs1use std::sync::{Arc, Condvar, Mutex};
13
14struct Inner {
15 cvar: Condvar,
16 count: Mutex<usize>,
17}
18
19pub struct WaitGroup {
20 inner: Arc<Inner>,
21}
22
23impl Default for WaitGroup {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl From<usize> for WaitGroup {
30 fn from(count: usize) -> Self {
31 Self {
32 inner: Arc::new(Inner {
33 cvar: Condvar::new(),
34 count: Mutex::new(count),
35 }),
36 }
37 }
38}
39
40impl Clone for WaitGroup {
41 fn clone(&self) -> Self {
42 Self {
43 inner: self.inner.clone(),
44 }
45 }
46}
47
48impl std::fmt::Debug for WaitGroup {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 let count = self.inner.count.lock().unwrap();
51 f.debug_struct("WaitGroup").field("count", &*count).finish()
52 }
53}
54
55impl WaitGroup {
56 pub fn new() -> Self {
57 Self {
58 inner: Arc::new(Inner {
59 cvar: Condvar::new(),
60 count: Mutex::new(0),
61 }),
62 }
63 }
64
65 pub fn add(&self, num: usize) -> Self {
66 let mut ctr = self.inner.count.lock().unwrap();
67 *ctr += num;
68 Self {
69 inner: self.inner.clone(),
70 }
71 }
72
73 pub fn done(&self) -> usize {
74 let mut val = self.inner.count.lock().unwrap();
75
76 *val = if val.eq(&1) {
77 self.inner.cvar.notify_all();
78 0
79 } else if val.eq(&0) {
80 0
81 } else {
82 *val - 1
83 };
84 *val
85 }
86
87 pub fn waitings(&self) -> usize {
88 *self.inner.count.lock().unwrap()
89 }
90
91 pub fn wait(&self) {
92 let mut ctr = self.inner.count.lock().unwrap();
93
94 if ctr.eq(&0) {
95 return;
96 }
97
98 while *ctr > 0 {
99 ctr = self.inner.cvar.wait(ctr).unwrap();
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use std::{
107 sync::{
108 Arc,
109 atomic::{AtomicUsize, Ordering},
110 },
111 thread::{sleep, spawn},
112 time::Duration,
113 };
114
115 use crate::util::WaitGroup;
116
117 #[test]
118 fn test_sync_wait_group_reuse() {
119 let wg = WaitGroup::new();
120 let ctr = Arc::new(AtomicUsize::new(0));
121 for _ in 0..6 {
122 let wg = wg.add(1);
123 let ctrx = ctr.clone();
124 spawn(move || {
125 sleep(Duration::from_millis(5));
126 ctrx.fetch_add(1, Ordering::Relaxed);
127 wg.done();
128 });
129 }
130
131 wg.wait();
132 assert_eq!(ctr.load(Ordering::Relaxed), 6);
133
134 let worker = wg.add(1);
135 let ctrx = ctr.clone();
136 spawn(move || {
137 sleep(Duration::from_millis(5));
138 ctrx.fetch_add(1, Ordering::Relaxed);
139 worker.done();
140 });
141 wg.wait();
142 assert_eq!(ctr.load(Ordering::Relaxed), 7);
143 }
144
145 #[test]
146 fn test_sync_wait_group_nested() {
147 let wg = WaitGroup::new();
148 let ctr = Arc::new(AtomicUsize::new(0));
149 for _ in 0..5 {
150 let worker = wg.add(1);
151 let ctrx = ctr.clone();
152 spawn(move || {
153 let nested_worker = worker.add(1);
154 let ctrxx = ctrx.clone();
155 spawn(move || {
156 ctrxx.fetch_add(1, Ordering::Relaxed);
157 nested_worker.done();
158 });
159 ctrx.fetch_add(1, Ordering::Relaxed);
160 worker.done();
161 });
162 }
163
164 wg.wait();
165 assert_eq!(ctr.load(Ordering::Relaxed), 10);
166 }
167
168 #[test]
169 fn test_sync_wait_group_from() {
170 std::thread::scope(|s| {
171 let wg = WaitGroup::from(5);
172 for _ in 0..5 {
173 let t = wg.clone();
174 s.spawn(move || {
175 t.done();
176 });
177 }
178 wg.wait();
179 });
180 }
181
182 #[test]
183 fn test_clone_and_fmt() {
184 let swg = WaitGroup::new();
185 let swg1 = swg.clone();
186 swg1.add(3);
187 assert_eq!(format!("{:?}", swg), format!("{:?}", swg1));
188 }
189
190 #[test]
191 fn test_waitings() {
192 let wg = WaitGroup::new();
193 wg.add(1);
194 wg.add(1);
195 assert_eq!(wg.waitings(), 2);
196 }
197}