await_group/
lib.rs

1//! Golang like [WaitGroup](https://pkg.go.dev/sync#WaitGroup) implementation.
2//!
3//! ## Usage
4//!
5//! Add this to your `Cargo.toml`:
6//!
7//! ```toml
8//! [build-dependencies]
9//! await-group = "0.1"
10//! ```
11//!
12//! ## Example
13//! ```rust
14//! use await_group::AwaitGroup;
15//!
16//! #[tokio::main]
17//! async fn main() {
18//!     let wg = AwaitGroup::new();
19//!     for _ in 0..10 {
20//!         let w = wg.clone();
21//!         tokio::spawn(async move {
22//!             _ = w;
23//!         });
24//!     }
25//!     wg.await;
26//! }
27//!
28//! ```
29
30extern crate alloc;
31
32use core::{
33    future::{Future, IntoFuture},
34    pin::Pin,
35    task::{Context, Poll},
36};
37
38use alloc::{sync::Arc, sync::Weak};
39
40use atomic_waker::AtomicWaker;
41
42#[derive(Clone, Default)]
43pub struct AwaitGroup {
44    inner: Arc<Inner>,
45}
46
47impl AwaitGroup {
48    pub fn new() -> Self {
49        Self {
50            inner: Arc::new(Inner {
51                waker: AtomicWaker::new(),
52            }),
53        }
54    }
55}
56
57impl IntoFuture for AwaitGroup {
58    type Output = ();
59
60    type IntoFuture = AwaitGroupFuture;
61
62    fn into_future(self) -> Self::IntoFuture {
63        AwaitGroupFuture {
64            inner: Arc::downgrade(&self.inner),
65        }
66    }
67}
68
69pub struct AwaitGroupFuture {
70    inner: Weak<Inner>,
71}
72
73impl Future for AwaitGroupFuture {
74    type Output = ();
75
76    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
77        match self.inner.upgrade() {
78            Some(inner) => {
79                inner.waker.register(cx.waker());
80                Poll::Pending
81            }
82            None => Poll::Ready(()),
83        }
84    }
85}
86
87#[derive(Default)]
88struct Inner {
89    waker: AtomicWaker,
90}
91
92impl Drop for Inner {
93    fn drop(&mut self) {
94        self.waker.wake();
95    }
96}
97
98#[cfg(test)]
99mod test {
100    use crate::AwaitGroup;
101
102    #[tokio::test]
103    async fn smoke() {
104        let wg = AwaitGroup::new();
105        for _ in 0..10 {
106            let w = wg.clone();
107            tokio::spawn(async move {
108                _ = w;
109            });
110        }
111        wg.await;
112    }
113}