Skip to main content

reifydb_core/util/
wait_group.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4// This file includes and modifies code from the wg project (https://github.com/al8n/wg),
5// originally licensed under the Apache License, Version 2.0.
6// Original copyright:
7//   Copyright (c) 2024 Al Liu
8//
9// The original Apache License can be found at:
10//   http://www.apache.org/licenses/LICENSE-2.0
11
12use std::sync::{
13	Arc,
14	atomic::{AtomicUsize, Ordering},
15};
16
17use reifydb_runtime::sync::{condvar::Condvar, mutex::Mutex};
18
19struct Inner {
20	count: AtomicUsize,
21	mutex: Mutex<()>,
22	condvar: Condvar,
23}
24
25pub struct WaitGroup {
26	inner: Arc<Inner>,
27}
28
29impl Default for WaitGroup {
30	fn default() -> Self {
31		Self::new()
32	}
33}
34
35impl From<usize> for WaitGroup {
36	fn from(count: usize) -> Self {
37		Self {
38			inner: Arc::new(Inner {
39				count: AtomicUsize::new(count),
40				mutex: Mutex::new(()),
41				condvar: Condvar::new(),
42			}),
43		}
44	}
45}
46
47impl Clone for WaitGroup {
48	fn clone(&self) -> Self {
49		Self {
50			inner: self.inner.clone(),
51		}
52	}
53}
54
55impl std::fmt::Debug for WaitGroup {
56	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57		let count = self.inner.count.load(Ordering::Acquire);
58		f.debug_struct("WaitGroup").field("count", &count).finish()
59	}
60}
61
62impl WaitGroup {
63	pub fn new() -> Self {
64		Self {
65			inner: Arc::new(Inner {
66				count: AtomicUsize::new(0),
67				mutex: Mutex::new(()),
68				condvar: Condvar::new(),
69			}),
70		}
71	}
72
73	pub fn add(&self, num: usize) -> Self {
74		self.inner.count.fetch_add(num, Ordering::AcqRel);
75		Self {
76			inner: self.inner.clone(),
77		}
78	}
79
80	pub fn done(&self) -> usize {
81		let prev = self.inner.count.fetch_sub(1, Ordering::AcqRel);
82		if prev == 1 {
83			let _guard = self.inner.mutex.lock();
84			self.inner.condvar.notify_all();
85		}
86		if prev == 0 {
87			// Already at zero, restore it (shouldn't happen in correct usage)
88			self.inner.count.fetch_add(1, Ordering::AcqRel);
89			return 0;
90		}
91		prev - 1
92	}
93
94	pub fn waitings(&self) -> usize {
95		self.inner.count.load(Ordering::Acquire)
96	}
97
98	pub fn wait(&self) {
99		let mut guard = self.inner.mutex.lock();
100		while self.inner.count.load(Ordering::Acquire) > 0 {
101			self.inner.condvar.wait(&mut guard);
102		}
103	}
104}
105
106#[cfg(test)]
107pub mod tests {
108	use std::{
109		sync::{
110			Arc,
111			atomic::{AtomicUsize, Ordering},
112		},
113		time::Duration,
114	};
115
116	use crate::util::wait_group::WaitGroup;
117
118	#[test]
119	fn test_wait_group_reuse() {
120		let wg = WaitGroup::new();
121		let ctr = Arc::new(AtomicUsize::new(0));
122		for _ in 0..6 {
123			let wg = wg.add(1);
124			let ctrx = ctr.clone();
125			std::thread::spawn(move || {
126				std::thread::sleep(Duration::from_millis(5));
127				ctrx.fetch_add(1, Ordering::Relaxed);
128				wg.done();
129			});
130		}
131
132		wg.wait();
133		assert_eq!(ctr.load(Ordering::Relaxed), 6);
134
135		let worker = wg.add(1);
136		let ctrx = ctr.clone();
137		std::thread::spawn(move || {
138			std::thread::sleep(Duration::from_millis(5));
139			ctrx.fetch_add(1, Ordering::Relaxed);
140			worker.done();
141		});
142		wg.wait();
143		assert_eq!(ctr.load(Ordering::Relaxed), 7);
144	}
145
146	#[test]
147	fn test_wait_group_nested() {
148		let wg = WaitGroup::new();
149		let ctr = Arc::new(AtomicUsize::new(0));
150		for _ in 0..5 {
151			let worker = wg.add(1);
152			let ctrx = ctr.clone();
153			std::thread::spawn(move || {
154				let nested_worker = worker.add(1);
155				let ctrxx = ctrx.clone();
156				std::thread::spawn(move || {
157					ctrxx.fetch_add(1, Ordering::Relaxed);
158					nested_worker.done();
159				});
160				ctrx.fetch_add(1, Ordering::Relaxed);
161				worker.done();
162			});
163		}
164
165		wg.wait();
166		assert_eq!(ctr.load(Ordering::Relaxed), 10);
167	}
168
169	#[test]
170	fn test_wait_group_from() {
171		let wg = WaitGroup::from(5);
172		for _ in 0..5 {
173			let t = wg.clone();
174			std::thread::spawn(move || {
175				t.done();
176			});
177		}
178		wg.wait();
179	}
180
181	#[test]
182	fn test_clone_and_fmt() {
183		let swg = WaitGroup::new();
184		let swg1 = swg.clone();
185		swg1.add(3);
186		assert_eq!(format!("{:?}", swg), format!("{:?}", swg1));
187	}
188
189	#[test]
190	fn test_waitings() {
191		let wg = WaitGroup::new();
192		wg.add(1);
193		wg.add(1);
194		assert_eq!(wg.waitings(), 2);
195	}
196}