1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15use tokio::sync::mpsc;
16use tracing::info;
17
18use super::gossip::{GossipConfig, GossipEvent, GossipMember, GossipProtocol, MemberState};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClusterConfig {
23 pub health_check_interval_ms: u64,
25 pub health_timeout_ms: u64,
27 pub failure_threshold: u32,
29 pub recovery_threshold: u32,
31 pub auto_failover: bool,
33 pub min_quorum: u32,
35}
36
37impl Default for ClusterConfig {
38 fn default() -> Self {
39 Self {
40 health_check_interval_ms: 5000,
41 health_timeout_ms: 10000,
42 failure_threshold: 3,
43 recovery_threshold: 2,
44 auto_failover: true,
45 min_quorum: 1,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
52pub enum NodeRole {
53 Primary,
55 Replica,
57 Coordinator,
59 Observer,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65pub enum NodeStatus {
66 Healthy,
68 Suspect,
70 Unhealthy,
72 Draining,
74 Offline,
76 Joining,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct NodeHealth {
83 pub status: NodeStatus,
85 pub last_healthy_ms: u64,
87 pub failure_count: u32,
89 pub success_count: u32,
91 pub avg_response_ms: f64,
93 pub cpu_percent: f32,
95 pub memory_percent: f32,
97 pub active_connections: u32,
99 pub replication_lag_ms: Option<u64>,
102}
103
104impl Default for NodeHealth {
105 fn default() -> Self {
106 Self {
107 status: NodeStatus::Joining,
108 last_healthy_ms: 0,
109 failure_count: 0,
110 success_count: 0,
111 avg_response_ms: 0.0,
112 cpu_percent: 0.0,
113 memory_percent: 0.0,
114 active_connections: 0,
115 replication_lag_ms: None,
116 }
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct NodeInfo {
123 pub node_id: String,
125 pub address: String,
127 pub role: NodeRole,
129 pub shard_ids: Vec<u32>,
131 pub health: NodeHealth,
133 pub metadata: HashMap<String, String>,
135 pub generation: u64,
137}
138
139impl NodeInfo {
140 pub fn new(node_id: String, address: String, role: NodeRole) -> Self {
142 Self {
143 node_id,
144 address,
145 role,
146 shard_ids: Vec::new(),
147 health: NodeHealth::default(),
148 metadata: HashMap::new(),
149 generation: 0,
150 }
151 }
152
153 pub fn can_serve_reads(&self) -> bool {
155 matches!(
156 self.health.status,
157 NodeStatus::Healthy | NodeStatus::Draining
158 )
159 }
160
161 pub fn can_serve_writes(&self) -> bool {
163 self.health.status == NodeStatus::Healthy && self.role == NodeRole::Primary
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize, Default)]
169pub struct ClusterState {
170 pub generation: u64,
172 pub nodes: HashMap<String, NodeInfo>,
174 pub leader_id: Option<String>,
176 pub is_healthy: bool,
178 pub has_quorum: bool,
180 pub healthy_node_count: u32,
182 pub total_node_count: u32,
184 pub last_update_ms: u64,
186}
187
188pub struct ClusterCoordinator {
190 config: ClusterConfig,
192 state: Arc<RwLock<ClusterState>>,
194 _local_node_id: String,
196 generation: AtomicU64,
198 start_time: Instant,
200 gossip: Option<Arc<GossipProtocol>>,
202 gossip_event_rx: Option<mpsc::Receiver<GossipEvent>>,
204}
205
206impl ClusterCoordinator {
207 pub fn new(config: ClusterConfig, local_node_id: String) -> Self {
209 Self {
210 config,
211 state: Arc::new(RwLock::new(ClusterState::default())),
212 _local_node_id: local_node_id,
213 generation: AtomicU64::new(0),
214 start_time: Instant::now(),
215 gossip: None,
216 gossip_event_rx: None,
217 }
218 }
219
220 pub fn with_gossip(
230 config: ClusterConfig,
231 local_node_id: String,
232 local_address: std::net::SocketAddr,
233 api_address: String,
234 role: NodeRole,
235 gossip_config: GossipConfig,
236 ) -> Self {
237 let (event_tx, event_rx) = mpsc::channel(1000);
238 let local_member =
239 GossipMember::new(local_node_id.clone(), local_address, api_address, role);
240 let gossip = GossipProtocol::new(gossip_config, local_member, event_tx);
241
242 Self {
243 config,
244 state: Arc::new(RwLock::new(ClusterState::default())),
245 _local_node_id: local_node_id,
246 generation: AtomicU64::new(0),
247 start_time: Instant::now(),
248 gossip: Some(Arc::new(gossip)),
249 gossip_event_rx: Some(event_rx),
250 }
251 }
252
253 pub async fn start_gossip(&self) -> Result<(), String> {
255 if let Some(gossip) = &self.gossip {
256 gossip.start().await.map_err(|e| e.to_string())
257 } else {
258 Err("Gossip protocol not configured".to_string())
259 }
260 }
261
262 pub fn stop_gossip(&self) {
264 if let Some(gossip) = &self.gossip {
265 gossip.stop();
266 }
267 }
268
269 pub async fn leave_cluster(&self) -> Result<(), String> {
271 if let Some(gossip) = &self.gossip {
272 gossip.leave().await.map_err(|e| e.to_string())
273 } else {
274 Err("Gossip protocol not configured".to_string())
275 }
276 }
277
278 pub async fn process_gossip_events(&mut self) -> Result<usize, String> {
280 let events: Vec<GossipEvent> = {
282 let rx = match &mut self.gossip_event_rx {
283 Some(rx) => rx,
284 None => return Ok(0),
285 };
286
287 let mut events = Vec::new();
288 loop {
289 match rx.try_recv() {
290 Ok(event) => events.push(event),
291 Err(mpsc::error::TryRecvError::Empty) => break,
292 Err(mpsc::error::TryRecvError::Disconnected) => {
293 return Err("Gossip event channel disconnected".to_string());
294 }
295 }
296 }
297 events
298 };
299
300 let count = events.len();
302 for event in events {
303 self.handle_gossip_event(event)?;
304 }
305
306 Ok(count)
307 }
308
309 fn handle_gossip_event(&self, event: GossipEvent) -> Result<(), String> {
311 match event {
312 GossipEvent::NodeJoined(member) => {
313 self.handle_member_joined(member)?;
314 }
315 GossipEvent::NodeLeft(node_id) => {
316 self.handle_member_left(&node_id)?;
317 }
318 GossipEvent::NodeFailed(node_id) => {
319 self.handle_member_failed(&node_id)?;
320 }
321 GossipEvent::NodeRecovered(node_id) => {
322 self.handle_member_recovered(&node_id)?;
323 }
324 GossipEvent::NodeUpdated(member) => {
325 self.handle_member_state_updated(member)?;
326 }
327 }
328 Ok(())
329 }
330
331 fn handle_member_joined(&self, member: GossipMember) -> Result<(), String> {
333 let node = NodeInfo::new(
334 member.node_id.clone(),
335 member.address.to_string(),
336 NodeRole::Replica, );
338 self.register_node(node)
339 }
340
341 fn handle_member_left(&self, node_id: &str) -> Result<(), String> {
343 self.deregister_node(node_id)?;
344 Ok(())
345 }
346
347 fn handle_member_failed(&self, node_id: &str) -> Result<(), String> {
349 let mut state = self.state.write().map_err(|e| e.to_string())?;
350
351 if let Some(node) = state.nodes.get_mut(node_id) {
352 node.health.status = NodeStatus::Unhealthy;
353 node.health.failure_count = self.config.failure_threshold;
354
355 if node.role == NodeRole::Primary && self.config.auto_failover {
357 let node_id_clone = node_id.to_string();
358 self.trigger_failover(&mut state, &node_id_clone);
359 }
360 }
361
362 self.update_cluster_health(&mut state);
363 Ok(())
364 }
365
366 fn handle_member_recovered(&self, node_id: &str) -> Result<(), String> {
368 let mut state = self.state.write().map_err(|e| e.to_string())?;
369
370 if let Some(node) = state.nodes.get_mut(node_id) {
371 node.health.status = NodeStatus::Healthy;
372 node.health.failure_count = 0;
373 node.health.success_count = self.config.recovery_threshold;
374 node.health.last_healthy_ms = current_time_ms();
375 }
376
377 self.update_cluster_health(&mut state);
378 Ok(())
379 }
380
381 fn handle_member_state_updated(&self, member: GossipMember) -> Result<(), String> {
383 let mut state = self.state.write().map_err(|e| e.to_string())?;
384
385 let member_addr_str = member.address.to_string();
386 if let Some(node) = state.nodes.get_mut(&member.node_id) {
387 if node.address != member_addr_str {
389 node.address = member_addr_str;
390 }
391
392 node.health.status = match member.state {
394 MemberState::Alive => NodeStatus::Healthy,
395 MemberState::Suspect => NodeStatus::Suspect,
396 MemberState::Dead => NodeStatus::Unhealthy,
397 MemberState::Left => NodeStatus::Offline,
398 };
399
400 for (key, value) in member.metadata {
402 node.metadata.insert(key, value);
403 }
404 }
405
406 self.update_cluster_health(&mut state);
407 Ok(())
408 }
409
410 pub fn gossip(&self) -> Option<&Arc<GossipProtocol>> {
412 self.gossip.as_ref()
413 }
414
415 pub async fn get_gossip_members(&self) -> Vec<GossipMember> {
417 if let Some(gossip) = &self.gossip {
418 gossip.get_members().await
419 } else {
420 Vec::new()
421 }
422 }
423
424 pub async fn broadcast_metadata(&self, key: String, value: String) -> Result<(), String> {
426 if let Some(gossip) = &self.gossip {
427 gossip.update_metadata(key, value).await;
428 Ok(())
429 } else {
430 Err("Gossip protocol not configured".to_string())
431 }
432 }
433
434 pub fn register_node(&self, node: NodeInfo) -> Result<(), String> {
436 let mut state = self.state.write().map_err(|e| e.to_string())?;
437
438 let gen = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
439 state.generation = gen;
440
441 state.nodes.insert(node.node_id.clone(), node);
442 state.total_node_count = state.nodes.len() as u32;
443
444 self.update_cluster_health(&mut state);
445
446 Ok(())
447 }
448
449 pub fn deregister_node(&self, node_id: &str) -> Result<Option<NodeInfo>, String> {
451 let mut state = self.state.write().map_err(|e| e.to_string())?;
452
453 let gen = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
454 state.generation = gen;
455
456 let removed = state.nodes.remove(node_id);
457 state.total_node_count = state.nodes.len() as u32;
458
459 if state.leader_id.as_deref() == Some(node_id) {
461 state.leader_id = None;
462 self.elect_leader(&mut state);
463 }
464
465 self.update_cluster_health(&mut state);
466
467 Ok(removed)
468 }
469
470 pub fn update_node_health(&self, node_id: &str, health: NodeHealth) -> Result<(), String> {
472 let mut state = self.state.write().map_err(|e| e.to_string())?;
473
474 let transition_info = if let Some(node) = state.nodes.get_mut(node_id) {
476 let old_status = node.health.status;
477 let new_status = health.status;
478 let role = node.role;
479 node.health = health;
480
481 if old_status != new_status {
482 Some((old_status, new_status, role))
483 } else {
484 None
485 }
486 } else {
487 None
488 };
489
490 if let Some((old_status, new_status, role)) = transition_info {
492 let gen = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
493 state.generation = gen;
494
495 if old_status == NodeStatus::Healthy
497 && new_status == NodeStatus::Unhealthy
498 && role == NodeRole::Primary
499 && self.config.auto_failover
500 {
501 self.trigger_failover(&mut state, node_id);
502 }
503 }
504
505 self.update_cluster_health(&mut state);
506
507 Ok(())
508 }
509
510 pub fn record_health_success(&self, node_id: &str) -> Result<(), String> {
512 let mut state = self.state.write().map_err(|e| e.to_string())?;
513
514 if let Some(node) = state.nodes.get_mut(node_id) {
515 node.health.success_count += 1;
516 node.health.failure_count = 0;
517 node.health.last_healthy_ms = current_time_ms();
518
519 if (matches!(
521 node.health.status,
522 NodeStatus::Suspect | NodeStatus::Unhealthy
523 ) && node.health.success_count >= self.config.recovery_threshold)
524 || node.health.status == NodeStatus::Joining
525 {
526 info!(node_id = %node_id, old_status = ?node.health.status, "Node recovered to Healthy");
527 node.health.status = NodeStatus::Healthy;
528 self.update_cluster_health(&mut state);
529 }
530 }
531
532 Ok(())
533 }
534
535 pub fn record_health_failure(&self, node_id: &str) -> Result<(), String> {
537 let mut state = self.state.write().map_err(|e| e.to_string())?;
538
539 if let Some(node) = state.nodes.get_mut(node_id) {
540 node.health.failure_count += 1;
541 node.health.success_count = 0;
542
543 if node.health.failure_count >= self.config.failure_threshold {
545 if node.health.status != NodeStatus::Unhealthy {
546 node.health.status = NodeStatus::Unhealthy;
547
548 if node.role == NodeRole::Primary && self.config.auto_failover {
550 let node_id_clone = node_id.to_string();
551 self.trigger_failover(&mut state, &node_id_clone);
552 }
553 }
554 } else if node.health.status == NodeStatus::Healthy {
555 node.health.status = NodeStatus::Suspect;
556 }
557
558 self.update_cluster_health(&mut state);
559 }
560
561 Ok(())
562 }
563
564 pub fn get_state(&self) -> ClusterState {
566 self.state
567 .read()
568 .expect("cluster state lock poisoned in get_state")
569 .clone()
570 }
571
572 pub fn get_healthy_nodes_for_shard(&self, shard_id: u32) -> Vec<NodeInfo> {
574 let state = self
575 .state
576 .read()
577 .expect("cluster state lock poisoned in get_healthy_nodes_for_shard");
578
579 state
580 .nodes
581 .values()
582 .filter(|n| n.shard_ids.contains(&shard_id) && n.can_serve_reads())
583 .cloned()
584 .collect()
585 }
586
587 pub fn get_primary_for_shard(&self, shard_id: u32) -> Option<NodeInfo> {
589 let state = self
590 .state
591 .read()
592 .expect("cluster state lock poisoned in get_primary_for_shard");
593
594 state
595 .nodes
596 .values()
597 .find(|n| {
598 n.shard_ids.contains(&shard_id)
599 && n.role == NodeRole::Primary
600 && n.can_serve_writes()
601 })
602 .cloned()
603 }
604
605 pub fn get_healthy_nodes(&self) -> Vec<NodeInfo> {
607 let state = self
608 .state
609 .read()
610 .expect("cluster state lock poisoned in get_healthy_nodes");
611
612 state
613 .nodes
614 .values()
615 .filter(|n| n.can_serve_reads())
616 .cloned()
617 .collect()
618 }
619
620 pub fn has_quorum(&self) -> bool {
622 self.state
623 .read()
624 .expect("cluster state lock poisoned in has_quorum")
625 .has_quorum
626 }
627
628 pub fn uptime_secs(&self) -> u64 {
630 self.start_time.elapsed().as_secs()
631 }
632
633 fn update_cluster_health(&self, state: &mut ClusterState) {
636 state.healthy_node_count =
637 state.nodes.values().filter(|n| n.can_serve_reads()).count() as u32;
638
639 state.has_quorum = state.healthy_node_count >= self.config.min_quorum;
640 state.is_healthy = state.has_quorum;
641 state.last_update_ms = current_time_ms();
642 }
643
644 fn elect_leader(&self, state: &mut ClusterState) {
645 let leader = state
647 .nodes
648 .values()
649 .filter(|n| n.can_serve_reads() && n.role == NodeRole::Primary)
650 .min_by(|a, b| a.node_id.cmp(&b.node_id));
651
652 state.leader_id = leader.map(|n| n.node_id.clone());
653 }
654
655 fn trigger_failover(&self, state: &mut ClusterState, failed_node_id: &str) {
656 let shards: Vec<u32> = state
658 .nodes
659 .get(failed_node_id)
660 .map(|n| n.shard_ids.clone())
661 .unwrap_or_default();
662
663 for shard_id in shards {
665 let replica = state.nodes.values_mut().find(|n| {
667 n.node_id != failed_node_id
668 && n.shard_ids.contains(&shard_id)
669 && n.role == NodeRole::Replica
670 && n.can_serve_reads()
671 });
672
673 if let Some(new_primary) = replica {
674 new_primary.role = NodeRole::Primary;
675 }
676 }
677
678 if state.leader_id.as_deref() == Some(failed_node_id) {
680 self.elect_leader(state);
681 }
682 }
683}
684
685fn current_time_ms() -> u64 {
687 std::time::SystemTime::now()
688 .duration_since(std::time::UNIX_EPOCH)
689 .unwrap_or(Duration::ZERO)
690 .as_millis() as u64
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696
697 #[test]
698 fn test_node_registration() {
699 let config = ClusterConfig::default();
700 let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
701
702 let node = NodeInfo::new(
703 "node-1".to_string(),
704 "localhost:8080".to_string(),
705 NodeRole::Primary,
706 );
707
708 coordinator.register_node(node).unwrap();
709
710 let state = coordinator.get_state();
711 assert_eq!(state.nodes.len(), 1);
712 assert!(state.nodes.contains_key("node-1"));
713 }
714
715 #[test]
716 fn test_health_transitions() {
717 let config = ClusterConfig {
718 failure_threshold: 2,
719 recovery_threshold: 2,
720 ..Default::default()
721 };
722 let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
723
724 let mut node = NodeInfo::new(
725 "node-1".to_string(),
726 "localhost:8080".to_string(),
727 NodeRole::Primary,
728 );
729 node.health.status = NodeStatus::Healthy;
730 coordinator.register_node(node).unwrap();
731
732 coordinator.record_health_failure("node-1").unwrap();
734 let state = coordinator.get_state();
735 assert_eq!(state.nodes["node-1"].health.status, NodeStatus::Suspect);
736
737 coordinator.record_health_failure("node-1").unwrap();
738 let state = coordinator.get_state();
739 assert_eq!(state.nodes["node-1"].health.status, NodeStatus::Unhealthy);
740
741 coordinator.record_health_success("node-1").unwrap();
743 coordinator.record_health_success("node-1").unwrap();
744 }
746
747 #[test]
748 fn test_quorum() {
749 let config = ClusterConfig {
750 min_quorum: 2,
751 ..Default::default()
752 };
753 let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
754
755 let mut node1 = NodeInfo::new(
757 "node-1".to_string(),
758 "localhost:8080".to_string(),
759 NodeRole::Primary,
760 );
761 node1.health.status = NodeStatus::Healthy;
762 coordinator.register_node(node1).unwrap();
763
764 assert!(!coordinator.has_quorum());
765
766 let mut node2 = NodeInfo::new(
768 "node-2".to_string(),
769 "localhost:8081".to_string(),
770 NodeRole::Replica,
771 );
772 node2.health.status = NodeStatus::Healthy;
773 coordinator.register_node(node2).unwrap();
774
775 assert!(coordinator.has_quorum());
776 }
777
778 #[test]
779 fn test_get_nodes_for_shard() {
780 let config = ClusterConfig::default();
781 let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
782
783 let mut node1 = NodeInfo::new(
784 "node-1".to_string(),
785 "localhost:8080".to_string(),
786 NodeRole::Primary,
787 );
788 node1.shard_ids = vec![0, 1];
789 node1.health.status = NodeStatus::Healthy;
790 coordinator.register_node(node1).unwrap();
791
792 let mut node2 = NodeInfo::new(
793 "node-2".to_string(),
794 "localhost:8081".to_string(),
795 NodeRole::Replica,
796 );
797 node2.shard_ids = vec![0, 2];
798 node2.health.status = NodeStatus::Healthy;
799 coordinator.register_node(node2).unwrap();
800
801 let shard0_nodes = coordinator.get_healthy_nodes_for_shard(0);
802 assert_eq!(shard0_nodes.len(), 2);
803
804 let shard1_nodes = coordinator.get_healthy_nodes_for_shard(1);
805 assert_eq!(shard1_nodes.len(), 1);
806
807 let shard2_nodes = coordinator.get_healthy_nodes_for_shard(2);
808 assert_eq!(shard2_nodes.len(), 1);
809 }
810
811 #[test]
812 fn test_deregister_node() {
813 let config = ClusterConfig::default();
814 let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
815
816 let node = NodeInfo::new(
817 "node-1".to_string(),
818 "localhost:8080".to_string(),
819 NodeRole::Primary,
820 );
821 coordinator.register_node(node).unwrap();
822
823 let removed = coordinator.deregister_node("node-1").unwrap();
824 assert!(removed.is_some());
825
826 let state = coordinator.get_state();
827 assert!(state.nodes.is_empty());
828 }
829}