use std::cmp::Ordering;
use std::io::Write;
use std::net::SocketAddr;
use crc32fast::Hasher;
#[cfg(feature = "v2")]
use i_key_sort::sort::one_key_cmp::OneKeyAndCmpSort;
pub const DEFAULT_POINT_MULTIPLE: u32 = 160;
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
pub struct Bucket {
node: SocketAddr,
weight: u32,
}
impl Bucket {
pub fn new(node: SocketAddr, weight: u32) -> Self {
assert!(weight != 0, "weight must be at least one");
Bucket { node, weight }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct PointV1 {
node: u32,
hash: u32,
}
impl Ord for PointV1 {
fn cmp(&self, other: &Self) -> Ordering {
self.hash.cmp(&other.hash)
}
}
impl PartialOrd for PointV1 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PointV1 {
fn new(node: u32, hash: u32) -> Self {
PointV1 { node, hash }
}
}
#[cfg(feature = "v2")]
#[derive(Copy, Clone, Eq, PartialEq)]
#[repr(transparent)]
struct PointV2([u8; 6]);
#[cfg(feature = "v2")]
impl PointV2 {
fn new(node: u16, hash: u32) -> Self {
let mut this = [0; 6];
this[0..4].copy_from_slice(&hash.to_ne_bytes());
this[4..6].copy_from_slice(&node.to_ne_bytes());
Self(this)
}
fn hash(&self) -> u32 {
u32::from_ne_bytes(self.0[0..4].try_into().expect("There are exactly 4 bytes"))
}
fn node(&self) -> u16 {
u16::from_ne_bytes(self.0[4..6].try_into().expect("There are exactly 2 bytes"))
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
pub enum Version {
#[default]
V1,
#[cfg(feature = "v2")]
V2 { point_multiple: u32 },
}
impl Version {
fn point_multiple(&self) -> u32 {
match self {
Version::V1 => DEFAULT_POINT_MULTIPLE,
#[cfg(feature = "v2")]
Version::V2 { point_multiple } => *point_multiple,
}
}
}
enum RingBuilder {
V1(Vec<PointV1>),
#[cfg(feature = "v2")]
V2(Vec<PointV2>),
}
impl RingBuilder {
fn new(version: Version, total_weight: u32) -> Self {
match version {
Version::V1 => RingBuilder::V1(Vec::with_capacity(
(total_weight * DEFAULT_POINT_MULTIPLE) as usize,
)),
#[cfg(feature = "v2")]
Version::V2 { point_multiple } => {
RingBuilder::V2(Vec::with_capacity((total_weight * point_multiple) as usize))
}
}
}
fn push(&mut self, node: u16, hash: u32) {
match self {
RingBuilder::V1(ring) => {
ring.push(PointV1::new(node as u32, hash));
}
#[cfg(feature = "v2")]
RingBuilder::V2(ring) => {
ring.push(PointV2::new(node, hash));
}
}
}
#[allow(unused)]
fn sort(&mut self, addresses: &[SocketAddr]) {
match self {
RingBuilder::V1(ring) => {
ring.sort_unstable();
ring.dedup_by(|a, b| a.hash == b.hash);
}
#[cfg(feature = "v2")]
RingBuilder::V2(ring) => {
ring.sort_by_one_key_then_by(
true,
|p| p.hash(),
|p1, p2| addresses[p1.node() as usize].cmp(&addresses[p2.node() as usize]),
);
ring.dedup_by(|a, b| a.0[0..4] == b.0[0..4]);
}
}
}
}
impl From<RingBuilder> for VersionedRing {
fn from(ring: RingBuilder) -> Self {
match ring {
RingBuilder::V1(ring) => VersionedRing::V1(ring.into_boxed_slice()),
#[cfg(feature = "v2")]
RingBuilder::V2(ring) => VersionedRing::V2(ring.into_boxed_slice()),
}
}
}
enum VersionedRing {
V1(Box<[PointV1]>),
#[cfg(feature = "v2")]
V2(Box<[PointV2]>),
}
impl VersionedRing {
pub fn node_idx(&self, hash: u32) -> usize {
let search_result = match self {
VersionedRing::V1(ring) => ring.binary_search_by(|p| p.hash.cmp(&hash)),
#[cfg(feature = "v2")]
VersionedRing::V2(ring) => ring.binary_search_by(|p| p.hash().cmp(&hash)),
};
match search_result {
Ok(i) => i,
Err(i) => {
if i == self.len() {
0
} else {
i
}
}
}
}
pub fn get(&self, index: usize) -> Option<usize> {
match self {
VersionedRing::V1(ring) => ring.get(index).map(|p| p.node as usize),
#[cfg(feature = "v2")]
VersionedRing::V2(ring) => ring.get(index).map(|p| p.node() as usize),
}
}
pub fn len(&self) -> usize {
match self {
VersionedRing::V1(ring) => ring.len(),
#[cfg(feature = "v2")]
VersionedRing::V2(ring) => ring.len(),
}
}
}
pub struct Continuum {
ring: VersionedRing,
addrs: Box<[SocketAddr]>,
}
impl Continuum {
pub fn new(buckets: &[Bucket]) -> Self {
Self::new_with_version(buckets, Version::default())
}
pub fn new_with_version(buckets: &[Bucket], version: Version) -> Self {
if buckets.is_empty() {
return Continuum {
ring: VersionedRing::V1(Box::new([])),
addrs: Box::new([]),
};
}
let total_weight: u32 = buckets.iter().fold(0, |sum, b| sum + b.weight);
let mut ring = RingBuilder::new(version, total_weight);
let mut addrs = Vec::with_capacity(buckets.len());
for bucket in buckets {
let mut hasher = Hasher::new();
let mut hash_bytes = Vec::with_capacity(39 + 1 + 5);
write!(&mut hash_bytes, "{}", bucket.node.ip()).unwrap();
write!(&mut hash_bytes, "\0").unwrap();
write!(&mut hash_bytes, "{}", bucket.node.port()).unwrap();
hasher.update(hash_bytes.as_ref());
let num_points = bucket.weight * version.point_multiple();
let mut prev_hash: u32 = 0;
addrs.push(bucket.node);
let node = addrs.len() - 1;
for _ in 0..num_points {
let mut hasher = hasher.clone();
hasher.update(&prev_hash.to_le_bytes());
let hash = hasher.finalize();
ring.push(node as u16, hash);
prev_hash = hash;
}
}
let addrs = addrs.into_boxed_slice();
ring.sort(&addrs);
Continuum {
ring: ring.into(),
addrs,
}
}
pub fn node_idx(&self, input: &[u8]) -> usize {
let hash = crc32fast::hash(input);
self.ring.node_idx(hash)
}
pub fn node(&self, hash_key: &[u8]) -> Option<SocketAddr> {
self.ring
.get(self.node_idx(hash_key)) .map(|n| self.addrs[n])
}
pub fn node_iter(&self, hash_key: &[u8]) -> NodeIterator<'_> {
NodeIterator {
idx: self.node_idx(hash_key),
continuum: self,
}
}
pub fn get_addr(&self, idx: &mut usize) -> Option<&SocketAddr> {
let point = self.ring.get(*idx);
if point.is_some() {
*idx = (*idx + 1) % self.ring.len();
}
point.map(|n| &self.addrs[n])
}
}
pub struct NodeIterator<'a> {
idx: usize,
continuum: &'a Continuum,
}
impl<'a> Iterator for NodeIterator<'a> {
type Item = &'a SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
self.continuum.get_addr(&mut self.idx)
}
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::path::Path;
use super::{Bucket, Continuum};
fn get_sockaddr(ip: &str) -> SocketAddr {
ip.parse().unwrap()
}
#[test]
fn consistency_after_adding_host() {
fn assert_hosts(c: &Continuum) {
assert_eq!(c.node(b"a"), Some(get_sockaddr("127.0.0.10:6443")));
assert_eq!(c.node(b"b"), Some(get_sockaddr("127.0.0.5:6443")));
}
let buckets: Vec<_> = (1..11)
.map(|u| Bucket::new(get_sockaddr(&format!("127.0.0.{u}:6443")), 1))
.collect();
let c = Continuum::new(&buckets);
assert_hosts(&c);
let buckets: Vec<_> = (1..12)
.map(|u| Bucket::new(get_sockaddr(&format!("127.0.0.{u}:6443")), 1))
.collect();
let c = Continuum::new(&buckets);
assert_hosts(&c);
}
#[test]
fn matches_nginx_sample() {
let upstream_hosts = ["127.0.0.1:7777", "127.0.0.1:7778"];
let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
let mut buckets = Vec::new();
for upstream in upstream_hosts {
buckets.push(Bucket::new(upstream, 1));
}
let c = Continuum::new(&buckets);
assert_eq!(c.node(b"/some/path"), Some(get_sockaddr("127.0.0.1:7778")));
assert_eq!(
c.node(b"/some/longer/path"),
Some(get_sockaddr("127.0.0.1:7777"))
);
assert_eq!(
c.node(b"/sad/zaidoon"),
Some(get_sockaddr("127.0.0.1:7778"))
);
assert_eq!(c.node(b"/g"), Some(get_sockaddr("127.0.0.1:7777")));
assert_eq!(
c.node(b"/pingora/team/is/cool/and/this/is/a/long/uri"),
Some(get_sockaddr("127.0.0.1:7778"))
);
assert_eq!(
c.node(b"/i/am/not/confident/in/this/code"),
Some(get_sockaddr("127.0.0.1:7777"))
);
}
#[test]
fn matches_nginx_sample_data() {
let upstream_hosts = [
"10.0.0.1:443",
"10.0.0.2:443",
"10.0.0.3:443",
"10.0.0.4:443",
"10.0.0.5:443",
"10.0.0.6:443",
"10.0.0.7:443",
"10.0.0.8:443",
"10.0.0.9:443",
];
let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
let mut buckets = Vec::new();
for upstream in upstream_hosts {
buckets.push(Bucket::new(upstream, 100));
}
let c = Continuum::new(&buckets);
let path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("test-data")
.join("sample-nginx-upstream.csv");
let mut rdr = csv::ReaderBuilder::new()
.has_headers(false)
.from_path(path)
.unwrap();
for pair in rdr.records() {
let pair = pair.unwrap();
let uri = pair.get(0).unwrap();
let upstream = pair.get(1).unwrap();
let got = c.node(uri.as_bytes()).unwrap();
assert_eq!(got, get_sockaddr(upstream));
}
}
#[test]
fn node_iter() {
let upstream_hosts = ["127.0.0.1:7777", "127.0.0.1:7778", "127.0.0.1:7779"];
let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
let mut buckets = Vec::new();
for upstream in upstream_hosts {
buckets.push(Bucket::new(upstream, 1));
}
let c = Continuum::new(&buckets);
let mut iter = c.node_iter(b"doghash");
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7778")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7777")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7777")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7778")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7778")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
let upstream_hosts = ["127.0.0.1:7777", "127.0.0.1:7779"];
let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
let mut buckets = Vec::new();
for upstream in upstream_hosts {
buckets.push(Bucket::new(upstream, 1));
}
let c = Continuum::new(&buckets);
let mut iter = c.node_iter(b"doghash");
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7777")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7777")));
assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
let c = Continuum::new(&[Bucket::new(get_sockaddr("127.0.0.1:7777"), 1)]);
let mut iter = c.node_iter(b"doghash");
let start_idx = iter.idx;
for _ in 0..c.ring.len() {
assert!(iter.next().is_some());
}
assert_eq!(start_idx, iter.idx);
}
#[test]
fn test_empty() {
let c = Continuum::new(&[]);
assert!(c.node(b"doghash").is_none());
let mut iter = c.node_iter(b"doghash");
assert!(iter.next().is_none());
assert!(iter.next().is_none());
assert!(iter.next().is_none());
}
#[test]
fn test_ipv6_ring() {
let upstream_hosts = ["[::1]:7777", "[::1]:7778", "[::1]:7779"];
let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
let mut buckets = Vec::new();
for upstream in upstream_hosts {
buckets.push(Bucket::new(upstream, 1));
}
let c = Continuum::new(&buckets);
let mut iter = c.node_iter(b"doghash");
assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7777")));
assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7778")));
assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7777")));
assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7778")));
assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7778")));
assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7777")));
assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7779")));
}
}