Skip to main content

forge_runtime/cluster/
heartbeat.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeId;
6use forge_core::config::cluster::ClusterConfig;
7use tokio::sync::watch;
8
9/// Heartbeat loop configuration.
10#[derive(Debug, Clone)]
11pub struct HeartbeatConfig {
12    /// Interval between heartbeats.
13    pub interval: Duration,
14    /// Threshold for marking nodes as dead.
15    pub dead_threshold: Duration,
16    /// Whether to mark dead nodes.
17    pub mark_dead_nodes: bool,
18    /// Maximum heartbeat interval when cluster is stable.
19    pub max_interval: Duration,
20}
21
22impl Default for HeartbeatConfig {
23    fn default() -> Self {
24        Self {
25            interval: Duration::from_secs(5),
26            dead_threshold: Duration::from_secs(15),
27            mark_dead_nodes: true,
28            max_interval: Duration::from_secs(60),
29        }
30    }
31}
32
33impl HeartbeatConfig {
34    /// Create a HeartbeatConfig from the user-facing ClusterConfig.
35    pub fn from_cluster_config(cluster: &ClusterConfig) -> Self {
36        use forge_core::config::cluster::DiscoveryMethod;
37
38        match cluster.discovery {
39            DiscoveryMethod::Postgres => {
40                tracing::debug!("Using PostgreSQL-based cluster discovery");
41            }
42            DiscoveryMethod::Dns => {
43                tracing::info!(
44                    dns_name = ?cluster.dns_name,
45                    "Using DNS-based cluster discovery"
46                );
47            }
48            DiscoveryMethod::Kubernetes => {
49                tracing::info!(
50                    dns_name = ?cluster.dns_name,
51                    "Using Kubernetes-based cluster discovery (via headless service DNS)"
52                );
53            }
54            DiscoveryMethod::Static => {
55                tracing::info!(
56                    seed_count = cluster.seed_nodes.len(),
57                    "Using static seed node discovery"
58                );
59            }
60        }
61
62        Self {
63            interval: Duration::from_secs(cluster.heartbeat_interval_secs),
64            dead_threshold: Duration::from_secs(cluster.dead_threshold_secs),
65            mark_dead_nodes: true,
66            max_interval: Duration::from_secs(cluster.heartbeat_interval_secs * 12),
67        }
68    }
69}
70
71/// Heartbeat loop for cluster health.
72pub struct HeartbeatLoop {
73    pool: sqlx::PgPool,
74    node_id: NodeId,
75    config: HeartbeatConfig,
76    running: Arc<AtomicBool>,
77    shutdown_tx: watch::Sender<bool>,
78    shutdown_rx: watch::Receiver<bool>,
79    current_interval_ms: AtomicU64,
80    stable_count: AtomicU32,
81    last_active_count: AtomicU32,
82}
83
84impl HeartbeatLoop {
85    /// Create a new heartbeat loop.
86    pub fn new(pool: sqlx::PgPool, node_id: NodeId, config: HeartbeatConfig) -> Self {
87        let (shutdown_tx, shutdown_rx) = watch::channel(false);
88        let interval_ms = config.interval.as_millis() as u64;
89        Self {
90            pool,
91            node_id,
92            config,
93            running: Arc::new(AtomicBool::new(false)),
94            shutdown_tx,
95            shutdown_rx,
96            current_interval_ms: AtomicU64::new(interval_ms),
97            stable_count: AtomicU32::new(0),
98            last_active_count: AtomicU32::new(0),
99        }
100    }
101
102    /// Check if the loop is running.
103    pub fn is_running(&self) -> bool {
104        self.running.load(Ordering::SeqCst)
105    }
106
107    /// Get a shutdown receiver.
108    pub fn shutdown_receiver(&self) -> watch::Receiver<bool> {
109        self.shutdown_rx.clone()
110    }
111
112    /// Stop the heartbeat loop.
113    pub fn stop(&self) {
114        let _ = self.shutdown_tx.send(true);
115        self.running.store(false, Ordering::SeqCst);
116    }
117
118    /// Run the heartbeat loop.
119    pub async fn run(&self) {
120        self.running.store(true, Ordering::SeqCst);
121        let mut shutdown_rx = self.shutdown_rx.clone();
122
123        loop {
124            let interval = self.current_interval();
125            tokio::select! {
126                _ = tokio::time::sleep(interval) => {
127                    // Update our heartbeat
128                    let hb_start = std::time::Instant::now();
129                    if let Err(e) = self.send_heartbeat().await {
130                        tracing::debug!(error = %e, "Failed to send heartbeat");
131                    }
132                    super::metrics::record_heartbeat_latency(hb_start.elapsed().as_secs_f64());
133
134                    // Adjust interval based on cluster stability
135                    self.adjust_interval().await;
136
137                    // Mark dead nodes if enabled
138                    if self.config.mark_dead_nodes
139                        && let Err(e) = self.mark_dead_nodes().await
140                    {
141                        tracing::debug!(error = %e, "Failed to mark dead nodes");
142                    }
143                }
144                _ = shutdown_rx.changed() => {
145                    if *shutdown_rx.borrow() {
146                        tracing::debug!("Heartbeat loop shutting down");
147                        break;
148                    }
149                }
150            }
151        }
152
153        self.running.store(false, Ordering::SeqCst);
154    }
155
156    /// Current adaptive interval.
157    fn current_interval(&self) -> Duration {
158        Duration::from_millis(self.current_interval_ms.load(Ordering::Relaxed))
159    }
160
161    /// Query the number of active nodes in the cluster.
162    async fn active_node_count(&self) -> forge_core::Result<u32> {
163        let row = sqlx::query_scalar!("SELECT COUNT(*) FROM forge_nodes WHERE status = 'active'")
164            .fetch_one(&self.pool)
165            .await
166            .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
167
168        Ok(row.unwrap_or(0) as u32)
169    }
170
171    /// Adjust heartbeat interval based on cluster stability.
172    async fn adjust_interval(&self) {
173        let count = match self.active_node_count().await {
174            Ok(c) => c,
175            Err(e) => {
176                tracing::debug!(error = %e, "Failed to query active node count");
177                return;
178            }
179        };
180
181        super::metrics::set_node_counts(count as i64, 0);
182
183        let last = self.last_active_count.load(Ordering::Relaxed);
184        if last != 0 && count == last {
185            let stable = self.stable_count.fetch_add(1, Ordering::Relaxed) + 1;
186            if stable >= 3 {
187                let base_ms = self.config.interval.as_millis() as u64;
188                let max_ms = self.config.max_interval.as_millis() as u64;
189                let cur = self.current_interval_ms.load(Ordering::Relaxed);
190                let next = (cur * 2).min(max_ms).max(base_ms);
191                self.current_interval_ms.store(next, Ordering::Relaxed);
192            }
193        } else {
194            // Membership changed, reset to base interval
195            self.stable_count.store(0, Ordering::Relaxed);
196            let base_ms = self.config.interval.as_millis() as u64;
197            self.current_interval_ms.store(base_ms, Ordering::Relaxed);
198        }
199        self.last_active_count.store(count, Ordering::Relaxed);
200    }
201
202    /// Send a heartbeat update.
203    async fn send_heartbeat(&self) -> forge_core::Result<()> {
204        sqlx::query!(
205            r#"
206            UPDATE forge_nodes
207            SET last_heartbeat = NOW()
208            WHERE id = $1
209            "#,
210            self.node_id.as_uuid(),
211        )
212        .execute(&self.pool)
213        .await
214        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
215
216        Ok(())
217    }
218
219    /// Mark stale nodes as dead. Threshold is the greater of the configured
220    /// dead_threshold and 3x the current adaptive interval.
221    async fn mark_dead_nodes(&self) -> forge_core::Result<u64> {
222        let adaptive = self.current_interval().as_secs_f64() * 3.0;
223        let configured = self.config.dead_threshold.as_secs_f64();
224        let threshold_secs = adaptive.max(configured);
225
226        let result = sqlx::query!(
227            r#"
228            UPDATE forge_nodes
229            SET status = 'dead'
230            WHERE status = 'active'
231              AND last_heartbeat < NOW() - make_interval(secs => $1)
232            "#,
233            threshold_secs,
234        )
235        .execute(&self.pool)
236        .await
237        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
238
239        let count = result.rows_affected();
240        if count > 0 {
241            tracing::warn!(count, "Marked nodes as dead");
242            // Update dead node count in metrics (we know `count` nodes just became dead)
243            super::metrics::set_node_counts(
244                self.last_active_count.load(Ordering::Relaxed) as i64,
245                count as i64,
246            );
247        }
248
249        Ok(count)
250    }
251
252    /// Update load metrics.
253    pub async fn update_load(
254        &self,
255        current_connections: u32,
256        current_jobs: u32,
257        cpu_usage: f32,
258        memory_usage: f32,
259    ) -> forge_core::Result<()> {
260        sqlx::query!(
261            r#"
262            UPDATE forge_nodes
263            SET current_connections = $2,
264                current_jobs = $3,
265                cpu_usage = $4,
266                memory_usage = $5,
267                last_heartbeat = NOW()
268            WHERE id = $1
269            "#,
270            self.node_id.as_uuid(),
271            current_connections as i32,
272            current_jobs as i32,
273            cpu_usage as f64,
274            memory_usage as f64,
275        )
276        .execute(&self.pool)
277        .await
278        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
279
280        Ok(())
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_heartbeat_config_default() {
290        let config = HeartbeatConfig::default();
291        assert_eq!(config.interval, Duration::from_secs(5));
292        assert_eq!(config.dead_threshold, Duration::from_secs(15));
293        assert!(config.mark_dead_nodes);
294        assert_eq!(config.max_interval, Duration::from_secs(60));
295    }
296
297    #[test]
298    fn test_heartbeat_config_from_cluster_config() {
299        let cluster = ClusterConfig {
300            heartbeat_interval_secs: 10,
301            dead_threshold_secs: 30,
302            ..ClusterConfig::default()
303        };
304
305        let config = HeartbeatConfig::from_cluster_config(&cluster);
306        assert_eq!(config.interval, Duration::from_secs(10));
307        assert_eq!(config.dead_threshold, Duration::from_secs(30));
308        assert!(config.mark_dead_nodes);
309        assert_eq!(config.max_interval, Duration::from_secs(120));
310    }
311}