use async_trait::async_trait;
use pingora::upstreams::peer::HttpPeer;
use rand::seq::IndexedRandom;
use std::collections::HashMap;
use std::net::ToSocketAddrs;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, error, info, trace, warn};
use zentinel_common::{
errors::{ZentinelError, ZentinelResult},
types::{CircuitBreakerConfig, LoadBalancingAlgorithm},
CircuitBreaker, UpstreamId,
};
use zentinel_config::UpstreamConfig;
#[derive(Debug, Clone)]
pub struct UpstreamTarget {
pub address: String,
pub port: u16,
pub weight: u32,
}
impl UpstreamTarget {
pub fn new(address: impl Into<String>, port: u16, weight: u32) -> Self {
Self {
address: address.into(),
port,
weight,
}
}
pub fn from_address(addr: &str) -> Option<Self> {
let parts: Vec<&str> = addr.rsplitn(2, ':').collect();
if parts.len() == 2 {
let port = parts[0].parse().ok()?;
let address = parts[1].to_string();
Some(Self {
address,
port,
weight: 100,
})
} else {
None
}
}
pub fn from_config(config: &zentinel_config::UpstreamTarget) -> Option<Self> {
Self::from_address(&config.address).map(|mut t| {
t.weight = config.weight;
t
})
}
pub fn full_address(&self) -> String {
format!("{}:{}", self.address, self.port)
}
}
pub mod adaptive;
pub mod consistent_hash;
pub mod drain;
pub mod health;
pub mod inference_health;
pub mod least_tokens;
pub mod locality;
pub mod maglev;
pub mod p2c;
pub mod peak_ewma;
pub mod sticky_session;
pub mod subset;
pub mod weighted_least_conn;
pub use adaptive::{AdaptiveBalancer, AdaptiveConfig};
pub use consistent_hash::{ConsistentHashBalancer, ConsistentHashConfig};
pub use health::{ActiveHealthChecker, HealthCheckRunner};
pub use inference_health::InferenceHealthCheck;
pub use least_tokens::{
LeastTokensQueuedBalancer, LeastTokensQueuedConfig, LeastTokensQueuedTargetStats,
};
pub use locality::{LocalityAwareBalancer, LocalityAwareConfig};
pub use maglev::{MaglevBalancer, MaglevConfig};
pub use p2c::{P2cBalancer, P2cConfig};
pub use peak_ewma::{PeakEwmaBalancer, PeakEwmaConfig};
pub use sticky_session::{StickySessionBalancer, StickySessionRuntimeConfig};
pub use subset::{SubsetBalancer, SubsetConfig};
pub use weighted_least_conn::{WeightedLeastConnBalancer, WeightedLeastConnConfig};
#[derive(Debug, Clone)]
pub struct RequestContext {
pub client_ip: Option<std::net::SocketAddr>,
pub headers: HashMap<String, String>,
pub path: String,
pub method: String,
}
#[async_trait]
pub trait LoadBalancer: Send + Sync {
async fn select(&self, context: Option<&RequestContext>) -> ZentinelResult<TargetSelection>;
async fn report_health(&self, address: &str, healthy: bool);
async fn healthy_targets(&self) -> Vec<String>;
async fn release(&self, _selection: &TargetSelection) {
}
async fn report_result(
&self,
_selection: &TargetSelection,
_success: bool,
_latency: Option<Duration>,
) {
}
async fn report_result_with_latency(
&self,
address: &str,
success: bool,
_latency: Option<Duration>,
) {
self.report_health(address, success).await;
}
}
#[derive(Debug, Clone)]
pub struct TargetSelection {
pub address: String,
pub weight: u32,
pub metadata: HashMap<String, String>,
}
pub struct UpstreamPool {
id: UpstreamId,
targets: Vec<UpstreamTarget>,
load_balancer: Arc<dyn LoadBalancer>,
pool_config: ConnectionPoolConfig,
http_version: HttpVersionOptions,
tls_enabled: bool,
tls_sni: Option<String>,
tls_config: Option<zentinel_config::UpstreamTlsConfig>,
circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
stats: Arc<PoolStats>,
}
pub struct ConnectionPoolConfig {
pub max_connections: usize,
pub max_idle: usize,
pub idle_timeout: Duration,
pub max_lifetime: Option<Duration>,
pub connection_timeout: Duration,
pub read_timeout: Duration,
pub write_timeout: Duration,
}
pub struct HttpVersionOptions {
pub min_version: u8,
pub max_version: u8,
pub h2_ping_interval: Duration,
pub max_h2_streams: usize,
}
impl ConnectionPoolConfig {
pub fn from_config(
pool_config: &zentinel_config::ConnectionPoolConfig,
timeouts: &zentinel_config::UpstreamTimeouts,
) -> Self {
Self {
max_connections: pool_config.max_connections,
max_idle: pool_config.max_idle,
idle_timeout: Duration::from_secs(pool_config.idle_timeout_secs),
max_lifetime: pool_config.max_lifetime_secs.map(Duration::from_secs),
connection_timeout: Duration::from_secs(timeouts.connect_secs),
read_timeout: Duration::from_secs(timeouts.read_secs),
write_timeout: Duration::from_secs(timeouts.write_secs),
}
}
}
#[derive(Default)]
pub struct PoolStats {
pub requests: AtomicU64,
pub successes: AtomicU64,
pub failures: AtomicU64,
pub retries: AtomicU64,
pub circuit_breaker_trips: AtomicU64,
pub active_requests: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct ShadowTarget {
pub scheme: String,
pub host: String,
pub port: u16,
pub sni: Option<String>,
}
impl ShadowTarget {
pub fn build_url(&self, path: &str) -> String {
let port_suffix = match (self.scheme.as_str(), self.port) {
("http", 80) | ("https", 443) => String::new(),
_ => format!(":{}", self.port),
};
format!("{}://{}{}{}", self.scheme, self.host, port_suffix, path)
}
}
#[derive(Debug, Clone)]
pub struct PoolConfigSnapshot {
pub max_connections: usize,
pub max_idle: usize,
pub idle_timeout_secs: u64,
pub max_lifetime_secs: Option<u64>,
pub connection_timeout_secs: u64,
pub read_timeout_secs: u64,
pub write_timeout_secs: u64,
}
struct RoundRobinBalancer {
targets: Vec<UpstreamTarget>,
current: AtomicUsize,
health_status: Arc<RwLock<HashMap<String, bool>>>,
}
impl RoundRobinBalancer {
fn new(targets: Vec<UpstreamTarget>) -> Self {
let mut health_status = HashMap::new();
for target in &targets {
health_status.insert(target.full_address(), true);
}
Self {
targets,
current: AtomicUsize::new(0),
health_status: Arc::new(RwLock::new(health_status)),
}
}
}
#[async_trait]
impl LoadBalancer for RoundRobinBalancer {
async fn select(&self, _context: Option<&RequestContext>) -> ZentinelResult<TargetSelection> {
trace!(
total_targets = self.targets.len(),
algorithm = "round_robin",
"Selecting upstream target"
);
let health = self.health_status.read().await;
let healthy_targets: Vec<_> = self
.targets
.iter()
.filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
.collect();
if healthy_targets.is_empty() {
warn!(
total_targets = self.targets.len(),
algorithm = "round_robin",
"No healthy upstream targets available"
);
return Err(ZentinelError::NoHealthyUpstream);
}
let index = self.current.fetch_add(1, Ordering::Relaxed) % healthy_targets.len();
let target = healthy_targets[index];
trace!(
selected_target = %target.full_address(),
healthy_count = healthy_targets.len(),
index = index,
algorithm = "round_robin",
"Selected target via round robin"
);
Ok(TargetSelection {
address: target.full_address(),
weight: target.weight,
metadata: HashMap::new(),
})
}
async fn report_health(&self, address: &str, healthy: bool) {
trace!(
target = %address,
healthy = healthy,
algorithm = "round_robin",
"Updating target health status"
);
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
}
async fn healthy_targets(&self) -> Vec<String> {
self.health_status
.read()
.await
.iter()
.filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
.collect()
}
}
struct RandomBalancer {
targets: Vec<UpstreamTarget>,
health_status: Arc<RwLock<HashMap<String, bool>>>,
}
impl RandomBalancer {
fn new(targets: Vec<UpstreamTarget>) -> Self {
let mut health_status = HashMap::new();
for target in &targets {
health_status.insert(target.full_address(), true);
}
Self {
targets,
health_status: Arc::new(RwLock::new(health_status)),
}
}
}
#[async_trait]
impl LoadBalancer for RandomBalancer {
async fn select(&self, _context: Option<&RequestContext>) -> ZentinelResult<TargetSelection> {
use rand::seq::SliceRandom;
trace!(
total_targets = self.targets.len(),
algorithm = "random",
"Selecting upstream target"
);
let health = self.health_status.read().await;
let healthy_targets: Vec<_> = self
.targets
.iter()
.filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
.collect();
if healthy_targets.is_empty() {
warn!(
total_targets = self.targets.len(),
algorithm = "random",
"No healthy upstream targets available"
);
return Err(ZentinelError::NoHealthyUpstream);
}
let mut rng = rand::rng();
let target = healthy_targets
.choose(&mut rng)
.ok_or(ZentinelError::NoHealthyUpstream)?;
trace!(
selected_target = %target.full_address(),
healthy_count = healthy_targets.len(),
algorithm = "random",
"Selected target via random selection"
);
Ok(TargetSelection {
address: target.full_address(),
weight: target.weight,
metadata: HashMap::new(),
})
}
async fn report_health(&self, address: &str, healthy: bool) {
trace!(
target = %address,
healthy = healthy,
algorithm = "random",
"Updating target health status"
);
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
}
async fn healthy_targets(&self) -> Vec<String> {
self.health_status
.read()
.await
.iter()
.filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
.collect()
}
}
struct LeastConnectionsBalancer {
targets: Vec<UpstreamTarget>,
connections: Arc<RwLock<HashMap<String, usize>>>,
health_status: Arc<RwLock<HashMap<String, bool>>>,
}
impl LeastConnectionsBalancer {
fn new(targets: Vec<UpstreamTarget>) -> Self {
let mut health_status = HashMap::new();
let mut connections = HashMap::new();
for target in &targets {
let addr = target.full_address();
health_status.insert(addr.clone(), true);
connections.insert(addr, 0);
}
Self {
targets,
connections: Arc::new(RwLock::new(connections)),
health_status: Arc::new(RwLock::new(health_status)),
}
}
}
#[async_trait]
impl LoadBalancer for LeastConnectionsBalancer {
async fn select(&self, _context: Option<&RequestContext>) -> ZentinelResult<TargetSelection> {
trace!(
total_targets = self.targets.len(),
algorithm = "least_connections",
"Selecting upstream target"
);
let health = self.health_status.read().await;
let conns = self.connections.read().await;
let mut best_target = None;
let mut min_connections = usize::MAX;
for target in &self.targets {
let addr = target.full_address();
if !*health.get(&addr).unwrap_or(&true) {
trace!(
target = %addr,
algorithm = "least_connections",
"Skipping unhealthy target"
);
continue;
}
let conn_count = *conns.get(&addr).unwrap_or(&0);
trace!(
target = %addr,
connections = conn_count,
"Evaluating target connection count"
);
if conn_count < min_connections {
min_connections = conn_count;
best_target = Some(target);
}
}
match best_target {
Some(target) => {
trace!(
selected_target = %target.full_address(),
connections = min_connections,
algorithm = "least_connections",
"Selected target with fewest connections"
);
Ok(TargetSelection {
address: target.full_address(),
weight: target.weight,
metadata: HashMap::new(),
})
}
None => {
warn!(
total_targets = self.targets.len(),
algorithm = "least_connections",
"No healthy upstream targets available"
);
Err(ZentinelError::NoHealthyUpstream)
}
}
}
async fn report_health(&self, address: &str, healthy: bool) {
trace!(
target = %address,
healthy = healthy,
algorithm = "least_connections",
"Updating target health status"
);
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
}
async fn healthy_targets(&self) -> Vec<String> {
self.health_status
.read()
.await
.iter()
.filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
.collect()
}
}
struct WeightedBalancer {
targets: Vec<UpstreamTarget>,
weights: Vec<u32>,
current_index: AtomicUsize,
health_status: Arc<RwLock<HashMap<String, bool>>>,
}
#[async_trait]
impl LoadBalancer for WeightedBalancer {
async fn select(&self, _context: Option<&RequestContext>) -> ZentinelResult<TargetSelection> {
trace!(
total_targets = self.targets.len(),
algorithm = "weighted",
"Selecting upstream target"
);
let health = self.health_status.read().await;
let healthy: Vec<_> = self
.targets
.iter()
.enumerate()
.filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
.map(|(i, _)| i)
.collect();
if healthy.is_empty() {
warn!(
total_targets = self.targets.len(),
algorithm = "weighted",
"No healthy upstream targets available"
);
return Err(ZentinelError::NoHealthyUpstream);
}
let total_weight: u32 = healthy
.iter()
.map(|&i| self.weights.get(i).copied().unwrap_or(1))
.sum();
if total_weight == 0 {
return Err(ZentinelError::NoHealthyUpstream);
}
let slot = (self.current_index.fetch_add(1, Ordering::Relaxed) as u32) % total_weight;
let mut cumulative = 0u32;
let mut target_idx = healthy[0];
for &i in &healthy {
let w = self.weights.get(i).copied().unwrap_or(1);
cumulative += w;
if slot < cumulative {
target_idx = i;
break;
}
}
let target = &self.targets[target_idx];
let weight = self.weights.get(target_idx).copied().unwrap_or(1);
trace!(
selected_target = %target.full_address(),
weight = weight,
healthy_count = healthy.len(),
algorithm = "weighted",
"Selected target via weighted round robin"
);
Ok(TargetSelection {
address: target.full_address(),
weight,
metadata: HashMap::new(),
})
}
async fn report_health(&self, address: &str, healthy: bool) {
trace!(
target = %address,
healthy = healthy,
algorithm = "weighted",
"Updating target health status"
);
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
}
async fn healthy_targets(&self) -> Vec<String> {
self.health_status
.read()
.await
.iter()
.filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
.collect()
}
}
struct IpHashBalancer {
targets: Vec<UpstreamTarget>,
health_status: Arc<RwLock<HashMap<String, bool>>>,
}
#[async_trait]
impl LoadBalancer for IpHashBalancer {
async fn select(&self, context: Option<&RequestContext>) -> ZentinelResult<TargetSelection> {
trace!(
total_targets = self.targets.len(),
algorithm = "ip_hash",
"Selecting upstream target"
);
let health = self.health_status.read().await;
let healthy_targets: Vec<_> = self
.targets
.iter()
.filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
.collect();
if healthy_targets.is_empty() {
warn!(
total_targets = self.targets.len(),
algorithm = "ip_hash",
"No healthy upstream targets available"
);
return Err(ZentinelError::NoHealthyUpstream);
}
let (hash, client_ip_str) = if let Some(ctx) = context {
if let Some(ip) = &ctx.client_ip {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
ip.hash(&mut hasher);
(hasher.finish(), Some(ip.to_string()))
} else {
(0, None)
}
} else {
(0, None)
};
let idx = (hash as usize) % healthy_targets.len();
let target = healthy_targets[idx];
trace!(
selected_target = %target.full_address(),
client_ip = client_ip_str.as_deref().unwrap_or("unknown"),
hash = hash,
index = idx,
healthy_count = healthy_targets.len(),
algorithm = "ip_hash",
"Selected target via IP hash"
);
Ok(TargetSelection {
address: target.full_address(),
weight: target.weight,
metadata: HashMap::new(),
})
}
async fn report_health(&self, address: &str, healthy: bool) {
trace!(
target = %address,
healthy = healthy,
algorithm = "ip_hash",
"Updating target health status"
);
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
}
async fn healthy_targets(&self) -> Vec<String> {
self.health_status
.read()
.await
.iter()
.filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
.collect()
}
}
impl UpstreamPool {
pub async fn new(config: UpstreamConfig) -> ZentinelResult<Self> {
let id = UpstreamId::new(&config.id);
info!(
upstream_id = %config.id,
target_count = config.targets.len(),
algorithm = ?config.load_balancing,
"Creating upstream pool"
);
let targets: Vec<UpstreamTarget> = config
.targets
.iter()
.filter_map(UpstreamTarget::from_config)
.collect();
if targets.is_empty() {
error!(
upstream_id = %config.id,
"No valid upstream targets configured"
);
return Err(ZentinelError::Config {
message: "No valid upstream targets".to_string(),
source: None,
});
}
for target in &targets {
debug!(
upstream_id = %config.id,
target = %target.full_address(),
weight = target.weight,
"Registered upstream target"
);
}
debug!(
upstream_id = %config.id,
algorithm = ?config.load_balancing,
"Creating load balancer"
);
let load_balancer = Self::create_load_balancer(&config.load_balancing, &targets, &config)?;
debug!(
upstream_id = %config.id,
max_connections = config.connection_pool.max_connections,
max_idle = config.connection_pool.max_idle,
idle_timeout_secs = config.connection_pool.idle_timeout_secs,
connect_timeout_secs = config.timeouts.connect_secs,
read_timeout_secs = config.timeouts.read_secs,
write_timeout_secs = config.timeouts.write_secs,
"Creating connection pool configuration"
);
let pool_config =
ConnectionPoolConfig::from_config(&config.connection_pool, &config.timeouts);
let http_version = HttpVersionOptions {
min_version: config.http_version.min_version,
max_version: config.http_version.max_version,
h2_ping_interval: if config.http_version.h2_ping_interval_secs > 0 {
Duration::from_secs(config.http_version.h2_ping_interval_secs)
} else {
Duration::ZERO
},
max_h2_streams: config.http_version.max_h2_streams,
};
let tls_enabled = config.tls.is_some();
let tls_sni = config.tls.as_ref().and_then(|t| t.sni.clone());
let tls_config = config.tls.clone();
if let Some(ref tls) = tls_config {
if tls.client_cert.is_some() {
info!(
upstream_id = %config.id,
"mTLS enabled for upstream (client certificate configured)"
);
}
}
if http_version.max_version >= 2 && tls_enabled {
info!(
upstream_id = %config.id,
"HTTP/2 enabled for upstream (via ALPN)"
);
}
let mut circuit_breakers = HashMap::new();
for target in &targets {
trace!(
upstream_id = %config.id,
target = %target.full_address(),
"Initializing circuit breaker for target"
);
circuit_breakers.insert(
target.full_address(),
CircuitBreaker::new(CircuitBreakerConfig::default()),
);
}
let pool = Self {
id: id.clone(),
targets,
load_balancer,
pool_config,
http_version,
tls_enabled,
tls_sni,
tls_config,
circuit_breakers: Arc::new(RwLock::new(circuit_breakers)),
stats: Arc::new(PoolStats::default()),
};
info!(
upstream_id = %id,
target_count = pool.targets.len(),
"Upstream pool created successfully"
);
Ok(pool)
}
fn create_load_balancer(
algorithm: &LoadBalancingAlgorithm,
targets: &[UpstreamTarget],
config: &UpstreamConfig,
) -> ZentinelResult<Arc<dyn LoadBalancer>> {
let balancer: Arc<dyn LoadBalancer> = match algorithm {
LoadBalancingAlgorithm::RoundRobin => {
Arc::new(RoundRobinBalancer::new(targets.to_vec()))
}
LoadBalancingAlgorithm::LeastConnections => {
Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
}
LoadBalancingAlgorithm::Weighted => {
let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
Arc::new(WeightedBalancer {
targets: targets.to_vec(),
weights,
current_index: AtomicUsize::new(0),
health_status: Arc::new(RwLock::new(HashMap::new())),
})
}
LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
targets: targets.to_vec(),
health_status: Arc::new(RwLock::new(HashMap::new())),
}),
LoadBalancingAlgorithm::Random => Arc::new(RandomBalancer::new(targets.to_vec())),
LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
targets.to_vec(),
ConsistentHashConfig::default(),
)),
LoadBalancingAlgorithm::PowerOfTwoChoices => {
Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
}
LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
targets.to_vec(),
AdaptiveConfig::default(),
)),
LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
targets.to_vec(),
LeastTokensQueuedConfig::default(),
)),
LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
targets.to_vec(),
MaglevConfig::default(),
)),
LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
targets.to_vec(),
LocalityAwareConfig::default(),
)),
LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
targets.to_vec(),
PeakEwmaConfig::default(),
)),
LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
targets.to_vec(),
SubsetConfig::default(),
)),
LoadBalancingAlgorithm::WeightedLeastConnections => {
Arc::new(WeightedLeastConnBalancer::new(
targets.to_vec(),
WeightedLeastConnConfig::default(),
))
}
LoadBalancingAlgorithm::Sticky => {
let sticky_config = config.sticky_session.as_ref().ok_or_else(|| {
ZentinelError::Config {
message: format!(
"Upstream '{}' uses Sticky algorithm but no sticky_session config provided",
config.id
),
source: None,
}
})?;
let runtime_config = StickySessionRuntimeConfig::from_config(sticky_config);
let fallback = Self::create_load_balancer_inner(&sticky_config.fallback, targets)?;
info!(
upstream_id = %config.id,
cookie_name = %runtime_config.cookie_name,
cookie_ttl_secs = runtime_config.cookie_ttl_secs,
fallback_algorithm = ?sticky_config.fallback,
"Creating sticky session balancer"
);
Arc::new(StickySessionBalancer::new(
targets.to_vec(),
runtime_config,
fallback,
))
}
};
Ok(balancer)
}
fn create_load_balancer_inner(
algorithm: &LoadBalancingAlgorithm,
targets: &[UpstreamTarget],
) -> ZentinelResult<Arc<dyn LoadBalancer>> {
let balancer: Arc<dyn LoadBalancer> = match algorithm {
LoadBalancingAlgorithm::RoundRobin => {
Arc::new(RoundRobinBalancer::new(targets.to_vec()))
}
LoadBalancingAlgorithm::LeastConnections => {
Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
}
LoadBalancingAlgorithm::Weighted => {
let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
Arc::new(WeightedBalancer {
targets: targets.to_vec(),
weights,
current_index: AtomicUsize::new(0),
health_status: Arc::new(RwLock::new(HashMap::new())),
})
}
LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
targets: targets.to_vec(),
health_status: Arc::new(RwLock::new(HashMap::new())),
}),
LoadBalancingAlgorithm::Random => Arc::new(RandomBalancer::new(targets.to_vec())),
LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
targets.to_vec(),
ConsistentHashConfig::default(),
)),
LoadBalancingAlgorithm::PowerOfTwoChoices => {
Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
}
LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
targets.to_vec(),
AdaptiveConfig::default(),
)),
LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
targets.to_vec(),
LeastTokensQueuedConfig::default(),
)),
LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
targets.to_vec(),
MaglevConfig::default(),
)),
LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
targets.to_vec(),
LocalityAwareConfig::default(),
)),
LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
targets.to_vec(),
PeakEwmaConfig::default(),
)),
LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
targets.to_vec(),
SubsetConfig::default(),
)),
LoadBalancingAlgorithm::WeightedLeastConnections => {
Arc::new(WeightedLeastConnBalancer::new(
targets.to_vec(),
WeightedLeastConnConfig::default(),
))
}
LoadBalancingAlgorithm::Sticky => {
return Err(ZentinelError::Config {
message: "Sticky algorithm cannot be used as fallback for sticky sessions"
.to_string(),
source: None,
});
}
};
Ok(balancer)
}
pub async fn select_peer_with_metadata(
&self,
context: Option<&RequestContext>,
) -> ZentinelResult<(HttpPeer, HashMap<String, String>)> {
let request_num = self.stats.requests.fetch_add(1, Ordering::Relaxed) + 1;
trace!(
upstream_id = %self.id,
request_num = request_num,
target_count = self.targets.len(),
"Starting peer selection with metadata"
);
let mut attempts = 0;
let max_attempts = self.targets.len() * 2;
while attempts < max_attempts {
attempts += 1;
trace!(
upstream_id = %self.id,
attempt = attempts,
max_attempts = max_attempts,
"Attempting to select peer"
);
let selection = match self.load_balancer.select(context).await {
Ok(s) => s,
Err(e) => {
warn!(
upstream_id = %self.id,
attempt = attempts,
error = %e,
"Load balancer selection failed"
);
continue;
}
};
trace!(
upstream_id = %self.id,
target = %selection.address,
attempt = attempts,
"Load balancer selected target"
);
let breakers = self.circuit_breakers.read().await;
if let Some(breaker) = breakers.get(&selection.address) {
if !breaker.is_closed() {
debug!(
upstream_id = %self.id,
target = %selection.address,
attempt = attempts,
"Circuit breaker is open, skipping target"
);
self.stats
.circuit_breaker_trips
.fetch_add(1, Ordering::Relaxed);
continue;
}
}
trace!(
upstream_id = %self.id,
target = %selection.address,
"Creating peer for upstream (Pingora handles connection reuse)"
);
let peer = self.create_peer(&selection)?;
debug!(
upstream_id = %self.id,
target = %selection.address,
attempt = attempts,
metadata_keys = ?selection.metadata.keys().collect::<Vec<_>>(),
"Selected upstream peer with metadata"
);
self.stats.successes.fetch_add(1, Ordering::Relaxed);
return Ok((peer, selection.metadata));
}
self.stats.failures.fetch_add(1, Ordering::Relaxed);
error!(
upstream_id = %self.id,
attempts = attempts,
max_attempts = max_attempts,
"Failed to select upstream after max attempts"
);
Err(ZentinelError::upstream(
self.id.to_string(),
"Failed to select upstream after max attempts",
))
}
pub async fn select_peer(&self, context: Option<&RequestContext>) -> ZentinelResult<HttpPeer> {
self.select_peer_with_metadata(context)
.await
.map(|(peer, _)| peer)
}
fn create_peer(&self, selection: &TargetSelection) -> ZentinelResult<HttpPeer> {
let sni_hostname = self.tls_sni.clone().unwrap_or_else(|| {
selection
.address
.split(':')
.next()
.unwrap_or(&selection.address)
.to_string()
});
let resolved_address = selection
.address
.to_socket_addrs()
.map_err(|e| {
error!(
upstream = %self.id,
address = %selection.address,
error = %e,
"Failed to resolve upstream address"
);
ZentinelError::Upstream {
upstream: self.id.to_string(),
message: format!("DNS resolution failed for {}: {}", selection.address, e),
retryable: true,
source: None,
}
})?
.next()
.ok_or_else(|| {
error!(
upstream = %self.id,
address = %selection.address,
"No addresses returned from DNS resolution"
);
ZentinelError::Upstream {
upstream: self.id.to_string(),
message: format!("No addresses for {}", selection.address),
retryable: true,
source: None,
}
})?;
let mut peer = HttpPeer::new(resolved_address, self.tls_enabled, sni_hostname.clone());
peer.options.idle_timeout = Some(self.pool_config.idle_timeout);
peer.options.connection_timeout = Some(self.pool_config.connection_timeout);
peer.options.total_connection_timeout = Some(Duration::from_secs(10));
peer.options.read_timeout = Some(self.pool_config.read_timeout);
peer.options.write_timeout = Some(self.pool_config.write_timeout);
peer.options.tcp_keepalive = Some(pingora::protocols::TcpKeepalive {
idle: Duration::from_secs(60),
interval: Duration::from_secs(10),
count: 3,
#[cfg(target_os = "linux")]
user_timeout: Duration::from_secs(60),
});
if self.tls_enabled {
let alpn = match (self.http_version.min_version, self.http_version.max_version) {
(2, _) => {
pingora::upstreams::peer::ALPN::H2
}
(1, 2) | (_, 2) => {
pingora::upstreams::peer::ALPN::H2H1
}
_ => {
pingora::upstreams::peer::ALPN::H1
}
};
peer.options.alpn = alpn;
if let Some(ref tls_config) = self.tls_config {
if tls_config.insecure_skip_verify {
peer.options.verify_cert = false;
peer.options.verify_hostname = false;
warn!(
upstream_id = %self.id,
target = %selection.address,
"TLS certificate verification DISABLED (insecure_skip_verify=true)"
);
}
if let Some(ref sni) = tls_config.sni {
peer.options.alternative_cn = Some(sni.clone());
trace!(
upstream_id = %self.id,
target = %selection.address,
alternative_cn = %sni,
"Set alternative CN for TLS verification"
);
}
if let (Some(cert_path), Some(key_path)) =
(&tls_config.client_cert, &tls_config.client_key)
{
match crate::tls::load_client_cert_key(cert_path, key_path) {
Ok(cert_key) => {
peer.client_cert_key = Some(cert_key);
info!(
upstream_id = %self.id,
target = %selection.address,
cert_path = ?cert_path,
"mTLS client certificate configured"
);
}
Err(e) => {
error!(
upstream_id = %self.id,
target = %selection.address,
error = %e,
"Failed to load mTLS client certificate"
);
return Err(ZentinelError::Tls {
message: format!("Failed to load client certificate: {}", e),
source: None,
});
}
}
}
}
trace!(
upstream_id = %self.id,
target = %selection.address,
alpn = ?peer.options.alpn,
min_version = self.http_version.min_version,
max_version = self.http_version.max_version,
verify_cert = peer.options.verify_cert,
verify_hostname = peer.options.verify_hostname,
"Configured ALPN and TLS options for HTTP version negotiation"
);
}
if self.http_version.max_version >= 2 {
if !self.http_version.h2_ping_interval.is_zero() {
peer.options.h2_ping_interval = Some(self.http_version.h2_ping_interval);
trace!(
upstream_id = %self.id,
target = %selection.address,
h2_ping_interval_secs = self.http_version.h2_ping_interval.as_secs(),
"Configured H2 ping interval"
);
}
}
trace!(
upstream_id = %self.id,
target = %selection.address,
tls = self.tls_enabled,
sni = %sni_hostname,
idle_timeout_secs = self.pool_config.idle_timeout.as_secs(),
http_max_version = self.http_version.max_version,
"Created peer with Pingora connection pooling enabled"
);
Ok(peer)
}
pub async fn report_result(&self, target: &str, success: bool) {
trace!(
upstream_id = %self.id,
target = %target,
success = success,
"Reporting connection result"
);
if success {
if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
breaker.record_success();
trace!(
upstream_id = %self.id,
target = %target,
"Recorded success in circuit breaker"
);
}
self.load_balancer.report_health(target, true).await;
} else {
let breaker_opened =
if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
let opened = breaker.record_failure();
debug!(
upstream_id = %self.id,
target = %target,
circuit_breaker_opened = opened,
"Recorded failure in circuit breaker"
);
opened
} else {
false
};
if breaker_opened {
self.load_balancer.report_health(target, false).await;
}
self.stats.failures.fetch_add(1, Ordering::Relaxed);
warn!(
upstream_id = %self.id,
target = %target,
circuit_breaker_opened = breaker_opened,
"Connection failure reported for target"
);
}
}
pub async fn report_result_with_latency(
&self,
target: &str,
success: bool,
latency: Option<Duration>,
) {
trace!(
upstream_id = %self.id,
target = %target,
success = success,
latency_ms = latency.map(|l| l.as_millis() as u64),
"Reporting result with latency for adaptive LB"
);
if success {
if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
breaker.record_success();
}
self.load_balancer
.report_result_with_latency(target, true, latency)
.await;
} else {
let breaker_opened =
if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
breaker.record_failure()
} else {
false
};
self.stats.failures.fetch_add(1, Ordering::Relaxed);
if breaker_opened {
self.load_balancer
.report_result_with_latency(target, false, latency)
.await;
}
}
}
pub fn stats(&self) -> &PoolStats {
&self.stats
}
pub fn id(&self) -> &UpstreamId {
&self.id
}
pub fn target_count(&self) -> usize {
self.targets.len()
}
pub fn pool_config(&self) -> PoolConfigSnapshot {
PoolConfigSnapshot {
max_connections: self.pool_config.max_connections,
max_idle: self.pool_config.max_idle,
idle_timeout_secs: self.pool_config.idle_timeout.as_secs(),
max_lifetime_secs: self.pool_config.max_lifetime.map(|d| d.as_secs()),
connection_timeout_secs: self.pool_config.connection_timeout.as_secs(),
read_timeout_secs: self.pool_config.read_timeout.as_secs(),
write_timeout_secs: self.pool_config.write_timeout.as_secs(),
}
}
pub async fn has_healthy_targets(&self) -> bool {
let healthy = self.load_balancer.healthy_targets().await;
!healthy.is_empty()
}
pub async fn select_shadow_target(
&self,
context: Option<&RequestContext>,
) -> ZentinelResult<ShadowTarget> {
let selection = self.load_balancer.select(context).await?;
let breakers = self.circuit_breakers.read().await;
if let Some(breaker) = breakers.get(&selection.address) {
if !breaker.is_closed() {
return Err(ZentinelError::upstream(
self.id.to_string(),
"Circuit breaker is open for shadow target",
));
}
}
let (host, port) = if selection.address.contains(':') {
let parts: Vec<&str> = selection.address.rsplitn(2, ':').collect();
if parts.len() == 2 {
(
parts[1].to_string(),
parts[0]
.parse::<u16>()
.unwrap_or(if self.tls_enabled { 443 } else { 80 }),
)
} else {
(
selection.address.clone(),
if self.tls_enabled { 443 } else { 80 },
)
}
} else {
(
selection.address.clone(),
if self.tls_enabled { 443 } else { 80 },
)
};
Ok(ShadowTarget {
scheme: if self.tls_enabled { "https" } else { "http" }.to_string(),
host,
port,
sni: self.tls_sni.clone(),
})
}
pub fn is_tls_enabled(&self) -> bool {
self.tls_enabled
}
pub fn active_request_count(&self) -> u64 {
self.stats.active_requests.load(Ordering::Relaxed)
}
pub fn increment_active(&self) {
self.stats.active_requests.fetch_add(1, Ordering::Relaxed);
}
pub fn decrement_active(&self) {
self.stats.active_requests.fetch_sub(1, Ordering::Relaxed);
}
pub async fn shutdown(&self) {
info!(
upstream_id = %self.id,
target_count = self.targets.len(),
total_requests = self.stats.requests.load(Ordering::Relaxed),
total_successes = self.stats.successes.load(Ordering::Relaxed),
total_failures = self.stats.failures.load(Ordering::Relaxed),
"Shutting down upstream pool"
);
debug!(upstream_id = %self.id, "Upstream pool shutdown complete");
}
}