use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
use tokio::sync::broadcast;
use tokio::time::timeout;
#[derive(Clone)]
pub struct ShutdownCoordinator {
shutdown_tx: broadcast::Sender<()>,
shutdown_complete: Arc<Notify>,
timeout_duration: Duration,
}
impl ShutdownCoordinator {
pub fn new(timeout_duration: Duration) -> Self {
let (shutdown_tx, _) = broadcast::channel(1);
let shutdown_complete = Arc::new(Notify::new());
Self {
shutdown_tx,
shutdown_complete,
timeout_duration,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(());
}
pub fn notify_shutdown_complete(&self) {
self.shutdown_complete.notify_one();
}
pub async fn wait_for_shutdown(&self) {
match timeout(self.timeout_duration, self.shutdown_complete.notified()).await {
Ok(_) => {
println!("Graceful shutdown completed");
}
Err(_) => {
eprintln!(
"Shutdown timeout after {:?}, forcing termination",
self.timeout_duration
);
}
}
}
pub fn timeout_duration(&self) -> Duration {
self.timeout_duration
}
}
pub async fn shutdown_signal() {
use tokio::signal;
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
println!("Received Ctrl+C signal");
}
_ = terminate => {
println!("Received SIGTERM signal");
}
}
}
pub async fn with_shutdown<F>(
future: F,
mut shutdown_rx: broadcast::Receiver<()>,
) -> Option<F::Output>
where
F: Future,
{
tokio::select! {
result = future => Some(result),
_ = shutdown_rx.recv() => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_shutdown_coordinator_creation() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(30));
assert_eq!(coordinator.timeout_duration(), Duration::from_secs(30));
}
#[tokio::test]
async fn test_shutdown_signal_propagation() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(1));
let mut rx = coordinator.subscribe();
coordinator.shutdown();
let result = rx.recv().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_multiple_subscribers() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(1));
let mut rx1 = coordinator.subscribe();
let mut rx2 = coordinator.subscribe();
coordinator.shutdown();
assert!(rx1.recv().await.is_ok());
assert!(rx2.recv().await.is_ok());
}
#[tokio::test]
async fn test_shutdown_notification() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(1));
let coordinator_clone = coordinator.clone();
tokio::spawn(async move {
coordinator_clone.notify_shutdown_complete();
});
coordinator.wait_for_shutdown().await;
}
#[tokio::test]
async fn test_shutdown_timeout() {
let coordinator = ShutdownCoordinator::new(Duration::from_millis(100));
let start = std::time::Instant::now();
coordinator.wait_for_shutdown().await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(100));
assert!(elapsed < Duration::from_millis(200));
}
#[tokio::test]
async fn test_with_shutdown_completes_normally() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(1));
let shutdown_rx = coordinator.subscribe();
let work = async { 42 };
let result = with_shutdown(work, shutdown_rx).await;
assert_eq!(result, Some(42));
}
#[tokio::test]
async fn test_with_shutdown_interrupted() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(1));
let shutdown_rx = coordinator.subscribe();
let work = async {
tokio::time::sleep(Duration::from_millis(100)).await;
42
};
coordinator.shutdown();
let result = with_shutdown(work, shutdown_rx).await;
assert_eq!(result, None);
}
}