reifydb_core/util/
wait_group.rs1use std::sync::{
13 Arc,
14 atomic::{AtomicUsize, Ordering},
15};
16
17use reifydb_runtime::sync::{condvar::Condvar, mutex::Mutex};
18
19struct Inner {
20 count: AtomicUsize,
21 mutex: Mutex<()>,
22 condvar: Condvar,
23}
24
25pub struct WaitGroup {
26 inner: Arc<Inner>,
27}
28
29impl Default for WaitGroup {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl From<usize> for WaitGroup {
36 fn from(count: usize) -> Self {
37 Self {
38 inner: Arc::new(Inner {
39 count: AtomicUsize::new(count),
40 mutex: Mutex::new(()),
41 condvar: Condvar::new(),
42 }),
43 }
44 }
45}
46
47impl Clone for WaitGroup {
48 fn clone(&self) -> Self {
49 Self {
50 inner: self.inner.clone(),
51 }
52 }
53}
54
55impl std::fmt::Debug for WaitGroup {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 let count = self.inner.count.load(Ordering::Acquire);
58 f.debug_struct("WaitGroup").field("count", &count).finish()
59 }
60}
61
62impl WaitGroup {
63 pub fn new() -> Self {
64 Self {
65 inner: Arc::new(Inner {
66 count: AtomicUsize::new(0),
67 mutex: Mutex::new(()),
68 condvar: Condvar::new(),
69 }),
70 }
71 }
72
73 pub fn add(&self, num: usize) -> Self {
74 self.inner.count.fetch_add(num, Ordering::AcqRel);
75 Self {
76 inner: self.inner.clone(),
77 }
78 }
79
80 pub fn done(&self) -> usize {
81 let prev = self.inner.count.fetch_sub(1, Ordering::AcqRel);
82 if prev == 1 {
83 let _guard = self.inner.mutex.lock();
84 self.inner.condvar.notify_all();
85 }
86 if prev == 0 {
87 self.inner.count.fetch_add(1, Ordering::AcqRel);
89 return 0;
90 }
91 prev - 1
92 }
93
94 pub fn waitings(&self) -> usize {
95 self.inner.count.load(Ordering::Acquire)
96 }
97
98 pub fn wait(&self) {
99 let mut guard = self.inner.mutex.lock();
100 while self.inner.count.load(Ordering::Acquire) > 0 {
101 self.inner.condvar.wait(&mut guard);
102 }
103 }
104}
105
106#[cfg(test)]
107pub mod tests {
108 use std::{
109 sync::{
110 Arc,
111 atomic::{AtomicUsize, Ordering},
112 },
113 time::Duration,
114 };
115
116 use crate::util::wait_group::WaitGroup;
117
118 #[test]
119 fn test_wait_group_reuse() {
120 let wg = WaitGroup::new();
121 let ctr = Arc::new(AtomicUsize::new(0));
122 for _ in 0..6 {
123 let wg = wg.add(1);
124 let ctrx = ctr.clone();
125 std::thread::spawn(move || {
126 std::thread::sleep(Duration::from_millis(5));
127 ctrx.fetch_add(1, Ordering::Relaxed);
128 wg.done();
129 });
130 }
131
132 wg.wait();
133 assert_eq!(ctr.load(Ordering::Relaxed), 6);
134
135 let worker = wg.add(1);
136 let ctrx = ctr.clone();
137 std::thread::spawn(move || {
138 std::thread::sleep(Duration::from_millis(5));
139 ctrx.fetch_add(1, Ordering::Relaxed);
140 worker.done();
141 });
142 wg.wait();
143 assert_eq!(ctr.load(Ordering::Relaxed), 7);
144 }
145
146 #[test]
147 fn test_wait_group_nested() {
148 let wg = WaitGroup::new();
149 let ctr = Arc::new(AtomicUsize::new(0));
150 for _ in 0..5 {
151 let worker = wg.add(1);
152 let ctrx = ctr.clone();
153 std::thread::spawn(move || {
154 let nested_worker = worker.add(1);
155 let ctrxx = ctrx.clone();
156 std::thread::spawn(move || {
157 ctrxx.fetch_add(1, Ordering::Relaxed);
158 nested_worker.done();
159 });
160 ctrx.fetch_add(1, Ordering::Relaxed);
161 worker.done();
162 });
163 }
164
165 wg.wait();
166 assert_eq!(ctr.load(Ordering::Relaxed), 10);
167 }
168
169 #[test]
170 fn test_wait_group_from() {
171 let wg = WaitGroup::from(5);
172 for _ in 0..5 {
173 let t = wg.clone();
174 std::thread::spawn(move || {
175 t.done();
176 });
177 }
178 wg.wait();
179 }
180
181 #[test]
182 fn test_clone_and_fmt() {
183 let swg = WaitGroup::new();
184 let swg1 = swg.clone();
185 swg1.add(3);
186 assert_eq!(format!("{:?}", swg), format!("{:?}", swg1));
187 }
188
189 #[test]
190 fn test_waitings() {
191 let wg = WaitGroup::new();
192 wg.add(1);
193 wg.add(1);
194 assert_eq!(wg.waitings(), 2);
195 }
196}