extern crate siphasher;
use siphasher::sip::SipHasher;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::hash::BuildHasher;
use std::hash::Hash;
#[derive(Clone, PartialEq, Debug)]
pub struct DefaultHashBuilder;
impl BuildHasher for DefaultHashBuilder {
type Hasher = SipHasher;
fn build_hasher(&self) -> Self::Hasher {
SipHasher::new()
}
}
#[derive(Clone, Debug)]
struct Node<T> {
key: u64,
node: T,
}
impl<T> Node<T> {
fn new(key: u64, node: T) -> Node<T> {
Node { key, node }
}
}
impl<T> PartialEq for Node<T> {
fn eq(&self, other: &Node<T>) -> bool {
self.key == other.key
}
}
impl<T> Eq for Node<T> {}
impl<T> PartialOrd for Node<T> {
fn partial_cmp(&self, other: &Node<T>) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for Node<T> {
fn cmp(&self, other: &Node<T>) -> Ordering {
self.key.cmp(&other.key)
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct HashRing<T, S = DefaultHashBuilder> {
hash_builder: S,
ring: Vec<Node<T>>,
}
impl<T> Default for HashRing<T> {
fn default() -> Self {
HashRing {
hash_builder: DefaultHashBuilder,
ring: Vec::new(),
}
}
}
impl<T> HashRing<T> {
pub fn new() -> HashRing<T> {
Default::default()
}
}
impl<T, S> HashRing<T, S> {
pub fn with_hasher(hash_builder: S) -> HashRing<T, S> {
HashRing {
hash_builder,
ring: Vec::new(),
}
}
pub fn len(&self) -> usize {
self.ring.len()
}
pub fn is_empty(&self) -> bool {
self.ring.len() == 0
}
}
impl<T: Hash, S: BuildHasher> HashRing<T, S> {
pub fn add(&mut self, node: T) {
let key = get_key(&self.hash_builder, &node);
self.ring.push(Node::new(key, node));
self.ring.sort();
}
pub fn batch_add(&mut self, nodes: Vec<T>) {
for node in nodes {
let key = get_key(&self.hash_builder, &node);
self.ring.push(Node::new(key, node));
}
self.ring.sort()
}
pub fn remove(&mut self, node: &T) -> Option<T> {
let key = get_key(&self.hash_builder, node);
match self.ring.binary_search_by(|node| node.key.cmp(&key)) {
Err(_) => None,
Ok(n) => Some(self.ring.remove(n).node),
}
}
pub fn get<U: Hash>(&self, key: &U) -> Option<&T> {
if self.ring.is_empty() {
return None;
}
let k = get_key(&self.hash_builder, key);
let n = match self.ring.binary_search_by(|node| node.key.cmp(&k)) {
Err(n) => n,
Ok(n) => n,
};
if n == self.ring.len() {
return Some(&self.ring[0].node);
}
Some(&self.ring[n].node)
}
pub fn get_with_replicas<U: Hash>(&self, key: &U, replicas: usize) -> Option<Vec<T>>
where
T: Clone + Debug,
{
if self.ring.is_empty() {
return None;
}
let replicas = if replicas > self.ring.len() {
self.ring.len()
} else {
replicas + 1
};
let k = get_key(&self.hash_builder, key);
let n = match self.ring.binary_search_by(|node| node.key.cmp(&k)) {
Err(n) => n,
Ok(n) => n,
};
let mut nodes = self.ring.clone();
nodes.rotate_left(n);
let replica_nodes = nodes
.iter()
.cycle()
.take(replicas)
.map(|node| node.node.clone())
.collect();
Some(replica_nodes)
}
}
pub struct HashRingIterator<T> {
ring: std::vec::IntoIter<Node<T>>,
}
impl<T> Iterator for HashRingIterator<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.ring.next().map(|node| node.node)
}
}
impl<T> IntoIterator for HashRing<T> {
type Item = T;
type IntoIter = HashRingIterator<T>;
fn into_iter(self) -> Self::IntoIter {
HashRingIterator {
ring: self.ring.into_iter(),
}
}
}
fn get_key<S, T>(hash_builder: &S, input: T) -> u64
where
S: BuildHasher,
T: Hash,
{
hash_builder.hash_one(input)
}
#[cfg(test)]
mod tests {
use std::hash::Hash;
use std::hash::Hasher;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr;
use super::HashRing;
#[derive(Debug, Copy, Clone, PartialEq)]
struct VNode {
id: usize,
addr: SocketAddrV4,
}
impl VNode {
fn new(ip: &str, port: u16, id: usize) -> Self {
let addr = SocketAddrV4::new(Ipv4Addr::from_str(ip).unwrap(), port);
VNode { id, addr }
}
}
impl Hash for VNode {
fn hash<H: Hasher>(&self, s: &mut H) {
(self.id, self.addr.port(), self.addr.ip()).hash(s)
}
}
#[test]
fn add_and_remove_nodes() {
let mut ring: HashRing<VNode> = HashRing::new();
assert_eq!(ring.len(), 0);
assert!(ring.is_empty());
let vnode1 = VNode::new("127.0.0.1", 1024, 1);
let vnode2 = VNode::new("127.0.0.1", 1024, 2);
let vnode3 = VNode::new("127.0.0.2", 1024, 1);
ring.add(vnode1);
ring.add(vnode2);
ring.add(vnode3);
assert_eq!(ring.len(), 3);
assert!(!ring.is_empty());
assert_eq!(ring.remove(&vnode2).unwrap(), vnode2);
assert_eq!(ring.len(), 2);
let vnode4 = VNode::new("127.0.0.2", 1024, 2);
let vnode5 = VNode::new("127.0.0.2", 1024, 3);
let vnode6 = VNode::new("127.0.0.3", 1024, 1);
ring.batch_add(vec![vnode4, vnode5, vnode6]);
assert_eq!(ring.remove(&vnode1).unwrap(), vnode1);
assert_eq!(ring.remove(&vnode3).unwrap(), vnode3);
assert_eq!(ring.remove(&vnode6).unwrap(), vnode6);
assert_eq!(ring.len(), 2);
}
#[test]
fn get_nodes() {
let mut ring: HashRing<VNode> = HashRing::new();
assert_eq!(ring.get(&"foo"), None);
let vnode1 = VNode::new("127.0.0.1", 1024, 1);
let vnode2 = VNode::new("127.0.0.1", 1024, 2);
let vnode3 = VNode::new("127.0.0.2", 1024, 1);
let vnode4 = VNode::new("127.0.0.2", 1024, 2);
let vnode5 = VNode::new("127.0.0.2", 1024, 3);
let vnode6 = VNode::new("127.0.0.3", 1024, 1);
ring.add(vnode1);
ring.add(vnode2);
ring.add(vnode3);
ring.add(vnode4);
ring.add(vnode5);
ring.add(vnode6);
assert_eq!(ring.get(&"foo"), Some(&vnode6));
assert_eq!(ring.get(&"bar"), Some(&vnode5));
assert_eq!(ring.get(&"baz"), Some(&vnode4));
assert_eq!(ring.get(&"abc"), Some(&vnode1));
assert_eq!(ring.get(&"def"), Some(&vnode1));
assert_eq!(ring.get(&"ghi"), Some(&vnode6));
assert_eq!(ring.get(&"cat"), Some(&vnode5));
assert_eq!(ring.get(&"dog"), Some(&vnode4));
assert_eq!(ring.get(&"bird"), Some(&vnode4));
let mut nodes = vec![0; 6];
for x in 0..50_000 {
let node = ring.get(&x).unwrap();
if vnode1 == *node {
nodes[0] += 1;
}
if vnode2 == *node {
nodes[1] += 1;
}
if vnode3 == *node {
nodes[2] += 1;
}
if vnode4 == *node {
nodes[3] += 1;
}
if vnode5 == *node {
nodes[4] += 1;
}
if vnode6 == *node {
nodes[5] += 1;
}
}
println!("{:?}", nodes);
assert!(nodes.iter().all(|x| *x != 0));
}
#[test]
fn get_nodes_with_replicas() {
let mut ring: HashRing<VNode> = HashRing::new();
assert_eq!(ring.get(&"foo"), None);
assert_eq!(ring.get_with_replicas(&"foo", 1), None);
let vnode1 = VNode::new("127.0.0.1", 1024, 1);
let vnode2 = VNode::new("127.0.0.1", 1024, 2);
let vnode3 = VNode::new("127.0.0.2", 1024, 3);
let vnode4 = VNode::new("127.0.0.2", 1024, 4);
let vnode5 = VNode::new("127.0.0.2", 1024, 5);
let vnode6 = VNode::new("127.0.0.3", 1024, 6);
ring.add(vnode1);
ring.add(vnode2);
ring.add(vnode3);
ring.add(vnode4);
ring.add(vnode5);
ring.add(vnode6);
assert_eq!(
ring.get_with_replicas(&"bar", 2).unwrap(),
vec![vnode3, vnode1, vnode2]
);
assert_eq!(
ring.get_with_replicas(&"foo", 4).unwrap(),
vec![vnode5, vnode4, vnode3, vnode1, vnode2]
);
}
#[test]
fn get_with_replicas_returns_too_many_replicas() {
let mut ring: HashRing<VNode> = HashRing::new();
assert_eq!(ring.get(&"foo"), None);
assert_eq!(ring.get_with_replicas(&"foo", 1), None);
let vnode1 = VNode::new("127.0.0.1", 1024, 1);
let vnode2 = VNode::new("127.0.0.1", 1024, 2);
let vnode3 = VNode::new("127.0.0.2", 1024, 3);
let vnode4 = VNode::new("127.0.0.2", 1024, 4);
let vnode5 = VNode::new("127.0.0.2", 1024, 5);
let vnode6 = VNode::new("127.0.0.3", 1024, 6);
ring.add(vnode1);
ring.add(vnode2);
ring.add(vnode3);
ring.add(vnode4);
ring.add(vnode5);
ring.add(vnode6);
assert_eq!(
ring.get_with_replicas(&"bar", 20).unwrap(),
vec![vnode3, vnode1, vnode2, vnode6, vnode5, vnode4],
"too high of replicas causes the count to shrink to ring length"
);
}
#[test]
fn into_iter() {
let mut ring: HashRing<VNode> = HashRing::new();
assert_eq!(ring.get(&"foo"), None);
let vnode1 = VNode::new("127.0.0.1", 1024, 1);
let vnode2 = VNode::new("127.0.0.1", 1024, 2);
let vnode3 = VNode::new("127.0.0.2", 1024, 1);
ring.add(vnode1);
ring.add(vnode2);
ring.add(vnode3);
let mut iter = ring.into_iter();
assert_eq!(Some(vnode3), iter.next());
assert_eq!(Some(vnode1), iter.next());
assert_eq!(Some(vnode2), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn hash_ring_eq() {
let mut ring: HashRing<VNode> = HashRing::new();
let mut other = ring.clone();
assert_eq!(ring, other);
assert_eq!(ring.len(), 0);
let vnode1 = VNode::new("127.0.0.1", 1024, 1);
let vnode2 = VNode::new("127.0.0.1", 1024, 2);
let vnode3 = VNode::new("127.0.0.2", 1024, 1);
other.add(vnode1);
other.add(vnode2);
other.add(vnode3);
assert_ne!(ring, other);
assert_eq!(other.len(), 3);
other.remove(&vnode1).unwrap();
other.remove(&vnode2).unwrap();
other.remove(&vnode3).unwrap();
assert_eq!(ring, other);
assert_eq!(other.len(), 0);
ring.add(vnode1);
ring.add(vnode2);
other.add(vnode2);
other.add(vnode3);
other.remove(&vnode3);
ring.remove(&vnode1);
assert_eq!(ring, other);
assert_eq!(ring.len(), 1);
assert_eq!(other.len(), 1);
}
}