1use std::{
2 io,
3 os::raw::c_int,
4 task::{self, Poll},
5};
6
7use parking_lot::Mutex;
8use sealed::sealed;
9use tokio::signal;
10#[cfg(unix)]
11use tokio::signal::unix;
12use tokio_util::sync::ReusableBoxFuture;
13
14use crate::{
15 addr::Addr,
16 envelope::{Envelope, MessageKind},
17 message::Message,
18 tracing::TraceId,
19};
20
21pub struct Signal<F> {
27 inner: Mutex<SignalInner>,
28 message_factory: F,
29}
30
31enum SignalInner {
32 CtrlC(ReusableBoxFuture<'static, io::Result<()>>),
33 #[cfg(unix)]
34 Unix(unix::Signal),
35 Empty,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum SignalKind {
40 CtrlC,
44
45 Raw(c_int),
47 Alarm,
49 Child,
51 Hangup,
53 Interrupt,
55 Io,
57 Pipe,
59 Quit,
61 Terminate,
63 User1,
65 User2,
67 WindowChange,
69}
70
71impl<F> Signal<F> {
72 pub fn new(kind: SignalKind, message_factory: F) -> Self {
73 let inner = match kind {
74 SignalKind::CtrlC => SignalInner::CtrlC(ReusableBoxFuture::new(signal::ctrl_c())),
76 #[cfg(unix)]
77 _ => match create_by_kind(kind) {
78 Ok(inner) => SignalInner::Unix(inner),
79 Err(err) => {
80 tracing::warn!(
81 kind = ?kind,
82 error = %err,
83 "failed to create a signal handler"
84 );
85 SignalInner::Empty
86 }
87 },
88 #[cfg(not(unix))]
89 _ => SignalInner::Empty,
90 };
91
92 Self {
93 inner: Mutex::new(inner),
94 message_factory,
95 }
96 }
97}
98
99#[cfg(unix)]
100fn create_by_kind(kind: SignalKind) -> io::Result<unix::Signal> {
101 use signal::unix::SignalKind as TSK;
102
103 let kind = match kind {
104 SignalKind::CtrlC => TSK::interrupt(),
105 SignalKind::Raw(signum) => TSK::from_raw(signum),
106 SignalKind::Alarm => TSK::alarm(),
107 SignalKind::Child => TSK::child(),
108 SignalKind::Hangup => TSK::hangup(),
109 SignalKind::Interrupt => TSK::interrupt(),
110 SignalKind::Io => TSK::io(),
111 SignalKind::Pipe => TSK::pipe(),
112 SignalKind::Quit => TSK::quit(),
113 SignalKind::Terminate => TSK::terminate(),
114 SignalKind::User1 => TSK::user_defined1(),
115 SignalKind::User2 => TSK::user_defined2(),
116 SignalKind::WindowChange => TSK::window_change(),
117 };
118
119 unix::signal(kind)
120}
121
122#[sealed]
123impl<M, F> crate::source::Source for Signal<F>
124where
125 F: Fn() -> M,
126 M: Message,
127{
128 fn poll_recv(&self, cx: &mut task::Context<'_>) -> Poll<Option<Envelope>> {
129 match &mut *self.inner.lock() {
130 SignalInner::CtrlC(inner) => {
131 if let Err(err) = futures::ready!(inner.poll(cx)) {
132 tracing::error!(error = %err, "failed to receive a signal");
133 }
134
135 assert!(inner.try_set(signal::ctrl_c()).is_ok());
136 }
137 #[cfg(unix)]
138 SignalInner::Unix(inner) => {
139 if !matches!(inner.poll_recv(cx), Poll::Ready(Some(()))) {
140 return Poll::Pending;
141 }
142 }
143 SignalInner::Empty => return Poll::Pending,
144 };
145
146 let message = (self.message_factory)();
147 let kind = MessageKind::Regular { sender: Addr::NULL };
148 let trace_id = TraceId::generate();
149 let envelope = Envelope::with_trace_id(message, kind, trace_id).upcast();
150 Poll::Ready(Some(envelope))
151 }
152}
153
154#[cfg(test)]
155#[cfg(feature = "test-util")]
156mod tests {
157 use super::*;
158
159 use futures::{future::poll_fn, poll};
160
161 use elfo_macros::message;
162
163 use crate::{assert_msg, source::Source};
164
165 #[message(elfo = crate)]
166 struct SomeSignal;
167
168 #[tokio::test]
169 #[cfg(unix)]
170 async fn unix_signal() {
171 let signal = Signal::new(SignalKind::User1, || SomeSignal);
172
173 for _ in 0..=5 {
174 assert!(poll!(poll_fn(|cx| signal.poll_recv(cx))).is_pending());
175 send_signal(libc::SIGUSR1);
176 assert_msg!(
177 poll_fn(|cx| signal.poll_recv(cx)).await.unwrap(),
178 SomeSignal
179 );
180 }
181 }
182
183 #[tokio::test]
184 #[cfg(unix)]
185 async fn ctrl_c() {
186 let signal = Signal::new(SignalKind::CtrlC, || SomeSignal);
187
188 for _ in 0..=5 {
189 assert!(poll!(poll_fn(|cx| signal.poll_recv(cx))).is_pending());
190 send_signal(libc::SIGINT);
191 assert_msg!(
192 poll_fn(|cx| signal.poll_recv(cx)).await.unwrap(),
193 SomeSignal
194 );
195 }
196 }
197
198 fn send_signal(signal: libc::c_int) {
199 use libc::{getpid, kill};
200
201 unsafe {
202 assert_eq!(kill(getpid(), signal), 0);
203 }
204 }
205}