use std::sync::Mutex as StdMutex;
use tokio::task::JoinHandle;
pub struct TaskManager {
tasks: StdMutex<Vec<JoinHandle<()>>>,
}
impl Default for TaskManager {
fn default() -> Self {
Self::new()
}
}
impl TaskManager {
pub fn new() -> Self {
Self {
tasks: StdMutex::new(Vec::new()),
}
}
#[track_caller]
pub fn spawn<F>(&self, fut: F)
where
F: std::future::Future<Output = ()> + Send + 'static,
{
use tracing::Instrument;
let location = std::panic::Location::caller();
let span = tracing::trace_span!(
"task",
file = location.file(),
line = location.line(),
column = location.column(),
);
let handle = tokio::spawn(fut.instrument(span));
let mut tasks = self.tasks.lock().unwrap();
tasks.retain(|h| !h.is_finished());
tasks.push(handle);
}
pub async fn join_all(&self) {
let handles = {
let mut tasks = self.tasks.lock().unwrap();
std::mem::take(&mut *tasks)
};
for handle in handles {
let _ = handle.await;
}
}
pub async fn abort_all(&self) {
let mut tasks = self.tasks.lock().unwrap();
for handle in tasks.drain(..) {
handle.abort();
}
}
}