1use std::io;
2use std::sync::Mutex;
3use std::sync::Once;
4
5use slab::Slab;
6use tracing::instrument;
7
8static LIVE_GUARDS: Mutex<GuardTable> = Mutex::new(Slab::new());
10
11type GuardTable = Slab<Box<dyn FnOnce() + Send>>;
12
13pub fn init() {
15 static CALLED: Once = Once::new();
17 CALLED.call_once(|| {
18 if let Err(ref e) = unsafe { platform::init() } {
19 eprintln!("couldn't register signal handler: {e}");
20 }
21 });
22}
23
24pub struct CleanupGuard {
26 slot: usize,
27}
28
29impl CleanupGuard {
30 pub fn new<F: FnOnce() + Send + 'static>(f: F) -> Self {
32 let guards = &mut *LIVE_GUARDS.lock().unwrap();
33 Self {
34 slot: guards.insert(Box::new(f)),
35 }
36 }
37}
38
39impl Drop for CleanupGuard {
40 #[instrument(skip_all)]
41 fn drop(&mut self) {
42 let guards = &mut *LIVE_GUARDS.lock().unwrap();
43 let f = guards.remove(self.slot);
44 f();
45 }
46}
47
48#[cfg(unix)]
49mod platform {
50 use std::os::unix::io::IntoRawFd as _;
51 use std::os::unix::io::RawFd;
52 use std::os::unix::net::UnixDatagram;
53 use std::panic::AssertUnwindSafe;
54 use std::sync::atomic::AtomicBool;
55 use std::sync::atomic::Ordering;
56 use std::thread;
57
58 use libc::SIGINT;
59 use libc::SIGTERM;
60 use libc::c_int;
61
62 use super::*;
63
64 pub unsafe fn init() -> io::Result<()> {
66 unsafe {
67 let (send, recv) = UnixDatagram::pair()?;
68
69 thread::spawn(move || {
72 let mut buf = [0];
73 let signal = match recv.recv(&mut buf) {
74 Ok(1) => c_int::from(buf[0]),
75 _ => unreachable!(),
76 };
77 let guards = &mut *LIVE_GUARDS.lock().unwrap();
80 if let Err(e) = std::panic::catch_unwind(AssertUnwindSafe(|| on_signal(guards))) {
81 match e.downcast::<String>() {
82 Ok(s) => eprintln!("signal handler panicked: {s}"),
83 Err(_) => eprintln!("signal handler panicked"),
84 }
85 }
86 libc::signal(signal, libc::SIG_DFL);
87 libc::raise(signal);
88 });
89
90 SIGNAL_SEND = send.into_raw_fd();
91 libc::signal(SIGINT, handler as *const () as libc::sighandler_t);
92 libc::signal(SIGTERM, handler as *const () as libc::sighandler_t);
93 Ok(())
94 }
95 }
96
97 fn on_signal(guards: &mut GuardTable) {
99 for guard in guards.drain() {
100 guard();
101 }
102 }
103
104 unsafe extern "C" fn handler(signal: c_int) {
105 unsafe {
106 static SIGNALED: AtomicBool = AtomicBool::new(false);
108 if SIGNALED.swap(true, Ordering::Relaxed) {
109 libc::signal(signal, libc::SIG_DFL);
110 libc::raise(signal);
111 }
112
113 let buf = [signal as u8];
114 libc::write(SIGNAL_SEND, buf.as_ptr().cast(), buf.len());
115 }
116 }
117
118 static mut SIGNAL_SEND: RawFd = 0;
119}
120
121#[cfg(not(unix))]
122mod platform {
123 use super::*;
124
125 pub unsafe fn init() -> io::Result<()> {
128 Ok(())
129 }
130}