use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::broadcast;
use tracing::{info, warn};
#[derive(Clone)]
pub struct ShutdownCoordinator {
sender: broadcast::Sender<()>,
shutdown_initiated: Arc<AtomicBool>,
}
impl ShutdownCoordinator {
pub fn new() -> Self {
let (sender, _) = broadcast::channel(16);
Self {
sender,
shutdown_initiated: Arc::new(AtomicBool::new(false)),
}
}
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.sender.subscribe()
}
pub fn shutdown(&self) {
if self.shutdown_initiated.swap(true, Ordering::SeqCst) {
return;
}
info!("Initiating graceful shutdown");
if let Err(e) = self.sender.send(()) {
warn!("Failed to broadcast shutdown signal: {}", e);
}
}
pub fn is_shutting_down(&self) -> bool {
self.shutdown_initiated.load(Ordering::SeqCst)
}
}
impl Default for ShutdownCoordinator {
fn default() -> Self {
Self::new()
}
}
pub async fn setup_signal_handlers(coordinator: ShutdownCoordinator) {
tokio::spawn(async move {
if let Err(e) = wait_for_signal().await {
warn!("Error setting up signal handlers: {}", e);
return;
}
info!("Received shutdown signal");
coordinator.shutdown();
});
}
async fn wait_for_signal() -> Result<(), std::io::Error> {
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm = signal(SignalKind::terminate())?;
let mut sigint = signal(SignalKind::interrupt())?;
tokio::select! {
_ = sigterm.recv() => {
info!("Received SIGTERM");
}
_ = sigint.recv() => {
info!("Received SIGINT");
}
}
}
#[cfg(not(unix))]
{
use tokio::signal;
signal::ctrl_c().await?;
info!("Received Ctrl+C");
}
Ok(())
}
pub struct ShutdownGuard {
coordinator: ShutdownCoordinator,
disarmed: Arc<AtomicBool>,
}
impl ShutdownGuard {
pub fn new(coordinator: ShutdownCoordinator) -> Self {
Self {
coordinator,
disarmed: Arc::new(AtomicBool::new(false)),
}
}
pub fn disarm(&self) {
self.disarmed.store(true, Ordering::SeqCst);
}
}
impl Drop for ShutdownGuard {
fn drop(&mut self) {
if !self.disarmed.load(Ordering::SeqCst) {
warn!("ShutdownGuard dropped without disarming - triggering shutdown");
self.coordinator.shutdown();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn test_shutdown_coordinator() {
let coordinator = ShutdownCoordinator::new();
let mut receiver = coordinator.subscribe();
assert!(!coordinator.is_shutting_down());
coordinator.shutdown();
assert!(coordinator.is_shutting_down());
let result = timeout(Duration::from_millis(100), receiver.recv()).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_multiple_subscribers() {
let coordinator = ShutdownCoordinator::new();
let mut rx1 = coordinator.subscribe();
let mut rx2 = coordinator.subscribe();
let mut rx3 = coordinator.subscribe();
coordinator.shutdown();
assert!(
timeout(Duration::from_millis(100), rx1.recv())
.await
.is_ok()
);
assert!(
timeout(Duration::from_millis(100), rx2.recv())
.await
.is_ok()
);
assert!(
timeout(Duration::from_millis(100), rx3.recv())
.await
.is_ok()
);
}
#[tokio::test]
async fn test_shutdown_idempotent() {
let coordinator = ShutdownCoordinator::new();
coordinator.shutdown();
coordinator.shutdown();
assert!(coordinator.is_shutting_down());
}
#[test]
fn test_shutdown_guard_disarm() {
let coordinator = ShutdownCoordinator::new();
let guard = ShutdownGuard::new(coordinator.clone());
guard.disarm();
drop(guard);
assert!(!coordinator.is_shutting_down());
}
#[test]
fn test_shutdown_guard_trigger() {
let coordinator = ShutdownCoordinator::new();
let guard = ShutdownGuard::new(coordinator.clone());
drop(guard);
assert!(coordinator.is_shutting_down());
}
}