use crate::{endpoints::EndpointGroup, error::Trace, hash::thread_local_xxhash};
use junction_api::backend::{Backend, LbPolicy, RequestHashPolicy, RequestHasher, RingHashParams};
use smol_str::ToSmolStr;
use std::{
net::SocketAddr,
sync::{
atomic::{AtomicUsize, Ordering},
RwLock,
},
};
#[derive(Debug)]
pub struct BackendLb {
pub config: Backend,
pub load_balancer: LoadBalancer,
}
#[derive(Debug)]
pub enum LoadBalancer {
RoundRobin(RoundRobinLb),
RingHash(RingHashLb),
}
impl LoadBalancer {
pub(crate) fn load_balance<'e>(
&self,
trace: &mut Trace,
endpoints: &'e EndpointGroup,
url: &crate::Url,
headers: &http::HeaderMap,
previous_addrs: &[SocketAddr],
) -> Option<&'e SocketAddr> {
match self {
LoadBalancer::RoundRobin(lb) => lb.pick_endpoint(trace, endpoints, previous_addrs),
LoadBalancer::RingHash(lb) => {
lb.pick_endpoint(trace, endpoints, url, headers, &lb.config.hash_params)
}
}
}
}
impl LoadBalancer {
pub(crate) fn from_config(config: &LbPolicy) -> Self {
match config {
LbPolicy::RoundRobin => LoadBalancer::RoundRobin(RoundRobinLb::default()),
LbPolicy::RingHash(x) => LoadBalancer::RingHash(RingHashLb::new(x)),
LbPolicy::Unspecified => LoadBalancer::RoundRobin(RoundRobinLb::default()),
}
}
}
#[derive(Debug, Default)]
pub struct RoundRobinLb {
idx: AtomicUsize,
}
impl RoundRobinLb {
fn pick_endpoint<'e>(
&self,
trace: &mut Trace,
endpoint_group: &'e EndpointGroup,
previous_addrs: &[SocketAddr],
) -> Option<&'e SocketAddr> {
let _ = previous_addrs;
let idx = self.idx.fetch_add(1, Ordering::SeqCst) % endpoint_group.len();
let addr = endpoint_group.nth(idx);
trace.load_balance("ROUND_ROBIN", addr, Vec::new());
addr
}
}
#[derive(Debug)]
pub struct RingHashLb {
config: RingHashParams,
ring: RwLock<Ring>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct RingEntry {
hash: u64,
idx: usize,
}
impl RingHashLb {
fn new(config: &RingHashParams) -> Self {
Self {
config: config.clone(),
ring: RwLock::new(Ring {
eg_hash: 0,
entries: Vec::with_capacity(config.min_ring_size as usize),
}),
}
}
fn pick_endpoint<'e>(
&self,
trace: &mut Trace,
endpoints: &'e EndpointGroup,
url: &crate::Url,
headers: &http::HeaderMap,
hash_params: &Vec<RequestHashPolicy>,
) -> Option<&'e SocketAddr> {
let request_hash =
hash_request(hash_params, url, headers).unwrap_or_else(crate::rand::random);
let endpoint_idx = self.with_ring(endpoints, |r| r.pick(request_hash))?;
let addr = endpoints.nth(endpoint_idx);
trace.load_balance("RING_HASH", addr, vec![("hash", request_hash.to_smolstr())]);
addr
}
fn with_ring<F, T>(&self, endpoint_group: &EndpointGroup, mut cb: F) -> T
where
F: FnMut(&Ring) -> T,
{
let ring = self.ring.read().unwrap();
if ring.eg_hash == endpoint_group.hash {
return cb(&ring);
}
std::mem::drop(ring);
let mut ring = self.ring.write().unwrap();
ring.rebuild(self.config.min_ring_size as usize, endpoint_group);
cb(&ring)
}
}
#[derive(Debug)]
struct Ring {
eg_hash: u64,
entries: Vec<RingEntry>,
}
impl Ring {
fn rebuild(&mut self, min_size: usize, endpoint_group: &EndpointGroup) {
let endpoint_count = endpoint_group.len();
let repeats = usize::max((min_size as f64 / endpoint_count as f64).ceil() as usize, 1);
let ring_size = repeats * endpoint_count;
self.entries.clear();
self.entries.reserve(ring_size);
for (idx, endpoint) in endpoint_group.iter().enumerate() {
for i in 0..repeats {
let hash = thread_local_xxhash::hash(&(endpoint, i));
self.entries.push(RingEntry { hash, idx });
}
}
self.eg_hash = endpoint_group.hash;
self.entries.sort_by_key(|e| e.hash);
}
fn pick(&self, endpoint_hash: u64) -> Option<usize> {
if self.entries.is_empty() {
return None;
}
let entry_idx = self.entries.partition_point(|e| e.hash < endpoint_hash);
let entry_idx = entry_idx % self.entries.len();
Some(self.entries[entry_idx].idx)
}
}
pub(crate) fn hash_request(
hash_policies: &Vec<RequestHashPolicy>,
url: &crate::Url,
headers: &http::HeaderMap,
) -> Option<u64> {
let mut hash: Option<u64> = None;
for hash_policy in hash_policies {
if let Some(new_hash) = hash_component(hash_policy, url, headers) {
hash = Some(match hash {
Some(hash) => hash.rotate_left(1) ^ new_hash,
None => new_hash,
});
if hash_policy.terminal {
break;
}
}
}
hash
}
fn hash_component(
policy: &RequestHashPolicy,
url: &crate::Url,
headers: &http::HeaderMap,
) -> Option<u64> {
match &policy.hasher {
RequestHasher::Header { name } => {
let mut header_values: Vec<_> = headers
.get_all(name)
.iter()
.map(http::HeaderValue::as_bytes)
.collect();
if header_values.is_empty() {
None
} else {
header_values.sort();
Some(thread_local_xxhash::hash_iter(header_values))
}
}
RequestHasher::QueryParam { ref name } => url.query().map(|query| {
let matching_vals = form_urlencoded::parse(query.as_bytes())
.filter_map(|(param, value)| (¶m == name).then_some(value));
thread_local_xxhash::hash_iter(matching_vals)
}),
}
}
#[cfg(test)]
mod test_ring_hash {
use crate::endpoints::Locality;
use super::*;
#[test]
fn test_rebuild_ring() {
let mut ring = Ring {
eg_hash: 0,
entries: Vec::new(),
};
ring.rebuild(0, &endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80"]));
assert_eq!(ring.eg_hash, 123);
assert_eq!(ring.entries.len(), 2);
assert_eq!(ring_indexes(&ring), (0..2).collect::<Vec<_>>());
assert_hashes_unique(&ring);
let first_ring = ring.entries.clone();
ring.rebuild(
0,
&endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80", "1.1.1.3:80"]),
);
assert_eq!(ring.eg_hash, 123);
assert_eq!(ring.entries.len(), 3);
assert_eq!(ring_indexes(&ring), (0..3).collect::<Vec<_>>());
assert_hashes_unique(&ring);
let second_ring: Vec<_> = ring
.entries
.iter()
.filter(|e| e.idx != 2)
.cloned()
.collect();
assert_eq!(first_ring, second_ring);
}
#[test]
fn test_rebuild_ring_min_size() {
let mut ring = Ring {
eg_hash: 0,
entries: Vec::new(),
};
ring.rebuild(
1024,
&endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80", "1.1.1.3:80"]),
);
assert_eq!(ring.entries.len(), 1026);
let mut counts = [0usize; 3];
for entry in &ring.entries {
counts[entry.idx] += 1;
}
assert!(counts.iter().all(|&c| c == 342));
assert_hashes_unique(&ring);
}
#[test]
fn test_pick() {
let mut ring = Ring {
eg_hash: 0,
entries: vec![],
};
ring.rebuild(
0,
&EndpointGroup::new(
[(
Locality::Unknown,
vec![
"1.1.1.1:80".parse().unwrap(),
"1.1.1.2:80".parse().unwrap(),
"1.1.1.3:80".parse().unwrap(),
],
)]
.into(),
),
);
let hashes_to_first = [0, ring.entries[0].hash, ring.entries[2].hash + 1];
for hash in hashes_to_first {
assert_eq!(ring.pick(hash), Some(ring.entries[0].idx),)
}
let hashes_to_last = [ring.entries[2].hash - 1, ring.entries[2].hash];
for hash in hashes_to_last {
assert_eq!(ring.pick(hash), Some(ring.entries[2].idx));
}
}
fn ring_indexes(r: &Ring) -> Vec<usize> {
let mut indexes: Vec<_> = r.entries.iter().map(|e| e.idx).collect();
indexes.sort();
indexes
}
fn endpoint_group(hash: u64, addrs: impl IntoIterator<Item = &'static str>) -> EndpointGroup {
let addrs = addrs.into_iter().map(|s| s.parse().unwrap()).collect();
let mut eg = EndpointGroup::new([(Locality::Unknown, addrs)].into());
eg.hash = hash;
eg
}
fn assert_hashes_unique(r: &Ring) {
let mut hashes: Vec<_> = r.entries.iter().map(|e| e.hash).collect();
hashes.sort();
hashes.dedup();
assert_eq!(hashes.len(), r.entries.len());
}
}