use futures::FutureExt;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::sync::Mutex;
use tracing::{debug, error};
#[allow(dead_code)]
pub(crate) struct BackgroundTaskManager {
tasks: Mutex<Vec<tokio::task::JoinHandle<()>>>,
}
impl std::fmt::Debug for BackgroundTaskManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("BackgroundTaskManager");
match self.tasks.lock() {
Ok(tasks) => s.field("tasks_count", &tasks.len()),
Err(_) => s.field("tasks_count", &"<poisoned>"),
};
s.finish()
}
}
#[allow(dead_code)]
impl BackgroundTaskManager {
pub fn new() -> Self {
Self {
tasks: Mutex::new(Vec::new()),
}
}
pub fn spawn<F>(&self, future: F)
where
F: Future<Output = ()> + Send + 'static,
{
let handle = tokio::spawn(async move {
if let Err(panic_payload) = AssertUnwindSafe(future).catch_unwind().await {
let msg = panic_payload
.downcast_ref::<&str>()
.copied()
.or_else(|| panic_payload.downcast_ref::<String>().map(|s| s.as_str()))
.unwrap_or("<non-string panic>");
error!("Background task panicked: {msg}");
}
});
let mut tasks = self
.tasks
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
tasks.retain(|h| !h.is_finished());
tasks.push(handle);
}
pub async fn shutdown(&self) {
let tasks: Vec<_> = self
.tasks
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.drain(..)
.collect();
let count = tasks.len();
debug!("BackgroundTaskManager: shutting down {count} background task(s).");
for handle in tasks {
handle.abort();
let _ = handle.await;
}
}
}
impl Drop for BackgroundTaskManager {
fn drop(&mut self) {
let tasks = self.tasks.get_mut().unwrap_or_else(|e| e.into_inner());
let count = tasks.len();
debug!(
"BackgroundTaskManager: aborting {} background task(s).",
count,
);
for handle in tasks.drain(..) {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use tokio::time::Duration;
#[test]
fn new_manager_has_no_tasks() {
let manager = BackgroundTaskManager::new();
assert_eq!(manager.tasks.lock().unwrap().len(), 0);
}
#[test]
fn debug_shows_task_count() {
let manager = BackgroundTaskManager::new();
let debug_str = format!("{:?}", manager);
assert!(debug_str.contains("tasks_count"));
}
#[tokio::test]
async fn drop_cleans_up_tasks() {
let manager = BackgroundTaskManager::new();
manager.spawn(async {});
assert_eq!(manager.tasks.lock().unwrap().len(), 1);
drop(manager);
}
#[tokio::test]
async fn task_runs_to_completion() {
let counter = Arc::new(AtomicU32::new(0));
let manager = BackgroundTaskManager::new();
{
let counter = Arc::clone(&counter);
manager.spawn(async move {
for _ in 0..5 {
counter.fetch_add(1, Ordering::SeqCst);
tokio::task::yield_now().await;
}
});
}
tokio::time::timeout(Duration::from_secs(5), async {
while counter.load(Ordering::SeqCst) < 5 {
tokio::task::yield_now().await;
}
})
.await
.expect("task should complete within timeout");
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
#[tokio::test]
async fn drop_aborts_running_task() {
let started = Arc::new(AtomicBool::new(false));
let completed = Arc::new(AtomicBool::new(false));
let manager = BackgroundTaskManager::new();
{
let started = Arc::clone(&started);
let completed = Arc::clone(&completed);
manager.spawn(async move {
started.store(true, Ordering::SeqCst);
for _ in 0..1_000_000 {
tokio::task::yield_now().await;
}
completed.store(true, Ordering::SeqCst);
});
}
tokio::time::timeout(Duration::from_secs(5), async {
while !started.load(Ordering::SeqCst) {
tokio::task::yield_now().await;
}
})
.await
.expect("task should start within timeout");
drop(manager);
tokio::task::yield_now().await;
assert!(
!completed.load(Ordering::SeqCst),
"task should have been aborted, not completed"
);
}
#[tokio::test]
async fn shutdown_awaits_task_termination() {
let started = Arc::new(AtomicBool::new(false));
let manager = BackgroundTaskManager::new();
{
let started = Arc::clone(&started);
manager.spawn(async move {
started.store(true, Ordering::SeqCst);
for _ in 0..1_000_000 {
tokio::task::yield_now().await;
}
});
}
tokio::time::timeout(Duration::from_secs(5), async {
while !started.load(Ordering::SeqCst) {
tokio::task::yield_now().await;
}
})
.await
.expect("task should start within timeout");
manager.shutdown().await;
assert_eq!(manager.tasks.lock().unwrap().len(), 0);
}
#[tokio::test]
async fn spawn_prunes_finished_handles() {
let manager = BackgroundTaskManager::new();
manager.spawn(async {});
tokio::time::timeout(Duration::from_secs(5), async {
loop {
let all_done = manager
.tasks
.lock()
.unwrap()
.iter()
.all(|h| h.is_finished());
if all_done {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("task should finish within timeout");
manager.spawn(async {});
assert_eq!(manager.tasks.lock().unwrap().len(), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_spawn_is_safe() {
let manager = Arc::new(BackgroundTaskManager::new());
let done_count = Arc::new(AtomicU32::new(0));
let mut spawner_handles = Vec::new();
for _ in 0..20 {
let mgr = Arc::clone(&manager);
let done_count = Arc::clone(&done_count);
spawner_handles.push(tokio::spawn(async move {
mgr.spawn(async move {
done_count.fetch_add(1, Ordering::SeqCst);
});
}));
}
for jh in spawner_handles {
jh.await.unwrap();
}
tokio::time::timeout(Duration::from_secs(5), async {
while done_count.load(Ordering::SeqCst) < 20 {
tokio::task::yield_now().await;
}
})
.await
.expect("all background tasks should complete within timeout");
assert_eq!(done_count.load(Ordering::SeqCst), 20);
}
}