elfo_core/
signal.rs

1use std::{
2    io,
3    os::raw::c_int,
4    task::{self, Poll},
5};
6
7use parking_lot::Mutex;
8use sealed::sealed;
9use tokio::signal;
10#[cfg(unix)]
11use tokio::signal::unix;
12use tokio_util::sync::ReusableBoxFuture;
13
14use crate::{
15    addr::Addr,
16    envelope::{Envelope, MessageKind},
17    message::Message,
18    tracing::TraceId,
19};
20
21/// `Source` that watches signals.
22///
23/// All signals except `SignalKind::CtrlC` produces messages on UNIX system
24/// only. For other systems they produce nothing. It's useful and helps to
25/// avoid writing `#[cfg(unix)]` everywhere around signals.
26pub struct Signal<F> {
27    inner: Mutex<SignalInner>,
28    message_factory: F,
29}
30
31enum SignalInner {
32    CtrlC(ReusableBoxFuture<'static, io::Result<()>>),
33    #[cfg(unix)]
34    Unix(unix::Signal),
35    Empty,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum SignalKind {
40    /// Completes when a “ctrl-c” notification is sent to the process.
41    ///
42    /// See <https://docs.rs/tokio/1.7.1/tokio/signal/fn.ctrl_c.html>
43    CtrlC,
44
45    /// Allows for listening to any valid OS signal.
46    Raw(c_int),
47    /// SIGALRM
48    Alarm,
49    /// SIGCHLD
50    Child,
51    /// SIGHUP
52    Hangup,
53    /// SIGINT
54    Interrupt,
55    /// SIGIO
56    Io,
57    /// SIGPIPE
58    Pipe,
59    /// SIGQUIT
60    Quit,
61    /// SIGTERM
62    Terminate,
63    /// SIGUSR1
64    User1,
65    /// SIGUSR2
66    User2,
67    /// SIGWINCH
68    WindowChange,
69}
70
71impl<F> Signal<F> {
72    pub fn new(kind: SignalKind, message_factory: F) -> Self {
73        let inner = match kind {
74            // TODO: remove this line for `unix` after testing windows.
75            SignalKind::CtrlC => SignalInner::CtrlC(ReusableBoxFuture::new(signal::ctrl_c())),
76            #[cfg(unix)]
77            _ => match create_by_kind(kind) {
78                Ok(inner) => SignalInner::Unix(inner),
79                Err(err) => {
80                    tracing::warn!(
81                        kind = ?kind,
82                        error = %err,
83                        "failed to create a signal handler"
84                    );
85                    SignalInner::Empty
86                }
87            },
88            #[cfg(not(unix))]
89            _ => SignalInner::Empty,
90        };
91
92        Self {
93            inner: Mutex::new(inner),
94            message_factory,
95        }
96    }
97}
98
99#[cfg(unix)]
100fn create_by_kind(kind: SignalKind) -> io::Result<unix::Signal> {
101    use signal::unix::SignalKind as TSK;
102
103    let kind = match kind {
104        SignalKind::CtrlC => TSK::interrupt(),
105        SignalKind::Raw(signum) => TSK::from_raw(signum),
106        SignalKind::Alarm => TSK::alarm(),
107        SignalKind::Child => TSK::child(),
108        SignalKind::Hangup => TSK::hangup(),
109        SignalKind::Interrupt => TSK::interrupt(),
110        SignalKind::Io => TSK::io(),
111        SignalKind::Pipe => TSK::pipe(),
112        SignalKind::Quit => TSK::quit(),
113        SignalKind::Terminate => TSK::terminate(),
114        SignalKind::User1 => TSK::user_defined1(),
115        SignalKind::User2 => TSK::user_defined2(),
116        SignalKind::WindowChange => TSK::window_change(),
117    };
118
119    unix::signal(kind)
120}
121
122#[sealed]
123impl<M, F> crate::source::Source for Signal<F>
124where
125    F: Fn() -> M,
126    M: Message,
127{
128    fn poll_recv(&self, cx: &mut task::Context<'_>) -> Poll<Option<Envelope>> {
129        match &mut *self.inner.lock() {
130            SignalInner::CtrlC(inner) => {
131                if let Err(err) = futures::ready!(inner.poll(cx)) {
132                    tracing::error!(error = %err, "failed to receive a signal");
133                }
134
135                assert!(inner.try_set(signal::ctrl_c()).is_ok());
136            }
137            #[cfg(unix)]
138            SignalInner::Unix(inner) => {
139                if !matches!(inner.poll_recv(cx), Poll::Ready(Some(()))) {
140                    return Poll::Pending;
141                }
142            }
143            SignalInner::Empty => return Poll::Pending,
144        };
145
146        let message = (self.message_factory)();
147        let kind = MessageKind::Regular { sender: Addr::NULL };
148        let trace_id = TraceId::generate();
149        let envelope = Envelope::with_trace_id(message, kind, trace_id).upcast();
150        Poll::Ready(Some(envelope))
151    }
152}
153
154#[cfg(test)]
155#[cfg(feature = "test-util")]
156mod tests {
157    use super::*;
158
159    use futures::{future::poll_fn, poll};
160
161    use elfo_macros::message;
162
163    use crate::{assert_msg, source::Source};
164
165    #[message(elfo = crate)]
166    struct SomeSignal;
167
168    #[tokio::test]
169    #[cfg(unix)]
170    async fn unix_signal() {
171        let signal = Signal::new(SignalKind::User1, || SomeSignal);
172
173        for _ in 0..=5 {
174            assert!(poll!(poll_fn(|cx| signal.poll_recv(cx))).is_pending());
175            send_signal(libc::SIGUSR1);
176            assert_msg!(
177                poll_fn(|cx| signal.poll_recv(cx)).await.unwrap(),
178                SomeSignal
179            );
180        }
181    }
182
183    #[tokio::test]
184    #[cfg(unix)]
185    async fn ctrl_c() {
186        let signal = Signal::new(SignalKind::CtrlC, || SomeSignal);
187
188        for _ in 0..=5 {
189            assert!(poll!(poll_fn(|cx| signal.poll_recv(cx))).is_pending());
190            send_signal(libc::SIGINT);
191            assert_msg!(
192                poll_fn(|cx| signal.poll_recv(cx)).await.unwrap(),
193                SomeSignal
194            );
195        }
196    }
197
198    fn send_signal(signal: libc::c_int) {
199        use libc::{getpid, kill};
200
201        unsafe {
202            assert_eq!(kill(getpid(), signal), 0);
203        }
204    }
205}