Skip to main content

forge_runtime/cluster/
heartbeat.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeId;
6use forge_core::config::cluster::ClusterConfig;
7use tokio::sync::watch;
8
9/// Heartbeat loop configuration.
10#[derive(Debug, Clone)]
11pub struct HeartbeatConfig {
12    /// Interval between heartbeats.
13    pub interval: Duration,
14    /// Threshold for marking nodes as dead.
15    pub dead_threshold: Duration,
16    /// Whether to mark dead nodes.
17    pub mark_dead_nodes: bool,
18    /// Maximum heartbeat interval when cluster is stable.
19    pub max_interval: Duration,
20}
21
22impl Default for HeartbeatConfig {
23    fn default() -> Self {
24        Self {
25            interval: Duration::from_secs(5),
26            dead_threshold: Duration::from_secs(15),
27            mark_dead_nodes: true,
28            max_interval: Duration::from_secs(60),
29        }
30    }
31}
32
33impl HeartbeatConfig {
34    /// Create a HeartbeatConfig from the user-facing ClusterConfig.
35    pub fn from_cluster_config(cluster: &ClusterConfig) -> Self {
36        use forge_core::config::cluster::DiscoveryMethod;
37
38        if cluster.discovery != DiscoveryMethod::Postgres {
39            tracing::warn!(
40                discovery = ?cluster.discovery,
41                "Only Postgres discovery is currently supported; ignoring configured discovery method"
42            );
43        }
44
45        Self {
46            interval: Duration::from_secs(cluster.heartbeat_interval_secs),
47            dead_threshold: Duration::from_secs(cluster.dead_threshold_secs),
48            mark_dead_nodes: true,
49            max_interval: Duration::from_secs(cluster.heartbeat_interval_secs * 12),
50        }
51    }
52}
53
54/// Heartbeat loop for cluster health.
55pub struct HeartbeatLoop {
56    pool: sqlx::PgPool,
57    node_id: NodeId,
58    config: HeartbeatConfig,
59    running: Arc<AtomicBool>,
60    shutdown_tx: watch::Sender<bool>,
61    shutdown_rx: watch::Receiver<bool>,
62    current_interval_ms: AtomicU64,
63    stable_count: AtomicU32,
64    last_active_count: AtomicU32,
65}
66
67impl HeartbeatLoop {
68    /// Create a new heartbeat loop.
69    pub fn new(pool: sqlx::PgPool, node_id: NodeId, config: HeartbeatConfig) -> Self {
70        let (shutdown_tx, shutdown_rx) = watch::channel(false);
71        let interval_ms = config.interval.as_millis() as u64;
72        Self {
73            pool,
74            node_id,
75            config,
76            running: Arc::new(AtomicBool::new(false)),
77            shutdown_tx,
78            shutdown_rx,
79            current_interval_ms: AtomicU64::new(interval_ms),
80            stable_count: AtomicU32::new(0),
81            last_active_count: AtomicU32::new(0),
82        }
83    }
84
85    /// Check if the loop is running.
86    pub fn is_running(&self) -> bool {
87        self.running.load(Ordering::SeqCst)
88    }
89
90    /// Get a shutdown receiver.
91    pub fn shutdown_receiver(&self) -> watch::Receiver<bool> {
92        self.shutdown_rx.clone()
93    }
94
95    /// Stop the heartbeat loop.
96    pub fn stop(&self) {
97        let _ = self.shutdown_tx.send(true);
98        self.running.store(false, Ordering::SeqCst);
99    }
100
101    /// Run the heartbeat loop.
102    pub async fn run(&self) {
103        self.running.store(true, Ordering::SeqCst);
104        let mut shutdown_rx = self.shutdown_rx.clone();
105
106        loop {
107            let interval = self.current_interval();
108            tokio::select! {
109                _ = tokio::time::sleep(interval) => {
110                    // Update our heartbeat
111                    let hb_start = std::time::Instant::now();
112                    if let Err(e) = self.send_heartbeat().await {
113                        tracing::debug!(error = %e, "Failed to send heartbeat");
114                    }
115                    super::metrics::record_heartbeat_latency(hb_start.elapsed().as_secs_f64());
116
117                    // Adjust interval based on cluster stability
118                    self.adjust_interval().await;
119
120                    // Mark dead nodes if enabled
121                    if self.config.mark_dead_nodes
122                        && let Err(e) = self.mark_dead_nodes().await
123                    {
124                        tracing::debug!(error = %e, "Failed to mark dead nodes");
125                    }
126                }
127                _ = shutdown_rx.changed() => {
128                    if *shutdown_rx.borrow() {
129                        tracing::debug!("Heartbeat loop shutting down");
130                        break;
131                    }
132                }
133            }
134        }
135
136        self.running.store(false, Ordering::SeqCst);
137    }
138
139    /// Current adaptive interval.
140    fn current_interval(&self) -> Duration {
141        Duration::from_millis(self.current_interval_ms.load(Ordering::Relaxed))
142    }
143
144    /// Query the number of active nodes in the cluster.
145    async fn active_node_count(&self) -> forge_core::Result<u32> {
146        let row = sqlx::query_scalar!("SELECT COUNT(*) FROM forge_nodes WHERE status = 'active'")
147            .fetch_one(&self.pool)
148            .await
149            .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
150
151        Ok(row.unwrap_or(0) as u32)
152    }
153
154    /// Adjust heartbeat interval based on cluster stability.
155    async fn adjust_interval(&self) {
156        let count = match self.active_node_count().await {
157            Ok(c) => c,
158            Err(e) => {
159                tracing::debug!(error = %e, "Failed to query active node count");
160                return;
161            }
162        };
163
164        super::metrics::set_node_counts(count as i64, 0);
165
166        let last = self.last_active_count.load(Ordering::Relaxed);
167        if last != 0 && count == last {
168            let stable = self.stable_count.fetch_add(1, Ordering::Relaxed) + 1;
169            if stable >= 3 {
170                let base_ms = self.config.interval.as_millis() as u64;
171                let max_ms = self.config.max_interval.as_millis() as u64;
172                let cur = self.current_interval_ms.load(Ordering::Relaxed);
173                let next = (cur * 2).min(max_ms).max(base_ms);
174                self.current_interval_ms.store(next, Ordering::Relaxed);
175            }
176        } else {
177            // Membership changed, reset to base interval
178            self.stable_count.store(0, Ordering::Relaxed);
179            let base_ms = self.config.interval.as_millis() as u64;
180            self.current_interval_ms.store(base_ms, Ordering::Relaxed);
181        }
182        self.last_active_count.store(count, Ordering::Relaxed);
183    }
184
185    /// Send a heartbeat update.
186    async fn send_heartbeat(&self) -> forge_core::Result<()> {
187        sqlx::query(
188            r#"
189            UPDATE forge_nodes
190            SET last_heartbeat = NOW()
191            WHERE id = $1
192            "#,
193        )
194        .bind(self.node_id.as_uuid())
195        .execute(&self.pool)
196        .await
197        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
198
199        Ok(())
200    }
201
202    /// Mark stale nodes as dead. Threshold is the greater of the configured
203    /// dead_threshold and 3x the current adaptive interval.
204    async fn mark_dead_nodes(&self) -> forge_core::Result<u64> {
205        let adaptive = self.current_interval().as_secs_f64() * 3.0;
206        let configured = self.config.dead_threshold.as_secs_f64();
207        let threshold_secs = adaptive.max(configured);
208
209        let result = sqlx::query(
210            r#"
211            UPDATE forge_nodes
212            SET status = 'dead'
213            WHERE status = 'active'
214              AND last_heartbeat < NOW() - make_interval(secs => $1)
215            "#,
216        )
217        .bind(threshold_secs)
218        .execute(&self.pool)
219        .await
220        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
221
222        let count = result.rows_affected();
223        if count > 0 {
224            tracing::warn!(count, "Marked nodes as dead");
225            // Update dead node count in metrics (we know `count` nodes just became dead)
226            super::metrics::set_node_counts(
227                self.last_active_count.load(Ordering::Relaxed) as i64,
228                count as i64,
229            );
230        }
231
232        Ok(count)
233    }
234
235    /// Update load metrics.
236    pub async fn update_load(
237        &self,
238        current_connections: u32,
239        current_jobs: u32,
240        cpu_usage: f32,
241        memory_usage: f32,
242    ) -> forge_core::Result<()> {
243        sqlx::query(
244            r#"
245            UPDATE forge_nodes
246            SET current_connections = $2,
247                current_jobs = $3,
248                cpu_usage = $4,
249                memory_usage = $5,
250                last_heartbeat = NOW()
251            WHERE id = $1
252            "#,
253        )
254        .bind(self.node_id.as_uuid())
255        .bind(current_connections as i32)
256        .bind(current_jobs as i32)
257        .bind(cpu_usage)
258        .bind(memory_usage)
259        .execute(&self.pool)
260        .await
261        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
262
263        Ok(())
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_heartbeat_config_default() {
273        let config = HeartbeatConfig::default();
274        assert_eq!(config.interval, Duration::from_secs(5));
275        assert_eq!(config.dead_threshold, Duration::from_secs(15));
276        assert!(config.mark_dead_nodes);
277        assert_eq!(config.max_interval, Duration::from_secs(60));
278    }
279
280    #[test]
281    fn test_heartbeat_config_from_cluster_config() {
282        let cluster = ClusterConfig {
283            heartbeat_interval_secs: 10,
284            dead_threshold_secs: 30,
285            ..ClusterConfig::default()
286        };
287
288        let config = HeartbeatConfig::from_cluster_config(&cluster);
289        assert_eq!(config.interval, Duration::from_secs(10));
290        assert_eq!(config.dead_threshold, Duration::from_secs(30));
291        assert!(config.mark_dead_nodes);
292        assert_eq!(config.max_interval, Duration::from_secs(120));
293    }
294}