use crate::config::{ServerConfig, Strategy};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug)]
pub struct Backend {
pub url: String,
pub weight: u32,
healthy: AtomicBool,
active_connections: AtomicUsize,
}
impl Backend {
pub fn new(url: String, weight: u32) -> Self {
Self {
url,
weight,
healthy: AtomicBool::new(true),
active_connections: AtomicUsize::new(0),
}
}
pub fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
pub fn set_healthy(&self, healthy: bool) {
self.healthy.store(healthy, Ordering::Relaxed);
}
pub fn inc_connections(&self) {
self.active_connections.fetch_add(1, Ordering::Relaxed);
}
pub fn dec_connections(&self) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
}
pub fn connections(&self) -> usize {
self.active_connections.load(Ordering::Relaxed)
}
}
pub struct LoadBalancer {
pub name: String,
strategy: Strategy,
backends: Vec<Arc<Backend>>,
rr_counter: AtomicUsize,
sticky_cookie: Option<String>,
}
impl LoadBalancer {
pub fn new(
name: String,
strategy: Strategy,
servers: &[ServerConfig],
sticky_cookie: Option<String>,
) -> Self {
let backends = servers
.iter()
.map(|s| Arc::new(Backend::new(s.url.clone(), s.weight)))
.collect();
Self {
name,
strategy,
backends,
rr_counter: AtomicUsize::new(0),
sticky_cookie,
}
}
pub fn next_backend(&self) -> Option<Arc<Backend>> {
let healthy_count = self.backends.iter().filter(|b| b.is_healthy()).count();
if healthy_count == 0 {
return None;
}
match self.strategy {
Strategy::RoundRobin => {
let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_count;
self.backends
.iter()
.filter(|b| b.is_healthy())
.nth(idx)
.cloned()
}
Strategy::Weighted => {
let total_weight: u32 = self
.backends
.iter()
.filter(|b| b.is_healthy())
.map(|b| b.weight)
.sum();
if total_weight == 0 {
return self.backends.iter().find(|b| b.is_healthy()).cloned();
}
let counter = self.rr_counter.fetch_add(1, Ordering::Relaxed) as u32;
let target = counter % total_weight;
let mut cumulative = 0u32;
for backend in self.backends.iter().filter(|b| b.is_healthy()) {
cumulative += backend.weight;
if target < cumulative {
return Some(backend.clone());
}
}
self.backends.iter().rfind(|b| b.is_healthy()).cloned()
}
Strategy::LeastConnections => self
.backends
.iter()
.filter(|b| b.is_healthy())
.min_by_key(|b| b.connections())
.cloned(),
Strategy::Random => {
let idx = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos() as usize)
% healthy_count;
self.backends
.iter()
.filter(|b| b.is_healthy())
.nth(idx)
.cloned()
}
}
}
pub fn backends(&self) -> &[Arc<Backend>] {
&self.backends
}
pub fn healthy_count(&self) -> usize {
self.backends.iter().filter(|b| b.is_healthy()).count()
}
#[allow(dead_code)]
pub fn total_count(&self) -> usize {
self.backends.len()
}
#[allow(dead_code)]
pub fn sticky_cookie(&self) -> Option<&str> {
self.sticky_cookie.as_deref()
}
#[allow(dead_code)]
pub fn strategy(&self) -> &Strategy {
&self.strategy
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_servers(urls: Vec<&str>) -> Vec<ServerConfig> {
urls.into_iter()
.map(|url| ServerConfig {
url: url.to_string(),
weight: 1,
})
.collect()
}
fn make_weighted_servers() -> Vec<ServerConfig> {
vec![
ServerConfig {
url: "http://a:8001".to_string(),
weight: 3,
},
ServerConfig {
url: "http://b:8002".to_string(),
weight: 1,
},
]
}
#[test]
fn test_round_robin_single() {
let servers = make_servers(vec!["http://127.0.0.1:8001"]);
let lb = LoadBalancer::new("test".into(), Strategy::RoundRobin, &servers, None);
let b = lb.next_backend().unwrap();
assert_eq!(b.url, "http://127.0.0.1:8001");
}
#[test]
fn test_round_robin_cycles() {
let servers = make_servers(vec!["http://a:8001", "http://b:8002", "http://c:8003"]);
let lb = LoadBalancer::new("test".into(), Strategy::RoundRobin, &servers, None);
let urls: Vec<String> = (0..6)
.map(|_| lb.next_backend().unwrap().url.clone())
.collect();
assert_eq!(urls[0], "http://a:8001");
assert_eq!(urls[1], "http://b:8002");
assert_eq!(urls[2], "http://c:8003");
assert_eq!(urls[3], "http://a:8001");
assert_eq!(urls[4], "http://b:8002");
assert_eq!(urls[5], "http://c:8003");
}
#[test]
fn test_round_robin_skips_unhealthy() {
let servers = make_servers(vec!["http://a:8001", "http://b:8002"]);
let lb = LoadBalancer::new("test".into(), Strategy::RoundRobin, &servers, None);
lb.backends()[0].set_healthy(false);
let b = lb.next_backend().unwrap();
assert_eq!(b.url, "http://b:8002");
}
#[test]
fn test_all_unhealthy_returns_none() {
let servers = make_servers(vec!["http://a:8001"]);
let lb = LoadBalancer::new("test".into(), Strategy::RoundRobin, &servers, None);
lb.backends()[0].set_healthy(false);
assert!(lb.next_backend().is_none());
}
#[test]
fn test_weighted_distribution() {
let servers = make_weighted_servers();
let lb = LoadBalancer::new("test".into(), Strategy::Weighted, &servers, None);
let mut a_count = 0;
let mut b_count = 0;
for _ in 0..100 {
let b = lb.next_backend().unwrap();
if b.url.contains("a:") {
a_count += 1;
} else {
b_count += 1;
}
}
assert!(a_count > b_count, "a={} should be > b={}", a_count, b_count);
}
#[test]
fn test_least_connections() {
let servers = make_servers(vec!["http://a:8001", "http://b:8002"]);
let lb = LoadBalancer::new("test".into(), Strategy::LeastConnections, &servers, None);
lb.backends()[0].inc_connections();
lb.backends()[0].inc_connections();
let b = lb.next_backend().unwrap();
assert_eq!(b.url, "http://b:8002"); }
#[test]
fn test_random_returns_something() {
let servers = make_servers(vec!["http://a:8001", "http://b:8002"]);
let lb = LoadBalancer::new("test".into(), Strategy::Random, &servers, None);
let b = lb.next_backend();
assert!(b.is_some());
}
#[test]
fn test_backend_health() {
let b = Backend::new("http://test:8001".to_string(), 1);
assert!(b.is_healthy());
b.set_healthy(false);
assert!(!b.is_healthy());
b.set_healthy(true);
assert!(b.is_healthy());
}
#[test]
fn test_backend_connections() {
let b = Backend::new("http://test:8001".to_string(), 1);
assert_eq!(b.connections(), 0);
b.inc_connections();
b.inc_connections();
assert_eq!(b.connections(), 2);
b.dec_connections();
assert_eq!(b.connections(), 1);
}
#[test]
fn test_healthy_count() {
let servers = make_servers(vec!["http://a:8001", "http://b:8002", "http://c:8003"]);
let lb = LoadBalancer::new("test".into(), Strategy::RoundRobin, &servers, None);
assert_eq!(lb.healthy_count(), 3);
assert_eq!(lb.total_count(), 3);
lb.backends()[1].set_healthy(false);
assert_eq!(lb.healthy_count(), 2);
assert_eq!(lb.total_count(), 3);
}
#[test]
fn test_sticky_cookie() {
let servers = make_servers(vec!["http://a:8001"]);
let lb = LoadBalancer::new(
"test".into(),
Strategy::RoundRobin,
&servers,
Some("session_id".to_string()),
);
assert_eq!(lb.sticky_cookie(), Some("session_id"));
let lb2 = LoadBalancer::new("test".into(), Strategy::RoundRobin, &servers, None);
assert_eq!(lb2.sticky_cookie(), None);
}
#[test]
fn test_empty_backends() {
let lb = LoadBalancer::new("test".into(), Strategy::RoundRobin, &[], None);
assert!(lb.next_backend().is_none());
assert_eq!(lb.healthy_count(), 0);
assert_eq!(lb.total_count(), 0);
}
#[test]
fn test_weighted_zero_total_weight() {
let servers = vec![
ServerConfig {
url: "http://a:8001".to_string(),
weight: 0,
},
ServerConfig {
url: "http://b:8002".to_string(),
weight: 0,
},
];
let lb = LoadBalancer::new("test".into(), Strategy::Weighted, &servers, None);
let b = lb.next_backend();
assert!(b.is_some());
assert!(b.unwrap().url.starts_with("http://"));
}
#[test]
fn test_weighted_all_unhealthy() {
let servers = vec![
ServerConfig {
url: "http://a:8001".to_string(),
weight: 3,
},
ServerConfig {
url: "http://b:8002".to_string(),
weight: 1,
},
];
let lb = LoadBalancer::new("test".into(), Strategy::Weighted, &servers, None);
lb.backends()[0].set_healthy(false);
lb.backends()[1].set_healthy(false);
assert!(lb.next_backend().is_none());
}
#[test]
fn test_round_robin_healthy_skips_all_unhealthy() {
let servers = make_servers(vec!["http://a:8001", "http://b:8002", "http://c:8003"]);
let lb = LoadBalancer::new("test".into(), Strategy::RoundRobin, &servers, None);
for b in lb.backends() {
b.set_healthy(false);
}
assert!(lb.next_backend().is_none());
}
#[test]
fn test_random_skips_unhealthy() {
let servers = make_servers(vec!["http://a:8001", "http://b:8002", "http://c:8003"]);
let lb = LoadBalancer::new("test".into(), Strategy::Random, &servers, None);
lb.backends()[0].set_healthy(false);
lb.backends()[1].set_healthy(false);
for _ in 0..10 {
let b = lb.next_backend().unwrap();
assert_eq!(b.url, "http://c:8003");
}
}
#[test]
fn test_random_all_unhealthy() {
let servers = make_servers(vec!["http://a:8001", "http://b:8002"]);
let lb = LoadBalancer::new("test".into(), Strategy::Random, &servers, None);
lb.backends()[0].set_healthy(false);
lb.backends()[1].set_healthy(false);
assert!(lb.next_backend().is_none());
}
}