reifydb_core/util/
wait_group.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4// This file includes and modifies code from the wg project (https://github.com/al8n/wg),
5// originally licensed under the Apache License, Version 2.0.
6// Original copyright:
7//   Copyright (c) 2024 Al Liu
8//
9// The original Apache License can be found at:
10//   http://www.apache.org/licenses/LICENSE-2.0
11
12use std::sync::{
13	Arc,
14	atomic::{AtomicUsize, Ordering},
15};
16
17use tokio::sync::Notify;
18
19struct Inner {
20	count: AtomicUsize,
21	notify: Notify,
22}
23
24pub struct WaitGroup {
25	inner: Arc<Inner>,
26}
27
28impl Default for WaitGroup {
29	fn default() -> Self {
30		Self::new()
31	}
32}
33
34impl From<usize> for WaitGroup {
35	fn from(count: usize) -> Self {
36		Self {
37			inner: Arc::new(Inner {
38				count: AtomicUsize::new(count),
39				notify: Notify::new(),
40			}),
41		}
42	}
43}
44
45impl Clone for WaitGroup {
46	fn clone(&self) -> Self {
47		Self {
48			inner: self.inner.clone(),
49		}
50	}
51}
52
53impl std::fmt::Debug for WaitGroup {
54	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55		let count = self.inner.count.load(Ordering::Acquire);
56		f.debug_struct("WaitGroup").field("count", &count).finish()
57	}
58}
59
60impl WaitGroup {
61	pub fn new() -> Self {
62		Self {
63			inner: Arc::new(Inner {
64				count: AtomicUsize::new(0),
65				notify: Notify::new(),
66			}),
67		}
68	}
69
70	pub fn add(&self, num: usize) -> Self {
71		self.inner.count.fetch_add(num, Ordering::AcqRel);
72		Self {
73			inner: self.inner.clone(),
74		}
75	}
76
77	pub fn done(&self) -> usize {
78		let prev = self.inner.count.fetch_sub(1, Ordering::AcqRel);
79		if prev == 1 {
80			self.inner.notify.notify_waiters();
81		}
82		if prev == 0 {
83			// Already at zero, restore it (shouldn't happen in correct usage)
84			self.inner.count.fetch_add(1, Ordering::AcqRel);
85			return 0;
86		}
87		prev - 1
88	}
89
90	pub fn waitings(&self) -> usize {
91		self.inner.count.load(Ordering::Acquire)
92	}
93
94	pub async fn wait(&self) {
95		loop {
96			if self.inner.count.load(Ordering::Acquire) == 0 {
97				return;
98			}
99			self.inner.notify.notified().await;
100		}
101	}
102}
103
104#[cfg(test)]
105mod tests {
106	use std::sync::{
107		Arc,
108		atomic::{AtomicUsize, Ordering},
109	};
110
111	use tokio::time::{Duration, sleep};
112
113	use crate::util::WaitGroup;
114
115	#[tokio::test]
116	async fn test_wait_group_reuse() {
117		let wg = WaitGroup::new();
118		let ctr = Arc::new(AtomicUsize::new(0));
119		for _ in 0..6 {
120			let wg = wg.add(1);
121			let ctrx = ctr.clone();
122			tokio::spawn(async move {
123				sleep(Duration::from_millis(5)).await;
124				ctrx.fetch_add(1, Ordering::Relaxed);
125				wg.done();
126			});
127		}
128
129		wg.wait().await;
130		assert_eq!(ctr.load(Ordering::Relaxed), 6);
131
132		let worker = wg.add(1);
133		let ctrx = ctr.clone();
134		tokio::spawn(async move {
135			sleep(Duration::from_millis(5)).await;
136			ctrx.fetch_add(1, Ordering::Relaxed);
137			worker.done();
138		});
139		wg.wait().await;
140		assert_eq!(ctr.load(Ordering::Relaxed), 7);
141	}
142
143	#[tokio::test]
144	async fn test_wait_group_nested() {
145		let wg = WaitGroup::new();
146		let ctr = Arc::new(AtomicUsize::new(0));
147		for _ in 0..5 {
148			let worker = wg.add(1);
149			let ctrx = ctr.clone();
150			tokio::spawn(async move {
151				let nested_worker = worker.add(1);
152				let ctrxx = ctrx.clone();
153				tokio::spawn(async move {
154					ctrxx.fetch_add(1, Ordering::Relaxed);
155					nested_worker.done();
156				});
157				ctrx.fetch_add(1, Ordering::Relaxed);
158				worker.done();
159			});
160		}
161
162		wg.wait().await;
163		assert_eq!(ctr.load(Ordering::Relaxed), 10);
164	}
165
166	#[tokio::test]
167	async fn test_wait_group_from() {
168		let wg = WaitGroup::from(5);
169		for _ in 0..5 {
170			let t = wg.clone();
171			tokio::spawn(async move {
172				t.done();
173			});
174		}
175		wg.wait().await;
176	}
177
178	#[test]
179	fn test_clone_and_fmt() {
180		let swg = WaitGroup::new();
181		let swg1 = swg.clone();
182		swg1.add(3);
183		assert_eq!(format!("{:?}", swg), format!("{:?}", swg1));
184	}
185
186	#[test]
187	fn test_waitings() {
188		let wg = WaitGroup::new();
189		wg.add(1);
190		wg.add(1);
191		assert_eq!(wg.waitings(), 2);
192	}
193}