reifydb_core/util/
wait_group.rs1use std::sync::{
13 Arc,
14 atomic::{AtomicUsize, Ordering},
15};
16
17use tokio::sync::Notify;
18
19struct Inner {
20 count: AtomicUsize,
21 notify: Notify,
22}
23
24pub struct WaitGroup {
25 inner: Arc<Inner>,
26}
27
28impl Default for WaitGroup {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl From<usize> for WaitGroup {
35 fn from(count: usize) -> Self {
36 Self {
37 inner: Arc::new(Inner {
38 count: AtomicUsize::new(count),
39 notify: Notify::new(),
40 }),
41 }
42 }
43}
44
45impl Clone for WaitGroup {
46 fn clone(&self) -> Self {
47 Self {
48 inner: self.inner.clone(),
49 }
50 }
51}
52
53impl std::fmt::Debug for WaitGroup {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 let count = self.inner.count.load(Ordering::Acquire);
56 f.debug_struct("WaitGroup").field("count", &count).finish()
57 }
58}
59
60impl WaitGroup {
61 pub fn new() -> Self {
62 Self {
63 inner: Arc::new(Inner {
64 count: AtomicUsize::new(0),
65 notify: Notify::new(),
66 }),
67 }
68 }
69
70 pub fn add(&self, num: usize) -> Self {
71 self.inner.count.fetch_add(num, Ordering::AcqRel);
72 Self {
73 inner: self.inner.clone(),
74 }
75 }
76
77 pub fn done(&self) -> usize {
78 let prev = self.inner.count.fetch_sub(1, Ordering::AcqRel);
79 if prev == 1 {
80 self.inner.notify.notify_waiters();
81 }
82 if prev == 0 {
83 self.inner.count.fetch_add(1, Ordering::AcqRel);
85 return 0;
86 }
87 prev - 1
88 }
89
90 pub fn waitings(&self) -> usize {
91 self.inner.count.load(Ordering::Acquire)
92 }
93
94 pub async fn wait(&self) {
95 loop {
96 if self.inner.count.load(Ordering::Acquire) == 0 {
97 return;
98 }
99 self.inner.notify.notified().await;
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use std::sync::{
107 Arc,
108 atomic::{AtomicUsize, Ordering},
109 };
110
111 use tokio::time::{Duration, sleep};
112
113 use crate::util::WaitGroup;
114
115 #[tokio::test]
116 async fn test_wait_group_reuse() {
117 let wg = WaitGroup::new();
118 let ctr = Arc::new(AtomicUsize::new(0));
119 for _ in 0..6 {
120 let wg = wg.add(1);
121 let ctrx = ctr.clone();
122 tokio::spawn(async move {
123 sleep(Duration::from_millis(5)).await;
124 ctrx.fetch_add(1, Ordering::Relaxed);
125 wg.done();
126 });
127 }
128
129 wg.wait().await;
130 assert_eq!(ctr.load(Ordering::Relaxed), 6);
131
132 let worker = wg.add(1);
133 let ctrx = ctr.clone();
134 tokio::spawn(async move {
135 sleep(Duration::from_millis(5)).await;
136 ctrx.fetch_add(1, Ordering::Relaxed);
137 worker.done();
138 });
139 wg.wait().await;
140 assert_eq!(ctr.load(Ordering::Relaxed), 7);
141 }
142
143 #[tokio::test]
144 async fn test_wait_group_nested() {
145 let wg = WaitGroup::new();
146 let ctr = Arc::new(AtomicUsize::new(0));
147 for _ in 0..5 {
148 let worker = wg.add(1);
149 let ctrx = ctr.clone();
150 tokio::spawn(async move {
151 let nested_worker = worker.add(1);
152 let ctrxx = ctrx.clone();
153 tokio::spawn(async move {
154 ctrxx.fetch_add(1, Ordering::Relaxed);
155 nested_worker.done();
156 });
157 ctrx.fetch_add(1, Ordering::Relaxed);
158 worker.done();
159 });
160 }
161
162 wg.wait().await;
163 assert_eq!(ctr.load(Ordering::Relaxed), 10);
164 }
165
166 #[tokio::test]
167 async fn test_wait_group_from() {
168 let wg = WaitGroup::from(5);
169 for _ in 0..5 {
170 let t = wg.clone();
171 tokio::spawn(async move {
172 t.done();
173 });
174 }
175 wg.wait().await;
176 }
177
178 #[test]
179 fn test_clone_and_fmt() {
180 let swg = WaitGroup::new();
181 let swg1 = swg.clone();
182 swg1.add(3);
183 assert_eq!(format!("{:?}", swg), format!("{:?}", swg1));
184 }
185
186 #[test]
187 fn test_waitings() {
188 let wg = WaitGroup::new();
189 wg.add(1);
190 wg.add(1);
191 assert_eq!(wg.waitings(), 2);
192 }
193}