async_task_group/lib.rs
1//! Manage a group of tasks on a runtime.
2//!
3//! Enables cancellation to be propagated across tasks, and ensures if an error
4//! occurs that all sibling tasks in the group are cancelled too.
5//!
6//! # Closures and References
7//!
8//! When calling `group`, a `GroupHandle` instance is passed into the closure.
9//! The intended design is that the instance of `GroupHandle` never outlives the
10//! closure it's contained within. This makes it so when the `TaskGroup` exits
11//! or is cancelled, it is also no longer possible to spawn more tasks on the
12//! group.
13//!
14//! However async closures are not yet a primitive in Rust, so we
15//! cannot yet pass `GroupHandle` by-reference. Instead we approximate the borrowing
16//! behavior by requiring the `GroupHandle` instance be returned at the end of
17//! the closure.
18//!
19//! # Credit
20//!
21//! This codebase is based on the
22//! [`task-group`](https://github.com/pchickey/task-group) project by Pat
23//! Hickey.
24//!
25//! # Examples
26//!
27//! Create an echo tcp server which processes incoming connections in a loop
28//! without ever creating any dangling tasks:
29//!
30//! ```no_run
31//! use async_std::io;
32//! use async_std::net::{TcpListener, TcpStream};
33//! use async_std::prelude::*;
34//! use async_std::task;
35//!
36//! async fn process(stream: TcpStream) -> io::Result<()> {
37//! println!("Accepted from: {}", stream.peer_addr()?);
38//!
39//! let mut reader = stream.clone();
40//! let mut writer = stream;
41//! io::copy(&mut reader, &mut writer).await?;
42//!
43//! Ok(())
44//! }
45//!
46//! #[async_std::main]
47//! async fn main() -> io::Result<()> {
48//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
49//! println!("Listening on {}", listener.local_addr()?);
50//!
51//! let handle = async_task_group::group(|group| async move {
52//! let mut incoming = listener.incoming();
53//! while let Some(stream) = incoming.next().await {
54//! let stream = stream?;
55//! group.spawn(async move { process(stream).await });
56//! }
57//! Ok(group)
58//! });
59//! handle.await?;
60//! Ok(())
61//! }
62//! ```
63
64#![deny(missing_debug_implementations, nonstandard_style)]
65#![warn(missing_docs, unreachable_pub)]
66
67use async_std::channel::{self, Receiver, Sender};
68use async_std::future::Future;
69use async_std::stream::Stream;
70use async_std::task::{self, JoinHandle};
71use async_std::task::{Context, Poll};
72use core::pin::Pin;
73
74/// A TaskGroup is used to spawn a collection of tasks. The collection has two properties:
75/// * if any task returns an error or panicks, all tasks are terminated.
76/// * if the `JoinHandle` returned by `group` is dropped, all tasks are terminated.
77pub struct TaskGroup<E> {
78 sender: Sender<ChildHandle<E>>,
79}
80
81impl<E> std::fmt::Debug for TaskGroup<E> {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("TaskGroup").finish_non_exhaustive()
84 }
85}
86
87/// Create a new instance.
88pub fn group<E, Fut, F>(f: F) -> GroupJoinHandle<E>
89where
90 E: Send + 'static,
91 F: FnOnce(TaskGroup<E>) -> Fut,
92 Fut: Future<Output = Result<TaskGroup<E>, E>> + Send + 'static,
93{
94 let (sender, receiver) = channel::unbounded();
95 let group = TaskGroup { sender };
96 let join_handle = GroupJoinHandle::new(receiver);
97 let fut = f(group.clone());
98 group.spawn(async move {
99 let _ = fut.await;
100 Ok(())
101 });
102 join_handle
103}
104
105impl<E> TaskGroup<E>
106where
107 E: Send + 'static,
108{
109 /// Spawn a new task on the runtime.
110 pub fn spawn<F>(&self, f: F)
111 where
112 F: Future<Output = Result<(), E>> + Send + 'static,
113 {
114 let join = task::spawn(f);
115 self.sender
116 .try_send(ChildHandle { handle: join })
117 .expect("Sending a task to the channel failed");
118 }
119
120 /// Spawn a new local task on the runtime.
121 pub fn spawn_local<F>(&self, f: F)
122 where
123 F: Future<Output = Result<(), E>> + 'static,
124 {
125 let join = task::spawn_local(f);
126 self.sender
127 .try_send(ChildHandle { handle: join })
128 .expect("Sending a task to the channel failed");
129 }
130
131 /// Create a new builder.
132 pub fn build(&self) -> GroupBuilder<'_, E> {
133 GroupBuilder {
134 task_group: self,
135 builder: task::Builder::new(),
136 }
137 }
138
139 /// Returns `true` if the task group has been shut down.
140 pub fn is_closed(&self) -> bool {
141 self.sender.is_closed()
142 }
143
144 // Private clone method. This should not be public to guarantee no handle to
145 // `TaskGroup` cannot outlive the closure in which it is granted. Once Rust
146 // has async closures, we can pass `&TaskGroup` down the to closure.
147 fn clone(&self) -> Self {
148 Self {
149 sender: self.sender.clone(),
150 }
151 }
152}
153
154/// Task builder that configures the settings of a new task
155#[derive(Debug)]
156pub struct GroupBuilder<'a, E> {
157 task_group: &'a TaskGroup<E>,
158 builder: task::Builder,
159}
160
161impl<'a, E> GroupBuilder<'a, E>
162where
163 E: Send + 'static,
164{
165 /// Configures the name of the task.
166 pub fn name<A: AsRef<String>>(mut self, name: A) -> Self {
167 self.builder = self.builder.name(name.as_ref().to_owned());
168 self
169 }
170
171 /// Spawns a task with the configured settings.
172 pub fn spawn<F>(self, future: F)
173 where
174 F: Future<Output = Result<(), E>> + Send + 'static,
175 {
176 let handle = self.builder.spawn(future).unwrap();
177 self.task_group
178 .sender
179 .try_send(ChildHandle { handle })
180 .expect("Sending a task to the channel failed");
181 }
182
183 ///Spawns a task locally with the configured settings.
184 pub fn spawn_local<F>(self, future: F)
185 where
186 F: Future<Output = Result<(), E>> + 'static,
187 {
188 let handle = self.builder.local(future).unwrap();
189 self.task_group
190 .sender
191 .try_send(ChildHandle { handle })
192 .expect("Sending a task to the channel failed");
193 }
194}
195
196#[derive(Debug)]
197struct ChildHandle<E> {
198 handle: JoinHandle<Result<(), E>>,
199}
200
201impl<E> ChildHandle<E> {
202 // Pin projection. Since there is only this one required, avoid pulling in the proc macro.
203 fn pin_join(self: Pin<&mut Self>) -> Pin<&mut JoinHandle<Result<(), E>>> {
204 unsafe { self.map_unchecked_mut(|s| &mut s.handle) }
205 }
206}
207
208// As a consequence of this Drop impl, when a JoinHandle is dropped, all of its children will be
209// canceled.
210impl<E> Drop for ChildHandle<E> {
211 fn drop(&mut self) {}
212}
213
214/// A JoinHandle is used to manage a collection of tasks. There are two
215/// things you can do with it:
216/// * JoinHandle impls Future, so you can poll or await on it. It will be
217/// Ready when all tasks return Ok(()) and the associated `TaskGroup` is
218/// dropped (so no more tasks can be created), or when any task panicks or
219/// returns an Err(E).
220/// * When a JoinHandle is dropped, all tasks it contains are canceled
221/// (terminated). So, if you use a combinator like
222/// `tokio::time::timeout(duration, task_manager).await`, all tasks will be
223/// terminated if the timeout occurs.
224#[derive(Debug)]
225pub struct GroupJoinHandle<E> {
226 channel: Option<Receiver<ChildHandle<E>>>,
227 children: Vec<Pin<Box<ChildHandle<E>>>>,
228}
229
230impl<E> GroupJoinHandle<E> {
231 fn new(channel: Receiver<ChildHandle<E>>) -> Self {
232 Self {
233 channel: Some(channel),
234 children: Vec::new(),
235 }
236 }
237}
238
239impl<E> Future for GroupJoinHandle<E> {
240 type Output = Result<(), E>;
241 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
242 let mut s = self.as_mut();
243
244 // If the channel is still open, take it out of s to satisfy the borrow checker.
245 // We'll put it right back once we're done polling it.
246 if let Some(mut channel) = s.channel.take() {
247 // This loop processes each message in the channel until it is either empty
248 // or closed.
249 s.channel = loop {
250 match unsafe { Pin::new_unchecked(&mut channel) }.poll_next(ctx) {
251 Poll::Pending => {
252 // No more messages, but channel still open
253 break Some(channel);
254 }
255 Poll::Ready(Some(new_child)) => {
256 // Put element from channel into the children
257 s.children.push(Box::pin(new_child));
258 }
259 Poll::Ready(None) => {
260 // Channel has closed and all messages have been recieved. No
261 // longer need channel.
262 break None;
263 }
264 }
265 };
266 }
267
268 // Need to mutate s after discovering error: store here temporarily
269 let mut err = None;
270 // Need to iterate through vec, possibly removing via swap_remove, so we cant use
271 // a normal iterator:
272 let mut child_ix = 0;
273 while s.children.get(child_ix).is_some() {
274 let child = s
275 .children
276 .get_mut(child_ix)
277 .expect("precondition: child exists at index");
278 match child.as_mut().pin_join().poll(ctx) {
279 // Pending children get retained - move to next
280 Poll::Pending => child_ix += 1,
281 // Child returns successfully: remove it from children.
282 // Then execute the loop body again with ix unchanged, because
283 // last element was swapped into child_ix.
284 Poll::Ready(Ok(())) => {
285 let _ = s.children.swap_remove(child_ix);
286 }
287 // Child returns with error: yield the error
288 Poll::Ready(Err(error)) => {
289 err = Some(error);
290 break;
291 }
292 }
293 }
294
295 if let Some(err) = err {
296 // Drop all children, and the channel reciever, current tasks are destroyed
297 // and new tasks cannot be created:
298 s.children.truncate(0);
299 s.channel.take();
300 // Return the error:
301 Poll::Ready(Err(err))
302 } else if s.children.is_empty() {
303 if s.channel.is_none() {
304 // Task manager is complete when there are no more children, and
305 // no more channel to get more children:
306 Poll::Ready(Ok(()))
307 } else {
308 // Channel is still pending, so we are not done:
309 Poll::Pending
310 }
311 } else {
312 Poll::Pending
313 }
314 }
315}
316
317#[cfg(test)]
318mod test {
319 use super::*;
320 use anyhow::anyhow;
321
322 #[async_std::test]
323 async fn no_task() {
324 let handle = group(|group| async move { Ok::<_, ()>(group) });
325 assert!(handle.await.is_ok());
326 }
327
328 #[async_std::test]
329 async fn one_empty_task() {
330 let handle = group(|group| async move {
331 group.spawn(async move { Ok(()) });
332 Ok::<_, ()>(group)
333 });
334 assert!(handle.await.is_ok());
335 }
336
337 #[async_std::test]
338 async fn root_task_errors() {
339 let handle = group(|group| async move {
340 group.spawn(async { Err(anyhow!("idk!")) });
341 Ok(group)
342 });
343 let res = handle.await;
344 assert!(res.is_err());
345 assert_eq!(format!("{:?}", res), "Err(idk!)");
346 }
347}