1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
//! Golang like [WaitGroup](https://pkg.go.dev/sync#WaitGroup) implementation.
//!
//! ## Usage
//!
//! Add this to your `Cargo.toml`:
//!
//! ```toml
//! [build-dependencies]
//! await-group = "0.1"
//! ```
//!
//! ## Example
//! ```rust
//! use await_group::AwaitGroup;
//!
//! #[tokio::main]
//! async fn main() {
//!     let wg = AwaitGroup::new();
//!     for _ in 0..10 {
//!         let w = wg.clone();
//!         tokio::spawn(async move {
//!             _ = w;
//!         });
//!     }
//!     wg.await;
//! }
//!
//! ```

extern crate alloc;

use core::{
    future::{Future, IntoFuture},
    pin::Pin,
    task::{Context, Poll},
};

use alloc::{sync::Arc, sync::Weak};

use atomic_waker::AtomicWaker;

#[derive(Clone, Default)]
pub struct AwaitGroup {
    inner: Arc<Inner>,
}

impl AwaitGroup {
    pub fn new() -> Self {
        Self {
            inner: Arc::new(Inner {
                waker: AtomicWaker::new(),
            }),
        }
    }
}

impl IntoFuture for AwaitGroup {
    type Output = ();

    type IntoFuture = AwaitGroupFuture;

    fn into_future(self) -> Self::IntoFuture {
        AwaitGroupFuture {
            inner: Arc::downgrade(&self.inner),
        }
    }
}

pub struct AwaitGroupFuture {
    inner: Weak<Inner>,
}

impl Future for AwaitGroupFuture {
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        match self.inner.upgrade() {
            Some(inner) => {
                inner.waker.register(cx.waker());
                Poll::Pending
            }
            None => Poll::Ready(()),
        }
    }
}

#[derive(Default)]
struct Inner {
    waker: AtomicWaker,
}

impl Drop for Inner {
    fn drop(&mut self) {
        self.waker.wake();
    }
}

#[cfg(test)]
mod test {
    use crate::AwaitGroup;

    #[tokio::test]
    async fn smoke() {
        let wg = AwaitGroup::new();
        for _ in 0..10 {
            let w = wg.clone();
            tokio::spawn(async move {
                _ = w;
            });
        }
        wg.await;
    }
}