use std::future::Future;
use tokio::task::JoinHandle;
pub fn spawn_supervised<F>(name: &'static str, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let inner = tokio::spawn(fut);
tokio::spawn(async move {
match inner.await {
Ok(()) => {
tracing::debug!(task = name, "supervised task exited");
}
Err(e) if e.is_panic() => {
tracing::error!(
task = name,
panic = %format_panic(e),
"supervised task panicked",
);
}
Err(e) if e.is_cancelled() => {
tracing::debug!(task = name, "supervised task cancelled");
}
Err(e) => {
tracing::error!(
task = name,
error = %e,
"supervised task join failed",
);
}
}
})
}
fn format_panic(e: tokio::task::JoinError) -> String {
if e.is_panic() {
let payload = e.into_panic();
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_owned()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_owned()
}
} else {
e.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[tokio::test]
async fn supervised_task_runs_to_completion() {
let ran = Arc::new(AtomicBool::new(false));
let r = ran.clone();
let handle = spawn_supervised("test.ok", async move {
r.store(true, Ordering::SeqCst);
});
handle.await.unwrap();
assert!(ran.load(Ordering::SeqCst));
}
#[tokio::test]
async fn supervised_task_catches_panic() {
let handle = spawn_supervised("test.panic", async move {
panic!("intentional");
});
let outcome = handle.await;
assert!(outcome.is_ok(), "supervisor must absorb the panic");
}
}