use pollen_types::NodeId;
use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};
pub struct HashRing {
ring: BTreeMap<u64, NodeId>,
replicas: usize,
nodes: Vec<NodeId>,
}
impl HashRing {
pub fn new(replicas: usize) -> Self {
Self {
ring: BTreeMap::new(),
replicas,
nodes: Vec::new(),
}
}
pub fn add(&mut self, node: NodeId) {
if self.nodes.contains(&node) {
return;
}
self.nodes.push(node);
for i in 0..self.replicas {
let key = format!("{}:{}", node, i);
let hash = self.hash(key.as_bytes());
self.ring.insert(hash, node);
}
}
pub fn remove(&mut self, node: NodeId) {
self.nodes.retain(|n| *n != node);
for i in 0..self.replicas {
let key = format!("{}:{}", node, i);
let hash = self.hash(key.as_bytes());
self.ring.remove(&hash);
}
}
pub fn clear(&mut self) {
self.ring.clear();
self.nodes.clear();
}
pub fn get(&self, key: &[u8]) -> Option<&NodeId> {
if self.ring.is_empty() {
return None;
}
let hash = self.hash(key);
if let Some((_, node)) = self.ring.range(hash..).next() {
return Some(node);
}
self.ring.values().next()
}
pub fn get_n(&self, key: &[u8], n: usize) -> Vec<NodeId> {
if self.ring.is_empty() || n == 0 {
return vec![];
}
let hash = self.hash(key);
let mut result = Vec::with_capacity(n.min(self.nodes.len()));
let mut seen = std::collections::HashSet::new();
for (_, node) in self.ring.range(hash..).chain(self.ring.range(..hash)) {
if seen.insert(*node) {
result.push(*node);
if result.len() >= n || result.len() >= self.nodes.len() {
break;
}
}
}
result
}
pub fn is_empty(&self) -> bool {
self.ring.is_empty()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn nodes(&self) -> &[NodeId] {
&self.nodes
}
fn hash(&self, key: &[u8]) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_ring() {
let ring = HashRing::new(10);
assert!(ring.get(b"test").is_none());
assert!(ring.is_empty());
}
#[test]
fn test_single_node() {
let mut ring = HashRing::new(10);
let node = NodeId::from_raw(1);
ring.add(node);
assert_eq!(ring.get(b"key1"), Some(&node));
assert_eq!(ring.get(b"key2"), Some(&node));
assert_eq!(ring.get(b"key3"), Some(&node));
}
#[test]
fn test_multiple_nodes() {
let mut ring = HashRing::new(100);
let nodes: Vec<_> = (1..=3).map(NodeId::from_raw).collect();
for node in &nodes {
ring.add(*node);
}
assert_eq!(ring.len(), 3);
let mut distribution = std::collections::HashMap::new();
for i in 0..1000 {
let key = format!("key{}", i);
if let Some(node) = ring.get(key.as_bytes()) {
*distribution.entry(*node).or_insert(0) += 1;
}
}
for node in &nodes {
assert!(distribution.get(node).unwrap_or(&0) > &0);
}
}
#[test]
fn test_get_n() {
let mut ring = HashRing::new(100);
let nodes: Vec<_> = (1..=5).map(NodeId::from_raw).collect();
for node in &nodes {
ring.add(*node);
}
let replicas = ring.get_n(b"test", 3);
assert_eq!(replicas.len(), 3);
let unique: std::collections::HashSet<_> = replicas.iter().collect();
assert_eq!(unique.len(), 3);
}
#[test]
fn test_node_removal() {
let mut ring = HashRing::new(100);
let node1 = NodeId::from_raw(1);
let node2 = NodeId::from_raw(2);
ring.add(node1);
ring.add(node2);
ring.remove(node1);
assert_eq!(ring.len(), 1);
assert_eq!(ring.get(b"test"), Some(&node2));
}
#[test]
fn test_consistent_hashing() {
let mut ring = HashRing::new(100);
let nodes: Vec<_> = (1..=3).map(NodeId::from_raw).collect();
for node in &nodes {
ring.add(*node);
}
let mut assignments: std::collections::HashMap<String, NodeId> = std::collections::HashMap::new();
for i in 0..100 {
let key = format!("key{}", i);
if let Some(node) = ring.get(key.as_bytes()) {
assignments.insert(key, *node);
}
}
ring.add(NodeId::from_raw(4));
let mut unchanged = 0;
for (key, old_node) in &assignments {
if let Some(new_node) = ring.get(key.as_bytes()) {
if *new_node == *old_node {
unchanged += 1;
}
}
}
assert!(unchanged >= 70);
}
}