mea/waitgroup/
mod.rs

1// Copyright 2024 tison <wander4096@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A synchronization primitive for waiting on multiple tasks to complete.
16//!
17//! Similar to Go's WaitGroup, this type allows a task to wait for multiple other
18//! tasks to finish. Each task holds a handle to the WaitGroup, and the main task
19//! can wait for all handles to be dropped before proceeding.
20//!
21//! A WaitGroup waits for a collection of tasks to finish. The main task calls
22//! [`clone()`] to create a new worker handle for each task, and can then wait
23//! for all tasks to complete by calling `.await` on the WaitGroup.
24//!
25//! # Examples
26//!
27//! ```
28//! # #[tokio::main]
29//! # async fn main() {
30//! use std::time::Duration;
31//!
32//! use mea::waitgroup::WaitGroup;
33//! let wg = WaitGroup::new();
34//!
35//! for i in 0..3 {
36//!     let wg = wg.clone();
37//!     tokio::spawn(async move {
38//!         println!("Task {} starting", i);
39//!         tokio::time::sleep(Duration::from_millis(100)).await;
40//!         // wg is automatically decremented when dropped
41//!         drop(wg);
42//!     });
43//! }
44//!
45//! // Wait for all tasks to complete
46//! wg.await;
47//! println!("All tasks completed");
48//! # }
49//! ```
50//!
51//! [`clone()`]: WaitGroup::clone
52
53use std::fmt;
54use std::future::Future;
55use std::future::IntoFuture;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::Context;
59use std::task::Poll;
60
61use crate::internal::CountdownState;
62
63#[cfg(test)]
64mod tests;
65
66/// A synchronization primitive for waiting on multiple tasks to complete.
67///
68/// See the [module level documentation](self) for more.
69pub struct WaitGroup {
70    state: Arc<CountdownState>,
71}
72
73impl fmt::Debug for WaitGroup {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        f.debug_struct("WaitGroup").finish_non_exhaustive()
76    }
77}
78
79impl Default for WaitGroup {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl WaitGroup {
86    /// Creates a new `WaitGroup`.
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// use mea::waitgroup::WaitGroup;
92    ///
93    /// let wg = WaitGroup::new();
94    /// ```
95    pub fn new() -> Self {
96        Self {
97            state: Arc::new(CountdownState::new(1)),
98        }
99    }
100}
101
102impl Clone for WaitGroup {
103    /// Creates a new worker handle for the WaitGroup.
104    ///
105    /// This increments the WaitGroup counter. The counter will be decremented
106    /// when the new handle is dropped.
107    fn clone(&self) -> Self {
108        let sync = self.state.clone();
109        let mut cnt = sync.state();
110        loop {
111            let new_cnt = cnt.saturating_add(1);
112            match sync.cas_state(cnt, new_cnt) {
113                Ok(_) => return Self { state: sync },
114                Err(x) => cnt = x,
115            }
116        }
117    }
118}
119
120impl Drop for WaitGroup {
121    fn drop(&mut self) {
122        if self.state.decrement(1) {
123            self.state.wake_all();
124        }
125    }
126}
127
128impl IntoFuture for WaitGroup {
129    type Output = ();
130    type IntoFuture = Wait;
131
132    /// Converts the WaitGroup into a future that completes when all tasks finish. This decreases
133    /// the WaitGroup counter.
134    fn into_future(self) -> Self::IntoFuture {
135        let state = self.state.clone();
136        drop(self);
137        Wait { idx: None, state }
138    }
139}
140
141/// A future that completes when all tasks in a WaitGroup have finished.
142///
143/// This type is created by either: (1) calling `.await` on a `WaitGroup`, or (2) cloning
144/// itself, which does not increase the WaitGroup counter, but creates a new future that
145/// will complete when the WaitGroup counter reaches zero.
146#[must_use = "futures do nothing unless you `.await` or poll them"]
147pub struct Wait {
148    idx: Option<usize>,
149    state: Arc<CountdownState>,
150}
151
152impl Clone for Wait {
153    /// Creates a new future that also completes when the WaitGroup counter reaches zero.
154    ///
155    /// This does not increment the WaitGroup counter.
156    fn clone(&self) -> Self {
157        Wait {
158            idx: None,
159            state: self.state.clone(),
160        }
161    }
162}
163
164impl fmt::Debug for Wait {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        f.debug_struct("Wait").finish_non_exhaustive()
167    }
168}
169
170impl Future for Wait {
171    type Output = ();
172
173    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
174        let Self { idx, state } = self.get_mut();
175
176        // register waker if the counter is not zero
177        if state.spin_wait(16).is_err() {
178            state.register_waker(idx, cx);
179            // double check after register waker, to catch the update between two steps
180            if state.spin_wait(0).is_err() {
181                return Poll::Pending;
182            }
183        }
184
185        Poll::Ready(())
186    }
187}