use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
use std::sync::Arc;
use signal_hook::consts::{SIGHUP, SIGINT, SIGTERM};
#[derive(Debug)]
pub struct ShutdownSignal {
signaled: AtomicBool,
signal_num: AtomicI32,
}
impl ShutdownSignal {
pub fn new() -> Self {
Self {
signaled: AtomicBool::new(false),
signal_num: AtomicI32::new(0),
}
}
#[inline]
pub fn is_signaled(&self) -> bool {
self.signaled.load(Ordering::Acquire)
}
pub fn signal_number(&self) -> i32 {
self.signal_num.load(Ordering::Acquire)
}
pub fn trigger(&self, signal_num: i32) {
self.signal_num.store(signal_num, Ordering::Release);
self.signaled.store(true, Ordering::Release);
}
#[cfg(test)]
pub fn reset(&self) {
self.signaled.store(false, Ordering::Release);
self.signal_num.store(0, Ordering::Release);
}
pub fn signal_name(&self) -> &'static str {
match self.signal_num.load(Ordering::Acquire) {
SIGINT => "SIGINT",
SIGTERM => "SIGTERM",
SIGHUP => "SIGHUP",
0 => "none",
_ => "unknown",
}
}
}
impl Default for ShutdownSignal {
fn default() -> Self {
Self::new()
}
}
pub fn install_signal_handlers(shutdown: Arc<ShutdownSignal>) -> std::io::Result<()> {
let shutdown_sigint = shutdown.clone();
unsafe {
signal_hook::low_level::register(SIGINT, move || {
shutdown_sigint.trigger(SIGINT);
})?;
}
let shutdown_sigterm = shutdown.clone();
unsafe {
signal_hook::low_level::register(SIGTERM, move || {
shutdown_sigterm.trigger(SIGTERM);
})?;
}
#[cfg(unix)]
{
let shutdown_sighup = shutdown.clone();
unsafe {
signal_hook::low_level::register(SIGHUP, move || {
shutdown_sighup.trigger(SIGHUP);
})?;
}
}
Ok(())
}
pub fn install_simple_signal_handlers(shutdown_flag: Arc<AtomicBool>) -> std::io::Result<()> {
let flag_sigint = shutdown_flag.clone();
unsafe {
signal_hook::low_level::register(SIGINT, move || {
flag_sigint.store(true, Ordering::Release);
})?;
}
let flag_sigterm = shutdown_flag.clone();
unsafe {
signal_hook::low_level::register(SIGTERM, move || {
flag_sigterm.store(true, Ordering::Release);
})?;
}
#[cfg(unix)]
{
let flag_sighup = shutdown_flag;
unsafe {
signal_hook::low_level::register(SIGHUP, move || {
flag_sighup.store(true, Ordering::Release);
})?;
}
}
Ok(())
}
#[derive(Debug)]
pub struct ShutdownGuard {
shutdown: Arc<ShutdownSignal>,
triggered: bool,
}
impl ShutdownGuard {
pub fn new(shutdown: Arc<ShutdownSignal>) -> Self {
Self {
shutdown,
triggered: false,
}
}
pub fn disarm(&mut self) {
self.triggered = true;
}
}
impl Drop for ShutdownGuard {
fn drop(&mut self) {
if !self.triggered && !self.shutdown.is_signaled() {
self.shutdown.trigger(0);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shutdown_signal_initial_state() {
let signal = ShutdownSignal::new();
assert!(!signal.is_signaled());
assert_eq!(signal.signal_number(), 0);
assert_eq!(signal.signal_name(), "none");
}
#[test]
fn test_shutdown_signal_trigger() {
let signal = ShutdownSignal::new();
signal.trigger(SIGINT);
assert!(signal.is_signaled());
assert_eq!(signal.signal_number(), SIGINT);
assert_eq!(signal.signal_name(), "SIGINT");
}
#[test]
fn test_shutdown_signal_reset() {
let signal = ShutdownSignal::new();
signal.trigger(SIGTERM);
assert!(signal.is_signaled());
signal.reset();
assert!(!signal.is_signaled());
assert_eq!(signal.signal_number(), 0);
}
#[test]
fn test_shutdown_guard_disarm() {
let signal = Arc::new(ShutdownSignal::new());
{
let mut guard = ShutdownGuard::new(signal.clone());
guard.disarm();
}
assert!(!signal.is_signaled());
}
#[test]
fn test_shutdown_guard_no_disarm() {
let signal = Arc::new(ShutdownSignal::new());
{
let _guard = ShutdownGuard::new(signal.clone());
}
assert!(signal.is_signaled());
}
#[test]
fn test_signal_names() {
let signal = ShutdownSignal::new();
signal.trigger(SIGINT);
assert_eq!(signal.signal_name(), "SIGINT");
signal.reset();
signal.trigger(SIGTERM);
assert_eq!(signal.signal_name(), "SIGTERM");
signal.reset();
signal.trigger(SIGHUP);
assert_eq!(signal.signal_name(), "SIGHUP");
}
#[test]
fn test_concurrent_access() {
use std::thread;
let signal = Arc::new(ShutdownSignal::new());
let handles: Vec<_> = (0..10)
.map(|_| {
let sig = signal.clone();
thread::spawn(move || {
for _ in 0..1000 {
let _ = sig.is_signaled();
}
})
})
.collect();
signal.trigger(SIGINT);
for h in handles {
h.join().unwrap();
}
assert!(signal.is_signaled());
}
}