use std::io;
use std::sync::{Mutex, Once};
use slab::Slab;
use tracing::instrument;
static LIVE_GUARDS: Mutex<GuardTable> = Mutex::new(Slab::new());
type GuardTable = Slab<Box<dyn FnOnce() + Send>>;
pub fn init() {
static CALLED: Once = Once::new();
CALLED.call_once(|| {
if let Err(ref e) = unsafe { platform::init() } {
eprintln!("couldn't register signal handler: {e}");
}
});
}
pub struct CleanupGuard {
slot: usize,
}
impl CleanupGuard {
pub fn new<F: FnOnce() + Send + 'static>(f: F) -> Self {
let guards = &mut *LIVE_GUARDS.lock().unwrap();
Self {
slot: guards.insert(Box::new(f)),
}
}
}
impl Drop for CleanupGuard {
#[instrument(skip_all)]
fn drop(&mut self) {
let guards = &mut *LIVE_GUARDS.lock().unwrap();
let f = guards.remove(self.slot);
f();
}
}
#[cfg(unix)]
mod platform {
use std::os::unix::io::{IntoRawFd as _, RawFd};
use std::os::unix::net::UnixDatagram;
use std::panic::AssertUnwindSafe;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use libc::{c_int, SIGINT, SIGTERM};
use super::*;
pub unsafe fn init() -> io::Result<()> {
let (send, recv) = UnixDatagram::pair()?;
thread::spawn(move || {
let mut buf = [0];
let signal = match recv.recv(&mut buf) {
Ok(1) => c_int::from(buf[0]),
_ => unreachable!(),
};
let guards = &mut *LIVE_GUARDS.lock().unwrap();
if let Err(e) = std::panic::catch_unwind(AssertUnwindSafe(|| on_signal(guards))) {
match e.downcast::<String>() {
Ok(s) => eprintln!("signal handler panicked: {s}"),
Err(_) => eprintln!("signal handler panicked"),
}
}
libc::signal(signal, libc::SIG_DFL);
libc::raise(signal);
});
SIGNAL_SEND = send.into_raw_fd();
libc::signal(SIGINT, handler as libc::sighandler_t);
libc::signal(SIGTERM, handler as libc::sighandler_t);
Ok(())
}
fn on_signal(guards: &mut GuardTable) {
for guard in guards.drain() {
guard();
}
}
unsafe extern "C" fn handler(signal: c_int) {
static SIGNALED: AtomicBool = AtomicBool::new(false);
if SIGNALED.swap(true, Ordering::Relaxed) {
libc::signal(signal, libc::SIG_DFL);
libc::raise(signal);
}
let buf = [signal as u8];
libc::write(SIGNAL_SEND, buf.as_ptr().cast(), buf.len());
}
static mut SIGNAL_SEND: RawFd = 0;
}
#[cfg(not(unix))]
mod platform {
use super::*;
pub unsafe fn init() -> io::Result<()> {
Ok(())
}
}