1use std::collections::HashMap;
12use std::net::SocketAddr;
13use std::sync::{Arc, Mutex, RwLock};
14use std::time::{Duration, Instant};
15
16use tracing::{debug, info, warn};
17
18use crate::catalog::ClusterCatalog;
19use crate::loop_metrics::LoopMetrics;
20use crate::rpc_codec::{
21 JoinNodeInfo, PingRequest, PongResponse, RaftRpc, TopologyAck, TopologyUpdate,
22};
23use crate::topology::{ClusterTopology, NodeState};
24use crate::transport::NexarTransport;
25
26pub const DEFAULT_PING_INTERVAL: Duration = Duration::from_secs(5);
30
31pub const DEFAULT_FAILURE_THRESHOLD: u32 = 3;
35
36#[derive(Debug, Clone)]
38pub struct HealthConfig {
39 pub ping_interval: Duration,
40 pub failure_threshold: u32,
41}
42
43impl Default for HealthConfig {
44 fn default() -> Self {
45 Self {
46 ping_interval: DEFAULT_PING_INTERVAL,
47 failure_threshold: DEFAULT_FAILURE_THRESHOLD,
48 }
49 }
50}
51
52pub struct HealthMonitor {
57 node_id: u64,
58 transport: Arc<NexarTransport>,
59 topology: Arc<RwLock<ClusterTopology>>,
60 catalog: Arc<ClusterCatalog>,
61 config: HealthConfig,
62 ping_failures: Mutex<HashMap<u64, u32>>,
64 loop_metrics: Arc<LoopMetrics>,
65}
66
67impl HealthMonitor {
68 pub fn new(
69 node_id: u64,
70 transport: Arc<NexarTransport>,
71 topology: Arc<RwLock<ClusterTopology>>,
72 catalog: Arc<ClusterCatalog>,
73 config: HealthConfig,
74 ) -> Self {
75 Self {
76 node_id,
77 transport,
78 topology,
79 catalog,
80 config,
81 ping_failures: Mutex::new(HashMap::new()),
82 loop_metrics: LoopMetrics::new("health_loop"),
83 }
84 }
85
86 pub fn loop_metrics(&self) -> Arc<LoopMetrics> {
88 Arc::clone(&self.loop_metrics)
89 }
90
91 pub fn suspect_peers(&self) -> HashMap<u64, u32> {
95 self.ping_failures
96 .lock()
97 .unwrap_or_else(|p| p.into_inner())
98 .clone()
99 }
100
101 pub async fn run(&self, mut shutdown: tokio::sync::watch::Receiver<bool>) {
103 let mut interval = tokio::time::interval(self.config.ping_interval);
104 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
105
106 info!(node_id = self.node_id, "health monitor started");
107 self.loop_metrics.set_up(true);
108
109 loop {
110 tokio::select! {
111 _ = interval.tick() => {
112 let started = Instant::now();
113 self.ping_all_peers().await;
114 self.loop_metrics.observe(started.elapsed());
115 }
116 _ = shutdown.changed() => {
117 if *shutdown.borrow() {
118 debug!("health monitor shutting down");
119 break;
120 }
121 }
122 }
123 }
124 self.loop_metrics.set_up(false);
125 }
126
127 async fn ping_all_peers(&self) {
129 let peers = self.collect_peers();
130 if peers.is_empty() {
131 return;
132 }
133
134 let topo_version = {
135 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
136 topo.version()
137 };
138
139 let mut handles = Vec::new();
140 for (peer_id, addr) in peers {
141 let transport = self.transport.clone();
142 let ping = RaftRpc::Ping(PingRequest {
143 sender_id: self.node_id,
144 topology_version: topo_version,
145 });
146 handles.push(tokio::spawn(async move {
147 let result = transport.send_rpc(peer_id, ping).await;
148 (peer_id, addr, result)
149 }));
150 }
151
152 let mut topology_changed = false;
153 for handle in handles {
154 let (peer_id, _addr, result) = match handle.await {
155 Ok(r) => r,
156 Err(_) => continue, };
158
159 match result {
160 Ok(RaftRpc::Pong(pong)) => {
161 topology_changed |= self.handle_pong(peer_id, &pong);
162 }
163 Ok(_) => {
164 topology_changed |= self.record_ping_failure(peer_id);
166 }
167 Err(_) => {
168 topology_changed |= self.record_ping_failure(peer_id);
169 }
170 }
171 }
172
173 if topology_changed {
174 self.persist_and_broadcast().await;
175 }
176 }
177
178 fn handle_pong(&self, peer_id: u64, pong: &PongResponse) -> bool {
181 {
183 let mut failures = self.ping_failures.lock().unwrap_or_else(|p| p.into_inner());
184 failures.remove(&peer_id);
185 }
186
187 let our_version = {
192 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
193 topo.version()
194 };
195 if pong.topology_version < our_version {
196 debug!(
197 peer_id,
198 peer_version = pong.topology_version,
199 our_version,
200 "peer has stale topology, pushing update"
201 );
202 let transport = self.transport.clone();
203 let topology = self.topology.clone();
204 let self_id = self.node_id;
205 tokio::spawn(async move {
206 broadcast_topology_to_peer(self_id, peer_id, &topology, &transport).await;
207 });
208 }
209
210 let mut topo = self.topology.write().unwrap_or_else(|p| p.into_inner());
212 if let Some(node) = topo.get_node(peer_id)
213 && node.state != NodeState::Active
214 && node.state != NodeState::Decommissioned
215 {
216 info!(peer_id, "peer recovered, marking active");
217 topo.set_state(peer_id, NodeState::Active);
218 return true;
219 }
220 false
221 }
222
223 fn record_ping_failure(&self, peer_id: u64) -> bool {
225 self.loop_metrics.record_error("ping");
226 let count = {
227 let mut failures = self.ping_failures.lock().unwrap_or_else(|p| p.into_inner());
228 let count = failures.entry(peer_id).or_insert(0);
229 *count += 1;
230 *count
231 };
232
233 if count >= self.config.failure_threshold {
234 let mut topo = self.topology.write().unwrap_or_else(|p| p.into_inner());
235 if let Some(node) = topo.get_node(peer_id)
236 && node.state == NodeState::Active
237 {
238 warn!(
239 peer_id,
240 failures = count,
241 "peer unreachable, marking draining"
242 );
243 topo.set_state(peer_id, NodeState::Draining);
244 return true;
245 }
246 }
247 false
248 }
249
250 async fn persist_and_broadcast(&self) {
252 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
253 if let Err(e) = self.catalog.save_topology(&topo) {
254 warn!(error = %e, "failed to persist topology update");
255 }
256 drop(topo);
257 broadcast_topology(self.node_id, &self.topology, &self.transport);
258 }
259
260 fn collect_peers(&self) -> Vec<(u64, SocketAddr)> {
262 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
263 topo.all_nodes()
264 .filter(|n| n.node_id != self.node_id && n.state != NodeState::Decommissioned)
265 .filter_map(|n| n.socket_addr().map(|addr| (n.node_id, addr)))
266 .collect()
267 }
268}
269
270pub fn broadcast_topology(
277 self_node_id: u64,
278 topology: &RwLock<ClusterTopology>,
279 transport: &Arc<NexarTransport>,
280) {
281 let (update, active_peers) = {
282 let topo = topology.read().unwrap_or_else(|p| p.into_inner());
283 let update = RaftRpc::TopologyUpdate(TopologyUpdate {
284 version: topo.version(),
285 nodes: topo
286 .all_nodes()
287 .map(|n| JoinNodeInfo {
288 node_id: n.node_id,
289 addr: n.addr.clone(),
290 state: n.state.as_u8(),
291 raft_groups: n.raft_groups.clone(),
292 wire_version: n.wire_version,
293 spiffe_id: n.spiffe_id.clone(),
294 spki_pin: n.spki_pin.map(|arr| arr.to_vec()),
295 })
296 .collect(),
297 });
298 let peers: Vec<u64> = topo
299 .active_nodes()
300 .iter()
301 .map(|n| n.node_id)
302 .filter(|&id| id != self_node_id)
303 .collect();
304 (update, peers)
305 };
306
307 for peer_id in active_peers {
308 let transport = transport.clone();
309 let msg = update.clone();
310 tokio::spawn(async move {
311 if let Err(e) = transport.send_rpc(peer_id, msg).await {
312 debug!(peer_id, error = %e, "topology broadcast failed");
313 }
314 });
315 }
316}
317
318async fn broadcast_topology_to_peer(
320 _self_node_id: u64,
321 peer_id: u64,
322 topology: &RwLock<ClusterTopology>,
323 transport: &NexarTransport,
324) {
325 let update = {
326 let topo = topology.read().unwrap_or_else(|p| p.into_inner());
327 RaftRpc::TopologyUpdate(TopologyUpdate {
328 version: topo.version(),
329 nodes: topo
330 .all_nodes()
331 .map(|n| JoinNodeInfo {
332 node_id: n.node_id,
333 addr: n.addr.clone(),
334 state: n.state.as_u8(),
335 raft_groups: n.raft_groups.clone(),
336 wire_version: n.wire_version,
337 spiffe_id: n.spiffe_id.clone(),
338 spki_pin: n.spki_pin.map(|arr| arr.to_vec()),
339 })
340 .collect(),
341 })
342 };
343 if let Err(e) = transport.send_rpc(peer_id, update).await {
344 debug!(peer_id, error = %e, "targeted topology push failed");
345 }
346}
347
348pub fn handle_ping(node_id: u64, topology_version: u64, _req: &PingRequest) -> RaftRpc {
350 RaftRpc::Pong(PongResponse {
351 responder_id: node_id,
352 topology_version,
353 })
354}
355
356pub fn handle_topology_update(
360 node_id: u64,
361 topology: &RwLock<ClusterTopology>,
362 update: &TopologyUpdate,
363) -> (bool, RaftRpc) {
364 let mut topo = topology.write().unwrap_or_else(|p| p.into_inner());
365
366 let updated = if update.version > topo.version() {
367 let mut new_topo = ClusterTopology::new();
369 for node in &update.nodes {
370 let state = crate::topology::NodeState::from_u8(node.state)
371 .unwrap_or(crate::topology::NodeState::Active);
372 let spki_pin: Option<[u8; 32]> = node.spki_pin.as_deref().and_then(|b| {
373 if b.len() == 32 {
374 let mut arr = [0u8; 32];
375 arr.copy_from_slice(b);
376 Some(arr)
377 } else {
378 None
379 }
380 });
381 let mut info = crate::topology::NodeInfo::new(
382 node.node_id,
383 node.addr.parse().unwrap_or_else(|_| {
384 "0.0.0.0:0"
385 .parse()
386 .expect("invariant: \"0.0.0.0:0\" is a valid SocketAddr literal")
387 }),
388 state,
389 )
390 .with_wire_version(node.wire_version)
391 .with_spiffe_id(node.spiffe_id.clone())
392 .with_spki_pin(spki_pin);
393 info.raft_groups = node.raft_groups.clone();
394 new_topo.add_node(info);
395 }
396 *topo = new_topo;
397 true
398 } else {
399 false
400 };
401
402 let ack = RaftRpc::TopologyAck(TopologyAck {
403 responder_id: node_id,
404 accepted_version: topo.version(),
405 });
406
407 (updated, ack)
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::topology::NodeInfo;
414
415 #[test]
416 fn handle_ping_returns_pong() {
417 let req = PingRequest {
418 sender_id: 2,
419 topology_version: 5,
420 };
421 let resp = handle_ping(1, 7, &req);
422 match resp {
423 RaftRpc::Pong(pong) => {
424 assert_eq!(pong.responder_id, 1);
425 assert_eq!(pong.topology_version, 7);
426 }
427 other => panic!("expected Pong, got {other:?}"),
428 }
429 }
430
431 #[test]
432 fn topology_update_adopts_newer_version() {
433 let topo = RwLock::new(ClusterTopology::new()); let update = TopologyUpdate {
436 version: 3,
437 nodes: vec![
438 JoinNodeInfo {
439 node_id: 1,
440 addr: "10.0.0.1:9400".into(),
441 state: 1,
442 raft_groups: vec![],
443 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
444 spiffe_id: None,
445 spki_pin: None,
446 },
447 JoinNodeInfo {
448 node_id: 2,
449 addr: "10.0.0.2:9400".into(),
450 state: 1,
451 raft_groups: vec![],
452 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
453 spiffe_id: None,
454 spki_pin: None,
455 },
456 ],
457 };
458
459 let (updated, ack) = handle_topology_update(1, &topo, &update);
460 assert!(updated);
461
462 let t = topo.read().unwrap();
463 assert_eq!(t.node_count(), 2);
464
465 match ack {
466 RaftRpc::TopologyAck(a) => assert_eq!(a.accepted_version, t.version()),
467 other => panic!("expected TopologyAck, got {other:?}"),
468 }
469 }
470
471 #[test]
472 fn topology_update_ignores_stale_version() {
473 let topo = RwLock::new(ClusterTopology::new());
474 {
475 let mut t = topo.write().unwrap();
476 t.add_node(NodeInfo::new(
477 1,
478 "10.0.0.1:9400".parse().unwrap(),
479 NodeState::Active,
480 ));
481 }
483
484 let update = TopologyUpdate {
485 version: 0, nodes: vec![],
487 };
488
489 let (updated, _) = handle_topology_update(1, &topo, &update);
490 assert!(!updated);
491
492 let t = topo.read().unwrap();
493 assert_eq!(t.node_count(), 1); }
495
496 #[tokio::test]
497 async fn failure_tracking_marks_draining() {
498 let topo = Arc::new(RwLock::new(ClusterTopology::new()));
500 {
501 let mut t = topo.write().unwrap();
502 t.add_node(NodeInfo::new(
503 1,
504 "10.0.0.1:9400".parse().unwrap(),
505 NodeState::Active,
506 ));
507 t.add_node(NodeInfo::new(
508 2,
509 "10.0.0.2:9400".parse().unwrap(),
510 NodeState::Active,
511 ));
512 }
513
514 let transport = Arc::new(
515 NexarTransport::new(
516 1,
517 "127.0.0.1:0".parse().unwrap(),
518 crate::transport::credentials::TransportCredentials::Insecure,
519 )
520 .unwrap(),
521 );
522 let dir = tempfile::tempdir().unwrap();
523 let catalog = Arc::new(ClusterCatalog::open(&dir.path().join("cluster.redb")).unwrap());
524
525 let monitor = HealthMonitor::new(
526 1,
527 transport,
528 topo.clone(),
529 catalog,
530 HealthConfig {
531 ping_interval: Duration::from_secs(5),
532 failure_threshold: 3,
533 },
534 );
535
536 assert!(!monitor.record_ping_failure(2)); assert!(!monitor.record_ping_failure(2)); assert!(monitor.record_ping_failure(2)); let t = topo.read().unwrap();
542 assert_eq!(t.get_node(2).unwrap().state, NodeState::Draining);
543 }
544
545 #[tokio::test]
546 async fn pong_recovers_node() {
547 let topo = Arc::new(RwLock::new(ClusterTopology::new()));
548 {
549 let mut t = topo.write().unwrap();
550 t.add_node(NodeInfo::new(
551 1,
552 "10.0.0.1:9400".parse().unwrap(),
553 NodeState::Active,
554 ));
555 t.add_node(NodeInfo::new(
556 2,
557 "10.0.0.2:9400".parse().unwrap(),
558 NodeState::Draining, ));
560 }
561
562 let transport = Arc::new(
563 NexarTransport::new(
564 1,
565 "127.0.0.1:0".parse().unwrap(),
566 crate::transport::credentials::TransportCredentials::Insecure,
567 )
568 .unwrap(),
569 );
570 let dir = tempfile::tempdir().unwrap();
571 let catalog = Arc::new(ClusterCatalog::open(&dir.path().join("cluster.redb")).unwrap());
572
573 let monitor =
574 HealthMonitor::new(1, transport, topo.clone(), catalog, HealthConfig::default());
575
576 let pong = PongResponse {
577 responder_id: 2,
578 topology_version: 1,
579 };
580 let changed = monitor.handle_pong(2, &pong);
581 assert!(changed); let t = topo.read().unwrap();
584 assert_eq!(t.get_node(2).unwrap().state, NodeState::Active);
585 }
586}