Skip to main content

forge_runtime/cluster/
heartbeat.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeId;
6use forge_core::config::cluster::ClusterConfig;
7use tokio::sync::watch;
8
9/// Heartbeat loop configuration.
10#[derive(Debug, Clone)]
11pub struct HeartbeatConfig {
12    /// Interval between heartbeats.
13    pub interval: Duration,
14    /// Threshold for marking nodes as dead.
15    pub dead_threshold: Duration,
16    /// Whether to mark dead nodes.
17    pub mark_dead_nodes: bool,
18    /// Maximum heartbeat interval when cluster is stable.
19    pub max_interval: Duration,
20}
21
22impl Default for HeartbeatConfig {
23    fn default() -> Self {
24        Self {
25            interval: Duration::from_secs(5),
26            dead_threshold: Duration::from_secs(15),
27            mark_dead_nodes: true,
28            max_interval: Duration::from_secs(60),
29        }
30    }
31}
32
33impl HeartbeatConfig {
34    /// Create a HeartbeatConfig from the user-facing ClusterConfig.
35    pub fn from_cluster_config(cluster: &ClusterConfig) -> Self {
36        use forge_core::config::cluster::DiscoveryMethod;
37
38        if cluster.discovery != DiscoveryMethod::Postgres {
39            tracing::warn!(
40                discovery = ?cluster.discovery,
41                "Only Postgres discovery is currently supported; ignoring configured discovery method"
42            );
43        }
44
45        Self {
46            interval: Duration::from_secs(cluster.heartbeat_interval_secs),
47            dead_threshold: Duration::from_secs(cluster.dead_threshold_secs),
48            mark_dead_nodes: true,
49            max_interval: Duration::from_secs(cluster.heartbeat_interval_secs * 12),
50        }
51    }
52}
53
54/// Heartbeat loop for cluster health.
55pub struct HeartbeatLoop {
56    pool: sqlx::PgPool,
57    node_id: NodeId,
58    config: HeartbeatConfig,
59    running: Arc<AtomicBool>,
60    shutdown_tx: watch::Sender<bool>,
61    shutdown_rx: watch::Receiver<bool>,
62    current_interval_ms: AtomicU64,
63    stable_count: AtomicU32,
64    last_active_count: AtomicU32,
65}
66
67impl HeartbeatLoop {
68    /// Create a new heartbeat loop.
69    pub fn new(pool: sqlx::PgPool, node_id: NodeId, config: HeartbeatConfig) -> Self {
70        let (shutdown_tx, shutdown_rx) = watch::channel(false);
71        let interval_ms = config.interval.as_millis() as u64;
72        Self {
73            pool,
74            node_id,
75            config,
76            running: Arc::new(AtomicBool::new(false)),
77            shutdown_tx,
78            shutdown_rx,
79            current_interval_ms: AtomicU64::new(interval_ms),
80            stable_count: AtomicU32::new(0),
81            last_active_count: AtomicU32::new(0),
82        }
83    }
84
85    /// Check if the loop is running.
86    pub fn is_running(&self) -> bool {
87        self.running.load(Ordering::SeqCst)
88    }
89
90    /// Get a shutdown receiver.
91    pub fn shutdown_receiver(&self) -> watch::Receiver<bool> {
92        self.shutdown_rx.clone()
93    }
94
95    /// Stop the heartbeat loop.
96    pub fn stop(&self) {
97        let _ = self.shutdown_tx.send(true);
98        self.running.store(false, Ordering::SeqCst);
99    }
100
101    /// Run the heartbeat loop.
102    pub async fn run(&self) {
103        self.running.store(true, Ordering::SeqCst);
104        let mut shutdown_rx = self.shutdown_rx.clone();
105
106        loop {
107            let interval = self.current_interval();
108            tokio::select! {
109                _ = tokio::time::sleep(interval) => {
110                    // Update our heartbeat
111                    let hb_start = std::time::Instant::now();
112                    if let Err(e) = self.send_heartbeat().await {
113                        tracing::debug!(error = %e, "Failed to send heartbeat");
114                    }
115                    super::metrics::record_heartbeat_latency(hb_start.elapsed().as_secs_f64());
116
117                    // Adjust interval based on cluster stability
118                    self.adjust_interval().await;
119
120                    // Mark dead nodes if enabled
121                    if self.config.mark_dead_nodes
122                        && let Err(e) = self.mark_dead_nodes().await
123                    {
124                        tracing::debug!(error = %e, "Failed to mark dead nodes");
125                    }
126                }
127                _ = shutdown_rx.changed() => {
128                    if *shutdown_rx.borrow() {
129                        tracing::debug!("Heartbeat loop shutting down");
130                        break;
131                    }
132                }
133            }
134        }
135
136        self.running.store(false, Ordering::SeqCst);
137    }
138
139    /// Current adaptive interval.
140    fn current_interval(&self) -> Duration {
141        Duration::from_millis(self.current_interval_ms.load(Ordering::Relaxed))
142    }
143
144    /// Query the number of active nodes in the cluster.
145    async fn active_node_count(&self) -> forge_core::Result<u32> {
146        let row: (i64,) =
147            sqlx::query_as("SELECT COUNT(*) FROM forge_nodes WHERE status = 'active'")
148                .fetch_one(&self.pool)
149                .await
150                .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
151
152        Ok(row.0 as u32)
153    }
154
155    /// Adjust heartbeat interval based on cluster stability.
156    async fn adjust_interval(&self) {
157        let count = match self.active_node_count().await {
158            Ok(c) => c,
159            Err(e) => {
160                tracing::debug!(error = %e, "Failed to query active node count");
161                return;
162            }
163        };
164
165        super::metrics::set_node_counts(count as i64, 0);
166
167        let last = self.last_active_count.load(Ordering::Relaxed);
168        if last != 0 && count == last {
169            let stable = self.stable_count.fetch_add(1, Ordering::Relaxed) + 1;
170            if stable >= 3 {
171                let base_ms = self.config.interval.as_millis() as u64;
172                let max_ms = self.config.max_interval.as_millis() as u64;
173                let cur = self.current_interval_ms.load(Ordering::Relaxed);
174                let next = (cur * 2).min(max_ms).max(base_ms);
175                self.current_interval_ms.store(next, Ordering::Relaxed);
176            }
177        } else {
178            // Membership changed, reset to base interval
179            self.stable_count.store(0, Ordering::Relaxed);
180            let base_ms = self.config.interval.as_millis() as u64;
181            self.current_interval_ms.store(base_ms, Ordering::Relaxed);
182        }
183        self.last_active_count.store(count, Ordering::Relaxed);
184    }
185
186    /// Send a heartbeat update.
187    async fn send_heartbeat(&self) -> forge_core::Result<()> {
188        sqlx::query(
189            r#"
190            UPDATE forge_nodes
191            SET last_heartbeat = NOW()
192            WHERE id = $1
193            "#,
194        )
195        .bind(self.node_id.as_uuid())
196        .execute(&self.pool)
197        .await
198        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
199
200        Ok(())
201    }
202
203    /// Mark stale nodes as dead. Threshold is the greater of the configured
204    /// dead_threshold and 3x the current adaptive interval.
205    async fn mark_dead_nodes(&self) -> forge_core::Result<u64> {
206        let adaptive = self.current_interval().as_secs_f64() * 3.0;
207        let configured = self.config.dead_threshold.as_secs_f64();
208        let threshold_secs = adaptive.max(configured);
209
210        let result = sqlx::query(
211            r#"
212            UPDATE forge_nodes
213            SET status = 'dead'
214            WHERE status = 'active'
215              AND last_heartbeat < NOW() - make_interval(secs => $1)
216            "#,
217        )
218        .bind(threshold_secs)
219        .execute(&self.pool)
220        .await
221        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
222
223        let count = result.rows_affected();
224        if count > 0 {
225            tracing::warn!(count, "Marked nodes as dead");
226            // Update dead node count in metrics (we know `count` nodes just became dead)
227            super::metrics::set_node_counts(
228                self.last_active_count.load(Ordering::Relaxed) as i64,
229                count as i64,
230            );
231        }
232
233        Ok(count)
234    }
235
236    /// Update load metrics.
237    pub async fn update_load(
238        &self,
239        current_connections: u32,
240        current_jobs: u32,
241        cpu_usage: f32,
242        memory_usage: f32,
243    ) -> forge_core::Result<()> {
244        sqlx::query(
245            r#"
246            UPDATE forge_nodes
247            SET current_connections = $2,
248                current_jobs = $3,
249                cpu_usage = $4,
250                memory_usage = $5,
251                last_heartbeat = NOW()
252            WHERE id = $1
253            "#,
254        )
255        .bind(self.node_id.as_uuid())
256        .bind(current_connections as i32)
257        .bind(current_jobs as i32)
258        .bind(cpu_usage)
259        .bind(memory_usage)
260        .execute(&self.pool)
261        .await
262        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
263
264        Ok(())
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_heartbeat_config_default() {
274        let config = HeartbeatConfig::default();
275        assert_eq!(config.interval, Duration::from_secs(5));
276        assert_eq!(config.dead_threshold, Duration::from_secs(15));
277        assert!(config.mark_dead_nodes);
278        assert_eq!(config.max_interval, Duration::from_secs(60));
279    }
280
281    #[test]
282    fn test_heartbeat_config_from_cluster_config() {
283        let cluster = ClusterConfig {
284            heartbeat_interval_secs: 10,
285            dead_threshold_secs: 30,
286            ..ClusterConfig::default()
287        };
288
289        let config = HeartbeatConfig::from_cluster_config(&cluster);
290        assert_eq!(config.interval, Duration::from_secs(10));
291        assert_eq!(config.dead_threshold, Duration::from_secs(30));
292        assert!(config.mark_dead_nodes);
293        assert_eq!(config.max_interval, Duration::from_secs(120));
294    }
295}