use async_trait::async_trait;
use rand::seq::IndexedRandom;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, trace, warn};
use zentinel_common::errors::{ZentinelError, ZentinelResult};
use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
#[derive(Debug, Clone)]
pub struct LocalityAwareConfig {
pub local_zone: String,
pub fallback_strategy: LocalityFallback,
pub min_local_healthy: usize,
pub use_weights: bool,
pub zone_priority: Vec<String>,
}
impl Default for LocalityAwareConfig {
fn default() -> Self {
Self {
local_zone: std::env::var("ZENTINEL_ZONE")
.or_else(|_| std::env::var("ZONE"))
.or_else(|_| std::env::var("REGION"))
.unwrap_or_else(|_| "default".to_string()),
fallback_strategy: LocalityFallback::RoundRobin,
min_local_healthy: 1,
use_weights: true,
zone_priority: Vec::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LocalityFallback {
RoundRobin,
Random,
FailLocal,
}
#[derive(Debug, Clone)]
struct ZonedTarget {
target: UpstreamTarget,
zone: String,
}
pub struct LocalityAwareBalancer {
targets: Vec<ZonedTarget>,
health_status: Arc<RwLock<HashMap<String, bool>>>,
local_counter: AtomicUsize,
fallback_counter: AtomicUsize,
config: LocalityAwareConfig,
}
impl LocalityAwareBalancer {
pub fn new(targets: Vec<UpstreamTarget>, config: LocalityAwareConfig) -> Self {
let mut health_status = HashMap::new();
let mut zoned_targets = Vec::with_capacity(targets.len());
for target in targets {
health_status.insert(target.full_address(), true);
let (zone, actual_target) = Self::parse_zone_from_target(&target);
zoned_targets.push(ZonedTarget {
target: actual_target,
zone,
});
}
debug!(
local_zone = %config.local_zone,
total_targets = zoned_targets.len(),
local_targets = zoned_targets.iter().filter(|t| t.zone == config.local_zone).count(),
"Created locality-aware balancer"
);
Self {
targets: zoned_targets,
health_status: Arc::new(RwLock::new(health_status)),
local_counter: AtomicUsize::new(0),
fallback_counter: AtomicUsize::new(0),
config,
}
}
fn parse_zone_from_target(target: &UpstreamTarget) -> (String, UpstreamTarget) {
let addr = &target.address;
if let Some(rest) = addr.strip_prefix("zone=") {
if let Some((zone, host)) = rest.split_once(',') {
return (
zone.to_string(),
UpstreamTarget::new(host, target.port, target.weight),
);
}
}
if let Some((zone, host)) = addr.split_once('/') {
if !zone.contains(':') && !zone.contains('.') {
return (
zone.to_string(),
UpstreamTarget::new(host, target.port, target.weight),
);
}
}
("unknown".to_string(), target.clone())
}
async fn healthy_in_zone(&self, zone: &str) -> Vec<&ZonedTarget> {
let health = self.health_status.read().await;
self.targets
.iter()
.filter(|t| t.zone == zone && *health.get(&t.target.full_address()).unwrap_or(&true))
.collect()
}
async fn healthy_fallback(&self) -> Vec<&ZonedTarget> {
let health = self.health_status.read().await;
let local_zone = &self.config.local_zone;
let mut fallback: Vec<_> = self
.targets
.iter()
.filter(|t| {
t.zone != *local_zone && *health.get(&t.target.full_address()).unwrap_or(&true)
})
.collect();
if !self.config.zone_priority.is_empty() {
fallback.sort_by(|a, b| {
let priority_a = self
.config
.zone_priority
.iter()
.position(|z| z == &a.zone)
.unwrap_or(usize::MAX);
let priority_b = self
.config
.zone_priority
.iter()
.position(|z| z == &b.zone)
.unwrap_or(usize::MAX);
priority_a.cmp(&priority_b)
});
}
fallback
}
fn select_round_robin<'a>(
&self,
targets: &[&'a ZonedTarget],
counter: &AtomicUsize,
) -> Option<&'a ZonedTarget> {
if targets.is_empty() {
return None;
}
if self.config.use_weights {
let total_weight: u32 = targets.iter().map(|t| t.target.weight).sum();
if total_weight == 0 {
return targets.first().copied();
}
let idx = counter.fetch_add(1, Ordering::Relaxed);
let mut weight_idx = (idx as u32) % total_weight;
for target in targets {
if weight_idx < target.target.weight {
return Some(target);
}
weight_idx -= target.target.weight;
}
targets.first().copied()
} else {
let idx = counter.fetch_add(1, Ordering::Relaxed) % targets.len();
Some(targets[idx])
}
}
fn select_random<'a>(&self, targets: &[&'a ZonedTarget]) -> Option<&'a ZonedTarget> {
use rand::seq::SliceRandom;
if targets.is_empty() {
return None;
}
let mut rng = rand::rng();
targets.choose(&mut rng).copied()
}
}
#[async_trait]
impl LoadBalancer for LocalityAwareBalancer {
async fn select(&self, _context: Option<&RequestContext>) -> ZentinelResult<TargetSelection> {
trace!(
total_targets = self.targets.len(),
local_zone = %self.config.local_zone,
algorithm = "locality_aware",
"Selecting upstream target"
);
let local_healthy = self.healthy_in_zone(&self.config.local_zone).await;
if local_healthy.len() >= self.config.min_local_healthy {
let selected = self
.select_round_robin(&local_healthy, &self.local_counter)
.ok_or(ZentinelError::NoHealthyUpstream)?;
trace!(
selected_target = %selected.target.full_address(),
zone = %selected.zone,
local_healthy = local_healthy.len(),
algorithm = "locality_aware",
"Selected local target"
);
return Ok(TargetSelection {
address: selected.target.full_address(),
weight: selected.target.weight,
metadata: {
let mut m = HashMap::new();
m.insert("zone".to_string(), selected.zone.clone());
m.insert("locality".to_string(), "local".to_string());
m
},
});
}
match self.config.fallback_strategy {
LocalityFallback::FailLocal => {
warn!(
local_zone = %self.config.local_zone,
local_healthy = local_healthy.len(),
min_required = self.config.min_local_healthy,
algorithm = "locality_aware",
"No healthy local targets and fallback disabled"
);
return Err(ZentinelError::NoHealthyUpstream);
}
LocalityFallback::RoundRobin | LocalityFallback::Random => {
}
}
let fallback_targets = self.healthy_fallback().await;
let all_targets: Vec<&ZonedTarget> = if !local_healthy.is_empty() {
local_healthy.into_iter().chain(fallback_targets).collect()
} else {
fallback_targets
};
if all_targets.is_empty() {
warn!(
total_targets = self.targets.len(),
algorithm = "locality_aware",
"No healthy upstream targets available"
);
return Err(ZentinelError::NoHealthyUpstream);
}
let selected = match self.config.fallback_strategy {
LocalityFallback::RoundRobin => {
self.select_round_robin(&all_targets, &self.fallback_counter)
}
LocalityFallback::Random => self.select_random(&all_targets),
LocalityFallback::FailLocal => unreachable!(),
}
.ok_or(ZentinelError::NoHealthyUpstream)?;
let is_local = selected.zone == self.config.local_zone;
debug!(
selected_target = %selected.target.full_address(),
zone = %selected.zone,
is_local = is_local,
fallback_used = !is_local,
algorithm = "locality_aware",
"Selected target (fallback path)"
);
Ok(TargetSelection {
address: selected.target.full_address(),
weight: selected.target.weight,
metadata: {
let mut m = HashMap::new();
m.insert("zone".to_string(), selected.zone.clone());
m.insert(
"locality".to_string(),
if is_local { "local" } else { "remote" }.to_string(),
);
m
},
})
}
async fn report_health(&self, address: &str, healthy: bool) {
trace!(
target = %address,
healthy = healthy,
algorithm = "locality_aware",
"Updating target health status"
);
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
}
async fn healthy_targets(&self) -> Vec<String> {
self.health_status
.read()
.await
.iter()
.filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_zoned_targets() -> Vec<UpstreamTarget> {
vec![
UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100),
UpstreamTarget::new("zone=us-west-1,10.0.0.2", 8080, 100),
UpstreamTarget::new("zone=us-east-1,10.1.0.1", 8080, 100),
UpstreamTarget::new("zone=us-east-1,10.1.0.2", 8080, 100),
UpstreamTarget::new("zone=eu-west-1,10.2.0.1", 8080, 100),
]
}
#[test]
fn test_zone_parsing() {
let target = UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100);
let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
assert_eq!(zone, "us-west-1");
assert_eq!(parsed.address, "10.0.0.1");
let target = UpstreamTarget::new("us-east-1/10.0.0.1", 8080, 100);
let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
assert_eq!(zone, "us-east-1");
assert_eq!(parsed.address, "10.0.0.1");
let target = UpstreamTarget::new("10.0.0.1", 8080, 100);
let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
assert_eq!(zone, "unknown");
assert_eq!(parsed.address, "10.0.0.1");
}
#[tokio::test]
async fn test_prefers_local_zone() {
let targets = make_zoned_targets();
let config = LocalityAwareConfig {
local_zone: "us-west-1".to_string(),
..Default::default()
};
let balancer = LocalityAwareBalancer::new(targets, config);
for _ in 0..10 {
let selection = balancer.select(None).await.unwrap();
assert!(
selection.address.starts_with("10.0.0."),
"Expected local target, got {}",
selection.address
);
assert_eq!(selection.metadata.get("locality").unwrap(), "local");
}
}
#[tokio::test]
async fn test_fallback_when_local_unhealthy() {
let targets = make_zoned_targets();
let config = LocalityAwareConfig {
local_zone: "us-west-1".to_string(),
min_local_healthy: 1,
..Default::default()
};
let balancer = LocalityAwareBalancer::new(targets, config);
balancer.report_health("10.0.0.1:8080", false).await;
balancer.report_health("10.0.0.2:8080", false).await;
let selection = balancer.select(None).await.unwrap();
assert!(
!selection.address.starts_with("10.0.0."),
"Expected fallback target, got {}",
selection.address
);
assert_eq!(selection.metadata.get("locality").unwrap(), "remote");
}
#[tokio::test]
async fn test_zone_priority() {
let targets = make_zoned_targets();
let config = LocalityAwareConfig {
local_zone: "us-west-1".to_string(),
min_local_healthy: 1,
zone_priority: vec!["us-east-1".to_string(), "eu-west-1".to_string()],
..Default::default()
};
let balancer = LocalityAwareBalancer::new(targets, config);
balancer.report_health("10.0.0.1:8080", false).await;
balancer.report_health("10.0.0.2:8080", false).await;
let selection = balancer.select(None).await.unwrap();
assert!(
selection.address.starts_with("10.1.0."),
"Expected us-east-1 target, got {}",
selection.address
);
}
#[tokio::test]
async fn test_fail_local_strategy() {
let targets = make_zoned_targets();
let config = LocalityAwareConfig {
local_zone: "us-west-1".to_string(),
fallback_strategy: LocalityFallback::FailLocal,
..Default::default()
};
let balancer = LocalityAwareBalancer::new(targets, config);
balancer.report_health("10.0.0.1:8080", false).await;
balancer.report_health("10.0.0.2:8080", false).await;
let result = balancer.select(None).await;
assert!(result.is_err());
}
}