use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use tokio::task::{AbortHandle, JoinHandle};
use crate::error::{Result, ServiceError};
use crate::shutdown::ShutdownToken;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskKind {
BackgroundLoop,
PeerConnection,
RpcHandler,
Maintenance,
}
#[derive(Clone)]
pub struct TaskRegistry {
inner: Arc<Inner>,
}
struct Inner {
shutdown: ShutdownToken,
entries: Mutex<Vec<Entry>>,
}
struct Entry {
name: &'static str,
kind: TaskKind,
spawned_at: Instant,
handle: JoinHandle<()>,
}
#[derive(Debug, Clone)]
pub struct TaskSummary {
pub name: &'static str,
pub spawned_at: Instant,
pub kind: TaskKind,
}
impl TaskRegistry {
pub fn new(shutdown: ShutdownToken) -> Self {
Self {
inner: Arc::new(Inner {
shutdown,
entries: Mutex::new(Vec::new()),
}),
}
}
pub fn spawn<F>(&self, name: &'static str, kind: TaskKind, fut: F) -> AbortHandle
where
F: Future<Output = anyhow::Result<()>> + Send + 'static,
{
let handle: JoinHandle<()> = tokio::spawn(async move {
if let Err(e) = fut.await {
tracing::error!(task = name, error = %e, "task exited with error");
}
});
let abort = handle.abort_handle();
self.inner.entries.lock().push(Entry {
name,
kind,
spawned_at: Instant::now(),
handle,
});
abort
}
pub fn shutdown(&self) -> &ShutdownToken {
&self.inner.shutdown
}
pub async fn join_all(&self, deadline: Duration) -> Result<()> {
let entries = {
let mut g = self.inner.entries.lock();
std::mem::take(&mut *g)
};
let start = Instant::now();
let mut pending = 0usize;
for entry in entries {
let remaining = deadline.saturating_sub(start.elapsed());
if remaining.is_zero() {
entry.handle.abort();
pending += 1;
continue;
}
match tokio::time::timeout(remaining, entry.handle).await {
Ok(Ok(())) => {}
Ok(Err(join_err)) => {
tracing::warn!(
task = entry.name,
elapsed = ?entry.spawned_at.elapsed(),
panic = join_err.is_panic(),
"tracked task did not exit cleanly",
);
}
Err(_elapsed) => {
pending += 1;
tracing::warn!(
task = entry.name,
"task exceeded shutdown deadline; aborting",
);
}
}
}
if pending == 0 {
Ok(())
} else {
Err(ServiceError::ShutdownDeadlineExceeded { deadline, pending })
}
}
pub fn snapshot(&self) -> Vec<TaskSummary> {
self.inner
.entries
.lock()
.iter()
.map(|e| TaskSummary {
name: e.name,
spawned_at: e.spawned_at,
kind: e.kind,
})
.collect()
}
pub fn len(&self) -> usize {
self.inner.entries.lock().len()
}
pub fn is_empty(&self) -> bool {
self.inner.entries.lock().is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_on_construction() {
let r = TaskRegistry::new(ShutdownToken::new());
assert!(r.is_empty());
assert_eq!(r.len(), 0);
}
#[tokio::test]
async fn spawn_registers() {
let r = TaskRegistry::new(ShutdownToken::new());
let _h = r.spawn("t1", TaskKind::BackgroundLoop, async { Ok(()) });
assert_eq!(r.len(), 1);
assert_eq!(r.snapshot()[0].name, "t1");
assert_eq!(r.snapshot()[0].kind, TaskKind::BackgroundLoop);
}
#[tokio::test]
async fn join_all_awaits_fast_tasks() {
let r = TaskRegistry::new(ShutdownToken::new());
r.spawn("fast", TaskKind::BackgroundLoop, async {
tokio::time::sleep(Duration::from_millis(5)).await;
Ok(())
});
let res = r.join_all(Duration::from_secs(5)).await;
assert!(res.is_ok());
}
#[tokio::test]
async fn join_all_aborts_slow_tasks() {
let r = TaskRegistry::new(ShutdownToken::new());
r.spawn("slow", TaskKind::BackgroundLoop, async {
tokio::time::sleep(Duration::from_secs(60)).await;
Ok(())
});
let res = r.join_all(Duration::from_millis(20)).await;
assert!(matches!(
res,
Err(ServiceError::ShutdownDeadlineExceeded { pending: 1, .. })
));
}
}