use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tracing::info;
use super::gossip::{GossipConfig, GossipEvent, GossipMember, GossipProtocol, MemberState};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterConfig {
pub health_check_interval_ms: u64,
pub health_timeout_ms: u64,
pub failure_threshold: u32,
pub recovery_threshold: u32,
pub auto_failover: bool,
pub min_quorum: u32,
}
impl Default for ClusterConfig {
fn default() -> Self {
Self {
health_check_interval_ms: 5000,
health_timeout_ms: 10000,
failure_threshold: 3,
recovery_threshold: 2,
auto_failover: true,
min_quorum: 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeRole {
Primary,
Replica,
Coordinator,
Observer,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeStatus {
Healthy,
Suspect,
Unhealthy,
Draining,
Offline,
Joining,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeHealth {
pub status: NodeStatus,
pub last_healthy_ms: u64,
pub failure_count: u32,
pub success_count: u32,
pub avg_response_ms: f64,
pub cpu_percent: f32,
pub memory_percent: f32,
pub active_connections: u32,
pub replication_lag_ms: Option<u64>,
}
impl Default for NodeHealth {
fn default() -> Self {
Self {
status: NodeStatus::Joining,
last_healthy_ms: 0,
failure_count: 0,
success_count: 0,
avg_response_ms: 0.0,
cpu_percent: 0.0,
memory_percent: 0.0,
active_connections: 0,
replication_lag_ms: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub node_id: String,
pub address: String,
pub role: NodeRole,
pub shard_ids: Vec<u32>,
pub health: NodeHealth,
pub metadata: HashMap<String, String>,
pub generation: u64,
}
impl NodeInfo {
pub fn new(node_id: String, address: String, role: NodeRole) -> Self {
Self {
node_id,
address,
role,
shard_ids: Vec::new(),
health: NodeHealth::default(),
metadata: HashMap::new(),
generation: 0,
}
}
pub fn can_serve_reads(&self) -> bool {
matches!(
self.health.status,
NodeStatus::Healthy | NodeStatus::Draining
)
}
pub fn can_serve_writes(&self) -> bool {
self.health.status == NodeStatus::Healthy && self.role == NodeRole::Primary
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ClusterState {
pub generation: u64,
pub nodes: HashMap<String, NodeInfo>,
pub leader_id: Option<String>,
pub is_healthy: bool,
pub has_quorum: bool,
pub healthy_node_count: u32,
pub total_node_count: u32,
pub last_update_ms: u64,
}
pub struct ClusterCoordinator {
config: ClusterConfig,
state: Arc<RwLock<ClusterState>>,
_local_node_id: String,
generation: AtomicU64,
start_time: Instant,
gossip: Option<Arc<GossipProtocol>>,
gossip_event_rx: Option<mpsc::Receiver<GossipEvent>>,
}
impl ClusterCoordinator {
pub fn new(config: ClusterConfig, local_node_id: String) -> Self {
Self {
config,
state: Arc::new(RwLock::new(ClusterState::default())),
_local_node_id: local_node_id,
generation: AtomicU64::new(0),
start_time: Instant::now(),
gossip: None,
gossip_event_rx: None,
}
}
pub fn with_gossip(
config: ClusterConfig,
local_node_id: String,
local_address: std::net::SocketAddr,
api_address: String,
role: NodeRole,
gossip_config: GossipConfig,
) -> Self {
let (event_tx, event_rx) = mpsc::channel(1000);
let local_member =
GossipMember::new(local_node_id.clone(), local_address, api_address, role);
let gossip = GossipProtocol::new(gossip_config, local_member, event_tx);
Self {
config,
state: Arc::new(RwLock::new(ClusterState::default())),
_local_node_id: local_node_id,
generation: AtomicU64::new(0),
start_time: Instant::now(),
gossip: Some(Arc::new(gossip)),
gossip_event_rx: Some(event_rx),
}
}
pub async fn start_gossip(&self) -> Result<(), String> {
if let Some(gossip) = &self.gossip {
gossip.start().await.map_err(|e| e.to_string())
} else {
Err("Gossip protocol not configured".to_string())
}
}
pub fn stop_gossip(&self) {
if let Some(gossip) = &self.gossip {
gossip.stop();
}
}
pub async fn leave_cluster(&self) -> Result<(), String> {
if let Some(gossip) = &self.gossip {
gossip.leave().await.map_err(|e| e.to_string())
} else {
Err("Gossip protocol not configured".to_string())
}
}
pub async fn process_gossip_events(&mut self) -> Result<usize, String> {
let events: Vec<GossipEvent> = {
let rx = match &mut self.gossip_event_rx {
Some(rx) => rx,
None => return Ok(0),
};
let mut events = Vec::new();
loop {
match rx.try_recv() {
Ok(event) => events.push(event),
Err(mpsc::error::TryRecvError::Empty) => break,
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err("Gossip event channel disconnected".to_string());
}
}
}
events
};
let count = events.len();
for event in events {
self.handle_gossip_event(event)?;
}
Ok(count)
}
fn handle_gossip_event(&self, event: GossipEvent) -> Result<(), String> {
match event {
GossipEvent::NodeJoined(member) => {
self.handle_member_joined(member)?;
}
GossipEvent::NodeLeft(node_id) => {
self.handle_member_left(&node_id)?;
}
GossipEvent::NodeFailed(node_id) => {
self.handle_member_failed(&node_id)?;
}
GossipEvent::NodeRecovered(node_id) => {
self.handle_member_recovered(&node_id)?;
}
GossipEvent::NodeUpdated(member) => {
self.handle_member_state_updated(member)?;
}
}
Ok(())
}
fn handle_member_joined(&self, member: GossipMember) -> Result<(), String> {
let node = NodeInfo::new(
member.node_id.clone(),
member.address.to_string(),
NodeRole::Replica, );
self.register_node(node)
}
fn handle_member_left(&self, node_id: &str) -> Result<(), String> {
self.deregister_node(node_id)?;
Ok(())
}
fn handle_member_failed(&self, node_id: &str) -> Result<(), String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
if let Some(node) = state.nodes.get_mut(node_id) {
node.health.status = NodeStatus::Unhealthy;
node.health.failure_count = self.config.failure_threshold;
if node.role == NodeRole::Primary && self.config.auto_failover {
let node_id_clone = node_id.to_string();
self.trigger_failover(&mut state, &node_id_clone);
}
}
self.update_cluster_health(&mut state);
Ok(())
}
fn handle_member_recovered(&self, node_id: &str) -> Result<(), String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
if let Some(node) = state.nodes.get_mut(node_id) {
node.health.status = NodeStatus::Healthy;
node.health.failure_count = 0;
node.health.success_count = self.config.recovery_threshold;
node.health.last_healthy_ms = current_time_ms();
}
self.update_cluster_health(&mut state);
Ok(())
}
fn handle_member_state_updated(&self, member: GossipMember) -> Result<(), String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
let member_addr_str = member.address.to_string();
if let Some(node) = state.nodes.get_mut(&member.node_id) {
if node.address != member_addr_str {
node.address = member_addr_str;
}
node.health.status = match member.state {
MemberState::Alive => NodeStatus::Healthy,
MemberState::Suspect => NodeStatus::Suspect,
MemberState::Dead => NodeStatus::Unhealthy,
MemberState::Left => NodeStatus::Offline,
};
for (key, value) in member.metadata {
node.metadata.insert(key, value);
}
}
self.update_cluster_health(&mut state);
Ok(())
}
pub fn gossip(&self) -> Option<&Arc<GossipProtocol>> {
self.gossip.as_ref()
}
pub async fn get_gossip_members(&self) -> Vec<GossipMember> {
if let Some(gossip) = &self.gossip {
gossip.get_members().await
} else {
Vec::new()
}
}
pub async fn broadcast_metadata(&self, key: String, value: String) -> Result<(), String> {
if let Some(gossip) = &self.gossip {
gossip.update_metadata(key, value).await;
Ok(())
} else {
Err("Gossip protocol not configured".to_string())
}
}
pub fn register_node(&self, node: NodeInfo) -> Result<(), String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
let gen = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
state.generation = gen;
state.nodes.insert(node.node_id.clone(), node);
state.total_node_count = state.nodes.len() as u32;
self.update_cluster_health(&mut state);
Ok(())
}
pub fn deregister_node(&self, node_id: &str) -> Result<Option<NodeInfo>, String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
let gen = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
state.generation = gen;
let removed = state.nodes.remove(node_id);
state.total_node_count = state.nodes.len() as u32;
if state.leader_id.as_deref() == Some(node_id) {
state.leader_id = None;
self.elect_leader(&mut state);
}
self.update_cluster_health(&mut state);
Ok(removed)
}
pub fn update_node_health(&self, node_id: &str, health: NodeHealth) -> Result<(), String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
let transition_info = if let Some(node) = state.nodes.get_mut(node_id) {
let old_status = node.health.status;
let new_status = health.status;
let role = node.role;
node.health = health;
if old_status != new_status {
Some((old_status, new_status, role))
} else {
None
}
} else {
None
};
if let Some((old_status, new_status, role)) = transition_info {
let gen = self.generation.fetch_add(1, Ordering::SeqCst) + 1;
state.generation = gen;
if old_status == NodeStatus::Healthy
&& new_status == NodeStatus::Unhealthy
&& role == NodeRole::Primary
&& self.config.auto_failover
{
self.trigger_failover(&mut state, node_id);
}
}
self.update_cluster_health(&mut state);
Ok(())
}
pub fn record_health_success(&self, node_id: &str) -> Result<(), String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
if let Some(node) = state.nodes.get_mut(node_id) {
node.health.success_count += 1;
node.health.failure_count = 0;
node.health.last_healthy_ms = current_time_ms();
if (matches!(
node.health.status,
NodeStatus::Suspect | NodeStatus::Unhealthy
) && node.health.success_count >= self.config.recovery_threshold)
|| node.health.status == NodeStatus::Joining
{
info!(node_id = %node_id, old_status = ?node.health.status, "Node recovered to Healthy");
node.health.status = NodeStatus::Healthy;
self.update_cluster_health(&mut state);
}
}
Ok(())
}
pub fn record_health_failure(&self, node_id: &str) -> Result<(), String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
if let Some(node) = state.nodes.get_mut(node_id) {
node.health.failure_count += 1;
node.health.success_count = 0;
if node.health.failure_count >= self.config.failure_threshold {
if node.health.status != NodeStatus::Unhealthy {
node.health.status = NodeStatus::Unhealthy;
if node.role == NodeRole::Primary && self.config.auto_failover {
let node_id_clone = node_id.to_string();
self.trigger_failover(&mut state, &node_id_clone);
}
}
} else if node.health.status == NodeStatus::Healthy {
node.health.status = NodeStatus::Suspect;
}
self.update_cluster_health(&mut state);
}
Ok(())
}
pub fn get_state(&self) -> ClusterState {
self.state
.read()
.expect("cluster state lock poisoned in get_state")
.clone()
}
pub fn get_healthy_nodes_for_shard(&self, shard_id: u32) -> Vec<NodeInfo> {
let state = self
.state
.read()
.expect("cluster state lock poisoned in get_healthy_nodes_for_shard");
state
.nodes
.values()
.filter(|n| n.shard_ids.contains(&shard_id) && n.can_serve_reads())
.cloned()
.collect()
}
pub fn get_primary_for_shard(&self, shard_id: u32) -> Option<NodeInfo> {
let state = self
.state
.read()
.expect("cluster state lock poisoned in get_primary_for_shard");
state
.nodes
.values()
.find(|n| {
n.shard_ids.contains(&shard_id)
&& n.role == NodeRole::Primary
&& n.can_serve_writes()
})
.cloned()
}
pub fn get_healthy_nodes(&self) -> Vec<NodeInfo> {
let state = self
.state
.read()
.expect("cluster state lock poisoned in get_healthy_nodes");
state
.nodes
.values()
.filter(|n| n.can_serve_reads())
.cloned()
.collect()
}
pub fn has_quorum(&self) -> bool {
self.state
.read()
.expect("cluster state lock poisoned in has_quorum")
.has_quorum
}
pub fn uptime_secs(&self) -> u64 {
self.start_time.elapsed().as_secs()
}
fn update_cluster_health(&self, state: &mut ClusterState) {
state.healthy_node_count =
state.nodes.values().filter(|n| n.can_serve_reads()).count() as u32;
state.has_quorum = state.healthy_node_count >= self.config.min_quorum;
state.is_healthy = state.has_quorum;
state.last_update_ms = current_time_ms();
}
fn elect_leader(&self, state: &mut ClusterState) {
let leader = state
.nodes
.values()
.filter(|n| n.can_serve_reads() && n.role == NodeRole::Primary)
.min_by(|a, b| a.node_id.cmp(&b.node_id));
state.leader_id = leader.map(|n| n.node_id.clone());
}
fn trigger_failover(&self, state: &mut ClusterState, failed_node_id: &str) {
let shards: Vec<u32> = state
.nodes
.get(failed_node_id)
.map(|n| n.shard_ids.clone())
.unwrap_or_default();
for shard_id in shards {
let replica = state.nodes.values_mut().find(|n| {
n.node_id != failed_node_id
&& n.shard_ids.contains(&shard_id)
&& n.role == NodeRole::Replica
&& n.can_serve_reads()
});
if let Some(new_primary) = replica {
new_primary.role = NodeRole::Primary;
}
}
if state.leader_id.as_deref() == Some(failed_node_id) {
self.elect_leader(state);
}
}
}
fn current_time_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_registration() {
let config = ClusterConfig::default();
let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
let node = NodeInfo::new(
"node-1".to_string(),
"localhost:8080".to_string(),
NodeRole::Primary,
);
coordinator.register_node(node).unwrap();
let state = coordinator.get_state();
assert_eq!(state.nodes.len(), 1);
assert!(state.nodes.contains_key("node-1"));
}
#[test]
fn test_health_transitions() {
let config = ClusterConfig {
failure_threshold: 2,
recovery_threshold: 2,
..Default::default()
};
let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
let mut node = NodeInfo::new(
"node-1".to_string(),
"localhost:8080".to_string(),
NodeRole::Primary,
);
node.health.status = NodeStatus::Healthy;
coordinator.register_node(node).unwrap();
coordinator.record_health_failure("node-1").unwrap();
let state = coordinator.get_state();
assert_eq!(state.nodes["node-1"].health.status, NodeStatus::Suspect);
coordinator.record_health_failure("node-1").unwrap();
let state = coordinator.get_state();
assert_eq!(state.nodes["node-1"].health.status, NodeStatus::Unhealthy);
coordinator.record_health_success("node-1").unwrap();
coordinator.record_health_success("node-1").unwrap();
}
#[test]
fn test_quorum() {
let config = ClusterConfig {
min_quorum: 2,
..Default::default()
};
let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
let mut node1 = NodeInfo::new(
"node-1".to_string(),
"localhost:8080".to_string(),
NodeRole::Primary,
);
node1.health.status = NodeStatus::Healthy;
coordinator.register_node(node1).unwrap();
assert!(!coordinator.has_quorum());
let mut node2 = NodeInfo::new(
"node-2".to_string(),
"localhost:8081".to_string(),
NodeRole::Replica,
);
node2.health.status = NodeStatus::Healthy;
coordinator.register_node(node2).unwrap();
assert!(coordinator.has_quorum());
}
#[test]
fn test_get_nodes_for_shard() {
let config = ClusterConfig::default();
let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
let mut node1 = NodeInfo::new(
"node-1".to_string(),
"localhost:8080".to_string(),
NodeRole::Primary,
);
node1.shard_ids = vec![0, 1];
node1.health.status = NodeStatus::Healthy;
coordinator.register_node(node1).unwrap();
let mut node2 = NodeInfo::new(
"node-2".to_string(),
"localhost:8081".to_string(),
NodeRole::Replica,
);
node2.shard_ids = vec![0, 2];
node2.health.status = NodeStatus::Healthy;
coordinator.register_node(node2).unwrap();
let shard0_nodes = coordinator.get_healthy_nodes_for_shard(0);
assert_eq!(shard0_nodes.len(), 2);
let shard1_nodes = coordinator.get_healthy_nodes_for_shard(1);
assert_eq!(shard1_nodes.len(), 1);
let shard2_nodes = coordinator.get_healthy_nodes_for_shard(2);
assert_eq!(shard2_nodes.len(), 1);
}
#[test]
fn test_deregister_node() {
let config = ClusterConfig::default();
let coordinator = ClusterCoordinator::new(config, "node-1".to_string());
let node = NodeInfo::new(
"node-1".to_string(),
"localhost:8080".to_string(),
NodeRole::Primary,
);
coordinator.register_node(node).unwrap();
let removed = coordinator.deregister_node("node-1").unwrap();
assert!(removed.is_some());
let state = coordinator.get_state();
assert!(state.nodes.is_empty());
}
}