async_waitgroup/
lib.rs

1//! Golang like WaitGroup implementation that supports both sync and async Rust.
2
3#![deny(missing_docs)]
4#![deny(unsafe_code)]
5#![deny(unused_qualifications)]
6
7extern crate alloc;
8
9use alloc::sync::Arc;
10use core::fmt;
11use core::marker::PhantomPinned;
12use core::pin::Pin;
13use core::sync::atomic::{AtomicUsize, Ordering};
14use core::task::Poll;
15
16use event_listener::{Event, EventListener};
17use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
18use futures_core::ready;
19use pin_project_lite::pin_project;
20
21/// Enables tasks to synchronize the beginning or end of some computation.
22///
23/// # Examples
24///
25/// ```
26/// use async_waitgroup::WaitGroup;
27///
28/// # #[tokio::main(flavor = "current_thread")] async fn main() {
29/// // Create a new wait group.
30/// let wg = WaitGroup::new();
31///
32/// for _ in 0..4 {
33///     // Create another reference to the wait group.
34///     let wg = wg.clone();
35///
36///     tokio::spawn(async move {
37///         // Do some work.
38///
39///         // Drop the reference to the wait group.
40///         drop(wg);
41///     });
42/// }
43///
44/// // Block until all tasks have finished their work.
45/// wg.wait().await;
46/// # }
47/// ```
48pub struct WaitGroup {
49    inner: Arc<WgInner>,
50}
51
52/// Inner state of a `WaitGroup`.
53struct WgInner {
54    count: AtomicUsize,
55    drop_ops: Event,
56}
57
58impl Default for WaitGroup {
59    fn default() -> Self {
60        Self {
61            inner: Arc::new(WgInner {
62                count: AtomicUsize::new(1),
63                drop_ops: Event::new(),
64            }),
65        }
66    }
67}
68
69impl WaitGroup {
70    /// Creates a new wait group and returns the single reference to it.
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use async_waitgroup::WaitGroup;
76    ///
77    /// let wg = WaitGroup::new();
78    /// ```
79    pub fn new() -> Self {
80        Self::default()
81    }
82
83    /// Drops this reference and waits until all other references are dropped.
84    ///
85    /// # Examples
86    ///
87    /// ```
88    /// use async_waitgroup::WaitGroup;
89    ///
90    /// # #[tokio::main(flavor = "current_thread")] async fn main() {
91    /// let wg = WaitGroup::new();
92    ///
93    /// tokio::spawn({
94    ///     let wg = wg.clone();
95    ///     async move {
96    ///         // Block until both tasks have reached `wait()`.
97    ///         wg.wait().await;
98    ///     }
99    /// });
100    ///
101    /// // Block until all tasks have finished their work.
102    /// wg.wait().await;
103    /// # }
104    /// ```
105    pub fn wait(self) -> Wait {
106        let w = Wait::_new(WaitInner {
107            wg: self.inner.clone(),
108            listener: None,
109            _pin: PhantomPinned,
110        });
111        drop(self);
112        w
113    }
114
115    /// Waits using the blocking strategy.
116    ///
117    /// # Examples
118    ///
119    /// ```
120    /// use std::thread;
121    ///
122    /// use async_waitgroup::WaitGroup;
123    ///
124    /// let wg = WaitGroup::new();
125    ///
126    /// thread::spawn({
127    ///     let wg = wg.clone();
128    ///     move || {
129    ///         wg.wait_blocking();
130    ///     }
131    /// });
132    ///
133    /// wg.wait_blocking();
134    /// ```
135    #[cfg(all(feature = "std", not(target_family = "wasm")))]
136    pub fn wait_blocking(self) {
137        self.wait().wait();
138    }
139}
140
141easy_wrapper! {
142    /// A future returned by [`WaitGroup::wait()`].
143    #[must_use = "futures do nothing unless you `.await` or poll them"]
144    pub struct Wait(WaitInner => ());
145    #[cfg(all(feature = "std", not(target_family = "wasm")))]
146    pub(crate) wait();
147}
148
149pin_project! {
150    #[project(!Unpin)]
151    struct WaitInner {
152        wg: Arc<WgInner>,
153        listener: Option<EventListener>,
154        #[pin]
155        _pin: PhantomPinned
156    }
157}
158
159impl EventListenerFuture for WaitInner {
160    type Output = ();
161
162    fn poll_with_strategy<'a, S: Strategy<'a>>(
163        self: Pin<&mut Self>,
164        strategy: &mut S,
165        context: &mut S::Context,
166    ) -> Poll<Self::Output> {
167        let this = self.project();
168
169        if this.wg.count.load(Ordering::SeqCst) == 0 {
170            return Poll::Ready(());
171        }
172
173        let mut count = this.wg.count.load(Ordering::SeqCst);
174        while count > 0 {
175            if this.listener.is_some() {
176                ready!(strategy.poll(&mut *this.listener, context))
177            } else {
178                *this.listener = Some(this.wg.drop_ops.listen());
179            }
180            count = this.wg.count.load(Ordering::SeqCst);
181        }
182
183        Poll::Ready(())
184    }
185}
186
187impl Drop for WaitGroup {
188    fn drop(&mut self) {
189        if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
190            self.inner.drop_ops.notify(usize::MAX);
191        }
192    }
193}
194
195impl Clone for WaitGroup {
196    fn clone(&self) -> Self {
197        self.inner.count.fetch_add(1, Ordering::SeqCst);
198
199        Self {
200            inner: self.inner.clone(),
201        }
202    }
203}
204
205impl fmt::Debug for WaitGroup {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        let count = self.inner.count.load(Ordering::SeqCst);
208        f.debug_struct("WaitGroup").field("count", &count).finish()
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    #[cfg(feature = "std")]
216    use std::thread;
217
218    #[tokio::test]
219    async fn test_wait() {
220        const LOOP: usize = if cfg!(miri) { 100 } else { 10_000 };
221
222        let wg = WaitGroup::new();
223        let cnt = Arc::new(AtomicUsize::new(0));
224
225        for _ in 0..LOOP {
226            tokio::spawn({
227                let wg = wg.clone();
228                let cnt = cnt.clone();
229                async move {
230                    cnt.fetch_add(1, Ordering::Relaxed);
231                    drop(wg);
232                }
233            });
234        }
235
236        wg.wait().await;
237        assert_eq!(cnt.load(Ordering::Relaxed), LOOP)
238    }
239
240    #[cfg(all(feature = "std", not(target_family = "wasm")))]
241    #[test]
242    fn test_wait_blocking() {
243        const LOOP: usize = 100;
244
245        let wg = WaitGroup::new();
246        let cnt = Arc::new(AtomicUsize::new(0));
247
248        for _ in 0..LOOP {
249            thread::spawn({
250                let wg = wg.clone();
251                let cnt = cnt.clone();
252                move || {
253                    cnt.fetch_add(1, Ordering::Relaxed);
254                    drop(wg);
255                }
256            });
257        }
258
259        wg.wait_blocking();
260        assert_eq!(cnt.load(Ordering::Relaxed), LOOP)
261    }
262
263    #[test]
264    fn test_clone() {
265        let wg = WaitGroup::new();
266        assert_eq!(Arc::strong_count(&wg.inner), 1);
267
268        let wg2 = wg.clone();
269        assert_eq!(Arc::strong_count(&wg.inner), 2);
270        assert_eq!(Arc::strong_count(&wg2.inner), 2);
271        drop(wg2);
272        assert_eq!(Arc::strong_count(&wg.inner), 1);
273    }
274
275    #[tokio::test]
276    async fn test_futures() {
277        let wg = WaitGroup::new();
278        let wg2 = wg.clone();
279
280        let w = wg.wait();
281        pin_utils::pin_mut!(w);
282        assert_eq!(futures_util::poll!(w.as_mut()), Poll::Pending);
283        assert_eq!(futures_util::poll!(w.as_mut()), Poll::Pending);
284
285        drop(wg2);
286        assert_eq!(futures_util::poll!(w.as_mut()), Poll::Ready(()));
287    }
288}