mecomp_daemon/
termination.rs

1use std::sync::{Arc, atomic::AtomicBool};
2
3#[cfg(unix)]
4use tokio::signal::unix::signal;
5use tokio::sync::broadcast;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Interrupted {
9    OsSigInt,
10    OsSigQuit,
11    OsSigTerm,
12    UserInt,
13}
14
15const FORCE_QUIT_THRESHOLD: u8 = 3;
16
17#[derive(Debug)]
18/// Used to handle the termination of the application.
19///
20/// A struct that handles listening for interrupt signals, and/or tracking whether an interrupt signal has been received.
21///
22/// Essentially, the receiving side of the broadcast channel.
23pub struct InterruptReceiver {
24    interrupt_rx: broadcast::Receiver<Interrupted>,
25    stopped: Arc<AtomicBool>,
26}
27
28impl InterruptReceiver {
29    #[must_use]
30    #[inline]
31    pub fn new(interrupt_rx: broadcast::Receiver<Interrupted>) -> Self {
32        Self {
33            interrupt_rx,
34            stopped: Arc::new(AtomicBool::new(false)),
35        }
36    }
37
38    #[must_use]
39    #[inline]
40    /// Create a dummy receiver that doesn't receive any signals
41    ///
42    /// Attempting to wait on this receiver will wait indefinitely.
43    pub fn dummy() -> Self {
44        let (tx, rx) = broadcast::channel(1);
45
46        // forget the sender so it's dropped w/o calling its destructor
47        std::mem::forget(tx);
48
49        Self {
50            interrupt_rx: rx,
51            stopped: Arc::new(AtomicBool::new(false)),
52        }
53    }
54
55    /// Wait for an interrupt signal to be received.
56    ///
57    /// # Errors
58    ///
59    /// Fails if the interrupt signal cannot be received (e.g. the sender has been dropped)
60    #[inline]
61    pub async fn wait(&mut self) -> Result<Interrupted, tokio::sync::broadcast::error::RecvError> {
62        let interrupted = self.interrupt_rx.recv().await?;
63
64        // Set the stopped flag to true
65        self.stopped
66            .store(true, std::sync::atomic::Ordering::SeqCst);
67
68        Ok(interrupted)
69    }
70
71    /// Re-subscribe to the broadcast channel.
72    ///
73    /// Gives you a new receiver that can be used to receive interrupt signals.
74    #[must_use]
75    #[inline]
76    pub fn resubscribe(&self) -> Self {
77        // Resubscribe to the broadcast channel
78        Self {
79            interrupt_rx: self.interrupt_rx.resubscribe(),
80            stopped: self.stopped.clone(),
81        }
82    }
83
84    /// Check if an interrupt signal has been received previously.
85    #[must_use]
86    #[inline]
87    pub fn is_stopped(&self) -> bool {
88        self.stopped.load(std::sync::atomic::Ordering::SeqCst)
89    }
90}
91
92#[derive(Debug, Clone)]
93/// Used to handle the termination of the application.
94///
95/// A struct that handles sending interrupt signals to the application.
96///
97/// Essentially, the sending side of the broadcast channel.
98pub struct Terminator {
99    interrupt_tx: broadcast::Sender<Interrupted>,
100}
101
102impl Terminator {
103    #[must_use]
104    #[inline]
105    pub const fn new(interrupt_tx: broadcast::Sender<Interrupted>) -> Self {
106        Self { interrupt_tx }
107    }
108
109    /// Send an interrupt signal to the application.
110    ///
111    /// # Errors
112    ///
113    /// Fails if the interrupt signal cannot be sent (e.g. the receiver has been dropped)
114    #[inline]
115    pub fn terminate(&self, interrupted: Interrupted) -> anyhow::Result<()> {
116        self.interrupt_tx.send(interrupted)?;
117
118        Ok(())
119    }
120}
121
122#[cfg(unix)]
123#[inline]
124async fn terminate_by_signal(terminator: Terminator) {
125    let mut interrupt_signal = signal(tokio::signal::unix::SignalKind::interrupt())
126        .expect("failed to create interrupt signal stream");
127    let mut term_signal = signal(tokio::signal::unix::SignalKind::terminate())
128        .expect("failed to create terminate signal stream");
129    let mut quit_signal = signal(tokio::signal::unix::SignalKind::quit())
130        .expect("failed to create quit signal stream");
131
132    let mut signal_tick = tokio::time::interval(std::time::Duration::from_secs(1));
133    signal_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
134
135    let mut kill_count = 0;
136
137    loop {
138        // if we've received 3 signals, we should forcefully terminate the application
139        if kill_count >= FORCE_QUIT_THRESHOLD {
140            log::warn!(
141                "Received {FORCE_QUIT_THRESHOLD} signals, forcefully terminating the application"
142            );
143            std::process::exit(1);
144        }
145
146        tokio::select! {
147            _ = signal_tick.tick() => {
148                // If we receive a signal within 1 second, we can ignore it
149                // and wait for the next signal.
150            }
151            _ = interrupt_signal.recv() => {
152                if let Err(e) = terminator.terminate(Interrupted::OsSigInt) {
153                    log::warn!("failed to send interrupt signal: {e}");
154                }
155                kill_count += 1;
156            }
157            _ = term_signal.recv() => {
158                if let Err(e) = terminator.terminate(Interrupted::OsSigTerm) {
159                    log::warn!("failed to send terminate signal: {e}");
160                }
161                kill_count += 1;
162            }
163            _ = quit_signal.recv() => {
164                if let Err(e) = terminator.terminate(Interrupted::OsSigQuit) {
165                    log::warn!("failed to send quit signal: {e}");
166                }
167                kill_count += 1;
168            }
169            _ = tokio::signal::ctrl_c() => {
170                if let Err(e) = terminator.terminate(Interrupted::UserInt) {
171                    log::warn!("failed to send interrupt signal: {e}");
172                }
173                kill_count += 1;
174            }
175        }
176    }
177}
178
179#[cfg(not(unix))]
180async fn terminate_by_signal(terminator: Terminator) {
181    // On non-unix systems, we don't have any signals to handle.
182    // We can still use the ctrl_c signal to terminate the application.
183
184    let mut signal_tick = tokio::time::interval(std::time::Duration::from_secs(1));
185    signal_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
186
187    let mut kill_count = 0;
188
189    loop {
190        // if we've received 3 signals, we should forcefully terminate the application
191        if kill_count >= FORCE_QUIT_THRESHOLD {
192            log::warn!(
193                "Received {FORCE_QUIT_THRESHOLD} signals, forcefully terminating the application"
194            );
195            std::process::exit(1);
196        }
197
198        tokio::select! {
199            _ = signal_tick.tick() => {
200                // If we receive a signal within 1 second, we can ignore it
201                // and wait for the next signal.
202            }
203            _ = tokio::signal::ctrl_c() => {
204                if let Err(e) = terminator.terminate(Interrupted::UserInt) {
205                    log::warn!("failed to send interrupt signal: {e}");
206                }
207                kill_count += 1;
208            }
209        }
210    }
211}
212
213/// create a broadcast channel for retrieving the application kill signal
214///
215/// # Panics
216///
217/// This function will panic if the tokio runtime cannot be created.
218#[allow(clippy::module_name_repetitions)]
219#[must_use]
220#[inline]
221pub fn create_termination() -> (Terminator, InterruptReceiver) {
222    let (tx, rx) = broadcast::channel(2);
223    let terminator = Terminator::new(tx);
224    let interrupt = InterruptReceiver::new(rx);
225
226    tokio::spawn(terminate_by_signal(terminator.clone()));
227
228    (terminator, interrupt)
229}
230
231#[cfg(test)]
232mod test {
233    use std::time::Duration;
234
235    use super::*;
236    use pretty_assertions::assert_eq;
237    use rstest::rstest;
238
239    #[rstest]
240    #[timeout(Duration::from_secs(1))]
241    #[tokio::test]
242    async fn test_terminate() {
243        let (terminator, mut rx) = create_termination();
244
245        terminator
246            .terminate(Interrupted::UserInt)
247            .expect("failed to send interrupt signal");
248
249        assert_eq!(rx.wait().await, Ok(Interrupted::UserInt));
250    }
251}