use crate::error::{Error, Result};
#[cfg(unix)]
use nix::sys::signal::{kill, Signal};
#[cfg(unix)]
use nix::unistd::Pid;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::signal;
use tracing::{debug, info, warn};
pub struct SignalHandler {
shutdown_requested: Arc<AtomicBool>,
}
impl SignalHandler {
pub fn new() -> Self {
Self {
shutdown_requested: Arc::new(AtomicBool::new(false)),
}
}
pub fn is_shutdown_requested(&self) -> bool {
self.shutdown_requested.load(Ordering::Relaxed)
}
pub fn reset_shutdown_flag(&self) {
self.shutdown_requested.store(false, Ordering::Relaxed);
}
#[cfg(unix)]
pub fn send_signal(&self, pid: u32, signal: ProcessSignal) -> Result<()> {
let pid = Pid::from_raw(pid as i32);
let signal = match signal {
ProcessSignal::Term => Signal::SIGTERM,
ProcessSignal::Kill => Signal::SIGKILL,
ProcessSignal::Int => Signal::SIGINT,
ProcessSignal::Quit => Signal::SIGQUIT,
ProcessSignal::Usr1 => Signal::SIGUSR1,
ProcessSignal::Usr2 => Signal::SIGUSR2,
};
debug!("Sending signal {} to PID {}", signal, pid);
kill(pid, signal).map_err(|e| {
Error::signal(format!(
"Failed to send signal {} to PID {}: {}",
signal, pid, e
))
})?;
Ok(())
}
#[cfg(windows)]
pub fn send_signal(&self, pid: u32, signal: ProcessSignal) -> Result<()> {
use std::process::Command;
debug!("Sending signal {} to PID {} (Windows)", signal, pid);
match signal {
ProcessSignal::Term | ProcessSignal::Kill => {
let output = Command::new("taskkill")
.args(&["/PID", &pid.to_string(), "/F"])
.output()
.map_err(|e| Error::signal(format!("Failed to execute taskkill: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(Error::signal(format!(
"Failed to kill process {}: {}",
pid, stderr
)));
}
Ok(())
}
_ => {
Err(Error::signal(format!(
"Signal {} is not supported on Windows",
signal
)))
}
}
}
#[cfg(unix)]
pub async fn setup_handlers(&self) -> Result<()> {
info!("Setting up signal handlers for graceful shutdown (Unix)");
let shutdown_flag = Arc::clone(&self.shutdown_requested);
let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())
.map_err(|e| Error::signal(format!("Failed to setup SIGTERM handler: {}", e)))?;
let mut sigint = signal::unix::signal(signal::unix::SignalKind::interrupt())
.map_err(|e| Error::signal(format!("Failed to setup SIGINT handler: {}", e)))?;
let shutdown_flag_term = Arc::clone(&shutdown_flag);
let shutdown_flag_int = Arc::clone(&shutdown_flag);
tokio::spawn(async move {
tokio::select! {
_ = sigterm.recv() => {
info!("Received SIGTERM, initiating graceful shutdown");
shutdown_flag_term.store(true, Ordering::Relaxed);
}
_ = sigint.recv() => {
info!("Received SIGINT, initiating graceful shutdown");
shutdown_flag_int.store(true, Ordering::Relaxed);
}
}
});
debug!("Signal handlers setup completed");
Ok(())
}
#[cfg(windows)]
pub async fn setup_handlers(&self) -> Result<()> {
info!("Setting up signal handlers for graceful shutdown (Windows)");
let shutdown_flag = Arc::clone(&self.shutdown_requested);
let ctrl_c = signal::ctrl_c();
tokio::spawn(async move {
ctrl_c.await.ok();
info!("Received Ctrl+C, initiating graceful shutdown");
shutdown_flag.store(true, Ordering::Relaxed);
});
debug!("Signal handlers setup completed");
Ok(())
}
#[cfg(unix)]
pub async fn graceful_shutdown(&self, pid: u32, timeout_ms: u64) -> Result<()> {
debug!(
"Initiating graceful shutdown for PID {} with timeout {}ms (Unix)",
pid, timeout_ms
);
self.send_signal(pid, ProcessSignal::Term)?;
let timeout = tokio::time::Duration::from_millis(timeout_ms);
let start = tokio::time::Instant::now();
let poll_interval = tokio::time::Duration::from_millis(100);
while start.elapsed() < timeout {
match kill(Pid::from_raw(pid as i32), None) {
Ok(_) => {
debug!("Process {} still running, waiting...", pid);
tokio::time::sleep(poll_interval).await;
}
Err(_) => {
info!("Process {} exited gracefully", pid);
return Ok(());
}
}
}
warn!(
"Process {} did not exit gracefully within {}ms, sending SIGKILL",
pid, timeout_ms
);
self.send_signal(pid, ProcessSignal::Kill)?;
let kill_timeout = tokio::time::Duration::from_millis(500);
let kill_start = tokio::time::Instant::now();
while kill_start.elapsed() < kill_timeout {
match kill(Pid::from_raw(pid as i32), None) {
Ok(_) => {
tokio::time::sleep(poll_interval).await;
}
Err(_) => {
info!("Process {} forcefully terminated", pid);
return Ok(());
}
}
}
Err(Error::signal(format!(
"Failed to kill process {} even with SIGKILL",
pid
)))
}
#[cfg(windows)]
pub async fn graceful_shutdown(&self, pid: u32, timeout_ms: u64) -> Result<()> {
debug!(
"Initiating graceful shutdown for PID {} with timeout {}ms (Windows)",
pid, timeout_ms
);
if let Ok(_) = self.send_signal(pid, ProcessSignal::Term) {
let timeout = tokio::time::Duration::from_millis(timeout_ms);
let start = tokio::time::Instant::now();
let poll_interval = tokio::time::Duration::from_millis(100);
while start.elapsed() < timeout {
let output = std::process::Command::new("tasklist")
.args(&["/FI", &format!("PID eq {}", pid)])
.output();
if let Ok(output) = output {
let stdout = String::from_utf8_lossy(&output.stdout);
if !stdout.contains(&pid.to_string()) {
info!("Process {} exited gracefully", pid);
return Ok(());
}
}
tokio::time::sleep(poll_interval).await;
}
}
warn!(
"Process {} did not exit gracefully within {}ms, force killing",
pid, timeout_ms
);
self.send_signal(pid, ProcessSignal::Kill)?;
info!("Process {} forcefully terminated", pid);
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProcessSignal {
Term,
Kill,
Int,
Quit,
Usr1,
Usr2,
}
impl Default for SignalHandler {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for ProcessSignal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProcessSignal::Term => write!(f, "SIGTERM"),
ProcessSignal::Kill => write!(f, "SIGKILL"),
ProcessSignal::Int => write!(f, "SIGINT"),
ProcessSignal::Quit => write!(f, "SIGQUIT"),
ProcessSignal::Usr1 => write!(f, "SIGUSR1"),
ProcessSignal::Usr2 => write!(f, "SIGUSR2"),
}
}
}
impl FromStr for ProcessSignal {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"TERM" | "SIGTERM" => Ok(ProcessSignal::Term),
"KILL" | "SIGKILL" => Ok(ProcessSignal::Kill),
"INT" | "SIGINT" => Ok(ProcessSignal::Int),
"QUIT" | "SIGQUIT" => Ok(ProcessSignal::Quit),
"USR1" | "SIGUSR1" => Ok(ProcessSignal::Usr1),
"USR2" | "SIGUSR2" => Ok(ProcessSignal::Usr2),
_ => Err(format!("Invalid signal: {}", s)),
}
}
}
impl ProcessSignal {
pub fn all() -> Vec<ProcessSignal> {
vec![
ProcessSignal::Term,
ProcessSignal::Kill,
ProcessSignal::Int,
ProcessSignal::Quit,
ProcessSignal::Usr1,
ProcessSignal::Usr2,
]
}
pub fn is_termination_signal(&self) -> bool {
matches!(
self,
ProcessSignal::Term | ProcessSignal::Kill | ProcessSignal::Int | ProcessSignal::Quit
)
}
pub fn is_user_signal(&self) -> bool {
matches!(self, ProcessSignal::Usr1 | ProcessSignal::Usr2)
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_signal_handler_new() {
let handler = SignalHandler::new();
assert!(!handler.is_shutdown_requested());
}
#[test]
fn test_signal_handler_default() {
let handler = SignalHandler::default();
assert!(!handler.is_shutdown_requested());
}
#[test]
fn test_signal_handler_shutdown_flag() {
let handler = SignalHandler::new();
assert!(!handler.is_shutdown_requested());
handler.shutdown_requested.store(true, Ordering::Relaxed);
assert!(handler.is_shutdown_requested());
handler.reset_shutdown_flag();
assert!(!handler.is_shutdown_requested());
}
#[test]
fn test_process_signal_display() {
assert_eq!(ProcessSignal::Term.to_string(), "SIGTERM");
assert_eq!(ProcessSignal::Kill.to_string(), "SIGKILL");
assert_eq!(ProcessSignal::Int.to_string(), "SIGINT");
assert_eq!(ProcessSignal::Quit.to_string(), "SIGQUIT");
assert_eq!(ProcessSignal::Usr1.to_string(), "SIGUSR1");
assert_eq!(ProcessSignal::Usr2.to_string(), "SIGUSR2");
}
#[test]
fn test_process_signal_debug() {
let signal = ProcessSignal::Term;
let debug_str = format!("{:?}", signal);
assert!(debug_str.contains("Term"));
}
#[test]
fn test_process_signal_clone() {
let original = ProcessSignal::Term;
let cloned = original;
assert_eq!(format!("{:?}", original), format!("{:?}", cloned));
}
#[test]
fn test_process_signal_copy() {
let original = ProcessSignal::Term;
let copied = original;
assert_eq!(format!("{:?}", original), format!("{:?}", copied));
}
#[test]
fn test_process_signal_all() {
let signals = ProcessSignal::all();
assert_eq!(signals.len(), 6);
assert!(signals.contains(&ProcessSignal::Term));
assert!(signals.contains(&ProcessSignal::Kill));
assert!(signals.contains(&ProcessSignal::Int));
assert!(signals.contains(&ProcessSignal::Quit));
assert!(signals.contains(&ProcessSignal::Usr1));
assert!(signals.contains(&ProcessSignal::Usr2));
}
#[test]
fn test_process_signal_from_str() {
assert_eq!("TERM".parse::<ProcessSignal>(), Ok(ProcessSignal::Term));
assert_eq!("SIGTERM".parse::<ProcessSignal>(), Ok(ProcessSignal::Term));
assert_eq!("term".parse::<ProcessSignal>(), Ok(ProcessSignal::Term));
assert_eq!("sigterm".parse::<ProcessSignal>(), Ok(ProcessSignal::Term));
assert_eq!("KILL".parse::<ProcessSignal>(), Ok(ProcessSignal::Kill));
assert_eq!("SIGKILL".parse::<ProcessSignal>(), Ok(ProcessSignal::Kill));
assert_eq!("INT".parse::<ProcessSignal>(), Ok(ProcessSignal::Int));
assert_eq!("SIGINT".parse::<ProcessSignal>(), Ok(ProcessSignal::Int));
assert_eq!("QUIT".parse::<ProcessSignal>(), Ok(ProcessSignal::Quit));
assert_eq!("SIGQUIT".parse::<ProcessSignal>(), Ok(ProcessSignal::Quit));
assert_eq!("USR1".parse::<ProcessSignal>(), Ok(ProcessSignal::Usr1));
assert_eq!("SIGUSR1".parse::<ProcessSignal>(), Ok(ProcessSignal::Usr1));
assert_eq!("USR2".parse::<ProcessSignal>(), Ok(ProcessSignal::Usr2));
assert_eq!("SIGUSR2".parse::<ProcessSignal>(), Ok(ProcessSignal::Usr2));
assert!("INVALID".parse::<ProcessSignal>().is_err());
assert!("".parse::<ProcessSignal>().is_err());
}
#[test]
fn test_process_signal_is_termination_signal() {
assert!(ProcessSignal::Term.is_termination_signal());
assert!(ProcessSignal::Kill.is_termination_signal());
assert!(ProcessSignal::Int.is_termination_signal());
assert!(ProcessSignal::Quit.is_termination_signal());
assert!(!ProcessSignal::Usr1.is_termination_signal());
assert!(!ProcessSignal::Usr2.is_termination_signal());
}
#[test]
fn test_process_signal_is_user_signal() {
assert!(!ProcessSignal::Term.is_user_signal());
assert!(!ProcessSignal::Kill.is_user_signal());
assert!(!ProcessSignal::Int.is_user_signal());
assert!(!ProcessSignal::Quit.is_user_signal());
assert!(ProcessSignal::Usr1.is_user_signal());
assert!(ProcessSignal::Usr2.is_user_signal());
}
#[tokio::test]
async fn test_send_signal_to_nonexistent_process() {
let handler = SignalHandler::new();
let fake_pid = 999999u32;
let result = handler.send_signal(fake_pid, ProcessSignal::Term);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("Failed to send signal")
|| error_msg.contains("Failed to kill process")
|| error_msg.contains("Failed to execute taskkill")
|| error_msg.contains("No such process")
|| error_msg.contains("process")
);
}
#[tokio::test]
async fn test_graceful_shutdown_nonexistent_process() {
let handler = SignalHandler::new();
let fake_pid = 999999u32;
let result = handler.graceful_shutdown(fake_pid, 100).await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("Failed to send signal")
|| error_msg.contains("Failed to kill process")
|| error_msg.contains("Failed to execute taskkill")
|| error_msg.contains("No such process")
|| error_msg.contains("process")
);
}
#[tokio::test]
async fn test_setup_handlers() {
let handler = SignalHandler::new();
let result = handler.setup_handlers().await;
assert!(result.is_ok());
}
#[test]
fn test_process_signal_equality() {
assert_eq!(ProcessSignal::Term, ProcessSignal::Term);
assert_ne!(ProcessSignal::Term, ProcessSignal::Kill);
assert_ne!(ProcessSignal::Int, ProcessSignal::Quit);
assert_eq!(ProcessSignal::Usr1, ProcessSignal::Usr1);
}
#[test]
fn test_process_signal_comprehensive_coverage() {
let signals = vec![
ProcessSignal::Term,
ProcessSignal::Kill,
ProcessSignal::Int,
ProcessSignal::Quit,
ProcessSignal::Usr1,
ProcessSignal::Usr2,
];
for signal in signals {
let display_str = signal.to_string();
assert!(!display_str.is_empty());
assert!(display_str.starts_with("SIG"));
let debug_str = format!("{:?}", signal);
assert!(!debug_str.is_empty());
let is_term = signal.is_termination_signal();
let is_user = signal.is_user_signal();
assert!(is_term || is_user);
assert!(!(is_term && is_user));
}
}
}