Skip to main content

aranya_util/
task.rs

1use core::{any::Any, future::Future, panic::AssertUnwindSafe};
2use std::panic::resume_unwind;
3
4use futures_util::FutureExt as _;
5use tokio::sync::mpsc;
6use tokio_util::task::TaskTracker;
7use tracing::Instrument;
8
9// TODO(jdygert): Abort all tasks on drop?
10
11/// Creates a scope for spawning scoped async tasks for structured concurrency.
12///
13/// The given function will be provided a [`Scope`] through which tasks can be
14/// [`spawned`][Scope::spawn]. The resulting future will pend until all tasks have finished.
15/// The tracked tasks can free their memory as soon as they finish, making this well suited for
16/// spawning many short lived tasks over time.
17///
18/// Unlike [`std::thread::scope`], this function does not let you spawn tasks which borrow from the
19/// local scope, since there is no safe way to do so.
20///
21/// # Panics
22///
23/// If any of the spawned tasks panic, this future will panic.
24///
25/// # Example
26///
27/// ```no_run
28/// # async fn test() {
29/// # use core::time::Duration;
30/// # use tokio::time::sleep;
31/// use aranya_util::task::scope;
32/// // prints "Hello, world!" after 1s and resolves after 10s.
33/// scope(async |s| {
34///     s.spawn(async {
35///         sleep(Duration::from_secs(10)).await;
36///     });
37///
38///     let msg = String::from("Hello, world!");
39///     s.spawn(async move {
40///         sleep(Duration::from_secs(1)).await;
41///         println!("{msg}");
42///     });
43/// })
44/// .await;
45/// # }
46/// ```
47pub async fn scope<F>(f: F)
48where
49    F: for<'a> AsyncFnOnce(&'a mut Scope),
50{
51    #![allow(clippy::disallowed_macros, reason = "unreachable in select")]
52
53    let (mut scope, mut rx) = Scope::new();
54    let run = async {
55        f(&mut scope).await;
56        scope.tracker.close();
57        scope.tracker.wait().await;
58    };
59    tokio::select! {
60        Some(err) = rx.recv() => {
61            resume_unwind(err);
62        }
63        () = run => {
64            drop(scope);
65            if let Some(err) = rx.recv().await {
66                resume_unwind(err);
67            }
68        }
69    }
70}
71
72type Panic = Box<dyn Any + Send>;
73
74#[derive(Debug)]
75pub struct Scope {
76    tracker: TaskTracker,
77    tx: mpsc::Sender<Panic>,
78}
79
80impl Scope {
81    fn new() -> (Self, mpsc::Receiver<Panic>) {
82        let (tx, rx) = mpsc::channel(1);
83        (
84            Self {
85                tracker: TaskTracker::new(),
86                tx,
87            },
88            rx,
89        )
90    }
91
92    /// Spawns a future as a task.
93    ///
94    /// The future must be `Send + 'static`.
95    pub fn spawn<Fut>(&mut self, fut: Fut)
96    where
97        Fut: Future<Output = ()> + Send + 'static,
98    {
99        let tx = self.tx.clone();
100        self.tracker.spawn(
101            async move {
102                // Note: Tokio's join error gives you the panic payload anyways, so using
103                // `AssertUnwindSafe` here isn't any less unwind-safe than that.
104                // (`UnwindSafe` is more like a lint anyways.)
105                if let Err(err) = AssertUnwindSafe(fut).catch_unwind().await {
106                    _ = tx.try_send(err);
107                }
108            }
109            .in_current_span(),
110        );
111    }
112}
113
114#[cfg(test)]
115mod test {
116    #![allow(clippy::panic)]
117
118    use std::{
119        future::pending,
120        sync::atomic::{AtomicU32, Ordering},
121        time::Duration,
122    };
123
124    use tokio::time::sleep;
125    use tokio_util::time::FutureExt as _;
126
127    use super::scope;
128
129    #[tokio::test]
130    async fn test_scope_usage() {
131        const ITERATIONS: u32 = 1000;
132        const DELAY: Duration = Duration::from_millis(100);
133        const TIMEOUT: Duration = Duration::from_secs(5);
134
135        static COUNTER: AtomicU32 = AtomicU32::new(0);
136
137        // This ensures that the task cannot be run sequentially within the timeout.
138        assert!(ITERATIONS * DELAY > TIMEOUT);
139
140        scope(async |s| {
141            for _ in 0..ITERATIONS {
142                s.spawn(async {
143                    sleep(DELAY).await;
144                    COUNTER.fetch_add(1, Ordering::AcqRel);
145                });
146            }
147        })
148        .timeout(TIMEOUT)
149        .await
150        .unwrap();
151        assert_eq!(COUNTER.load(Ordering::Acquire), ITERATIONS);
152    }
153
154    #[tokio::test]
155    #[should_panic(expected = "panic while spawning")]
156    async fn test_panic_while_spawning() {
157        scope(async |s| {
158            s.spawn(pending());
159            s.spawn(async move {
160                panic!("panic while spawning");
161            });
162            s.spawn(pending());
163            pending::<()>().await;
164        })
165        .timeout(Duration::from_secs(1))
166        .await
167        .unwrap();
168    }
169
170    #[tokio::test]
171    #[should_panic(expected = "panic after spawning")]
172    async fn test_panic_after_spawning() {
173        scope(async |s| {
174            s.spawn(pending());
175            s.spawn({
176                async {
177                    sleep(Duration::from_millis(100)).await;
178                    panic!("panic after spawning");
179                }
180            });
181            s.spawn(pending());
182        })
183        .timeout(Duration::from_secs(1))
184        .await
185        .unwrap();
186    }
187
188    #[tokio::test]
189    #[should_panic(expected = "panic in scope")]
190    async fn test_panic_in_scope() {
191        scope(async |s| {
192            s.spawn(pending());
193            panic!("panic in scope")
194        })
195        .timeout(Duration::from_secs(1))
196        .await
197        .unwrap();
198    }
199}