1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::sync::{Arc, Mutex};
4
5use crate::config::{ServiceConfig, UpstreamConfig};
6use tracing::info;
7
8pub const FAILURE_THRESHOLD: u32 = 3;
9pub const COOLDOWN_SECS: u64 = 30;
10
11#[derive(Debug, Clone, Copy)]
12pub struct CooldownBackoff {
13 pub factor: u64,
14 pub max_secs: u64,
15}
16
17impl CooldownBackoff {
18 fn effective_cooldown_secs(&self, base_secs: u64, penalty_streak: u32) -> u64 {
19 if base_secs == 0 {
20 return 0;
21 }
22 if self.factor <= 1 {
23 return base_secs;
24 }
25 let cap = if self.max_secs == 0 {
26 base_secs
27 } else {
28 self.max_secs.max(base_secs)
29 };
30
31 let mut secs = base_secs;
32 for _ in 0..penalty_streak.min(64) {
33 secs = secs.saturating_mul(self.factor);
34 if secs >= cap {
35 return cap;
36 }
37 }
38 secs.min(cap)
39 }
40}
41
42#[derive(Debug, Default)]
43pub struct LbState {
44 pub failure_counts: Vec<u32>,
45 pub cooldown_until: Vec<Option<std::time::Instant>>,
46 pub usage_exhausted: Vec<bool>,
47 pub last_good_index: Option<usize>,
48 pub penalty_streak: Vec<u32>,
49}
50
51impl LbState {
52 fn ensure_len(&mut self, len: usize) {
53 if self.failure_counts.len() != len {
54 self.failure_counts = vec![0; len];
55 self.cooldown_until = vec![None; len];
56 self.usage_exhausted = vec![false; len];
57 self.penalty_streak = vec![0; len];
58 self.last_good_index = None;
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct SelectedUpstream {
67 pub config_name: String,
68 pub index: usize,
69 pub upstream: UpstreamConfig,
70}
71
72#[derive(Clone)]
74pub struct LoadBalancer {
75 pub service: Arc<ServiceConfig>,
76 pub states: Arc<Mutex<HashMap<String, LbState>>>,
77}
78
79impl LoadBalancer {
80 pub fn new(service: Arc<ServiceConfig>, states: Arc<Mutex<HashMap<String, LbState>>>) -> Self {
81 Self { service, states }
82 }
83
84 #[cfg(test)]
85 pub fn select_upstream(&self) -> Option<SelectedUpstream> {
86 self.select_upstream_avoiding(&HashSet::new())
87 }
88
89 pub fn select_upstream_avoiding(&self, avoid: &HashSet<usize>) -> Option<SelectedUpstream> {
90 self.select_upstream_avoiding_inner(avoid, false)
91 }
92
93 pub fn select_upstream_avoiding_strict(
94 &self,
95 avoid: &HashSet<usize>,
96 ) -> Option<SelectedUpstream> {
97 self.select_upstream_avoiding_inner(avoid, true)
98 }
99
100 fn select_upstream_avoiding_inner(
101 &self,
102 avoid: &HashSet<usize>,
103 strict: bool,
104 ) -> Option<SelectedUpstream> {
105 if self.service.upstreams.is_empty() {
106 return None;
107 }
108
109 let mut map = match self.states.lock() {
110 Ok(m) => m,
111 Err(e) => e.into_inner(),
112 };
113 let entry = map.entry(self.service.name.clone()).or_default();
114 entry.ensure_len(self.service.upstreams.len());
115
116 let now = std::time::Instant::now();
117
118 for idx in 0..self.service.upstreams.len() {
120 if let Some(until) = entry.cooldown_until.get(idx).and_then(|v| *v)
121 && now >= until
122 {
123 entry.failure_counts[idx] = 0;
124 if let Some(slot) = entry.cooldown_until.get_mut(idx) {
125 *slot = None;
126 }
127 }
128 }
129
130 if let Some(idx) = entry.last_good_index
133 && idx < self.service.upstreams.len()
134 && entry.failure_counts[idx] < FAILURE_THRESHOLD
135 && !entry.usage_exhausted.get(idx).copied().unwrap_or(false)
136 && !avoid.contains(&idx)
137 {
138 let upstream = self.service.upstreams[idx].clone();
139 return Some(SelectedUpstream {
140 config_name: self.service.name.clone(),
141 index: idx,
142 upstream,
143 });
144 }
145
146 if let Some(idx) = self
148 .service
149 .upstreams
150 .iter()
151 .enumerate()
152 .find_map(|(idx, _)| {
153 if avoid.contains(&idx) {
154 return None;
155 }
156 if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
157 return None;
158 }
159 if entry.usage_exhausted.get(idx).copied().unwrap_or(false) {
160 return None;
161 }
162 Some(idx)
163 })
164 {
165 let upstream = self.service.upstreams[idx].clone();
166 return Some(SelectedUpstream {
167 config_name: self.service.name.clone(),
168 index: idx,
169 upstream,
170 });
171 }
172
173 if let Some(idx) = self
175 .service
176 .upstreams
177 .iter()
178 .enumerate()
179 .find_map(|(idx, _)| {
180 if avoid.contains(&idx) {
181 return None;
182 }
183 if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
184 None
185 } else {
186 Some(idx)
187 }
188 })
189 {
190 let upstream = self.service.upstreams[idx].clone();
191 return Some(SelectedUpstream {
192 config_name: self.service.name.clone(),
193 index: idx,
194 upstream,
195 });
196 }
197
198 if strict {
199 return None;
200 }
201
202 let idx = (0..self.service.upstreams.len())
205 .find(|i| !avoid.contains(i))
206 .unwrap_or(0);
207 let upstream = self.service.upstreams[idx].clone();
208 Some(SelectedUpstream {
209 config_name: self.service.name.clone(),
210 index: idx,
211 upstream,
212 })
213 }
214
215 pub fn penalize_with_backoff(
216 &self,
217 index: usize,
218 cooldown_secs: u64,
219 reason: &str,
220 backoff: CooldownBackoff,
221 ) {
222 let mut map = match self.states.lock() {
223 Ok(m) => m,
224 Err(_) => return,
225 };
226 let entry = map
227 .entry(self.service.name.clone())
228 .or_insert_with(LbState::default);
229 entry.ensure_len(self.service.upstreams.len());
230 if index >= entry.failure_counts.len() {
231 return;
232 }
233
234 let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
235 let effective_secs = backoff.effective_cooldown_secs(cooldown_secs, streak);
236
237 entry.failure_counts[index] = FAILURE_THRESHOLD;
238 if let Some(slot) = entry.cooldown_until.get_mut(index) {
239 *slot =
240 Some(std::time::Instant::now() + std::time::Duration::from_secs(effective_secs));
241 }
242 if let Some(slot) = entry.penalty_streak.get_mut(index) {
243 *slot = streak.saturating_add(1);
244 }
245 if entry.last_good_index == Some(index) {
246 entry.last_good_index = None;
247 }
248 info!(
249 "lb: upstream '{}' index {} penalized for {}s (reason: {})",
250 self.service.name, index, effective_secs, reason
251 );
252 }
253
254 pub fn record_result_with_backoff(
255 &self,
256 index: usize,
257 success: bool,
258 failure_threshold_cooldown_secs: u64,
259 backoff: CooldownBackoff,
260 ) {
261 let mut map = match self.states.lock() {
262 Ok(m) => m,
263 Err(_) => return,
264 };
265 let entry = map
266 .entry(self.service.name.clone())
267 .or_insert_with(LbState::default);
268 entry.ensure_len(self.service.upstreams.len());
269 if index >= entry.failure_counts.len() {
270 return;
271 }
272 if success {
273 entry.failure_counts[index] = 0;
274 if let Some(slot) = entry.cooldown_until.get_mut(index) {
275 *slot = None;
276 }
277 if let Some(slot) = entry.penalty_streak.get_mut(index) {
278 *slot = 0;
279 }
280 entry.last_good_index = Some(index);
282 } else {
283 entry.failure_counts[index] = entry.failure_counts[index].saturating_add(1);
284 if entry.failure_counts[index] >= FAILURE_THRESHOLD
285 && let Some(slot) = entry.cooldown_until.get_mut(index)
286 {
287 let base_secs = if failure_threshold_cooldown_secs == 0 {
288 COOLDOWN_SECS
289 } else {
290 failure_threshold_cooldown_secs
291 };
292 let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
293 let effective_secs = backoff.effective_cooldown_secs(base_secs, streak);
294 let now = std::time::Instant::now();
295 let new_until = now + std::time::Duration::from_secs(effective_secs);
296 let should_update = match *slot {
297 Some(existing) => new_until > existing,
298 None => true,
299 };
300 if should_update {
301 *slot = Some(new_until);
302 }
303 if let Some(slot) = entry.penalty_streak.get_mut(index) {
304 *slot = streak.saturating_add(1);
305 }
306 info!(
307 "lb: upstream '{}' index {} reached failure threshold {} (count = {}), entering cooldown for {}s",
308 self.service.name,
309 index,
310 FAILURE_THRESHOLD,
311 entry.failure_counts[index],
312 effective_secs
313 );
314 if entry.last_good_index == Some(index) {
316 entry.last_good_index = None;
317 }
318 }
319 }
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::config::{ServiceConfig, UpstreamAuth, UpstreamConfig};
327
328 fn make_service(name: &str, urls: &[&str]) -> ServiceConfig {
329 ServiceConfig {
330 name: name.to_string(),
331 alias: None,
332 enabled: true,
333 level: 1,
334 upstreams: urls
335 .iter()
336 .map(|u| UpstreamConfig {
337 base_url: u.to_string(),
338 auth: UpstreamAuth {
339 auth_token: Some("sk-test".to_string()),
340 auth_token_env: None,
341 api_key: None,
342 api_key_env: None,
343 },
344 tags: HashMap::new(),
345 supported_models: HashMap::new(),
346 model_mapping: HashMap::new(),
347 })
348 .collect(),
349 }
350 }
351
352 #[test]
353 fn lb_prefers_non_exhausted_upstream_when_available() {
354 let service = make_service(
355 "codex-main",
356 &["https://primary.example", "https://backup.example"],
357 );
358 let states = Arc::new(Mutex::new(HashMap::new()));
359 let lb = LoadBalancer::new(Arc::new(service), states.clone());
360
361 let first = lb.select_upstream().expect("should select an upstream");
363 assert_eq!(first.index, 0);
364
365 {
367 let mut guard = states.lock().unwrap();
368 let entry = guard
369 .entry("codex-main".to_string())
370 .or_insert_with(LbState::default);
371 entry.ensure_len(2);
372 entry.usage_exhausted[0] = true;
373 entry.usage_exhausted[1] = false;
374 }
375
376 let second = lb.select_upstream().expect("should select backup upstream");
378 assert_eq!(second.index, 1);
379 }
380
381 #[test]
382 fn lb_falls_back_when_all_exhausted() {
383 let service = make_service(
384 "codex-main",
385 &["https://primary.example", "https://backup.example"],
386 );
387 let states = Arc::new(Mutex::new(HashMap::new()));
388 let lb = LoadBalancer::new(Arc::new(service), states.clone());
389
390 let _ = lb.select_upstream();
392
393 {
394 let mut guard = states.lock().unwrap();
395 let entry = guard
396 .entry("codex-main".to_string())
397 .or_insert_with(LbState::default);
398 entry.ensure_len(2);
399 entry.usage_exhausted[0] = true;
400 entry.usage_exhausted[1] = true;
401 }
402
403 let selected = lb
405 .select_upstream()
406 .expect("should still select an upstream");
407 assert_eq!(selected.index, 0);
408 }
409
410 #[test]
411 fn lb_avoids_upstreams_past_failure_threshold() {
412 let service = make_service(
413 "codex-main",
414 &["https://primary.example", "https://backup.example"],
415 );
416 let states = Arc::new(Mutex::new(HashMap::new()));
417 let lb = LoadBalancer::new(Arc::new(service), states.clone());
418
419 let disabled_backoff = CooldownBackoff {
420 factor: 1,
421 max_secs: 0,
422 };
423
424 for _ in 0..FAILURE_THRESHOLD {
426 lb.record_result_with_backoff(0, false, COOLDOWN_SECS, disabled_backoff);
427 }
428
429 let selected = lb
431 .select_upstream()
432 .expect("should select backup after failures");
433 assert_eq!(selected.index, 1);
434 }
435}