use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use futures::stream::StreamExt;
use log::debug;
use signal_hook::{
consts::{SIGINT, SIGTERM},
flag,
};
use signal_hook_tokio::Signals;
use tokio_util::sync::CancellationToken;
#[derive(Debug)]
pub struct Shutdown {
registered: Arc<AtomicBool>,
token: CancellationToken,
}
impl Clone for Shutdown {
fn clone(&self) -> Self {
self.subscribe()
}
}
impl From<Shutdown> for CancellationToken {
fn from(value: Shutdown) -> Self {
value.token
}
}
impl Shutdown {
pub fn new() -> Result<Self, std::io::Error> {
let mut shutdown = Shutdown::unregistered();
shutdown.register_signals()?;
Ok(shutdown)
}
pub fn unregistered() -> Self {
Self {
registered: Arc::default(),
token: CancellationToken::new(),
}
}
pub fn register_signals(&mut self) -> Result<(), std::io::Error> {
if self.registered.load(Ordering::SeqCst) {
return Ok(());
}
let mut signals = Signals::new([SIGINT, SIGTERM])?;
let token = self.token.clone();
tokio::spawn(async move {
if let Some(signal) = signals.next().await {
debug!("Received a shutdown signal: {}", signal);
flag::register_conditional_shutdown(SIGINT, 0, Arc::new(AtomicBool::new(true)))
.unwrap();
flag::register_conditional_shutdown(SIGTERM, 0, Arc::new(AtomicBool::new(true)))
.unwrap();
token.cancel();
}
});
self.registered.store(true, Ordering::SeqCst);
Ok(())
}
pub fn branch(&self) -> Self {
Self {
registered: self.registered.clone(),
token: self.token.child_token(),
}
}
pub fn subscribe(&self) -> Self {
Self {
registered: self.registered.clone(),
token: self.token.clone(),
}
}
pub fn is_signalled(&self) -> bool {
self.token.is_cancelled()
}
pub fn signal(&self) {
self.token.cancel();
}
pub async fn signalled(&self) {
self.token.cancelled().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn not_notified() {
let _ = env_logger::Builder::new()
.format_timestamp(None)
.filter(None, log::LevelFilter::Debug)
.is_test(true)
.try_init();
let root = Shutdown::new().unwrap();
let branch1 = root.branch();
let branch2 = branch1.branch();
let sub1 = branch1.subscribe();
let sub2 = branch2.subscribe();
tokio::select! {
_ = root.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
assert!(!root.is_signalled(), "root shutdown without notify");
assert!(!branch1.is_signalled(), "branch1 shutdown without notify");
assert!(!branch2.is_signalled(), "branch2 shutdown without notify");
assert!(!sub1.is_signalled(), "subscriber1 shutdown without notify");
assert!(!sub2.is_signalled(), "subscriber2 shutdown without notify");
}
#[tokio::test]
async fn shutdown_sigint() {
let _ = env_logger::Builder::new()
.format_timestamp(None)
.filter(None, log::LevelFilter::Debug)
.is_test(true)
.try_init();
let root = Shutdown::new().unwrap();
let branch1 = root.branch();
let branch2 = branch1.branch();
let sub1 = branch1.subscribe();
let sub2 = branch2.subscribe();
unsafe { libc::raise(signal_hook::consts::SIGINT) };
tokio::select! {
_ = root.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
tokio::select! {
_ = sub1.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
tokio::select! {
_ = sub2.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
assert!(root.is_signalled(), "root not shutdown (signal)");
assert!(branch1.is_signalled(), "branch1 not shutdown (signal)");
assert!(branch2.is_signalled(), "branch2 not shutdown (signal)");
assert!(sub1.is_signalled(), "subscriber1 not shutdown (signal)");
assert!(sub2.is_signalled(), "subscriber2 not shutdown (signal)");
}
#[tokio::test]
async fn shutdown_now() {
let root = Shutdown::new().unwrap();
let branch1 = root.branch();
let branch2 = branch1.branch();
let sub1 = branch1.subscribe();
let sub2 = branch2.subscribe();
root.signal();
tokio::select! {
_ = root.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
tokio::select! {
_ = sub1.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
tokio::select! {
_ = sub2.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
assert!(root.is_signalled(), "root not shutdown (manual)");
assert!(branch1.is_signalled(), "branch1 not shutdown (manual)");
assert!(branch2.is_signalled(), "branch2 not shutdown (manual)");
assert!(sub1.is_signalled(), "subscriber1 not shutdown (manual)");
assert!(sub2.is_signalled(), "subscriber2 not shutdown (manual)");
}
#[tokio::test]
async fn shutdown_branch() {
let root = Shutdown::new().unwrap();
let branch1 = root.branch();
let branch2 = branch1.branch();
let sub1 = branch1.subscribe();
let sub2 = branch2.subscribe();
sub2.signal();
tokio::select! {
_ = root.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
tokio::select! {
_ = sub1.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
tokio::select! {
_ = sub2.signalled() => (),
_ = sleep(Duration::from_secs(1)) => (),
}
assert!(!root.is_signalled(), "root shutdown without notify");
assert!(!branch1.is_signalled(), "branch1 shutdown without notify");
assert!(!sub1.is_signalled(), "subscriber1 shutdown without notify");
assert!(branch2.is_signalled(), "branch2 not shutdown (manual)");
assert!(sub2.is_signalled(), "subscriber2 not shutdown (manual)");
}
#[tokio::test]
async fn shutdown_signal_via_token_cancel() {
let shutdown = Shutdown::new().unwrap();
let token: CancellationToken = shutdown.clone().into();
token.cancel();
assert!(
shutdown.is_signalled(),
"shutdown not signalled via token.cancel()"
);
}
#[tokio::test]
async fn shutdown_token_cancel_via_signal() {
let shutdown = Shutdown::new().unwrap();
let token: CancellationToken = shutdown.clone().into();
shutdown.signal();
assert!(
token.is_cancelled(),
"token not cancelled via shutdown.signal()"
);
}
}