use std::sync::Arc;
use parking_lot::Mutex;
use tokio::task::JoinHandle;
struct TrackedTask {
name: &'static str,
handle: JoinHandle<()>,
}
#[derive(Clone)]
pub(crate) struct BackgroundTaskMonitor {
inner: Arc<Mutex<Vec<TrackedTask>>>,
}
impl BackgroundTaskMonitor {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn register(&self, name: &'static str, handle: JoinHandle<()>) {
self.inner.lock().push(TrackedTask { name, handle });
}
pub fn wait_for_any_exit(self) -> impl std::future::Future<Output = anyhow::Error> + Send {
let tasks: Vec<TrackedTask> = {
let mut inner = self.inner.lock();
std::mem::take(&mut *inner)
};
async move {
if tasks.is_empty() {
std::future::pending::<()>().await;
unreachable!();
}
let mut join_set = tokio::task::JoinSet::new();
let names: Vec<&'static str> = tasks.iter().map(|t| t.name).collect();
for (idx, task) in tasks.into_iter().enumerate() {
join_set.spawn(async move {
let result = task.handle.await;
(idx, result)
});
}
let (idx, result) = join_set
.join_next()
.await
.expect("JoinSet is non-empty")
.expect("wrapper task should not panic");
let name = names[idx];
join_set.abort_all();
match result {
Err(e) if e.is_panic() => {
tracing::error!(task = name, "Background task panicked: {e}");
anyhow::anyhow!("Background task '{name}' panicked: {e}")
}
Err(e) => {
tracing::error!(task = name, "Background task failed: {e}");
anyhow::anyhow!("Background task '{name}' failed: {e}")
}
Ok(()) => {
tracing::error!(
task = name,
"Background task exited unexpectedly (clean return)"
);
anyhow::anyhow!("Background task '{name}' exited unexpectedly")
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn panicking_task_is_detected() {
let monitor = BackgroundTaskMonitor::new();
let h1 = tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(60)).await;
});
let h2 = tokio::spawn(async {
tokio::time::sleep(Duration::from_millis(10)).await;
panic!("background task panic");
});
monitor.register("sleeper", h1);
monitor.register("panicker", h2);
let err = monitor.wait_for_any_exit().await;
let msg = err.to_string();
assert!(
msg.contains("panicker") && msg.contains("panicked"),
"Expected panic error for 'panicker', got: {msg}"
);
}
#[tokio::test]
async fn clean_exit_is_detected() {
let monitor = BackgroundTaskMonitor::new();
let h1 = tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(60)).await;
});
let h2 = tokio::spawn(async {
});
monitor.register("sleeper", h1);
monitor.register("quick_exit", h2);
let err = monitor.wait_for_any_exit().await;
let msg = err.to_string();
assert!(
msg.contains("quick_exit") && msg.contains("exited unexpectedly"),
"Expected 'exited unexpectedly' error for 'quick_exit', got: {msg}"
);
}
#[tokio::test]
async fn empty_monitor_never_resolves() {
let monitor = BackgroundTaskMonitor::new();
let result =
tokio::time::timeout(Duration::from_millis(50), monitor.wait_for_any_exit()).await;
assert!(result.is_err(), "Empty monitor should not resolve");
}
#[tokio::test]
async fn first_exit_wins() {
let monitor = BackgroundTaskMonitor::new();
let h1 = tokio::spawn(async {
tokio::time::sleep(Duration::from_millis(100)).await;
});
let h2 = tokio::spawn(async {
tokio::time::sleep(Duration::from_millis(10)).await;
});
let h3 = tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(60)).await;
});
monitor.register("slow_exit", h1);
monitor.register("fast_exit", h2);
monitor.register("sleeper", h3);
let err = monitor.wait_for_any_exit().await;
let msg = err.to_string();
assert!(
msg.contains("fast_exit"),
"Expected 'fast_exit' to be detected first, got: {msg}"
);
}
}