use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::failover::FailoverController;
use crate::shard::{ShardId, ShardRegistry};
use crate::types::NodeId;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeState {
Leader,
Follower,
Candidate,
Offline,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeStatus {
pub node_id: NodeId,
pub state: NodeState,
pub shard_count: usize,
pub last_heartbeat_ms: Option<u64>,
pub log_index: u64,
pub is_leader: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterTopology {
pub nodes: Vec<NodeStatus>,
pub total_shards: usize,
pub leader_node_id: Option<NodeId>,
pub shard_distribution: HashMap<NodeId, Vec<ShardId>>,
}
pub struct TopologyCollector {
failover: Arc<FailoverController>,
registry: Option<Arc<ShardRegistry>>,
}
impl TopologyCollector {
pub fn new(failover: Arc<FailoverController>) -> Self {
Self {
failover,
registry: None,
}
}
pub fn with_registry(failover: Arc<FailoverController>, registry: Arc<ShardRegistry>) -> Self {
Self {
failover,
registry: Some(registry),
}
}
pub fn snapshot(&self, nodes: &[NodeId]) -> ClusterTopology {
self.snapshot_with_leader(nodes, None)
}
pub fn snapshot_with_leader(
&self,
nodes: &[NodeId],
leader_hint: Option<NodeId>,
) -> ClusterTopology {
let mut shard_distribution: HashMap<NodeId, Vec<ShardId>> =
nodes.iter().map(|&nid| (nid, Vec::new())).collect();
if let Some(ref reg) = self.registry {
for shard in reg.get_all() {
shard_distribution
.entry(shard.node_id)
.or_default()
.push(shard.id);
}
}
let total_shards: usize = shard_distribution.values().map(|v| v.len()).sum();
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let mut node_statuses = Vec::with_capacity(nodes.len());
let failed_set: std::collections::HashSet<NodeId> =
self.failover.failed_nodes().into_iter().collect();
for &node_id in nodes {
let is_offline = failed_set.contains(&node_id);
let is_leader = leader_hint.map(|l| l == node_id).unwrap_or(false);
let state = if is_offline {
NodeState::Offline
} else if is_leader {
NodeState::Leader
} else {
NodeState::Follower
};
let shard_count = shard_distribution
.get(&node_id)
.map(|v| v.len())
.unwrap_or(0);
let last_heartbeat_ms = if is_offline { None } else { Some(now_ms) };
node_statuses.push(NodeStatus {
node_id,
state,
shard_count,
last_heartbeat_ms,
log_index: 0,
is_leader,
});
}
ClusterTopology {
nodes: node_statuses,
total_shards,
leader_node_id: leader_hint,
shard_distribution,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::failover::FailoverController;
use crate::shard::{KeyRange, ShardMetadata, ShardRegistry};
use amaters_core::Key;
use std::time::Duration;
fn make_controller() -> Arc<FailoverController> {
Arc::new(FailoverController::new(Duration::from_millis(500)))
}
fn make_registry_with_shards() -> Arc<ShardRegistry> {
let reg = ShardRegistry::new();
for (shard_id, node_id, s, e) in [
(1u64, 1u64, "a0", "a1"),
(2, 1, "a1", "a2"),
(3, 2, "b0", "b1"),
] {
let range = KeyRange::new(Key::from_str(s), Key::from_str(e)).expect("range");
reg.register(ShardMetadata::new(shard_id, range, node_id))
.expect("register");
}
Arc::new(reg)
}
#[test]
fn test_topology_snapshot_contains_all_nodes() {
let controller = make_controller();
let collector = TopologyCollector::new(Arc::clone(&controller));
let nodes = vec![1u64, 2, 3];
let topology = collector.snapshot(&nodes);
assert_eq!(topology.nodes.len(), 3, "topology must contain all 3 nodes");
let ids: Vec<NodeId> = topology.nodes.iter().map(|n| n.node_id).collect();
for &nid in &nodes {
assert!(ids.contains(&nid), "node {} must appear in topology", nid);
}
}
#[test]
fn test_topology_marks_failed_nodes_offline() {
let controller = make_controller();
controller.mark_failed(2);
let collector = TopologyCollector::new(Arc::clone(&controller));
let nodes = vec![1u64, 2, 3];
let topology = collector.snapshot(&nodes);
let node2 = topology
.nodes
.iter()
.find(|n| n.node_id == 2)
.expect("node 2 must be present");
assert_eq!(
node2.state,
NodeState::Offline,
"failed node must be marked Offline"
);
assert!(
node2.last_heartbeat_ms.is_none(),
"offline node must have no heartbeat timestamp"
);
for &nid in &[1u64, 3] {
let n = topology
.nodes
.iter()
.find(|n| n.node_id == nid)
.expect("node must be present");
assert_ne!(
n.state,
NodeState::Offline,
"node {} should not be offline",
nid
);
}
}
#[test]
fn test_topology_shard_distribution() {
let controller = make_controller();
let registry = make_registry_with_shards();
let collector = TopologyCollector::with_registry(Arc::clone(&controller), registry);
let nodes = vec![1u64, 2];
let topology = collector.snapshot(&nodes);
assert_eq!(topology.total_shards, 3);
let node1_shards = &topology.shard_distribution[&1];
assert_eq!(node1_shards.len(), 2, "node 1 should have 2 shards");
let node2_shards = &topology.shard_distribution[&2];
assert_eq!(node2_shards.len(), 1, "node 2 should have 1 shard");
}
#[test]
fn test_topology_leader_hint() {
let controller = make_controller();
let collector = TopologyCollector::new(Arc::clone(&controller));
let nodes = vec![1u64, 2, 3];
let topology = collector.snapshot_with_leader(&nodes, Some(1));
assert_eq!(topology.leader_node_id, Some(1));
let leader_node = topology
.nodes
.iter()
.find(|n| n.node_id == 1)
.expect("node 1 must be present");
assert!(leader_node.is_leader);
assert_eq!(leader_node.state, NodeState::Leader);
}
#[test]
fn test_topology_serialises_to_json() {
let controller = make_controller();
let collector = TopologyCollector::new(Arc::clone(&controller));
let topology = collector.snapshot(&[1, 2]);
let json = serde_json::to_string(&topology).expect("serialize");
assert!(json.contains("\"nodes\""));
assert!(json.contains("\"total_shards\""));
}
}