use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tokio::task::JoinHandle;
#[derive(Clone)]
pub struct ShutdownSignal {
requested: Arc<AtomicBool>,
rx: watch::Receiver<bool>,
}
impl ShutdownSignal {
pub fn is_shutdown_requested(&self) -> bool {
self.requested.load(Ordering::SeqCst)
}
pub async fn wait(&self) {
if self.is_shutdown_requested() {
return;
}
let mut rx = self.rx.clone();
loop {
if *rx.borrow_and_update() {
return;
}
if rx.changed().await.is_err() {
return;
}
}
}
}
impl std::fmt::Debug for ShutdownSignal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShutdownSignal")
.field("requested", &self.is_shutdown_requested())
.finish()
}
}
pub struct GracefulShutdown {
requested: Arc<AtomicBool>,
tx: watch::Sender<bool>,
rx: watch::Receiver<bool>,
grace_period: Duration,
}
impl GracefulShutdown {
pub fn new(grace_period: Duration) -> Self {
let (tx, rx) = watch::channel(false);
Self {
requested: Arc::new(AtomicBool::new(false)),
tx,
rx,
grace_period,
}
}
pub fn with_default_grace_period() -> Self {
Self::new(Duration::from_secs(5))
}
pub fn grace_period(&self) -> Duration {
self.grace_period
}
pub fn is_shutdown_requested(&self) -> bool {
self.requested.load(Ordering::SeqCst)
}
pub fn signal(&self) -> ShutdownSignal {
ShutdownSignal {
requested: Arc::clone(&self.requested),
rx: self.rx.clone(),
}
}
pub fn trigger(&self) {
self.requested.store(true, Ordering::SeqCst);
let _ = self.tx.send(true);
}
pub fn install(self) -> ShutdownGuard {
let requested = Arc::clone(&self.requested);
let tx = self.tx;
let grace_period = self.grace_period;
let signal = ShutdownSignal {
requested: Arc::clone(&self.requested),
rx: self.rx.clone(),
};
let handle = tokio::spawn(async move {
let sigint = tokio::signal::ctrl_c();
#[cfg(unix)]
{
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to register SIGTERM handler");
tokio::select! {
_ = sigint => {
tracing::info!("Received SIGINT (Ctrl-C), initiating graceful shutdown");
}
_ = sigterm.recv() => {
tracing::info!("Received SIGTERM, initiating graceful shutdown");
}
}
}
#[cfg(not(unix))]
{
let _ = sigint.await;
tracing::info!("Received Ctrl-C, initiating graceful shutdown");
}
requested.store(true, Ordering::SeqCst);
let _ = tx.send(true);
tracing::info!(
grace_period_secs = grace_period.as_secs_f64(),
"Shutdown signal sent; grace period started"
);
tokio::time::sleep(grace_period).await;
tracing::warn!("Grace period elapsed; force termination may follow");
});
ShutdownGuard { handle, signal }
}
}
impl std::fmt::Debug for GracefulShutdown {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GracefulShutdown")
.field("requested", &self.is_shutdown_requested())
.field("grace_period", &self.grace_period)
.finish()
}
}
pub struct ShutdownGuard {
handle: JoinHandle<()>,
signal: ShutdownSignal,
}
impl ShutdownGuard {
pub fn signal(&self) -> ShutdownSignal {
self.signal.clone()
}
pub async fn wait(self) {
let _ = self.handle.await;
}
pub fn cancel(self) {
self.handle.abort();
}
}
impl std::fmt::Debug for ShutdownGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShutdownGuard")
.field("signal", &self.signal)
.field("finished", &self.handle.is_finished())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_grace_period() {
let shutdown = GracefulShutdown::with_default_grace_period();
assert_eq!(shutdown.grace_period(), Duration::from_secs(5));
}
#[test]
fn test_custom_grace_period() {
let shutdown = GracefulShutdown::new(Duration::from_secs(30));
assert_eq!(shutdown.grace_period(), Duration::from_secs(30));
}
#[test]
fn test_not_requested_initially() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
assert!(!shutdown.is_shutdown_requested());
let signal = shutdown.signal();
assert!(!signal.is_shutdown_requested());
}
#[test]
fn test_manual_trigger() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
let signal = shutdown.signal();
assert!(!signal.is_shutdown_requested());
shutdown.trigger();
assert!(signal.is_shutdown_requested());
assert!(shutdown.is_shutdown_requested());
}
#[test]
fn test_signal_clone_shares_state() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
let s1 = shutdown.signal();
let s2 = s1.clone();
let s3 = shutdown.signal();
assert!(!s1.is_shutdown_requested());
assert!(!s2.is_shutdown_requested());
assert!(!s3.is_shutdown_requested());
shutdown.trigger();
assert!(s1.is_shutdown_requested());
assert!(s2.is_shutdown_requested());
assert!(s3.is_shutdown_requested());
}
#[tokio::test]
async fn test_signal_wait_resolves_on_trigger() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
let signal = shutdown.signal();
let trigger_handle = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
shutdown.trigger();
});
tokio::time::timeout(Duration::from_secs(2), signal.wait())
.await
.expect("signal.wait() should have resolved within timeout");
trigger_handle.await.unwrap();
}
#[tokio::test]
async fn test_signal_wait_returns_immediately_if_already_triggered() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
let signal = shutdown.signal();
shutdown.trigger();
tokio::time::timeout(Duration::from_millis(100), signal.wait())
.await
.expect("signal.wait() should return immediately when already triggered");
}
#[tokio::test]
async fn test_guard_signal() {
let shutdown = GracefulShutdown::new(Duration::from_millis(50));
let signal = shutdown.signal();
let guard = shutdown.install();
let guard_signal = guard.signal();
assert!(!guard_signal.is_shutdown_requested());
assert!(!signal.is_shutdown_requested());
guard.cancel();
}
#[tokio::test]
async fn test_guard_cancel() {
let shutdown = GracefulShutdown::new(Duration::from_secs(60));
let guard = shutdown.install();
guard.cancel();
}
#[tokio::test]
async fn test_multiple_signals_from_same_shutdown() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
let signals: Vec<_> = (0..10).map(|_| shutdown.signal()).collect();
for s in &signals {
assert!(!s.is_shutdown_requested());
}
shutdown.trigger();
for s in &signals {
assert!(s.is_shutdown_requested());
}
}
#[tokio::test]
async fn test_wait_with_dropped_sender() {
let signal = {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
shutdown.signal()
};
tokio::time::timeout(Duration::from_millis(100), signal.wait())
.await
.expect("signal.wait() should resolve when sender is dropped");
}
#[test]
fn test_debug_impls() {
let shutdown = GracefulShutdown::new(Duration::from_secs(5));
let debug_str = format!("{:?}", shutdown);
assert!(debug_str.contains("GracefulShutdown"));
assert!(debug_str.contains("requested"));
let signal = shutdown.signal();
let debug_str = format!("{:?}", signal);
assert!(debug_str.contains("ShutdownSignal"));
}
#[tokio::test]
async fn test_guard_debug() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
let guard = shutdown.install();
let debug_str = format!("{:?}", guard);
assert!(debug_str.contains("ShutdownGuard"));
guard.cancel();
}
#[tokio::test]
async fn test_concurrent_trigger_is_safe() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
let signal = shutdown.signal();
let shutdown = Arc::new(shutdown);
let mut handles = Vec::new();
for _ in 0..10 {
let s = Arc::clone(&shutdown);
handles.push(tokio::spawn(async move {
s.trigger();
}));
}
for h in handles {
h.await.unwrap();
}
assert!(signal.is_shutdown_requested());
}
#[tokio::test]
async fn test_signal_wait_multiple_waiters() {
let shutdown = GracefulShutdown::new(Duration::from_secs(1));
let mut handles = Vec::new();
for _ in 0..5 {
let signal = shutdown.signal();
handles.push(tokio::spawn(async move {
signal.wait().await;
}));
}
tokio::time::sleep(Duration::from_millis(50)).await;
shutdown.trigger();
for h in handles {
tokio::time::timeout(Duration::from_secs(2), h)
.await
.expect("waiter should complete")
.expect("waiter task should not panic");
}
}
}