halldyll_core/observe/
shutdown.rs1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::time::Duration;
26use tokio::sync::Notify;
27use tokio::time::timeout;
28
29pub struct GracefulShutdown {
31 shutting_down: AtomicBool,
33 in_flight: AtomicUsize,
35 all_complete: Arc<Notify>,
37 shutdown_timeout: Duration,
39 completed: AtomicBool,
41}
42
43impl GracefulShutdown {
44 pub fn new(shutdown_timeout: Duration) -> Self {
46 Self {
47 shutting_down: AtomicBool::new(false),
48 in_flight: AtomicUsize::new(0),
49 all_complete: Arc::new(Notify::new()),
50 shutdown_timeout,
51 completed: AtomicBool::new(false),
52 }
53 }
54
55 pub fn default_timeout() -> Self {
57 Self::new(Duration::from_secs(30))
58 }
59
60 pub fn is_shutting_down(&self) -> bool {
62 self.shutting_down.load(Ordering::SeqCst)
63 }
64
65 pub fn in_flight_count(&self) -> usize {
67 self.in_flight.load(Ordering::SeqCst)
68 }
69
70 pub fn start_request(&self) -> Option<RequestGuard<'_>> {
72 if self.is_shutting_down() {
73 return None;
74 }
75
76 self.in_flight.fetch_add(1, Ordering::SeqCst);
77
78 if self.is_shutting_down() {
80 self.finish_request();
81 return None;
82 }
83
84 Some(RequestGuard {
85 shutdown: self,
86 })
87 }
88
89 fn finish_request(&self) {
91 let prev = self.in_flight.fetch_sub(1, Ordering::SeqCst);
92 if prev == 1 && self.is_shutting_down() {
93 self.all_complete.notify_waiters();
95 }
96 }
97
98 pub fn initiate(&self) {
100 self.shutting_down.store(true, Ordering::SeqCst);
101
102 if self.in_flight.load(Ordering::SeqCst) == 0 {
104 self.all_complete.notify_waiters();
105 }
106 }
107
108 pub async fn wait_for_completion(&self) -> ShutdownResult {
110 if !self.is_shutting_down() {
111 self.initiate();
112 }
113
114 if self.in_flight.load(Ordering::SeqCst) == 0 {
115 self.completed.store(true, Ordering::SeqCst);
116 return ShutdownResult::Clean;
117 }
118
119 match timeout(self.shutdown_timeout, self.all_complete.notified()).await {
120 Ok(_) => {
121 self.completed.store(true, Ordering::SeqCst);
122 ShutdownResult::Clean
123 }
124 Err(_) => {
125 let remaining = self.in_flight.load(Ordering::SeqCst);
126 self.completed.store(true, Ordering::SeqCst);
127 ShutdownResult::Timeout { remaining_requests: remaining }
128 }
129 }
130 }
131
132 pub fn is_completed(&self) -> bool {
134 self.completed.load(Ordering::SeqCst)
135 }
136
137 pub fn status(&self) -> ShutdownStatus {
139 ShutdownStatus {
140 shutting_down: self.is_shutting_down(),
141 in_flight: self.in_flight_count(),
142 completed: self.is_completed(),
143 }
144 }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq)]
149pub enum ShutdownResult {
150 Clean,
152 Timeout {
154 remaining_requests: usize,
156 },
157}
158
159impl ShutdownResult {
160 pub fn is_clean(&self) -> bool {
162 matches!(self, ShutdownResult::Clean)
163 }
164}
165
166pub struct RequestGuard<'a> {
168 shutdown: &'a GracefulShutdown,
169}
170
171impl Drop for RequestGuard<'_> {
172 fn drop(&mut self) {
173 self.shutdown.finish_request();
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct ShutdownStatus {
180 pub shutting_down: bool,
182 pub in_flight: usize,
184 pub completed: bool,
186}
187
188#[cfg(unix)]
190pub async fn wait_for_shutdown_signal() {
191 use tokio::signal::unix::{signal, SignalKind};
192
193 let mut sigterm = signal(SignalKind::terminate()).expect("Failed to install SIGTERM handler");
194 let mut sigint = signal(SignalKind::interrupt()).expect("Failed to install SIGINT handler");
195
196 tokio::select! {
197 _ = sigterm.recv() => {
198 tracing::info!("Received SIGTERM, initiating graceful shutdown");
199 }
200 _ = sigint.recv() => {
201 tracing::info!("Received SIGINT, initiating graceful shutdown");
202 }
203 }
204}
205
206#[cfg(windows)]
208pub async fn wait_for_shutdown_signal() {
209 use tokio::signal::ctrl_c;
210
211 ctrl_c().await.expect("Failed to listen for Ctrl+C");
212 tracing::info!("Received Ctrl+C, initiating graceful shutdown");
213}
214
215pub async fn run_with_graceful_shutdown<F, Fut>(
217 shutdown: Arc<GracefulShutdown>,
218 main_task: F,
219) -> ShutdownResult
220where
221 F: FnOnce() -> Fut,
222 Fut: std::future::Future<Output = ()>,
223{
224 let shutdown_clone = shutdown.clone();
226 tokio::spawn(async move {
227 wait_for_shutdown_signal().await;
228 shutdown_clone.initiate();
229 });
230
231 main_task().await;
233
234 shutdown.wait_for_completion().await
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[tokio::test]
243 async fn test_clean_shutdown() {
244 let shutdown = GracefulShutdown::new(Duration::from_secs(5));
245
246 let guard1 = shutdown.start_request().unwrap();
248 let guard2 = shutdown.start_request().unwrap();
249
250 assert_eq!(shutdown.in_flight_count(), 2);
251
252 shutdown.initiate();
254
255 assert!(shutdown.start_request().is_none());
257
258 drop(guard1);
260 assert_eq!(shutdown.in_flight_count(), 1);
261 drop(guard2);
262 assert_eq!(shutdown.in_flight_count(), 0);
263
264 let result = shutdown.wait_for_completion().await;
266 assert_eq!(result, ShutdownResult::Clean);
267 }
268
269 #[tokio::test]
270 async fn test_shutdown_timeout() {
271 let shutdown = GracefulShutdown::new(Duration::from_millis(100));
272
273 let _guard = shutdown.start_request().unwrap();
275
276 shutdown.initiate();
277
278 let result = shutdown.wait_for_completion().await;
279 assert!(matches!(result, ShutdownResult::Timeout { remaining_requests: 1 }));
280 }
281
282 #[tokio::test]
283 async fn test_empty_shutdown() {
284 let shutdown = GracefulShutdown::new(Duration::from_secs(5));
285
286 let result = shutdown.wait_for_completion().await;
287 assert_eq!(result, ShutdownResult::Clean);
288 }
289
290 #[test]
291 fn test_status() {
292 let shutdown = GracefulShutdown::default_timeout();
293
294 let status = shutdown.status();
295 assert!(!status.shutting_down);
296 assert_eq!(status.in_flight, 0);
297 assert!(!status.completed);
298 }
299}