use tokio::sync::watch;
use tracing::info;
#[derive(Debug, Clone)]
pub struct ShutdownController {
tx: watch::Sender<bool>,
}
impl ShutdownController {
#[must_use]
pub fn install() -> Self {
let (tx, _rx) = watch::channel(false);
let tx_sig = tx.clone();
tokio::spawn(async move {
wait_for_signal().await;
tx_sig.send_replace(true);
});
Self { tx }
}
#[must_use]
pub fn manual() -> Self {
let (tx, _rx) = watch::channel(false);
Self { tx }
}
pub fn shutdown(&self) {
self.tx.send_replace(true);
}
#[must_use]
pub fn token(&self) -> Shutdown {
Shutdown {
rx: self.tx.subscribe(),
}
}
#[must_use]
pub fn is_triggered(&self) -> bool {
*self.tx.borrow()
}
}
#[derive(Debug, Clone)]
pub struct Shutdown {
rx: watch::Receiver<bool>,
}
impl Shutdown {
pub async fn wait(mut self) {
self.wait_ref().await;
}
pub async fn wait_ref(&mut self) {
loop {
if *self.rx.borrow() {
return;
}
if self.rx.changed().await.is_err() {
return;
}
}
}
#[must_use]
pub fn is_triggered(&self) -> bool {
*self.rx.borrow()
}
}
async fn wait_for_signal() {
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm = signal(SignalKind::terminate()).expect("install SIGTERM handler");
let mut sigint = signal(SignalKind::interrupt()).expect("install SIGINT handler");
tokio::select! {
_ = sigterm.recv() => info!(signal = "SIGTERM", "shutdown signal received, draining"),
_ = sigint.recv() => info!(signal = "SIGINT", "shutdown signal received, draining"),
}
}
#[cfg(not(unix))]
{
if let Ok(()) = tokio::signal::ctrl_c().await {
info!(signal = "ctrl-c", "shutdown signal received, draining");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::timeout;
#[tokio::test]
async fn manual_shutdown_fires_token() {
let ctrl = ShutdownController::manual();
let tok = ctrl.token();
ctrl.shutdown();
timeout(Duration::from_millis(200), tok.wait())
.await
.expect("wait should complete after manual shutdown");
}
#[tokio::test]
async fn multiple_tokens_all_fire() {
let ctrl = ShutdownController::manual();
let t1 = ctrl.token();
let t2 = ctrl.token();
let t3 = ctrl.token();
ctrl.shutdown();
let handle = tokio::spawn(async move {
tokio::join!(t1.wait(), t2.wait(), t3.wait());
});
timeout(Duration::from_millis(500), handle)
.await
.expect("all three tokens should fire")
.expect("join handle");
}
#[tokio::test]
async fn shutdown_is_idempotent() {
let ctrl = ShutdownController::manual();
let tok = ctrl.token();
ctrl.shutdown();
ctrl.shutdown();
ctrl.shutdown();
assert!(ctrl.is_triggered());
assert!(tok.is_triggered());
tok.wait().await;
}
#[tokio::test]
async fn token_fires_when_subscribed_after_shutdown() {
let ctrl = ShutdownController::manual();
ctrl.shutdown();
let late_tok = ctrl.token();
assert!(late_tok.is_triggered());
timeout(Duration::from_millis(200), late_tok.wait())
.await
.expect("late subscriber wait should return immediately");
}
#[tokio::test]
async fn wait_ref_can_be_polled_in_select() {
let ctrl = ShutdownController::manual();
let mut tok = ctrl.token();
let ctrl_clone = ctrl.clone();
let (work_tx, work_rx) = oneshot::channel::<()>();
let triggered_after_work = Arc::new(AtomicBool::new(false));
let t_flag = triggered_after_work.clone();
let worker = tokio::spawn(async move {
tokio::select! {
_ = work_rx => {
t_flag.store(tok.is_triggered(), Ordering::SeqCst);
}
() = tok.wait_ref() => {
t_flag.store(true, Ordering::SeqCst);
}
}
});
tokio::time::sleep(Duration::from_millis(20)).await;
ctrl_clone.shutdown();
timeout(Duration::from_millis(500), worker)
.await
.expect("worker should complete")
.expect("join");
assert!(triggered_after_work.load(Ordering::SeqCst));
drop(work_tx);
}
#[tokio::test]
async fn is_triggered_reflects_state() {
let ctrl = ShutdownController::manual();
let tok = ctrl.token();
assert!(!ctrl.is_triggered());
assert!(!tok.is_triggered());
ctrl.shutdown();
assert!(ctrl.is_triggered());
assert!(tok.is_triggered());
}
#[tokio::test]
async fn controller_drop_unblocks_waiters() {
let ctrl = ShutdownController::manual();
let tok = ctrl.token();
drop(ctrl);
timeout(Duration::from_millis(200), tok.wait())
.await
.expect("wait should unblock when controller is dropped");
}
#[tokio::test]
async fn cloned_controllers_share_state() {
let ctrl1 = ShutdownController::manual();
let ctrl2 = ctrl1.clone();
let tok = ctrl2.token();
ctrl1.shutdown();
assert!(ctrl2.is_triggered());
tok.wait().await;
}
#[tokio::test]
async fn install_signal_handlers_does_not_fire_without_signal() {
let ctrl = ShutdownController::install();
assert!(!ctrl.is_triggered());
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(!ctrl.is_triggered());
ctrl.shutdown();
}
}