use std::sync::Arc;
use std::time::Duration;
use tokio::signal;
use tokio::sync::{broadcast, RwLock};
use tokio::time::timeout;
#[derive(Debug)]
pub struct ShutdownSignal {
tx: broadcast::Sender<()>,
timeout_secs: u64,
}
impl ShutdownSignal {
pub fn new(timeout_secs: u64) -> Self {
let (tx, _rx) = broadcast::channel(1);
ShutdownSignal { tx, timeout_secs }
}
pub fn sender(&self) -> broadcast::Sender<()> {
self.tx.clone()
}
pub fn receiver(&self) -> broadcast::Receiver<()> {
self.tx.subscribe()
}
pub fn trigger(&self) {
let _ = self.tx.send(());
}
pub async fn wait(&self) {
let mut rx = self.tx.subscribe();
let _ = rx.recv().await;
}
}
pub struct GracefulShutdown {
signal: Arc<ShutdownSignal>,
running: Arc<RwLock<bool>>,
active_tasks: Arc<RwLock<usize>>,
}
impl GracefulShutdown {
pub fn new(timeout_secs: u64) -> Self {
GracefulShutdown {
signal: Arc::new(ShutdownSignal::new(timeout_secs)),
running: Arc::new(RwLock::new(true)),
active_tasks: Arc::new(RwLock::new(0)),
}
}
pub async fn register_task(&self) {
let mut tasks = self.active_tasks.write().await;
*tasks += 1;
}
pub async fn complete_task(&self) {
let mut tasks = self.active_tasks.write().await;
if *tasks > 0 {
*tasks -= 1;
}
}
pub async fn active_tasks(&self) -> usize {
*self.active_tasks.read().await
}
pub async fn is_running(&self) -> bool {
*self.running.read().await
}
pub fn install_signal_handlers(&self) {
let signal = self.signal.clone();
let running = self.running.clone();
tokio::spawn(async move {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
println!("Shutdown signal received, stopping gracefully...");
{
let mut r = running.write().await;
*r = false;
}
signal.trigger();
});
}
pub async fn wait_for_shutdown(&self) {
let mut rx = self.signal.receiver();
let timeout_duration = Duration::from_secs(self.signal.timeout_secs);
match timeout(timeout_duration, rx.recv()).await {
Ok(_) => {
println!("Shutdown completed gracefully");
}
Err(_) => {
println!("Shutdown timeout reached, forcing exit");
}
}
}
pub async fn wait_for_tasks(&self) {
let timeout_duration = Duration::from_secs(self.signal.timeout_secs);
let start = std::time::Instant::now();
loop {
let active = self.active_tasks().await;
if active == 0 {
break;
}
if start.elapsed() > timeout_duration {
println!(
"{} tasks still running after timeout, forcing shutdown",
active
);
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shutdown_signal_creation() {
let signal = ShutdownSignal::new(30);
assert_eq!(signal.timeout_secs, 30);
}
#[tokio::test]
async fn test_shutdown_signal_trigger() {
let signal = ShutdownSignal::new(5);
let mut rx = signal.receiver();
signal.trigger();
let result = rx.recv().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_graceful_shutdown_creation() {
let shutdown = GracefulShutdown::new(30);
assert!(shutdown.is_running().await);
assert_eq!(shutdown.active_tasks().await, 0);
}
#[tokio::test]
async fn test_task_tracking() {
let shutdown = GracefulShutdown::new(30);
shutdown.register_task().await;
assert_eq!(shutdown.active_tasks().await, 1);
shutdown.register_task().await;
assert_eq!(shutdown.active_tasks().await, 2);
shutdown.complete_task().await;
assert_eq!(shutdown.active_tasks().await, 1);
}
}