use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
#[derive(Clone)]
pub struct ShutdownCoordinator {
token: CancellationToken,
in_flight: Arc<AtomicUsize>,
timeout_seconds: u64,
}
impl ShutdownCoordinator {
pub fn new(timeout_seconds: u64) -> Self {
Self {
token: CancellationToken::new(),
in_flight: Arc::new(AtomicUsize::new(0)),
timeout_seconds,
}
}
pub fn token(&self) -> CancellationToken {
self.token.clone()
}
pub fn is_shutting_down(&self) -> bool {
self.token.is_cancelled()
}
pub fn track_task(&self) -> TaskGuard {
self.in_flight.fetch_add(1, Ordering::SeqCst);
TaskGuard {
in_flight: Arc::clone(&self.in_flight),
}
}
pub fn in_flight_count(&self) -> usize {
self.in_flight.load(Ordering::SeqCst)
}
pub fn trigger(&self) {
self.token.cancel();
}
pub async fn wait_for_tasks(&self) -> bool {
let timeout = std::time::Duration::from_secs(self.timeout_seconds);
let poll_interval = std::time::Duration::from_millis(250);
let deadline = tokio::time::Instant::now() + timeout;
loop {
let count = self.in_flight_count();
if count == 0 {
info!("All in-flight tasks completed");
return true;
}
if tokio::time::Instant::now() >= deadline {
warn!(
remaining_tasks = count,
timeout_seconds = self.timeout_seconds,
"Shutdown timeout expired with in-flight tasks still pending"
);
return false;
}
info!(
remaining_tasks = count,
"Waiting for in-flight tasks to complete"
);
tokio::time::sleep(poll_interval).await;
}
}
}
pub struct TaskGuard {
in_flight: Arc<AtomicUsize>,
}
impl Drop for TaskGuard {
fn drop(&mut self) {
self.in_flight.fetch_sub(1, Ordering::SeqCst);
}
}
pub async fn shutdown_signal(coordinator: ShutdownCoordinator) {
let token = coordinator.token();
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm =
signal(SignalKind::terminate()).expect("failed to install SIGTERM handler");
let mut sigint = signal(SignalKind::interrupt()).expect("failed to install SIGINT handler");
tokio::select! {
_ = sigterm.recv() => {
info!("Shutdown signal received (SIGTERM)");
}
_ = sigint.recv() => {
info!("Shutdown signal received (SIGINT)");
}
_ = token.cancelled() => {
info!("Shutdown signal received (token cancelled)");
}
}
}
#[cfg(not(unix))]
{
tokio::select! {
_ = tokio::signal::ctrl_c() => {
info!("Shutdown signal received (Ctrl-C)");
}
_ = token.cancelled() => {
info!("Shutdown signal received (token cancelled)");
}
}
}
coordinator.trigger();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coordinator_initial_state() {
let coord = ShutdownCoordinator::new(30);
assert!(!coord.is_shutting_down());
assert_eq!(coord.in_flight_count(), 0);
}
#[test]
fn test_coordinator_trigger() {
let coord = ShutdownCoordinator::new(30);
assert!(!coord.is_shutting_down());
coord.trigger();
assert!(coord.is_shutting_down());
}
#[test]
fn test_task_guard_increments_and_decrements() {
let coord = ShutdownCoordinator::new(30);
assert_eq!(coord.in_flight_count(), 0);
let guard1 = coord.track_task();
assert_eq!(coord.in_flight_count(), 1);
let guard2 = coord.track_task();
assert_eq!(coord.in_flight_count(), 2);
drop(guard1);
assert_eq!(coord.in_flight_count(), 1);
drop(guard2);
assert_eq!(coord.in_flight_count(), 0);
}
#[test]
fn test_coordinator_clone_shares_state() {
let coord = ShutdownCoordinator::new(30);
let coord2 = coord.clone();
let _guard = coord.track_task();
assert_eq!(coord2.in_flight_count(), 1);
coord.trigger();
assert!(coord2.is_shutting_down());
}
#[tokio::test]
async fn test_wait_for_tasks_immediate_when_empty() {
let coord = ShutdownCoordinator::new(1);
assert!(coord.wait_for_tasks().await);
}
#[tokio::test]
async fn test_wait_for_tasks_completes_when_guard_dropped() {
let coord = ShutdownCoordinator::new(5);
let coord2 = coord.clone();
tokio::spawn(async move {
let _guard = coord2.track_task();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(coord.in_flight_count() > 0);
assert!(coord.wait_for_tasks().await);
}
#[tokio::test]
async fn test_wait_for_tasks_timeout() {
let coord = ShutdownCoordinator::new(1); let _guard = coord.track_task();
let start = tokio::time::Instant::now();
assert!(!coord.wait_for_tasks().await);
let elapsed = start.elapsed();
assert!(elapsed.as_secs() >= 1);
}
#[tokio::test]
async fn test_shutdown_signal_via_token_cancellation() {
let coord = ShutdownCoordinator::new(30);
let coord2 = coord.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
coord2.trigger();
});
let start = tokio::time::Instant::now();
shutdown_signal(coord.clone()).await;
assert!(start.elapsed().as_millis() < 1000);
assert!(coord.is_shutting_down());
}
}