1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeId;
6use forge_core::config::cluster::ClusterConfig;
7use tokio::sync::{Mutex, watch};
8
9#[derive(Debug, Clone)]
11pub struct HeartbeatConfig {
12 pub interval: Duration,
13 pub dead_threshold: Duration,
14 pub mark_dead_nodes: bool,
15 pub max_interval: Duration,
17}
18
19impl Default for HeartbeatConfig {
20 fn default() -> Self {
21 Self {
22 interval: Duration::from_secs(5),
23 dead_threshold: Duration::from_secs(15),
24 mark_dead_nodes: true,
25 max_interval: Duration::from_secs(60),
26 }
27 }
28}
29
30impl HeartbeatConfig {
31 pub fn from_cluster_config(cluster: &ClusterConfig) -> Self {
32 use forge_core::config::cluster::DiscoveryMethod;
33
34 match cluster.discovery {
35 DiscoveryMethod::Postgres => {
36 tracing::debug!("Using PostgreSQL-based cluster discovery");
37 }
38 DiscoveryMethod::Dns => {
39 tracing::info!(
40 dns_name = ?cluster.dns_name,
41 "Using DNS-based cluster discovery"
42 );
43 }
44 DiscoveryMethod::Kubernetes => {
45 tracing::info!(
46 dns_name = ?cluster.dns_name,
47 "Using Kubernetes-based cluster discovery (via headless service DNS)"
48 );
49 }
50 DiscoveryMethod::Static => {
51 tracing::info!(
52 seed_count = cluster.seed_nodes.len(),
53 "Using static seed node discovery"
54 );
55 }
56 }
57
58 Self {
59 interval: *cluster.heartbeat_interval,
60 dead_threshold: *cluster.dead_threshold,
61 mark_dead_nodes: true,
62 max_interval: Duration::from_secs(cluster.heartbeat_interval.as_secs() * 12),
63 }
64 }
65}
66
67pub struct HeartbeatLoop {
69 pool: sqlx::PgPool,
70 node_id: NodeId,
71 config: HeartbeatConfig,
72 running: Arc<AtomicBool>,
73 shutdown_tx: watch::Sender<bool>,
74 shutdown_rx: watch::Receiver<bool>,
75 current_interval_ms: AtomicU64,
76 stable_count: AtomicU32,
77 last_active_count: AtomicU32,
78 heartbeat_conn: Mutex<sqlx::pool::PoolConnection<sqlx::Postgres>>,
80}
81
82impl HeartbeatLoop {
83 pub async fn new(
85 pool: sqlx::PgPool,
86 node_id: NodeId,
87 config: HeartbeatConfig,
88 ) -> forge_core::Result<Self> {
89 let conn = pool
90 .acquire()
91 .await
92 .map_err(forge_core::ForgeError::Database)?;
93 let (shutdown_tx, shutdown_rx) = watch::channel(false);
94 let interval_ms = config.interval.as_millis() as u64;
95 Ok(Self {
96 pool,
97 node_id,
98 config,
99 running: Arc::new(AtomicBool::new(false)),
100 shutdown_tx,
101 shutdown_rx,
102 current_interval_ms: AtomicU64::new(interval_ms),
103 stable_count: AtomicU32::new(0),
104 last_active_count: AtomicU32::new(0),
105 heartbeat_conn: Mutex::new(conn),
106 })
107 }
108
109 pub fn is_running(&self) -> bool {
110 self.running.load(Ordering::SeqCst)
111 }
112
113 pub fn stop(&self) {
114 let _ = self.shutdown_tx.send(true);
115 self.running.store(false, Ordering::SeqCst);
116 }
117
118 pub async fn run(&self) {
119 self.running.store(true, Ordering::SeqCst);
120 let mut shutdown_rx = self.shutdown_rx.clone();
121
122 loop {
123 let interval = self.current_interval();
124 tokio::select! {
125 _ = tokio::time::sleep(interval) => {
126 let hb_start = std::time::Instant::now();
127 if let Err(e) = self.send_heartbeat().await {
128 tracing::debug!(error = %e, "Failed to send heartbeat");
129 }
130 super::metrics::record_heartbeat_latency(hb_start.elapsed().as_secs_f64());
131
132 self.adjust_interval().await;
133
134 if self.config.mark_dead_nodes
135 && let Err(e) = self.mark_dead_nodes().await
136 {
137 tracing::debug!(error = %e, "Failed to mark dead nodes");
138 }
139 }
140 _ = shutdown_rx.changed() => {
141 if *shutdown_rx.borrow() {
142 tracing::debug!("Heartbeat loop shutting down");
143 break;
144 }
145 }
146 }
147 }
148
149 self.running.store(false, Ordering::SeqCst);
150 }
151
152 fn current_interval(&self) -> Duration {
153 Duration::from_millis(self.current_interval_ms.load(Ordering::Relaxed))
154 }
155
156 async fn active_node_count(&self) -> forge_core::Result<u32> {
157 let row = sqlx::query_scalar!("SELECT COUNT(*) FROM forge_nodes WHERE status = 'active'")
158 .fetch_one(&self.pool)
159 .await
160 .map_err(forge_core::ForgeError::Database)?;
161
162 Ok(row.unwrap_or(0) as u32)
163 }
164
165 async fn adjust_interval(&self) {
166 let count = match self.active_node_count().await {
167 Ok(c) => c,
168 Err(e) => {
169 tracing::debug!(error = %e, "Failed to query active node count");
170 return;
171 }
172 };
173
174 super::metrics::set_node_counts(count as i64, 0);
175
176 let last = self.last_active_count.load(Ordering::Relaxed);
177 let stable = self.stable_count.load(Ordering::Relaxed);
178 let cur = self.current_interval_ms.load(Ordering::Relaxed);
179 let base_ms = self.config.interval.as_millis() as u64;
180 let max_ms = self.config.max_interval.as_millis() as u64;
181
182 let (next_ms, next_stable) =
183 next_adaptive_interval(count, last, stable, cur, base_ms, max_ms);
184 self.current_interval_ms.store(next_ms, Ordering::Relaxed);
185 self.stable_count.store(next_stable, Ordering::Relaxed);
186 self.last_active_count.store(count, Ordering::Relaxed);
187 }
188
189 async fn heartbeat_conn(
190 &self,
191 ) -> forge_core::Result<tokio::sync::MutexGuard<'_, sqlx::pool::PoolConnection<sqlx::Postgres>>>
192 {
193 use sqlx::Connection as _;
194 let mut guard = self.heartbeat_conn.lock().await;
195 if guard.ping().await.is_err() {
196 tracing::debug!("Heartbeat connection lost; reconnecting");
197 let new_conn = self
198 .pool
199 .acquire()
200 .await
201 .map_err(forge_core::ForgeError::Database)?;
202 *guard = new_conn;
203 }
204 Ok(guard)
205 }
206
207 async fn send_heartbeat(&self) -> forge_core::Result<()> {
208 let mut conn = self.heartbeat_conn().await?;
209 sqlx::query!(
210 r#"
211 UPDATE forge_nodes
212 SET last_heartbeat = NOW()
213 WHERE id = $1
214 "#,
215 self.node_id.as_uuid(),
216 )
217 .execute(&mut **conn)
218 .await
219 .map_err(forge_core::ForgeError::Database)?;
220
221 Ok(())
222 }
223
224 async fn mark_dead_nodes(&self) -> forge_core::Result<u64> {
225 let threshold_secs =
226 dead_node_threshold(self.current_interval(), self.config.dead_threshold).as_secs_f64();
227
228 let result = sqlx::query!(
229 r#"
230 UPDATE forge_nodes
231 SET status = 'dead'
232 WHERE status = 'active'
233 AND last_heartbeat < NOW() - make_interval(secs => $1)
234 "#,
235 threshold_secs,
236 )
237 .execute(&self.pool)
238 .await
239 .map_err(forge_core::ForgeError::Database)?;
240
241 let count = result.rows_affected();
242 if count > 0 {
243 tracing::warn!(count, "Marked nodes as dead");
244 super::metrics::set_node_counts(
245 self.last_active_count.load(Ordering::Relaxed) as i64,
246 count as i64,
247 );
248 }
249
250 Ok(count)
251 }
252
253 pub async fn update_load(
254 &self,
255 current_connections: u32,
256 current_jobs: u32,
257 cpu_usage: f32,
258 memory_usage: f32,
259 ) -> forge_core::Result<()> {
260 let mut conn = self.heartbeat_conn().await?;
261 sqlx::query!(
262 r#"
263 UPDATE forge_nodes
264 SET current_connections = $2,
265 current_jobs = $3,
266 cpu_usage = $4,
267 memory_usage = $5,
268 last_heartbeat = NOW()
269 WHERE id = $1
270 "#,
271 self.node_id.as_uuid(),
272 current_connections as i32,
273 current_jobs as i32,
274 cpu_usage as f64,
275 memory_usage as f64,
276 )
277 .execute(&mut **conn)
278 .await
279 .map_err(forge_core::ForgeError::Database)?;
280
281 Ok(())
282 }
283}
284
285fn next_adaptive_interval(
287 observed: u32,
288 last: u32,
289 stable: u32,
290 current_ms: u64,
291 base_ms: u64,
292 max_ms: u64,
293) -> (u64, u32) {
294 if last == 0 || observed != last {
295 return (base_ms, 0);
296 }
297 let new_stable = stable.saturating_add(1);
298 if new_stable >= 3 {
299 let doubled = current_ms.saturating_mul(2).min(max_ms).max(base_ms);
300 (doubled, new_stable)
301 } else {
302 (current_ms, new_stable)
303 }
304}
305
306fn dead_node_threshold(current_interval: Duration, configured: Duration) -> Duration {
308 let adaptive = current_interval.saturating_mul(3);
309 adaptive.max(configured)
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn heartbeat_config_default_matches_documented_values() {
318 let config = HeartbeatConfig::default();
319 assert_eq!(config.interval, Duration::from_secs(5));
320 assert_eq!(config.dead_threshold, Duration::from_secs(15));
321 assert!(config.mark_dead_nodes);
322 assert_eq!(config.max_interval, Duration::from_secs(60));
323 }
324
325 #[test]
326 fn heartbeat_config_from_cluster_config_propagates_durations() {
327 let mut cluster = ClusterConfig::default();
328 cluster.heartbeat_interval = forge_core::config::DurationStr::new(Duration::from_secs(10));
329 cluster.dead_threshold = forge_core::config::DurationStr::new(Duration::from_secs(30));
330
331 let config = HeartbeatConfig::from_cluster_config(&cluster);
332 assert_eq!(config.interval, Duration::from_secs(10));
333 assert_eq!(config.dead_threshold, Duration::from_secs(30));
334 assert!(config.mark_dead_nodes);
335 assert_eq!(config.max_interval, Duration::from_secs(120));
337 }
338
339 #[test]
340 fn adaptive_interval_first_observation_seeds_base() {
341 let (next, stable) = next_adaptive_interval(5, 0, 0, 9_999, 5_000, 60_000);
343 assert_eq!(next, 5_000);
344 assert_eq!(stable, 0);
345 }
346
347 #[test]
348 fn adaptive_interval_membership_change_resets() {
349 let (next, stable) = next_adaptive_interval(7, 5, 9, 40_000, 5_000, 60_000);
350 assert_eq!(next, 5_000);
351 assert_eq!(stable, 0);
352 }
353
354 #[test]
355 fn adaptive_interval_stable_under_threshold_holds_current() {
356 let (next, stable) = next_adaptive_interval(5, 5, 0, 5_000, 5_000, 60_000);
358 assert_eq!(next, 5_000);
359 assert_eq!(stable, 1);
360
361 let (next, stable) = next_adaptive_interval(5, 5, 1, 5_000, 5_000, 60_000);
362 assert_eq!(next, 5_000);
363 assert_eq!(stable, 2);
364 }
365
366 #[test]
367 fn adaptive_interval_doubles_at_third_stable_tick() {
368 let (next, stable) = next_adaptive_interval(5, 5, 2, 5_000, 5_000, 60_000);
370 assert_eq!(next, 10_000);
371 assert_eq!(stable, 3);
372 }
373
374 #[test]
375 fn adaptive_interval_doubles_clamps_to_max() {
376 let (next, stable) = next_adaptive_interval(5, 5, 5, 40_000, 5_000, 60_000);
378 assert_eq!(next, 60_000);
379 assert_eq!(stable, 6);
380 }
381
382 #[test]
383 fn adaptive_interval_doubles_floor_to_base() {
384 let (next, _) = next_adaptive_interval(5, 5, 5, 1_000, 5_000, 60_000);
386 assert_eq!(next, 5_000);
387 }
388
389 #[test]
390 fn adaptive_interval_doubling_saturates_without_overflow() {
391 let (next, _) = next_adaptive_interval(5, 5, 5, u64::MAX - 1, 5_000, 60_000);
393 assert_eq!(next, 60_000);
394 }
395
396 #[test]
397 fn dead_threshold_uses_adaptive_when_larger() {
398 let got = dead_node_threshold(Duration::from_secs(30), Duration::from_secs(15));
400 assert_eq!(got, Duration::from_secs(90));
401 }
402
403 #[test]
404 fn dead_threshold_uses_configured_when_larger() {
405 let got = dead_node_threshold(Duration::from_secs(5), Duration::from_secs(60));
407 assert_eq!(got, Duration::from_secs(60));
408 }
409
410 #[test]
411 fn dead_threshold_saturates_on_huge_interval() {
412 let got = dead_node_threshold(Duration::MAX, Duration::from_secs(60));
414 assert!(got >= Duration::from_secs(60));
415 }
416}
417
418#[cfg(all(test, feature = "testcontainers"))]
419#[allow(
420 clippy::unwrap_used,
421 clippy::indexing_slicing,
422 clippy::panic,
423 clippy::disallowed_methods
424)]
425mod integration_tests {
426 use super::*;
427 use forge_core::testing::{IsolatedTestDb, TestDatabase};
428
429 async fn setup_db(test_name: &str) -> IsolatedTestDb {
430 let base = TestDatabase::from_env()
431 .await
432 .expect("Failed to create test database");
433 let db = base
434 .isolated(test_name)
435 .await
436 .expect("Failed to create isolated db");
437 let system_sql = crate::pg::migration::get_all_system_sql();
438 db.run_sql(&system_sql)
439 .await
440 .expect("Failed to apply system schema");
441 db
442 }
443
444 async fn seed_node(pool: &sqlx::PgPool, id: NodeId, status: &str, heartbeat_age_secs: i64) {
445 sqlx::query(
446 r#"
447 INSERT INTO forge_nodes (
448 id, hostname, ip_address, http_port, grpc_port, status, last_heartbeat
449 ) VALUES ($1, $2, $3, $4, $5, $6, NOW() - make_interval(secs => $7))
450 "#,
451 )
452 .bind(id.as_uuid())
453 .bind("test-host")
454 .bind("127.0.0.1")
455 .bind(8080_i32)
456 .bind(8081_i32)
457 .bind(status)
458 .bind(heartbeat_age_secs as f64)
459 .execute(pool)
460 .await
461 .unwrap();
462 }
463
464 async fn loop_for(pool: sqlx::PgPool, node_id: NodeId) -> HeartbeatLoop {
465 HeartbeatLoop::new(
466 pool,
467 node_id,
468 HeartbeatConfig {
469 interval: Duration::from_secs(1),
470 dead_threshold: Duration::from_secs(10),
471 mark_dead_nodes: true,
472 max_interval: Duration::from_secs(60),
473 },
474 )
475 .await
476 .expect("HeartbeatLoop::new")
477 }
478
479 #[tokio::test]
480 async fn send_heartbeat_bumps_last_heartbeat_to_now() {
481 let db = setup_db("hb_send").await;
482 let node = NodeId::new();
483 seed_node(db.pool(), node, "active", 30).await;
484
485 let hb = loop_for(db.pool().clone(), node).await;
486 hb.send_heartbeat().await.unwrap();
487
488 let age: f64 = sqlx::query_scalar(
489 "SELECT EXTRACT(EPOCH FROM (NOW() - last_heartbeat))::float8 FROM forge_nodes WHERE id = $1",
490 )
491 .bind(node.as_uuid())
492 .fetch_one(db.pool())
493 .await
494 .unwrap();
495 assert!(age < 2.0, "heartbeat should be fresh, got age = {age}");
496 }
497
498 #[tokio::test]
499 async fn active_node_count_only_counts_active_status() {
500 let db = setup_db("hb_count").await;
501 let self_id = NodeId::new();
502 seed_node(db.pool(), self_id, "active", 0).await;
503 seed_node(db.pool(), NodeId::new(), "active", 0).await;
504 seed_node(db.pool(), NodeId::new(), "dead", 0).await;
505 seed_node(db.pool(), NodeId::new(), "starting", 0).await;
506
507 let hb = loop_for(db.pool().clone(), self_id).await;
508 let count = hb.active_node_count().await.unwrap();
509 assert_eq!(count, 2);
510 }
511
512 #[tokio::test]
513 async fn mark_dead_nodes_flips_stale_active_nodes() {
514 let db = setup_db("hb_mark_dead").await;
515 let self_id = NodeId::new();
516 let stale = NodeId::new();
517 let fresh = NodeId::new();
518 seed_node(db.pool(), self_id, "active", 0).await;
519 seed_node(db.pool(), stale, "active", 120).await; seed_node(db.pool(), fresh, "active", 0).await;
521
522 let hb = loop_for(db.pool().clone(), self_id).await;
523 let marked = hb.mark_dead_nodes().await.unwrap();
524 assert_eq!(marked, 1);
525
526 let stale_status: String =
527 sqlx::query_scalar("SELECT status FROM forge_nodes WHERE id = $1")
528 .bind(stale.as_uuid())
529 .fetch_one(db.pool())
530 .await
531 .unwrap();
532 assert_eq!(stale_status, "dead");
533
534 let fresh_status: String =
535 sqlx::query_scalar("SELECT status FROM forge_nodes WHERE id = $1")
536 .bind(fresh.as_uuid())
537 .fetch_one(db.pool())
538 .await
539 .unwrap();
540 assert_eq!(fresh_status, "active");
541 }
542
543 #[tokio::test]
544 async fn mark_dead_nodes_does_not_revive_already_dead_nodes() {
545 let db = setup_db("hb_no_revive").await;
546 let self_id = NodeId::new();
547 let already_dead = NodeId::new();
548 seed_node(db.pool(), self_id, "active", 0).await;
549 seed_node(db.pool(), already_dead, "dead", 120).await;
550
551 let hb = loop_for(db.pool().clone(), self_id).await;
552 let marked = hb.mark_dead_nodes().await.unwrap();
553 assert_eq!(marked, 0, "dead nodes should not be re-touched");
554
555 let status: String = sqlx::query_scalar("SELECT status FROM forge_nodes WHERE id = $1")
556 .bind(already_dead.as_uuid())
557 .fetch_one(db.pool())
558 .await
559 .unwrap();
560 assert_eq!(status, "dead");
561 }
562
563 #[tokio::test]
564 async fn update_load_persists_metrics_and_refreshes_heartbeat() {
565 let db = setup_db("hb_update_load").await;
566 let node = NodeId::new();
567 seed_node(db.pool(), node, "active", 30).await;
568
569 let hb = loop_for(db.pool().clone(), node).await;
570 hb.update_load(42, 7, 0.5, 0.25).await.unwrap();
571
572 let (conns, jobs, cpu, mem, age): (i32, i32, Option<f64>, Option<f64>, f64) =
573 sqlx::query_as(
574 r#"
575 SELECT current_connections, current_jobs, cpu_usage, memory_usage,
576 EXTRACT(EPOCH FROM (NOW() - last_heartbeat))::float8
577 FROM forge_nodes WHERE id = $1
578 "#,
579 )
580 .bind(node.as_uuid())
581 .fetch_one(db.pool())
582 .await
583 .unwrap();
584
585 assert_eq!(conns, 42);
586 assert_eq!(jobs, 7);
587 assert!((cpu.unwrap() - 0.5).abs() < 1e-5);
589 assert!((mem.unwrap() - 0.25).abs() < 1e-5);
590 assert!(age < 2.0);
591 }
592}