use async_trait::async_trait;
use rand::seq::IndexedRandom;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, trace, warn};
use xxhash_rust::xxh3::xxh3_64;
use zentinel_common::errors::{ZentinelError, ZentinelResult};
use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
#[derive(Debug, Clone)]
pub struct SubsetConfig {
pub subset_size: usize,
pub proxy_id: String,
pub inner_algorithm: SubsetInnerAlgorithm,
}
impl Default for SubsetConfig {
fn default() -> Self {
Self {
subset_size: 10,
proxy_id: format!("proxy-{}", rand::random::<u32>()),
inner_algorithm: SubsetInnerAlgorithm::RoundRobin,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum SubsetInnerAlgorithm {
#[default]
RoundRobin,
Random,
LeastConnections,
}
pub struct SubsetBalancer {
all_targets: Vec<UpstreamTarget>,
subset: Arc<RwLock<Vec<UpstreamTarget>>>,
current: AtomicUsize,
connections: Arc<RwLock<HashMap<String, usize>>>,
health_status: Arc<RwLock<HashMap<String, bool>>>,
config: SubsetConfig,
}
impl SubsetBalancer {
pub fn new(targets: Vec<UpstreamTarget>, config: SubsetConfig) -> Self {
let mut health_status = HashMap::new();
let mut connections = HashMap::new();
for target in &targets {
let addr = target.full_address();
health_status.insert(addr.clone(), true);
connections.insert(addr, 0);
}
let subset = Self::compute_subset(&targets, &config);
info!(
total_targets = targets.len(),
subset_size = subset.len(),
proxy_id = %config.proxy_id,
algorithm = "deterministic_subset",
"Created subset balancer"
);
for target in &subset {
debug!(
target = %target.full_address(),
proxy_id = %config.proxy_id,
"Target included in subset"
);
}
Self {
all_targets: targets,
subset: Arc::new(RwLock::new(subset)),
current: AtomicUsize::new(0),
connections: Arc::new(RwLock::new(connections)),
health_status: Arc::new(RwLock::new(health_status)),
config,
}
}
fn compute_subset(targets: &[UpstreamTarget], config: &SubsetConfig) -> Vec<UpstreamTarget> {
if targets.is_empty() {
return Vec::new();
}
let subset_size = config.subset_size.min(targets.len());
let mut scored_targets: Vec<_> = targets
.iter()
.map(|t| {
let score = Self::subset_score(&t.full_address(), &config.proxy_id);
(t.clone(), score)
})
.collect();
scored_targets.sort_by_key(|(_, score)| *score);
scored_targets
.into_iter()
.take(subset_size)
.map(|(t, _)| t)
.collect()
}
fn subset_score(target_addr: &str, proxy_id: &str) -> u64 {
let combined = format!("{}:{}", target_addr, proxy_id);
xxh3_64(combined.as_bytes())
}
async fn rebuild_subset_if_needed(&self) {
let health = self.health_status.read().await;
let current_subset = self.subset.read().await;
let healthy_in_subset = current_subset
.iter()
.filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
.count();
drop(current_subset);
drop(health);
if healthy_in_subset < self.config.subset_size / 2 {
self.rebuild_subset().await;
}
}
async fn rebuild_subset(&self) {
let health = self.health_status.read().await;
let healthy_targets: Vec<_> = self
.all_targets
.iter()
.filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
.cloned()
.collect();
drop(health);
if healthy_targets.is_empty() {
return;
}
let new_subset = Self::compute_subset(&healthy_targets, &self.config);
info!(
new_subset_size = new_subset.len(),
healthy_total = healthy_targets.len(),
proxy_id = %self.config.proxy_id,
algorithm = "deterministic_subset",
"Rebuilt subset from healthy targets"
);
let mut subset = self.subset.write().await;
*subset = new_subset;
}
async fn select_from_subset<'a>(
&self,
healthy: &[&'a UpstreamTarget],
) -> Option<&'a UpstreamTarget> {
if healthy.is_empty() {
return None;
}
match self.config.inner_algorithm {
SubsetInnerAlgorithm::RoundRobin => {
let idx = self.current.fetch_add(1, Ordering::Relaxed) % healthy.len();
Some(healthy[idx])
}
SubsetInnerAlgorithm::Random => {
use rand::seq::SliceRandom;
let mut rng = rand::rng();
healthy.choose(&mut rng).copied()
}
SubsetInnerAlgorithm::LeastConnections => {
let conns = self.connections.read().await;
healthy
.iter()
.min_by_key(|t| conns.get(&t.full_address()).copied().unwrap_or(0))
.copied()
}
}
}
}
#[async_trait]
impl LoadBalancer for SubsetBalancer {
async fn select(&self, _context: Option<&RequestContext>) -> ZentinelResult<TargetSelection> {
trace!(
total_targets = self.all_targets.len(),
algorithm = "deterministic_subset",
"Selecting upstream target"
);
self.rebuild_subset_if_needed().await;
let health = self.health_status.read().await;
let subset = self.subset.read().await;
let healthy_subset: Vec<_> = subset
.iter()
.filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
.collect();
drop(health);
if healthy_subset.is_empty() {
warn!(
subset_size = subset.len(),
total_targets = self.all_targets.len(),
proxy_id = %self.config.proxy_id,
algorithm = "deterministic_subset",
"No healthy targets in subset"
);
drop(subset);
return Err(ZentinelError::NoHealthyUpstream);
}
let target = self
.select_from_subset(&healthy_subset)
.await
.ok_or(ZentinelError::NoHealthyUpstream)?;
if matches!(
self.config.inner_algorithm,
SubsetInnerAlgorithm::LeastConnections
) {
let mut conns = self.connections.write().await;
*conns.entry(target.full_address()).or_insert(0) += 1;
}
trace!(
selected_target = %target.full_address(),
subset_size = subset.len(),
healthy_count = healthy_subset.len(),
proxy_id = %self.config.proxy_id,
algorithm = "deterministic_subset",
"Selected target from subset"
);
Ok(TargetSelection {
address: target.full_address(),
weight: target.weight,
metadata: HashMap::new(),
})
}
async fn release(&self, selection: &TargetSelection) {
if matches!(
self.config.inner_algorithm,
SubsetInnerAlgorithm::LeastConnections
) {
let mut conns = self.connections.write().await;
if let Some(count) = conns.get_mut(&selection.address) {
*count = count.saturating_sub(1);
}
}
}
async fn report_health(&self, address: &str, healthy: bool) {
let prev_health = {
let health = self.health_status.read().await;
*health.get(address).unwrap_or(&true)
};
trace!(
target = %address,
healthy = healthy,
prev_healthy = prev_health,
algorithm = "deterministic_subset",
"Updating target health status"
);
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
if prev_health != healthy {
self.rebuild_subset_if_needed().await;
}
}
async fn healthy_targets(&self) -> Vec<String> {
self.health_status
.read()
.await
.iter()
.filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_targets(count: usize) -> Vec<UpstreamTarget> {
(0..count)
.map(|i| UpstreamTarget::new(format!("backend-{}", i), 8080, 100))
.collect()
}
#[test]
fn test_subset_size_limited() {
let targets = make_targets(100);
let config = SubsetConfig {
subset_size: 10,
proxy_id: "test-proxy".to_string(),
inner_algorithm: SubsetInnerAlgorithm::RoundRobin,
};
let balancer = SubsetBalancer::new(targets, config);
let subset = balancer.subset.blocking_read();
assert_eq!(subset.len(), 10);
}
#[test]
fn test_subset_deterministic() {
let targets = make_targets(50);
let config1 = SubsetConfig {
subset_size: 10,
proxy_id: "proxy-a".to_string(),
inner_algorithm: SubsetInnerAlgorithm::RoundRobin,
};
let config2 = SubsetConfig {
subset_size: 10,
proxy_id: "proxy-a".to_string(),
inner_algorithm: SubsetInnerAlgorithm::RoundRobin,
};
let balancer1 = SubsetBalancer::new(targets.clone(), config1);
let balancer2 = SubsetBalancer::new(targets, config2);
let subset1: Vec<_> = balancer1
.subset
.blocking_read()
.iter()
.map(|t| t.full_address())
.collect();
let subset2: Vec<_> = balancer2
.subset
.blocking_read()
.iter()
.map(|t| t.full_address())
.collect();
assert_eq!(subset1, subset2);
}
#[test]
fn test_different_proxies_get_different_subsets() {
let targets = make_targets(50);
let config1 = SubsetConfig {
subset_size: 10,
proxy_id: "proxy-a".to_string(),
inner_algorithm: SubsetInnerAlgorithm::RoundRobin,
};
let config2 = SubsetConfig {
subset_size: 10,
proxy_id: "proxy-b".to_string(),
inner_algorithm: SubsetInnerAlgorithm::RoundRobin,
};
let balancer1 = SubsetBalancer::new(targets.clone(), config1);
let balancer2 = SubsetBalancer::new(targets, config2);
let subset1: Vec<_> = balancer1
.subset
.blocking_read()
.iter()
.map(|t| t.full_address())
.collect();
let subset2: Vec<_> = balancer2
.subset
.blocking_read()
.iter()
.map(|t| t.full_address())
.collect();
assert_ne!(subset1, subset2);
}
#[tokio::test]
async fn test_selects_from_subset_only() {
let targets = make_targets(50);
let config = SubsetConfig {
subset_size: 5,
proxy_id: "test-proxy".to_string(),
inner_algorithm: SubsetInnerAlgorithm::RoundRobin,
};
let balancer = SubsetBalancer::new(targets, config);
let subset_addrs: Vec<_> = balancer
.subset
.read()
.await
.iter()
.map(|t| t.full_address())
.collect();
for _ in 0..20 {
let selection = balancer.select(None).await.unwrap();
assert!(
subset_addrs.contains(&selection.address),
"Selected {} which is not in subset {:?}",
selection.address,
subset_addrs
);
}
}
#[test]
fn test_even_distribution_across_proxies() {
let targets = make_targets(100);
let num_proxies = 100;
let subset_size = 10;
let mut backend_counts: HashMap<String, usize> = HashMap::new();
for i in 0..num_proxies {
let config = SubsetConfig {
subset_size,
proxy_id: format!("proxy-{}", i),
inner_algorithm: SubsetInnerAlgorithm::RoundRobin,
};
let subset = SubsetBalancer::compute_subset(&targets, &config);
for target in subset.iter() {
*backend_counts.entry(target.full_address()).or_insert(0) += 1;
}
}
let expected = (num_proxies * subset_size) / targets.len();
let min_count = *backend_counts.values().min().unwrap_or(&0);
let max_count = *backend_counts.values().max().unwrap_or(&0);
assert!(min_count > 0, "Some backends were never selected");
assert!(
max_count <= expected * 3,
"Backend received too much traffic: {} (expected ~{})",
max_count,
expected
);
let mean = (num_proxies * subset_size) as f64 / targets.len() as f64;
let variance: f64 = backend_counts
.values()
.map(|&c| (c as f64 - mean).powi(2))
.sum::<f64>()
/ targets.len() as f64;
let std_dev = variance.sqrt();
assert!(
std_dev < mean,
"Distribution too uneven: std_dev={:.2}, mean={:.2}",
std_dev,
mean
);
}
}