Skip to main content

argus_worker/
shutdown.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3
4use tokio::signal;
5use tokio::sync::broadcast;
6
7#[derive(Clone)]
8pub struct ShutdownSignal {
9    sender: broadcast::Sender<()>,
10    is_shutdown: Arc<AtomicBool>,
11}
12
13impl ShutdownSignal {
14    pub fn new() -> Self {
15        let (sender, _) = broadcast::channel(1);
16        Self {
17            sender,
18            is_shutdown: Arc::new(AtomicBool::new(false)),
19        }
20    }
21
22    pub fn subscribe(&self) -> broadcast::Receiver<()> {
23        self.sender.subscribe()
24    }
25
26    pub fn trigger(&self) {
27        self.is_shutdown.store(true, Ordering::SeqCst);
28        let _ = self.sender.send(());
29        tracing::info!("shutdown signal triggered");
30    }
31
32    pub fn is_shutdown(&self) -> bool {
33        self.is_shutdown.load(Ordering::SeqCst)
34    }
35
36    pub async fn wait_for_signal(&self) {
37        let mut receiver = self.subscribe();
38        let _ = receiver.recv().await;
39    }
40}
41
42impl Default for ShutdownSignal {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48pub async fn listen_for_shutdown(signal: ShutdownSignal) {
49    tokio::select! {
50        _ = signal::ctrl_c() => {
51            tracing::info!("received SIGINT (Ctrl+C)");
52            signal.trigger();
53        }
54        _ = wait_for_sigterm() => {
55            tracing::info!("received SIGTERM");
56            signal.trigger();
57        }
58    }
59}
60
61#[cfg(unix)]
62async fn wait_for_sigterm() {
63    use tokio::signal::unix::{signal, SignalKind};
64    let mut sigterm = signal(SignalKind::terminate()).expect("failed to setup SIGTERM handler");
65    sigterm.recv().await;
66}
67
68#[cfg(not(unix))]
69async fn wait_for_sigterm() {
70    std::future::pending::<()>().await;
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use std::time::Duration;
77
78    #[tokio::test]
79    async fn shutdown_signal_triggers() {
80        let signal = ShutdownSignal::new();
81        assert!(!signal.is_shutdown());
82
83        signal.trigger();
84        assert!(signal.is_shutdown());
85    }
86
87    #[tokio::test]
88    async fn multiple_subscribers_receive_signal() {
89        let signal = ShutdownSignal::new();
90        let mut rx1 = signal.subscribe();
91        let mut rx2 = signal.subscribe();
92
93        signal.trigger();
94
95        tokio::time::timeout(Duration::from_millis(100), rx1.recv())
96            .await
97            .expect("timeout")
98            .expect("receive");
99        tokio::time::timeout(Duration::from_millis(100), rx2.recv())
100            .await
101            .expect("timeout")
102            .expect("receive");
103    }
104
105    #[tokio::test]
106    async fn wait_for_signal_completes_on_trigger() {
107        let signal = ShutdownSignal::new();
108        let signal_clone = signal.clone();
109
110        let handle = tokio::spawn(async move {
111            signal_clone.wait_for_signal().await;
112        });
113
114        tokio::time::sleep(Duration::from_millis(10)).await;
115        signal.trigger();
116
117        tokio::time::timeout(Duration::from_millis(100), handle)
118            .await
119            .expect("timeout")
120            .expect("join");
121    }
122}