Skip to main content

forge_runtime/cluster/
heartbeat.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeId;
6use tokio::sync::watch;
7
8/// Heartbeat loop configuration.
9#[derive(Debug, Clone)]
10pub struct HeartbeatConfig {
11    /// Interval between heartbeats.
12    pub interval: Duration,
13    /// Threshold for marking nodes as dead.
14    pub dead_threshold: Duration,
15    /// Whether to mark dead nodes.
16    pub mark_dead_nodes: bool,
17    /// Maximum heartbeat interval when cluster is stable.
18    pub max_interval: Duration,
19}
20
21impl Default for HeartbeatConfig {
22    fn default() -> Self {
23        Self {
24            interval: Duration::from_secs(5),
25            dead_threshold: Duration::from_secs(15),
26            mark_dead_nodes: true,
27            max_interval: Duration::from_secs(60),
28        }
29    }
30}
31
32/// Heartbeat loop for cluster health.
33pub struct HeartbeatLoop {
34    pool: sqlx::PgPool,
35    node_id: NodeId,
36    config: HeartbeatConfig,
37    running: Arc<AtomicBool>,
38    shutdown_tx: watch::Sender<bool>,
39    shutdown_rx: watch::Receiver<bool>,
40    current_interval_ms: AtomicU64,
41    stable_count: AtomicU32,
42    last_active_count: AtomicU32,
43}
44
45impl HeartbeatLoop {
46    /// Create a new heartbeat loop.
47    pub fn new(pool: sqlx::PgPool, node_id: NodeId, config: HeartbeatConfig) -> Self {
48        let (shutdown_tx, shutdown_rx) = watch::channel(false);
49        let interval_ms = config.interval.as_millis() as u64;
50        Self {
51            pool,
52            node_id,
53            config,
54            running: Arc::new(AtomicBool::new(false)),
55            shutdown_tx,
56            shutdown_rx,
57            current_interval_ms: AtomicU64::new(interval_ms),
58            stable_count: AtomicU32::new(0),
59            last_active_count: AtomicU32::new(0),
60        }
61    }
62
63    /// Check if the loop is running.
64    pub fn is_running(&self) -> bool {
65        self.running.load(Ordering::SeqCst)
66    }
67
68    /// Get a shutdown receiver.
69    pub fn shutdown_receiver(&self) -> watch::Receiver<bool> {
70        self.shutdown_rx.clone()
71    }
72
73    /// Stop the heartbeat loop.
74    pub fn stop(&self) {
75        let _ = self.shutdown_tx.send(true);
76        self.running.store(false, Ordering::SeqCst);
77    }
78
79    /// Run the heartbeat loop.
80    pub async fn run(&self) {
81        self.running.store(true, Ordering::SeqCst);
82        let mut shutdown_rx = self.shutdown_rx.clone();
83
84        loop {
85            let interval = self.current_interval();
86            tokio::select! {
87                _ = tokio::time::sleep(interval) => {
88                    // Update our heartbeat
89                    if let Err(e) = self.send_heartbeat().await {
90                        tracing::debug!(error = %e, "Failed to send heartbeat");
91                    }
92
93                    // Adjust interval based on cluster stability
94                    self.adjust_interval().await;
95
96                    // Mark dead nodes if enabled
97                    if self.config.mark_dead_nodes
98                        && let Err(e) = self.mark_dead_nodes().await
99                    {
100                        tracing::debug!(error = %e, "Failed to mark dead nodes");
101                    }
102                }
103                _ = shutdown_rx.changed() => {
104                    if *shutdown_rx.borrow() {
105                        tracing::debug!("Heartbeat loop shutting down");
106                        break;
107                    }
108                }
109            }
110        }
111
112        self.running.store(false, Ordering::SeqCst);
113    }
114
115    /// Current adaptive interval.
116    fn current_interval(&self) -> Duration {
117        Duration::from_millis(self.current_interval_ms.load(Ordering::Relaxed))
118    }
119
120    /// Query the number of active nodes in the cluster.
121    async fn active_node_count(&self) -> forge_core::Result<u32> {
122        let row: (i64,) =
123            sqlx::query_as("SELECT COUNT(*) FROM forge_nodes WHERE status = 'active'")
124                .fetch_one(&self.pool)
125                .await
126                .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
127
128        Ok(row.0 as u32)
129    }
130
131    /// Adjust heartbeat interval based on cluster stability.
132    async fn adjust_interval(&self) {
133        let count = match self.active_node_count().await {
134            Ok(c) => c,
135            Err(e) => {
136                tracing::debug!(error = %e, "Failed to query active node count");
137                return;
138            }
139        };
140
141        let last = self.last_active_count.load(Ordering::Relaxed);
142        if last != 0 && count == last {
143            let stable = self.stable_count.fetch_add(1, Ordering::Relaxed) + 1;
144            if stable >= 3 {
145                let base_ms = self.config.interval.as_millis() as u64;
146                let max_ms = self.config.max_interval.as_millis() as u64;
147                let cur = self.current_interval_ms.load(Ordering::Relaxed);
148                let next = (cur * 2).min(max_ms).max(base_ms);
149                self.current_interval_ms.store(next, Ordering::Relaxed);
150            }
151        } else {
152            // Membership changed, reset to base interval
153            self.stable_count.store(0, Ordering::Relaxed);
154            let base_ms = self.config.interval.as_millis() as u64;
155            self.current_interval_ms.store(base_ms, Ordering::Relaxed);
156        }
157        self.last_active_count.store(count, Ordering::Relaxed);
158    }
159
160    /// Send a heartbeat update.
161    async fn send_heartbeat(&self) -> forge_core::Result<()> {
162        sqlx::query(
163            r#"
164            UPDATE forge_nodes
165            SET last_heartbeat = NOW()
166            WHERE id = $1
167            "#,
168        )
169        .bind(self.node_id.as_uuid())
170        .execute(&self.pool)
171        .await
172        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
173
174        Ok(())
175    }
176
177    /// Mark stale nodes as dead. Threshold is the greater of the configured
178    /// dead_threshold and 3x the current adaptive interval.
179    async fn mark_dead_nodes(&self) -> forge_core::Result<u64> {
180        let adaptive = self.current_interval().as_secs_f64() * 3.0;
181        let configured = self.config.dead_threshold.as_secs_f64();
182        let threshold_secs = adaptive.max(configured);
183
184        let result = sqlx::query(
185            r#"
186            UPDATE forge_nodes
187            SET status = 'dead'
188            WHERE status = 'active'
189              AND last_heartbeat < NOW() - make_interval(secs => $1)
190            "#,
191        )
192        .bind(threshold_secs)
193        .execute(&self.pool)
194        .await
195        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
196
197        let count = result.rows_affected();
198        if count > 0 {
199            tracing::warn!(count, "Marked nodes as dead");
200        }
201
202        Ok(count)
203    }
204
205    /// Update load metrics.
206    pub async fn update_load(
207        &self,
208        current_connections: u32,
209        current_jobs: u32,
210        cpu_usage: f32,
211        memory_usage: f32,
212    ) -> forge_core::Result<()> {
213        sqlx::query(
214            r#"
215            UPDATE forge_nodes
216            SET current_connections = $2,
217                current_jobs = $3,
218                cpu_usage = $4,
219                memory_usage = $5,
220                last_heartbeat = NOW()
221            WHERE id = $1
222            "#,
223        )
224        .bind(self.node_id.as_uuid())
225        .bind(current_connections as i32)
226        .bind(current_jobs as i32)
227        .bind(cpu_usage)
228        .bind(memory_usage)
229        .execute(&self.pool)
230        .await
231        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
232
233        Ok(())
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_heartbeat_config_default() {
243        let config = HeartbeatConfig::default();
244        assert_eq!(config.interval, Duration::from_secs(5));
245        assert_eq!(config.dead_threshold, Duration::from_secs(15));
246        assert!(config.mark_dead_nodes);
247        assert_eq!(config.max_interval, Duration::from_secs(60));
248    }
249}