use mio::{
event::{self},
unix::SourceFd,
Interest, Registry, Token,
};
use std::{fmt, mem, os::unix::prelude::RawFd, str::FromStr};
use crate::{Error, Result};
#[repr(i32)]
#[non_exhaustive]
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum Signal {
SIGHUP = libc::SIGHUP,
SIGINT = libc::SIGINT,
SIGQUIT = libc::SIGQUIT,
SIGILL = libc::SIGILL,
SIGTRAP = libc::SIGTRAP,
SIGABRT = libc::SIGABRT,
SIGBUS = libc::SIGBUS,
SIGFPE = libc::SIGFPE,
SIGKILL = libc::SIGKILL,
SIGUSR1 = libc::SIGUSR1,
SIGSEGV = libc::SIGSEGV,
SIGUSR2 = libc::SIGUSR2,
SIGPIPE = libc::SIGPIPE,
SIGALRM = libc::SIGALRM,
SIGTERM = libc::SIGTERM,
SIGSTKFLT = libc::SIGSTKFLT,
SIGCHLD = libc::SIGCHLD,
SIGCONT = libc::SIGCONT,
SIGSTOP = libc::SIGSTOP,
SIGTSTP = libc::SIGTSTP,
SIGTTIN = libc::SIGTTIN,
SIGTTOU = libc::SIGTTOU,
SIGURG = libc::SIGURG,
SIGXCPU = libc::SIGXCPU,
SIGXFSZ = libc::SIGXFSZ,
SIGVTALRM = libc::SIGVTALRM,
SIGPROF = libc::SIGPROF,
SIGWINCH = libc::SIGWINCH,
SIGIO = libc::SIGIO,
SIGPWR = libc::SIGPWR,
SIGSYS = libc::SIGSYS,
}
impl Signal {
pub const fn as_str(self) -> &'static str {
match self {
Signal::SIGHUP => "SIGHUP",
Signal::SIGINT => "SIGINT",
Signal::SIGQUIT => "SIGQUIT",
Signal::SIGILL => "SIGILL",
Signal::SIGTRAP => "SIGTRAP",
Signal::SIGABRT => "SIGABRT",
Signal::SIGBUS => "SIGBUS",
Signal::SIGFPE => "SIGFPE",
Signal::SIGKILL => "SIGKILL",
Signal::SIGUSR1 => "SIGUSR1",
Signal::SIGSEGV => "SIGSEGV",
Signal::SIGUSR2 => "SIGUSR2",
Signal::SIGPIPE => "SIGPIPE",
Signal::SIGALRM => "SIGALRM",
Signal::SIGTERM => "SIGTERM",
Signal::SIGSTKFLT => "SIGSTKFLT",
Signal::SIGCHLD => "SIGCHLD",
Signal::SIGCONT => "SIGCONT",
Signal::SIGSTOP => "SIGSTOP",
Signal::SIGTSTP => "SIGTSTP",
Signal::SIGTTIN => "SIGTTIN",
Signal::SIGTTOU => "SIGTTOU",
Signal::SIGURG => "SIGURG",
Signal::SIGXCPU => "SIGXCPU",
Signal::SIGXFSZ => "SIGXFSZ",
Signal::SIGVTALRM => "SIGVTALRM",
Signal::SIGPROF => "SIGPROF",
Signal::SIGWINCH => "SIGWINCH",
Signal::SIGIO => "SIGIO",
Signal::SIGPWR => "SIGPWR",
Signal::SIGSYS => "SIGSYS",
}
}
}
impl FromStr for Signal {
type Err = Error;
fn from_str(s: &str) -> Result<Signal> {
Ok(match s {
"SIGHUP" => Signal::SIGHUP,
"SIGINT" => Signal::SIGINT,
"SIGQUIT" => Signal::SIGQUIT,
"SIGILL" => Signal::SIGILL,
"SIGTRAP" => Signal::SIGTRAP,
"SIGABRT" => Signal::SIGABRT,
"SIGBUS" => Signal::SIGBUS,
"SIGFPE" => Signal::SIGFPE,
"SIGKILL" => Signal::SIGKILL,
"SIGUSR1" => Signal::SIGUSR1,
"SIGSEGV" => Signal::SIGSEGV,
"SIGUSR2" => Signal::SIGUSR2,
"SIGPIPE" => Signal::SIGPIPE,
"SIGALRM" => Signal::SIGALRM,
"SIGTERM" => Signal::SIGTERM,
"SIGSTKFLT" => Signal::SIGSTKFLT,
"SIGCHLD" => Signal::SIGCHLD,
"SIGCONT" => Signal::SIGCONT,
"SIGSTOP" => Signal::SIGSTOP,
"SIGTSTP" => Signal::SIGTSTP,
"SIGTTIN" => Signal::SIGTTIN,
"SIGTTOU" => Signal::SIGTTOU,
"SIGURG" => Signal::SIGURG,
"SIGXCPU" => Signal::SIGXCPU,
"SIGXFSZ" => Signal::SIGXFSZ,
"SIGVTALRM" => Signal::SIGVTALRM,
"SIGPROF" => Signal::SIGPROF,
"SIGWINCH" => Signal::SIGWINCH,
"SIGIO" => Signal::SIGIO,
"SIGPWR" => Signal::SIGPWR,
"SIGSYS" => Signal::SIGSYS,
_ => {
return Err(Error::Syscall(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"invalid signal",
)))
}
})
}
}
impl TryFrom<libc::c_int> for Signal {
type Error = Error;
fn try_from(signum: libc::c_int) -> std::result::Result<Self, Self::Error> {
if 0 < signum && signum < 32 {
Ok(unsafe { mem::transmute(signum) })
} else {
Err(Error::Syscall(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"invalid signal number",
)))
}
}
}
impl AsRef<str> for Signal {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl fmt::Display for Signal {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.as_ref())
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct SignalSet(libc::sigset_t);
impl SignalSet {
pub fn fill() -> Result<SignalSet> {
let mut set = mem::MaybeUninit::uninit();
syscall!(sigfillset(set.as_mut_ptr()))?;
Ok(unsafe { SignalSet(set.assume_init()) })
}
pub fn empty() -> Result<SignalSet> {
let mut set = mem::MaybeUninit::uninit();
syscall!(sigemptyset(set.as_mut_ptr()))?;
Ok(unsafe { SignalSet(set.assume_init()) })
}
pub fn add(&mut self, signal: Signal) -> Result<()> {
syscall!(sigaddset(
&mut self.0 as *mut libc::sigset_t,
signal as libc::c_int
))?;
Ok(())
}
pub fn remove(&mut self, signal: Signal) -> Result<()> {
syscall!(sigdelset(
&mut self.0 as *mut libc::sigset_t,
signal as libc::c_int
))?;
Ok(())
}
}
impl AsRef<libc::sigset_t> for SignalSet {
fn as_ref(&self) -> &libc::sigset_t {
&self.0
}
}
impl From<&[Signal]> for SignalSet {
fn from(signals: &[Signal]) -> Self {
*signals.iter().fold(
&mut SignalSet::empty().expect("syscall failed"),
|set, sig| {
set.add(*sig).expect("syscall failed");
set
},
)
}
}
const SIGNALFD_NEW: libc::c_int = -1;
#[derive(Debug, PartialEq, Eq)]
pub struct SignalFd(RawFd);
impl SignalFd {
pub fn new(signals: SignalSet) -> Result<SignalFd> {
let fd = syscall!(signalfd(
SIGNALFD_NEW,
signals.as_ref() as *const libc::sigset_t,
0
))?;
Ok(SignalFd(fd))
}
pub fn read_signal(&mut self) -> Result<Signal> {
let mut siginfo = mem::MaybeUninit::<libc::signalfd_siginfo>::uninit();
let size = mem::size_of_val(&siginfo);
let num = syscall!(read(
self.0,
siginfo.as_mut_ptr() as *mut libc::c_void,
size
))?;
if num as usize != size {
return Err(Error::Syscall(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"invalid signal",
)));
}
let siginfo = unsafe { siginfo.assume_init() };
let signum = siginfo.ssi_signo as libc::c_int;
signum.try_into()
}
}
impl event::Source for SignalFd {
fn register(
&mut self,
registry: &Registry,
token: Token,
interests: Interest,
) -> std::io::Result<()> {
SourceFd(&self.0).register(registry, token, interests)
}
fn reregister(
&mut self,
registry: &Registry,
token: Token,
interests: Interest,
) -> std::io::Result<()> {
SourceFd(&self.0).reregister(registry, token, interests)
}
fn deregister(&mut self, registry: &Registry) -> std::io::Result<()> {
SourceFd(&self.0).deregister(registry)
}
}
impl Drop for SignalFd {
fn drop(&mut self) {
unsafe { libc::close(self.0) };
if std::io::Error::last_os_error().raw_os_error().unwrap_or(0) == libc::EBADF {
panic!("closing invalid signal fd");
}
}
}
pub fn signal_block(set: SignalSet) -> Result<SignalSet> {
let mut old = SignalSet::empty()?;
syscall!(sigprocmask(
libc::SIG_BLOCK,
&set.0 as *const libc::sigset_t,
&mut old.0 as &mut libc::sigset_t
))?;
Ok(old)
}
pub fn signal_restore(set: SignalSet) -> Result<SignalSet> {
let mut old = SignalSet::empty()?;
syscall!(sigprocmask(
libc::SIG_SETMASK,
&set.0 as *const libc::sigset_t,
&mut old.0 as &mut libc::sigset_t
))?;
Ok(old)
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use super::{signal_block, signal_restore, Signal, SignalFd, SignalSet};
#[test]
fn signal_set_add() -> Result<()> {
let mut set = SignalSet::empty()?;
set.add(Signal::SIGCHLD)?;
set.add(Signal::SIGPIPE)?;
assert_ne!(set, SignalSet::empty()?);
assert_ne!(set, SignalSet::fill()?);
Ok(())
}
#[test]
fn signal_set_remove() -> Result<()> {
let mut set = SignalSet::empty()?;
set.add(Signal::SIGHUP)?;
set.add(Signal::SIGQUIT)?;
set.remove(Signal::SIGHUP)?;
set.remove(Signal::SIGQUIT)?;
assert_eq!(set, SignalSet::empty()?);
assert_ne!(set, SignalSet::fill()?);
Ok(())
}
#[test]
fn signal_set_remove_unknown() -> Result<()> {
let mut set = SignalSet::empty()?;
set.remove(Signal::SIGCHLD)?;
set.add(Signal::SIGHUP)?;
set.remove(Signal::SIGCHLD)?;
set.remove(Signal::SIGHUP)?;
assert_eq!(set, SignalSet::empty()?);
assert_ne!(set, SignalSet::fill()?);
Ok(())
}
#[test]
fn signal_try_from() -> Result<()> {
let signum = Signal::SIGQUIT as libc::c_int;
let sig: Signal = signum.try_into()?;
assert_eq!(signum, libc::SIGQUIT);
assert_eq!(sig, Signal::SIGQUIT);
let res: std::result::Result<Signal, _> = (255 as libc::c_int).try_into();
assert_eq!(
format!("{:?}", res.err().unwrap()),
"Syscall(Custom { kind: InvalidData, error: \"invalid signal number\" })"
);
Ok(())
}
#[test]
fn block_signals() -> Result<()> {
let signals = vec![
Signal::SIGCHLD,
Signal::SIGINT,
Signal::SIGQUIT,
Signal::SIGTERM,
];
let old = signal_block(signals.as_slice().into())?;
assert_eq!(old, SignalSet::empty()?);
assert_ne!(old, SignalSet::fill()?);
let blocked = signal_block(old)?;
assert_eq!(blocked, signals.as_slice().into());
Ok(())
}
#[test]
fn restore_signals() -> Result<()> {
let signals = vec![
Signal::SIGCHLD,
Signal::SIGINT,
Signal::SIGQUIT,
Signal::SIGTERM,
];
let old = signal_block(signals.as_slice().into())?;
let blocked = signal_restore(old)?;
assert_eq!(blocked, signals.as_slice().into());
Ok(())
}
#[test]
fn signalfd_new() -> Result<()> {
let _ = SignalFd::new(SignalSet::empty()?)?;
Ok(())
}
#[test]
#[should_panic(expected = "closing invalid signal fd")]
fn signalfd_drop_invalid() {
let fake = SignalFd(-1);
drop(fake);
}
}