use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use crate::config::{ServiceConfig, UpstreamConfig};
use tracing::info;
pub const FAILURE_THRESHOLD: u32 = 3;
pub const COOLDOWN_SECS: u64 = 30;
#[derive(Debug, Clone, Copy)]
pub struct CooldownBackoff {
pub factor: u64,
pub max_secs: u64,
}
impl CooldownBackoff {
fn effective_cooldown_secs(&self, base_secs: u64, penalty_streak: u32) -> u64 {
if base_secs == 0 {
return 0;
}
if self.factor <= 1 {
return base_secs;
}
let cap = if self.max_secs == 0 {
base_secs
} else {
self.max_secs.max(base_secs)
};
let mut secs = base_secs;
for _ in 0..penalty_streak.min(64) {
secs = secs.saturating_mul(self.factor);
if secs >= cap {
return cap;
}
}
secs.min(cap)
}
}
#[derive(Debug, Default)]
pub struct LbState {
pub failure_counts: Vec<u32>,
pub cooldown_until: Vec<Option<std::time::Instant>>,
pub usage_exhausted: Vec<bool>,
pub last_good_index: Option<usize>,
pub penalty_streak: Vec<u32>,
}
impl LbState {
fn ensure_len(&mut self, len: usize) {
if self.failure_counts.len() != len {
self.failure_counts = vec![0; len];
self.cooldown_until = vec![None; len];
self.usage_exhausted = vec![false; len];
self.penalty_streak = vec![0; len];
self.last_good_index = None;
}
}
}
#[derive(Debug, Clone)]
pub struct SelectedUpstream {
pub config_name: String,
pub index: usize,
pub upstream: UpstreamConfig,
}
#[derive(Clone)]
pub struct LoadBalancer {
pub service: Arc<ServiceConfig>,
pub states: Arc<Mutex<HashMap<String, LbState>>>,
}
impl LoadBalancer {
pub fn new(service: Arc<ServiceConfig>, states: Arc<Mutex<HashMap<String, LbState>>>) -> Self {
Self { service, states }
}
#[cfg(test)]
pub fn select_upstream(&self) -> Option<SelectedUpstream> {
self.select_upstream_avoiding(&HashSet::new())
}
pub fn select_upstream_avoiding(&self, avoid: &HashSet<usize>) -> Option<SelectedUpstream> {
self.select_upstream_avoiding_inner(avoid, false)
}
pub fn select_upstream_avoiding_strict(
&self,
avoid: &HashSet<usize>,
) -> Option<SelectedUpstream> {
self.select_upstream_avoiding_inner(avoid, true)
}
fn select_upstream_avoiding_inner(
&self,
avoid: &HashSet<usize>,
strict: bool,
) -> Option<SelectedUpstream> {
if self.service.upstreams.is_empty() {
return None;
}
let mut map = match self.states.lock() {
Ok(m) => m,
Err(e) => e.into_inner(),
};
let entry = map.entry(self.service.name.clone()).or_default();
entry.ensure_len(self.service.upstreams.len());
let now = std::time::Instant::now();
for idx in 0..self.service.upstreams.len() {
if let Some(until) = entry.cooldown_until.get(idx).and_then(|v| *v)
&& now >= until
{
entry.failure_counts[idx] = 0;
if let Some(slot) = entry.cooldown_until.get_mut(idx) {
*slot = None;
}
}
}
if let Some(idx) = entry.last_good_index
&& idx < self.service.upstreams.len()
&& entry.failure_counts[idx] < FAILURE_THRESHOLD
&& !entry.usage_exhausted.get(idx).copied().unwrap_or(false)
&& !avoid.contains(&idx)
{
let upstream = self.service.upstreams[idx].clone();
return Some(SelectedUpstream {
config_name: self.service.name.clone(),
index: idx,
upstream,
});
}
if let Some(idx) = self
.service
.upstreams
.iter()
.enumerate()
.find_map(|(idx, _)| {
if avoid.contains(&idx) {
return None;
}
if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
return None;
}
if entry.usage_exhausted.get(idx).copied().unwrap_or(false) {
return None;
}
Some(idx)
})
{
let upstream = self.service.upstreams[idx].clone();
return Some(SelectedUpstream {
config_name: self.service.name.clone(),
index: idx,
upstream,
});
}
if let Some(idx) = self
.service
.upstreams
.iter()
.enumerate()
.find_map(|(idx, _)| {
if avoid.contains(&idx) {
return None;
}
if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
None
} else {
Some(idx)
}
})
{
let upstream = self.service.upstreams[idx].clone();
return Some(SelectedUpstream {
config_name: self.service.name.clone(),
index: idx,
upstream,
});
}
if strict {
return None;
}
let idx = (0..self.service.upstreams.len())
.find(|i| !avoid.contains(i))
.unwrap_or(0);
let upstream = self.service.upstreams[idx].clone();
Some(SelectedUpstream {
config_name: self.service.name.clone(),
index: idx,
upstream,
})
}
pub fn penalize_with_backoff(
&self,
index: usize,
cooldown_secs: u64,
reason: &str,
backoff: CooldownBackoff,
) {
let mut map = match self.states.lock() {
Ok(m) => m,
Err(_) => return,
};
let entry = map
.entry(self.service.name.clone())
.or_insert_with(LbState::default);
entry.ensure_len(self.service.upstreams.len());
if index >= entry.failure_counts.len() {
return;
}
let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
let effective_secs = backoff.effective_cooldown_secs(cooldown_secs, streak);
entry.failure_counts[index] = FAILURE_THRESHOLD;
if let Some(slot) = entry.cooldown_until.get_mut(index) {
*slot =
Some(std::time::Instant::now() + std::time::Duration::from_secs(effective_secs));
}
if let Some(slot) = entry.penalty_streak.get_mut(index) {
*slot = streak.saturating_add(1);
}
if entry.last_good_index == Some(index) {
entry.last_good_index = None;
}
info!(
"lb: upstream '{}' index {} penalized for {}s (reason: {})",
self.service.name, index, effective_secs, reason
);
}
pub fn record_result_with_backoff(
&self,
index: usize,
success: bool,
failure_threshold_cooldown_secs: u64,
backoff: CooldownBackoff,
) {
let mut map = match self.states.lock() {
Ok(m) => m,
Err(_) => return,
};
let entry = map
.entry(self.service.name.clone())
.or_insert_with(LbState::default);
entry.ensure_len(self.service.upstreams.len());
if index >= entry.failure_counts.len() {
return;
}
if success {
entry.failure_counts[index] = 0;
if let Some(slot) = entry.cooldown_until.get_mut(index) {
*slot = None;
}
if let Some(slot) = entry.penalty_streak.get_mut(index) {
*slot = 0;
}
entry.last_good_index = Some(index);
} else {
entry.failure_counts[index] = entry.failure_counts[index].saturating_add(1);
if entry.failure_counts[index] >= FAILURE_THRESHOLD
&& let Some(slot) = entry.cooldown_until.get_mut(index)
{
let base_secs = if failure_threshold_cooldown_secs == 0 {
COOLDOWN_SECS
} else {
failure_threshold_cooldown_secs
};
let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
let effective_secs = backoff.effective_cooldown_secs(base_secs, streak);
let now = std::time::Instant::now();
let new_until = now + std::time::Duration::from_secs(effective_secs);
let should_update = match *slot {
Some(existing) => new_until > existing,
None => true,
};
if should_update {
*slot = Some(new_until);
}
if let Some(slot) = entry.penalty_streak.get_mut(index) {
*slot = streak.saturating_add(1);
}
info!(
"lb: upstream '{}' index {} reached failure threshold {} (count = {}), entering cooldown for {}s",
self.service.name,
index,
FAILURE_THRESHOLD,
entry.failure_counts[index],
effective_secs
);
if entry.last_good_index == Some(index) {
entry.last_good_index = None;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{ServiceConfig, UpstreamAuth, UpstreamConfig};
fn make_service(name: &str, urls: &[&str]) -> ServiceConfig {
ServiceConfig {
name: name.to_string(),
alias: None,
enabled: true,
level: 1,
upstreams: urls
.iter()
.map(|u| UpstreamConfig {
base_url: u.to_string(),
auth: UpstreamAuth {
auth_token: Some("sk-test".to_string()),
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: HashMap::new(),
model_mapping: HashMap::new(),
})
.collect(),
}
}
#[test]
fn lb_prefers_non_exhausted_upstream_when_available() {
let service = make_service(
"codex-main",
&["https://primary.example", "https://backup.example"],
);
let states = Arc::new(Mutex::new(HashMap::new()));
let lb = LoadBalancer::new(Arc::new(service), states.clone());
let first = lb.select_upstream().expect("should select an upstream");
assert_eq!(first.index, 0);
{
let mut guard = states.lock().unwrap();
let entry = guard
.entry("codex-main".to_string())
.or_insert_with(LbState::default);
entry.ensure_len(2);
entry.usage_exhausted[0] = true;
entry.usage_exhausted[1] = false;
}
let second = lb.select_upstream().expect("should select backup upstream");
assert_eq!(second.index, 1);
}
#[test]
fn lb_falls_back_when_all_exhausted() {
let service = make_service(
"codex-main",
&["https://primary.example", "https://backup.example"],
);
let states = Arc::new(Mutex::new(HashMap::new()));
let lb = LoadBalancer::new(Arc::new(service), states.clone());
let _ = lb.select_upstream();
{
let mut guard = states.lock().unwrap();
let entry = guard
.entry("codex-main".to_string())
.or_insert_with(LbState::default);
entry.ensure_len(2);
entry.usage_exhausted[0] = true;
entry.usage_exhausted[1] = true;
}
let selected = lb
.select_upstream()
.expect("should still select an upstream");
assert_eq!(selected.index, 0);
}
#[test]
fn lb_avoids_upstreams_past_failure_threshold() {
let service = make_service(
"codex-main",
&["https://primary.example", "https://backup.example"],
);
let states = Arc::new(Mutex::new(HashMap::new()));
let lb = LoadBalancer::new(Arc::new(service), states.clone());
let disabled_backoff = CooldownBackoff {
factor: 1,
max_secs: 0,
};
for _ in 0..FAILURE_THRESHOLD {
lb.record_result_with_backoff(0, false, COOLDOWN_SECS, disabled_backoff);
}
let selected = lb
.select_upstream()
.expect("should select backup after failures");
assert_eq!(selected.index, 1);
}
}