1use std::collections::HashMap;
6use std::convert::TryFrom;
7use std::io::{Read, Write};
8use std::os::raw::c_int;
9use std::os::unix::net::UnixStream;
10use std::pin::Pin;
11use std::sync::atomic::{AtomicU32, Ordering};
12use std::sync::{Arc, Mutex, Once};
13use std::task::Context;
14use std::task::Poll;
15use std::{io, thread};
16
17use crossbeam_queue::SegQueue;
18use crossbeam_skiplist::SkipSet;
19use futures_util::task::AtomicWaker;
20use futures_util::Stream;
21use nix::sys::signal::{sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal};
22use once_cell::sync::Lazy;
23
24static PIPE: Lazy<io::Result<(UnixStream, UnixStream)>> = Lazy::new(|| {
25 let (reader, write) = UnixStream::pair()?;
26 write.set_nonblocking(true)?;
28
29 Ok((reader, write))
30});
31
32static SIGNALS_SET: Lazy<SignalsSet> = Lazy::new(SignalsSet::default);
33
34#[derive(Debug)]
35struct SignalsSet {
36 id_gen: AtomicU32,
37 signal_notifiers: Mutex<SignalNotifiers>,
38 start_signal_handle_thread: Once,
39}
40
41impl Default for SignalsSet {
42 fn default() -> Self {
43 Self {
44 id_gen: Default::default(),
45 signal_notifiers: Mutex::new(Default::default()),
46 start_signal_handle_thread: Once::new(),
47 }
48 }
49}
50
51impl SignalsSet {
52 fn generate_id(&self) -> u32 {
53 self.id_gen.fetch_add(1, Ordering::Relaxed)
54 }
55
56 fn start_signal_handle_thread(&'static self) {
57 self.start_signal_handle_thread.call_once(|| {
58 thread::spawn(|| {
59 let mut reader = get_pipe_reader();
60 loop {
61 let mut buf = [0];
62 (&mut reader).read_exact(&mut buf).unwrap_or_else(|err| {
63 panic!("pipe reader read return error {err}, that should not happened")
64 });
65
66 let signal_notifiers = self.signal_notifiers.lock().unwrap();
67 for signal_inner in signal_notifiers.notifiers.values() {
68 if signal_inner.interest(buf[0] as _) {
69 signal_inner.notify(buf[0] as _);
70 }
71 }
72 }
73 });
74 })
75 }
76}
77
78#[derive(Debug, Default)]
79struct SignalNotifiers {
80 installed_signals: HashMap<c_int, usize>,
81 notifiers: HashMap<u32, Arc<SignalsInner>>,
82}
83
84impl SignalNotifiers {
85 fn add_signal(&mut self, signal: c_int) -> io::Result<()> {
86 let count = self
87 .installed_signals
88 .entry(signal)
89 .and_modify(|count| *count += 1)
90 .or_insert(1);
91 if *count == 1 {
92 let handler = SigHandler::Handler(handle);
93 let action = SigAction::new(handler, SaFlags::SA_RESTART, SigSet::empty());
94
95 unsafe {
96 sigaction(Signal::try_from(signal)?, &action)?;
97 }
98 }
99
100 Ok(())
101 }
102
103 fn remove_signal(&mut self, signal: c_int) -> io::Result<()> {
104 if let Some(count) = self.installed_signals.get_mut(&signal) {
105 *count -= 1;
106 if *count == 0 {
107 let action =
108 SigAction::new(SigHandler::SigDfl, SaFlags::SA_RESTART, SigSet::empty());
109
110 unsafe {
111 let signal = Signal::try_from(signal)
112 .unwrap_or_else(|_| panic!("signal {signal} should be valid"));
113 let _ = sigaction(signal, &action);
114 }
115
116 self.installed_signals.remove(&signal);
117 }
118 }
119
120 Ok(())
121 }
122
123 fn add_notifier(&mut self, id: u32, signals_inner: Arc<SignalsInner>) {
124 self.notifiers.insert(id, signals_inner);
125 }
126
127 fn remove_notifier(&mut self, id: u32) {
128 self.notifiers.remove(&id);
129 }
130}
131
132fn get_pipe_writer() -> &'static UnixStream {
133 match PIPE.as_ref() {
134 Err(_) => unreachable!("if init pipe failed, should not get pipe writer"),
135 Ok((_, writer)) => writer,
136 }
137}
138
139fn get_pipe_reader() -> &'static UnixStream {
140 match PIPE.as_ref() {
141 Err(_) => unreachable!("if init pipe failed, should not get pipe writer"),
142 Ok((reader, _)) => reader,
143 }
144}
145
146extern "C" fn handle(receive_signal: c_int) {
147 let _ = get_pipe_writer().write(&[receive_signal as _]);
148}
149
150#[derive(Debug)]
151struct SignalsInner {
152 queue: SegQueue<c_int>,
153 waker: AtomicWaker,
154 interests: SkipSet<c_int>,
155}
156
157impl SignalsInner {
158 fn new(interests: SkipSet<c_int>) -> Self {
159 Self {
160 queue: Default::default(),
161 waker: Default::default(),
162 interests,
163 }
164 }
165
166 fn interest(&self, signal: c_int) -> bool {
167 self.interests.contains(&signal)
168 }
169
170 fn notify(&self, signal: c_int) {
171 self.queue.push(signal);
172 self.waker.wake();
173 }
174}
175
176#[derive(Debug)]
187pub struct Signals {
188 id: u32,
189 inner: Arc<SignalsInner>,
190}
191
192impl Drop for Signals {
193 fn drop(&mut self) {
194 let mut signal_notifiers = SIGNALS_SET.signal_notifiers.lock().unwrap();
195 signal_notifiers.remove_notifier(self.id);
196
197 for signal in &self.inner.interests {
198 let _ = signal_notifiers.remove_signal(*signal.value());
199 }
200 }
201}
202
203impl Signals {
204 pub fn new<I: IntoIterator<Item = c_int>>(signals: I) -> io::Result<Signals> {
227 Self::check_pipe()?;
228
229 SIGNALS_SET.start_signal_handle_thread();
230
231 let signals = signals.into_iter().collect::<SkipSet<_>>();
232 let id = SIGNALS_SET.generate_id();
233 let mut signal_notifiers = SIGNALS_SET.signal_notifiers.lock().unwrap();
234 for signal in &signals {
235 signal_notifiers.add_signal(*signal.value())?;
236 }
237
238 let inner = Arc::new(SignalsInner::new(signals));
239 signal_notifiers.add_notifier(id, inner.clone());
240
241 Ok(Self { id, inner })
242 }
243
244 fn check_pipe() -> io::Result<()> {
245 PIPE.as_ref()
246 .map(|_| ())
247 .map_err(|err| io::Error::new(err.kind(), err.to_string()))
248 }
249
250 pub fn add_signal(&mut self, signal: c_int) -> io::Result<()> {
275 SIGNALS_SET
276 .signal_notifiers
277 .lock()
278 .unwrap()
279 .add_signal(signal)?;
280 self.inner.interests.insert(signal);
281
282 Ok(())
283 }
284}
285
286impl Stream for Signals {
287 type Item = c_int;
288
289 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
290 self.inner.waker.register(cx.waker());
292
293 if let Some(signal) = self.inner.queue.pop() {
294 return Poll::Ready(Some(signal));
295 }
296
297 Poll::Pending
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use futures_util::StreamExt;
304 use nix::sys;
305 use nix::unistd;
306
307 use super::*;
308
309 #[tokio::test]
310 async fn interrupt() {
311 let mut signal = Signals::new(vec![libc::SIGINT]).unwrap();
312
313 let pid = unistd::getpid();
314
315 sys::signal::kill(pid, Some(sys::signal::SIGINT)).unwrap();
316
317 let interrupt = signal.next().await.unwrap();
318
319 assert_eq!(interrupt, libc::SIGINT);
320 }
321
322 #[tokio::test]
323 async fn add_signal() {
324 let mut signal = Signals::new(vec![libc::SIGHUP]).unwrap();
325
326 signal.add_signal(libc::SIGINT).unwrap();
327
328 let pid = unistd::getpid();
329
330 sys::signal::kill(pid, Some(sys::signal::SIGINT)).unwrap();
331
332 let interrupt = signal.next().await.unwrap();
333
334 assert_eq!(interrupt, libc::SIGINT);
335 }
336
337 #[tokio::test]
338 async fn multi_signals() {
339 let mut signal1 = Signals::new(vec![libc::SIGINT]).unwrap();
340 let mut signal2 = Signals::new(vec![libc::SIGINT]).unwrap();
341
342 let pid = unistd::getpid();
343
344 sys::signal::kill(pid, Some(sys::signal::SIGINT)).unwrap();
345
346 assert_eq!(signal1.next().await.unwrap(), libc::SIGINT);
347
348 assert_eq!(signal2.next().await.unwrap(), libc::SIGINT);
349 }
350}