use std::collections::HashSet;
use std::os::unix::io::RawFd;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering};
use nix::sys::signal::{SaFlags, SigAction, SigHandler, SigSet, Signal, sigaction};
static PENDING_EXIT_SIGNAL: AtomicBool = AtomicBool::new(false);
pub fn has_pending_exit_signal() -> bool {
PENDING_EXIT_SIGNAL.load(Ordering::Acquire)
}
pub const SIGNAL_TABLE: &[(i32, &str)] = &[
(libc::SIGHUP, "HUP"),
(libc::SIGINT, "INT"),
(libc::SIGQUIT, "QUIT"),
(libc::SIGABRT, "ABRT"),
(libc::SIGKILL, "KILL"),
(libc::SIGUSR1, "USR1"),
(libc::SIGUSR2, "USR2"),
(libc::SIGPIPE, "PIPE"),
(libc::SIGALRM, "ALRM"),
(libc::SIGTERM, "TERM"),
(libc::SIGCHLD, "CHLD"),
(libc::SIGCONT, "CONT"),
(libc::SIGSTOP, "STOP"),
(libc::SIGTSTP, "TSTP"),
(libc::SIGTTIN, "TTIN"),
(libc::SIGTTOU, "TTOU"),
];
pub const HANDLED_SIGNALS: &[(i32, &str)] = &[
(libc::SIGHUP, "HUP"),
(libc::SIGINT, "INT"),
(libc::SIGQUIT, "QUIT"),
(libc::SIGALRM, "ALRM"),
(libc::SIGTERM, "TERM"),
(libc::SIGUSR1, "USR1"),
(libc::SIGUSR2, "USR2"),
];
pub fn signal_name_to_number(name: &str) -> Result<i32, String> {
let upper = name.to_ascii_uppercase();
let stripped = upper.strip_prefix("SIG").unwrap_or(&upper);
for &(num, table_name) in SIGNAL_TABLE {
if table_name == stripped {
return Ok(num);
}
}
Err(format!("unknown signal: {name}"))
}
pub fn signal_number_to_name(num: i32) -> Option<&'static str> {
for &(table_num, name) in SIGNAL_TABLE {
if table_num == num {
return Some(name);
}
}
None
}
static SELF_PIPE: OnceLock<(RawFd, RawFd)> = OnceLock::new();
static IGNORED_ON_ENTRY: OnceLock<HashSet<i32>> = OnceLock::new();
fn capture_ignored_on_entry() -> HashSet<i32> {
let mut set = HashSet::new();
for &(num, _) in SIGNAL_TABLE {
if num == libc::SIGKILL || num == libc::SIGSTOP {
continue;
}
let mut old: libc::sigaction = unsafe { std::mem::zeroed() };
let rc = unsafe { libc::sigaction(num, std::ptr::null(), &mut old) };
if rc != 0 {
continue;
}
if old.sa_sigaction == libc::SIG_IGN {
set.insert(num);
}
}
set
}
pub fn is_ignored_on_entry(sig: i32) -> bool {
IGNORED_ON_ENTRY
.get()
.map_or(false, |set| set.contains(&sig))
}
pub fn ignored_on_entry_set_opt() -> Option<&'static HashSet<i32>> {
IGNORED_ON_ENTRY.get()
}
#[allow(dead_code)]
pub fn ignored_on_entry_set() -> &'static HashSet<i32> {
IGNORED_ON_ENTRY
.get()
.expect("init_signal_handling() must be called first")
}
extern "C" fn signal_handler(sig: libc::c_int) {
if sig == libc::SIGHUP || sig == libc::SIGTERM {
PENDING_EXIT_SIGNAL.store(true, Ordering::Release);
}
let Some(&(_, write_fd)) = SELF_PIPE.get() else {
return;
};
let byte = sig as u8;
unsafe {
libc::write(write_fd, &byte as *const u8 as *const libc::c_void, 1);
}
}
pub fn init_signal_handling() {
SELF_PIPE.get_or_init(|| {
let entry_ignored = IGNORED_ON_ENTRY.get_or_init(capture_ignored_on_entry);
let mut fds: [libc::c_int; 2] = [0; 2];
let ret = unsafe { libc::pipe(fds.as_mut_ptr()) };
assert_eq!(ret, 0, "pipe() failed");
let read_fd = unsafe { libc::fcntl(fds[0], libc::F_DUPFD_CLOEXEC, 10) };
assert!(read_fd >= 10, "F_DUPFD_CLOEXEC failed for read end");
unsafe { libc::close(fds[0]) };
let write_fd = unsafe { libc::fcntl(fds[1], libc::F_DUPFD_CLOEXEC, 10) };
assert!(write_fd >= 10, "F_DUPFD_CLOEXEC failed for write end");
unsafe { libc::close(fds[1]) };
for &fd in &[read_fd, write_fd] {
let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
unsafe {
libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK);
}
}
let sa_restart = SigAction::new(
SigHandler::Handler(signal_handler),
SaFlags::SA_RESTART,
SigSet::empty(),
);
let sa_no_restart = SigAction::new(
SigHandler::Handler(signal_handler),
SaFlags::empty(),
SigSet::empty(),
);
for &(num, _) in HANDLED_SIGNALS {
if entry_ignored.contains(&num) {
continue;
}
let sig = Signal::try_from(num).expect("invalid signal number in HANDLED_SIGNALS");
let sa = if num == libc::SIGHUP || num == libc::SIGTERM {
&sa_no_restart
} else {
&sa_restart
};
unsafe {
sigaction(sig, sa).expect("sigaction failed");
}
}
(read_fd, write_fd)
});
}
pub fn drain_pending_signals() -> Vec<i32> {
PENDING_EXIT_SIGNAL.store(false, Ordering::Release);
let Some(&(read_fd, _)) = SELF_PIPE.get() else {
return Vec::new();
};
let mut signals = Vec::new();
let mut buf = [0u8; 128];
loop {
let n = unsafe { libc::read(read_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
if n <= 0 {
break;
}
for &b in &buf[..n as usize] {
signals.push(b as i32);
}
}
signals
}
pub fn self_pipe_read_fd() -> RawFd {
SELF_PIPE
.get()
.expect("init_signal_handling() must be called first")
.0
}
pub fn ignore_signal(sig: i32) {
let signal = Signal::try_from(sig).expect("invalid signal number");
let sa = SigAction::new(SigHandler::SigIgn, SaFlags::empty(), SigSet::empty());
unsafe {
sigaction(signal, &sa).expect("sigaction(SIG_IGN) failed");
}
}
pub fn default_signal(sig: i32) {
let signal = Signal::try_from(sig).expect("invalid signal number");
let sa = SigAction::new(SigHandler::SigDfl, SaFlags::empty(), SigSet::empty());
unsafe {
sigaction(signal, &sa).expect("sigaction(SIG_DFL) failed");
}
}
pub fn reset_child_signals(ignored: &[i32]) {
let entry_set = IGNORED_ON_ENTRY.get();
for &(num, _) in HANDLED_SIGNALS {
let keep_ignored = ignored.contains(&num) || entry_set.map_or(false, |s| s.contains(&num));
if keep_ignored {
ignore_signal(num);
} else {
default_signal(num);
}
}
if let Some(&(read_fd, write_fd)) = SELF_PIPE.get() {
unsafe {
libc::close(read_fd);
libc::close(write_fd);
}
}
}
pub fn init_job_control_signals() {
ignore_signal(libc::SIGTSTP);
ignore_signal(libc::SIGTTIN);
ignore_signal(libc::SIGTTOU);
let sa = SigAction::new(
SigHandler::Handler(signal_handler),
SaFlags::SA_RESTART,
SigSet::empty(),
);
let sig = Signal::try_from(libc::SIGCHLD).expect("SIGCHLD is valid");
unsafe {
sigaction(sig, &sa).expect("sigaction(SIGCHLD) failed");
}
}
pub fn reset_job_control_signals() {
default_signal(libc::SIGTSTP);
default_signal(libc::SIGTTIN);
default_signal(libc::SIGTTOU);
default_signal(libc::SIGCHLD);
}
pub fn setup_foreground_child_signals(ignored: &[i32]) {
reset_child_signals(ignored);
if !ignored.contains(&libc::SIGTSTP) {
default_signal(libc::SIGTSTP);
}
if !ignored.contains(&libc::SIGTTIN) {
default_signal(libc::SIGTTIN);
}
if !ignored.contains(&libc::SIGTTOU) {
default_signal(libc::SIGTTOU);
}
}
pub fn setup_background_child_signals(ignored: &[i32]) {
reset_child_signals(ignored);
ignore_signal(libc::SIGTTIN);
if !ignored.contains(&libc::SIGTSTP) {
default_signal(libc::SIGTSTP);
}
if !ignored.contains(&libc::SIGTTOU) {
default_signal(libc::SIGTTOU);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_signal_name_to_number_int() {
assert_eq!(signal_name_to_number("INT").unwrap(), 2);
}
#[test]
fn test_signal_name_to_number_sigint() {
assert_eq!(signal_name_to_number("SIGINT").unwrap(), 2);
}
#[test]
fn test_signal_name_to_number_case_insensitive() {
assert_eq!(signal_name_to_number("hup").unwrap(), 1);
}
#[test]
fn test_signal_name_to_number_term() {
assert_eq!(signal_name_to_number("TERM").unwrap(), 15);
}
#[test]
fn test_signal_name_to_number_kill() {
assert_eq!(signal_name_to_number("KILL").unwrap(), 9);
}
#[test]
fn test_signal_name_to_number_invalid() {
assert!(signal_name_to_number("INVALID").is_err());
}
#[test]
fn test_signal_number_to_name_2() {
assert_eq!(signal_number_to_name(2), Some("INT"));
}
#[test]
fn test_signal_number_to_name_15() {
assert_eq!(signal_number_to_name(15), Some("TERM"));
}
#[test]
fn test_signal_number_to_name_9() {
assert_eq!(signal_number_to_name(9), Some("KILL"));
}
#[test]
fn test_signal_number_to_name_999() {
assert_eq!(signal_number_to_name(999), None);
}
#[test]
fn test_handled_signals_are_in_signal_table() {
for &(num, name) in HANDLED_SIGNALS {
let found = SIGNAL_TABLE.iter().any(|&(n, nm)| n == num && nm == name);
assert!(
found,
"HANDLED_SIGNALS entry ({num}, {name}) not found in SIGNAL_TABLE"
);
}
}
#[test]
fn test_init_signal_handling() {
init_signal_handling();
init_signal_handling();
let fd = self_pipe_read_fd();
assert!(fd >= 0, "self_pipe_read_fd() should return a valid fd");
}
#[test]
fn test_drain_pending_signals_empty() {
init_signal_handling();
let signals = drain_pending_signals();
assert!(
signals.is_empty(),
"expected no pending signals, got: {signals:?}"
);
}
#[test]
fn test_signal_table_has_job_control_signals() {
assert_eq!(signal_name_to_number("CHLD").unwrap(), libc::SIGCHLD);
assert_eq!(signal_name_to_number("CONT").unwrap(), libc::SIGCONT);
assert_eq!(signal_name_to_number("STOP").unwrap(), libc::SIGSTOP);
assert_eq!(signal_name_to_number("TSTP").unwrap(), libc::SIGTSTP);
assert_eq!(signal_name_to_number("TTIN").unwrap(), libc::SIGTTIN);
assert_eq!(signal_name_to_number("TTOU").unwrap(), libc::SIGTTOU);
}
#[test]
fn test_signal_number_to_name_job_control() {
assert_eq!(signal_number_to_name(libc::SIGCHLD), Some("CHLD"));
assert_eq!(signal_number_to_name(libc::SIGTSTP), Some("TSTP"));
}
#[test]
fn test_job_control_signal_functions_exist() {
let _ = init_job_control_signals as fn();
let _ = reset_job_control_signals as fn();
let _ = setup_foreground_child_signals as fn(&[i32]);
let _ = setup_background_child_signals as fn(&[i32]);
}
#[test]
fn test_reset_job_control_signals_after_init() {
init_signal_handling();
init_job_control_signals();
reset_job_control_signals();
}
#[test]
fn test_is_ignored_on_entry_false_for_unlikely_signal() {
init_signal_handling();
assert!(
!is_ignored_on_entry(libc::SIGALRM),
"SIGALRM should not be ignored-on-entry in a normal test environment"
);
}
#[test]
fn test_capture_ignored_on_entry_detects_sig_ign() {
init_signal_handling();
let sig_num = libc::SIGALRM;
let mut original: libc::sigaction = unsafe { std::mem::zeroed() };
let rc = unsafe { libc::sigaction(sig_num, std::ptr::null(), &mut original) };
assert_eq!(rc, 0);
let ign_sa = SigAction::new(SigHandler::SigIgn, SaFlags::empty(), SigSet::empty());
let sig = Signal::try_from(sig_num).unwrap();
unsafe {
sigaction(sig, &ign_sa).unwrap();
}
let captured = capture_ignored_on_entry();
assert!(
captured.contains(&sig_num),
"capture_ignored_on_entry should detect SIGALRM SIG_IGN, got {:?}",
captured
);
let rc = unsafe { libc::sigaction(sig_num, &original, std::ptr::null_mut()) };
assert_eq!(rc, 0);
}
#[test]
fn test_capture_ignored_on_entry_excludes_default() {
init_signal_handling();
let sig_num = libc::SIGPIPE;
let mut original: libc::sigaction = unsafe { std::mem::zeroed() };
let rc = unsafe { libc::sigaction(sig_num, std::ptr::null(), &mut original) };
assert_eq!(rc, 0);
let dfl_sa = SigAction::new(SigHandler::SigDfl, SaFlags::empty(), SigSet::empty());
let sig = Signal::try_from(sig_num).unwrap();
unsafe {
sigaction(sig, &dfl_sa).unwrap();
}
let captured = capture_ignored_on_entry();
assert!(
!captured.contains(&sig_num),
"capture_ignored_on_entry should not include SIG_DFL signals, got {:?}",
captured
);
let rc = unsafe { libc::sigaction(sig_num, &original, std::ptr::null_mut()) };
assert_eq!(rc, 0);
}
#[test]
fn test_signal_table_matches_libc_constants() {
for &(num, name) in SIGNAL_TABLE {
let expected = match name {
"HUP" => libc::SIGHUP,
"INT" => libc::SIGINT,
"QUIT" => libc::SIGQUIT,
"ABRT" => libc::SIGABRT,
"KILL" => libc::SIGKILL,
"USR1" => libc::SIGUSR1,
"USR2" => libc::SIGUSR2,
"PIPE" => libc::SIGPIPE,
"ALRM" => libc::SIGALRM,
"TERM" => libc::SIGTERM,
"CHLD" => libc::SIGCHLD,
"CONT" => libc::SIGCONT,
"STOP" => libc::SIGSTOP,
"TSTP" => libc::SIGTSTP,
"TTIN" => libc::SIGTTIN,
"TTOU" => libc::SIGTTOU,
other => panic!("unexpected signal name in table: {other}"),
};
assert_eq!(
num, expected,
"SIGNAL_TABLE entry for {name} has {num}, libc says {expected}"
);
}
}
#[test]
fn test_handled_signals_match_libc_constants() {
for &(num, name) in HANDLED_SIGNALS {
let expected = match name {
"HUP" => libc::SIGHUP,
"INT" => libc::SIGINT,
"QUIT" => libc::SIGQUIT,
"ALRM" => libc::SIGALRM,
"TERM" => libc::SIGTERM,
"USR1" => libc::SIGUSR1,
"USR2" => libc::SIGUSR2,
other => panic!("unexpected signal name in HANDLED_SIGNALS: {other}"),
};
assert_eq!(
num, expected,
"HANDLED_SIGNALS entry for {name} has {num}, libc says {expected}"
);
}
}
}