forge_runtime/cluster/
shutdown.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeStatus;
6use tokio::sync::broadcast;
7
8use super::leader::LeaderElection;
9use super::registry::NodeRegistry;
10
11#[derive(Debug, Clone)]
13pub struct ShutdownConfig {
14 pub drain_timeout: Duration,
16 pub poll_interval: Duration,
18}
19
20impl Default for ShutdownConfig {
21 fn default() -> Self {
22 Self {
23 drain_timeout: Duration::from_secs(30),
24 poll_interval: Duration::from_millis(100),
25 }
26 }
27}
28
29pub struct GracefulShutdown {
31 registry: Arc<NodeRegistry>,
32 #[allow(dead_code)]
33 leader_election: Option<Arc<LeaderElection>>,
34 config: ShutdownConfig,
35 shutdown_requested: Arc<AtomicBool>,
36 in_flight_count: Arc<AtomicU32>,
37 shutdown_tx: broadcast::Sender<()>,
38}
39
40impl GracefulShutdown {
41 pub fn new(
43 registry: Arc<NodeRegistry>,
44 leader_election: Option<Arc<LeaderElection>>,
45 config: ShutdownConfig,
46 ) -> Self {
47 let (shutdown_tx, _) = broadcast::channel(1);
48 Self {
49 registry,
50 leader_election,
51 config,
52 shutdown_requested: Arc::new(AtomicBool::new(false)),
53 in_flight_count: Arc::new(AtomicU32::new(0)),
54 shutdown_tx,
55 }
56 }
57
58 pub fn is_shutdown_requested(&self) -> bool {
60 self.shutdown_requested.load(Ordering::SeqCst)
61 }
62
63 pub fn in_flight_count(&self) -> u32 {
65 self.in_flight_count.load(Ordering::SeqCst)
66 }
67
68 pub fn increment_in_flight(&self) {
70 self.in_flight_count.fetch_add(1, Ordering::SeqCst);
71 }
72
73 pub fn decrement_in_flight(&self) {
75 self.in_flight_count.fetch_sub(1, Ordering::SeqCst);
76 }
77
78 pub fn subscribe(&self) -> broadcast::Receiver<()> {
80 self.shutdown_tx.subscribe()
81 }
82
83 pub fn should_accept_work(&self) -> bool {
85 !self.shutdown_requested.load(Ordering::SeqCst)
86 }
87
88 pub async fn shutdown(&self) -> forge_core::Result<()> {
90 self.shutdown_requested.store(true, Ordering::SeqCst);
92
93 let _ = self.shutdown_tx.send(());
95
96 tracing::info!("Starting graceful shutdown");
97
98 if let Err(e) = self.registry.set_status(NodeStatus::Draining).await {
100 tracing::warn!("Failed to set draining status: {}", e);
101 }
102
103 let drain_result = self.wait_for_drain().await;
105 match drain_result {
106 DrainResult::Completed => {
107 tracing::info!("All in-flight requests completed");
108 }
109 DrainResult::Timeout(remaining) => {
110 tracing::warn!(
111 "Drain timeout reached with {} requests still in-flight",
112 remaining
113 );
114 }
115 }
116
117 if let Err(e) = self.registry.deregister().await {
121 tracing::warn!("Failed to deregister from cluster: {}", e);
122 }
123
124 tracing::info!("Graceful shutdown complete");
125 Ok(())
126 }
127
128 async fn wait_for_drain(&self) -> DrainResult {
130 let deadline = tokio::time::Instant::now() + self.config.drain_timeout;
131
132 loop {
133 let count = self.in_flight_count.load(Ordering::SeqCst);
134
135 if count == 0 {
136 return DrainResult::Completed;
137 }
138
139 if tokio::time::Instant::now() >= deadline {
140 return DrainResult::Timeout(count);
141 }
142
143 tokio::time::sleep(self.config.poll_interval).await;
144 }
145 }
146}
147
148#[derive(Debug)]
150enum DrainResult {
151 Completed,
153 Timeout(u32),
155}
156
157pub struct InFlightGuard {
159 shutdown: Arc<GracefulShutdown>,
160}
161
162impl InFlightGuard {
163 pub fn try_new(shutdown: Arc<GracefulShutdown>) -> Option<Self> {
166 if shutdown.should_accept_work() {
167 shutdown.increment_in_flight();
168 Some(Self { shutdown })
169 } else {
170 None
171 }
172 }
173}
174
175impl Drop for InFlightGuard {
176 fn drop(&mut self) {
177 self.shutdown.decrement_in_flight();
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_shutdown_config_default() {
187 let config = ShutdownConfig::default();
188 assert_eq!(config.drain_timeout, Duration::from_secs(30));
189 assert_eq!(config.poll_interval, Duration::from_millis(100));
190 }
191}