use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, watch, Notify};
use tracing::{debug, info, warn};
pub const DEFAULT_SHUTDOWN_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownState {
Running,
Draining,
Stopped,
}
impl std::fmt::Display for ShutdownState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShutdownState::Running => write!(f, "running"),
ShutdownState::Draining => write!(f, "draining"),
ShutdownState::Stopped => write!(f, "stopped"),
}
}
}
#[derive(Clone)]
pub struct ShutdownController {
inner: Arc<ShutdownControllerInner>,
}
struct ShutdownControllerInner {
is_shutting_down: AtomicBool,
state: std::sync::RwLock<ShutdownState>,
shutdown_notify: Notify,
state_tx: watch::Sender<ShutdownState>,
state_rx: watch::Receiver<ShutdownState>,
shutdown_tx: broadcast::Sender<()>,
active_connections: AtomicU64,
drain_timeout: Duration,
shutdown_started: std::sync::RwLock<Option<Instant>>,
}
impl ShutdownController {
pub fn new() -> Self {
Self::with_timeout(Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS))
}
pub fn with_timeout(drain_timeout: Duration) -> Self {
let (state_tx, state_rx) = watch::channel(ShutdownState::Running);
let (shutdown_tx, _) = broadcast::channel(16);
Self {
inner: Arc::new(ShutdownControllerInner {
is_shutting_down: AtomicBool::new(false),
state: std::sync::RwLock::new(ShutdownState::Running),
shutdown_notify: Notify::new(),
state_tx,
state_rx,
shutdown_tx,
active_connections: AtomicU64::new(0),
drain_timeout,
shutdown_started: std::sync::RwLock::new(None),
}),
}
}
pub fn is_shutting_down(&self) -> bool {
self.inner.is_shutting_down.load(Ordering::SeqCst)
}
pub fn state(&self) -> ShutdownState {
*self.inner.state.read().unwrap()
}
pub fn state_receiver(&self) -> watch::Receiver<ShutdownState> {
self.inner.state_rx.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.inner.shutdown_tx.subscribe()
}
pub async fn initiate_shutdown(&self) {
if self
.inner
.is_shutting_down
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
debug!("Shutdown already in progress");
return;
}
info!("Initiating graceful shutdown");
*self.inner.shutdown_started.write().unwrap() = Some(Instant::now());
self.set_state(ShutdownState::Draining);
#[cfg(target_os = "linux")]
systemd_notify_stopping();
let _ = self.inner.shutdown_tx.send(());
self.inner.shutdown_notify.notify_waiters();
self.wait_for_drain().await;
self.flush_logs();
self.set_state(ShutdownState::Stopped);
info!("Graceful shutdown complete");
}
pub async fn wait_for_shutdown(&self) {
if self.is_shutting_down() {
return;
}
self.inner.shutdown_notify.notified().await;
}
pub fn connection_start(&self) {
self.inner.active_connections.fetch_add(1, Ordering::SeqCst);
}
pub fn connection_end(&self) {
self.inner.active_connections.fetch_sub(1, Ordering::SeqCst);
}
pub fn active_connections(&self) -> u64 {
self.inner.active_connections.load(Ordering::SeqCst)
}
pub fn connection_guard(&self) -> ConnectionGuard {
self.connection_start();
ConnectionGuard {
controller: self.clone(),
}
}
pub fn drain_timeout(&self) -> Duration {
self.inner.drain_timeout
}
pub fn shutdown_elapsed(&self) -> Option<Duration> {
self.inner
.shutdown_started
.read()
.unwrap()
.map(|started| started.elapsed())
}
pub fn retry_after_secs(&self) -> u64 {
match self.shutdown_elapsed() {
Some(elapsed) => {
let remaining = self.inner.drain_timeout.saturating_sub(elapsed);
remaining.as_secs().saturating_add(5) }
None => DEFAULT_DRAIN_TIMEOUT_SECS + 5,
}
}
fn set_state(&self, state: ShutdownState) {
*self.inner.state.write().unwrap() = state;
let _ = self.inner.state_tx.send(state);
info!("Shutdown state changed to: {}", state);
}
async fn wait_for_drain(&self) {
let timeout = self.inner.drain_timeout;
let start = Instant::now();
info!(
"Waiting for {} active connections to drain (timeout: {:?})",
self.active_connections(),
timeout
);
loop {
let active = self.active_connections();
if active == 0 {
info!("All connections drained successfully");
return;
}
if start.elapsed() >= timeout {
warn!(
"Drain timeout reached with {} active connections remaining",
active
);
return;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
fn flush_logs(&self) {
debug!("Flushing logs before shutdown");
std::thread::sleep(Duration::from_millis(50));
}
}
impl Default for ShutdownController {
fn default() -> Self {
Self::new()
}
}
pub struct ConnectionGuard {
controller: ShutdownController,
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.controller.connection_end();
}
}
pub async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm =
signal(SignalKind::terminate()).expect("Failed to install SIGTERM handler");
let mut sigint = signal(SignalKind::interrupt()).expect("Failed to install SIGINT handler");
let mut sighup = signal(SignalKind::hangup()).expect("Failed to install SIGHUP handler");
tokio::select! {
_ = sigterm.recv() => {
info!("Received SIGTERM");
}
_ = sigint.recv() => {
info!("Received SIGINT");
}
_ = sighup.recv() => {
info!("Received SIGHUP");
}
}
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
info!("Received Ctrl+C");
}
_ = terminate => {}
}
}
pub async fn shutdown_signal_with_controller(controller: ShutdownController) {
shutdown_signal().await;
controller.initiate_shutdown().await;
}
#[cfg(target_os = "linux")]
pub fn systemd_notify_ready() {
if let Err(e) = sd_notify("READY=1") {
debug!(
"Failed to notify systemd ready (may not be running under systemd): {}",
e
);
} else {
info!("Notified systemd: READY");
}
}
#[cfg(target_os = "linux")]
pub fn systemd_notify_stopping() {
if let Err(e) = sd_notify("STOPPING=1") {
debug!("Failed to notify systemd stopping: {}", e);
} else {
info!("Notified systemd: STOPPING");
}
}
#[cfg(target_os = "linux")]
pub fn systemd_notify_status(status: &str) {
if let Err(e) = sd_notify(&format!("STATUS={}", status)) {
debug!("Failed to notify systemd status: {}", e);
}
}
#[cfg(target_os = "linux")]
pub fn systemd_watchdog_ping() {
if let Err(e) = sd_notify("WATCHDOG=1") {
debug!("Failed to send watchdog ping: {}", e);
}
}
#[cfg(target_os = "linux")]
fn sd_notify(state: &str) -> std::io::Result<()> {
use std::os::unix::net::UnixDatagram;
let socket_path = match std::env::var("NOTIFY_SOCKET") {
Ok(path) => path,
Err(_) => {
return Ok(());
}
};
let socket_path = if let Some(rest) = socket_path.strip_prefix('@') {
format!("\0{rest}")
} else {
socket_path
};
let socket = UnixDatagram::unbound()?;
if let Some(rest) = socket_path.strip_prefix('\0') {
use std::os::unix::net::SocketAddr;
let addr = SocketAddr::from_pathname(rest)?;
socket.send_to(state.as_bytes(), addr.as_pathname().unwrap())?;
} else {
socket.send_to(state.as_bytes(), &socket_path)?;
}
Ok(())
}
#[cfg(not(target_os = "linux"))]
pub fn systemd_notify_ready() {
debug!("systemd_notify_ready: not on Linux, skipping");
}
#[cfg(not(target_os = "linux"))]
pub fn systemd_notify_stopping() {
debug!("systemd_notify_stopping: not on Linux, skipping");
}
#[cfg(not(target_os = "linux"))]
pub fn systemd_notify_status(_status: &str) {
debug!("systemd_notify_status: not on Linux, skipping");
}
#[cfg(not(target_os = "linux"))]
pub fn systemd_watchdog_ping() {
debug!("systemd_watchdog_ping: not on Linux, skipping");
}
pub async fn watchdog_task(interval: Duration, mut shutdown_rx: broadcast::Receiver<()>) {
info!(
"Starting systemd watchdog task with {:?} interval",
interval
);
loop {
tokio::select! {
_ = tokio::time::sleep(interval) => {
systemd_watchdog_ping();
}
_ = shutdown_rx.recv() => {
info!("Watchdog task stopping due to shutdown");
break;
}
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct HealthStatus {
pub status: String,
pub healthy: bool,
pub shutdown_state: String,
pub active_connections: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub drain_remaining_secs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub retry_after_secs: Option<u64>,
}
impl ShutdownController {
pub fn health_status(&self) -> HealthStatus {
let state = self.state();
let active = self.active_connections();
match state {
ShutdownState::Running => HealthStatus {
status: "ok".to_string(),
healthy: true,
shutdown_state: state.to_string(),
active_connections: active,
drain_remaining_secs: None,
retry_after_secs: None,
},
ShutdownState::Draining => {
let drain_remaining = self
.shutdown_elapsed()
.map(|elapsed| self.drain_timeout().saturating_sub(elapsed).as_secs());
HealthStatus {
status: "draining".to_string(),
healthy: false,
shutdown_state: state.to_string(),
active_connections: active,
drain_remaining_secs: drain_remaining,
retry_after_secs: Some(self.retry_after_secs()),
}
}
ShutdownState::Stopped => HealthStatus {
status: "stopped".to_string(),
healthy: false,
shutdown_state: state.to_string(),
active_connections: active,
drain_remaining_secs: Some(0),
retry_after_secs: Some(self.retry_after_secs()),
},
}
}
}
pub mod axum_integration {
use super::*;
use axum::{
body::Body,
http::{header, Request, Response, StatusCode},
};
use std::task::{Context, Poll};
use tower::{Layer, Service};
#[derive(Clone)]
pub struct ShutdownLayer {
controller: ShutdownController,
}
impl ShutdownLayer {
pub fn new(controller: ShutdownController) -> Self {
Self { controller }
}
}
impl<S> Layer<S> for ShutdownLayer {
type Service = ShutdownService<S>;
fn layer(&self, inner: S) -> Self::Service {
ShutdownService {
inner,
controller: self.controller.clone(),
}
}
}
#[derive(Clone)]
pub struct ShutdownService<S> {
inner: S,
controller: ShutdownController,
}
impl<S, ReqBody> Service<Request<ReqBody>> for ShutdownService<S>
where
S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
ReqBody: Send + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let controller = self.controller.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
if controller.is_shutting_down() {
let retry_after = controller.retry_after_secs().to_string();
let health = controller.health_status();
let body = serde_json::to_string(&health).unwrap_or_else(|_| {
r#"{"status":"unavailable","healthy":false}"#.to_string()
});
let response = Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.header(header::RETRY_AFTER, retry_after)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.unwrap();
return Ok(response);
}
let _guard = controller.connection_guard();
inner.call(req).await
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_shutdown_controller_new() {
let controller = ShutdownController::new();
assert!(!controller.is_shutting_down());
assert_eq!(controller.state(), ShutdownState::Running);
assert_eq!(controller.active_connections(), 0);
}
#[tokio::test]
async fn test_shutdown_controller_with_timeout() {
let controller = ShutdownController::with_timeout(Duration::from_secs(60));
assert_eq!(controller.drain_timeout(), Duration::from_secs(60));
}
#[tokio::test]
async fn test_connection_tracking() {
let controller = ShutdownController::new();
controller.connection_start();
assert_eq!(controller.active_connections(), 1);
controller.connection_start();
assert_eq!(controller.active_connections(), 2);
controller.connection_end();
assert_eq!(controller.active_connections(), 1);
controller.connection_end();
assert_eq!(controller.active_connections(), 0);
}
#[tokio::test]
async fn test_connection_guard() {
let controller = ShutdownController::new();
{
let _guard = controller.connection_guard();
assert_eq!(controller.active_connections(), 1);
{
let _guard2 = controller.connection_guard();
assert_eq!(controller.active_connections(), 2);
}
assert_eq!(controller.active_connections(), 1);
}
assert_eq!(controller.active_connections(), 0);
}
#[tokio::test]
async fn test_shutdown_initiation() {
let controller = ShutdownController::with_timeout(Duration::from_millis(100));
assert!(!controller.is_shutting_down());
assert_eq!(controller.state(), ShutdownState::Running);
controller.initiate_shutdown().await;
assert!(controller.is_shutting_down());
assert_eq!(controller.state(), ShutdownState::Stopped);
}
#[tokio::test]
async fn test_shutdown_only_once() {
let controller = ShutdownController::with_timeout(Duration::from_millis(100));
let controller2 = controller.clone();
let handle1 = tokio::spawn(async move {
controller.initiate_shutdown().await;
});
let handle2 = tokio::spawn(async move {
controller2.initiate_shutdown().await;
});
let (r1, r2) = tokio::join!(handle1, handle2);
r1.unwrap();
r2.unwrap();
}
#[tokio::test]
async fn test_health_status_running() {
let controller = ShutdownController::new();
let health = controller.health_status();
assert!(health.healthy);
assert_eq!(health.status, "ok");
assert_eq!(health.shutdown_state, "running");
assert!(health.retry_after_secs.is_none());
}
#[tokio::test]
async fn test_subscribe_and_notify() {
let controller = ShutdownController::with_timeout(Duration::from_millis(100));
let mut rx = controller.subscribe();
let controller2 = controller.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
controller2.initiate_shutdown().await;
});
let result = tokio::time::timeout(Duration::from_secs(1), rx.recv()).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_state_receiver() {
let controller = ShutdownController::with_timeout(Duration::from_millis(100));
let mut rx = controller.state_receiver();
assert_eq!(*rx.borrow(), ShutdownState::Running);
let controller2 = controller.clone();
tokio::spawn(async move {
controller2.initiate_shutdown().await;
});
rx.changed().await.unwrap();
let state = *rx.borrow();
assert!(state == ShutdownState::Draining || state == ShutdownState::Stopped);
}
#[tokio::test]
async fn test_drain_with_active_connections() {
let controller = ShutdownController::with_timeout(Duration::from_millis(500));
let guard = controller.connection_guard();
let controller2 = controller.clone();
let shutdown_handle = tokio::spawn(async move {
controller2.initiate_shutdown().await;
});
tokio::time::sleep(Duration::from_millis(100)).await;
drop(guard);
tokio::time::timeout(Duration::from_secs(1), shutdown_handle)
.await
.unwrap()
.unwrap();
assert_eq!(controller.state(), ShutdownState::Stopped);
}
#[test]
fn test_shutdown_state_display() {
assert_eq!(ShutdownState::Running.to_string(), "running");
assert_eq!(ShutdownState::Draining.to_string(), "draining");
assert_eq!(ShutdownState::Stopped.to_string(), "stopped");
}
#[test]
fn test_retry_after_secs() {
let controller = ShutdownController::with_timeout(Duration::from_secs(30));
assert_eq!(controller.retry_after_secs(), 35);
}
#[test]
fn test_health_status_serialization() {
let status = HealthStatus {
status: "ok".to_string(),
healthy: true,
shutdown_state: "running".to_string(),
active_connections: 5,
drain_remaining_secs: None,
retry_after_secs: None,
};
let json = serde_json::to_string(&status).unwrap();
assert!(json.contains("\"status\":\"ok\""));
assert!(json.contains("\"healthy\":true"));
assert!(!json.contains("drain_remaining_secs"));
assert!(!json.contains("retry_after_secs"));
}
#[tokio::test]
async fn test_default_trait() {
let controller = ShutdownController::default();
assert!(!controller.is_shutting_down());
assert_eq!(
controller.drain_timeout(),
Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS)
);
}
}