Skip to main content

elfo_core/
signal.rs

1use std::{
2    any::Any,
3    io,
4    os::raw::c_int,
5    pin::Pin,
6    task::{self, Poll},
7};
8
9use pin_project::pin_project;
10use sealed::sealed;
11use serde::{Deserialize, Serialize};
12#[cfg(unix)]
13use tokio::signal;
14#[cfg(unix)]
15use tokio::signal::unix;
16#[cfg(windows)]
17use tokio::signal::windows;
18
19use crate::{
20    envelope::{Envelope, MessageKind},
21    message::Message,
22    source::{SourceArc, SourceStream, UnattachedSource},
23    tracing::TraceId,
24    Addr,
25};
26
27/// A source that emits a message once a signal is received.
28/// Clones the message on every tick.
29///
30/// It's based on the tokio implementation, so it should be useful to read
31/// about [caveats](https://docs.rs/tokio/latest/tokio/signal/unix/struct.Signal.html).
32///
33/// # Tracing
34///
35/// Every message starts a new trace, thus a new trace id is generated and
36/// assigned to the current scope.
37///
38/// # Example
39///
40/// ```
41/// # use std::time::Duration;
42/// # use elfo_core as elfo;
43/// # async fn exec(mut ctx: elfo::Context) {
44/// # use elfo::{message, msg};
45/// use elfo::signal::{Signal, SignalKind};
46///
47/// #[message]
48/// struct ReloadFile;
49///
50/// ctx.attach(Signal::new(SignalKind::UnixHangup, ReloadFile));
51///
52/// while let Some(envelope) = ctx.recv().await {
53///     msg!(match envelope {
54///         ReloadFile => { /* ... */ },
55///     });
56/// }
57/// # }
58/// ```
59pub struct Signal<M> {
60    source: SourceArc<SignalSource<M>>,
61}
62
63#[sealed]
64impl<M: Message> crate::source::SourceHandle for Signal<M> {
65    fn is_terminated(&self) -> bool {
66        self.source.is_terminated()
67    }
68
69    fn terminate_by_ref(&self) -> bool {
70        self.source.terminate_by_ref()
71    }
72}
73
74#[pin_project]
75struct SignalSource<M> {
76    message: M,
77    inner: SignalInner,
78}
79
80enum SignalInner {
81    Disabled,
82    #[cfg(windows)]
83    WindowsCtrlC(windows::CtrlC),
84    #[cfg(unix)]
85    Unix(unix::Signal),
86}
87
88/// A kind of signal to listen to.
89///
90/// * `Unix*` variants are available only on UNIX systems and produce nothing on
91///   other systems.
92/// * `Windows*` variants are available only on Windows and produce nothing on
93///   other systems.
94///
95/// It helps to avoid writing `#[cfg(_)]` everywhere around signals.
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
97#[non_exhaustive]
98pub enum SignalKind {
99    /// The "ctrl-c" notification.
100    WindowsCtrlC,
101
102    /// Any valid OS signal.
103    UnixRaw(c_int),
104    /// SIGALRM
105    UnixAlarm,
106    /// SIGCHLD
107    UnixChild,
108    /// SIGHUP
109    UnixHangup,
110    /// SIGINT
111    UnixInterrupt,
112    /// SIGIO
113    UnixIo,
114    /// SIGPIPE
115    UnixPipe,
116    /// SIGQUIT
117    UnixQuit,
118    /// SIGTERM
119    UnixTerminate,
120    /// SIGUSR1
121    UnixUser1,
122    /// SIGUSR2
123    UnixUser2,
124    /// SIGWINCH
125    UnixWindowChange,
126}
127
128impl<M: Message> Signal<M> {
129    /// Creates an unattached instance of [`Signal`].
130    pub fn new(kind: SignalKind, message: M) -> UnattachedSource<Self> {
131        let inner = SignalInner::new(kind).unwrap_or_else(|err| {
132            tracing::warn!(kind = ?kind, error = %err, "failed to create a signal handler");
133            SignalInner::Disabled
134        });
135
136        let source = SourceArc::new(SignalSource { message, inner }, false);
137        UnattachedSource::new(source, |source| Self { source })
138    }
139
140    /// Replaces a stored message with the provided one.
141    pub fn set_message(&self, message: M) {
142        let mut guard = ward!(self.source.lock());
143        *guard.stream().project().message = message;
144    }
145}
146
147impl SignalInner {
148    #[cfg(unix)]
149    fn new(kind: SignalKind) -> io::Result<SignalInner> {
150        use signal::unix::SignalKind as U;
151
152        let kind = match kind {
153            SignalKind::UnixRaw(signum) => U::from_raw(signum),
154            SignalKind::UnixAlarm => U::alarm(),
155            SignalKind::UnixChild => U::child(),
156            SignalKind::UnixHangup => U::hangup(),
157            SignalKind::UnixInterrupt => U::interrupt(),
158            SignalKind::UnixIo => U::io(),
159            SignalKind::UnixPipe => U::pipe(),
160            SignalKind::UnixQuit => U::quit(),
161            SignalKind::UnixTerminate => U::terminate(),
162            SignalKind::UnixUser1 => U::user_defined1(),
163            SignalKind::UnixUser2 => U::user_defined2(),
164            SignalKind::UnixWindowChange => U::window_change(),
165            _ => return Ok(SignalInner::Disabled),
166        };
167
168        unix::signal(kind).map(SignalInner::Unix)
169    }
170
171    #[cfg(windows)]
172    fn new(kind: SignalKind) -> io::Result<SignalInner> {
173        match kind {
174            SignalKind::WindowsCtrlC => windows::ctrl_c().map(SignalInner::WindowsCtrlC),
175            _ => Ok(SignalInner::Disabled),
176        }
177    }
178
179    fn poll_recv(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<()>> {
180        match self {
181            SignalInner::Disabled => Poll::Ready(None),
182            #[cfg(windows)]
183            SignalInner::WindowsCtrlC(inner) => inner.poll_recv(cx),
184            #[cfg(unix)]
185            SignalInner::Unix(inner) => inner.poll_recv(cx),
186        }
187    }
188}
189
190impl<M: Message> SourceStream for SignalSource<M> {
191    fn as_any_mut(self: Pin<&mut Self>) -> Pin<&mut dyn Any> {
192        // SAFETY: we only cast here, it cannot move data.
193        unsafe { self.map_unchecked_mut(|s| s) }
194    }
195
196    fn poll_recv(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Envelope>> {
197        let this = self.project();
198
199        match this.inner.poll_recv(cx) {
200            Poll::Ready(Some(())) => {}
201            Poll::Ready(None) => return Poll::Ready(None),
202            Poll::Pending => return Poll::Pending,
203        }
204
205        let message = this.message.clone();
206        let kind = MessageKind::regular(Addr::NULL);
207        let trace_id = TraceId::generate();
208        let envelope = Envelope::with_trace_id(message, kind, trace_id);
209        Poll::Ready(Some(envelope))
210    }
211}