async_signals/
lib.rs

1//! Library for easier and safe Unix signal handling with async Stream.
2//!
3//! You can use this crate with any async runtime.
4
5use std::collections::HashMap;
6use std::convert::TryFrom;
7use std::io::{Read, Write};
8use std::os::raw::c_int;
9use std::os::unix::net::UnixStream;
10use std::pin::Pin;
11use std::sync::atomic::{AtomicU32, Ordering};
12use std::sync::{Arc, Mutex, Once};
13use std::task::Context;
14use std::task::Poll;
15use std::{io, thread};
16
17use crossbeam_queue::SegQueue;
18use crossbeam_skiplist::SkipSet;
19use futures_util::task::AtomicWaker;
20use futures_util::Stream;
21use nix::sys::signal::{sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal};
22use once_cell::sync::Lazy;
23
24static PIPE: Lazy<io::Result<(UnixStream, UnixStream)>> = Lazy::new(|| {
25    let (reader, write) = UnixStream::pair()?;
26    // don't block the signal handler
27    write.set_nonblocking(true)?;
28
29    Ok((reader, write))
30});
31
32static SIGNALS_SET: Lazy<SignalsSet> = Lazy::new(SignalsSet::default);
33
34#[derive(Debug)]
35struct SignalsSet {
36    id_gen: AtomicU32,
37    signal_notifiers: Mutex<SignalNotifiers>,
38    start_signal_handle_thread: Once,
39}
40
41impl Default for SignalsSet {
42    fn default() -> Self {
43        Self {
44            id_gen: Default::default(),
45            signal_notifiers: Mutex::new(Default::default()),
46            start_signal_handle_thread: Once::new(),
47        }
48    }
49}
50
51impl SignalsSet {
52    fn generate_id(&self) -> u32 {
53        self.id_gen.fetch_add(1, Ordering::Relaxed)
54    }
55
56    fn start_signal_handle_thread(&'static self) {
57        self.start_signal_handle_thread.call_once(|| {
58            thread::spawn(|| {
59                let mut reader = get_pipe_reader();
60                loop {
61                    let mut buf = [0];
62                    (&mut reader).read_exact(&mut buf).unwrap_or_else(|err| {
63                        panic!("pipe reader read return error {err}, that should not happened")
64                    });
65
66                    let signal_notifiers = self.signal_notifiers.lock().unwrap();
67                    for signal_inner in signal_notifiers.notifiers.values() {
68                        if signal_inner.interest(buf[0] as _) {
69                            signal_inner.notify(buf[0] as _);
70                        }
71                    }
72                }
73            });
74        })
75    }
76}
77
78#[derive(Debug, Default)]
79struct SignalNotifiers {
80    installed_signals: HashMap<c_int, usize>,
81    notifiers: HashMap<u32, Arc<SignalsInner>>,
82}
83
84impl SignalNotifiers {
85    fn add_signal(&mut self, signal: c_int) -> io::Result<()> {
86        let count = self
87            .installed_signals
88            .entry(signal)
89            .and_modify(|count| *count += 1)
90            .or_insert(1);
91        if *count == 1 {
92            let handler = SigHandler::Handler(handle);
93            let action = SigAction::new(handler, SaFlags::SA_RESTART, SigSet::empty());
94
95            unsafe {
96                sigaction(Signal::try_from(signal)?, &action)?;
97            }
98        }
99
100        Ok(())
101    }
102
103    fn remove_signal(&mut self, signal: c_int) -> io::Result<()> {
104        if let Some(count) = self.installed_signals.get_mut(&signal) {
105            *count -= 1;
106            if *count == 0 {
107                let action =
108                    SigAction::new(SigHandler::SigDfl, SaFlags::SA_RESTART, SigSet::empty());
109
110                unsafe {
111                    let signal = Signal::try_from(signal)
112                        .unwrap_or_else(|_| panic!("signal {signal} should be valid"));
113                    let _ = sigaction(signal, &action);
114                }
115
116                self.installed_signals.remove(&signal);
117            }
118        }
119
120        Ok(())
121    }
122
123    fn add_notifier(&mut self, id: u32, signals_inner: Arc<SignalsInner>) {
124        self.notifiers.insert(id, signals_inner);
125    }
126
127    fn remove_notifier(&mut self, id: u32) {
128        self.notifiers.remove(&id);
129    }
130}
131
132fn get_pipe_writer() -> &'static UnixStream {
133    match PIPE.as_ref() {
134        Err(_) => unreachable!("if init pipe failed, should not get pipe writer"),
135        Ok((_, writer)) => writer,
136    }
137}
138
139fn get_pipe_reader() -> &'static UnixStream {
140    match PIPE.as_ref() {
141        Err(_) => unreachable!("if init pipe failed, should not get pipe writer"),
142        Ok((reader, _)) => reader,
143    }
144}
145
146extern "C" fn handle(receive_signal: c_int) {
147    let _ = get_pipe_writer().write(&[receive_signal as _]);
148}
149
150#[derive(Debug)]
151struct SignalsInner {
152    queue: SegQueue<c_int>,
153    waker: AtomicWaker,
154    interests: SkipSet<c_int>,
155}
156
157impl SignalsInner {
158    fn new(interests: SkipSet<c_int>) -> Self {
159        Self {
160            queue: Default::default(),
161            waker: Default::default(),
162            interests,
163        }
164    }
165
166    fn interest(&self, signal: c_int) -> bool {
167        self.interests.contains(&signal)
168    }
169
170    fn notify(&self, signal: c_int) {
171        self.queue.push(signal);
172        self.waker.wake();
173    }
174}
175
176/// Handle unix signal like `signal_hook::iterator::Signals`, receive signals
177/// with `futures::stream::Stream`.
178///
179/// If multi `Signals` register a same signal, all of them will receive the signal.
180///
181/// If you drop all `Signals` which handle a signal like `SIGINT`, when process receive
182/// this signal, will use system default handler.
183///
184/// # Notes:
185/// You can't handle `SIGKILL` or `SIGSTOP`.
186#[derive(Debug)]
187pub struct Signals {
188    id: u32,
189    inner: Arc<SignalsInner>,
190}
191
192impl Drop for Signals {
193    fn drop(&mut self) {
194        let mut signal_notifiers = SIGNALS_SET.signal_notifiers.lock().unwrap();
195        signal_notifiers.remove_notifier(self.id);
196
197        for signal in &self.inner.interests {
198            let _ = signal_notifiers.remove_signal(*signal.value());
199        }
200    }
201}
202
203impl Signals {
204    /// Creates the `Signals` structure, all signals will be registered.
205    ///
206    /// # Examples
207    ///
208    /// ```
209    /// use async_signals::Signals;
210    /// use futures_util::StreamExt;
211    /// use nix::sys;
212    /// use nix::unistd;
213    ///
214    /// #[tokio::main(flavor = "current_thread")]
215    /// async fn main() {
216    ///     let mut signals = Signals::new(vec![libc::SIGINT]).unwrap();
217    ///
218    ///     let pid = unistd::getpid();
219    ///     sys::signal::kill(pid, Some(sys::signal::SIGINT)).unwrap();
220    ///
221    ///     let signal = signals.next().await.unwrap();
222    ///
223    ///     assert_eq!(signal, libc::SIGINT);
224    /// }
225    /// ```
226    pub fn new<I: IntoIterator<Item = c_int>>(signals: I) -> io::Result<Signals> {
227        Self::check_pipe()?;
228
229        SIGNALS_SET.start_signal_handle_thread();
230
231        let signals = signals.into_iter().collect::<SkipSet<_>>();
232        let id = SIGNALS_SET.generate_id();
233        let mut signal_notifiers = SIGNALS_SET.signal_notifiers.lock().unwrap();
234        for signal in &signals {
235            signal_notifiers.add_signal(*signal.value())?;
236        }
237
238        let inner = Arc::new(SignalsInner::new(signals));
239        signal_notifiers.add_notifier(id, inner.clone());
240
241        Ok(Self { id, inner })
242    }
243
244    fn check_pipe() -> io::Result<()> {
245        PIPE.as_ref()
246            .map(|_| ())
247            .map_err(|err| io::Error::new(err.kind(), err.to_string()))
248    }
249
250    /// Registers another signal to a created `Signals`.
251    ///
252    /// # Examples
253    ///
254    /// ```
255    /// use async_signals::Signals;
256    /// use futures_util::StreamExt;
257    /// use nix::sys;
258    /// use nix::unistd;
259    ///
260    /// #[tokio::main(flavor = "current_thread")]
261    /// async fn main() {
262    ///     let mut signals = Signals::new(vec![libc::SIGHUP]).unwrap();
263    ///
264    ///     signals.add_signal(libc::SIGINT).unwrap();
265    ///
266    ///     let pid = unistd::getpid();
267    ///     sys::signal::kill(pid, Some(sys::signal::SIGINT)).unwrap();
268    ///
269    ///     let signal = signals.next().await.unwrap();
270    ///
271    ///     assert_eq!(signal, libc::SIGINT);
272    /// }
273    /// ```
274    pub fn add_signal(&mut self, signal: c_int) -> io::Result<()> {
275        SIGNALS_SET
276            .signal_notifiers
277            .lock()
278            .unwrap()
279            .add_signal(signal)?;
280        self.inner.interests.insert(signal);
281
282        Ok(())
283    }
284}
285
286impl Stream for Signals {
287    type Item = c_int;
288
289    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
290        // register at first, make sure when ready, we can be notified
291        self.inner.waker.register(cx.waker());
292
293        if let Some(signal) = self.inner.queue.pop() {
294            return Poll::Ready(Some(signal));
295        }
296
297        Poll::Pending
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use futures_util::StreamExt;
304    use nix::sys;
305    use nix::unistd;
306
307    use super::*;
308
309    #[tokio::test]
310    async fn interrupt() {
311        let mut signal = Signals::new(vec![libc::SIGINT]).unwrap();
312
313        let pid = unistd::getpid();
314
315        sys::signal::kill(pid, Some(sys::signal::SIGINT)).unwrap();
316
317        let interrupt = signal.next().await.unwrap();
318
319        assert_eq!(interrupt, libc::SIGINT);
320    }
321
322    #[tokio::test]
323    async fn add_signal() {
324        let mut signal = Signals::new(vec![libc::SIGHUP]).unwrap();
325
326        signal.add_signal(libc::SIGINT).unwrap();
327
328        let pid = unistd::getpid();
329
330        sys::signal::kill(pid, Some(sys::signal::SIGINT)).unwrap();
331
332        let interrupt = signal.next().await.unwrap();
333
334        assert_eq!(interrupt, libc::SIGINT);
335    }
336
337    #[tokio::test]
338    async fn multi_signals() {
339        let mut signal1 = Signals::new(vec![libc::SIGINT]).unwrap();
340        let mut signal2 = Signals::new(vec![libc::SIGINT]).unwrap();
341
342        let pid = unistd::getpid();
343
344        sys::signal::kill(pid, Some(sys::signal::SIGINT)).unwrap();
345
346        assert_eq!(signal1.next().await.unwrap(), libc::SIGINT);
347
348        assert_eq!(signal2.next().await.unwrap(), libc::SIGINT);
349    }
350}