1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3
4use tokio::signal;
5use tokio::sync::broadcast;
6
7#[derive(Clone)]
8pub struct ShutdownSignal {
9 sender: broadcast::Sender<()>,
10 is_shutdown: Arc<AtomicBool>,
11}
12
13impl ShutdownSignal {
14 pub fn new() -> Self {
15 let (sender, _) = broadcast::channel(1);
16 Self {
17 sender,
18 is_shutdown: Arc::new(AtomicBool::new(false)),
19 }
20 }
21
22 pub fn subscribe(&self) -> broadcast::Receiver<()> {
23 self.sender.subscribe()
24 }
25
26 pub fn trigger(&self) {
27 self.is_shutdown.store(true, Ordering::SeqCst);
28 let _ = self.sender.send(());
29 tracing::info!("shutdown signal triggered");
30 }
31
32 pub fn is_shutdown(&self) -> bool {
33 self.is_shutdown.load(Ordering::SeqCst)
34 }
35
36 pub async fn wait_for_signal(&self) {
37 let mut receiver = self.subscribe();
38 let _ = receiver.recv().await;
39 }
40}
41
42impl Default for ShutdownSignal {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48pub async fn listen_for_shutdown(signal: ShutdownSignal) {
49 tokio::select! {
50 _ = signal::ctrl_c() => {
51 tracing::info!("received SIGINT (Ctrl+C)");
52 signal.trigger();
53 }
54 _ = wait_for_sigterm() => {
55 tracing::info!("received SIGTERM");
56 signal.trigger();
57 }
58 }
59}
60
61#[cfg(unix)]
62async fn wait_for_sigterm() {
63 use tokio::signal::unix::{signal, SignalKind};
64 let mut sigterm = signal(SignalKind::terminate()).expect("failed to setup SIGTERM handler");
65 sigterm.recv().await;
66}
67
68#[cfg(not(unix))]
69async fn wait_for_sigterm() {
70 std::future::pending::<()>().await;
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use std::time::Duration;
77
78 #[tokio::test]
79 async fn shutdown_signal_triggers() {
80 let signal = ShutdownSignal::new();
81 assert!(!signal.is_shutdown());
82
83 signal.trigger();
84 assert!(signal.is_shutdown());
85 }
86
87 #[tokio::test]
88 async fn multiple_subscribers_receive_signal() {
89 let signal = ShutdownSignal::new();
90 let mut rx1 = signal.subscribe();
91 let mut rx2 = signal.subscribe();
92
93 signal.trigger();
94
95 tokio::time::timeout(Duration::from_millis(100), rx1.recv())
96 .await
97 .expect("timeout")
98 .expect("receive");
99 tokio::time::timeout(Duration::from_millis(100), rx2.recv())
100 .await
101 .expect("timeout")
102 .expect("receive");
103 }
104
105 #[tokio::test]
106 async fn wait_for_signal_completes_on_trigger() {
107 let signal = ShutdownSignal::new();
108 let signal_clone = signal.clone();
109
110 let handle = tokio::spawn(async move {
111 signal_clone.wait_for_signal().await;
112 });
113
114 tokio::time::sleep(Duration::from_millis(10)).await;
115 signal.trigger();
116
117 tokio::time::timeout(Duration::from_millis(100), handle)
118 .await
119 .expect("timeout")
120 .expect("join");
121 }
122}