use dashmap::DashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, warn};
const UNHEALTHY_THRESHOLD: u64 = 3;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LbStrategy {
RoundRobin,
LeastConnections,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HealthStatus {
Healthy,
Unhealthy,
}
pub struct Backend {
pub addr: SocketAddr,
active_connections: AtomicU64,
health: std::sync::RwLock<HealthStatus>,
consecutive_failures: AtomicU64,
}
impl std::fmt::Debug for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Backend")
.field("addr", &self.addr)
.field(
"active_connections",
&self.active_connections.load(Ordering::Relaxed),
)
.field("health", &*self.health.read().unwrap())
.field(
"consecutive_failures",
&self.consecutive_failures.load(Ordering::Relaxed),
)
.finish()
}
}
impl Backend {
#[must_use]
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
active_connections: AtomicU64::new(0),
health: std::sync::RwLock::new(HealthStatus::Healthy),
consecutive_failures: AtomicU64::new(0),
}
}
pub fn track_connection(self: &Arc<Self>) -> ConnectionGuard {
self.active_connections.fetch_add(1, Ordering::Relaxed);
ConnectionGuard {
backend: Arc::clone(self),
}
}
pub fn active_connections(&self) -> u64 {
self.active_connections.load(Ordering::Relaxed)
}
pub fn is_healthy(&self) -> bool {
*self.health.read().unwrap() == HealthStatus::Healthy
}
pub fn set_healthy(&self) {
*self.health.write().unwrap() = HealthStatus::Healthy;
}
pub fn set_unhealthy(&self) {
*self.health.write().unwrap() = HealthStatus::Unhealthy;
}
pub fn record_failure(&self) {
self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
}
pub fn reset_failures(&self) {
self.consecutive_failures.store(0, Ordering::Relaxed);
}
pub fn consecutive_failures(&self) -> u64 {
self.consecutive_failures.load(Ordering::Relaxed)
}
}
pub struct ConnectionGuard {
backend: Arc<Backend>,
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.backend
.active_connections
.fetch_sub(1, Ordering::Relaxed);
}
}
pub struct BackendGroup {
pub backends: Vec<Arc<Backend>>,
pub strategy: LbStrategy,
rr_counter: AtomicUsize,
}
impl BackendGroup {
#[must_use]
pub fn new(strategy: LbStrategy) -> Self {
Self {
backends: Vec::new(),
strategy,
rr_counter: AtomicUsize::new(0),
}
}
pub fn select(&self) -> Option<Arc<Backend>> {
if self.backends.is_empty() {
return None;
}
match self.strategy {
LbStrategy::RoundRobin => self.select_round_robin(),
LbStrategy::LeastConnections => self.select_least_connections(),
}
}
fn select_round_robin(&self) -> Option<Arc<Backend>> {
let len = self.backends.len();
let start = self.rr_counter.fetch_add(1, Ordering::Relaxed) % len;
for i in 0..len {
let idx = (start + i) % len;
let backend = &self.backends[idx];
if backend.is_healthy() {
return Some(Arc::clone(backend));
}
}
None
}
fn select_least_connections(&self) -> Option<Arc<Backend>> {
self.backends
.iter()
.filter(|b| b.is_healthy())
.min_by_key(|b| b.active_connections())
.cloned()
}
pub fn update_backends(&mut self, addrs: Vec<SocketAddr>) {
let mut new_backends = Vec::with_capacity(addrs.len());
for addr in addrs {
if let Some(existing) = self.backends.iter().find(|b| b.addr == addr) {
new_backends.push(Arc::clone(existing));
} else {
new_backends.push(Arc::new(Backend::new(addr)));
}
}
self.backends = new_backends;
}
pub fn add_backend(&mut self, addr: SocketAddr) {
if !self.backends.iter().any(|b| b.addr == addr) {
self.backends.push(Arc::new(Backend::new(addr)));
}
}
pub fn remove_backend(&mut self, addr: &SocketAddr) {
self.backends.retain(|b| b.addr != *addr);
}
}
#[derive(Debug, Clone)]
pub struct BackendSnapshot {
pub addr: SocketAddr,
pub healthy: bool,
pub active_connections: u64,
pub consecutive_failures: u64,
}
#[derive(Debug, Clone)]
pub struct BackendGroupSnapshot {
pub strategy: LbStrategy,
pub backends: Vec<BackendSnapshot>,
}
pub struct LoadBalancer {
groups: DashMap<String, BackendGroup>,
}
impl Default for LoadBalancer {
fn default() -> Self {
Self::new()
}
}
impl LoadBalancer {
#[must_use]
pub fn new() -> Self {
Self {
groups: DashMap::new(),
}
}
pub fn register(&self, service: &str, addrs: Vec<SocketAddr>, strategy: LbStrategy) {
let mut group = BackendGroup::new(strategy);
group.backends = addrs
.into_iter()
.map(|a| Arc::new(Backend::new(a)))
.collect();
self.groups.insert(service.to_string(), group);
}
#[must_use]
pub fn select(&self, service: &str) -> Option<Arc<Backend>> {
self.groups.get(service).and_then(|g| g.select())
}
pub fn update_backends(&self, service: &str, addrs: Vec<SocketAddr>) {
if let Some(mut group) = self.groups.get_mut(service) {
group.update_backends(addrs);
}
}
pub fn unregister(&self, service: &str) {
self.groups.remove(service);
}
pub fn add_backend(&self, service: &str, addr: SocketAddr) {
if let Some(mut group) = self.groups.get_mut(service) {
group.add_backend(addr);
debug!(service = service, backend = %addr, total = group.backends.len(), "Added backend to LB group");
} else {
warn!(service = service, backend = %addr, "Cannot add backend: LB group not registered");
}
}
pub fn remove_backend(&self, service: &str, addr: &SocketAddr) {
if let Some(mut group) = self.groups.get_mut(service) {
group.remove_backend(addr);
}
}
#[must_use]
pub fn backend_count(&self, service: &str) -> usize {
self.groups.get(service).map_or(0, |g| g.backends.len())
}
#[must_use]
pub fn healthy_count(&self, service: &str) -> usize {
self.groups
.get(service)
.map_or(0, |g| g.backends.iter().filter(|b| b.is_healthy()).count())
}
pub fn mark_health(&self, service: &str, addr: &SocketAddr, healthy: bool) {
if let Some(group) = self.groups.get(service) {
if let Some(backend) = group.backends.iter().find(|b| b.addr == *addr) {
if healthy {
backend.set_healthy();
backend.reset_failures();
} else {
backend.set_unhealthy();
backend.record_failure();
}
}
}
}
#[must_use]
pub fn list_service_names(&self) -> Vec<String> {
self.groups.iter().map(|e| e.key().clone()).collect()
}
#[must_use]
pub fn group_snapshot(&self, service: &str) -> Option<BackendGroupSnapshot> {
self.groups.get(service).map(|g| BackendGroupSnapshot {
strategy: g.strategy,
backends: g
.backends
.iter()
.map(|b| BackendSnapshot {
addr: b.addr,
healthy: b.is_healthy(),
active_connections: b.active_connections(),
consecutive_failures: b.consecutive_failures(),
})
.collect(),
})
}
#[must_use]
pub fn spawn_health_checker(
self: &Arc<Self>,
interval: Duration,
timeout: Duration,
) -> tokio::task::JoinHandle<()> {
let lb = Arc::clone(self);
tokio::spawn(async move {
let semaphore = Arc::new(tokio::sync::Semaphore::new(64));
loop {
let backends: Vec<Arc<Backend>> = lb
.groups
.iter()
.flat_map(|entry| entry.value().backends.clone())
.collect();
debug!(
backend_count = backends.len(),
"Starting health check sweep"
);
let mut handles = Vec::with_capacity(backends.len());
for backend in backends {
let sem = Arc::clone(&semaphore);
let probe_timeout = timeout;
handles.push(tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
let addr = backend.addr;
match tokio::time::timeout(
probe_timeout,
tokio::net::TcpStream::connect(addr),
)
.await
{
Ok(Ok(_stream)) => {
if !backend.is_healthy() {
debug!(%addr, "Backend recovered");
}
backend.set_healthy();
backend.reset_failures();
}
Ok(Err(e)) => {
backend.record_failure();
let failures = backend.consecutive_failures();
if failures >= UNHEALTHY_THRESHOLD {
if backend.is_healthy() {
warn!(
%addr,
error = %e,
failures,
"Backend marked unhealthy after consecutive failures"
);
}
backend.set_unhealthy();
} else {
debug!(
%addr,
error = %e,
failures,
"Health check failed ({failures}/{UNHEALTHY_THRESHOLD} before unhealthy)"
);
}
}
Err(_elapsed) => {
backend.record_failure();
let failures = backend.consecutive_failures();
if failures >= UNHEALTHY_THRESHOLD {
if backend.is_healthy() {
warn!(
%addr,
failures,
"Backend marked unhealthy after consecutive timeout failures"
);
}
backend.set_unhealthy();
} else {
debug!(
%addr,
failures,
"Health check timed out ({failures}/{UNHEALTHY_THRESHOLD} before unhealthy)"
);
}
}
}
}));
}
for handle in handles {
let _ = handle.await;
}
tokio::time::sleep(interval).await;
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn addr(port: u16) -> SocketAddr {
format!("127.0.0.1:{port}").parse().unwrap()
}
#[test]
fn test_round_robin_selection() {
let mut group = BackendGroup::new(LbStrategy::RoundRobin);
group.backends = vec![
Arc::new(Backend::new(addr(8001))),
Arc::new(Backend::new(addr(8002))),
Arc::new(Backend::new(addr(8003))),
];
let a = group.select().unwrap();
let b = group.select().unwrap();
let c = group.select().unwrap();
let d = group.select().unwrap();
assert_eq!(a.addr, addr(8001));
assert_eq!(b.addr, addr(8002));
assert_eq!(c.addr, addr(8003));
assert_eq!(d.addr, addr(8001)); }
#[test]
fn test_least_connections_selection() {
let mut group = BackendGroup::new(LbStrategy::LeastConnections);
let b1 = Arc::new(Backend::new(addr(8001)));
let b2 = Arc::new(Backend::new(addr(8002)));
let b3 = Arc::new(Backend::new(addr(8003)));
let _guard = b1.track_connection();
group.backends = vec![b1, Arc::clone(&b2), b3];
let selected = group.select().unwrap();
assert_ne!(selected.addr, addr(8001));
assert!(selected.addr == addr(8002) || selected.addr == addr(8003));
let _guard2 = b2.track_connection();
let selected = group.select().unwrap();
assert_eq!(selected.addr, addr(8003));
}
#[test]
fn test_unhealthy_backends_skipped() {
let mut group = BackendGroup::new(LbStrategy::RoundRobin);
let b1 = Arc::new(Backend::new(addr(8001)));
let b2 = Arc::new(Backend::new(addr(8002)));
let b3 = Arc::new(Backend::new(addr(8003)));
b2.set_unhealthy();
group.backends = vec![b1, b2, Arc::clone(&b3)];
for _ in 0..10 {
let selected = group.select().unwrap();
assert_ne!(selected.addr, addr(8002), "Unhealthy backend was selected");
}
}
#[test]
fn test_connection_guard_decrement() {
let backend = Arc::new(Backend::new(addr(9000)));
assert_eq!(backend.active_connections(), 0);
let guard1 = backend.track_connection();
assert_eq!(backend.active_connections(), 1);
let guard2 = backend.track_connection();
assert_eq!(backend.active_connections(), 2);
drop(guard1);
assert_eq!(backend.active_connections(), 1);
drop(guard2);
assert_eq!(backend.active_connections(), 0);
}
#[test]
fn test_update_backends_preserves_state() {
let mut group = BackendGroup::new(LbStrategy::RoundRobin);
let b1 = Arc::new(Backend::new(addr(8001)));
let b2 = Arc::new(Backend::new(addr(8002)));
let _guard = b1.track_connection();
b2.set_unhealthy();
group.backends = vec![Arc::clone(&b1), Arc::clone(&b2)];
group.update_backends(vec![addr(8001), addr(8003)]);
assert_eq!(group.backends.len(), 2);
let preserved = group
.backends
.iter()
.find(|b| b.addr == addr(8001))
.unwrap();
assert_eq!(preserved.active_connections(), 1);
let new_backend = group
.backends
.iter()
.find(|b| b.addr == addr(8003))
.unwrap();
assert_eq!(new_backend.active_connections(), 0);
assert!(new_backend.is_healthy());
assert!(group.backends.iter().all(|b| b.addr != addr(8002)));
}
#[test]
fn test_all_unhealthy_returns_none() {
let mut group = BackendGroup::new(LbStrategy::RoundRobin);
let b1 = Arc::new(Backend::new(addr(8001)));
let b2 = Arc::new(Backend::new(addr(8002)));
b1.set_unhealthy();
b2.set_unhealthy();
group.backends = vec![b1, b2];
assert!(group.select().is_none());
group.strategy = LbStrategy::LeastConnections;
assert!(group.select().is_none());
}
#[test]
fn test_register_and_select() {
let lb = LoadBalancer::new();
lb.register("web", vec![addr(8080), addr(8081)], LbStrategy::RoundRobin);
let backend = lb.select("web").unwrap();
assert!(backend.addr == addr(8080) || backend.addr == addr(8081));
assert!(lb.select("nonexistent").is_none());
}
#[test]
fn test_add_remove_backend() {
let lb = LoadBalancer::new();
lb.register("api", vec![addr(9001)], LbStrategy::RoundRobin);
lb.add_backend("api", addr(9002));
{
let group = lb.groups.get("api").unwrap();
assert_eq!(group.backends.len(), 2);
}
lb.add_backend("api", addr(9002));
{
let group = lb.groups.get("api").unwrap();
assert_eq!(group.backends.len(), 2);
}
lb.remove_backend("api", &addr(9001));
{
let group = lb.groups.get("api").unwrap();
assert_eq!(group.backends.len(), 1);
assert_eq!(group.backends[0].addr, addr(9002));
}
}
#[test]
fn test_unregister() {
let lb = LoadBalancer::new();
lb.register("svc", vec![addr(5000)], LbStrategy::RoundRobin);
assert!(lb.select("svc").is_some());
lb.unregister("svc");
assert!(lb.select("svc").is_none());
}
#[test]
fn test_update_backends_via_lb() {
let lb = LoadBalancer::new();
lb.register("svc", vec![addr(3000)], LbStrategy::RoundRobin);
lb.update_backends("svc", vec![addr(3001), addr(3002)]);
let group = lb.groups.get("svc").unwrap();
assert_eq!(group.backends.len(), 2);
assert!(group.backends.iter().any(|b| b.addr == addr(3001)));
assert!(group.backends.iter().any(|b| b.addr == addr(3002)));
}
#[test]
fn test_empty_group_returns_none() {
let group = BackendGroup::new(LbStrategy::RoundRobin);
assert!(group.select().is_none());
let group_lc = BackendGroup::new(LbStrategy::LeastConnections);
assert!(group_lc.select().is_none());
}
#[test]
fn test_failure_tracking() {
let backend = Backend::new(addr(7000));
assert_eq!(backend.consecutive_failures(), 0);
backend.record_failure();
backend.record_failure();
assert_eq!(backend.consecutive_failures(), 2);
backend.reset_failures();
assert_eq!(backend.consecutive_failures(), 0);
}
#[test]
fn test_health_transitions() {
let backend = Backend::new(addr(7001));
assert!(backend.is_healthy());
backend.set_unhealthy();
assert!(!backend.is_healthy());
backend.set_healthy();
assert!(backend.is_healthy());
}
}