use std::time::Duration;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
struct Service {
name: String,
handle: JoinHandle<()>,
}
pub struct Supervisor {
cancel: CancellationToken,
services: Vec<Service>,
}
impl Supervisor {
pub fn new(cancel: CancellationToken) -> Self {
Self {
cancel,
services: Vec::new(),
}
}
pub fn cancel_token(&self) -> &CancellationToken {
&self.cancel
}
pub fn spawn<F, Fut>(&mut self, name: impl Into<String>, body: F)
where
F: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let name = name.into();
let cancel = self.cancel.clone();
let handle = tokio::spawn(async move {
body(cancel).await;
});
tracing::debug!(service = %name, "service spawned");
self.services.push(Service { name, handle });
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.services.len()
}
#[cfg(test)]
pub(crate) fn is_empty(&self) -> bool {
self.services.is_empty()
}
pub async fn shutdown(mut self, per_service_timeout: Duration) {
self.cancel.cancel();
for Service { name, handle } in self.services.drain(..) {
match tokio::time::timeout(per_service_timeout, handle).await {
Ok(Ok(())) => tracing::debug!(service = %name, "service exited cleanly"),
Ok(Err(e)) => tracing::warn!(service = %name, error = %e, "service panicked"),
Err(_) => tracing::warn!(
service = %name,
timeout_secs = per_service_timeout.as_secs(),
"service shutdown timed out; abandoning"
),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn spawn_increments_len_and_runs_body() {
let counter = Arc::new(AtomicUsize::new(0));
let mut sup = Supervisor::new(CancellationToken::new());
let c = Arc::clone(&counter);
sup.spawn("worker", move |_cancel| async move {
c.fetch_add(1, Ordering::SeqCst);
});
assert_eq!(sup.len(), 1);
sup.shutdown(Duration::from_secs(1)).await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn shutdown_cancels_long_running_service() {
let mut sup = Supervisor::new(CancellationToken::new());
sup.spawn("idler", move |cancel| async move {
cancel.cancelled().await;
});
sup.shutdown(Duration::from_secs(1)).await;
}
#[tokio::test(start_paused = true)]
async fn shutdown_times_out_on_wedged_service() {
let mut sup = Supervisor::new(CancellationToken::new());
sup.spawn("wedged", move |_cancel| async move {
tokio::time::sleep(Duration::from_secs(86_400)).await;
});
let start = tokio::time::Instant::now();
sup.shutdown(Duration::from_millis(50)).await;
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(1),
"shutdown must not block past the timeout, took {elapsed:?}"
);
}
#[tokio::test]
async fn empty_supervisor_shuts_down_immediately() {
let sup = Supervisor::new(CancellationToken::new());
assert!(sup.is_empty());
sup.shutdown(Duration::from_secs(1)).await;
}
#[test]
fn cancel_token_returns_underlying_token() {
let outer = CancellationToken::new();
let sup = Supervisor::new(outer.clone());
sup.cancel_token().cancel();
assert!(outer.is_cancelled());
}
#[tokio::test]
async fn shutdown_logs_panicking_service_via_join_error() {
let mut sup = Supervisor::new(CancellationToken::new());
sup.spawn("panicker", move |_cancel| async move {
panic!("intentional");
});
sup.shutdown(Duration::from_secs(1)).await;
}
}