jj_cli/
cleanup_guard.rs

1use std::io;
2use std::sync::Mutex;
3use std::sync::Once;
4
5use slab::Slab;
6use tracing::instrument;
7
8/// Contains the callbacks passed to currently-live [`CleanupGuard`]s
9static LIVE_GUARDS: Mutex<GuardTable> = Mutex::new(Slab::new());
10
11type GuardTable = Slab<Box<dyn FnOnce() + Send>>;
12
13/// Prepare to run [`CleanupGuard`]s on `SIGINT`/`SIGTERM`
14pub fn init() {
15    // Safety: `` ensures at most one call
16    static CALLED: Once = Once::new();
17    CALLED.call_once(|| {
18        if let Err(ref e) = unsafe { platform::init() } {
19            eprintln!("couldn't register signal handler: {e}");
20        }
21    });
22}
23
24/// A drop guard that also runs on `SIGINT`/`SIGTERM`
25pub struct CleanupGuard {
26    slot: usize,
27}
28
29impl CleanupGuard {
30    /// Invoke `f` when dropped or killed by `SIGINT`/`SIGTERM`
31    pub fn new<F: FnOnce() + Send + 'static>(f: F) -> Self {
32        let guards = &mut *LIVE_GUARDS.lock().unwrap();
33        Self {
34            slot: guards.insert(Box::new(f)),
35        }
36    }
37}
38
39impl Drop for CleanupGuard {
40    #[instrument(skip_all)]
41    fn drop(&mut self) {
42        let guards = &mut *LIVE_GUARDS.lock().unwrap();
43        let f = guards.remove(self.slot);
44        f();
45    }
46}
47
48#[cfg(unix)]
49mod platform {
50    use std::os::unix::io::IntoRawFd as _;
51    use std::os::unix::io::RawFd;
52    use std::os::unix::net::UnixDatagram;
53    use std::panic::AssertUnwindSafe;
54    use std::sync::atomic::AtomicBool;
55    use std::sync::atomic::Ordering;
56    use std::thread;
57
58    use libc::SIGINT;
59    use libc::SIGTERM;
60    use libc::c_int;
61
62    use super::*;
63
64    /// Safety: Must be called at most once
65    pub unsafe fn init() -> io::Result<()> {
66        unsafe {
67            let (send, recv) = UnixDatagram::pair()?;
68
69            // Spawn a background thread that waits for the signal handler to write a signal
70            // into it
71            thread::spawn(move || {
72                let mut buf = [0];
73                let signal = match recv.recv(&mut buf) {
74                    Ok(1) => c_int::from(buf[0]),
75                    _ => unreachable!(),
76                };
77                // We must hold the lock for the remainder of the process's lifetime to avoid a
78                // race where a guard is created between `on_signal` and `raise`.
79                let guards = &mut *LIVE_GUARDS.lock().unwrap();
80                if let Err(e) = std::panic::catch_unwind(AssertUnwindSafe(|| on_signal(guards))) {
81                    match e.downcast::<String>() {
82                        Ok(s) => eprintln!("signal handler panicked: {s}"),
83                        Err(_) => eprintln!("signal handler panicked"),
84                    }
85                }
86                libc::signal(signal, libc::SIG_DFL);
87                libc::raise(signal);
88            });
89
90            SIGNAL_SEND = send.into_raw_fd();
91            libc::signal(SIGINT, handler as *const () as libc::sighandler_t);
92            libc::signal(SIGTERM, handler as *const () as libc::sighandler_t);
93            Ok(())
94        }
95    }
96
97    // Invoked on a background thread. Process exits after this returns.
98    fn on_signal(guards: &mut GuardTable) {
99        for guard in guards.drain() {
100            guard();
101        }
102    }
103
104    unsafe extern "C" fn handler(signal: c_int) {
105        unsafe {
106            // Treat the second signal as instantly fatal.
107            static SIGNALED: AtomicBool = AtomicBool::new(false);
108            if SIGNALED.swap(true, Ordering::Relaxed) {
109                libc::signal(signal, libc::SIG_DFL);
110                libc::raise(signal);
111            }
112
113            let buf = [signal as u8];
114            libc::write(SIGNAL_SEND, buf.as_ptr().cast(), buf.len());
115        }
116    }
117
118    static mut SIGNAL_SEND: RawFd = 0;
119}
120
121#[cfg(not(unix))]
122mod platform {
123    use super::*;
124
125    /// Safety: this function is safe to call, but is marked as unsafe to have
126    /// the same signature as other `init` functions in other platforms.
127    pub unsafe fn init() -> io::Result<()> {
128        Ok(())
129    }
130}