1use core::{any::Any, future::Future, panic::AssertUnwindSafe};
2use std::panic::resume_unwind;
3
4use futures_util::FutureExt as _;
5use tokio::sync::mpsc;
6use tokio_util::task::TaskTracker;
7use tracing::Instrument;
8
9pub async fn scope<F>(f: F)
48where
49 F: for<'a> AsyncFnOnce(&'a mut Scope),
50{
51 #![allow(clippy::disallowed_macros, reason = "unreachable in select")]
52
53 let (mut scope, mut rx) = Scope::new();
54 let run = async {
55 f(&mut scope).await;
56 scope.tracker.close();
57 scope.tracker.wait().await;
58 };
59 tokio::select! {
60 Some(err) = rx.recv() => {
61 resume_unwind(err);
62 }
63 () = run => {
64 drop(scope);
65 if let Some(err) = rx.recv().await {
66 resume_unwind(err);
67 }
68 }
69 }
70}
71
72type Panic = Box<dyn Any + Send>;
73
74#[derive(Debug)]
75pub struct Scope {
76 tracker: TaskTracker,
77 tx: mpsc::Sender<Panic>,
78}
79
80impl Scope {
81 fn new() -> (Self, mpsc::Receiver<Panic>) {
82 let (tx, rx) = mpsc::channel(1);
83 (
84 Self {
85 tracker: TaskTracker::new(),
86 tx,
87 },
88 rx,
89 )
90 }
91
92 pub fn spawn<Fut>(&mut self, fut: Fut)
96 where
97 Fut: Future<Output = ()> + Send + 'static,
98 {
99 let tx = self.tx.clone();
100 self.tracker.spawn(
101 async move {
102 if let Err(err) = AssertUnwindSafe(fut).catch_unwind().await {
106 _ = tx.try_send(err);
107 }
108 }
109 .in_current_span(),
110 );
111 }
112}
113
114#[cfg(test)]
115mod test {
116 #![allow(clippy::panic)]
117
118 use std::{
119 future::pending,
120 sync::atomic::{AtomicU32, Ordering},
121 time::Duration,
122 };
123
124 use tokio::time::sleep;
125 use tokio_util::time::FutureExt as _;
126
127 use super::scope;
128
129 #[tokio::test]
130 async fn test_scope_usage() {
131 const ITERATIONS: u32 = 1000;
132 const DELAY: Duration = Duration::from_millis(100);
133 const TIMEOUT: Duration = Duration::from_secs(5);
134
135 static COUNTER: AtomicU32 = AtomicU32::new(0);
136
137 assert!(ITERATIONS * DELAY > TIMEOUT);
139
140 scope(async |s| {
141 for _ in 0..ITERATIONS {
142 s.spawn(async {
143 sleep(DELAY).await;
144 COUNTER.fetch_add(1, Ordering::AcqRel);
145 });
146 }
147 })
148 .timeout(TIMEOUT)
149 .await
150 .unwrap();
151 assert_eq!(COUNTER.load(Ordering::Acquire), ITERATIONS);
152 }
153
154 #[tokio::test]
155 #[should_panic(expected = "panic while spawning")]
156 async fn test_panic_while_spawning() {
157 scope(async |s| {
158 s.spawn(pending());
159 s.spawn(async move {
160 panic!("panic while spawning");
161 });
162 s.spawn(pending());
163 pending::<()>().await;
164 })
165 .timeout(Duration::from_secs(1))
166 .await
167 .unwrap();
168 }
169
170 #[tokio::test]
171 #[should_panic(expected = "panic after spawning")]
172 async fn test_panic_after_spawning() {
173 scope(async |s| {
174 s.spawn(pending());
175 s.spawn({
176 async {
177 sleep(Duration::from_millis(100)).await;
178 panic!("panic after spawning");
179 }
180 });
181 s.spawn(pending());
182 })
183 .timeout(Duration::from_secs(1))
184 .await
185 .unwrap();
186 }
187
188 #[tokio::test]
189 #[should_panic(expected = "panic in scope")]
190 async fn test_panic_in_scope() {
191 scope(async |s| {
192 s.spawn(pending());
193 panic!("panic in scope")
194 })
195 .timeout(Duration::from_secs(1))
196 .await
197 .unwrap();
198 }
199}