1#![deny(missing_docs)]
4#![deny(unsafe_code)]
5#![deny(unused_qualifications)]
6
7extern crate alloc;
8
9use alloc::sync::Arc;
10use core::fmt;
11use core::marker::PhantomPinned;
12use core::pin::Pin;
13use core::sync::atomic::{AtomicUsize, Ordering};
14use core::task::Poll;
15
16use event_listener::{Event, EventListener};
17use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
18use futures_core::ready;
19use pin_project_lite::pin_project;
20
21pub struct WaitGroup {
49 inner: Arc<WgInner>,
50}
51
52struct WgInner {
54 count: AtomicUsize,
55 drop_ops: Event,
56}
57
58impl Default for WaitGroup {
59 fn default() -> Self {
60 Self {
61 inner: Arc::new(WgInner {
62 count: AtomicUsize::new(1),
63 drop_ops: Event::new(),
64 }),
65 }
66 }
67}
68
69impl WaitGroup {
70 pub fn new() -> Self {
80 Self::default()
81 }
82
83 pub fn wait(self) -> Wait {
106 let w = Wait::_new(WaitInner {
107 wg: self.inner.clone(),
108 listener: None,
109 _pin: PhantomPinned,
110 });
111 drop(self);
112 w
113 }
114
115 #[cfg(all(feature = "std", not(target_family = "wasm")))]
136 pub fn wait_blocking(self) {
137 self.wait().wait();
138 }
139}
140
141easy_wrapper! {
142 #[must_use = "futures do nothing unless you `.await` or poll them"]
144 pub struct Wait(WaitInner => ());
145 #[cfg(all(feature = "std", not(target_family = "wasm")))]
146 pub(crate) wait();
147}
148
149pin_project! {
150 #[project(!Unpin)]
151 struct WaitInner {
152 wg: Arc<WgInner>,
153 listener: Option<EventListener>,
154 #[pin]
155 _pin: PhantomPinned
156 }
157}
158
159impl EventListenerFuture for WaitInner {
160 type Output = ();
161
162 fn poll_with_strategy<'a, S: Strategy<'a>>(
163 self: Pin<&mut Self>,
164 strategy: &mut S,
165 context: &mut S::Context,
166 ) -> Poll<Self::Output> {
167 let this = self.project();
168
169 if this.wg.count.load(Ordering::SeqCst) == 0 {
170 return Poll::Ready(());
171 }
172
173 let mut count = this.wg.count.load(Ordering::SeqCst);
174 while count > 0 {
175 if this.listener.is_some() {
176 ready!(strategy.poll(&mut *this.listener, context))
177 } else {
178 *this.listener = Some(this.wg.drop_ops.listen());
179 }
180 count = this.wg.count.load(Ordering::SeqCst);
181 }
182
183 Poll::Ready(())
184 }
185}
186
187impl Drop for WaitGroup {
188 fn drop(&mut self) {
189 if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
190 self.inner.drop_ops.notify(usize::MAX);
191 }
192 }
193}
194
195impl Clone for WaitGroup {
196 fn clone(&self) -> Self {
197 self.inner.count.fetch_add(1, Ordering::SeqCst);
198
199 Self {
200 inner: self.inner.clone(),
201 }
202 }
203}
204
205impl fmt::Debug for WaitGroup {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 let count = self.inner.count.load(Ordering::SeqCst);
208 f.debug_struct("WaitGroup").field("count", &count).finish()
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 #[cfg(feature = "std")]
216 use std::thread;
217
218 #[tokio::test]
219 async fn test_wait() {
220 const LOOP: usize = if cfg!(miri) { 100 } else { 10_000 };
221
222 let wg = WaitGroup::new();
223 let cnt = Arc::new(AtomicUsize::new(0));
224
225 for _ in 0..LOOP {
226 tokio::spawn({
227 let wg = wg.clone();
228 let cnt = cnt.clone();
229 async move {
230 cnt.fetch_add(1, Ordering::Relaxed);
231 drop(wg);
232 }
233 });
234 }
235
236 wg.wait().await;
237 assert_eq!(cnt.load(Ordering::Relaxed), LOOP)
238 }
239
240 #[cfg(all(feature = "std", not(target_family = "wasm")))]
241 #[test]
242 fn test_wait_blocking() {
243 const LOOP: usize = 100;
244
245 let wg = WaitGroup::new();
246 let cnt = Arc::new(AtomicUsize::new(0));
247
248 for _ in 0..LOOP {
249 thread::spawn({
250 let wg = wg.clone();
251 let cnt = cnt.clone();
252 move || {
253 cnt.fetch_add(1, Ordering::Relaxed);
254 drop(wg);
255 }
256 });
257 }
258
259 wg.wait_blocking();
260 assert_eq!(cnt.load(Ordering::Relaxed), LOOP)
261 }
262
263 #[test]
264 fn test_clone() {
265 let wg = WaitGroup::new();
266 assert_eq!(Arc::strong_count(&wg.inner), 1);
267
268 let wg2 = wg.clone();
269 assert_eq!(Arc::strong_count(&wg.inner), 2);
270 assert_eq!(Arc::strong_count(&wg2.inner), 2);
271 drop(wg2);
272 assert_eq!(Arc::strong_count(&wg.inner), 1);
273 }
274
275 #[tokio::test]
276 async fn test_futures() {
277 let wg = WaitGroup::new();
278 let wg2 = wg.clone();
279
280 let w = wg.wait();
281 pin_utils::pin_mut!(w);
282 assert_eq!(futures_util::poll!(w.as_mut()), Poll::Pending);
283 assert_eq!(futures_util::poll!(w.as_mut()), Poll::Pending);
284
285 drop(wg2);
286 assert_eq!(futures_util::poll!(w.as_mut()), Poll::Ready(()));
287 }
288}