use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap, HashSet};
pub type NodeId = String;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ClusterNode {
pub id: NodeId,
pub address: String,
pub weight: u32,
pub load: usize,
pub active: bool,
pub zone: Option<String>,
}
impl ClusterNode {
pub fn new(id: impl Into<String>, address: impl Into<String>) -> Self {
Self {
id: id.into(),
address: address.into(),
weight: 1,
load: 0,
active: true,
zone: None,
}
}
pub fn with_weight(mut self, weight: u32) -> Self {
self.weight = weight;
self
}
pub fn with_zone(mut self, zone: impl Into<String>) -> Self {
self.zone = Some(zone.into());
self
}
}
#[derive(Debug, Clone)]
pub struct HashRingConfig {
pub vnodes_per_node: usize,
pub load_factor: Option<f64>,
pub replica_count: usize,
}
impl Default for HashRingConfig {
fn default() -> Self {
Self {
vnodes_per_node: 150,
load_factor: Some(1.25),
replica_count: 3,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RingStats {
pub node_count: usize,
pub active_node_count: usize,
pub vnode_count: usize,
pub load_distribution: HashMap<NodeId, usize>,
pub load_stddev: f64,
pub max_load: usize,
pub min_load: usize,
}
pub struct ConsistentHashRing {
config: HashRingConfig,
ring: BTreeMap<u64, (NodeId, usize)>,
nodes: HashMap<NodeId, ClusterNode>,
loads: HashMap<NodeId, usize>,
}
impl ConsistentHashRing {
pub fn new(config: HashRingConfig) -> Self {
Self {
config,
ring: BTreeMap::new(),
nodes: HashMap::new(),
loads: HashMap::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(HashRingConfig::default())
}
pub fn config(&self) -> &HashRingConfig {
&self.config
}
pub fn add_node(&mut self, node: ClusterNode) -> usize {
let node_id = node.id.clone();
let weight = node.weight.max(1) as usize;
let vnode_count = self.config.vnodes_per_node * weight;
self.nodes.insert(node_id.clone(), node);
self.loads.entry(node_id.clone()).or_insert(0);
let mut created = 0;
for i in 0..vnode_count {
let hash = hash_vnode(&node_id, i);
self.ring.insert(hash, (node_id.clone(), i));
created += 1;
}
created
}
pub fn remove_node(&mut self, node_id: &str) -> bool {
if self.nodes.remove(node_id).is_some() {
self.ring.retain(|_, (id, _)| id != node_id);
self.loads.remove(node_id);
true
} else {
false
}
}
pub fn deactivate_node(&mut self, node_id: &str) -> bool {
if let Some(node) = self.nodes.get_mut(node_id) {
node.active = false;
true
} else {
false
}
}
pub fn activate_node(&mut self, node_id: &str) -> bool {
if let Some(node) = self.nodes.get_mut(node_id) {
node.active = true;
true
} else {
false
}
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn active_node_count(&self) -> usize {
self.nodes.values().filter(|n| n.active).count()
}
pub fn vnode_count(&self) -> usize {
self.ring.len()
}
pub fn get_node(&self, node_id: &str) -> Option<&ClusterNode> {
self.nodes.get(node_id)
}
pub fn node_ids(&self) -> Vec<NodeId> {
self.nodes.keys().cloned().collect()
}
pub fn get_node_for_key(&self, key: &[u8]) -> Option<NodeId> {
if self.ring.is_empty() {
return None;
}
let hash = hash_key(key);
let max_load = self.max_allowed_load();
let candidates = self
.ring
.range(hash..)
.chain(self.ring.iter())
.take(self.ring.len());
for (_, (node_id, _)) in candidates {
let node = match self.nodes.get(node_id) {
Some(n) => n,
None => continue,
};
if !node.active {
continue;
}
if let Some(max) = max_load {
let current_load = self.loads.get(node_id).copied().unwrap_or(0);
if current_load >= max {
continue;
}
}
return Some(node_id.clone());
}
self.ring
.range(hash..)
.chain(self.ring.iter())
.find(|(_, (id, _))| self.nodes.get(id).is_some_and(|n| n.active))
.map(|(_, (id, _))| id.clone())
}
pub fn get_replicas(&self, key: &[u8]) -> Vec<NodeId> {
if self.ring.is_empty() {
return Vec::new();
}
let hash = hash_key(key);
let mut replicas = Vec::new();
let mut seen = HashSet::new();
let candidates = self
.ring
.range(hash..)
.chain(self.ring.iter())
.take(self.ring.len());
for (_, (node_id, _)) in candidates {
if seen.contains(node_id) {
continue;
}
if let Some(node) = self.nodes.get(node_id) {
if node.active {
seen.insert(node_id.clone());
replicas.push(node_id.clone());
if replicas.len() >= self.config.replica_count {
break;
}
}
}
}
replicas
}
pub fn get_zone_aware_replicas(&self, key: &[u8]) -> Vec<NodeId> {
if self.ring.is_empty() {
return Vec::new();
}
let hash = hash_key(key);
let mut replicas = Vec::new();
let mut seen_nodes = HashSet::new();
let mut seen_zones = HashSet::new();
let candidates = self
.ring
.range(hash..)
.chain(self.ring.iter())
.take(self.ring.len());
let all_candidates: Vec<_> = candidates.collect();
for (_, (node_id, _)) in &all_candidates {
if seen_nodes.contains(node_id) {
continue;
}
if let Some(node) = self.nodes.get(node_id) {
if !node.active {
continue;
}
let zone = node.zone.as_deref().unwrap_or("default");
if !seen_zones.contains(zone) {
seen_zones.insert(zone.to_string());
seen_nodes.insert(node_id.clone());
replicas.push(node_id.clone());
if replicas.len() >= self.config.replica_count {
return replicas;
}
}
}
}
for (_, (node_id, _)) in &all_candidates {
if seen_nodes.contains(node_id) {
continue;
}
if let Some(node) = self.nodes.get(node_id) {
if node.active {
seen_nodes.insert(node_id.clone());
replicas.push(node_id.clone());
if replicas.len() >= self.config.replica_count {
break;
}
}
}
}
replicas
}
pub fn increment_load(&mut self, node_id: &str) {
if let Some(load) = self.loads.get_mut(node_id) {
*load += 1;
}
}
pub fn decrement_load(&mut self, node_id: &str) {
if let Some(load) = self.loads.get_mut(node_id) {
*load = load.saturating_sub(1);
}
}
pub fn reset_loads(&mut self) {
for load in self.loads.values_mut() {
*load = 0;
}
}
pub fn node_load(&self, node_id: &str) -> usize {
self.loads.get(node_id).copied().unwrap_or(0)
}
fn max_allowed_load(&self) -> Option<usize> {
let factor = self.config.load_factor?;
let active_count = self.active_node_count();
if active_count == 0 {
return None;
}
let total_load: usize = self.loads.values().sum();
let avg_load = total_load as f64 / active_count as f64;
Some((avg_load * factor).ceil() as usize)
}
pub fn affected_ranges_on_add(
&self,
new_node_id: &str,
weight: u32,
) -> HashMap<NodeId, Vec<u64>> {
let mut affected: HashMap<NodeId, Vec<u64>> = HashMap::new();
let vnode_count = self.config.vnodes_per_node * weight.max(1) as usize;
for i in 0..vnode_count {
let hash = hash_vnode(new_node_id, i);
if let Some((_, (current_owner, _))) = self
.ring
.range(hash..)
.next()
.or_else(|| self.ring.iter().next())
{
affected
.entry(current_owner.clone())
.or_default()
.push(hash);
}
}
affected
}
pub fn affected_ranges_on_remove(&self, node_id: &str) -> HashMap<NodeId, usize> {
let mut affected: HashMap<NodeId, usize> = HashMap::new();
let vnodes: Vec<u64> = self
.ring
.iter()
.filter(|(_, (id, _))| id == node_id)
.map(|(&pos, _)| pos)
.collect();
for pos in vnodes {
let successor = self
.ring
.range((pos + 1)..)
.chain(self.ring.iter())
.find(|(_, (id, _))| id != node_id)
.map(|(_, (id, _))| id.clone());
if let Some(succ_id) = successor {
*affected.entry(succ_id).or_insert(0) += 1;
}
}
affected
}
pub fn stats(&self) -> RingStats {
let load_distribution: HashMap<NodeId, usize> = self.loads.clone();
let loads: Vec<usize> = load_distribution.values().copied().collect();
let (max_load, min_load) = if loads.is_empty() {
(0, 0)
} else {
(
loads.iter().copied().max().unwrap_or(0),
loads.iter().copied().min().unwrap_or(0),
)
};
let mean = if loads.is_empty() {
0.0
} else {
loads.iter().sum::<usize>() as f64 / loads.len() as f64
};
let variance = if loads.is_empty() {
0.0
} else {
loads
.iter()
.map(|&l| (l as f64 - mean).powi(2))
.sum::<f64>()
/ loads.len() as f64
};
RingStats {
node_count: self.nodes.len(),
active_node_count: self.active_node_count(),
vnode_count: self.ring.len(),
load_distribution,
load_stddev: variance.sqrt(),
max_load,
min_load,
}
}
}
fn hash_key(key: &[u8]) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for &byte in key {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}
fn hash_vnode(node_id: &str, vnode_index: usize) -> u64 {
let combined = format!("{node_id}#vnode#{vnode_index}");
hash_key(combined.as_bytes())
}
#[cfg(test)]
mod tests {
use super::*;
fn default_ring() -> ConsistentHashRing {
ConsistentHashRing::with_defaults()
}
fn three_node_ring() -> ConsistentHashRing {
let mut ring = default_ring();
ring.add_node(ClusterNode::new("node-a", "10.0.0.1:8080"));
ring.add_node(ClusterNode::new("node-b", "10.0.0.2:8080"));
ring.add_node(ClusterNode::new("node-c", "10.0.0.3:8080"));
ring
}
#[test]
fn test_add_node() {
let mut ring = default_ring();
let created = ring.add_node(ClusterNode::new("node-1", "10.0.0.1:8080"));
assert_eq!(created, 150); assert_eq!(ring.node_count(), 1);
assert_eq!(ring.vnode_count(), 150);
}
#[test]
fn test_add_weighted_node() {
let mut ring = default_ring();
let node = ClusterNode::new("node-1", "10.0.0.1:8080").with_weight(2);
let created = ring.add_node(node);
assert_eq!(created, 300); }
#[test]
fn test_remove_node() {
let mut ring = three_node_ring();
assert!(ring.remove_node("node-b"));
assert_eq!(ring.node_count(), 2);
assert!(!ring.remove_node("nonexistent"));
}
#[test]
fn test_deactivate_activate_node() {
let mut ring = three_node_ring();
assert!(ring.deactivate_node("node-a"));
assert_eq!(ring.active_node_count(), 2);
assert!(ring.activate_node("node-a"));
assert_eq!(ring.active_node_count(), 3);
}
#[test]
fn test_deactivate_nonexistent() {
let mut ring = default_ring();
assert!(!ring.deactivate_node("nonexistent"));
}
#[test]
fn test_get_node() {
let ring = three_node_ring();
let node = ring.get_node("node-a");
assert!(node.is_some());
assert_eq!(node.map(|n| n.address.as_str()), Some("10.0.0.1:8080"));
}
#[test]
fn test_node_ids() {
let ring = three_node_ring();
let mut ids = ring.node_ids();
ids.sort();
assert_eq!(ids, vec!["node-a", "node-b", "node-c"]);
}
#[test]
fn test_get_node_for_key() {
let ring = three_node_ring();
let node = ring.get_node_for_key(b"some-key");
assert!(node.is_some());
}
#[test]
fn test_empty_ring_lookup() {
let ring = default_ring();
assert!(ring.get_node_for_key(b"key").is_none());
}
#[test]
fn test_key_consistency() {
let ring = three_node_ring();
let node1 = ring.get_node_for_key(b"consistent-key");
let node2 = ring.get_node_for_key(b"consistent-key");
assert_eq!(node1, node2);
}
#[test]
fn test_different_keys_may_map_differently() {
let ring = three_node_ring();
let mut mappings = HashSet::new();
for i in 0..100 {
let key = format!("key-{i}");
if let Some(node) = ring.get_node_for_key(key.as_bytes()) {
mappings.insert(node);
}
}
assert!(mappings.len() >= 2, "Expected distribution across nodes");
}
#[test]
fn test_inactive_node_skipped() {
let mut ring = three_node_ring();
let key = b"test-key";
let original = ring.get_node_for_key(key).expect("should have node");
ring.deactivate_node(&original);
let new_node = ring.get_node_for_key(key).expect("should have node");
assert_ne!(original, new_node);
}
#[test]
fn test_get_replicas() {
let ring = three_node_ring();
let replicas = ring.get_replicas(b"some-key");
assert_eq!(replicas.len(), 3); let unique: HashSet<_> = replicas.iter().collect();
assert_eq!(unique.len(), 3);
}
#[test]
fn test_get_replicas_fewer_nodes() {
let mut ring = ConsistentHashRing::new(HashRingConfig {
replica_count: 5,
..HashRingConfig::default()
});
ring.add_node(ClusterNode::new("a", "10.0.0.1:8080"));
ring.add_node(ClusterNode::new("b", "10.0.0.2:8080"));
let replicas = ring.get_replicas(b"key");
assert_eq!(replicas.len(), 2); }
#[test]
fn test_get_replicas_empty_ring() {
let ring = default_ring();
assert!(ring.get_replicas(b"key").is_empty());
}
#[test]
fn test_zone_aware_replicas() {
let mut ring = ConsistentHashRing::new(HashRingConfig {
replica_count: 3,
..HashRingConfig::default()
});
ring.add_node(ClusterNode::new("a", "10.0.0.1:8080").with_zone("us-east"));
ring.add_node(ClusterNode::new("b", "10.0.0.2:8080").with_zone("us-west"));
ring.add_node(ClusterNode::new("c", "10.0.0.3:8080").with_zone("eu-west"));
let replicas = ring.get_zone_aware_replicas(b"key");
assert_eq!(replicas.len(), 3);
let unique: HashSet<_> = replicas.iter().collect();
assert_eq!(unique.len(), 3);
}
#[test]
fn test_zone_aware_same_zone() {
let mut ring = ConsistentHashRing::new(HashRingConfig {
replica_count: 3,
..HashRingConfig::default()
});
ring.add_node(ClusterNode::new("a", "10.0.0.1:8080").with_zone("zone1"));
ring.add_node(ClusterNode::new("b", "10.0.0.2:8080").with_zone("zone1"));
ring.add_node(ClusterNode::new("c", "10.0.0.3:8080").with_zone("zone2"));
let replicas = ring.get_zone_aware_replicas(b"key");
assert_eq!(replicas.len(), 3);
}
#[test]
fn test_increment_decrement_load() {
let mut ring = three_node_ring();
ring.increment_load("node-a");
ring.increment_load("node-a");
assert_eq!(ring.node_load("node-a"), 2);
ring.decrement_load("node-a");
assert_eq!(ring.node_load("node-a"), 1);
}
#[test]
fn test_decrement_below_zero() {
let mut ring = three_node_ring();
ring.decrement_load("node-a"); assert_eq!(ring.node_load("node-a"), 0);
}
#[test]
fn test_reset_loads() {
let mut ring = three_node_ring();
ring.increment_load("node-a");
ring.increment_load("node-b");
ring.reset_loads();
assert_eq!(ring.node_load("node-a"), 0);
assert_eq!(ring.node_load("node-b"), 0);
}
#[test]
fn test_bounded_load() {
let mut ring = ConsistentHashRing::new(HashRingConfig {
load_factor: Some(1.5),
replica_count: 1,
vnodes_per_node: 10,
});
ring.add_node(ClusterNode::new("a", "10.0.0.1:8080"));
ring.add_node(ClusterNode::new("b", "10.0.0.2:8080"));
for _ in 0..100 {
ring.increment_load("a");
}
let mut b_count = 0;
for i in 0..50 {
let key = format!("key-{i}");
if let Some(node) = ring.get_node_for_key(key.as_bytes()) {
if node == "b" {
b_count += 1;
}
}
}
assert!(
b_count > 20,
"Expected most keys to go to node b, got {b_count}"
);
}
#[test]
fn test_minimal_redistribution_on_add() {
let ring = three_node_ring();
let mut before: HashMap<String, String> = HashMap::new();
for i in 0..100 {
let key = format!("key-{i}");
if let Some(node) = ring.get_node_for_key(key.as_bytes()) {
before.insert(key, node);
}
}
let mut ring_after = three_node_ring();
ring_after.add_node(ClusterNode::new("node-d", "10.0.0.4:8080"));
let mut moved = 0;
for i in 0..100 {
let key = format!("key-{i}");
if let Some(new_node) = ring_after.get_node_for_key(key.as_bytes()) {
if let Some(old_node) = before.get(&key) {
if old_node != &new_node {
moved += 1;
}
}
}
}
assert!(
moved < 60,
"Too many keys moved: {moved}/100 (expected ~25%)"
);
}
#[test]
fn test_affected_ranges_on_add() {
let ring = three_node_ring();
let affected = ring.affected_ranges_on_add("node-d", 1);
assert!(!affected.is_empty());
}
#[test]
fn test_affected_ranges_on_remove() {
let ring = three_node_ring();
let affected = ring.affected_ranges_on_remove("node-b");
assert!(!affected.is_empty());
let total_affected: usize = affected.values().sum();
assert!(total_affected > 0);
}
#[test]
fn test_stats() {
let ring = three_node_ring();
let stats = ring.stats();
assert_eq!(stats.node_count, 3);
assert_eq!(stats.active_node_count, 3);
assert_eq!(stats.vnode_count, 450); }
#[test]
fn test_stats_with_loads() {
let mut ring = three_node_ring();
ring.increment_load("node-a");
ring.increment_load("node-a");
ring.increment_load("node-b");
let stats = ring.stats();
assert_eq!(stats.max_load, 2);
assert_eq!(stats.min_load, 0);
}
#[test]
fn test_stats_stddev() {
let mut ring = three_node_ring();
ring.increment_load("node-a");
ring.increment_load("node-b");
ring.increment_load("node-c");
let stats = ring.stats();
assert!(stats.load_stddev < 0.001);
}
#[test]
fn test_single_node_ring() {
let mut ring = default_ring();
ring.add_node(ClusterNode::new("solo", "10.0.0.1:8080"));
let node = ring.get_node_for_key(b"any-key");
assert_eq!(node, Some("solo".to_string()));
}
#[test]
fn test_all_nodes_inactive() {
let mut ring = three_node_ring();
ring.deactivate_node("node-a");
ring.deactivate_node("node-b");
ring.deactivate_node("node-c");
assert!(ring.get_node_for_key(b"key").is_none());
}
#[test]
fn test_replicas_with_one_active() {
let mut ring = three_node_ring();
ring.deactivate_node("node-a");
ring.deactivate_node("node-b");
let replicas = ring.get_replicas(b"key");
assert_eq!(replicas.len(), 1);
assert_eq!(replicas[0], "node-c");
}
#[test]
fn test_hash_key_consistency() {
let h1 = hash_key(b"test");
let h2 = hash_key(b"test");
assert_eq!(h1, h2);
}
#[test]
fn test_hash_key_different_inputs() {
let h1 = hash_key(b"hello");
let h2 = hash_key(b"world");
assert_ne!(h1, h2);
}
#[test]
fn test_hash_vnode_consistency() {
let h1 = hash_vnode("node-a", 0);
let h2 = hash_vnode("node-a", 0);
assert_eq!(h1, h2);
}
#[test]
fn test_hash_vnode_different_indices() {
let h1 = hash_vnode("node-a", 0);
let h2 = hash_vnode("node-a", 1);
assert_ne!(h1, h2);
}
#[test]
fn test_cluster_node_builder() {
let node = ClusterNode::new("n1", "10.0.0.1:8080")
.with_weight(3)
.with_zone("us-east");
assert_eq!(node.id, "n1");
assert_eq!(node.weight, 3);
assert_eq!(node.zone, Some("us-east".to_string()));
assert!(node.active);
}
#[test]
fn test_cluster_node_default_weight() {
let node = ClusterNode::new("n1", "10.0.0.1:8080");
assert_eq!(node.weight, 1);
}
#[test]
fn test_default_config() {
let cfg = HashRingConfig::default();
assert_eq!(cfg.vnodes_per_node, 150);
assert_eq!(cfg.replica_count, 3);
assert!(cfg.load_factor.is_some());
}
#[test]
fn test_config_access() {
let ring = default_ring();
assert_eq!(ring.config().vnodes_per_node, 150);
}
}