Skip to main content

forge_runtime/cluster/
shutdown.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeStatus;
6use tokio::sync::watch;
7
8use super::registry::NodeRegistry;
9use crate::pg::LeaderElection;
10
11/// Graceful shutdown configuration.
12#[derive(Debug, Clone)]
13pub struct ShutdownConfig {
14    /// Timeout for waiting on in-flight requests.
15    pub drain_timeout: Duration,
16    /// How often to check for completion.
17    pub poll_interval: Duration,
18}
19
20impl Default for ShutdownConfig {
21    fn default() -> Self {
22        Self {
23            drain_timeout: Duration::from_secs(30),
24            poll_interval: Duration::from_millis(100),
25        }
26    }
27}
28
29/// Graceful shutdown coordinator.
30pub struct GracefulShutdown {
31    registry: Arc<NodeRegistry>,
32    leader_election: Option<Arc<LeaderElection>>,
33    config: ShutdownConfig,
34    shutdown_requested: Arc<AtomicBool>,
35    in_flight_count: Arc<AtomicU32>,
36    shutdown_tx: watch::Sender<bool>,
37}
38
39impl GracefulShutdown {
40    pub fn new(
41        registry: Arc<NodeRegistry>,
42        leader_election: Option<Arc<LeaderElection>>,
43        config: ShutdownConfig,
44    ) -> Self {
45        let (shutdown_tx, _) = watch::channel(false);
46        Self {
47            registry,
48            leader_election,
49            config,
50            shutdown_requested: Arc::new(AtomicBool::new(false)),
51            in_flight_count: Arc::new(AtomicU32::new(0)),
52            shutdown_tx,
53        }
54    }
55
56    pub fn is_shutdown_requested(&self) -> bool {
57        self.shutdown_requested.load(Ordering::SeqCst)
58    }
59
60    pub fn in_flight_count(&self) -> u32 {
61        self.in_flight_count.load(Ordering::SeqCst)
62    }
63
64    pub fn increment_in_flight(&self) {
65        self.in_flight_count.fetch_add(1, Ordering::SeqCst);
66    }
67
68    pub fn decrement_in_flight(&self) {
69        self.in_flight_count.fetch_sub(1, Ordering::SeqCst);
70    }
71
72    /// Subscribe to shutdown notifications.
73    ///
74    /// Late subscribers immediately see `true` if shutdown was already requested.
75    pub fn subscribe(&self) -> watch::Receiver<bool> {
76        self.shutdown_tx.subscribe()
77    }
78
79    pub fn should_accept_work(&self) -> bool {
80        !self.shutdown_requested.load(Ordering::SeqCst)
81    }
82
83    pub async fn shutdown(&self) -> forge_core::Result<()> {
84        self.shutdown_requested.store(true, Ordering::SeqCst);
85        self.shutdown_tx.send_replace(true);
86
87        tracing::info!("Starting graceful shutdown");
88
89        if let Err(e) = self.registry.set_status(NodeStatus::Draining).await {
90            tracing::warn!("Failed to set draining status: {}", e);
91        }
92
93        let drain_result = self.wait_for_drain().await;
94        match drain_result {
95            DrainResult::Completed => {
96                tracing::info!("All in-flight requests completed");
97            }
98            DrainResult::Timeout(remaining) => {
99                tracing::warn!(
100                    "Drain timeout reached with {} requests still in-flight",
101                    remaining
102                );
103            }
104        }
105
106        if let Some(ref election) = self.leader_election {
107            if let Err(e) = election.release_leadership().await {
108                tracing::warn!("Failed to release leadership: {}", e);
109            } else {
110                tracing::debug!("Leadership released");
111            }
112        }
113
114        if let Err(e) = self.registry.deregister().await {
115            tracing::warn!("Failed to deregister from cluster: {}", e);
116        }
117
118        tracing::info!("Graceful shutdown complete");
119        Ok(())
120    }
121
122    async fn wait_for_drain(&self) -> DrainResult {
123        let deadline = tokio::time::Instant::now() + self.config.drain_timeout;
124
125        loop {
126            let count = self.in_flight_count.load(Ordering::SeqCst);
127
128            if count == 0 {
129                return DrainResult::Completed;
130            }
131
132            if tokio::time::Instant::now() >= deadline {
133                return DrainResult::Timeout(count);
134            }
135
136            tokio::time::sleep(self.config.poll_interval).await;
137        }
138    }
139}
140
141/// Result of drain operation.
142#[derive(Debug)]
143enum DrainResult {
144    Completed,
145    Timeout(u32),
146}
147
148/// RAII guard for tracking in-flight requests.
149pub struct InFlightGuard {
150    shutdown: Arc<GracefulShutdown>,
151}
152
153impl InFlightGuard {
154    /// Returns `None` if shutdown is in progress.
155    pub fn try_new(shutdown: Arc<GracefulShutdown>) -> Option<Self> {
156        if shutdown.should_accept_work() {
157            shutdown.increment_in_flight();
158            Some(Self { shutdown })
159        } else {
160            None
161        }
162    }
163}
164
165impl Drop for InFlightGuard {
166    fn drop(&mut self) {
167        self.shutdown.decrement_in_flight();
168    }
169}
170
171#[cfg(test)]
172#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
173mod tests {
174    use super::*;
175    use forge_core::cluster::{NodeInfo, NodeRole};
176    use sqlx::postgres::PgPoolOptions;
177    use std::net::{IpAddr, Ipv4Addr};
178
179    fn make_shutdown() -> Arc<GracefulShutdown> {
180        // `connect_lazy` never opens the socket, so we can build a NodeRegistry
181        // without a live Postgres. None of the methods exercised below touch
182        // the pool — they only read/write atomics and the broadcast channel.
183        let pool = PgPoolOptions::new()
184            .connect_lazy("postgres://localhost:1/never")
185            .unwrap();
186        let node = NodeInfo::new_local(
187            "test-host".to_string(),
188            IpAddr::V4(Ipv4Addr::LOCALHOST),
189            9081,
190            9082,
191            vec![NodeRole::Gateway],
192            vec!["default".to_string()],
193            "test".to_string(),
194        );
195        let registry = Arc::new(NodeRegistry::new(pool, node));
196        Arc::new(GracefulShutdown::new(
197            registry,
198            None,
199            ShutdownConfig::default(),
200        ))
201    }
202
203    #[test]
204    fn test_shutdown_config_default() {
205        let config = ShutdownConfig::default();
206        assert_eq!(config.drain_timeout, Duration::from_secs(30));
207        assert_eq!(config.poll_interval, Duration::from_millis(100));
208    }
209
210    #[tokio::test]
211    async fn fresh_shutdown_accepts_work_and_has_zero_in_flight() {
212        let sd = make_shutdown();
213        assert!(!sd.is_shutdown_requested());
214        assert!(sd.should_accept_work());
215        assert_eq!(sd.in_flight_count(), 0);
216    }
217
218    #[tokio::test]
219    async fn in_flight_counter_increments_and_decrements() {
220        let sd = make_shutdown();
221        sd.increment_in_flight();
222        sd.increment_in_flight();
223        assert_eq!(sd.in_flight_count(), 2);
224        sd.decrement_in_flight();
225        assert_eq!(sd.in_flight_count(), 1);
226        sd.decrement_in_flight();
227        assert_eq!(sd.in_flight_count(), 0);
228    }
229
230    #[tokio::test]
231    async fn in_flight_guard_tracks_counter_via_raii() {
232        let sd = make_shutdown();
233        {
234            let _g1 = InFlightGuard::try_new(sd.clone()).expect("should admit work");
235            let _g2 = InFlightGuard::try_new(sd.clone()).expect("should admit work");
236            assert_eq!(sd.in_flight_count(), 2);
237        }
238        // Both guards dropped — counter back to zero.
239        assert_eq!(sd.in_flight_count(), 0);
240    }
241
242    #[tokio::test]
243    async fn in_flight_guard_refuses_work_after_shutdown_flag_set() {
244        let sd = make_shutdown();
245        // Flip the flag directly — emulates state after `shutdown()` ran past
246        // step 1 without needing the registry/DB calls.
247        sd.shutdown_requested.store(true, Ordering::SeqCst);
248        assert!(!sd.should_accept_work());
249        assert!(InFlightGuard::try_new(sd.clone()).is_none());
250        // Counter must not have been incremented by the refused attempt.
251        assert_eq!(sd.in_flight_count(), 0);
252    }
253
254    #[tokio::test]
255    async fn subscribe_returns_independent_receivers() {
256        let sd = make_shutdown();
257        let mut r1 = sd.subscribe();
258        let mut r2 = sd.subscribe();
259        // Both should see the state change.
260        sd.shutdown_tx.send_replace(true);
261        assert!(r1.changed().await.is_ok());
262        assert!(*r1.borrow());
263        assert!(r2.changed().await.is_ok());
264        assert!(*r2.borrow());
265    }
266
267    #[test]
268    fn shutdown_config_clone_preserves_custom_values() {
269        let original = ShutdownConfig {
270            drain_timeout: Duration::from_millis(250),
271            poll_interval: Duration::from_millis(5),
272        };
273        let cloned = original.clone();
274        assert_eq!(cloned.drain_timeout, Duration::from_millis(250));
275        assert_eq!(cloned.poll_interval, Duration::from_millis(5));
276    }
277
278    #[tokio::test]
279    async fn late_subscribers_see_shutdown_state() {
280        // watch channel replays current value to new subscribers, so late
281        // subscribers immediately observe that shutdown was requested.
282        let sd = make_shutdown();
283        sd.shutdown_tx.send_replace(true);
284
285        let late = sd.subscribe();
286        assert!(
287            *late.borrow(),
288            "late subscriber must see shutdown=true from watch channel"
289        );
290    }
291
292    #[tokio::test]
293    async fn guard_admitted_before_shutdown_still_decrements_after_flag_set() {
294        // Models a request that began serving before shutdown was requested;
295        // when it finishes, the counter must come back to zero so the drain
296        // loop can exit.
297        let sd = make_shutdown();
298        let guard = InFlightGuard::try_new(sd.clone()).expect("admit");
299        assert_eq!(sd.in_flight_count(), 1);
300
301        sd.shutdown_requested.store(true, Ordering::SeqCst);
302        assert!(!sd.should_accept_work(), "no new work after flag set");
303
304        drop(guard);
305        assert_eq!(
306            sd.in_flight_count(),
307            0,
308            "RAII drop must decrement even mid-shutdown"
309        );
310    }
311
312    #[tokio::test]
313    async fn concurrent_increments_and_decrements_keep_counter_consistent() {
314        // Hammer the atomic from multiple tasks; the final balance should be
315        // zero. Tests the SeqCst orderings on the counter under contention.
316        let sd = make_shutdown();
317        let mut handles = Vec::new();
318        for _ in 0..16 {
319            let s = sd.clone();
320            handles.push(tokio::spawn(async move {
321                for _ in 0..50 {
322                    s.increment_in_flight();
323                    s.decrement_in_flight();
324                }
325            }));
326        }
327        for h in handles {
328            h.await.expect("task did not panic");
329        }
330        assert_eq!(sd.in_flight_count(), 0);
331    }
332}