use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, watch};
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct ShutdownConfig {
pub timeout: Duration,
pub handle_signals: bool,
pub force_shutdown_delay: Duration,
}
impl Default for ShutdownConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
handle_signals: true,
force_shutdown_delay: Duration::from_secs(5),
}
}
}
#[derive(Clone)]
pub struct ShutdownCoordinator {
shutdown_tx: broadcast::Sender<()>,
state_tx: Arc<watch::Sender<ShutdownState>>,
state_rx: watch::Receiver<ShutdownState>,
active_count: Arc<AtomicUsize>,
is_shutting_down: Arc<AtomicBool>,
config: ShutdownConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownState {
Running,
Draining,
Shutdown,
}
impl ShutdownCoordinator {
pub fn new(config: ShutdownConfig) -> Self {
let (shutdown_tx, _) = broadcast::channel(1);
let (state_tx, state_rx) = watch::channel(ShutdownState::Running);
Self {
shutdown_tx,
state_tx: Arc::new(state_tx),
state_rx,
active_count: Arc::new(AtomicUsize::new(0)),
is_shutting_down: Arc::new(AtomicBool::new(false)),
config,
}
}
pub fn with_defaults() -> Self {
Self::new(ShutdownConfig::default())
}
pub fn is_shutting_down(&self) -> bool {
self.is_shutting_down.load(Ordering::SeqCst)
}
pub fn active_count(&self) -> usize {
self.active_count.load(Ordering::SeqCst)
}
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
pub fn watch_state(&self) -> watch::Receiver<ShutdownState> {
self.state_rx.clone()
}
pub fn request_started(&self) {
self.active_count.fetch_add(1, Ordering::SeqCst);
}
pub fn request_completed(&self) {
let prev = self.active_count.fetch_sub(1, Ordering::SeqCst);
debug!(active_count = prev - 1, "Request completed");
}
pub fn request_guard(&self) -> RequestGuard {
self.request_started();
RequestGuard {
coordinator: self.clone(),
}
}
pub async fn shutdown(&self) {
if self.is_shutting_down.swap(true, Ordering::SeqCst) {
debug!("Shutdown already in progress");
return;
}
info!("Initiating graceful shutdown...");
let _ = self.state_tx.send(ShutdownState::Draining);
let _ = self.shutdown_tx.send(());
let start = std::time::Instant::now();
let timeout = self.config.timeout;
loop {
let active = self.active_count();
if active == 0 {
info!("All requests completed, shutting down");
break;
}
if start.elapsed() >= timeout {
warn!(
active_count = active,
"Shutdown timeout reached, {} requests still active", active
);
break;
}
debug!(
active_count = active,
elapsed_secs = start.elapsed().as_secs(),
"Waiting for {} active request(s) to complete",
active
);
tokio::time::sleep(Duration::from_millis(100)).await;
}
if self.active_count() > 0 {
warn!(
"Waiting {} seconds before forcing shutdown...",
self.config.force_shutdown_delay.as_secs()
);
tokio::time::sleep(self.config.force_shutdown_delay).await;
}
let _ = self.state_tx.send(ShutdownState::Shutdown);
info!("Graceful shutdown complete");
}
pub fn shutdown_signal(&self) -> impl Future<Output = ()> + Send + 'static {
let mut rx = self.subscribe();
async move {
let _ = rx.recv().await;
}
}
}
pub struct RequestGuard {
coordinator: ShutdownCoordinator,
}
impl Drop for RequestGuard {
fn drop(&mut self) {
self.coordinator.request_completed();
}
}
#[cfg(unix)]
pub async fn signal_shutdown() {
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = signal(SignalKind::terminate()).expect("failed to install SIGTERM handler");
let mut sigint = signal(SignalKind::interrupt()).expect("failed to install SIGINT handler");
tokio::select! {
_ = sigterm.recv() => {
info!("Received SIGTERM");
}
_ = sigint.recv() => {
info!("Received SIGINT (Ctrl+C)");
}
}
}
#[cfg(not(unix))]
pub async fn signal_shutdown() {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
info!("Received Ctrl+C");
}
pub async fn os_signal_shutdown() {
signal_shutdown().await;
}
pub async fn run_with_graceful_shutdown(
listener: tokio::net::TcpListener,
app: axum::Router,
config: ShutdownConfig,
) -> std::io::Result<()> {
let coordinator = ShutdownCoordinator::new(config.clone());
let coordinator_clone = coordinator.clone();
if config.handle_signals {
let coordinator_for_signal = coordinator.clone();
tokio::spawn(async move {
signal_shutdown().await;
coordinator_for_signal.shutdown().await;
});
}
axum::serve(listener, app)
.with_graceful_shutdown(coordinator_clone.shutdown_signal())
.await?;
coordinator.shutdown().await;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[tokio::test]
async fn test_shutdown_resilience() {
let config = ShutdownConfig {
timeout: Duration::from_millis(50),
handle_signals: false,
force_shutdown_delay: Duration::from_millis(10),
};
let coordinator = ShutdownCoordinator::new(config);
coordinator.request_started();
let start = Instant::now();
coordinator.shutdown().await;
let duration = start.elapsed();
assert!(duration >= Duration::from_millis(50));
assert!(duration < Duration::from_secs(5));
assert_eq!(coordinator.active_count(), 1);
assert_eq!(*coordinator.watch_state().borrow(), ShutdownState::Shutdown);
}
#[tokio::test]
async fn test_shutdown_coordinator_creation() {
let coordinator = ShutdownCoordinator::with_defaults();
assert!(!coordinator.is_shutting_down());
assert_eq!(coordinator.active_count(), 0);
}
#[tokio::test]
async fn test_request_tracking() {
let coordinator = ShutdownCoordinator::with_defaults();
coordinator.request_started();
assert_eq!(coordinator.active_count(), 1);
coordinator.request_started();
assert_eq!(coordinator.active_count(), 2);
coordinator.request_completed();
assert_eq!(coordinator.active_count(), 1);
coordinator.request_completed();
assert_eq!(coordinator.active_count(), 0);
}
#[tokio::test]
async fn test_request_guard() {
let coordinator = ShutdownCoordinator::with_defaults();
assert_eq!(coordinator.active_count(), 0);
{
let _guard = coordinator.request_guard();
assert_eq!(coordinator.active_count(), 1);
}
assert_eq!(coordinator.active_count(), 0);
}
#[tokio::test]
async fn test_shutdown_idempotency() {
let coordinator = ShutdownCoordinator::with_defaults();
let c1 = coordinator.clone();
let task1 = tokio::spawn(async move {
c1.shutdown().await;
});
let c2 = coordinator.clone();
let task2 = tokio::spawn(async move {
c2.shutdown().await;
});
let _ = tokio::join!(task1, task2);
assert!(coordinator.is_shutting_down());
}
#[tokio::test]
async fn test_shutdown_with_no_active_requests() {
let config = ShutdownConfig {
timeout: Duration::from_millis(100),
handle_signals: false,
force_shutdown_delay: Duration::from_millis(10),
};
let coordinator = ShutdownCoordinator::new(config);
let start = std::time::Instant::now();
coordinator.shutdown().await;
assert!(start.elapsed() < Duration::from_millis(50));
assert!(coordinator.is_shutting_down());
}
#[tokio::test]
async fn test_shutdown_waits_for_active_requests() {
let config = ShutdownConfig {
timeout: Duration::from_secs(5),
handle_signals: false,
force_shutdown_delay: Duration::from_millis(10),
};
let coordinator = ShutdownCoordinator::new(config);
let coordinator_clone = coordinator.clone();
tokio::spawn(async move {
let _guard = coordinator_clone.request_guard();
tokio::time::sleep(Duration::from_millis(200)).await;
});
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(coordinator.active_count(), 1);
coordinator.shutdown().await;
assert_eq!(coordinator.active_count(), 0);
}
#[tokio::test]
async fn test_shutdown_signal_propagation() {
let coordinator = ShutdownCoordinator::with_defaults();
let mut rx = coordinator.subscribe();
let handle = tokio::spawn(async move { rx.recv().await });
coordinator.shutdown().await;
assert!(handle.await.is_ok());
}
#[tokio::test]
async fn test_shutdown_state_transitions() {
let config = ShutdownConfig {
timeout: Duration::from_millis(100),
handle_signals: false,
force_shutdown_delay: Duration::from_millis(10),
};
let coordinator = ShutdownCoordinator::new(config);
let mut state_rx = coordinator.watch_state();
assert_eq!(*state_rx.borrow(), ShutdownState::Running);
let c_clone = coordinator.clone();
tokio::spawn(async move {
c_clone.shutdown().await;
});
while state_rx.changed().await.is_ok() {
let state = *state_rx.borrow();
if state == ShutdownState::Shutdown {
break;
}
}
assert_eq!(*state_rx.borrow(), ShutdownState::Shutdown);
}
}