use core::{any::Any, future::Future, panic::AssertUnwindSafe};
use std::panic::resume_unwind;
use futures_util::FutureExt as _;
use tokio::sync::mpsc;
use tokio_util::task::TaskTracker;
use tracing::Instrument;
pub async fn scope<F>(f: F)
where
F: for<'a> AsyncFnOnce(&'a mut Scope),
{
#![allow(clippy::disallowed_macros, reason = "unreachable in select")]
let (mut scope, mut rx) = Scope::new();
let run = async {
f(&mut scope).await;
scope.tracker.close();
scope.tracker.wait().await;
};
tokio::select! {
Some(err) = rx.recv() => {
resume_unwind(err);
}
() = run => {
drop(scope);
if let Some(err) = rx.recv().await {
resume_unwind(err);
}
}
}
}
type Panic = Box<dyn Any + Send>;
#[derive(Debug)]
pub struct Scope {
tracker: TaskTracker,
tx: mpsc::Sender<Panic>,
}
impl Scope {
fn new() -> (Self, mpsc::Receiver<Panic>) {
let (tx, rx) = mpsc::channel(1);
(
Self {
tracker: TaskTracker::new(),
tx,
},
rx,
)
}
pub fn spawn<Fut>(&mut self, fut: Fut)
where
Fut: Future<Output = ()> + Send + 'static,
{
let tx = self.tx.clone();
self.tracker.spawn(
async move {
if let Err(err) = AssertUnwindSafe(fut).catch_unwind().await {
_ = tx.try_send(err);
}
}
.in_current_span(),
);
}
}
#[cfg(test)]
mod test {
#![allow(clippy::panic)]
use std::{
future::pending,
sync::atomic::{AtomicU32, Ordering},
time::Duration,
};
use tokio::time::sleep;
use tokio_util::time::FutureExt as _;
use super::scope;
#[tokio::test]
async fn test_scope_usage() {
const ITERATIONS: u32 = 1000;
const DELAY: Duration = Duration::from_millis(100);
const TIMEOUT: Duration = Duration::from_secs(5);
static COUNTER: AtomicU32 = AtomicU32::new(0);
assert!(ITERATIONS * DELAY > TIMEOUT);
scope(async |s| {
for _ in 0..ITERATIONS {
s.spawn(async {
sleep(DELAY).await;
COUNTER.fetch_add(1, Ordering::AcqRel);
});
}
})
.timeout(TIMEOUT)
.await
.unwrap();
assert_eq!(COUNTER.load(Ordering::Acquire), ITERATIONS);
}
#[tokio::test]
#[should_panic(expected = "panic while spawning")]
async fn test_panic_while_spawning() {
scope(async |s| {
s.spawn(pending());
s.spawn(async move {
panic!("panic while spawning");
});
s.spawn(pending());
pending::<()>().await;
})
.timeout(Duration::from_secs(1))
.await
.unwrap();
}
#[tokio::test]
#[should_panic(expected = "panic after spawning")]
async fn test_panic_after_spawning() {
scope(async |s| {
s.spawn(pending());
s.spawn({
async {
sleep(Duration::from_millis(100)).await;
panic!("panic after spawning");
}
});
s.spawn(pending());
})
.timeout(Duration::from_secs(1))
.await
.unwrap();
}
#[tokio::test]
#[should_panic(expected = "panic in scope")]
async fn test_panic_in_scope() {
scope(async |s| {
s.spawn(pending());
panic!("panic in scope")
})
.timeout(Duration::from_secs(1))
.await
.unwrap();
}
}