#![doc = include_str!("../README.md")]
#![forbid(unsafe_code)]
mod iter;
mod partitioner;
mod range;
mod token;
use {
crate::{
iter::HashRingIter,
RingDirection::{Clockwise, CounterClockwise},
},
crossbeam_skiplist::SkipMap,
std::{
hash::Hash,
ops::Bound::{Excluded, Unbounded},
sync::Arc,
},
};
pub use {partitioner::*, range::*, token::RingToken};
pub trait RingNode: Hash + Send + 'static {}
impl<T> RingNode for T where T: Hash + Send + 'static {}
pub const DEFAULT_PROBE_COUNT: usize = 23;
pub type RingPosition = u64;
#[derive(Clone, Copy)]
pub enum RingDirection {
Clockwise,
CounterClockwise,
}
#[derive(Clone)]
pub struct HashRing<N: RingNode, P = DefaultPartitioner> {
partitioner: P,
positions: Arc<SkipMap<RingPosition, N>>,
probe_count: usize,
}
impl<N: RingNode> Default for HashRing<N> {
fn default() -> Self {
Self {
partitioner: DefaultPartitioner::new(),
positions: Arc::new(SkipMap::new()),
probe_count: DEFAULT_PROBE_COUNT,
}
}
}
impl<N: RingNode> HashRing<N> {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&self, pos: RingPosition, node: N) {
self.positions.insert(pos, node);
}
pub fn add(&self, node: N) {
let pos = self.partitioner.position(&node);
self.positions.insert(pos, node);
}
pub fn remove(&self, node: &N) {
let pos = self.partitioner.position(node);
self.positions.remove(&pos);
}
pub fn replicas<K: Hash>(&self, key: &K, k: usize) -> Vec<RingToken<'_, N>> {
self.tokens(self.position(key), Clockwise)
.take(k)
.collect::<Vec<_>>()
}
pub fn intervals(&self, node: &N) -> Option<Vec<KeyRange<RingPosition>>> {
let pos = self.position(node);
self.key_range(pos).map(|range| vec![range])
}
pub fn position<K: Hash>(&self, key: &K) -> RingPosition {
self.partitioner.position(key)
}
pub fn node<K: Hash>(&self, key: &K) -> Option<RingToken<N>> {
self.primary_token(key)
}
fn primary_token<K: Hash>(&self, key: &K) -> Option<RingToken<N>> {
let mut min_distance = RingPosition::MAX;
let mut min_token = None;
for pos in self.partitioner.positions(key, self.probe_count) {
match self.tokens(pos, Clockwise).next() {
Some(token) => {
let distance = distance(pos, token.position());
if distance < min_distance {
min_distance = distance;
min_token = Some(token);
}
}
None => {
return None;
}
};
}
min_token
}
#[must_use]
fn tokens(
&self,
start: RingPosition,
dir: RingDirection,
) -> impl DoubleEndedIterator<Item = RingToken<N>> {
match dir {
Clockwise => HashRingIter::Clockwise(
self.positions
.range(start..)
.chain(self.positions.range(0..start)),
),
CounterClockwise => HashRingIter::CounterClockwise(
self.positions
.range(..=start)
.rev()
.chain(self.positions.range((Excluded(start), Unbounded)).rev()),
),
}
.map(Into::into)
}
pub fn key_range(&self, pos: RingPosition) -> Option<KeyRange<RingPosition>> {
if self.positions.is_empty() {
return None;
}
let prev_pos = self.tokens(pos, Clockwise).next_back();
let start = prev_pos.map_or(0, |token| token.position());
Some(KeyRange::new(start, pos))
}
pub fn len(&self) -> usize {
self.positions.len()
}
pub fn is_empty(&self) -> bool {
self.positions.is_empty()
}
}
const fn distance(pos1: RingPosition, pos2: RingPosition) -> RingPosition {
if pos1 > pos2 {
RingPosition::MAX - pos1 + pos2
} else {
pos2 - pos1
}
}
#[cfg(test)]
mod tests {
use {super::*, rand::random, std::collections::BTreeSet};
#[derive(Hash, Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
struct Node {
id: u64,
}
impl Node {
fn random() -> Self {
Self { id: random() }
}
}
#[test]
fn tokens() {
let ring = HashRing::new();
let node1 = Node::random();
let node2 = Node::random();
let node3 = Node::random();
ring.add(node1);
ring.add(node2);
ring.add(node3);
let positions = ring
.tokens(0, Clockwise)
.map(|token| *token.node())
.collect::<Vec<_>>();
assert_eq!(positions.len(), 3);
assert!(positions.contains(&node1));
assert!(positions.contains(&node2));
assert!(positions.contains(&node3));
let positions = ring
.tokens(0, CounterClockwise)
.map(|token| *token.node())
.collect::<Vec<_>>();
assert_eq!(positions.len(), 3);
assert!(positions.contains(&node1));
assert!(positions.contains(&node2));
assert!(positions.contains(&node3));
ring.remove(&node2);
let positions = ring
.tokens(0, Clockwise)
.map(|token| *token.node())
.collect::<Vec<_>>();
assert_eq!(positions.len(), 2);
assert!(positions.contains(&node1));
assert!(!positions.contains(&node2));
assert!(positions.contains(&node3));
}
#[test]
fn tokens_wrap_around() {
let ring = HashRing::new();
let nodes = vec![Node::random(), Node::random(), Node::random()];
nodes.iter().for_each(|node| ring.add(*node));
let positions = ring
.tokens(u64::MAX - 1, Clockwise)
.map(|token| *token.node())
.collect::<Vec<_>>();
assert_eq!(
BTreeSet::from_iter(positions),
BTreeSet::from_iter(nodes.clone())
);
let positions = ring
.tokens(1, CounterClockwise)
.map(|token| *token.node())
.collect::<Vec<_>>();
assert_eq!(BTreeSet::from_iter(positions), BTreeSet::from_iter(nodes));
}
#[track_caller]
fn assert_nodes(ring: &HashRing<Node>, start: u64, dir: RingDirection, expected: Vec<Node>) {
let positions = ring
.tokens(start, dir)
.map(|token| *token.node())
.collect::<Vec<_>>();
assert_eq!(positions, expected);
}
#[test]
fn tokens_corner_cases() {
let ring = HashRing::new();
let node1 = Node::random();
let node2 = Node::random();
let node3 = Node::random();
ring.insert(0, node1);
ring.insert(u64::MAX / 2, node2);
ring.insert(u64::MAX, node3);
let test_cases = vec![
(0, Clockwise, vec![node1, node2, node3]),
(0, CounterClockwise, vec![node1, node3, node2]),
(1, Clockwise, vec![node2, node3, node1]),
(1, CounterClockwise, vec![node1, node3, node2]),
(u64::MAX / 2, Clockwise, vec![node2, node3, node1]),
(u64::MAX / 2, CounterClockwise, vec![node2, node1, node3]),
(u64::MAX / 2 + 1, Clockwise, vec![node3, node1, node2]),
(u64::MAX / 2 + 1, CounterClockwise, vec![
node2, node1, node3,
]),
(u64::MAX, Clockwise, vec![node3, node1, node2]),
(u64::MAX, CounterClockwise, vec![node3, node2, node1]),
];
for (start, dir, expected) in test_cases {
assert_nodes(&ring, start, dir, expected);
}
}
#[test]
fn tokens_for_key() {
let ring = HashRing::new();
let node1 = Node::random();
let node2 = Node::random();
let node3 = Node::random();
ring.add(node1);
ring.add(node2);
ring.add(node3);
let tokens = ring
.tokens(ring.position(&"foo"), Clockwise)
.collect::<Vec<_>>();
assert_eq!(tokens.len(), 3);
}
}