use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Notify, RwLock};
use tracing::{error, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ShutdownPhase {
PreDrain,
Drain,
Cleanup,
Final,
}
pub struct ShutdownHook {
pub name: String,
pub phase: ShutdownPhase,
pub callback: Box<dyn Fn() -> Result<(), String> + Send + Sync>,
}
impl std::fmt::Debug for ShutdownHook {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShutdownHook")
.field("name", &self.name)
.field("phase", &self.phase)
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookResult {
pub name: String,
pub phase: ShutdownPhase,
pub success: bool,
pub error: Option<String>,
pub duration_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShutdownReport {
pub total_duration_ms: u64,
pub hooks_succeeded: usize,
pub hooks_failed: usize,
pub completed_in_time: bool,
pub hook_results: Vec<HookResult>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ShutdownState {
Running,
ShuttingDown,
Completed,
}
#[derive(Clone)]
pub struct ShutdownManager {
state: Arc<RwLock<ShutdownState>>,
hooks: Arc<RwLock<Vec<ShutdownHook>>>,
notify: Arc<Notify>,
timeout: Duration,
}
impl std::fmt::Debug for ShutdownManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShutdownManager")
.field("timeout", &self.timeout)
.finish()
}
}
impl ShutdownManager {
pub fn new(timeout: Duration) -> Self {
Self {
state: Arc::new(RwLock::new(ShutdownState::Running)),
hooks: Arc::new(RwLock::new(Vec::new())),
notify: Arc::new(Notify::new()),
timeout,
}
}
pub async fn register_hook(&self, hook: ShutdownHook) {
let mut hooks = self.hooks.write().await;
info!(name = %hook.name, phase = ?hook.phase, "Shutdown hook registered");
hooks.push(hook);
}
pub async fn on_shutdown(
&self,
name: impl Into<String>,
phase: ShutdownPhase,
callback: impl Fn() -> Result<(), String> + Send + Sync + 'static,
) {
self.register_hook(ShutdownHook {
name: name.into(),
phase,
callback: Box::new(callback),
})
.await;
}
pub async fn is_shutting_down(&self) -> bool {
let state = self.state.read().await;
*state != ShutdownState::Running
}
pub fn shutdown_signal(&self) -> Arc<Notify> {
self.notify.clone()
}
pub async fn shutdown(&self) -> ShutdownReport {
let start = Instant::now();
{
let mut state = self.state.write().await;
if *state != ShutdownState::Running {
return ShutdownReport {
total_duration_ms: 0,
hooks_succeeded: 0,
hooks_failed: 0,
completed_in_time: true,
hook_results: Vec::new(),
};
}
*state = ShutdownState::ShuttingDown;
}
self.notify.notify_waiters();
info!("Graceful shutdown initiated");
let mut all_results = Vec::new();
let mut succeeded = 0;
let mut failed = 0;
let phases = [
ShutdownPhase::PreDrain,
ShutdownPhase::Drain,
ShutdownPhase::Cleanup,
ShutdownPhase::Final,
];
let hooks = self.hooks.read().await;
for phase in &phases {
if start.elapsed() > self.timeout {
warn!(phase = ?phase, "Shutdown timeout reached, skipping remaining hooks");
break;
}
let phase_hooks: Vec<&ShutdownHook> =
hooks.iter().filter(|h| h.phase == *phase).collect();
if !phase_hooks.is_empty() {
info!(phase = ?phase, count = phase_hooks.len(), "Executing shutdown phase");
}
for hook in phase_hooks {
let hook_start = Instant::now();
let result = (hook.callback)();
let duration_ms = hook_start.elapsed().as_millis() as u64;
match result {
Ok(()) => {
info!(name = %hook.name, duration_ms, "Shutdown hook completed");
succeeded += 1;
all_results.push(HookResult {
name: hook.name.clone(),
phase: *phase,
success: true,
error: None,
duration_ms,
});
}
Err(e) => {
error!(name = %hook.name, error = %e, "Shutdown hook failed");
failed += 1;
all_results.push(HookResult {
name: hook.name.clone(),
phase: *phase,
success: false,
error: Some(e),
duration_ms,
});
}
}
}
}
let total_duration_ms = start.elapsed().as_millis() as u64;
let completed_in_time = start.elapsed() <= self.timeout;
{
let mut state = self.state.write().await;
*state = ShutdownState::Completed;
}
info!(
total_ms = total_duration_ms,
succeeded, failed, "Graceful shutdown complete"
);
ShutdownReport {
total_duration_ms,
hooks_succeeded: succeeded,
hooks_failed: failed,
completed_in_time,
hook_results: all_results,
}
}
pub async fn hook_count(&self) -> usize {
self.hooks.read().await.len()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
fn manager() -> ShutdownManager {
ShutdownManager::new(Duration::from_secs(5))
}
#[tokio::test]
async fn test_initial_state() {
let mgr = manager();
assert!(!mgr.is_shutting_down().await);
}
#[tokio::test]
async fn test_shutdown_changes_state() {
let mgr = manager();
let _report = mgr.shutdown().await;
assert!(mgr.is_shutting_down().await);
}
#[tokio::test]
async fn test_hooks_run() {
let mgr = manager();
let ran = Arc::new(AtomicBool::new(false));
let ran_clone = ran.clone();
mgr.on_shutdown("test-hook", ShutdownPhase::Cleanup, move || {
ran_clone.store(true, Ordering::SeqCst);
Ok(())
})
.await;
let report = mgr.shutdown().await;
assert!(ran.load(Ordering::SeqCst));
assert_eq!(report.hooks_succeeded, 1);
assert_eq!(report.hooks_failed, 0);
}
#[tokio::test]
async fn test_failed_hooks() {
let mgr = manager();
mgr.on_shutdown("fail-hook", ShutdownPhase::Cleanup, || {
Err("intentional failure".to_string())
})
.await;
let report = mgr.shutdown().await;
assert_eq!(report.hooks_failed, 1);
assert_eq!(report.hooks_succeeded, 0);
assert!(report.hook_results[0].error.is_some());
}
#[tokio::test]
async fn test_phase_order() {
let mgr = manager();
let order = Arc::new(std::sync::Mutex::new(Vec::new()));
for (name, phase) in [
("final", ShutdownPhase::Final),
("pre-drain", ShutdownPhase::PreDrain),
("cleanup", ShutdownPhase::Cleanup),
("drain", ShutdownPhase::Drain),
] {
let order_clone = order.clone();
let name_str = name.to_string();
mgr.on_shutdown(name, phase, move || {
order_clone.lock().unwrap().push(name_str.clone());
Ok(())
})
.await;
}
mgr.shutdown().await;
let order = order.lock().unwrap();
assert_eq!(&*order, &["pre-drain", "drain", "cleanup", "final"]);
}
#[tokio::test]
async fn test_multiple_hooks_per_phase() {
let mgr = manager();
let counter = Arc::new(AtomicU32::new(0));
for i in 0..3 {
let c = counter.clone();
mgr.on_shutdown(format!("hook-{i}"), ShutdownPhase::Cleanup, move || {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.await;
}
let report = mgr.shutdown().await;
assert_eq!(counter.load(Ordering::SeqCst), 3);
assert_eq!(report.hooks_succeeded, 3);
}
#[tokio::test]
async fn test_double_shutdown() {
let mgr = manager();
let counter = Arc::new(AtomicU32::new(0));
let c = counter.clone();
mgr.on_shutdown("once", ShutdownPhase::Cleanup, move || {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.await;
mgr.shutdown().await;
let report2 = mgr.shutdown().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert_eq!(report2.hooks_succeeded, 0);
}
#[tokio::test]
async fn test_report_timing() {
let mgr = manager();
let report = mgr.shutdown().await;
assert!(report.completed_in_time);
assert!(report.total_duration_ms < 1000);
}
#[tokio::test]
async fn test_hook_count() {
let mgr = manager();
assert_eq!(mgr.hook_count().await, 0);
mgr.on_shutdown("h1", ShutdownPhase::Cleanup, || Ok(()))
.await;
mgr.on_shutdown("h2", ShutdownPhase::Final, || Ok(())).await;
assert_eq!(mgr.hook_count().await, 2);
}
#[tokio::test]
async fn test_shutdown_signal() {
let mgr = manager();
let signal = mgr.shutdown_signal();
let notified = Arc::new(AtomicBool::new(false));
let notified_clone = notified.clone();
let handle = tokio::spawn(async move {
signal.notified().await;
notified_clone.store(true, Ordering::SeqCst);
});
tokio::time::sleep(Duration::from_millis(10)).await;
mgr.shutdown().await;
handle.await.unwrap();
assert!(notified.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_report_serializable() {
let mgr = manager();
mgr.on_shutdown("ser", ShutdownPhase::Cleanup, || Ok(()))
.await;
let report = mgr.shutdown().await;
let json = serde_json::to_string(&report).unwrap();
assert!(json.contains("\"hooks_succeeded\":1"));
let restored: ShutdownReport = serde_json::from_str(&json).unwrap();
assert_eq!(restored.hooks_succeeded, 1);
}
#[test]
fn test_phase_ordering() {
assert!(ShutdownPhase::PreDrain < ShutdownPhase::Drain);
assert!(ShutdownPhase::Drain < ShutdownPhase::Cleanup);
assert!(ShutdownPhase::Cleanup < ShutdownPhase::Final);
}
#[tokio::test]
async fn test_hook_duration() {
let mgr = manager();
mgr.on_shutdown("slow", ShutdownPhase::Cleanup, || {
std::thread::sleep(Duration::from_millis(10));
Ok(())
})
.await;
let report = mgr.shutdown().await;
assert!(report.hook_results[0].duration_ms >= 5);
}
#[tokio::test]
async fn test_mixed_results() {
let mgr = manager();
mgr.on_shutdown("ok1", ShutdownPhase::PreDrain, || Ok(()))
.await;
mgr.on_shutdown("fail1", ShutdownPhase::Drain, || Err("oops".to_string()))
.await;
mgr.on_shutdown("ok2", ShutdownPhase::Cleanup, || Ok(()))
.await;
mgr.on_shutdown("fail2", ShutdownPhase::Final, || Err("boom".to_string()))
.await;
let report = mgr.shutdown().await;
assert_eq!(report.hooks_succeeded, 2);
assert_eq!(report.hooks_failed, 2);
}
#[tokio::test]
async fn test_clone_shares_state() {
let mgr1 = manager();
let mgr2 = mgr1.clone();
mgr1.on_shutdown("shared", ShutdownPhase::Cleanup, || Ok(()))
.await;
assert_eq!(mgr2.hook_count().await, 1);
mgr2.shutdown().await;
assert!(mgr1.is_shutting_down().await);
}
#[tokio::test]
async fn test_predrain_before_cleanup() {
let mgr = manager();
let val = Arc::new(std::sync::Mutex::new(0u32));
let val1 = val.clone();
let val2 = val.clone();
mgr.on_shutdown("cleanup", ShutdownPhase::Cleanup, move || {
let v = *val1.lock().unwrap();
assert_eq!(v, 1);
Ok(())
})
.await;
mgr.on_shutdown("predrain", ShutdownPhase::PreDrain, move || {
*val2.lock().unwrap() = 1;
Ok(())
})
.await;
let report = mgr.shutdown().await;
assert_eq!(report.hooks_succeeded, 2);
}
}