use std::collections::hash_map;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
pub struct ConsistentHash<K: Hash, H: Hasher = hash_map::DefaultHasher> {
ring: Ring<K, H>,
}
#[derive(Hash)]
struct VitualNode<K: Hash>(u8, K);
#[derive(Hash)]
struct RingEntry<K: Hash> {
sum: u64,
node: VitualNode<K>,
}
struct Ring<K: Hash, H: Hasher> {
_h: Option<PhantomData<H>>,
pub ring: Vec<RingEntry<K>>,
}
impl<K: Hash + Ord, H: Hasher + Default> ConsistentHash<K, H> {
pub fn new() -> Self {
Self { ring: Ring::new() }
}
}
impl<K: Hash + Ord + Clone, H: Hasher + Default> ConsistentHash<K, H> {
pub fn add(&mut self, k: K, replicas: u8) {
for id in 0..replicas {
self.ring.add(VitualNode(id, k.clone()))
}
}
pub fn remove<Q: ?Sized + PartialEq<K>>(&mut self, q: &Q) {
self.ring.remove(q)
}
pub fn pick<Q: Hash + ?Sized>(&self, q: &Q) -> &K {
self.ring.pick_node(q)
}
}
impl<K: Hash> VitualNode<K> {
fn new<Q: Into<K>>(id: u8, k: Q) -> Self {
Self(id, k.into())
}
}
impl<K: Hash + Ord, H: Hasher + Default> Ring<K, H> {
pub fn new() -> Self {
Self {
_h: None,
ring: Vec::new(),
}
}
pub fn add(&mut self, node: VitualNode<K>) {
let mut h = H::default();
node.hash(&mut h);
let sum = h.finish();
let index = self.find(sum);
let new_entry = RingEntry { sum, node };
if index == self.ring.len() {
self.ring.push(new_entry);
return;
}
if self.ring[index].sum == sum {
if new_entry.node.1 < self.ring[index].node.1 {
self.ring.insert(index, new_entry);
}
return;
}
let ring_part = self.ring.drain(index..).collect::<Vec<_>>();
self.ring.push(new_entry);
self.ring.extend(ring_part);
}
pub fn remove<Q: ?Sized + PartialEq<K>>(&mut self, k: &Q) {
let mut deleted = 0;
for i in 0..self.ring.len() {
if k == &self.ring[i].node.1 {
deleted += 1;
continue;
}
self.ring.swap(i - deleted, i);
}
self.ring.truncate(self.ring.len() - deleted);
}
pub fn pick_node<Q: Hash + ?Sized>(&self, q: &Q) -> &K {
let mut h = H::default();
q.hash(&mut h);
let sum = h.finish();
let i = self.find(sum) % self.ring.len();
&self.ring[i].node.1
}
fn find(&self, sum: u64) -> usize {
if self.ring.len() == 0 {
return 0;
}
let mut start = 0;
let mut end = self.ring.len();
while start < end {
let mid = (start + end) / 2;
let v = self.ring[mid].sum;
if v == sum {
return mid;
} else if v < sum {
start = mid + 1
} else if v > sum {
end = mid
}
}
return start;
}
}
#[cfg(test)]
mod tests {
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use bytes::Bytes;
use super::{ConsistentHash, Ring, RingEntry, VitualNode};
#[test]
fn test_ring_find() {
let ring = Ring::<String, DefaultHasher> {
_h: None,
ring: vec![
RingEntry {
sum: 1,
node: VitualNode::new(0, ""),
},
RingEntry {
sum: 2,
node: VitualNode::new(0, String::new()),
},
RingEntry {
sum: 4,
node: VitualNode::new(0, ""),
},
RingEntry {
sum: 6,
node: VitualNode::new(0, ""),
},
],
};
assert_eq!(ring.find(0), 0);
assert_eq!(ring.find(1), 0);
assert_eq!(ring.find(2), 1);
assert_eq!(ring.find(3), 2);
assert_eq!(ring.find(4), 2);
assert_eq!(ring.find(5), 3);
assert_eq!(ring.find(6), 3);
assert_eq!(ring.find(7), 4);
assert_eq!(ring.find(100), 4);
}
#[test]
fn test_ring() {
let mut ring = Ring::<String, DefaultHasher>::new();
ring.add(VitualNode::new(0, "8080"));
ring.add(VitualNode::new(1, "8080"));
ring.add(VitualNode::new(0, "8081"));
ring.add(VitualNode::new(1, "8081"));
ring.add(VitualNode::new(1, "8082"));
ring.add(VitualNode::new(0, "8082"));
ring.add(VitualNode::new(0, "8082"));
assert_eq!(ring.ring.len(), 6);
for i in 0..ring.ring.len() - 1 {
assert!(ring.ring[i].sum < ring.ring[i + 1].sum);
}
ring.remove("8080");
assert_eq!(ring.ring.len(), 4);
ring.remove("8081");
assert_eq!(ring.ring.len(), 2);
ring.remove("8082");
assert_eq!(ring.ring.len(), 0);
}
#[test]
fn test_consistent_hash() {
let mut map = ConsistentHash::<Bytes>::new();
map.add("8080".into(), 255);
map.add("8081".into(), 255);
map.add("8082".into(), 255);
assert_eq!(map.ring.ring.len(), 255 * 3);
let mut count_map = [b"8080", b"8081", b"8082"]
.into_iter()
.map(|node| (Bytes::from_static(node), 0))
.collect::<HashMap<_, _>>();
let round = 1000000;
for i in 1..round {
if let Some(c) = count_map.get_mut(map.pick(&i)) {
*c += 1
}
}
for (n, &c) in count_map.iter() {
let node = std::str::from_utf8(n).unwrap();
assert!(c > round * 3 / 10, "count of {}: {}", node, c);
assert!(c < round * 11 / 30, "count of {}: {}", node, c);
}
map.remove("8082");
for (_, c) in count_map.iter_mut() {
*c = 0
}
for i in 1..round {
if let Some(c) = count_map.get_mut(map.pick(&i)) {
*c += 1
}
}
for (n, &c) in count_map.iter() {
let node = std::str::from_utf8(n).unwrap();
if node == "8082" {
assert_eq!(c, 0)
} else {
assert!(c > round * 9 / 20, "count of {}: {}", node, c);
assert!(c < round * 11 / 20, "count of {}: {}", node, c);
}
}
}
}