use std::collections::BTreeMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Mutex;
use super::{node_set_fingerprint, SelectionContext, Strategy};
use crate::Node;
#[derive(Debug, bon::Builder)]
pub struct ConsistentHash {
#[builder(default = 150)]
replicas: usize,
#[builder(skip = Mutex::new(None))]
cache: Mutex<Option<CachedRing>>,
}
#[derive(Debug)]
struct CachedRing {
fingerprint: u64,
ring: BTreeMap<u64, usize>,
}
impl Default for ConsistentHash {
fn default() -> Self {
Self::builder().build()
}
}
fn build_ring<N: Node>(candidates: &[N], replicas: usize) -> BTreeMap<u64, usize> {
let mut ring = BTreeMap::new();
for (i, node) in candidates.iter().enumerate() {
for replica in 0..replicas {
let mut hasher = DefaultHasher::new();
node.id().hash(&mut hasher);
replica.hash(&mut hasher);
ring.insert(hasher.finish(), i);
}
}
ring
}
impl<N: Node> Strategy<N> for ConsistentHash {
fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
if candidates.is_empty() {
return None;
}
let request_hash = ctx.hash_key.unwrap_or(0);
let fingerprint = node_set_fingerprint(candidates);
let mut cache = self.cache.lock().unwrap();
let ring = match cache.as_ref() {
Some(cached) if cached.fingerprint == fingerprint => &cached.ring,
_ => {
*cache = Some(CachedRing {
fingerprint,
ring: build_ring(candidates, self.replicas),
});
&cache.as_ref().unwrap().ring
}
};
ring.range(request_hash..)
.chain(ring.iter())
.find(|(_, idx)| !ctx.is_excluded(**idx))
.map(|(_, &idx)| idx)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct N(String);
impl crate::Node for N {
type Id = String;
fn id(&self) -> &String {
&self.0
}
}
#[test]
fn consistent_for_same_key() {
let ch = ConsistentHash::default();
let nodes = [N("a".into()), N("b".into()), N("c".into())];
let ctx = SelectionContext::builder().hash_key(42).build();
let first = ch.select(&nodes, &ctx);
let second = ch.select(&nodes, &ctx);
assert_eq!(first, second);
}
#[test]
fn different_keys_can_hit_different_nodes() {
use std::hash::{DefaultHasher, Hash, Hasher};
let ch = ConsistentHash::default();
let nodes = [N("a".into()), N("b".into()), N("c".into())];
let results: std::collections::HashSet<usize> = (0..1000)
.map(|i| {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
let ctx = SelectionContext::builder().hash_key(hasher.finish()).build();
ch.select(&nodes, &ctx).unwrap()
})
.collect();
assert!(results.len() >= 2);
}
#[test]
fn ring_is_cached() {
let ch = ConsistentHash::builder().replicas(10).build();
let nodes = [N("a".into()), N("b".into())];
let ctx = SelectionContext::builder().hash_key(100).build();
let r1 = ch.select(&nodes, &ctx);
let r2 = ch.select(&nodes, &ctx);
assert_eq!(r1, r2);
let nodes2 = [N("a".into()), N("b".into()), N("c".into())];
let _ = ch.select(&nodes2, &ctx);
let r3 = ch.select(&nodes, &ctx);
assert_eq!(r1, r3);
}
#[test]
fn skips_excluded_in_ring() {
let ch = ConsistentHash::default();
let nodes = [N("a".into()), N("b".into()), N("c".into())];
let ctx = SelectionContext::builder().hash_key(42).build();
let normal = ch.select(&nodes, &ctx).unwrap();
let excluded_ctx = SelectionContext::builder()
.hash_key(42)
.exclude(vec![normal])
.build();
let fallback = ch.select(&nodes, &excluded_ctx);
assert!(fallback.is_some());
assert_ne!(fallback.unwrap(), normal);
}
}