mecomp_daemon/
termination.rs1use std::sync::{Arc, atomic::AtomicBool};
2
3#[cfg(unix)]
4use tokio::signal::unix::signal;
5use tokio::sync::broadcast;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Interrupted {
9 OsSigInt,
10 OsSigQuit,
11 OsSigTerm,
12 UserInt,
13}
14
15const FORCE_QUIT_THRESHOLD: u8 = 3;
16
17#[derive(Debug)]
18pub struct InterruptReceiver {
24 interrupt_rx: broadcast::Receiver<Interrupted>,
25 stopped: Arc<AtomicBool>,
26}
27
28impl InterruptReceiver {
29 #[must_use]
30 #[inline]
31 pub fn new(interrupt_rx: broadcast::Receiver<Interrupted>) -> Self {
32 Self {
33 interrupt_rx,
34 stopped: Arc::new(AtomicBool::new(false)),
35 }
36 }
37
38 #[must_use]
39 #[inline]
40 pub fn dummy() -> Self {
44 let (tx, rx) = broadcast::channel(1);
45
46 std::mem::forget(tx);
48
49 Self {
50 interrupt_rx: rx,
51 stopped: Arc::new(AtomicBool::new(false)),
52 }
53 }
54
55 #[inline]
61 pub async fn wait(&mut self) -> Result<Interrupted, tokio::sync::broadcast::error::RecvError> {
62 let interrupted = self.interrupt_rx.recv().await?;
63
64 self.stopped
66 .store(true, std::sync::atomic::Ordering::SeqCst);
67
68 Ok(interrupted)
69 }
70
71 #[must_use]
75 #[inline]
76 pub fn resubscribe(&self) -> Self {
77 Self {
79 interrupt_rx: self.interrupt_rx.resubscribe(),
80 stopped: self.stopped.clone(),
81 }
82 }
83
84 #[must_use]
86 #[inline]
87 pub fn is_stopped(&self) -> bool {
88 self.stopped.load(std::sync::atomic::Ordering::SeqCst)
89 }
90}
91
92#[derive(Debug, Clone)]
93pub struct Terminator {
99 interrupt_tx: broadcast::Sender<Interrupted>,
100}
101
102impl Terminator {
103 #[must_use]
104 #[inline]
105 pub const fn new(interrupt_tx: broadcast::Sender<Interrupted>) -> Self {
106 Self { interrupt_tx }
107 }
108
109 #[inline]
115 pub fn terminate(&self, interrupted: Interrupted) -> anyhow::Result<()> {
116 self.interrupt_tx.send(interrupted)?;
117
118 Ok(())
119 }
120}
121
122#[cfg(unix)]
123#[inline]
124async fn terminate_by_signal(terminator: Terminator) {
125 let mut interrupt_signal = signal(tokio::signal::unix::SignalKind::interrupt())
126 .expect("failed to create interrupt signal stream");
127 let mut term_signal = signal(tokio::signal::unix::SignalKind::terminate())
128 .expect("failed to create terminate signal stream");
129 let mut quit_signal = signal(tokio::signal::unix::SignalKind::quit())
130 .expect("failed to create quit signal stream");
131
132 let mut signal_tick = tokio::time::interval(std::time::Duration::from_secs(1));
133 signal_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
134
135 let mut kill_count = 0;
136
137 loop {
138 if kill_count >= FORCE_QUIT_THRESHOLD {
140 log::warn!(
141 "Received {FORCE_QUIT_THRESHOLD} signals, forcefully terminating the application"
142 );
143 std::process::exit(1);
144 }
145
146 tokio::select! {
147 _ = signal_tick.tick() => {
148 }
151 _ = interrupt_signal.recv() => {
152 if let Err(e) = terminator.terminate(Interrupted::OsSigInt) {
153 log::warn!("failed to send interrupt signal: {e}");
154 }
155 kill_count += 1;
156 }
157 _ = term_signal.recv() => {
158 if let Err(e) = terminator.terminate(Interrupted::OsSigTerm) {
159 log::warn!("failed to send terminate signal: {e}");
160 }
161 kill_count += 1;
162 }
163 _ = quit_signal.recv() => {
164 if let Err(e) = terminator.terminate(Interrupted::OsSigQuit) {
165 log::warn!("failed to send quit signal: {e}");
166 }
167 kill_count += 1;
168 }
169 _ = tokio::signal::ctrl_c() => {
170 if let Err(e) = terminator.terminate(Interrupted::UserInt) {
171 log::warn!("failed to send interrupt signal: {e}");
172 }
173 kill_count += 1;
174 }
175 }
176 }
177}
178
179#[cfg(not(unix))]
180async fn terminate_by_signal(terminator: Terminator) {
181 let mut signal_tick = tokio::time::interval(std::time::Duration::from_secs(1));
185 signal_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
186
187 let mut kill_count = 0;
188
189 loop {
190 if kill_count >= FORCE_QUIT_THRESHOLD {
192 log::warn!(
193 "Received {FORCE_QUIT_THRESHOLD} signals, forcefully terminating the application"
194 );
195 std::process::exit(1);
196 }
197
198 tokio::select! {
199 _ = signal_tick.tick() => {
200 }
203 _ = tokio::signal::ctrl_c() => {
204 if let Err(e) = terminator.terminate(Interrupted::UserInt) {
205 log::warn!("failed to send interrupt signal: {e}");
206 }
207 kill_count += 1;
208 }
209 }
210 }
211}
212
213#[allow(clippy::module_name_repetitions)]
219#[must_use]
220#[inline]
221pub fn create_termination() -> (Terminator, InterruptReceiver) {
222 let (tx, rx) = broadcast::channel(2);
223 let terminator = Terminator::new(tx);
224 let interrupt = InterruptReceiver::new(rx);
225
226 tokio::spawn(terminate_by_signal(terminator.clone()));
227
228 (terminator, interrupt)
229}
230
231#[cfg(test)]
232mod test {
233 use std::time::Duration;
234
235 use super::*;
236 use pretty_assertions::assert_eq;
237 use rstest::rstest;
238
239 #[rstest]
240 #[timeout(Duration::from_secs(1))]
241 #[tokio::test]
242 async fn test_terminate() {
243 let (terminator, mut rx) = create_termination();
244
245 terminator
246 .terminate(Interrupted::UserInt)
247 .expect("failed to send interrupt signal");
248
249 assert_eq!(rx.wait().await, Ok(Interrupted::UserInt));
250 }
251}