use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
use tracing::{debug, info};
#[derive(Clone)]
pub struct ShutdownSignal {
inner: Arc<ShutdownInner>,
}
struct ShutdownInner {
triggered: AtomicBool,
notify: Notify,
}
impl ShutdownSignal {
pub fn new() -> Self {
Self {
inner: Arc::new(ShutdownInner {
triggered: AtomicBool::new(false),
notify: Notify::new(),
}),
}
}
pub fn shutdown(&self) {
if !self.inner.triggered.swap(true, Ordering::SeqCst) {
info!("Shutdown signal triggered");
self.inner.notify.notify_waiters();
}
}
pub fn is_shutdown(&self) -> bool {
self.inner.triggered.load(Ordering::SeqCst)
}
pub async fn cancelled(&self) {
if self.is_shutdown() {
return;
}
self.inner.notify.notified().await;
}
pub async fn wait_timeout(&self, timeout: Duration) -> bool {
tokio::select! {
_ = self.cancelled() => true,
_ = tokio::time::sleep(timeout) => false,
}
}
}
impl Default for ShutdownSignal {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for ShutdownSignal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShutdownSignal")
.field("triggered", &self.is_shutdown())
.finish()
}
}
pub struct ShutdownGuard {
signal: ShutdownSignal,
}
impl ShutdownGuard {
pub fn new(signal: ShutdownSignal) -> Self {
Self { signal }
}
pub fn disarm(self) {
std::mem::forget(self);
}
}
impl Drop for ShutdownGuard {
fn drop(&mut self) {
debug!("ShutdownGuard dropped, triggering shutdown");
self.signal.shutdown();
}
}
pub fn install_signal_handlers(signal: ShutdownSignal) {
#[cfg(unix)]
{
use tokio::signal::unix::{signal as unix_signal, SignalKind};
let signal_clone = signal.clone();
tokio::spawn(async move {
let mut sigint =
unix_signal(SignalKind::interrupt()).expect("Failed to install SIGINT handler");
let mut sigterm =
unix_signal(SignalKind::terminate()).expect("Failed to install SIGTERM handler");
let mut sigquit =
unix_signal(SignalKind::quit()).expect("Failed to install SIGQUIT handler");
tokio::select! {
_ = sigint.recv() => info!("Received SIGINT"),
_ = sigterm.recv() => info!("Received SIGTERM"),
_ = sigquit.recv() => info!("Received SIGQUIT"),
}
signal_clone.shutdown();
});
}
#[cfg(windows)]
{
let signal_clone = signal.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
info!("Received Ctrl+C");
signal_clone.shutdown();
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_shutdown_signal_basic() {
let signal = ShutdownSignal::new();
assert!(!signal.is_shutdown());
signal.shutdown();
assert!(signal.is_shutdown());
}
#[tokio::test]
async fn test_shutdown_signal_cancelled() {
let signal = ShutdownSignal::new();
let signal_clone = signal.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
signal_clone.shutdown();
});
signal.cancelled().await;
assert!(signal.is_shutdown());
}
#[tokio::test]
async fn test_shutdown_already_triggered() {
let signal = ShutdownSignal::new();
signal.shutdown();
tokio::time::timeout(Duration::from_millis(10), signal.cancelled())
.await
.expect("Should complete immediately");
}
#[tokio::test]
async fn test_shutdown_wait_timeout() {
let signal = ShutdownSignal::new();
let result = signal.wait_timeout(Duration::from_millis(10)).await;
assert!(!result);
signal.shutdown();
let result = signal.wait_timeout(Duration::from_millis(10)).await;
assert!(result);
}
#[tokio::test]
async fn test_shutdown_guard() {
let signal = ShutdownSignal::new();
{
let _guard = ShutdownGuard::new(signal.clone());
}
assert!(signal.is_shutdown());
}
#[tokio::test]
async fn test_shutdown_guard_disarm() {
let signal = ShutdownSignal::new();
{
let guard = ShutdownGuard::new(signal.clone());
guard.disarm();
}
assert!(!signal.is_shutdown());
}
#[tokio::test]
async fn test_multiple_clones() {
let signal = ShutdownSignal::new();
let clone1 = signal.clone();
let clone2 = signal.clone();
clone1.shutdown();
assert!(signal.is_shutdown());
assert!(clone1.is_shutdown());
assert!(clone2.is_shutdown());
}
}