forge_runtime/cluster/
shutdown.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeStatus;
6use tokio::sync::broadcast;
7
8use super::leader::LeaderElection;
9use super::registry::NodeRegistry;
10
11#[derive(Debug, Clone)]
13pub struct ShutdownConfig {
14 pub drain_timeout: Duration,
16 pub poll_interval: Duration,
18}
19
20impl Default for ShutdownConfig {
21 fn default() -> Self {
22 Self {
23 drain_timeout: Duration::from_secs(30),
24 poll_interval: Duration::from_millis(100),
25 }
26 }
27}
28
29pub struct GracefulShutdown {
31 registry: Arc<NodeRegistry>,
32 leader_election: Option<Arc<LeaderElection>>,
33 config: ShutdownConfig,
34 shutdown_requested: Arc<AtomicBool>,
35 in_flight_count: Arc<AtomicU32>,
36 shutdown_tx: broadcast::Sender<()>,
37}
38
39impl GracefulShutdown {
40 pub fn new(
42 registry: Arc<NodeRegistry>,
43 leader_election: Option<Arc<LeaderElection>>,
44 config: ShutdownConfig,
45 ) -> Self {
46 let (shutdown_tx, _) = broadcast::channel(1);
47 Self {
48 registry,
49 leader_election,
50 config,
51 shutdown_requested: Arc::new(AtomicBool::new(false)),
52 in_flight_count: Arc::new(AtomicU32::new(0)),
53 shutdown_tx,
54 }
55 }
56
57 pub fn is_shutdown_requested(&self) -> bool {
59 self.shutdown_requested.load(Ordering::SeqCst)
60 }
61
62 pub fn in_flight_count(&self) -> u32 {
64 self.in_flight_count.load(Ordering::SeqCst)
65 }
66
67 pub fn increment_in_flight(&self) {
69 self.in_flight_count.fetch_add(1, Ordering::SeqCst);
70 }
71
72 pub fn decrement_in_flight(&self) {
74 self.in_flight_count.fetch_sub(1, Ordering::SeqCst);
75 }
76
77 pub fn subscribe(&self) -> broadcast::Receiver<()> {
79 self.shutdown_tx.subscribe()
80 }
81
82 pub fn should_accept_work(&self) -> bool {
84 !self.shutdown_requested.load(Ordering::SeqCst)
85 }
86
87 pub async fn shutdown(&self) -> forge_core::Result<()> {
89 self.shutdown_requested.store(true, Ordering::SeqCst);
91
92 let _ = self.shutdown_tx.send(());
94
95 tracing::info!("Starting graceful shutdown");
96
97 if let Err(e) = self.registry.set_status(NodeStatus::Draining).await {
99 tracing::warn!("Failed to set draining status: {}", e);
100 }
101
102 let drain_result = self.wait_for_drain().await;
104 match drain_result {
105 DrainResult::Completed => {
106 tracing::info!("All in-flight requests completed");
107 }
108 DrainResult::Timeout(remaining) => {
109 tracing::warn!(
110 "Drain timeout reached with {} requests still in-flight",
111 remaining
112 );
113 }
114 }
115
116 if let Some(ref election) = self.leader_election {
118 if let Err(e) = election.release_leadership().await {
119 tracing::warn!("Failed to release leadership: {}", e);
120 } else {
121 tracing::debug!("Leadership released");
122 }
123 }
124
125 if let Err(e) = self.registry.deregister().await {
127 tracing::warn!("Failed to deregister from cluster: {}", e);
128 }
129
130 tracing::info!("Graceful shutdown complete");
131 Ok(())
132 }
133
134 async fn wait_for_drain(&self) -> DrainResult {
136 let deadline = tokio::time::Instant::now() + self.config.drain_timeout;
137
138 loop {
139 let count = self.in_flight_count.load(Ordering::SeqCst);
140
141 if count == 0 {
142 return DrainResult::Completed;
143 }
144
145 if tokio::time::Instant::now() >= deadline {
146 return DrainResult::Timeout(count);
147 }
148
149 tokio::time::sleep(self.config.poll_interval).await;
150 }
151 }
152}
153
154#[derive(Debug)]
156enum DrainResult {
157 Completed,
159 Timeout(u32),
161}
162
163pub struct InFlightGuard {
165 shutdown: Arc<GracefulShutdown>,
166}
167
168impl InFlightGuard {
169 pub fn try_new(shutdown: Arc<GracefulShutdown>) -> Option<Self> {
172 if shutdown.should_accept_work() {
173 shutdown.increment_in_flight();
174 Some(Self { shutdown })
175 } else {
176 None
177 }
178 }
179}
180
181impl Drop for InFlightGuard {
182 fn drop(&mut self) {
183 self.shutdown.decrement_in_flight();
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_shutdown_config_default() {
193 let config = ShutdownConfig::default();
194 assert_eq!(config.drain_timeout, Duration::from_secs(30));
195 assert_eq!(config.poll_interval, Duration::from_millis(100));
196 }
197}