use libc::{
c_int, c_void, pthread_kill, pthread_sigmask, pthread_t, sigaction, sigaddset, sigemptyset,
sigfillset, siginfo_t, sigismember, sigpending, sigset_t, sigtimedwait, timespec, EAGAIN,
EINTR, EINVAL, SIG_BLOCK, SIG_UNBLOCK,
};
use crate::errno;
use std::fmt::{self, Display};
use std::io;
use std::mem;
use std::os::unix::thread::JoinHandleExt;
use std::ptr::{null, null_mut};
use std::result;
use std::thread::JoinHandle;
#[derive(Debug, PartialEq, Eq)]
pub enum Error {
CreateSigset(errno::Error),
SignalAlreadyBlocked(c_int),
CompareBlockedSignals(errno::Error),
BlockSignal(errno::Error),
RetrieveSignalMask(c_int),
UnblockSignal(errno::Error),
ClearWaitPending(errno::Error),
ClearGetPending(errno::Error),
ClearCheckPending(errno::Error),
}
impl Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::Error::*;
match self {
CreateSigset(e) => write!(f, "couldn't create a sigset: {}", e),
SignalAlreadyBlocked(num) => write!(f, "signal {} already blocked", num),
CompareBlockedSignals(e) => write!(
f,
"failed to check whether requested signal is in the blocked set: {}",
e,
),
BlockSignal(e) => write!(f, "signal could not be blocked: {}", e),
RetrieveSignalMask(errno) => write!(
f,
"failed to retrieve signal mask: {}",
io::Error::from_raw_os_error(*errno),
),
UnblockSignal(e) => write!(f, "signal could not be unblocked: {}", e),
ClearWaitPending(e) => write!(f, "failed to wait for given signal: {}", e),
ClearGetPending(e) => write!(f, "failed to get pending signals: {}", e),
ClearCheckPending(e) => write!(
f,
"failed to check whether given signal is in the pending set: {}",
e,
),
}
}
}
pub type SignalResult<T> = result::Result<T, Error>;
pub type SignalHandler =
extern "C" fn(num: c_int, info: *mut siginfo_t, _unused: *mut c_void) -> ();
extern "C" {
fn __libc_current_sigrtmin() -> c_int;
fn __libc_current_sigrtmax() -> c_int;
}
#[allow(non_snake_case)]
pub fn SIGRTMIN() -> c_int {
unsafe { __libc_current_sigrtmin() }
}
#[allow(non_snake_case)]
pub fn SIGRTMAX() -> c_int {
unsafe { __libc_current_sigrtmax() }
}
pub fn validate_signal_num(num: c_int) -> errno::Result<()> {
if (libc::SIGHUP..=libc::SIGSYS).contains(&num) || (SIGRTMIN() <= num && num <= SIGRTMAX()) {
Ok(())
} else {
Err(errno::Error::new(EINVAL))
}
}
pub fn register_signal_handler(num: c_int, handler: SignalHandler) -> errno::Result<()> {
validate_signal_num(num)?;
if libc::SIGKILL == num || libc::SIGSTOP == num {
return Err(errno::Error::new(EINVAL));
}
let mut act: sigaction = unsafe { mem::zeroed() };
act.sa_sigaction = handler as *const () as usize;
act.sa_flags = libc::SA_SIGINFO;
if unsafe { sigfillset(&mut act.sa_mask as *mut sigset_t) } < 0 {
return errno::errno_result();
}
match unsafe { sigaction(num, &act, null_mut()) } {
0 => Ok(()),
_ => errno::errno_result(),
}
}
pub fn create_sigset(signals: &[c_int]) -> errno::Result<sigset_t> {
let mut sigset: sigset_t = unsafe { mem::zeroed() };
let ret = unsafe { sigemptyset(&mut sigset) };
if ret < 0 {
return errno::errno_result();
}
for signal in signals {
let ret = unsafe { sigaddset(&mut sigset, *signal) };
if ret < 0 {
return errno::errno_result();
}
}
Ok(sigset)
}
pub fn get_blocked_signals() -> SignalResult<Vec<c_int>> {
let mut mask = Vec::new();
unsafe {
let mut old_sigset: sigset_t = mem::zeroed();
let ret = pthread_sigmask(SIG_BLOCK, null(), &mut old_sigset as *mut sigset_t);
if ret < 0 {
return Err(Error::RetrieveSignalMask(ret));
}
for num in 0..=SIGRTMAX() {
if sigismember(&old_sigset, num) > 0 {
mask.push(num);
}
}
}
Ok(mask)
}
#[allow(clippy::comparison_chain)]
pub fn block_signal(num: c_int) -> SignalResult<()> {
let sigset = create_sigset(&[num]).map_err(Error::CreateSigset)?;
unsafe {
let mut old_sigset: sigset_t = mem::zeroed();
let ret = pthread_sigmask(SIG_BLOCK, &sigset, &mut old_sigset as *mut sigset_t);
if ret < 0 {
return Err(Error::BlockSignal(errno::Error::last()));
}
let ret = sigismember(&old_sigset, num);
if ret < 0 {
return Err(Error::CompareBlockedSignals(errno::Error::last()));
} else if ret > 0 {
return Err(Error::SignalAlreadyBlocked(num));
}
}
Ok(())
}
pub fn unblock_signal(num: c_int) -> SignalResult<()> {
let sigset = create_sigset(&[num]).map_err(Error::CreateSigset)?;
let ret = unsafe { pthread_sigmask(SIG_UNBLOCK, &sigset, null_mut()) };
if ret < 0 {
return Err(Error::UnblockSignal(errno::Error::last()));
}
Ok(())
}
pub fn clear_signal(num: c_int) -> SignalResult<()> {
let sigset = create_sigset(&[num]).map_err(Error::CreateSigset)?;
while {
unsafe {
let mut siginfo: siginfo_t = mem::zeroed();
let ts = timespec {
tv_sec: 0,
tv_nsec: 0,
};
let ret = sigtimedwait(&sigset, &mut siginfo, &ts);
if ret < 0 {
let e = errno::Error::last();
match e.errno() {
EAGAIN | EINTR => {}
_ => {
return Err(Error::ClearWaitPending(errno::Error::last()));
}
}
}
let mut chkset: sigset_t = mem::zeroed();
let ret = sigpending(&mut chkset);
if ret < 0 {
return Err(Error::ClearGetPending(errno::Error::last()));
}
let ret = sigismember(&chkset, num);
if ret < 0 {
return Err(Error::ClearCheckPending(errno::Error::last()));
}
ret != 0
}
} {}
Ok(())
}
pub unsafe trait Killable {
fn pthread_handle(&self) -> pthread_t;
fn kill(&self, num: c_int) -> errno::Result<()> {
validate_signal_num(num)?;
let ret = unsafe { pthread_kill(self.pthread_handle(), num) };
if ret < 0 {
return errno::errno_result();
}
Ok(())
}
}
unsafe impl<T> Killable for JoinHandle<T> {
fn pthread_handle(&self) -> pthread_t {
assert_eq!(mem::size_of::<pthread_t>(), mem::size_of::<usize>());
self.as_pthread_t() as usize as pthread_t
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::undocumented_unsafe_blocks)]
use super::*;
use std::thread;
use std::time::Duration;
static mut SIGNAL_HANDLER_CALLED: bool = false;
extern "C" fn handle_signal(_: c_int, _: *mut siginfo_t, _: *mut c_void) {
unsafe {
SIGNAL_HANDLER_CALLED = true;
}
}
fn is_pending(signal: c_int) -> bool {
unsafe {
let mut chkset: sigset_t = mem::zeroed();
sigpending(&mut chkset);
sigismember(&chkset, signal) == 1
}
}
#[test]
fn test_register_signal_handler() {
assert!(register_signal_handler(libc::SIGKILL, handle_signal).is_err());
assert!(register_signal_handler(libc::SIGSTOP, handle_signal).is_err());
assert!(register_signal_handler(SIGRTMAX() + 1, handle_signal).is_err());
format!("{:?}", register_signal_handler(SIGRTMAX(), handle_signal));
assert!(register_signal_handler(SIGRTMIN(), handle_signal).is_ok());
assert!(register_signal_handler(libc::SIGSYS, handle_signal).is_ok());
}
#[test]
#[allow(clippy::empty_loop)]
fn test_killing_thread() {
let killable = thread::spawn(|| thread::current().id());
let killable_id = killable.join().unwrap();
assert_ne!(killable_id, thread::current().id());
register_signal_handler(SIGRTMIN(), handle_signal)
.expect("failed to register vcpu signal handler");
let killable = thread::spawn(|| loop {});
let res = killable.kill(SIGRTMAX() + 1);
assert!(res.is_err());
format!("{:?}", res);
unsafe {
assert!(!SIGNAL_HANDLER_CALLED);
}
assert!(killable.kill(SIGRTMIN()).is_ok());
const MAX_WAIT_ITERS: u32 = 20;
let mut iter_count = 0;
loop {
thread::sleep(Duration::from_millis(100));
if unsafe { SIGNAL_HANDLER_CALLED } {
break;
}
iter_count += 1;
assert!(iter_count <= MAX_WAIT_ITERS);
}
}
#[test]
fn test_block_unblock_signal() {
let signal = SIGRTMIN();
unsafe {
let mut sigset: sigset_t = mem::zeroed();
pthread_sigmask(SIG_BLOCK, null(), &mut sigset as *mut sigset_t);
assert_eq!(sigismember(&sigset, signal), 0);
}
block_signal(signal).unwrap();
assert!(get_blocked_signals().unwrap().contains(&(signal)));
unblock_signal(signal).unwrap();
assert!(!get_blocked_signals().unwrap().contains(&(signal)));
}
#[test]
fn test_clear_pending() {
let signal = SIGRTMIN() + 1;
block_signal(signal).unwrap();
let killable = thread::spawn(move || {
loop {
thread::sleep(Duration::from_millis(100));
if is_pending(signal) {
clear_signal(signal).unwrap();
assert!(!is_pending(signal));
break;
}
}
});
assert!(killable.kill(SIGRTMIN() + 1).is_ok());
killable.join().unwrap();
}
}