forge_runtime/cluster/
heartbeat.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use forge_core::cluster::NodeId;
6use tokio::sync::watch;
7
8#[derive(Debug, Clone)]
10pub struct HeartbeatConfig {
11 pub interval: Duration,
13 pub dead_threshold: Duration,
15 pub mark_dead_nodes: bool,
17 pub max_interval: Duration,
19}
20
21impl Default for HeartbeatConfig {
22 fn default() -> Self {
23 Self {
24 interval: Duration::from_secs(5),
25 dead_threshold: Duration::from_secs(15),
26 mark_dead_nodes: true,
27 max_interval: Duration::from_secs(60),
28 }
29 }
30}
31
32pub struct HeartbeatLoop {
34 pool: sqlx::PgPool,
35 node_id: NodeId,
36 config: HeartbeatConfig,
37 running: Arc<AtomicBool>,
38 shutdown_tx: watch::Sender<bool>,
39 shutdown_rx: watch::Receiver<bool>,
40 current_interval_ms: AtomicU64,
41 stable_count: AtomicU32,
42 last_active_count: AtomicU32,
43}
44
45impl HeartbeatLoop {
46 pub fn new(pool: sqlx::PgPool, node_id: NodeId, config: HeartbeatConfig) -> Self {
48 let (shutdown_tx, shutdown_rx) = watch::channel(false);
49 let interval_ms = config.interval.as_millis() as u64;
50 Self {
51 pool,
52 node_id,
53 config,
54 running: Arc::new(AtomicBool::new(false)),
55 shutdown_tx,
56 shutdown_rx,
57 current_interval_ms: AtomicU64::new(interval_ms),
58 stable_count: AtomicU32::new(0),
59 last_active_count: AtomicU32::new(0),
60 }
61 }
62
63 pub fn is_running(&self) -> bool {
65 self.running.load(Ordering::SeqCst)
66 }
67
68 pub fn shutdown_receiver(&self) -> watch::Receiver<bool> {
70 self.shutdown_rx.clone()
71 }
72
73 pub fn stop(&self) {
75 let _ = self.shutdown_tx.send(true);
76 self.running.store(false, Ordering::SeqCst);
77 }
78
79 pub async fn run(&self) {
81 self.running.store(true, Ordering::SeqCst);
82 let mut shutdown_rx = self.shutdown_rx.clone();
83
84 loop {
85 let interval = self.current_interval();
86 tokio::select! {
87 _ = tokio::time::sleep(interval) => {
88 if let Err(e) = self.send_heartbeat().await {
90 tracing::debug!(error = %e, "Failed to send heartbeat");
91 }
92
93 self.adjust_interval().await;
95
96 if self.config.mark_dead_nodes
98 && let Err(e) = self.mark_dead_nodes().await
99 {
100 tracing::debug!(error = %e, "Failed to mark dead nodes");
101 }
102 }
103 _ = shutdown_rx.changed() => {
104 if *shutdown_rx.borrow() {
105 tracing::debug!("Heartbeat loop shutting down");
106 break;
107 }
108 }
109 }
110 }
111
112 self.running.store(false, Ordering::SeqCst);
113 }
114
115 fn current_interval(&self) -> Duration {
117 Duration::from_millis(self.current_interval_ms.load(Ordering::Relaxed))
118 }
119
120 async fn active_node_count(&self) -> forge_core::Result<u32> {
122 let row: (i64,) =
123 sqlx::query_as("SELECT COUNT(*) FROM forge_nodes WHERE status = 'active'")
124 .fetch_one(&self.pool)
125 .await
126 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
127
128 Ok(row.0 as u32)
129 }
130
131 async fn adjust_interval(&self) {
133 let count = match self.active_node_count().await {
134 Ok(c) => c,
135 Err(e) => {
136 tracing::debug!(error = %e, "Failed to query active node count");
137 return;
138 }
139 };
140
141 let last = self.last_active_count.load(Ordering::Relaxed);
142 if last != 0 && count == last {
143 let stable = self.stable_count.fetch_add(1, Ordering::Relaxed) + 1;
144 if stable >= 3 {
145 let base_ms = self.config.interval.as_millis() as u64;
146 let max_ms = self.config.max_interval.as_millis() as u64;
147 let cur = self.current_interval_ms.load(Ordering::Relaxed);
148 let next = (cur * 2).min(max_ms).max(base_ms);
149 self.current_interval_ms.store(next, Ordering::Relaxed);
150 }
151 } else {
152 self.stable_count.store(0, Ordering::Relaxed);
154 let base_ms = self.config.interval.as_millis() as u64;
155 self.current_interval_ms.store(base_ms, Ordering::Relaxed);
156 }
157 self.last_active_count.store(count, Ordering::Relaxed);
158 }
159
160 async fn send_heartbeat(&self) -> forge_core::Result<()> {
162 sqlx::query(
163 r#"
164 UPDATE forge_nodes
165 SET last_heartbeat = NOW()
166 WHERE id = $1
167 "#,
168 )
169 .bind(self.node_id.as_uuid())
170 .execute(&self.pool)
171 .await
172 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
173
174 Ok(())
175 }
176
177 async fn mark_dead_nodes(&self) -> forge_core::Result<u64> {
180 let adaptive = self.current_interval().as_secs_f64() * 3.0;
181 let configured = self.config.dead_threshold.as_secs_f64();
182 let threshold_secs = adaptive.max(configured);
183
184 let result = sqlx::query(
185 r#"
186 UPDATE forge_nodes
187 SET status = 'dead'
188 WHERE status = 'active'
189 AND last_heartbeat < NOW() - make_interval(secs => $1)
190 "#,
191 )
192 .bind(threshold_secs)
193 .execute(&self.pool)
194 .await
195 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
196
197 let count = result.rows_affected();
198 if count > 0 {
199 tracing::warn!(count, "Marked nodes as dead");
200 }
201
202 Ok(count)
203 }
204
205 pub async fn update_load(
207 &self,
208 current_connections: u32,
209 current_jobs: u32,
210 cpu_usage: f32,
211 memory_usage: f32,
212 ) -> forge_core::Result<()> {
213 sqlx::query(
214 r#"
215 UPDATE forge_nodes
216 SET current_connections = $2,
217 current_jobs = $3,
218 cpu_usage = $4,
219 memory_usage = $5,
220 last_heartbeat = NOW()
221 WHERE id = $1
222 "#,
223 )
224 .bind(self.node_id.as_uuid())
225 .bind(current_connections as i32)
226 .bind(current_jobs as i32)
227 .bind(cpu_usage)
228 .bind(memory_usage)
229 .execute(&self.pool)
230 .await
231 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
232
233 Ok(())
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_heartbeat_config_default() {
243 let config = HeartbeatConfig::default();
244 assert_eq!(config.interval, Duration::from_secs(5));
245 assert_eq!(config.dead_threshold, Duration::from_secs(15));
246 assert!(config.mark_dead_nodes);
247 assert_eq!(config.max_interval, Duration::from_secs(60));
248 }
249}