1use crate::raft::{NodeId, NodeState};
17use dashmap::DashMap;
18use ipfrs_core::{Error, Result};
19use serde::{Deserialize, Serialize};
20use std::net::SocketAddr;
21use std::sync::Arc;
22use std::time::{Duration, SystemTime};
23use tokio::sync::RwLock;
24
25#[derive(Debug, Clone)]
27pub struct ClusterConfig {
28 pub heartbeat_interval_ms: u64,
30 pub failure_threshold: u32,
32 pub min_cluster_size: usize,
34 pub max_cluster_size: usize,
36}
37
38impl Default for ClusterConfig {
39 fn default() -> Self {
40 Self {
41 heartbeat_interval_ms: 1000, failure_threshold: 3, min_cluster_size: 3, max_cluster_size: 100, }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct NodeInfo {
52 pub node_id: NodeId,
54 pub address: SocketAddr,
56 pub state: NodeState,
58 pub last_heartbeat: SystemTime,
60 pub health: NodeHealth,
62 pub missed_heartbeats: u32,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68pub enum NodeHealth {
69 Healthy,
71 Degraded,
73 Suspected,
75 Down,
77}
78
79type FailoverCallback = Arc<RwLock<Option<Box<dyn Fn(NodeId) + Send + Sync>>>>;
81
82pub struct ClusterCoordinator {
84 config: ClusterConfig,
86 nodes: Arc<DashMap<NodeId, NodeInfo>>,
88 leader: Arc<RwLock<Option<NodeId>>>,
90 shutdown: Arc<RwLock<bool>>,
92 failover_callback: FailoverCallback,
94}
95
96impl ClusterCoordinator {
97 pub fn new(config: ClusterConfig) -> Self {
99 Self {
100 config,
101 nodes: Arc::new(DashMap::new()),
102 leader: Arc::new(RwLock::new(None)),
103 shutdown: Arc::new(RwLock::new(false)),
104 failover_callback: Arc::new(RwLock::new(None)),
105 }
106 }
107
108 pub async fn set_failover_callback<F>(&self, callback: F)
110 where
111 F: Fn(NodeId) + Send + Sync + 'static,
112 {
113 *self.failover_callback.write().await = Some(Box::new(callback));
114 }
115
116 #[allow(clippy::unused_async)]
118 pub async fn add_node(&self, node_id: NodeId, address: SocketAddr) -> Result<()> {
119 if self.nodes.len() >= self.config.max_cluster_size {
120 return Err(Error::Network(format!(
121 "Cluster size limit reached: {}",
122 self.config.max_cluster_size
123 )));
124 }
125
126 let node_info = NodeInfo {
127 node_id,
128 address,
129 state: NodeState::Follower,
130 last_heartbeat: SystemTime::now(),
131 health: NodeHealth::Healthy,
132 missed_heartbeats: 0,
133 };
134
135 self.nodes.insert(node_id, node_info);
136 tracing::info!("Added node {} to cluster at {}", node_id.0, address);
137
138 Ok(())
139 }
140
141 pub async fn remove_node(&self, node_id: NodeId) -> Result<()> {
143 self.nodes.remove(&node_id);
144 tracing::info!("Removed node {} from cluster", node_id.0);
145
146 let mut leader = self.leader.write().await;
148 if *leader == Some(node_id) {
149 *leader = None;
150 }
151
152 Ok(())
153 }
154
155 pub async fn update_node_state(&self, node_id: NodeId, state: NodeState) -> Result<()> {
157 if let Some(mut node) = self.nodes.get_mut(&node_id) {
158 node.state = state;
159
160 if state == NodeState::Leader {
162 *self.leader.write().await = Some(node_id);
163 tracing::info!("Node {} is now the leader", node_id.0);
164 }
165
166 Ok(())
167 } else {
168 Err(Error::Network(format!("Node {} not found", node_id.0)))
169 }
170 }
171
172 #[allow(clippy::unused_async)]
174 pub async fn heartbeat(&self, node_id: NodeId) -> Result<()> {
175 if let Some(mut node) = self.nodes.get_mut(&node_id) {
176 node.last_heartbeat = SystemTime::now();
177 node.missed_heartbeats = 0;
178 node.health = NodeHealth::Healthy;
179 Ok(())
180 } else {
181 Err(Error::Network(format!("Node {} not found", node_id.0)))
182 }
183 }
184
185 #[allow(clippy::unused_async)]
187 pub async fn start_health_monitoring(&self) {
188 let nodes = self.nodes.clone();
189 let config = self.config.clone();
190 let shutdown = self.shutdown.clone();
191 let leader = self.leader.clone();
192 let failover_callback = self.failover_callback.clone();
193
194 tokio::spawn(async move {
195 let interval = Duration::from_millis(config.heartbeat_interval_ms);
196
197 loop {
198 if *shutdown.read().await {
199 break;
200 }
201
202 let mut leader_down = false;
203 let mut failed_leader_id = None;
204
205 for mut entry in nodes.iter_mut() {
207 let node = entry.value_mut();
208
209 if let Ok(elapsed) = node.last_heartbeat.elapsed() {
210 let missed =
211 (elapsed.as_millis() / config.heartbeat_interval_ms as u128) as u32;
212
213 if missed > node.missed_heartbeats {
214 node.missed_heartbeats = missed;
215
216 let old_health = node.health;
218 node.health = if missed >= config.failure_threshold {
219 NodeHealth::Down
220 } else if missed >= config.failure_threshold / 2 {
221 NodeHealth::Suspected
222 } else if missed > 0 {
223 NodeHealth::Degraded
224 } else {
225 NodeHealth::Healthy
226 };
227
228 if node.health == NodeHealth::Down && old_health != NodeHealth::Down {
230 tracing::warn!(
231 "Node {} is down (missed {} heartbeats)",
232 node.node_id.0,
233 missed
234 );
235
236 let current_leader = leader.read().await;
238 if *current_leader == Some(node.node_id) {
239 leader_down = true;
240 failed_leader_id = Some(node.node_id);
241 }
242 }
243 }
244 }
245 }
246
247 if leader_down {
249 if let Some(leader_id) = failed_leader_id {
250 tracing::warn!("Leader {} has failed, triggering failover", leader_id.0);
251
252 *leader.write().await = None;
254
255 if let Some(callback) = failover_callback.read().await.as_ref() {
257 callback(leader_id);
258 }
259 }
260 }
261
262 tokio::time::sleep(interval).await;
263 }
264 });
265 }
266
267 pub async fn trigger_failover(&self) -> Result<()> {
269 let current_leader = *self.leader.read().await;
270
271 if let Some(leader_id) = current_leader {
272 tracing::info!("Manually triggering failover for leader {}", leader_id.0);
273
274 *self.leader.write().await = None;
276
277 if let Some(callback) = self.failover_callback.read().await.as_ref() {
279 callback(leader_id);
280 }
281
282 Ok(())
283 } else {
284 Err(Error::Network("No leader to failover from".into()))
285 }
286 }
287
288 pub async fn should_trigger_reelection(&self) -> bool {
290 let current_leader = *self.leader.read().await;
291
292 current_leader.is_none() && self.has_quorum()
294 }
295
296 pub fn get_election_candidates(&self) -> Vec<NodeId> {
298 self.nodes
299 .iter()
300 .filter(|entry| {
301 let node = entry.value();
302 matches!(node.health, NodeHealth::Healthy | NodeHealth::Degraded)
303 })
304 .map(|entry| *entry.key())
305 .collect()
306 }
307
308 pub fn cluster_size(&self) -> usize {
310 self.nodes.len()
311 }
312
313 pub fn healthy_nodes(&self) -> usize {
315 self.nodes
316 .iter()
317 .filter(|entry| entry.value().health == NodeHealth::Healthy)
318 .count()
319 }
320
321 pub fn has_quorum(&self) -> bool {
323 let healthy = self.healthy_nodes();
324 healthy >= (self.config.min_cluster_size / 2 + 1)
325 }
326
327 pub async fn get_leader(&self) -> Option<NodeId> {
329 *self.leader.read().await
330 }
331
332 pub fn get_node_ids(&self) -> Vec<NodeId> {
334 self.nodes.iter().map(|entry| *entry.key()).collect()
335 }
336
337 pub fn get_node_info(&self, node_id: NodeId) -> Option<NodeInfo> {
339 self.nodes.get(&node_id).map(|entry| entry.value().clone())
340 }
341
342 pub fn get_cluster_stats(&self) -> ClusterStats {
344 let total = self.nodes.len();
345 let mut healthy = 0;
346 let mut degraded = 0;
347 let mut suspected = 0;
348 let mut down = 0;
349
350 for entry in self.nodes.iter() {
351 match entry.value().health {
352 NodeHealth::Healthy => healthy += 1,
353 NodeHealth::Degraded => degraded += 1,
354 NodeHealth::Suspected => suspected += 1,
355 NodeHealth::Down => down += 1,
356 }
357 }
358
359 ClusterStats {
360 total_nodes: total,
361 healthy_nodes: healthy,
362 degraded_nodes: degraded,
363 suspected_nodes: suspected,
364 down_nodes: down,
365 has_quorum: self.has_quorum(),
366 }
367 }
368
369 pub async fn shutdown(&self) {
371 *self.shutdown.write().await = true;
372 }
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct ClusterStats {
378 pub total_nodes: usize,
380 pub healthy_nodes: usize,
382 pub degraded_nodes: usize,
384 pub suspected_nodes: usize,
386 pub down_nodes: usize,
388 pub has_quorum: bool,
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[tokio::test]
397 async fn test_cluster_add_remove_node() {
398 let config = ClusterConfig::default();
399 let coordinator = ClusterCoordinator::new(config);
400
401 let node_id = NodeId(1);
402 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
403
404 coordinator.add_node(node_id, addr).await.unwrap();
405 assert_eq!(coordinator.cluster_size(), 1);
406
407 coordinator.remove_node(node_id).await.unwrap();
408 assert_eq!(coordinator.cluster_size(), 0);
409 }
410
411 #[tokio::test]
412 async fn test_cluster_size_limit() {
413 let config = ClusterConfig {
414 max_cluster_size: 2,
415 ..Default::default()
416 };
417 let coordinator = ClusterCoordinator::new(config);
418
419 coordinator
420 .add_node(NodeId(1), "127.0.0.1:8001".parse().unwrap())
421 .await
422 .unwrap();
423
424 coordinator
425 .add_node(NodeId(2), "127.0.0.1:8002".parse().unwrap())
426 .await
427 .unwrap();
428
429 let result = coordinator
431 .add_node(NodeId(3), "127.0.0.1:8003".parse().unwrap())
432 .await;
433
434 assert!(result.is_err());
435 }
436
437 #[tokio::test]
438 async fn test_heartbeat() {
439 let config = ClusterConfig::default();
440 let coordinator = ClusterCoordinator::new(config);
441
442 let node_id = NodeId(1);
443 coordinator
444 .add_node(node_id, "127.0.0.1:8000".parse().unwrap())
445 .await
446 .unwrap();
447
448 coordinator.heartbeat(node_id).await.unwrap();
449
450 let info = coordinator.get_node_info(node_id).unwrap();
451 assert_eq!(info.health, NodeHealth::Healthy);
452 assert_eq!(info.missed_heartbeats, 0);
453 }
454
455 #[tokio::test]
456 async fn test_leader_tracking() {
457 let config = ClusterConfig::default();
458 let coordinator = ClusterCoordinator::new(config);
459
460 let node_id = NodeId(1);
461 coordinator
462 .add_node(node_id, "127.0.0.1:8000".parse().unwrap())
463 .await
464 .unwrap();
465
466 assert_eq!(coordinator.get_leader().await, None);
467
468 coordinator
469 .update_node_state(node_id, NodeState::Leader)
470 .await
471 .unwrap();
472
473 assert_eq!(coordinator.get_leader().await, Some(node_id));
474 }
475
476 #[tokio::test]
477 async fn test_quorum() {
478 let config = ClusterConfig {
479 min_cluster_size: 3,
480 ..Default::default()
481 };
482 let coordinator = ClusterCoordinator::new(config);
483
484 coordinator
486 .add_node(NodeId(1), "127.0.0.1:8001".parse().unwrap())
487 .await
488 .unwrap();
489
490 coordinator
491 .add_node(NodeId(2), "127.0.0.1:8002".parse().unwrap())
492 .await
493 .unwrap();
494
495 coordinator
496 .add_node(NodeId(3), "127.0.0.1:8003".parse().unwrap())
497 .await
498 .unwrap();
499
500 assert!(coordinator.has_quorum());
502
503 let stats = coordinator.get_cluster_stats();
504 assert_eq!(stats.total_nodes, 3);
505 assert_eq!(stats.healthy_nodes, 3);
506 assert!(stats.has_quorum);
507 }
508
509 #[tokio::test]
510 async fn test_cluster_stats() {
511 let config = ClusterConfig::default();
512 let coordinator = ClusterCoordinator::new(config);
513
514 coordinator
515 .add_node(NodeId(1), "127.0.0.1:8001".parse().unwrap())
516 .await
517 .unwrap();
518
519 coordinator
520 .add_node(NodeId(2), "127.0.0.1:8002".parse().unwrap())
521 .await
522 .unwrap();
523
524 let stats = coordinator.get_cluster_stats();
525 assert_eq!(stats.total_nodes, 2);
526 assert_eq!(stats.healthy_nodes, 2);
527 }
528
529 #[tokio::test]
530 async fn test_manual_failover() {
531 let config = ClusterConfig::default();
532 let coordinator = ClusterCoordinator::new(config);
533
534 let node_id = NodeId(1);
535 coordinator
536 .add_node(node_id, "127.0.0.1:8000".parse().unwrap())
537 .await
538 .unwrap();
539
540 coordinator
542 .update_node_state(node_id, NodeState::Leader)
543 .await
544 .unwrap();
545
546 assert_eq!(coordinator.get_leader().await, Some(node_id));
547
548 coordinator.trigger_failover().await.unwrap();
550
551 assert_eq!(coordinator.get_leader().await, None);
553 }
554
555 #[tokio::test]
556 async fn test_failover_callback() {
557 use std::sync::atomic::{AtomicBool, Ordering};
558
559 let config = ClusterConfig::default();
560 let coordinator = ClusterCoordinator::new(config);
561
562 let node_id = NodeId(1);
563 coordinator
564 .add_node(node_id, "127.0.0.1:8000".parse().unwrap())
565 .await
566 .unwrap();
567
568 let callback_triggered = Arc::new(AtomicBool::new(false));
570 let callback_triggered_clone = callback_triggered.clone();
571
572 coordinator
573 .set_failover_callback(move |_| {
574 callback_triggered_clone.store(true, Ordering::SeqCst);
575 })
576 .await;
577
578 coordinator
580 .update_node_state(node_id, NodeState::Leader)
581 .await
582 .unwrap();
583
584 coordinator.trigger_failover().await.unwrap();
586
587 assert!(callback_triggered.load(Ordering::SeqCst));
589 }
590
591 #[tokio::test]
592 async fn test_reelection_trigger_check() {
593 let config = ClusterConfig {
594 min_cluster_size: 3,
595 ..Default::default()
596 };
597 let coordinator = ClusterCoordinator::new(config);
598
599 coordinator
601 .add_node(NodeId(1), "127.0.0.1:8001".parse().unwrap())
602 .await
603 .unwrap();
604
605 coordinator
606 .add_node(NodeId(2), "127.0.0.1:8002".parse().unwrap())
607 .await
608 .unwrap();
609
610 coordinator
611 .add_node(NodeId(3), "127.0.0.1:8003".parse().unwrap())
612 .await
613 .unwrap();
614
615 assert!(coordinator.should_trigger_reelection().await);
617
618 coordinator
620 .update_node_state(NodeId(1), NodeState::Leader)
621 .await
622 .unwrap();
623
624 assert!(!coordinator.should_trigger_reelection().await);
626 }
627
628 #[tokio::test]
629 async fn test_election_candidates() {
630 let config = ClusterConfig::default();
631 let coordinator = ClusterCoordinator::new(config);
632
633 coordinator
634 .add_node(NodeId(1), "127.0.0.1:8001".parse().unwrap())
635 .await
636 .unwrap();
637
638 coordinator
639 .add_node(NodeId(2), "127.0.0.1:8002".parse().unwrap())
640 .await
641 .unwrap();
642
643 let candidates = coordinator.get_election_candidates();
644 assert_eq!(candidates.len(), 2);
645 assert!(candidates.contains(&NodeId(1)));
646 assert!(candidates.contains(&NodeId(2)));
647 }
648}