reifydb_core/util/
wait_group.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4// This file includes and modifies code from the wg project (https://github.com/al8n/wg),
5// originally licensed under the Apache License, Version 2.0.
6// Original copyright:
7//   Copyright (c) 2024 Al Liu
8//
9// The original Apache License can be found at:
10//   http://www.apache.org/licenses/LICENSE-2.0
11
12use std::sync::{Arc, Condvar, Mutex};
13
14struct Inner {
15	cvar: Condvar,
16	count: Mutex<usize>,
17}
18
19pub struct WaitGroup {
20	inner: Arc<Inner>,
21}
22
23impl Default for WaitGroup {
24	fn default() -> Self {
25		Self::new()
26	}
27}
28
29impl From<usize> for WaitGroup {
30	fn from(count: usize) -> Self {
31		Self {
32			inner: Arc::new(Inner {
33				cvar: Condvar::new(),
34				count: Mutex::new(count),
35			}),
36		}
37	}
38}
39
40impl Clone for WaitGroup {
41	fn clone(&self) -> Self {
42		Self {
43			inner: self.inner.clone(),
44		}
45	}
46}
47
48impl std::fmt::Debug for WaitGroup {
49	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50		let count = self.inner.count.lock().unwrap();
51		f.debug_struct("WaitGroup").field("count", &*count).finish()
52	}
53}
54
55impl WaitGroup {
56	pub fn new() -> Self {
57		Self {
58			inner: Arc::new(Inner {
59				cvar: Condvar::new(),
60				count: Mutex::new(0),
61			}),
62		}
63	}
64
65	pub fn add(&self, num: usize) -> Self {
66		let mut ctr = self.inner.count.lock().unwrap();
67		*ctr += num;
68		Self {
69			inner: self.inner.clone(),
70		}
71	}
72
73	pub fn done(&self) -> usize {
74		let mut val = self.inner.count.lock().unwrap();
75
76		*val = if val.eq(&1) {
77			self.inner.cvar.notify_all();
78			0
79		} else if val.eq(&0) {
80			0
81		} else {
82			*val - 1
83		};
84		*val
85	}
86
87	pub fn waitings(&self) -> usize {
88		*self.inner.count.lock().unwrap()
89	}
90
91	pub fn wait(&self) {
92		let mut ctr = self.inner.count.lock().unwrap();
93
94		if ctr.eq(&0) {
95			return;
96		}
97
98		while *ctr > 0 {
99			ctr = self.inner.cvar.wait(ctr).unwrap();
100		}
101	}
102}
103
104#[cfg(test)]
105mod tests {
106	use std::{
107		sync::{
108			Arc,
109			atomic::{AtomicUsize, Ordering},
110		},
111		thread::{sleep, spawn},
112		time::Duration,
113	};
114
115	use crate::util::WaitGroup;
116
117	#[test]
118	fn test_sync_wait_group_reuse() {
119		let wg = WaitGroup::new();
120		let ctr = Arc::new(AtomicUsize::new(0));
121		for _ in 0..6 {
122			let wg = wg.add(1);
123			let ctrx = ctr.clone();
124			spawn(move || {
125				sleep(Duration::from_millis(5));
126				ctrx.fetch_add(1, Ordering::Relaxed);
127				wg.done();
128			});
129		}
130
131		wg.wait();
132		assert_eq!(ctr.load(Ordering::Relaxed), 6);
133
134		let worker = wg.add(1);
135		let ctrx = ctr.clone();
136		spawn(move || {
137			sleep(Duration::from_millis(5));
138			ctrx.fetch_add(1, Ordering::Relaxed);
139			worker.done();
140		});
141		wg.wait();
142		assert_eq!(ctr.load(Ordering::Relaxed), 7);
143	}
144
145	#[test]
146	fn test_sync_wait_group_nested() {
147		let wg = WaitGroup::new();
148		let ctr = Arc::new(AtomicUsize::new(0));
149		for _ in 0..5 {
150			let worker = wg.add(1);
151			let ctrx = ctr.clone();
152			spawn(move || {
153				let nested_worker = worker.add(1);
154				let ctrxx = ctrx.clone();
155				spawn(move || {
156					ctrxx.fetch_add(1, Ordering::Relaxed);
157					nested_worker.done();
158				});
159				ctrx.fetch_add(1, Ordering::Relaxed);
160				worker.done();
161			});
162		}
163
164		wg.wait();
165		assert_eq!(ctr.load(Ordering::Relaxed), 10);
166	}
167
168	#[test]
169	fn test_sync_wait_group_from() {
170		std::thread::scope(|s| {
171			let wg = WaitGroup::from(5);
172			for _ in 0..5 {
173				let t = wg.clone();
174				s.spawn(move || {
175					t.done();
176				});
177			}
178			wg.wait();
179		});
180	}
181
182	#[test]
183	fn test_clone_and_fmt() {
184		let swg = WaitGroup::new();
185		let swg1 = swg.clone();
186		swg1.add(3);
187		assert_eq!(format!("{:?}", swg), format!("{:?}", swg1));
188	}
189
190	#[test]
191	fn test_waitings() {
192		let wg = WaitGroup::new();
193		wg.add(1);
194		wg.add(1);
195		assert_eq!(wg.waitings(), 2);
196	}
197}