use std::{
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
time::Duration,
};
use tokio::{
sync::{Notify, broadcast, watch},
time::timeout,
};
pub struct ShutdownCoordinator {
shutdown_initiated: AtomicBool,
shutdown_tx: broadcast::Sender<()>,
ready_tx: watch::Sender<bool>,
ready_rx: watch::Receiver<bool>,
in_flight: AtomicU64,
drain_complete: Notify,
config: ShutdownConfig,
}
#[derive(Debug, Clone)]
pub struct ShutdownConfig {
pub timeout: Duration,
pub delay: Duration,
}
impl Default for ShutdownConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
delay: Duration::from_secs(5),
}
}
}
impl ShutdownCoordinator {
pub fn new(config: ShutdownConfig) -> Arc<Self> {
let (shutdown_tx, _) = broadcast::channel(1);
let (ready_tx, ready_rx) = watch::channel(true);
Arc::new(Self {
shutdown_initiated: AtomicBool::new(false),
shutdown_tx,
ready_tx,
ready_rx,
in_flight: AtomicU64::new(0),
drain_complete: Notify::new(),
config,
})
}
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
pub fn ready_watch(&self) -> watch::Receiver<bool> {
self.ready_rx.clone()
}
pub fn is_ready(&self) -> bool {
*self.ready_rx.borrow()
}
pub fn is_shutting_down(&self) -> bool {
self.shutdown_initiated.load(Ordering::SeqCst)
}
pub fn request_started(&self) -> Option<RequestGuard<'_>> {
if self.is_shutting_down() {
return None;
}
self.in_flight.fetch_add(1, Ordering::SeqCst);
Some(RequestGuard { coordinator: self })
}
pub fn in_flight_count(&self) -> u64 {
self.in_flight.load(Ordering::SeqCst)
}
fn request_completed(&self) {
let prev = self.in_flight.fetch_sub(1, Ordering::SeqCst);
if prev == 1 && self.is_shutting_down() {
self.drain_complete.notify_waiters();
}
}
pub async fn shutdown(&self) {
if self.shutdown_initiated.swap(true, Ordering::SeqCst) {
return;
}
tracing::info!("Initiating graceful shutdown");
let _ = self.ready_tx.send(false);
tracing::info!("Marked as not ready, waiting for load balancer deregistration");
tokio::time::sleep(self.config.delay).await;
let _ = self.shutdown_tx.send(());
tracing::info!("Shutdown signal sent to all components");
let in_flight = self.in_flight.load(Ordering::SeqCst);
if in_flight > 0 {
tracing::info!("Waiting for {} in-flight requests to complete", in_flight);
let drain_result = timeout(self.config.timeout, self.wait_for_drain()).await;
match drain_result {
Ok(()) => {
tracing::info!("All in-flight requests completed");
},
Err(_) => {
let remaining = self.in_flight.load(Ordering::SeqCst);
tracing::warn!(
"Shutdown timeout reached with {} requests still in-flight",
remaining
);
},
}
}
tracing::info!("Graceful shutdown complete");
}
async fn wait_for_drain(&self) {
while self.in_flight.load(Ordering::SeqCst) > 0 {
self.drain_complete.notified().await;
}
}
}
pub struct RequestGuard<'a> {
coordinator: &'a ShutdownCoordinator,
}
impl Drop for RequestGuard<'_> {
fn drop(&mut self) {
self.coordinator.request_completed();
}
}
pub async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c().await.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received Ctrl+C signal");
}
_ = terminate => {
tracing::info!("Received SIGTERM signal");
}
}
}