use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::time::Duration;
use tokio_util::sync::CancellationToken;
#[allow(dead_code)] const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);
const DRAIN_POLL_INTERVAL: Duration = Duration::from_millis(100);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RestartPhase {
DrainStart { in_flight: u32 },
DrainComplete,
Shutdown,
}
#[derive(Debug, Clone)]
pub struct ForceTerminatedRequest {
pub request_id: String,
pub elapsed: Duration,
}
pub struct ShutdownCoordinator {
in_flight: AtomicU32,
accepting: AtomicBool,
shutdown_token: CancellationToken,
drain_timeout: Duration,
restart_initiated: AtomicBool,
}
impl ShutdownCoordinator {
#[allow(dead_code)] pub fn new(shutdown_token: CancellationToken) -> Self {
Self {
in_flight: AtomicU32::new(0),
accepting: AtomicBool::new(true),
shutdown_token,
drain_timeout: DEFAULT_DRAIN_TIMEOUT,
restart_initiated: AtomicBool::new(false),
}
}
pub fn with_drain_timeout(shutdown_token: CancellationToken, drain_timeout: Duration) -> Self {
Self {
in_flight: AtomicU32::new(0),
accepting: AtomicBool::new(true),
shutdown_token,
drain_timeout,
restart_initiated: AtomicBool::new(false),
}
}
pub fn acquire(&self) -> Option<MessageGuard<'_>> {
if !self.accepting.load(Ordering::SeqCst) {
return None;
}
self.in_flight.fetch_add(1, Ordering::SeqCst);
if !self.accepting.load(Ordering::SeqCst) {
self.in_flight.fetch_sub(1, Ordering::SeqCst);
return None;
}
Some(MessageGuard { coordinator: self })
}
pub async fn initiate_shutdown(&self) {
self.accepting.store(false, Ordering::SeqCst);
tracing::info!("shutdown initiated, draining in-flight messages…");
let drain = async {
while self.in_flight.load(Ordering::SeqCst) > 0 {
tokio::time::sleep(DRAIN_POLL_INTERVAL).await;
}
};
match tokio::time::timeout(self.drain_timeout, drain).await {
Ok(()) => {
tracing::info!("all in-flight messages drained");
}
Err(_) => {
tracing::warn!(
remaining = self.in_flight.load(Ordering::SeqCst),
"drain timeout reached, forcing shutdown"
);
}
}
self.shutdown_token.cancel();
}
pub async fn initiate_restart(&self) {
self.restart_initiated.store(true, Ordering::SeqCst);
let current_in_flight = self.in_flight.load(Ordering::SeqCst);
Self::log_phase(&RestartPhase::DrainStart {
in_flight: current_in_flight,
});
self.accepting.store(false, Ordering::SeqCst);
let drain = async {
while self.in_flight.load(Ordering::SeqCst) > 0 {
tokio::time::sleep(DRAIN_POLL_INTERVAL).await;
}
};
match tokio::time::timeout(self.drain_timeout, drain).await {
Ok(()) => {
}
Err(_) => {
let remaining = self.in_flight.load(Ordering::SeqCst);
tracing::warn!(
remaining,
drain_timeout_secs = self.drain_timeout.as_secs(),
event = "force-terminate",
"drain timeout expired, force-terminating remaining requests"
);
for i in 0..remaining {
let terminated = ForceTerminatedRequest {
request_id: format!("in-flight-{}", i),
elapsed: self.drain_timeout,
};
tracing::warn!(
request_id = %terminated.request_id,
elapsed_secs = terminated.elapsed.as_secs(),
event = "request-force-terminated",
"force-terminated in-flight request due to drain timeout"
);
}
}
}
Self::log_phase(&RestartPhase::DrainComplete);
Self::log_phase(&RestartPhase::Shutdown);
self.shutdown_token.cancel();
}
pub fn log_phase(phase: &RestartPhase) {
match phase {
RestartPhase::DrainStart { in_flight } => {
tracing::info!(
event = "drain-start",
in_flight = in_flight,
"graceful restart: drain phase started"
);
}
RestartPhase::DrainComplete => {
tracing::info!(
event = "drain-complete",
"graceful restart: drain phase completed"
);
}
RestartPhase::Shutdown => {
tracing::info!(
event = "shutdown",
"graceful restart: shutdown phase — process exiting"
);
}
}
}
pub fn in_flight_count(&self) -> u32 {
self.in_flight.load(Ordering::SeqCst)
}
pub fn is_accepting(&self) -> bool {
self.accepting.load(Ordering::SeqCst)
}
pub fn is_restart(&self) -> bool {
self.restart_initiated.load(Ordering::SeqCst)
}
pub fn drain_timeout(&self) -> Duration {
self.drain_timeout
}
}
#[cfg(unix)]
pub async fn register_sigusr1_handler(
coordinator: std::sync::Arc<ShutdownCoordinator>,
) {
use tokio::signal::unix::{signal, SignalKind};
let mut sigusr1 = signal(SignalKind::user_defined1())
.expect("failed to register SIGUSR1 handler");
tokio::spawn(async move {
sigusr1.recv().await;
tracing::info!(
event = "sigusr1-received",
"received SIGUSR1, initiating graceful restart"
);
coordinator.initiate_restart().await;
});
}
pub struct MessageGuard<'a> {
coordinator: &'a ShutdownCoordinator,
}
impl Drop for MessageGuard<'_> {
fn drop(&mut self) {
self.coordinator.in_flight.fetch_sub(1, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn acquire_increments_and_drop_decrements() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::new(token);
assert_eq!(coord.in_flight_count(), 0);
assert!(coord.is_accepting());
let g1 = coord.acquire().expect("should acquire");
assert_eq!(coord.in_flight_count(), 1);
let g2 = coord.acquire().expect("should acquire");
assert_eq!(coord.in_flight_count(), 2);
drop(g1);
assert_eq!(coord.in_flight_count(), 1);
drop(g2);
assert_eq!(coord.in_flight_count(), 0);
}
#[test]
fn acquire_returns_none_after_accepting_set_false() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::new(token);
coord.accepting.store(false, Ordering::SeqCst);
assert!(!coord.is_accepting());
assert!(coord.acquire().is_none());
}
#[tokio::test]
async fn initiate_shutdown_drains_then_cancels_token() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::with_drain_timeout(token.clone(), Duration::from_secs(5));
let guard = coord.acquire().expect("should acquire");
let ((), ()) = tokio::join!(coord.initiate_shutdown(), async {
tokio::time::sleep(Duration::from_millis(200)).await;
drop(guard);
});
assert!(!coord.is_accepting());
assert_eq!(coord.in_flight_count(), 0);
assert!(token.is_cancelled());
}
#[tokio::test]
async fn initiate_shutdown_times_out_and_still_cancels() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::with_drain_timeout(
token.clone(),
Duration::from_millis(100), );
let _guard = coord.acquire().expect("should acquire");
coord.initiate_shutdown().await;
assert!(token.is_cancelled());
assert_eq!(coord.in_flight_count(), 1);
}
#[tokio::test]
async fn shutdown_with_no_in_flight_completes_immediately() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::new(token.clone());
coord.initiate_shutdown().await;
assert!(token.is_cancelled());
assert!(!coord.is_accepting());
assert_eq!(coord.in_flight_count(), 0);
}
#[test]
fn multiple_guards_track_correctly() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::new(token);
let guards: Vec<_> = (0..10)
.map(|_| coord.acquire().expect("should acquire"))
.collect();
assert_eq!(coord.in_flight_count(), 10);
drop(guards);
assert_eq!(coord.in_flight_count(), 0);
}
#[tokio::test]
async fn initiate_restart_stops_accepting_new_connections() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::new(token.clone());
assert!(coord.is_accepting());
coord.initiate_restart().await;
assert!(!coord.is_accepting());
assert!(coord.is_restart());
assert!(token.is_cancelled());
}
#[tokio::test]
async fn initiate_restart_drains_in_flight_requests() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::with_drain_timeout(token.clone(), Duration::from_secs(5));
let guard = coord.acquire().expect("should acquire");
assert_eq!(coord.in_flight_count(), 1);
let ((), ()) = tokio::join!(coord.initiate_restart(), async {
tokio::time::sleep(Duration::from_millis(100)).await;
drop(guard);
});
assert_eq!(coord.in_flight_count(), 0);
assert!(!coord.is_accepting());
assert!(token.is_cancelled());
}
#[tokio::test]
async fn initiate_restart_force_terminates_on_timeout() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::with_drain_timeout(
token.clone(),
Duration::from_millis(100), );
let _guard1 = coord.acquire().expect("should acquire");
let _guard2 = coord.acquire().expect("should acquire");
coord.initiate_restart().await;
assert!(token.is_cancelled());
assert!(!coord.is_accepting());
assert_eq!(coord.in_flight_count(), 2);
}
#[tokio::test]
async fn restart_rejects_new_connections_immediately() {
let token = CancellationToken::new();
let coord = std::sync::Arc::new(ShutdownCoordinator::with_drain_timeout(
token.clone(),
Duration::from_secs(5),
));
let guard = coord.acquire().expect("should acquire");
let coord_clone = coord.clone();
let restart_handle = tokio::spawn(async move {
coord_clone.initiate_restart().await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(coord.acquire().is_none());
drop(guard);
restart_handle.await.unwrap();
}
#[test]
fn restart_phase_log_does_not_panic() {
ShutdownCoordinator::log_phase(&RestartPhase::DrainStart { in_flight: 5 });
ShutdownCoordinator::log_phase(&RestartPhase::DrainComplete);
ShutdownCoordinator::log_phase(&RestartPhase::Shutdown);
}
#[test]
fn default_drain_timeout_is_30_seconds() {
let token = CancellationToken::new();
let coord = ShutdownCoordinator::new(token);
assert_eq!(coord.drain_timeout(), Duration::from_secs(30));
}
}