async_task_group/
lib.rs

1//! Manage a group of tasks on a runtime.
2//!
3//! Enables cancellation to be propagated across tasks, and ensures if an error
4//! occurs that all sibling tasks in the group are cancelled too.
5//!
6//! # Closures and References
7//!
8//! When calling `group`, a `GroupHandle` instance is passed into the closure.
9//! The intended design is that the instance of `GroupHandle` never outlives the
10//! closure it's contained within. This makes it so when the `TaskGroup` exits
11//! or is cancelled, it is also no longer possible to spawn more tasks on the
12//! group.
13//!
14//! However async closures are not yet a primitive in Rust, so we
15//! cannot yet pass `GroupHandle` by-reference. Instead we approximate the borrowing
16//! behavior by requiring the `GroupHandle` instance be returned at the end of
17//! the closure.
18//!
19//! # Credit
20//!
21//! This codebase is based on the
22//! [`task-group`](https://github.com/pchickey/task-group) project by Pat
23//! Hickey.
24//!
25//! # Examples
26//!
27//! Create an echo tcp server which processes incoming connections in a loop
28//! without ever creating any dangling tasks:
29//!
30//! ```no_run
31//! use async_std::io;
32//! use async_std::net::{TcpListener, TcpStream};
33//! use async_std::prelude::*;
34//! use async_std::task;
35//!
36//! async fn process(stream: TcpStream) -> io::Result<()> {
37//!     println!("Accepted from: {}", stream.peer_addr()?);
38//!
39//!     let mut reader = stream.clone();
40//!     let mut writer = stream;
41//!     io::copy(&mut reader, &mut writer).await?;
42//!
43//!     Ok(())
44//! }
45//!
46//! #[async_std::main]
47//! async fn main() -> io::Result<()> {
48//!     let listener = TcpListener::bind("127.0.0.1:8080").await?;
49//!     println!("Listening on {}", listener.local_addr()?);
50//!
51//!     let handle = async_task_group::group(|group| async move {
52//!         let mut incoming = listener.incoming();
53//!         while let Some(stream) = incoming.next().await {
54//!             let stream = stream?;
55//!             group.spawn(async move { process(stream).await });
56//!         }
57//!         Ok(group)
58//!     });
59//!     handle.await?;
60//!     Ok(())
61//! }
62//! ```
63
64#![deny(missing_debug_implementations, nonstandard_style)]
65#![warn(missing_docs, unreachable_pub)]
66
67use async_std::channel::{self, Receiver, Sender};
68use async_std::future::Future;
69use async_std::stream::Stream;
70use async_std::task::{self, JoinHandle};
71use async_std::task::{Context, Poll};
72use core::pin::Pin;
73
74/// A TaskGroup is used to spawn a collection of tasks. The collection has two properties:
75/// * if any task returns an error or panicks, all tasks are terminated.
76/// * if the `JoinHandle` returned by `group` is dropped, all tasks are terminated.
77pub struct TaskGroup<E> {
78    sender: Sender<ChildHandle<E>>,
79}
80
81impl<E> std::fmt::Debug for TaskGroup<E> {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("TaskGroup").finish_non_exhaustive()
84    }
85}
86
87/// Create a new instance.
88pub fn group<E, Fut, F>(f: F) -> GroupJoinHandle<E>
89where
90    E: Send + 'static,
91    F: FnOnce(TaskGroup<E>) -> Fut,
92    Fut: Future<Output = Result<TaskGroup<E>, E>> + Send + 'static,
93{
94    let (sender, receiver) = channel::unbounded();
95    let group = TaskGroup { sender };
96    let join_handle = GroupJoinHandle::new(receiver);
97    let fut = f(group.clone());
98    group.spawn(async move {
99        let _ = fut.await;
100        Ok(())
101    });
102    join_handle
103}
104
105impl<E> TaskGroup<E>
106where
107    E: Send + 'static,
108{
109    /// Spawn a new task on the runtime.
110    pub fn spawn<F>(&self, f: F)
111    where
112        F: Future<Output = Result<(), E>> + Send + 'static,
113    {
114        let join = task::spawn(f);
115        self.sender
116            .try_send(ChildHandle { handle: join })
117            .expect("Sending a task to the channel failed");
118    }
119
120    /// Spawn a new local task on the runtime.
121    pub fn spawn_local<F>(&self, f: F)
122    where
123        F: Future<Output = Result<(), E>> + 'static,
124    {
125        let join = task::spawn_local(f);
126        self.sender
127            .try_send(ChildHandle { handle: join })
128            .expect("Sending a task to the channel failed");
129    }
130
131    /// Create a new builder.
132    pub fn build(&self) -> GroupBuilder<'_, E> {
133        GroupBuilder {
134            task_group: self,
135            builder: task::Builder::new(),
136        }
137    }
138
139    /// Returns `true` if the task group has been shut down.
140    pub fn is_closed(&self) -> bool {
141        self.sender.is_closed()
142    }
143
144    // Private clone method. This should not be public to guarantee no handle to
145    // `TaskGroup` cannot outlive the closure in which it is granted. Once Rust
146    // has async closures, we can pass `&TaskGroup` down the to closure.
147    fn clone(&self) -> Self {
148        Self {
149            sender: self.sender.clone(),
150        }
151    }
152}
153
154/// Task builder that configures the settings of a new task
155#[derive(Debug)]
156pub struct GroupBuilder<'a, E> {
157    task_group: &'a TaskGroup<E>,
158    builder: task::Builder,
159}
160
161impl<'a, E> GroupBuilder<'a, E>
162where
163    E: Send + 'static,
164{
165    /// Configures the name of the task.
166    pub fn name<A: AsRef<String>>(mut self, name: A) -> Self {
167        self.builder = self.builder.name(name.as_ref().to_owned());
168        self
169    }
170
171    /// Spawns a task with the configured settings.
172    pub fn spawn<F>(self, future: F)
173    where
174        F: Future<Output = Result<(), E>> + Send + 'static,
175    {
176        let handle = self.builder.spawn(future).unwrap();
177        self.task_group
178            .sender
179            .try_send(ChildHandle { handle })
180            .expect("Sending a task to the channel failed");
181    }
182
183    ///Spawns a task locally with the configured settings.
184    pub fn spawn_local<F>(self, future: F)
185    where
186        F: Future<Output = Result<(), E>> + 'static,
187    {
188        let handle = self.builder.local(future).unwrap();
189        self.task_group
190            .sender
191            .try_send(ChildHandle { handle })
192            .expect("Sending a task to the channel failed");
193    }
194}
195
196#[derive(Debug)]
197struct ChildHandle<E> {
198    handle: JoinHandle<Result<(), E>>,
199}
200
201impl<E> ChildHandle<E> {
202    // Pin projection. Since there is only this one required, avoid pulling in the proc macro.
203    fn pin_join(self: Pin<&mut Self>) -> Pin<&mut JoinHandle<Result<(), E>>> {
204        unsafe { self.map_unchecked_mut(|s| &mut s.handle) }
205    }
206}
207
208// As a consequence of this Drop impl, when a JoinHandle is dropped, all of its children will be
209// canceled.
210impl<E> Drop for ChildHandle<E> {
211    fn drop(&mut self) {}
212}
213
214/// A JoinHandle is used to manage a collection of tasks. There are two
215/// things you can do with it:
216/// * JoinHandle impls Future, so you can poll or await on it. It will be
217/// Ready when all tasks return Ok(()) and the associated `TaskGroup` is
218/// dropped (so no more tasks can be created), or when any task panicks or
219/// returns an Err(E).
220/// * When a JoinHandle is dropped, all tasks it contains are canceled
221/// (terminated). So, if you use a combinator like
222/// `tokio::time::timeout(duration, task_manager).await`, all tasks will be
223/// terminated if the timeout occurs.
224#[derive(Debug)]
225pub struct GroupJoinHandle<E> {
226    channel: Option<Receiver<ChildHandle<E>>>,
227    children: Vec<Pin<Box<ChildHandle<E>>>>,
228}
229
230impl<E> GroupJoinHandle<E> {
231    fn new(channel: Receiver<ChildHandle<E>>) -> Self {
232        Self {
233            channel: Some(channel),
234            children: Vec::new(),
235        }
236    }
237}
238
239impl<E> Future for GroupJoinHandle<E> {
240    type Output = Result<(), E>;
241    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
242        let mut s = self.as_mut();
243
244        // If the channel is still open, take it out of s to satisfy the borrow checker.
245        // We'll put it right back once we're done polling it.
246        if let Some(mut channel) = s.channel.take() {
247            // This loop processes each message in the channel until it is either empty
248            // or closed.
249            s.channel = loop {
250                match unsafe { Pin::new_unchecked(&mut channel) }.poll_next(ctx) {
251                    Poll::Pending => {
252                        // No more messages, but channel still open
253                        break Some(channel);
254                    }
255                    Poll::Ready(Some(new_child)) => {
256                        // Put element from channel into the children
257                        s.children.push(Box::pin(new_child));
258                    }
259                    Poll::Ready(None) => {
260                        // Channel has closed and all messages have been recieved. No
261                        // longer need channel.
262                        break None;
263                    }
264                }
265            };
266        }
267
268        // Need to mutate s after discovering error: store here temporarily
269        let mut err = None;
270        // Need to iterate through vec, possibly removing via swap_remove, so we cant use
271        // a normal iterator:
272        let mut child_ix = 0;
273        while s.children.get(child_ix).is_some() {
274            let child = s
275                .children
276                .get_mut(child_ix)
277                .expect("precondition: child exists at index");
278            match child.as_mut().pin_join().poll(ctx) {
279                // Pending children get retained - move to next
280                Poll::Pending => child_ix += 1,
281                // Child returns successfully: remove it from children.
282                // Then execute the loop body again with ix unchanged, because
283                // last element was swapped into child_ix.
284                Poll::Ready(Ok(())) => {
285                    let _ = s.children.swap_remove(child_ix);
286                }
287                // Child returns with error: yield the error
288                Poll::Ready(Err(error)) => {
289                    err = Some(error);
290                    break;
291                }
292            }
293        }
294
295        if let Some(err) = err {
296            // Drop all children, and the channel reciever, current tasks are destroyed
297            // and new tasks cannot be created:
298            s.children.truncate(0);
299            s.channel.take();
300            // Return the error:
301            Poll::Ready(Err(err))
302        } else if s.children.is_empty() {
303            if s.channel.is_none() {
304                // Task manager is complete when there are no more children, and
305                // no more channel to get more children:
306                Poll::Ready(Ok(()))
307            } else {
308                // Channel is still pending, so we are not done:
309                Poll::Pending
310            }
311        } else {
312            Poll::Pending
313        }
314    }
315}
316
317#[cfg(test)]
318mod test {
319    use super::*;
320    use anyhow::anyhow;
321
322    #[async_std::test]
323    async fn no_task() {
324        let handle = group(|group| async move { Ok::<_, ()>(group) });
325        assert!(handle.await.is_ok());
326    }
327
328    #[async_std::test]
329    async fn one_empty_task() {
330        let handle = group(|group| async move {
331            group.spawn(async move { Ok(()) });
332            Ok::<_, ()>(group)
333        });
334        assert!(handle.await.is_ok());
335    }
336
337    #[async_std::test]
338    async fn root_task_errors() {
339        let handle = group(|group| async move {
340            group.spawn(async { Err(anyhow!("idk!")) });
341            Ok(group)
342        });
343        let res = handle.await;
344        assert!(res.is_err());
345        assert_eq!(format!("{:?}", res), "Err(idk!)");
346    }
347}