use std::collections::HashMap;
pub type SignalHandler = extern "C" fn(libc::c_int);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SignalKind(libc::c_int);
macro_rules! impl_signal_delegates {
(
$(
$constant:path = $name:ident
),*$(,)?
) => {
$(
#[doc = concat!("Wrapper around [`",
stringify!($constant),
"`](",
stringify!($constant),
").")]
pub const fn $name() -> Self {
Self($constant)
}
)*
};
}
impl SignalKind {
pub const fn as_raw(&self) -> libc::c_int {
self.0
}
impl_signal_delegates!(
libc::SIGABRT = abort,
libc::SIGFPE = fpe,
libc::SIGINT = int,
libc::SIGILL = invalid,
libc::SIGSEGV = segv,
libc::SIGTERM = term,
);
#[cfg(unix)]
impl_signal_delegates!(
libc::SIGALRM = alarm,
libc::SIGBUS = bus,
libc::SIGCHLD = child,
libc::SIGCONT = r#continue,
libc::SIGHUP = hangup,
libc::SIGKILL = kill,
libc::SIGPIPE = pipe,
libc::SIGQUIT = quit,
libc::SIGSTOP = stop,
libc::SIGTSTP = terminal_stop,
libc::SIGTTIN = tty_in,
libc::SIGTTOU = tty_out,
libc::SIGUSR1 = user1,
libc::SIGUSR2 = user2,
libc::SIGSYS = sys,
libc::SIGTRAP = trap,
libc::SIGURG = urgent,
libc::SIGVTALRM = virtual_alarm,
libc::SIGXCPU = xcpu,
libc::SIGXFSZ = xfsz,
);
}
impl From<SignalKind> for libc::c_int {
fn from(value: SignalKind) -> Self {
value.as_raw()
}
}
impl From<libc::c_int> for SignalKind {
fn from(value: libc::c_int) -> Self {
Self(value)
}
}
pub struct SignalGuard {
stashed_signals: HashMap<SignalKind, libc::sighandler_t>,
}
impl SignalGuard {
pub fn ignore(signals: impl IntoIterator<Item = SignalKind>) -> Option<Self> {
Self::new_impl_with_fallback(
signals.into_iter(),
None,
libc::SIG_IGN as libc::sighandler_t,
)
}
pub fn default(signals: impl IntoIterator<Item = SignalKind>) -> Option<Self> {
Self::new_impl_with_fallback(
signals.into_iter(),
None,
libc::SIG_DFL as libc::sighandler_t,
)
}
fn new_impl_with_fallback(
signals: impl Iterator<Item = SignalKind>,
keys: Option<&HashMap<SignalKind, SignalHandler>>,
fallback: libc::sighandler_t,
) -> Option<Self> {
let get_signal_for = |kind| {
let Some(keys) = keys else { return fallback };
keys.get(&kind)
.map(|handler| *handler as libc::sighandler_t)
.unwrap_or(fallback)
};
let mut stashed_signals = HashMap::new();
for signal in signals {
let new_handler = get_signal_for(signal);
let old_handler = unsafe { libc::signal(signal.as_raw(), new_handler) };
if old_handler == libc::SIG_ERR as libc::sighandler_t {
return None;
}
stashed_signals.insert(signal, old_handler);
}
Some(Self { stashed_signals })
}
}
impl Drop for SignalGuard {
fn drop(&mut self) {
for (signal, action) in self.stashed_signals.iter() {
let _ = unsafe { libc::signal(signal.as_raw() as libc::c_int, *action) };
}
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
#[cfg(unix)]
fn basic_sigterm() {
let _guard = SignalGuard::ignore([SignalKind::term()]).unwrap();
unsafe {
let _ = libc::raise(SignalKind::term().as_raw());
}
}
}