mea/waitgroup/mod.rs
1// Copyright 2024 tison <wander4096@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A synchronization primitive for waiting on multiple tasks to complete.
16//!
17//! Similar to Go's WaitGroup, this type allows a task to wait for multiple other
18//! tasks to finish. Each task holds a handle to the WaitGroup, and the main task
19//! can wait for all handles to be dropped before proceeding.
20//!
21//! A WaitGroup waits for a collection of tasks to finish. The main task calls
22//! [`clone()`] to create a new worker handle for each task, and can then wait
23//! for all tasks to complete by calling `.await` on the WaitGroup.
24//!
25//! # Examples
26//!
27//! ```
28//! # #[tokio::main]
29//! # async fn main() {
30//! use std::time::Duration;
31//!
32//! use mea::waitgroup::WaitGroup;
33//! let wg = WaitGroup::new();
34//!
35//! for i in 0..3 {
36//! let wg = wg.clone();
37//! tokio::spawn(async move {
38//! println!("Task {} starting", i);
39//! tokio::time::sleep(Duration::from_millis(100)).await;
40//! // wg is automatically decremented when dropped
41//! drop(wg);
42//! });
43//! }
44//!
45//! // Wait for all tasks to complete
46//! wg.await;
47//! println!("All tasks completed");
48//! # }
49//! ```
50//!
51//! [`clone()`]: WaitGroup::clone
52
53use std::fmt;
54use std::future::Future;
55use std::future::IntoFuture;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::Context;
59use std::task::Poll;
60
61use crate::internal::CountdownState;
62
63#[cfg(test)]
64mod tests;
65
66/// A synchronization primitive for waiting on multiple tasks to complete.
67///
68/// See the [module level documentation](self) for more.
69pub struct WaitGroup {
70 state: Arc<CountdownState>,
71}
72
73impl fmt::Debug for WaitGroup {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 f.debug_struct("WaitGroup").finish_non_exhaustive()
76 }
77}
78
79impl Default for WaitGroup {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl WaitGroup {
86 /// Creates a new `WaitGroup`.
87 ///
88 /// # Examples
89 ///
90 /// ```
91 /// use mea::waitgroup::WaitGroup;
92 ///
93 /// let wg = WaitGroup::new();
94 /// ```
95 pub fn new() -> Self {
96 Self {
97 state: Arc::new(CountdownState::new(1)),
98 }
99 }
100}
101
102impl Clone for WaitGroup {
103 /// Creates a new worker handle for the WaitGroup.
104 ///
105 /// This increments the WaitGroup counter. The counter will be decremented
106 /// when the new handle is dropped.
107 fn clone(&self) -> Self {
108 let sync = self.state.clone();
109 let mut cnt = sync.state();
110 loop {
111 let new_cnt = cnt.saturating_add(1);
112 match sync.cas_state(cnt, new_cnt) {
113 Ok(_) => return Self { state: sync },
114 Err(x) => cnt = x,
115 }
116 }
117 }
118}
119
120impl Drop for WaitGroup {
121 fn drop(&mut self) {
122 if self.state.decrement(1) {
123 self.state.wake_all();
124 }
125 }
126}
127
128impl IntoFuture for WaitGroup {
129 type Output = ();
130 type IntoFuture = Wait;
131
132 /// Converts the WaitGroup into a future that completes when all tasks finish. This decreases
133 /// the WaitGroup counter.
134 fn into_future(self) -> Self::IntoFuture {
135 let state = self.state.clone();
136 drop(self);
137 Wait { idx: None, state }
138 }
139}
140
141/// A future that completes when all tasks in a WaitGroup have finished.
142///
143/// This type is created by either: (1) calling `.await` on a `WaitGroup`, or (2) cloning
144/// itself, which does not increase the WaitGroup counter, but creates a new future that
145/// will complete when the WaitGroup counter reaches zero.
146#[must_use = "futures do nothing unless you `.await` or poll them"]
147pub struct Wait {
148 idx: Option<usize>,
149 state: Arc<CountdownState>,
150}
151
152impl Clone for Wait {
153 /// Creates a new future that also completes when the WaitGroup counter reaches zero.
154 ///
155 /// This does not increment the WaitGroup counter.
156 fn clone(&self) -> Self {
157 Wait {
158 idx: None,
159 state: self.state.clone(),
160 }
161 }
162}
163
164impl fmt::Debug for Wait {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 f.debug_struct("Wait").finish_non_exhaustive()
167 }
168}
169
170impl Future for Wait {
171 type Output = ();
172
173 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
174 let Self { idx, state } = self.get_mut();
175
176 // register waker if the counter is not zero
177 if state.spin_wait(16).is_err() {
178 state.register_waker(idx, cx);
179 // double check after register waker, to catch the update between two steps
180 if state.spin_wait(0).is_err() {
181 return Poll::Pending;
182 }
183 }
184
185 Poll::Ready(())
186 }
187}