forge_runtime/cluster/
heartbeat.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4
5use forge_core::cluster::NodeId;
6use tokio::sync::watch;
7
8/// Heartbeat loop configuration.
9#[derive(Debug, Clone)]
10pub struct HeartbeatConfig {
11    /// Interval between heartbeats.
12    pub interval: Duration,
13    /// Threshold for marking nodes as dead.
14    pub dead_threshold: Duration,
15    /// Whether to mark dead nodes.
16    pub mark_dead_nodes: bool,
17}
18
19impl Default for HeartbeatConfig {
20    fn default() -> Self {
21        Self {
22            interval: Duration::from_secs(5),
23            dead_threshold: Duration::from_secs(15),
24            mark_dead_nodes: true,
25        }
26    }
27}
28
29/// Heartbeat loop for cluster health.
30pub struct HeartbeatLoop {
31    pool: sqlx::PgPool,
32    node_id: NodeId,
33    config: HeartbeatConfig,
34    running: Arc<AtomicBool>,
35    shutdown_tx: watch::Sender<bool>,
36    shutdown_rx: watch::Receiver<bool>,
37}
38
39impl HeartbeatLoop {
40    /// Create a new heartbeat loop.
41    pub fn new(pool: sqlx::PgPool, node_id: NodeId, config: HeartbeatConfig) -> Self {
42        let (shutdown_tx, shutdown_rx) = watch::channel(false);
43        Self {
44            pool,
45            node_id,
46            config,
47            running: Arc::new(AtomicBool::new(false)),
48            shutdown_tx,
49            shutdown_rx,
50        }
51    }
52
53    /// Check if the loop is running.
54    pub fn is_running(&self) -> bool {
55        self.running.load(Ordering::SeqCst)
56    }
57
58    /// Get a shutdown receiver.
59    pub fn shutdown_receiver(&self) -> watch::Receiver<bool> {
60        self.shutdown_rx.clone()
61    }
62
63    /// Stop the heartbeat loop.
64    pub fn stop(&self) {
65        let _ = self.shutdown_tx.send(true);
66        self.running.store(false, Ordering::SeqCst);
67    }
68
69    /// Run the heartbeat loop.
70    pub async fn run(&self) {
71        self.running.store(true, Ordering::SeqCst);
72        let mut shutdown_rx = self.shutdown_rx.clone();
73
74        loop {
75            tokio::select! {
76                _ = tokio::time::sleep(self.config.interval) => {
77                    // Update our heartbeat
78                    if let Err(e) = self.send_heartbeat().await {
79                        tracing::warn!("Failed to send heartbeat: {}", e);
80                    }
81
82                    // Mark dead nodes if enabled
83                    if self.config.mark_dead_nodes {
84                        if let Err(e) = self.mark_dead_nodes().await {
85                            tracing::warn!("Failed to mark dead nodes: {}", e);
86                        }
87                    }
88                }
89                _ = shutdown_rx.changed() => {
90                    if *shutdown_rx.borrow() {
91                        tracing::info!("Heartbeat loop shutting down");
92                        break;
93                    }
94                }
95            }
96        }
97
98        self.running.store(false, Ordering::SeqCst);
99    }
100
101    /// Send a heartbeat update.
102    async fn send_heartbeat(&self) -> forge_core::Result<()> {
103        sqlx::query(
104            r#"
105            UPDATE forge_nodes
106            SET last_heartbeat = NOW()
107            WHERE id = $1
108            "#,
109        )
110        .bind(self.node_id.as_uuid())
111        .execute(&self.pool)
112        .await
113        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
114
115        Ok(())
116    }
117
118    /// Mark stale nodes as dead.
119    async fn mark_dead_nodes(&self) -> forge_core::Result<u64> {
120        let threshold_secs = self.config.dead_threshold.as_secs() as f64;
121
122        let result = sqlx::query(
123            r#"
124            UPDATE forge_nodes
125            SET status = 'dead'
126            WHERE status = 'active'
127              AND last_heartbeat < NOW() - make_interval(secs => $1)
128            "#,
129        )
130        .bind(threshold_secs)
131        .execute(&self.pool)
132        .await
133        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
134
135        let count = result.rows_affected();
136        if count > 0 {
137            tracing::info!("Marked {} nodes as dead", count);
138        }
139
140        Ok(count)
141    }
142
143    /// Update load metrics.
144    pub async fn update_load(
145        &self,
146        current_connections: u32,
147        current_jobs: u32,
148        cpu_usage: f32,
149        memory_usage: f32,
150    ) -> forge_core::Result<()> {
151        sqlx::query(
152            r#"
153            UPDATE forge_nodes
154            SET current_connections = $2,
155                current_jobs = $3,
156                cpu_usage = $4,
157                memory_usage = $5,
158                last_heartbeat = NOW()
159            WHERE id = $1
160            "#,
161        )
162        .bind(self.node_id.as_uuid())
163        .bind(current_connections as i32)
164        .bind(current_jobs as i32)
165        .bind(cpu_usage)
166        .bind(memory_usage)
167        .execute(&self.pool)
168        .await
169        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
170
171        Ok(())
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_heartbeat_config_default() {
181        let config = HeartbeatConfig::default();
182        assert_eq!(config.interval, Duration::from_secs(5));
183        assert_eq!(config.dead_threshold, Duration::from_secs(15));
184        assert!(config.mark_dead_nodes);
185    }
186}