Skip to main content

forge_runtime/cluster/
heartbeat.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeId;
6use forge_core::config::cluster::ClusterConfig;
7use tokio::sync::{Mutex, watch};
8
9/// Heartbeat loop configuration.
10#[derive(Debug, Clone)]
11pub struct HeartbeatConfig {
12    pub interval: Duration,
13    pub dead_threshold: Duration,
14    pub mark_dead_nodes: bool,
15    /// Max interval when cluster is stable (adaptive backoff ceiling).
16    pub max_interval: Duration,
17}
18
19impl Default for HeartbeatConfig {
20    fn default() -> Self {
21        Self {
22            interval: Duration::from_secs(5),
23            dead_threshold: Duration::from_secs(15),
24            mark_dead_nodes: true,
25            max_interval: Duration::from_secs(60),
26        }
27    }
28}
29
30impl HeartbeatConfig {
31    pub fn from_cluster_config(cluster: &ClusterConfig) -> Self {
32        use forge_core::config::cluster::DiscoveryMethod;
33
34        match cluster.discovery {
35            DiscoveryMethod::Postgres => {
36                tracing::debug!("Using PostgreSQL-based cluster discovery");
37            }
38            DiscoveryMethod::Dns => {
39                tracing::info!(
40                    dns_name = ?cluster.dns_name,
41                    "Using DNS-based cluster discovery"
42                );
43            }
44            DiscoveryMethod::Kubernetes => {
45                tracing::info!(
46                    dns_name = ?cluster.dns_name,
47                    "Using Kubernetes-based cluster discovery (via headless service DNS)"
48                );
49            }
50            DiscoveryMethod::Static => {
51                tracing::info!(
52                    seed_count = cluster.seed_nodes.len(),
53                    "Using static seed node discovery"
54                );
55            }
56        }
57
58        Self {
59            interval: *cluster.heartbeat_interval,
60            dead_threshold: *cluster.dead_threshold,
61            mark_dead_nodes: true,
62            max_interval: Duration::from_secs(cluster.heartbeat_interval.as_secs() * 12),
63        }
64    }
65}
66
67/// Heartbeat loop using a dedicated connection for pool-exhaustion safety.
68pub struct HeartbeatLoop {
69    pool: sqlx::PgPool,
70    node_id: NodeId,
71    config: HeartbeatConfig,
72    running: Arc<AtomicBool>,
73    shutdown_tx: watch::Sender<bool>,
74    shutdown_rx: watch::Receiver<bool>,
75    current_interval_ms: AtomicU64,
76    stable_count: AtomicU32,
77    last_active_count: AtomicU32,
78    /// Dedicated connection held outside the shared pool for liveness safety.
79    heartbeat_conn: Mutex<sqlx::pool::PoolConnection<sqlx::Postgres>>,
80}
81
82impl HeartbeatLoop {
83    /// Acquire a dedicated connection outside the shared pool.
84    pub async fn new(
85        pool: sqlx::PgPool,
86        node_id: NodeId,
87        config: HeartbeatConfig,
88    ) -> forge_core::Result<Self> {
89        let conn = pool
90            .acquire()
91            .await
92            .map_err(forge_core::ForgeError::Database)?;
93        let (shutdown_tx, shutdown_rx) = watch::channel(false);
94        let interval_ms = config.interval.as_millis() as u64;
95        Ok(Self {
96            pool,
97            node_id,
98            config,
99            running: Arc::new(AtomicBool::new(false)),
100            shutdown_tx,
101            shutdown_rx,
102            current_interval_ms: AtomicU64::new(interval_ms),
103            stable_count: AtomicU32::new(0),
104            last_active_count: AtomicU32::new(0),
105            heartbeat_conn: Mutex::new(conn),
106        })
107    }
108
109    pub fn is_running(&self) -> bool {
110        self.running.load(Ordering::SeqCst)
111    }
112
113    pub fn stop(&self) {
114        let _ = self.shutdown_tx.send(true);
115        self.running.store(false, Ordering::SeqCst);
116    }
117
118    pub async fn run(&self) {
119        self.running.store(true, Ordering::SeqCst);
120        let mut shutdown_rx = self.shutdown_rx.clone();
121
122        loop {
123            let interval = self.current_interval();
124            tokio::select! {
125                _ = tokio::time::sleep(interval) => {
126                    let hb_start = std::time::Instant::now();
127                    if let Err(e) = self.send_heartbeat().await {
128                        tracing::debug!(error = %e, "Failed to send heartbeat");
129                    }
130                    super::metrics::record_heartbeat_latency(hb_start.elapsed().as_secs_f64());
131
132                    self.adjust_interval().await;
133
134                    if self.config.mark_dead_nodes
135                        && let Err(e) = self.mark_dead_nodes().await
136                    {
137                        tracing::debug!(error = %e, "Failed to mark dead nodes");
138                    }
139                }
140                _ = shutdown_rx.changed() => {
141                    if *shutdown_rx.borrow() {
142                        tracing::debug!("Heartbeat loop shutting down");
143                        break;
144                    }
145                }
146            }
147        }
148
149        self.running.store(false, Ordering::SeqCst);
150    }
151
152    fn current_interval(&self) -> Duration {
153        Duration::from_millis(self.current_interval_ms.load(Ordering::Relaxed))
154    }
155
156    async fn active_node_count(&self) -> forge_core::Result<u32> {
157        let row = sqlx::query_scalar!("SELECT COUNT(*) FROM forge_nodes WHERE status = 'active'")
158            .fetch_one(&self.pool)
159            .await
160            .map_err(forge_core::ForgeError::Database)?;
161
162        Ok(row.unwrap_or(0) as u32)
163    }
164
165    async fn adjust_interval(&self) {
166        let count = match self.active_node_count().await {
167            Ok(c) => c,
168            Err(e) => {
169                tracing::debug!(error = %e, "Failed to query active node count");
170                return;
171            }
172        };
173
174        super::metrics::set_node_counts(count as i64, 0);
175
176        let last = self.last_active_count.load(Ordering::Relaxed);
177        let stable = self.stable_count.load(Ordering::Relaxed);
178        let cur = self.current_interval_ms.load(Ordering::Relaxed);
179        let base_ms = self.config.interval.as_millis() as u64;
180        let max_ms = self.config.max_interval.as_millis() as u64;
181
182        let (next_ms, next_stable) =
183            next_adaptive_interval(count, last, stable, cur, base_ms, max_ms);
184        self.current_interval_ms.store(next_ms, Ordering::Relaxed);
185        self.stable_count.store(next_stable, Ordering::Relaxed);
186        self.last_active_count.store(count, Ordering::Relaxed);
187    }
188
189    async fn heartbeat_conn(
190        &self,
191    ) -> forge_core::Result<tokio::sync::MutexGuard<'_, sqlx::pool::PoolConnection<sqlx::Postgres>>>
192    {
193        use sqlx::Connection as _;
194        let mut guard = self.heartbeat_conn.lock().await;
195        if guard.ping().await.is_err() {
196            tracing::debug!("Heartbeat connection lost; reconnecting");
197            let new_conn = self
198                .pool
199                .acquire()
200                .await
201                .map_err(forge_core::ForgeError::Database)?;
202            *guard = new_conn;
203        }
204        Ok(guard)
205    }
206
207    async fn send_heartbeat(&self) -> forge_core::Result<()> {
208        let mut conn = self.heartbeat_conn().await?;
209        sqlx::query!(
210            r#"
211            UPDATE forge_nodes
212            SET last_heartbeat = NOW()
213            WHERE id = $1
214            "#,
215            self.node_id.as_uuid(),
216        )
217        .execute(&mut **conn)
218        .await
219        .map_err(forge_core::ForgeError::Database)?;
220
221        Ok(())
222    }
223
224    async fn mark_dead_nodes(&self) -> forge_core::Result<u64> {
225        let threshold_secs =
226            dead_node_threshold(self.current_interval(), self.config.dead_threshold).as_secs_f64();
227
228        let result = sqlx::query!(
229            r#"
230            UPDATE forge_nodes
231            SET status = 'dead'
232            WHERE status = 'active'
233              AND last_heartbeat < NOW() - make_interval(secs => $1)
234            "#,
235            threshold_secs,
236        )
237        .execute(&self.pool)
238        .await
239        .map_err(forge_core::ForgeError::Database)?;
240
241        let count = result.rows_affected();
242        if count > 0 {
243            tracing::warn!(count, "Marked nodes as dead");
244            super::metrics::set_node_counts(
245                self.last_active_count.load(Ordering::Relaxed) as i64,
246                count as i64,
247            );
248        }
249
250        Ok(count)
251    }
252
253    pub async fn update_load(
254        &self,
255        current_connections: u32,
256        current_jobs: u32,
257        cpu_usage: f32,
258        memory_usage: f32,
259    ) -> forge_core::Result<()> {
260        let mut conn = self.heartbeat_conn().await?;
261        sqlx::query!(
262            r#"
263            UPDATE forge_nodes
264            SET current_connections = $2,
265                current_jobs = $3,
266                cpu_usage = $4,
267                memory_usage = $5,
268                last_heartbeat = NOW()
269            WHERE id = $1
270            "#,
271            self.node_id.as_uuid(),
272            current_connections as i32,
273            current_jobs as i32,
274            cpu_usage as f64,
275            memory_usage as f64,
276        )
277        .execute(&mut **conn)
278        .await
279        .map_err(forge_core::ForgeError::Database)?;
280
281        Ok(())
282    }
283}
284
285/// Returns `(next_interval_ms, stable_count)` for the adaptive heartbeat policy.
286fn next_adaptive_interval(
287    observed: u32,
288    last: u32,
289    stable: u32,
290    current_ms: u64,
291    base_ms: u64,
292    max_ms: u64,
293) -> (u64, u32) {
294    if last == 0 || observed != last {
295        return (base_ms, 0);
296    }
297    let new_stable = stable.saturating_add(1);
298    if new_stable >= 3 {
299        let doubled = current_ms.saturating_mul(2).min(max_ms).max(base_ms);
300        (doubled, new_stable)
301    } else {
302        (current_ms, new_stable)
303    }
304}
305
306/// Dead-node threshold: max of 3x adaptive interval and configured threshold.
307fn dead_node_threshold(current_interval: Duration, configured: Duration) -> Duration {
308    let adaptive = current_interval.saturating_mul(3);
309    adaptive.max(configured)
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn heartbeat_config_default_matches_documented_values() {
318        let config = HeartbeatConfig::default();
319        assert_eq!(config.interval, Duration::from_secs(5));
320        assert_eq!(config.dead_threshold, Duration::from_secs(15));
321        assert!(config.mark_dead_nodes);
322        assert_eq!(config.max_interval, Duration::from_secs(60));
323    }
324
325    #[test]
326    fn heartbeat_config_from_cluster_config_propagates_durations() {
327        let mut cluster = ClusterConfig::default();
328        cluster.heartbeat_interval = forge_core::config::DurationStr::new(Duration::from_secs(10));
329        cluster.dead_threshold = forge_core::config::DurationStr::new(Duration::from_secs(30));
330
331        let config = HeartbeatConfig::from_cluster_config(&cluster);
332        assert_eq!(config.interval, Duration::from_secs(10));
333        assert_eq!(config.dead_threshold, Duration::from_secs(30));
334        assert!(config.mark_dead_nodes);
335        // max_interval = heartbeat_interval * 12
336        assert_eq!(config.max_interval, Duration::from_secs(120));
337    }
338
339    #[test]
340    fn adaptive_interval_first_observation_seeds_base() {
341        // last == 0 means "no prior observation". Always reset.
342        let (next, stable) = next_adaptive_interval(5, 0, 0, 9_999, 5_000, 60_000);
343        assert_eq!(next, 5_000);
344        assert_eq!(stable, 0);
345    }
346
347    #[test]
348    fn adaptive_interval_membership_change_resets() {
349        let (next, stable) = next_adaptive_interval(7, 5, 9, 40_000, 5_000, 60_000);
350        assert_eq!(next, 5_000);
351        assert_eq!(stable, 0);
352    }
353
354    #[test]
355    fn adaptive_interval_stable_under_threshold_holds_current() {
356        // stable=0 -> 1, still under 3, interval unchanged.
357        let (next, stable) = next_adaptive_interval(5, 5, 0, 5_000, 5_000, 60_000);
358        assert_eq!(next, 5_000);
359        assert_eq!(stable, 1);
360
361        let (next, stable) = next_adaptive_interval(5, 5, 1, 5_000, 5_000, 60_000);
362        assert_eq!(next, 5_000);
363        assert_eq!(stable, 2);
364    }
365
366    #[test]
367    fn adaptive_interval_doubles_at_third_stable_tick() {
368        // stable goes 2 -> 3, triggers doubling: 5000 * 2 = 10000.
369        let (next, stable) = next_adaptive_interval(5, 5, 2, 5_000, 5_000, 60_000);
370        assert_eq!(next, 10_000);
371        assert_eq!(stable, 3);
372    }
373
374    #[test]
375    fn adaptive_interval_doubles_clamps_to_max() {
376        // 40_000 * 2 = 80_000 > 60_000 max.
377        let (next, stable) = next_adaptive_interval(5, 5, 5, 40_000, 5_000, 60_000);
378        assert_eq!(next, 60_000);
379        assert_eq!(stable, 6);
380    }
381
382    #[test]
383    fn adaptive_interval_doubles_floor_to_base() {
384        // current_ms below base (would only happen if config changed). Floor.
385        let (next, _) = next_adaptive_interval(5, 5, 5, 1_000, 5_000, 60_000);
386        assert_eq!(next, 5_000);
387    }
388
389    #[test]
390    fn adaptive_interval_doubling_saturates_without_overflow() {
391        // current_ms near u64::MAX shouldn't panic on the saturating mul.
392        let (next, _) = next_adaptive_interval(5, 5, 5, u64::MAX - 1, 5_000, 60_000);
393        assert_eq!(next, 60_000);
394    }
395
396    #[test]
397    fn dead_threshold_uses_adaptive_when_larger() {
398        // adaptive = 30s*3 = 90s; configured = 15s. Pick 90s.
399        let got = dead_node_threshold(Duration::from_secs(30), Duration::from_secs(15));
400        assert_eq!(got, Duration::from_secs(90));
401    }
402
403    #[test]
404    fn dead_threshold_uses_configured_when_larger() {
405        // adaptive = 5s*3 = 15s; configured = 60s. Pick 60s.
406        let got = dead_node_threshold(Duration::from_secs(5), Duration::from_secs(60));
407        assert_eq!(got, Duration::from_secs(60));
408    }
409
410    #[test]
411    fn dead_threshold_saturates_on_huge_interval() {
412        // Should not panic on multiplication overflow.
413        let got = dead_node_threshold(Duration::MAX, Duration::from_secs(60));
414        assert!(got >= Duration::from_secs(60));
415    }
416}
417
418#[cfg(all(test, feature = "testcontainers"))]
419#[allow(
420    clippy::unwrap_used,
421    clippy::indexing_slicing,
422    clippy::panic,
423    clippy::disallowed_methods
424)]
425mod integration_tests {
426    use super::*;
427    use forge_core::testing::{IsolatedTestDb, TestDatabase};
428
429    async fn setup_db(test_name: &str) -> IsolatedTestDb {
430        let base = TestDatabase::from_env()
431            .await
432            .expect("Failed to create test database");
433        let db = base
434            .isolated(test_name)
435            .await
436            .expect("Failed to create isolated db");
437        let system_sql = crate::pg::migration::get_all_system_sql();
438        db.run_sql(&system_sql)
439            .await
440            .expect("Failed to apply system schema");
441        db
442    }
443
444    async fn seed_node(pool: &sqlx::PgPool, id: NodeId, status: &str, heartbeat_age_secs: i64) {
445        sqlx::query(
446            r#"
447            INSERT INTO forge_nodes (
448                id, hostname, ip_address, http_port, grpc_port, status, last_heartbeat
449            ) VALUES ($1, $2, $3, $4, $5, $6, NOW() - make_interval(secs => $7))
450            "#,
451        )
452        .bind(id.as_uuid())
453        .bind("test-host")
454        .bind("127.0.0.1")
455        .bind(8080_i32)
456        .bind(8081_i32)
457        .bind(status)
458        .bind(heartbeat_age_secs as f64)
459        .execute(pool)
460        .await
461        .unwrap();
462    }
463
464    async fn loop_for(pool: sqlx::PgPool, node_id: NodeId) -> HeartbeatLoop {
465        HeartbeatLoop::new(
466            pool,
467            node_id,
468            HeartbeatConfig {
469                interval: Duration::from_secs(1),
470                dead_threshold: Duration::from_secs(10),
471                mark_dead_nodes: true,
472                max_interval: Duration::from_secs(60),
473            },
474        )
475        .await
476        .expect("HeartbeatLoop::new")
477    }
478
479    #[tokio::test]
480    async fn send_heartbeat_bumps_last_heartbeat_to_now() {
481        let db = setup_db("hb_send").await;
482        let node = NodeId::new();
483        seed_node(db.pool(), node, "active", 30).await;
484
485        let hb = loop_for(db.pool().clone(), node).await;
486        hb.send_heartbeat().await.unwrap();
487
488        let age: f64 = sqlx::query_scalar(
489            "SELECT EXTRACT(EPOCH FROM (NOW() - last_heartbeat))::float8 FROM forge_nodes WHERE id = $1",
490        )
491        .bind(node.as_uuid())
492        .fetch_one(db.pool())
493        .await
494        .unwrap();
495        assert!(age < 2.0, "heartbeat should be fresh, got age = {age}");
496    }
497
498    #[tokio::test]
499    async fn active_node_count_only_counts_active_status() {
500        let db = setup_db("hb_count").await;
501        let self_id = NodeId::new();
502        seed_node(db.pool(), self_id, "active", 0).await;
503        seed_node(db.pool(), NodeId::new(), "active", 0).await;
504        seed_node(db.pool(), NodeId::new(), "dead", 0).await;
505        seed_node(db.pool(), NodeId::new(), "starting", 0).await;
506
507        let hb = loop_for(db.pool().clone(), self_id).await;
508        let count = hb.active_node_count().await.unwrap();
509        assert_eq!(count, 2);
510    }
511
512    #[tokio::test]
513    async fn mark_dead_nodes_flips_stale_active_nodes() {
514        let db = setup_db("hb_mark_dead").await;
515        let self_id = NodeId::new();
516        let stale = NodeId::new();
517        let fresh = NodeId::new();
518        seed_node(db.pool(), self_id, "active", 0).await;
519        seed_node(db.pool(), stale, "active", 120).await; // far past 30s threshold
520        seed_node(db.pool(), fresh, "active", 0).await;
521
522        let hb = loop_for(db.pool().clone(), self_id).await;
523        let marked = hb.mark_dead_nodes().await.unwrap();
524        assert_eq!(marked, 1);
525
526        let stale_status: String =
527            sqlx::query_scalar("SELECT status FROM forge_nodes WHERE id = $1")
528                .bind(stale.as_uuid())
529                .fetch_one(db.pool())
530                .await
531                .unwrap();
532        assert_eq!(stale_status, "dead");
533
534        let fresh_status: String =
535            sqlx::query_scalar("SELECT status FROM forge_nodes WHERE id = $1")
536                .bind(fresh.as_uuid())
537                .fetch_one(db.pool())
538                .await
539                .unwrap();
540        assert_eq!(fresh_status, "active");
541    }
542
543    #[tokio::test]
544    async fn mark_dead_nodes_does_not_revive_already_dead_nodes() {
545        let db = setup_db("hb_no_revive").await;
546        let self_id = NodeId::new();
547        let already_dead = NodeId::new();
548        seed_node(db.pool(), self_id, "active", 0).await;
549        seed_node(db.pool(), already_dead, "dead", 120).await;
550
551        let hb = loop_for(db.pool().clone(), self_id).await;
552        let marked = hb.mark_dead_nodes().await.unwrap();
553        assert_eq!(marked, 0, "dead nodes should not be re-touched");
554
555        let status: String = sqlx::query_scalar("SELECT status FROM forge_nodes WHERE id = $1")
556            .bind(already_dead.as_uuid())
557            .fetch_one(db.pool())
558            .await
559            .unwrap();
560        assert_eq!(status, "dead");
561    }
562
563    #[tokio::test]
564    async fn update_load_persists_metrics_and_refreshes_heartbeat() {
565        let db = setup_db("hb_update_load").await;
566        let node = NodeId::new();
567        seed_node(db.pool(), node, "active", 30).await;
568
569        let hb = loop_for(db.pool().clone(), node).await;
570        hb.update_load(42, 7, 0.5, 0.25).await.unwrap();
571
572        let (conns, jobs, cpu, mem, age): (i32, i32, Option<f64>, Option<f64>, f64) =
573            sqlx::query_as(
574                r#"
575                SELECT current_connections, current_jobs, cpu_usage, memory_usage,
576                       EXTRACT(EPOCH FROM (NOW() - last_heartbeat))::float8
577                FROM forge_nodes WHERE id = $1
578                "#,
579            )
580            .bind(node.as_uuid())
581            .fetch_one(db.pool())
582            .await
583            .unwrap();
584
585        assert_eq!(conns, 42);
586        assert_eq!(jobs, 7);
587        // f32 -> f64 cast introduces tiny drift; compare loosely.
588        assert!((cpu.unwrap() - 0.5).abs() < 1e-5);
589        assert!((mem.unwrap() - 0.25).abs() < 1e-5);
590        assert!(age < 2.0);
591    }
592}