use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::{ProxyError, Result, UpstreamPool, UpstreamState};
#[derive(Debug, Clone)]
pub struct LoadBalancer {
pool: UpstreamPool,
slots: Arc<Vec<usize>>,
counter: Arc<AtomicUsize>,
}
impl LoadBalancer {
pub fn new(pool: UpstreamPool) -> Self {
let slots = pool
.all()
.iter()
.enumerate()
.flat_map(|(idx, backend)| std::iter::repeat_n(idx, backend.weight() as usize))
.collect::<Vec<usize>>();
Self {
pool,
slots: Arc::new(slots),
counter: Arc::new(AtomicUsize::new(0)),
}
}
pub fn next(&self) -> Result<UpstreamState> {
let slots = self.slots.len();
if slots == 0 {
return Err(ProxyError::NoHealthyUpstream);
}
let start = self.counter.fetch_add(1, Ordering::Relaxed);
let backends = self.pool.all();
(0..slots)
.map(|offset| {
let slot_idx = (start + offset) % slots;
&backends[self.slots[slot_idx]]
})
.find(|backend| backend.is_healthy())
.cloned()
.ok_or(ProxyError::NoHealthyUpstream)
}
pub fn pool(&self) -> &UpstreamPool {
&self.pool
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ValidatedUpstream;
fn make_pool(specs: &[(&str, u32)]) -> UpstreamPool {
let validated = specs
.iter()
.map(|(addr, weight)| ValidatedUpstream {
uri: addr.parse().unwrap(),
weight: *weight,
})
.collect::<Vec<ValidatedUpstream>>();
UpstreamPool::from_validated(&validated)
}
#[test]
fn single_backend_always_selected() {
let pool = make_pool(&[("http://b1:3000", 1)]);
let balancer = LoadBalancer::new(pool);
for _ in 0..10 {
let selected = balancer.next().unwrap();
assert_eq!(
selected.uri(),
&"http://b1:3000".parse::<hyper::Uri>().unwrap()
);
}
}
#[test]
fn equal_weight_round_robins() {
let pool = make_pool(&[("http://b1:3000", 1), ("http://b2:3000", 1)]);
let balancer = LoadBalancer::new(pool);
let first = balancer.next().unwrap();
let second = balancer.next().unwrap();
let third = balancer.next().unwrap();
assert_ne!(first.uri(), second.uri());
assert_eq!(first.uri(), third.uri());
}
#[test]
fn weighted_distribution_respects_weights() {
let pool = make_pool(&[("http://b1:3000", 3), ("http://b2:3000", 1)]);
let balancer = LoadBalancer::new(pool);
let mut b1_count = 0u32;
let mut b2_count = 0u32;
let b1_uri = "http://b1:3000".parse::<hyper::Uri>().unwrap();
for _ in 0..400 {
let selected = balancer.next().unwrap();
if *selected.uri() == b1_uri {
b1_count += 1;
} else {
b2_count += 1;
}
}
assert_eq!(b1_count, 300);
assert_eq!(b2_count, 100);
}
#[test]
fn skips_unhealthy_backends() {
let pool = make_pool(&[("http://b1:3000", 1), ("http://b2:3000", 1)]);
let balancer = LoadBalancer::new(pool);
balancer.pool().all()[0].mark_unhealthy();
for _ in 0..10 {
let selected = balancer.next().unwrap();
assert_eq!(
selected.uri(),
&"http://b2:3000".parse::<hyper::Uri>().unwrap()
);
}
}
#[test]
fn all_unhealthy_returns_error() {
let pool = make_pool(&[("http://b1:3000", 1), ("http://b2:3000", 1)]);
let balancer = LoadBalancer::new(pool);
balancer.pool().all()[0].mark_unhealthy();
balancer.pool().all()[1].mark_unhealthy();
let result = balancer.next();
assert!(result.is_err());
}
#[test]
fn recovery_after_mark_healthy() {
let pool = make_pool(&[("http://b1:3000", 1), ("http://b2:3000", 1)]);
let balancer = LoadBalancer::new(pool);
balancer.pool().all()[0].mark_unhealthy();
balancer.pool().all()[1].mark_unhealthy();
assert!(balancer.next().is_err());
balancer.pool().all()[0].mark_healthy();
let selected = balancer.next().unwrap();
assert_eq!(
selected.uri(),
&"http://b1:3000".parse::<hyper::Uri>().unwrap()
);
}
}