use std::collections::{BTreeMap, HashMap};
use std::fmt;
const FNV_OFFSET_BASIS: u64 = 0xcbf29ce484222325u64;
const FNV_PRIME: u64 = 0x00000100000001b3u64;
fn fnv1a_64(data: &[u8]) -> u64 {
let mut hash = FNV_OFFSET_BASIS;
for &byte in data {
hash ^= u64::from(byte);
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(pub u64);
impl fmt::Display for NodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "node:{}", self.0)
}
}
#[derive(Debug, Clone)]
pub struct ConsistentHash {
ring: BTreeMap<u64, NodeId>,
virtual_nodes_per_node: u32,
}
impl ConsistentHash {
pub fn new(virtual_nodes: u32) -> Self {
Self {
ring: BTreeMap::new(),
virtual_nodes_per_node: virtual_nodes.max(1),
}
}
pub fn add_node(&mut self, node_id: NodeId) {
for i in 0..self.virtual_nodes_per_node {
let label = format!("{node_id}_{i}");
let pos = fnv1a_64(label.as_bytes());
self.ring.insert(pos, node_id);
}
}
pub fn remove_node(&mut self, node_id: NodeId) {
let to_remove: Vec<u64> = self
.ring
.iter()
.filter_map(|(&pos, &nid)| if nid == node_id { Some(pos) } else { None })
.collect();
for pos in to_remove {
self.ring.remove(&pos);
}
}
pub fn get_node(&self, key: &[u8]) -> Option<NodeId> {
if self.ring.is_empty() {
return None;
}
let pos = fnv1a_64(key);
self.ring
.range(pos..)
.next()
.or_else(|| self.ring.iter().next())
.map(|(_, &nid)| nid)
}
pub fn get_n_nodes(&self, key: &[u8], n: usize) -> Vec<NodeId> {
if self.ring.is_empty() || n == 0 {
return Vec::new();
}
let pos = fnv1a_64(key);
let after = self.ring.range(pos..).map(|(_, nid)| *nid);
let before = self.ring.range(..pos).map(|(_, nid)| *nid);
let full_circle = after.chain(before);
let mut seen: Vec<NodeId> = Vec::with_capacity(n);
for node in full_circle {
if !seen.contains(&node) {
seen.push(node);
if seen.len() == n {
break;
}
}
}
seen
}
pub fn virtual_node_count(&self) -> usize {
self.ring.len()
}
pub fn real_node_count(&self) -> usize {
let mut nodes: Vec<NodeId> = self.ring.values().copied().collect();
nodes.sort_unstable();
nodes.dedup();
nodes.len()
}
}
#[derive(Debug, Clone)]
pub struct DistributedCacheClient {
pub local_node: NodeId,
pub ring: ConsistentHash,
}
impl DistributedCacheClient {
pub fn new(local_node: NodeId, ring: ConsistentHash) -> Self {
Self { local_node, ring }
}
pub fn route_key(&self, key: &[u8]) -> NodeId {
self.ring.get_node(key).unwrap_or(self.local_node)
}
pub fn is_local_key(&self, key: &[u8]) -> bool {
self.route_key(key) == self.local_node
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReplicationFactor {
pub reads: u8,
pub writes: u8,
}
impl ReplicationFactor {
pub fn new(reads: u8, writes: u8) -> Self {
Self { reads, writes }
}
pub fn is_quorum_read_met(&self, responses: u8) -> bool {
responses >= self.reads
}
pub fn is_quorum_write_met(&self, responses: u8) -> bool {
responses >= self.writes
}
pub fn rf3() -> Self {
Self {
reads: 2,
writes: 2,
}
}
pub fn rf3_strong() -> Self {
Self {
reads: 3,
writes: 3,
}
}
}
impl Default for ReplicationFactor {
fn default() -> Self {
Self::rf3()
}
}
#[derive(Debug)]
pub struct CacheCoordinator {
pub clients: HashMap<NodeId, DistributedCacheClient>,
pub replication: ReplicationFactor,
}
impl CacheCoordinator {
pub fn new(replication: ReplicationFactor) -> Self {
Self {
clients: HashMap::new(),
replication,
}
}
pub fn add_client(&mut self, client: DistributedCacheClient) {
self.clients.insert(client.local_node, client);
}
pub fn remove_client(&mut self, node_id: NodeId) {
self.clients.remove(&node_id);
}
pub fn primary_node_for(&self, key: &[u8]) -> Option<NodeId> {
self.clients.values().next().map(|c| c.route_key(key))
}
pub fn replica_nodes_for(&self, key: &[u8], n: usize) -> Vec<NodeId> {
self.clients
.values()
.next()
.map(|c| c.ring.get_n_nodes(key, n))
.unwrap_or_default()
}
pub fn can_write_quorum(&self, key: &[u8], available_nodes: &[NodeId]) -> bool {
let replicas = self.replica_nodes_for(key, self.replication.writes as usize);
let ack_count = replicas
.iter()
.filter(|nid| available_nodes.contains(nid))
.count() as u8;
self.replication.is_quorum_write_met(ack_count)
}
pub fn can_read_quorum(&self, key: &[u8], available_nodes: &[NodeId]) -> bool {
let replicas = self.replica_nodes_for(key, self.replication.reads as usize);
let response_count = replicas
.iter()
.filter(|nid| available_nodes.contains(nid))
.count() as u8;
self.replication.is_quorum_read_met(response_count)
}
pub fn node_count(&self) -> usize {
self.clients.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ring_with_nodes(vn: u32, ids: &[u64]) -> ConsistentHash {
let mut ring = ConsistentHash::new(vn);
for &id in ids {
ring.add_node(NodeId(id));
}
ring
}
#[test]
fn test_node_id_display() {
let nid = NodeId(42);
assert_eq!(format!("{nid}"), "node:42");
}
#[test]
fn test_empty_ring_get_node() {
let ring = ConsistentHash::new(10);
assert!(ring.get_node(b"any_key").is_none());
}
#[test]
fn test_single_node_routing() {
let ring = make_ring_with_nodes(20, &[1]);
for key in [b"a".as_ref(), b"hello", b"oximedia"] {
assert_eq!(ring.get_node(key), Some(NodeId(1)));
}
}
#[test]
fn test_two_nodes_split_keyspace() {
let ring = make_ring_with_nodes(150, &[1, 2]);
let mut counts = [0usize; 2];
for i in 0u32..1000 {
let key = i.to_le_bytes();
match ring.get_node(&key) {
Some(NodeId(1)) => counts[0] += 1,
Some(NodeId(2)) => counts[1] += 1,
_ => {}
}
}
assert!(counts[0] > 100, "node 1 got too few keys: {}", counts[0]);
assert!(counts[1] > 100, "node 2 got too few keys: {}", counts[1]);
}
#[test]
fn test_virtual_node_count() {
let ring = make_ring_with_nodes(50, &[1, 2, 3]);
assert_eq!(ring.virtual_node_count(), 150);
}
#[test]
fn test_real_node_count() {
let ring = make_ring_with_nodes(20, &[10, 20, 30, 40]);
assert_eq!(ring.real_node_count(), 4);
}
#[test]
fn test_remove_node() {
let mut ring = make_ring_with_nodes(10, &[1, 2]);
ring.remove_node(NodeId(1));
assert_eq!(ring.real_node_count(), 1);
assert_eq!(ring.virtual_node_count(), 10);
for i in 0u32..50 {
assert_eq!(ring.get_node(&i.to_le_bytes()), Some(NodeId(2)));
}
}
#[test]
fn test_add_node_twice_does_not_double_positions() {
let mut ring = ConsistentHash::new(10);
ring.add_node(NodeId(7));
ring.add_node(NodeId(7)); assert!(ring.virtual_node_count() <= 10);
}
#[test]
fn test_get_n_nodes_distinct() {
let ring = make_ring_with_nodes(100, &[1, 2, 3]);
let nodes = ring.get_n_nodes(b"replicated_key", 3);
assert_eq!(nodes.len(), 3);
let unique: std::collections::HashSet<_> = nodes.iter().cloned().collect();
assert_eq!(unique.len(), 3);
}
#[test]
fn test_get_n_nodes_exceeds_real_count() {
let ring = make_ring_with_nodes(50, &[1, 2]);
let nodes = ring.get_n_nodes(b"key", 10);
assert_eq!(nodes.len(), 2);
}
#[test]
fn test_get_n_nodes_zero() {
let ring = make_ring_with_nodes(50, &[1, 2, 3]);
assert!(ring.get_n_nodes(b"key", 0).is_empty());
}
#[test]
fn test_get_n_nodes_empty_ring() {
let ring = ConsistentHash::new(10);
assert!(ring.get_n_nodes(b"key", 3).is_empty());
}
#[test]
fn test_consistent_routing() {
let ring = make_ring_with_nodes(100, &[1, 2, 3, 4, 5]);
for key in [b"video_001".as_ref(), b"audio_002", b"manifest"] {
let first = ring.get_node(key);
for _ in 0..10 {
assert_eq!(ring.get_node(key), first, "routing is not deterministic");
}
}
}
#[test]
fn test_distributed_cache_client_route() {
let ring = make_ring_with_nodes(100, &[1, 2, 3]);
let client = DistributedCacheClient::new(NodeId(1), ring);
let routed = client.route_key(b"some_key");
assert!(routed.0 >= 1 && routed.0 <= 3);
}
#[test]
fn test_is_local_key_single_node() {
let mut ring = ConsistentHash::new(50);
ring.add_node(NodeId(99));
let client = DistributedCacheClient::new(NodeId(99), ring);
assert!(client.is_local_key(b"anything"));
}
#[test]
fn test_replication_factor_read_quorum() {
let rf = ReplicationFactor::new(2, 2);
assert!(!rf.is_quorum_read_met(1));
assert!(rf.is_quorum_read_met(2));
assert!(rf.is_quorum_read_met(3));
}
#[test]
fn test_replication_factor_write_quorum() {
let rf = ReplicationFactor::new(2, 3);
assert!(!rf.is_quorum_write_met(2));
assert!(rf.is_quorum_write_met(3));
}
#[test]
fn test_rf3_defaults() {
let rf = ReplicationFactor::rf3();
assert_eq!(rf.reads, 2);
assert_eq!(rf.writes, 2);
}
#[test]
fn test_cache_coordinator_node_count() {
let mut coord = CacheCoordinator::new(ReplicationFactor::rf3());
let ring = make_ring_with_nodes(50, &[1, 2, 3]);
for id in 1..=3u64 {
coord.add_client(DistributedCacheClient::new(NodeId(id), ring.clone()));
}
assert_eq!(coord.node_count(), 3);
coord.remove_client(NodeId(2));
assert_eq!(coord.node_count(), 2);
}
#[test]
fn test_can_write_quorum_all_nodes_up() {
let ring = make_ring_with_nodes(100, &[1, 2, 3]);
let mut coord = CacheCoordinator::new(ReplicationFactor::new(2, 2));
for id in 1..=3u64 {
coord.add_client(DistributedCacheClient::new(NodeId(id), ring.clone()));
}
let all_nodes = vec![NodeId(1), NodeId(2), NodeId(3)];
assert!(coord.can_write_quorum(b"key", &all_nodes));
}
#[test]
fn test_can_write_quorum_insufficient() {
let ring = make_ring_with_nodes(100, &[1, 2, 3]);
let mut coord = CacheCoordinator::new(ReplicationFactor::new(2, 3));
for id in 1..=3u64 {
coord.add_client(DistributedCacheClient::new(NodeId(id), ring.clone()));
}
let partial = vec![NodeId(1)];
assert!(!coord.can_write_quorum(b"key", &partial));
}
#[test]
fn test_primary_node_for() {
let ring = make_ring_with_nodes(100, &[5, 6, 7]);
let mut coord = CacheCoordinator::new(ReplicationFactor::default());
coord.add_client(DistributedCacheClient::new(NodeId(5), ring));
let primary = coord.primary_node_for(b"video_segment");
assert!(primary.is_some());
}
#[test]
fn test_primary_node_for_empty() {
let coord = CacheCoordinator::new(ReplicationFactor::default());
assert!(coord.primary_node_for(b"key").is_none());
}
#[test]
fn test_routing_consistency_after_removal() {
let mut ring = make_ring_with_nodes(100, &[1, 2, 3, 4, 5]);
let key = b"stable_key";
let before = ring.get_node(key);
ring.remove_node(NodeId(99)); let after = ring.get_node(key);
assert_eq!(before, after, "routing changed when removing absent node");
}
#[test]
fn test_uniform_distribution_three_nodes() {
let ring = make_ring_with_nodes(200, &[1, 2, 3]);
let mut counts: HashMap<u64, usize> = HashMap::new();
for i in 0u32..3000 {
let key = format!("key_{i}");
if let Some(nid) = ring.get_node(key.as_bytes()) {
*counts.entry(nid.0).or_insert(0) += 1;
}
}
for (node, count) in &counts {
assert!(
*count > 300 && *count < 2400,
"node {node} has unbalanced load: {count} / 3000"
);
}
}
}