forge_runtime/cluster/
shutdown.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeStatus;
6use tokio::sync::watch;
7
8use super::registry::NodeRegistry;
9use crate::pg::LeaderElection;
10
11#[derive(Debug, Clone)]
13pub struct ShutdownConfig {
14 pub drain_timeout: Duration,
16 pub poll_interval: Duration,
18}
19
20impl Default for ShutdownConfig {
21 fn default() -> Self {
22 Self {
23 drain_timeout: Duration::from_secs(30),
24 poll_interval: Duration::from_millis(100),
25 }
26 }
27}
28
29pub struct GracefulShutdown {
31 registry: Arc<NodeRegistry>,
32 leader_election: Option<Arc<LeaderElection>>,
33 config: ShutdownConfig,
34 shutdown_requested: Arc<AtomicBool>,
35 in_flight_count: Arc<AtomicU32>,
36 shutdown_tx: watch::Sender<bool>,
37}
38
39impl GracefulShutdown {
40 pub fn new(
41 registry: Arc<NodeRegistry>,
42 leader_election: Option<Arc<LeaderElection>>,
43 config: ShutdownConfig,
44 ) -> Self {
45 let (shutdown_tx, _) = watch::channel(false);
46 Self {
47 registry,
48 leader_election,
49 config,
50 shutdown_requested: Arc::new(AtomicBool::new(false)),
51 in_flight_count: Arc::new(AtomicU32::new(0)),
52 shutdown_tx,
53 }
54 }
55
56 pub fn is_shutdown_requested(&self) -> bool {
57 self.shutdown_requested.load(Ordering::SeqCst)
58 }
59
60 pub fn in_flight_count(&self) -> u32 {
61 self.in_flight_count.load(Ordering::SeqCst)
62 }
63
64 pub fn increment_in_flight(&self) {
65 self.in_flight_count.fetch_add(1, Ordering::SeqCst);
66 }
67
68 pub fn decrement_in_flight(&self) {
69 self.in_flight_count.fetch_sub(1, Ordering::SeqCst);
70 }
71
72 pub fn subscribe(&self) -> watch::Receiver<bool> {
76 self.shutdown_tx.subscribe()
77 }
78
79 pub fn should_accept_work(&self) -> bool {
80 !self.shutdown_requested.load(Ordering::SeqCst)
81 }
82
83 pub async fn shutdown(&self) -> forge_core::Result<()> {
84 self.shutdown_requested.store(true, Ordering::SeqCst);
85 self.shutdown_tx.send_replace(true);
86
87 tracing::info!("Starting graceful shutdown");
88
89 if let Err(e) = self.registry.set_status(NodeStatus::Draining).await {
90 tracing::warn!("Failed to set draining status: {}", e);
91 }
92
93 let drain_result = self.wait_for_drain().await;
94 match drain_result {
95 DrainResult::Completed => {
96 tracing::info!("All in-flight requests completed");
97 }
98 DrainResult::Timeout(remaining) => {
99 tracing::warn!(
100 "Drain timeout reached with {} requests still in-flight",
101 remaining
102 );
103 }
104 }
105
106 if let Some(ref election) = self.leader_election {
107 if let Err(e) = election.release_leadership().await {
108 tracing::warn!("Failed to release leadership: {}", e);
109 } else {
110 tracing::debug!("Leadership released");
111 }
112 }
113
114 if let Err(e) = self.registry.deregister().await {
115 tracing::warn!("Failed to deregister from cluster: {}", e);
116 }
117
118 tracing::info!("Graceful shutdown complete");
119 Ok(())
120 }
121
122 async fn wait_for_drain(&self) -> DrainResult {
123 let deadline = tokio::time::Instant::now() + self.config.drain_timeout;
124
125 loop {
126 let count = self.in_flight_count.load(Ordering::SeqCst);
127
128 if count == 0 {
129 return DrainResult::Completed;
130 }
131
132 if tokio::time::Instant::now() >= deadline {
133 return DrainResult::Timeout(count);
134 }
135
136 tokio::time::sleep(self.config.poll_interval).await;
137 }
138 }
139}
140
141#[derive(Debug)]
143enum DrainResult {
144 Completed,
145 Timeout(u32),
146}
147
148pub struct InFlightGuard {
150 shutdown: Arc<GracefulShutdown>,
151}
152
153impl InFlightGuard {
154 pub fn try_new(shutdown: Arc<GracefulShutdown>) -> Option<Self> {
156 if shutdown.should_accept_work() {
157 shutdown.increment_in_flight();
158 Some(Self { shutdown })
159 } else {
160 None
161 }
162 }
163}
164
165impl Drop for InFlightGuard {
166 fn drop(&mut self) {
167 self.shutdown.decrement_in_flight();
168 }
169}
170
171#[cfg(test)]
172#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
173mod tests {
174 use super::*;
175 use forge_core::cluster::{NodeInfo, NodeRole};
176 use sqlx::postgres::PgPoolOptions;
177 use std::net::{IpAddr, Ipv4Addr};
178
179 fn make_shutdown() -> Arc<GracefulShutdown> {
180 let pool = PgPoolOptions::new()
184 .connect_lazy("postgres://localhost:1/never")
185 .unwrap();
186 let node = NodeInfo::new_local(
187 "test-host".to_string(),
188 IpAddr::V4(Ipv4Addr::LOCALHOST),
189 9081,
190 9082,
191 vec![NodeRole::Gateway],
192 vec!["default".to_string()],
193 "test".to_string(),
194 );
195 let registry = Arc::new(NodeRegistry::new(pool, node));
196 Arc::new(GracefulShutdown::new(
197 registry,
198 None,
199 ShutdownConfig::default(),
200 ))
201 }
202
203 #[test]
204 fn test_shutdown_config_default() {
205 let config = ShutdownConfig::default();
206 assert_eq!(config.drain_timeout, Duration::from_secs(30));
207 assert_eq!(config.poll_interval, Duration::from_millis(100));
208 }
209
210 #[tokio::test]
211 async fn fresh_shutdown_accepts_work_and_has_zero_in_flight() {
212 let sd = make_shutdown();
213 assert!(!sd.is_shutdown_requested());
214 assert!(sd.should_accept_work());
215 assert_eq!(sd.in_flight_count(), 0);
216 }
217
218 #[tokio::test]
219 async fn in_flight_counter_increments_and_decrements() {
220 let sd = make_shutdown();
221 sd.increment_in_flight();
222 sd.increment_in_flight();
223 assert_eq!(sd.in_flight_count(), 2);
224 sd.decrement_in_flight();
225 assert_eq!(sd.in_flight_count(), 1);
226 sd.decrement_in_flight();
227 assert_eq!(sd.in_flight_count(), 0);
228 }
229
230 #[tokio::test]
231 async fn in_flight_guard_tracks_counter_via_raii() {
232 let sd = make_shutdown();
233 {
234 let _g1 = InFlightGuard::try_new(sd.clone()).expect("should admit work");
235 let _g2 = InFlightGuard::try_new(sd.clone()).expect("should admit work");
236 assert_eq!(sd.in_flight_count(), 2);
237 }
238 assert_eq!(sd.in_flight_count(), 0);
240 }
241
242 #[tokio::test]
243 async fn in_flight_guard_refuses_work_after_shutdown_flag_set() {
244 let sd = make_shutdown();
245 sd.shutdown_requested.store(true, Ordering::SeqCst);
248 assert!(!sd.should_accept_work());
249 assert!(InFlightGuard::try_new(sd.clone()).is_none());
250 assert_eq!(sd.in_flight_count(), 0);
252 }
253
254 #[tokio::test]
255 async fn subscribe_returns_independent_receivers() {
256 let sd = make_shutdown();
257 let mut r1 = sd.subscribe();
258 let mut r2 = sd.subscribe();
259 sd.shutdown_tx.send_replace(true);
261 assert!(r1.changed().await.is_ok());
262 assert!(*r1.borrow());
263 assert!(r2.changed().await.is_ok());
264 assert!(*r2.borrow());
265 }
266
267 #[test]
268 fn shutdown_config_clone_preserves_custom_values() {
269 let original = ShutdownConfig {
270 drain_timeout: Duration::from_millis(250),
271 poll_interval: Duration::from_millis(5),
272 };
273 let cloned = original.clone();
274 assert_eq!(cloned.drain_timeout, Duration::from_millis(250));
275 assert_eq!(cloned.poll_interval, Duration::from_millis(5));
276 }
277
278 #[tokio::test]
279 async fn late_subscribers_see_shutdown_state() {
280 let sd = make_shutdown();
283 sd.shutdown_tx.send_replace(true);
284
285 let late = sd.subscribe();
286 assert!(
287 *late.borrow(),
288 "late subscriber must see shutdown=true from watch channel"
289 );
290 }
291
292 #[tokio::test]
293 async fn guard_admitted_before_shutdown_still_decrements_after_flag_set() {
294 let sd = make_shutdown();
298 let guard = InFlightGuard::try_new(sd.clone()).expect("admit");
299 assert_eq!(sd.in_flight_count(), 1);
300
301 sd.shutdown_requested.store(true, Ordering::SeqCst);
302 assert!(!sd.should_accept_work(), "no new work after flag set");
303
304 drop(guard);
305 assert_eq!(
306 sd.in_flight_count(),
307 0,
308 "RAII drop must decrement even mid-shutdown"
309 );
310 }
311
312 #[tokio::test]
313 async fn concurrent_increments_and_decrements_keep_counter_consistent() {
314 let sd = make_shutdown();
317 let mut handles = Vec::new();
318 for _ in 0..16 {
319 let s = sd.clone();
320 handles.push(tokio::spawn(async move {
321 for _ in 0..50 {
322 s.increment_in_flight();
323 s.decrement_in_flight();
324 }
325 }));
326 }
327 for h in handles {
328 h.await.expect("task did not panic");
329 }
330 assert_eq!(sd.in_flight_count(), 0);
331 }
332}