Skip to main content

forge_runtime/cluster/
shutdown.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeStatus;
6use tokio::sync::broadcast;
7
8use super::leader::LeaderElection;
9use super::registry::NodeRegistry;
10
11/// Graceful shutdown configuration.
12#[derive(Debug, Clone)]
13pub struct ShutdownConfig {
14    /// Timeout for waiting on in-flight requests.
15    pub drain_timeout: Duration,
16    /// How often to check for completion.
17    pub poll_interval: Duration,
18}
19
20impl Default for ShutdownConfig {
21    fn default() -> Self {
22        Self {
23            drain_timeout: Duration::from_secs(30),
24            poll_interval: Duration::from_millis(100),
25        }
26    }
27}
28
29/// Graceful shutdown coordinator.
30pub struct GracefulShutdown {
31    registry: Arc<NodeRegistry>,
32    leader_election: Option<Arc<LeaderElection>>,
33    config: ShutdownConfig,
34    shutdown_requested: Arc<AtomicBool>,
35    in_flight_count: Arc<AtomicU32>,
36    shutdown_tx: broadcast::Sender<()>,
37}
38
39impl GracefulShutdown {
40    /// Create a new graceful shutdown coordinator.
41    pub fn new(
42        registry: Arc<NodeRegistry>,
43        leader_election: Option<Arc<LeaderElection>>,
44        config: ShutdownConfig,
45    ) -> Self {
46        let (shutdown_tx, _) = broadcast::channel(1);
47        Self {
48            registry,
49            leader_election,
50            config,
51            shutdown_requested: Arc::new(AtomicBool::new(false)),
52            in_flight_count: Arc::new(AtomicU32::new(0)),
53            shutdown_tx,
54        }
55    }
56
57    /// Check if shutdown has been requested.
58    pub fn is_shutdown_requested(&self) -> bool {
59        self.shutdown_requested.load(Ordering::SeqCst)
60    }
61
62    /// Get the current in-flight count.
63    pub fn in_flight_count(&self) -> u32 {
64        self.in_flight_count.load(Ordering::SeqCst)
65    }
66
67    /// Increment the in-flight counter.
68    pub fn increment_in_flight(&self) {
69        self.in_flight_count.fetch_add(1, Ordering::SeqCst);
70    }
71
72    /// Decrement the in-flight counter.
73    pub fn decrement_in_flight(&self) {
74        self.in_flight_count.fetch_sub(1, Ordering::SeqCst);
75    }
76
77    /// Subscribe to shutdown notifications.
78    pub fn subscribe(&self) -> broadcast::Receiver<()> {
79        self.shutdown_tx.subscribe()
80    }
81
82    /// Check if new work should be accepted.
83    pub fn should_accept_work(&self) -> bool {
84        !self.shutdown_requested.load(Ordering::SeqCst)
85    }
86
87    /// Perform graceful shutdown.
88    pub async fn shutdown(&self) -> forge_core::Result<()> {
89        // Mark shutdown as requested
90        self.shutdown_requested.store(true, Ordering::SeqCst);
91
92        // Notify all listeners
93        let _ = self.shutdown_tx.send(());
94
95        tracing::info!("Starting graceful shutdown");
96
97        // 1. Set status to draining
98        if let Err(e) = self.registry.set_status(NodeStatus::Draining).await {
99            tracing::warn!("Failed to set draining status: {}", e);
100        }
101
102        // 2. Wait for in-flight requests with timeout
103        let drain_result = self.wait_for_drain().await;
104        match drain_result {
105            DrainResult::Completed => {
106                tracing::info!("All in-flight requests completed");
107            }
108            DrainResult::Timeout(remaining) => {
109                tracing::warn!(
110                    "Drain timeout reached with {} requests still in-flight",
111                    remaining
112                );
113            }
114        }
115
116        // 3. Release leadership explicitly so another node can take over immediately
117        if let Some(ref election) = self.leader_election {
118            if let Err(e) = election.release_leadership().await {
119                tracing::warn!("Failed to release leadership: {}", e);
120            } else {
121                tracing::debug!("Leadership released");
122            }
123        }
124
125        // 4. Deregister from cluster
126        if let Err(e) = self.registry.deregister().await {
127            tracing::warn!("Failed to deregister from cluster: {}", e);
128        }
129
130        tracing::info!("Graceful shutdown complete");
131        Ok(())
132    }
133
134    /// Wait for all in-flight requests to complete.
135    async fn wait_for_drain(&self) -> DrainResult {
136        let deadline = tokio::time::Instant::now() + self.config.drain_timeout;
137
138        loop {
139            let count = self.in_flight_count.load(Ordering::SeqCst);
140
141            if count == 0 {
142                return DrainResult::Completed;
143            }
144
145            if tokio::time::Instant::now() >= deadline {
146                return DrainResult::Timeout(count);
147            }
148
149            tokio::time::sleep(self.config.poll_interval).await;
150        }
151    }
152}
153
154/// Result of drain operation.
155#[derive(Debug)]
156enum DrainResult {
157    /// All requests completed.
158    Completed,
159    /// Timeout reached with remaining requests.
160    Timeout(u32),
161}
162
163/// RAII guard for tracking in-flight requests.
164pub struct InFlightGuard {
165    shutdown: Arc<GracefulShutdown>,
166}
167
168impl InFlightGuard {
169    /// Create a new in-flight guard.
170    /// Returns None if shutdown is in progress.
171    pub fn try_new(shutdown: Arc<GracefulShutdown>) -> Option<Self> {
172        if shutdown.should_accept_work() {
173            shutdown.increment_in_flight();
174            Some(Self { shutdown })
175        } else {
176            None
177        }
178    }
179}
180
181impl Drop for InFlightGuard {
182    fn drop(&mut self) {
183        self.shutdown.decrement_in_flight();
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_shutdown_config_default() {
193        let config = ShutdownConfig::default();
194        assert_eq!(config.drain_timeout, Duration::from_secs(30));
195        assert_eq!(config.poll_interval, Duration::from_millis(100));
196    }
197}