Skip to main content

punch_kernel/
shutdown.rs

1//! Graceful shutdown coordinator for the Punch kernel.
2//!
3//! Tracks in-flight requests, broadcasts a shutdown signal, waits for
4//! requests to drain (with a configurable timeout), and fires registered
5//! shutdown hooks in order.
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use std::time::Duration;
12
13use tokio::sync::{Mutex, watch};
14use tracing::{info, warn};
15
16/// A callback that runs during shutdown. The boxed future must be `Send`.
17pub type ShutdownHook = Box<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
18
19/// Coordinates graceful shutdown across the system.
20///
21/// # Shutdown phases
22///
23/// 1. **Stop accepting** — signal is broadcast, new requests get 503.
24/// 2. **Drain in-flight** — wait for the in-flight counter to reach zero.
25/// 3. **Cage all gorillas** — stop autonomous agents.
26/// 4. **Flush logs** — give tracing subscribers time to flush.
27/// 5. **Close DB** — release database handles.
28///
29/// Hooks registered via [`register_hook`](ShutdownCoordinator::register_hook)
30/// fire in registration order after the drain phase.
31pub struct ShutdownCoordinator {
32    /// Sender half — set to `true` to broadcast shutdown.
33    shutdown_signal: watch::Sender<bool>,
34    /// Receiver half — subscribers clone this to watch for shutdown.
35    shutdown_receiver: watch::Receiver<bool>,
36    /// Number of requests currently in flight.
37    in_flight: AtomicUsize,
38    /// Maximum time to wait for in-flight requests to finish.
39    drain_timeout: Duration,
40    /// Ordered list of hooks to run during shutdown.
41    hooks: Mutex<Vec<ShutdownHook>>,
42    /// Whether shutdown has already been initiated (idempotency guard).
43    initiated: AtomicBool,
44}
45
46impl ShutdownCoordinator {
47    /// Create a new shutdown coordinator with the given drain timeout.
48    pub fn new(drain_timeout: Duration) -> Arc<Self> {
49        let (tx, rx) = watch::channel(false);
50        Arc::new(Self {
51            shutdown_signal: tx,
52            shutdown_receiver: rx,
53            in_flight: AtomicUsize::new(0),
54            drain_timeout,
55            hooks: Mutex::new(Vec::new()),
56            initiated: AtomicBool::new(false),
57        })
58    }
59
60    /// Create a coordinator with the default 30-second drain timeout.
61    pub fn with_default_timeout() -> Arc<Self> {
62        Self::new(Duration::from_secs(30))
63    }
64
65    /// Subscribe to the shutdown signal.
66    ///
67    /// The returned receiver yields `true` once shutdown is initiated.
68    pub fn subscribe(&self) -> watch::Receiver<bool> {
69        self.shutdown_receiver.clone()
70    }
71
72    /// Returns `true` if shutdown has been initiated.
73    pub fn is_shutting_down(&self) -> bool {
74        *self.shutdown_receiver.borrow()
75    }
76
77    /// Increment the in-flight request counter.
78    ///
79    /// Returns `false` if shutdown is in progress (caller should reject
80    /// the request with 503).
81    pub fn track_request(&self) -> bool {
82        if self.is_shutting_down() {
83            return false;
84        }
85        self.in_flight.fetch_add(1, Ordering::SeqCst);
86        true
87    }
88
89    /// Decrement the in-flight request counter.
90    pub fn finish_request(&self) {
91        let prev = self.in_flight.fetch_sub(1, Ordering::SeqCst);
92        // Guard against underflow (shouldn't happen, but be safe).
93        if prev == 0 {
94            self.in_flight.store(0, Ordering::SeqCst);
95        }
96    }
97
98    /// Current number of in-flight requests.
99    pub fn in_flight_count(&self) -> usize {
100        self.in_flight.load(Ordering::SeqCst)
101    }
102
103    /// Register a hook that will be called during shutdown.
104    ///
105    /// Hooks fire in registration order after the drain phase completes
106    /// or times out.
107    pub async fn register_hook<F, Fut>(&self, hook: F)
108    where
109        F: Fn() -> Fut + Send + Sync + 'static,
110        Fut: Future<Output = ()> + Send + 'static,
111    {
112        let boxed: ShutdownHook = Box::new(move || Box::pin(hook()));
113        self.hooks.lock().await.push(boxed);
114    }
115
116    /// Initiate graceful shutdown.
117    ///
118    /// This method is idempotent — calling it multiple times has no
119    /// additional effect.
120    ///
121    /// 1. Broadcasts the shutdown signal (new requests will be rejected).
122    /// 2. Waits for in-flight requests to drain (up to `drain_timeout`).
123    /// 3. Runs all registered hooks in order.
124    pub async fn initiate_shutdown(&self) {
125        // Idempotency: only run once.
126        if self
127            .initiated
128            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
129            .is_err()
130        {
131            info!("shutdown already in progress, ignoring duplicate signal");
132            return;
133        }
134
135        info!("initiating graceful shutdown");
136
137        // Phase 1: broadcast signal.
138        let _ = self.shutdown_signal.send(true);
139
140        // Phase 2: drain in-flight requests.
141        let drain_start = tokio::time::Instant::now();
142        let deadline = drain_start + self.drain_timeout;
143
144        loop {
145            let count = self.in_flight.load(Ordering::SeqCst);
146            if count == 0 {
147                info!("all in-flight requests drained");
148                break;
149            }
150
151            if tokio::time::Instant::now() >= deadline {
152                warn!(
153                    remaining = count,
154                    "drain timeout reached, force-terminating remaining requests"
155                );
156                break;
157            }
158
159            tokio::time::sleep(Duration::from_millis(50)).await;
160        }
161
162        // Phase 3-5: run hooks (cage gorillas, flush logs, close DB, etc.).
163        let hooks = self.hooks.lock().await;
164        for (i, hook) in hooks.iter().enumerate() {
165            info!(hook_index = i, "running shutdown hook");
166            hook().await;
167        }
168
169        info!("shutdown complete");
170    }
171}
172
173// ---------------------------------------------------------------------------
174// Tests
175// ---------------------------------------------------------------------------
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use std::sync::atomic::AtomicU32;
181
182    #[tokio::test]
183    async fn shutdown_signal_propagates() {
184        let coord = ShutdownCoordinator::with_default_timeout();
185        let mut rx = coord.subscribe();
186
187        assert!(!coord.is_shutting_down());
188
189        coord.initiate_shutdown().await;
190
191        // Receiver should see the signal.
192        rx.changed().await.ok();
193        assert!(*rx.borrow());
194        assert!(coord.is_shutting_down());
195    }
196
197    #[tokio::test]
198    async fn in_flight_counter_tracks() {
199        let coord = ShutdownCoordinator::with_default_timeout();
200
201        assert_eq!(coord.in_flight_count(), 0);
202
203        assert!(coord.track_request());
204        assert_eq!(coord.in_flight_count(), 1);
205
206        assert!(coord.track_request());
207        assert_eq!(coord.in_flight_count(), 2);
208
209        coord.finish_request();
210        assert_eq!(coord.in_flight_count(), 1);
211
212        coord.finish_request();
213        assert_eq!(coord.in_flight_count(), 0);
214    }
215
216    #[tokio::test]
217    async fn drain_waits_for_in_flight() {
218        let coord = ShutdownCoordinator::new(Duration::from_secs(5));
219        let coord_clone = Arc::clone(&coord);
220
221        // Start a request.
222        assert!(coord.track_request());
223
224        // Start shutdown in the background.
225        let handle = tokio::spawn(async move {
226            coord_clone.initiate_shutdown().await;
227        });
228
229        // Give shutdown a moment to start draining.
230        tokio::time::sleep(Duration::from_millis(100)).await;
231
232        // Finish the request — shutdown should then complete.
233        coord.finish_request();
234
235        // Shutdown should complete within a reasonable time.
236        tokio::time::timeout(Duration::from_secs(2), handle)
237            .await
238            .expect("shutdown should complete")
239            .expect("shutdown task should not panic");
240    }
241
242    #[tokio::test]
243    async fn drain_timeout_forces_shutdown() {
244        let coord = ShutdownCoordinator::new(Duration::from_millis(100));
245
246        // Start a request but never finish it.
247        assert!(coord.track_request());
248
249        let start = tokio::time::Instant::now();
250        coord.initiate_shutdown().await;
251        let elapsed = start.elapsed();
252
253        // Should have timed out, not waited forever.
254        assert!(elapsed < Duration::from_secs(2));
255        // Request is still in flight (force-terminated).
256        assert_eq!(coord.in_flight_count(), 1);
257    }
258
259    #[tokio::test]
260    async fn hooks_fire_in_order() {
261        let coord = ShutdownCoordinator::with_default_timeout();
262
263        let order = Arc::new(Mutex::new(Vec::<u32>::new()));
264        let o1 = Arc::clone(&order);
265        let o2 = Arc::clone(&order);
266        let o3 = Arc::clone(&order);
267
268        coord
269            .register_hook(move || {
270                let o = Arc::clone(&o1);
271                async move {
272                    o.lock().await.push(1);
273                }
274            })
275            .await;
276
277        coord
278            .register_hook(move || {
279                let o = Arc::clone(&o2);
280                async move {
281                    o.lock().await.push(2);
282                }
283            })
284            .await;
285
286        coord
287            .register_hook(move || {
288                let o = Arc::clone(&o3);
289                async move {
290                    o.lock().await.push(3);
291                }
292            })
293            .await;
294
295        coord.initiate_shutdown().await;
296
297        let fired = order.lock().await;
298        assert_eq!(*fired, vec![1, 2, 3]);
299    }
300
301    #[tokio::test]
302    async fn multiple_shutdown_signals_are_idempotent() {
303        let coord = ShutdownCoordinator::with_default_timeout();
304        let counter = Arc::new(AtomicU32::new(0));
305        let c = Arc::clone(&counter);
306
307        coord
308            .register_hook(move || {
309                let c = Arc::clone(&c);
310                async move {
311                    c.fetch_add(1, Ordering::SeqCst);
312                }
313            })
314            .await;
315
316        coord.initiate_shutdown().await;
317        coord.initiate_shutdown().await;
318        coord.initiate_shutdown().await;
319
320        // Hook should have fired exactly once.
321        assert_eq!(counter.load(Ordering::SeqCst), 1);
322    }
323
324    #[tokio::test]
325    async fn track_request_rejected_during_shutdown() {
326        let coord = ShutdownCoordinator::with_default_timeout();
327
328        coord.initiate_shutdown().await;
329
330        // New requests should be rejected.
331        assert!(!coord.track_request());
332    }
333}