use futures::future::{select, Either};
use futures::FutureExt;
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::{broadcast, oneshot};
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
type ShutdownAction =
Box<dyn FnOnce() -> Box<dyn Future<Output = ()> + Unpin + Send> + Send + Sync>;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub enum ExitPriority {
First = 0,
Normal = 50,
Last = 100,
}
pub enum ShutdownActionStatus {
Completed,
TimedOut,
Forced,
Failed,
}
struct ExitHandler {
name: String,
action: ShutdownAction,
priority: ExitPriority,
}
#[derive(Debug, Clone)]
pub struct ShutdownResult {
pub completed: usize,
pub timed_out: usize,
pub forced: usize,
pub failed: usize,
}
pub struct ShutdownManager {
handlers: Arc<Mutex<Vec<ExitHandler>>>,
timeout_duration: Duration,
pub force_shutdown_tx: broadcast::Sender<()>,
name: String,
shutdown_started: Arc<AtomicBool>, }
impl Default for ShutdownManager {
fn default() -> Self {
Self::new("app")
}
}
impl ShutdownManager {
pub fn new(name: &str) -> Self {
let (force_shutdown_tx, _) = broadcast::channel(1);
Self {
handlers: Arc::new(Mutex::new(Vec::new())),
timeout_duration: Duration::from_secs(5),
force_shutdown_tx,
name: name.to_string(),
shutdown_started: Arc::new(AtomicBool::new(false)),
}
}
pub fn with_timeout(mut self, duration: Duration) -> Self {
self.timeout_duration = duration;
self
}
pub fn register<F, Fut>(&self, name: &str, priority: ExitPriority, action: F)
where
F: FnOnce() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed_action = Box::new(move || {
let fut = action();
Box::new(Box::pin(fut)) as Box<dyn Future<Output = ()> + Unpin + Send>
});
let mut handlers = self.handlers.lock().unwrap();
handlers.push(ExitHandler {
name: name.to_string(),
action: boxed_action,
priority,
});
debug!(
"[{}] Registered shutdown handler: {} (priority: {:?})",
self.name, name, priority
);
}
async fn execute_action(
&self,
handler: ExitHandler,
force_rx: &mut broadcast::Receiver<()>,
) -> ShutdownActionStatus {
info!(
"[{}] Executing shutdown handler: {} (priority: {:?})",
self.name, handler.name, handler.priority
);
let future = (handler.action)();
let timeout_fut = Box::pin(timeout(self.timeout_duration, future));
let force_fut = Box::pin(force_rx.recv().fuse());
match select(force_fut, timeout_fut).await {
Either::Left((_, _)) => {
warn!(
"[{}] Shutdown handler forced to terminate: {}",
self.name, handler.name
);
ShutdownActionStatus::Forced
}
Either::Right((timeout_result, _)) => match timeout_result {
Ok(_) => {
info!(
"[{}] Shutdown handler completed: {}",
self.name, handler.name
);
ShutdownActionStatus::Completed
}
Err(_) => {
error!(
"[{}] Shutdown handler timed out after {:?}: {}",
self.name, self.timeout_duration, handler.name
);
ShutdownActionStatus::TimedOut
}
},
}
}
pub async fn execute_shutdown(&self) -> ShutdownResult {
if self.shutdown_started.swap(true, Ordering::SeqCst) {
warn!("Shutdown already in progress, ignoring duplicate request");
return ShutdownResult { completed: 0, timed_out: 0, forced: 0, failed: 0 };
}
info!("Starting shutdown sequence for {}", self.name);
let handlers = {
let mut handlers = self.handlers.lock().unwrap();
handlers.sort_by_key(|h| h.priority);
std::mem::take(&mut *handlers)
};
let mut result = ShutdownResult {
completed: 0,
timed_out: 0,
forced: 0,
failed: 0,
};
let mut force_rx = self.force_shutdown_tx.subscribe();
for handler in handlers {
let status = self.execute_action(handler, &mut force_rx).await;
match status {
ShutdownActionStatus::Completed => result.completed += 1,
ShutdownActionStatus::TimedOut => {
result.timed_out += 1;
if result.timed_out >= 2 {
warn!(
"[{}] Multiple actions timed out, forcing remaining shutdowns",
self.name
);
let _ = self.force_shutdown_tx.send(());
}
}
ShutdownActionStatus::Forced => result.forced += 1,
ShutdownActionStatus::Failed => result.failed += 1,
}
}
info!(
"[{}] Graceful shutdown sequence completed: {:?}",
self.name, result
);
result
}
pub async fn force_shutdown(&self) -> ShutdownResult {
if self.shutdown_started.swap(true, Ordering::SeqCst) {
warn!("Shutdown already in progress, ignoring duplicate request");
return ShutdownResult { completed: 0, timed_out: 0, forced: 0, failed: 0 };
}
warn!("Force shutdown triggered for {}", self.name);
self.execute_shutdown().await
}
pub async fn wait_for_shutdown(&self) {
crate::shutdown_handler::wait_for_signal().await;
}
}
#[cfg(unix)]
pub async fn wait_for_signal() {
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = signal(SignalKind::terminate()).expect("Failed to setup SIGTERM channel");
let mut sigquit = signal(SignalKind::quit()).expect("Failed to setup SIGQUIT channel");
tokio::select! {
_ = tokio::signal::ctrl_c() => {
info!("Received SIGINT (Ctrl+C)");
},
_ = sigterm.recv() => {
info!("Received SIGTERM");
},
_ = sigquit.recv() => {
info!("Received SIGQUIT");
}
}
}
#[cfg(not(unix))]
pub async fn wait_for_signal() {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen for Ctrl+C");
info!("Received SIGINT (Ctrl+C)");
}
#[cfg(test)]
mod tests {
use super::*;
}