use dashmap::DashMap;
use std::{collections::HashMap, sync::Arc};
#[derive(Debug, Clone)]
pub struct Node {
pub id: u32,
pub address: String,
pub healthy: bool,
pub assigned_partitions: Vec<u32>,
}
pub struct NodeRegistry {
partition_count: u32,
nodes: Arc<DashMap<u32, Node>>,
}
impl NodeRegistry {
pub fn new(partition_count: u32) -> Self {
Self {
partition_count,
nodes: Arc::new(DashMap::new()),
}
}
pub fn register_node(&self, mut node: Node) {
node.assigned_partitions.clear();
self.nodes.insert(node.id, node);
self.rebalance_partitions();
}
pub fn unregister_node(&self, node_id: u32) {
self.nodes.remove(&node_id);
self.rebalance_partitions();
}
pub fn set_node_health(&self, node_id: u32, healthy: bool) {
if let Some(mut node) = self.nodes.get_mut(&node_id) {
node.healthy = healthy;
}
self.rebalance_partitions();
}
fn rebalance_partitions(&self) {
for mut entry in self.nodes.iter_mut() {
entry.value_mut().assigned_partitions.clear();
}
let mut healthy_nodes: Vec<u32> = self
.nodes
.iter()
.filter(|entry| entry.value().healthy)
.map(|entry| *entry.key())
.collect();
healthy_nodes.sort_unstable();
if healthy_nodes.is_empty() {
return; }
for partition_id in 0..self.partition_count {
let node_idx = (partition_id as usize) % healthy_nodes.len();
let node_id = healthy_nodes[node_idx];
if let Some(mut node) = self.nodes.get_mut(&node_id) {
node.assigned_partitions.push(partition_id);
}
}
}
pub fn node_for_partition(&self, partition_id: u32) -> Option<u32> {
self.nodes
.iter()
.find(|entry| {
entry.value().healthy && entry.value().assigned_partitions.contains(&partition_id)
})
.map(|entry| entry.value().id)
}
pub fn get_node(&self, node_id: u32) -> Option<Node> {
self.nodes.get(&node_id).map(|entry| entry.value().clone())
}
pub fn all_nodes(&self) -> Vec<Node> {
self.nodes
.iter()
.map(|entry| entry.value().clone())
.collect()
}
pub fn healthy_nodes(&self) -> Vec<Node> {
self.nodes
.iter()
.filter(|entry| entry.value().healthy)
.map(|entry| entry.value().clone())
.collect()
}
pub fn partition_distribution(&self) -> HashMap<u32, Vec<u32>> {
self.nodes
.iter()
.filter(|entry| entry.value().healthy)
.map(|entry| (*entry.key(), entry.value().assigned_partitions.clone()))
.collect()
}
pub fn is_cluster_healthy(&self) -> bool {
for partition_id in 0..self.partition_count {
let has_node = self.nodes.iter().any(|entry| {
entry.value().healthy && entry.value().assigned_partitions.contains(&partition_id)
});
if !has_node {
return false;
}
}
true
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn healthy_node_count(&self) -> usize {
self.nodes
.iter()
.filter(|entry| entry.value().healthy)
.count()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_registry() {
let registry = NodeRegistry::new(32);
assert_eq!(registry.node_count(), 0);
assert_eq!(registry.healthy_node_count(), 0);
}
#[test]
fn test_register_node() {
let registry = NodeRegistry::new(32);
let node = Node {
id: 0,
address: "node-0:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
};
registry.register_node(node);
assert_eq!(registry.node_count(), 1);
assert_eq!(registry.healthy_node_count(), 1);
let node = registry.get_node(0).unwrap();
assert_eq!(node.assigned_partitions.len(), 32);
}
#[test]
fn test_two_node_distribution() {
let registry = NodeRegistry::new(32);
registry.register_node(Node {
id: 0,
address: "node-0:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
registry.register_node(Node {
id: 1,
address: "node-1:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
let node0 = registry.get_node(0).unwrap();
let node1 = registry.get_node(1).unwrap();
assert_eq!(node0.assigned_partitions.len(), 16);
assert_eq!(node1.assigned_partitions.len(), 16);
for partition_id in &node0.assigned_partitions {
assert!(!node1.assigned_partitions.contains(partition_id));
}
}
#[test]
fn test_node_for_partition() {
let registry = NodeRegistry::new(32);
registry.register_node(Node {
id: 0,
address: "node-0:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
registry.register_node(Node {
id: 1,
address: "node-1:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
for partition_id in 0..32 {
let node_id = registry.node_for_partition(partition_id);
assert!(node_id.is_some());
}
}
#[test]
fn test_unhealthy_node_excluded() {
let registry = NodeRegistry::new(32);
registry.register_node(Node {
id: 0,
address: "node-0:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
registry.register_node(Node {
id: 1,
address: "node-1:8080".to_string(),
healthy: false, assigned_partitions: vec![],
});
let node0 = registry.get_node(0).unwrap();
let node1 = registry.get_node(1).unwrap();
assert_eq!(node0.assigned_partitions.len(), 32);
assert_eq!(node1.assigned_partitions.len(), 0);
}
#[test]
fn test_rebalance_on_health_change() {
let registry = NodeRegistry::new(32);
registry.register_node(Node {
id: 0,
address: "node-0:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
registry.register_node(Node {
id: 1,
address: "node-1:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
let node0_before = registry.get_node(0).unwrap();
assert_eq!(node0_before.assigned_partitions.len(), 16);
registry.set_node_health(1, false);
let node0_after = registry.get_node(0).unwrap();
assert_eq!(node0_after.assigned_partitions.len(), 32);
}
#[test]
fn test_cluster_health() {
let registry = NodeRegistry::new(32);
assert!(!registry.is_cluster_healthy());
registry.register_node(Node {
id: 0,
address: "node-0:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
assert!(registry.is_cluster_healthy());
registry.set_node_health(0, false);
assert!(!registry.is_cluster_healthy());
}
#[test]
fn test_partition_distribution() {
let registry = NodeRegistry::new(32);
registry.register_node(Node {
id: 0,
address: "node-0:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
registry.register_node(Node {
id: 1,
address: "node-1:8080".to_string(),
healthy: true,
assigned_partitions: vec![],
});
let distribution = registry.partition_distribution();
assert_eq!(distribution.len(), 2);
assert_eq!(distribution.get(&0).unwrap().len(), 16);
assert_eq!(distribution.get(&1).unwrap().len(), 16);
}
#[test]
fn test_deterministic_assignment() {
let registry1 = NodeRegistry::new(32);
let registry2 = NodeRegistry::new(32);
for i in 0..4 {
let node = Node {
id: i,
address: format!("node-{i}:8080"),
healthy: true,
assigned_partitions: vec![],
};
registry1.register_node(node.clone());
registry2.register_node(node);
}
for partition_id in 0..32 {
let node1 = registry1.node_for_partition(partition_id);
let node2 = registry2.node_for_partition(partition_id);
assert_eq!(node1, node2);
}
}
}