use std::{
io,
mem::MaybeUninit,
os::{
fd::{AsRawFd, RawFd},
unix::net::UnixStream,
},
sync::{
atomic::{AtomicU8, Ordering},
Arc,
},
};
use crate::cutils::cerr;
use libc::{c_int, siginfo_t, MSG_DONTWAIT};
use signal_hook::low_level::{emulate_default_handler, signal_name};
use signal_hook_registry::{register_sigaction, unregister, SigId, FORBIDDEN};
use super::interface::ProcessId;
const SIGINFO_SIZE: usize = std::mem::size_of::<siginfo_t>();
pub type SignalNumber = c_int;
pub struct SignalInfo {
info: siginfo_t,
}
impl SignalInfo {
pub fn is_user_signaled(&self) -> bool {
self.info.si_code <= 0
}
pub fn pid(&self) -> ProcessId {
unsafe { self.info.si_pid() }
}
pub fn signal(&self) -> SignalNumber {
self.info.si_signo
}
}
#[repr(u8)]
pub enum SignalAction {
Stream = 0,
Default = 1,
Ignore = 2,
}
impl SignalAction {
fn try_new(val: u8) -> Option<Self> {
if val == Self::Stream as u8 {
Some(Self::Stream)
} else if val == Self::Default as u8 {
Some(Self::Default)
} else if val == Self::Ignore as u8 {
Some(Self::Ignore)
} else {
None
}
}
}
pub struct SignalHandler {
signal: SignalNumber,
sig_id: SigId,
rx: UnixStream,
action: Arc<AtomicU8>,
}
impl Drop for SignalHandler {
fn drop(&mut self) {
self.set_action(SignalAction::Default);
}
}
impl AsRawFd for SignalHandler {
fn as_raw_fd(&self) -> RawFd {
self.rx.as_raw_fd()
}
}
impl SignalHandler {
pub fn new(signal: SignalNumber) -> io::Result<Self> {
Self::with_action(signal, SignalAction::Stream)
}
pub fn with_action(signal: SignalNumber, action: SignalAction) -> io::Result<Self> {
if FORBIDDEN.contains(&signal) {
panic!(
"SignalHandler cannot be used to handle the forbidden {} signal",
signal_name(signal).unwrap()
);
}
let (rx, tx) = UnixStream::pair()?;
let action = Arc::new(AtomicU8::from(action as u8));
let sig_id = {
let action = Arc::clone(&action);
unsafe {
register_sigaction(signal, move |info| {
if let Some(action) = SignalAction::try_new(action.load(Ordering::SeqCst)) {
match action {
SignalAction::Stream => send(&tx, info),
SignalAction::Default => {
emulate_default_handler(signal).ok();
}
SignalAction::Ignore => {}
}
}
})
}?
};
Ok(Self {
signal,
sig_id,
rx,
action,
})
}
pub fn set_action(&self, action: SignalAction) -> SignalAction {
SignalAction::try_new(self.action.swap(action as u8, Ordering::SeqCst))
.unwrap_or(SignalAction::Ignore)
}
pub fn recv(&mut self) -> io::Result<SignalInfo> {
let mut info = MaybeUninit::<siginfo_t>::uninit();
let fd = self.rx.as_raw_fd();
let bytes = cerr(unsafe { libc::recv(fd, info.as_mut_ptr().cast(), SIGINFO_SIZE, 0) })?;
if bytes as usize != SIGINFO_SIZE {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Not enough bytes when receiving `siginfo_t`",
));
}
let info = unsafe { info.assume_init() };
Ok(SignalInfo { info })
}
pub fn signal(&self) -> SignalNumber {
self.signal
}
pub fn unregister(&self) {
unregister(self.sig_id);
}
}
fn send(tx: &UnixStream, info: &siginfo_t) {
let fd = tx.as_raw_fd();
unsafe {
libc::send(
fd,
(info as *const siginfo_t).cast(),
SIGINFO_SIZE,
MSG_DONTWAIT,
);
}
}