1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use std::time::Duration;
12
13use tokio::sync::{Mutex, watch};
14use tracing::{info, warn};
15
16pub type ShutdownHook = Box<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
18
19pub struct ShutdownCoordinator {
32 shutdown_signal: watch::Sender<bool>,
34 shutdown_receiver: watch::Receiver<bool>,
36 in_flight: AtomicUsize,
38 drain_timeout: Duration,
40 hooks: Mutex<Vec<ShutdownHook>>,
42 initiated: AtomicBool,
44}
45
46impl ShutdownCoordinator {
47 pub fn new(drain_timeout: Duration) -> Arc<Self> {
49 let (tx, rx) = watch::channel(false);
50 Arc::new(Self {
51 shutdown_signal: tx,
52 shutdown_receiver: rx,
53 in_flight: AtomicUsize::new(0),
54 drain_timeout,
55 hooks: Mutex::new(Vec::new()),
56 initiated: AtomicBool::new(false),
57 })
58 }
59
60 pub fn with_default_timeout() -> Arc<Self> {
62 Self::new(Duration::from_secs(30))
63 }
64
65 pub fn subscribe(&self) -> watch::Receiver<bool> {
69 self.shutdown_receiver.clone()
70 }
71
72 pub fn is_shutting_down(&self) -> bool {
74 *self.shutdown_receiver.borrow()
75 }
76
77 pub fn track_request(&self) -> bool {
82 if self.is_shutting_down() {
83 return false;
84 }
85 self.in_flight.fetch_add(1, Ordering::SeqCst);
86 true
87 }
88
89 pub fn finish_request(&self) {
91 let prev = self.in_flight.fetch_sub(1, Ordering::SeqCst);
92 if prev == 0 {
94 self.in_flight.store(0, Ordering::SeqCst);
95 }
96 }
97
98 pub fn in_flight_count(&self) -> usize {
100 self.in_flight.load(Ordering::SeqCst)
101 }
102
103 pub async fn register_hook<F, Fut>(&self, hook: F)
108 where
109 F: Fn() -> Fut + Send + Sync + 'static,
110 Fut: Future<Output = ()> + Send + 'static,
111 {
112 let boxed: ShutdownHook = Box::new(move || Box::pin(hook()));
113 self.hooks.lock().await.push(boxed);
114 }
115
116 pub async fn initiate_shutdown(&self) {
125 if self
127 .initiated
128 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
129 .is_err()
130 {
131 info!("shutdown already in progress, ignoring duplicate signal");
132 return;
133 }
134
135 info!("initiating graceful shutdown");
136
137 let _ = self.shutdown_signal.send(true);
139
140 let drain_start = tokio::time::Instant::now();
142 let deadline = drain_start + self.drain_timeout;
143
144 loop {
145 let count = self.in_flight.load(Ordering::SeqCst);
146 if count == 0 {
147 info!("all in-flight requests drained");
148 break;
149 }
150
151 if tokio::time::Instant::now() >= deadline {
152 warn!(
153 remaining = count,
154 "drain timeout reached, force-terminating remaining requests"
155 );
156 break;
157 }
158
159 tokio::time::sleep(Duration::from_millis(50)).await;
160 }
161
162 let hooks = self.hooks.lock().await;
164 for (i, hook) in hooks.iter().enumerate() {
165 info!(hook_index = i, "running shutdown hook");
166 hook().await;
167 }
168
169 info!("shutdown complete");
170 }
171}
172
173#[cfg(test)]
178mod tests {
179 use super::*;
180 use std::sync::atomic::AtomicU32;
181
182 #[tokio::test]
183 async fn shutdown_signal_propagates() {
184 let coord = ShutdownCoordinator::with_default_timeout();
185 let mut rx = coord.subscribe();
186
187 assert!(!coord.is_shutting_down());
188
189 coord.initiate_shutdown().await;
190
191 rx.changed().await.ok();
193 assert!(*rx.borrow());
194 assert!(coord.is_shutting_down());
195 }
196
197 #[tokio::test]
198 async fn in_flight_counter_tracks() {
199 let coord = ShutdownCoordinator::with_default_timeout();
200
201 assert_eq!(coord.in_flight_count(), 0);
202
203 assert!(coord.track_request());
204 assert_eq!(coord.in_flight_count(), 1);
205
206 assert!(coord.track_request());
207 assert_eq!(coord.in_flight_count(), 2);
208
209 coord.finish_request();
210 assert_eq!(coord.in_flight_count(), 1);
211
212 coord.finish_request();
213 assert_eq!(coord.in_flight_count(), 0);
214 }
215
216 #[tokio::test]
217 async fn drain_waits_for_in_flight() {
218 let coord = ShutdownCoordinator::new(Duration::from_secs(5));
219 let coord_clone = Arc::clone(&coord);
220
221 assert!(coord.track_request());
223
224 let handle = tokio::spawn(async move {
226 coord_clone.initiate_shutdown().await;
227 });
228
229 tokio::time::sleep(Duration::from_millis(100)).await;
231
232 coord.finish_request();
234
235 tokio::time::timeout(Duration::from_secs(2), handle)
237 .await
238 .expect("shutdown should complete")
239 .expect("shutdown task should not panic");
240 }
241
242 #[tokio::test]
243 async fn drain_timeout_forces_shutdown() {
244 let coord = ShutdownCoordinator::new(Duration::from_millis(100));
245
246 assert!(coord.track_request());
248
249 let start = tokio::time::Instant::now();
250 coord.initiate_shutdown().await;
251 let elapsed = start.elapsed();
252
253 assert!(elapsed < Duration::from_secs(2));
255 assert_eq!(coord.in_flight_count(), 1);
257 }
258
259 #[tokio::test]
260 async fn hooks_fire_in_order() {
261 let coord = ShutdownCoordinator::with_default_timeout();
262
263 let order = Arc::new(Mutex::new(Vec::<u32>::new()));
264 let o1 = Arc::clone(&order);
265 let o2 = Arc::clone(&order);
266 let o3 = Arc::clone(&order);
267
268 coord
269 .register_hook(move || {
270 let o = Arc::clone(&o1);
271 async move {
272 o.lock().await.push(1);
273 }
274 })
275 .await;
276
277 coord
278 .register_hook(move || {
279 let o = Arc::clone(&o2);
280 async move {
281 o.lock().await.push(2);
282 }
283 })
284 .await;
285
286 coord
287 .register_hook(move || {
288 let o = Arc::clone(&o3);
289 async move {
290 o.lock().await.push(3);
291 }
292 })
293 .await;
294
295 coord.initiate_shutdown().await;
296
297 let fired = order.lock().await;
298 assert_eq!(*fired, vec![1, 2, 3]);
299 }
300
301 #[tokio::test]
302 async fn multiple_shutdown_signals_are_idempotent() {
303 let coord = ShutdownCoordinator::with_default_timeout();
304 let counter = Arc::new(AtomicU32::new(0));
305 let c = Arc::clone(&counter);
306
307 coord
308 .register_hook(move || {
309 let c = Arc::clone(&c);
310 async move {
311 c.fetch_add(1, Ordering::SeqCst);
312 }
313 })
314 .await;
315
316 coord.initiate_shutdown().await;
317 coord.initiate_shutdown().await;
318 coord.initiate_shutdown().await;
319
320 assert_eq!(counter.load(Ordering::SeqCst), 1);
322 }
323
324 #[tokio::test]
325 async fn track_request_rejected_during_shutdown() {
326 let coord = ShutdownCoordinator::with_default_timeout();
327
328 coord.initiate_shutdown().await;
329
330 assert!(!coord.track_request());
332 }
333}