1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::sync::{Arc, Mutex};
4
5use crate::config::{ServiceConfig, UpstreamConfig};
6use crate::runtime_identity::ProviderEndpointKey;
7use tracing::info;
8
9pub const FAILURE_THRESHOLD: u32 = 3;
10pub const COOLDOWN_SECS: u64 = 30;
11
12#[derive(Debug, Clone, Copy)]
13pub struct CooldownBackoff {
14 pub factor: u64,
15 pub max_secs: u64,
16}
17
18impl CooldownBackoff {
19 pub(crate) fn effective_cooldown_secs(&self, base_secs: u64, penalty_streak: u32) -> u64 {
20 if base_secs == 0 {
21 return 0;
22 }
23 if self.factor <= 1 {
24 return base_secs;
25 }
26 let cap = if self.max_secs == 0 {
27 base_secs
28 } else {
29 self.max_secs.max(base_secs)
30 };
31
32 let mut secs = base_secs;
33 for _ in 0..penalty_streak.min(64) {
34 secs = secs.saturating_mul(self.factor);
35 if secs >= cap {
36 return cap;
37 }
38 }
39 secs.min(cap)
40 }
41}
42
43#[derive(Debug, Clone, Default)]
44pub struct LbState {
45 pub failure_counts: Vec<u32>,
46 pub cooldown_until: Vec<Option<std::time::Instant>>,
47 pub usage_exhausted: Vec<bool>,
48 pub last_good_index: Option<usize>,
49 pub penalty_streak: Vec<u32>,
50 pub(crate) upstream_signature: Vec<String>,
51}
52
53impl LbState {
54 pub(crate) fn ensure_layout(&mut self, service_name: &str, upstreams: &[UpstreamConfig]) {
55 let signature = upstreams
56 .iter()
57 .enumerate()
58 .map(|(idx, upstream)| upstream_signature_key(service_name, idx, upstream))
59 .collect::<Vec<_>>();
60 let legacy_signature = upstreams
61 .iter()
62 .map(|upstream| upstream.base_url.clone())
63 .collect::<Vec<_>>();
64
65 if has_duplicate_signatures(&signature) {
66 self.reset_for_layout(signature);
67 return;
68 }
69
70 let len = upstreams.len();
71 if self.upstream_signature == signature
72 && self.failure_counts.len() == len
73 && self.cooldown_until.len() == len
74 && self.usage_exhausted.len() == len
75 && self.penalty_streak.len() == len
76 {
77 return;
78 }
79
80 self.migrate_layout(signature, legacy_signature);
81 }
82
83 fn reset_for_layout(&mut self, signature: Vec<String>) {
84 let len = signature.len();
85 self.failure_counts = vec![0; len];
86 self.cooldown_until = vec![None; len];
87 self.usage_exhausted = vec![false; len];
88 self.penalty_streak = vec![0; len];
89 self.last_good_index = None;
91 self.upstream_signature = signature;
92 }
93
94 fn migrate_layout(&mut self, signature: Vec<String>, legacy_signature: Vec<String>) {
95 if self.upstream_signature.is_empty() {
96 self.reset_for_layout(signature);
97 return;
98 }
99
100 let old_signature = std::mem::take(&mut self.upstream_signature);
101 if has_duplicate_signatures(&old_signature) {
102 self.reset_for_layout(signature);
103 return;
104 }
105
106 let old_index_by_signature = old_signature
107 .iter()
108 .enumerate()
109 .map(|(idx, key)| (key.clone(), idx))
110 .collect::<std::collections::HashMap<_, _>>();
111 let legacy_fallback_enabled = !has_duplicate_signatures(&legacy_signature);
112
113 let old_failure_counts = std::mem::take(&mut self.failure_counts);
114 let old_cooldown_until = std::mem::take(&mut self.cooldown_until);
115 let old_usage_exhausted = std::mem::take(&mut self.usage_exhausted);
116 let old_penalty_streak = std::mem::take(&mut self.penalty_streak);
117 let old_last_good_index = self.last_good_index.take();
118
119 let len = signature.len();
120 self.failure_counts = vec![0; len];
121 self.cooldown_until = vec![None; len];
122 self.usage_exhausted = vec![false; len];
123 self.penalty_streak = vec![0; len];
124
125 for (new_idx, key) in signature.iter().enumerate() {
126 let old_idx = old_index_by_signature.get(key).copied().or_else(|| {
127 legacy_fallback_enabled
128 .then(|| legacy_signature.get(new_idx))
129 .flatten()
130 .and_then(|legacy_key| old_index_by_signature.get(legacy_key).copied())
131 });
132 let Some(old_idx) = old_idx else {
133 continue;
134 };
135 self.failure_counts[new_idx] = old_failure_counts.get(old_idx).copied().unwrap_or(0);
136 self.cooldown_until[new_idx] = old_cooldown_until.get(old_idx).and_then(|until| *until);
137 self.usage_exhausted[new_idx] =
138 old_usage_exhausted.get(old_idx).copied().unwrap_or(false);
139 self.penalty_streak[new_idx] = old_penalty_streak.get(old_idx).copied().unwrap_or(0);
140 }
141
142 self.last_good_index = old_last_good_index.and_then(|old_idx| {
143 old_signature.get(old_idx).and_then(|key| {
144 signature
145 .iter()
146 .position(|new_key| new_key == key)
147 .or_else(|| {
148 legacy_fallback_enabled
149 .then(|| {
150 legacy_signature
151 .iter()
152 .position(|legacy_key| legacy_key == key)
153 })
154 .flatten()
155 })
156 })
157 });
158 self.upstream_signature = signature;
159 }
160}
161
162fn has_duplicate_signatures(values: &[String]) -> bool {
163 let mut seen = HashSet::new();
164 values.iter().any(|value| !seen.insert(value))
165}
166
167fn upstream_signature_key(
168 service_name: &str,
169 upstream_index: usize,
170 upstream: &UpstreamConfig,
171) -> String {
172 let provider_id = upstream
173 .tags
174 .get("provider_id")
175 .cloned()
176 .unwrap_or_else(|| format!("{service_name}#{upstream_index}"));
177 let endpoint_id = upstream
178 .tags
179 .get("endpoint_id")
180 .cloned()
181 .unwrap_or_else(|| upstream_index.to_string());
182 let provider_endpoint = ProviderEndpointKey::new(service_name, provider_id, endpoint_id);
183 format!("{}|{}", provider_endpoint.stable_key(), upstream.base_url)
184}
185
186#[derive(Debug, Clone)]
188pub struct SelectedUpstream {
189 pub station_name: String,
190 pub index: usize,
191 pub upstream: UpstreamConfig,
192}
193
194#[derive(Clone)]
196pub struct LoadBalancer {
197 pub service: Arc<ServiceConfig>,
198 pub states: Arc<Mutex<HashMap<String, LbState>>>,
199}
200
201impl LoadBalancer {
202 pub fn new(service: Arc<ServiceConfig>, states: Arc<Mutex<HashMap<String, LbState>>>) -> Self {
203 Self { service, states }
204 }
205
206 #[cfg(test)]
207 pub fn select_upstream(&self) -> Option<SelectedUpstream> {
208 self.select_upstream_avoiding(&HashSet::new())
209 }
210
211 pub fn select_upstream_avoiding(&self, avoid: &HashSet<usize>) -> Option<SelectedUpstream> {
212 self.select_upstream_avoiding_inner(avoid, false)
213 }
214
215 pub fn select_upstream_avoiding_strict(
216 &self,
217 avoid: &HashSet<usize>,
218 ) -> Option<SelectedUpstream> {
219 self.select_upstream_avoiding_inner(avoid, true)
220 }
221
222 fn select_upstream_avoiding_inner(
223 &self,
224 avoid: &HashSet<usize>,
225 strict: bool,
226 ) -> Option<SelectedUpstream> {
227 if self.service.upstreams.is_empty() {
228 return None;
229 }
230
231 let mut map = match self.states.lock() {
232 Ok(m) => m,
233 Err(e) => e.into_inner(),
234 };
235 let entry = map.entry(self.service.name.clone()).or_default();
236 entry.ensure_layout(self.service.name.as_str(), &self.service.upstreams);
237
238 let now = std::time::Instant::now();
239
240 for idx in 0..self.service.upstreams.len() {
242 if let Some(until) = entry.cooldown_until.get(idx).and_then(|v| *v)
243 && now >= until
244 {
245 entry.failure_counts[idx] = 0;
246 if let Some(slot) = entry.cooldown_until.get_mut(idx) {
247 *slot = None;
248 }
249 }
250 }
251
252 if let Some(idx) = entry.last_good_index
255 && idx < self.service.upstreams.len()
256 && entry.failure_counts[idx] < FAILURE_THRESHOLD
257 && !entry.usage_exhausted.get(idx).copied().unwrap_or(false)
258 && !avoid.contains(&idx)
259 {
260 let upstream = self.service.upstreams[idx].clone();
261 return Some(SelectedUpstream {
262 station_name: self.service.name.clone(),
263 index: idx,
264 upstream,
265 });
266 }
267
268 if let Some(idx) = self
270 .service
271 .upstreams
272 .iter()
273 .enumerate()
274 .find_map(|(idx, _)| {
275 if avoid.contains(&idx) {
276 return None;
277 }
278 if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
279 return None;
280 }
281 if entry.usage_exhausted.get(idx).copied().unwrap_or(false) {
282 return None;
283 }
284 Some(idx)
285 })
286 {
287 let upstream = self.service.upstreams[idx].clone();
288 return Some(SelectedUpstream {
289 station_name: self.service.name.clone(),
290 index: idx,
291 upstream,
292 });
293 }
294
295 if let Some(idx) = self
297 .service
298 .upstreams
299 .iter()
300 .enumerate()
301 .find_map(|(idx, _)| {
302 if avoid.contains(&idx) {
303 return None;
304 }
305 if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
306 None
307 } else {
308 Some(idx)
309 }
310 })
311 {
312 let upstream = self.service.upstreams[idx].clone();
313 return Some(SelectedUpstream {
314 station_name: self.service.name.clone(),
315 index: idx,
316 upstream,
317 });
318 }
319
320 if strict {
321 return None;
322 }
323
324 let idx = (0..self.service.upstreams.len())
327 .find(|i| !avoid.contains(i))
328 .unwrap_or(0);
329 let upstream = self.service.upstreams[idx].clone();
330 Some(SelectedUpstream {
331 station_name: self.service.name.clone(),
332 index: idx,
333 upstream,
334 })
335 }
336
337 pub fn penalize_with_backoff(
338 &self,
339 index: usize,
340 cooldown_secs: u64,
341 reason: &str,
342 backoff: CooldownBackoff,
343 ) {
344 let mut map = match self.states.lock() {
345 Ok(m) => m,
346 Err(_) => return,
347 };
348 let entry = map
349 .entry(self.service.name.clone())
350 .or_insert_with(LbState::default);
351 entry.ensure_layout(self.service.name.as_str(), &self.service.upstreams);
352 if index >= entry.failure_counts.len() {
353 return;
354 }
355
356 let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
357 let effective_secs = backoff.effective_cooldown_secs(cooldown_secs, streak);
358
359 entry.failure_counts[index] = FAILURE_THRESHOLD;
360 if let Some(slot) = entry.cooldown_until.get_mut(index) {
361 *slot =
362 Some(std::time::Instant::now() + std::time::Duration::from_secs(effective_secs));
363 }
364 if let Some(slot) = entry.penalty_streak.get_mut(index) {
365 *slot = streak.saturating_add(1);
366 }
367 if entry.last_good_index == Some(index) {
368 entry.last_good_index = None;
369 }
370 info!(
371 "lb: upstream '{}' index {} penalized for {}s (reason: {})",
372 self.service.name, index, effective_secs, reason
373 );
374 }
375
376 pub fn record_result_with_backoff(
377 &self,
378 index: usize,
379 success: bool,
380 failure_threshold_cooldown_secs: u64,
381 backoff: CooldownBackoff,
382 ) {
383 let mut map = match self.states.lock() {
384 Ok(m) => m,
385 Err(_) => return,
386 };
387 let entry = map
388 .entry(self.service.name.clone())
389 .or_insert_with(LbState::default);
390 entry.ensure_layout(self.service.name.as_str(), &self.service.upstreams);
391 if index >= entry.failure_counts.len() {
392 return;
393 }
394 if success {
395 entry.failure_counts[index] = 0;
396 if let Some(slot) = entry.cooldown_until.get_mut(index) {
397 *slot = None;
398 }
399 if let Some(slot) = entry.penalty_streak.get_mut(index) {
400 *slot = 0;
401 }
402 entry.last_good_index = Some(index);
404 } else {
405 entry.failure_counts[index] = entry.failure_counts[index].saturating_add(1);
406 if entry.failure_counts[index] >= FAILURE_THRESHOLD
407 && let Some(slot) = entry.cooldown_until.get_mut(index)
408 {
409 let base_secs = if failure_threshold_cooldown_secs == 0 {
410 COOLDOWN_SECS
411 } else {
412 failure_threshold_cooldown_secs
413 };
414 let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
415 let effective_secs = backoff.effective_cooldown_secs(base_secs, streak);
416 let now = std::time::Instant::now();
417 let new_until = now + std::time::Duration::from_secs(effective_secs);
418 let should_update = match *slot {
419 Some(existing) => new_until > existing,
420 None => true,
421 };
422 if should_update {
423 *slot = Some(new_until);
424 }
425 if let Some(slot) = entry.penalty_streak.get_mut(index) {
426 *slot = streak.saturating_add(1);
427 }
428 info!(
429 "lb: upstream '{}' index {} reached failure threshold {} (count = {}), entering cooldown for {}s",
430 self.service.name,
431 index,
432 FAILURE_THRESHOLD,
433 entry.failure_counts[index],
434 effective_secs
435 );
436 if entry.last_good_index == Some(index) {
438 entry.last_good_index = None;
439 }
440 }
441 }
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use crate::config::{ServiceConfig, UpstreamAuth, UpstreamConfig};
449
450 fn make_service(name: &str, urls: &[&str]) -> ServiceConfig {
451 ServiceConfig {
452 name: name.to_string(),
453 alias: None,
454 enabled: true,
455 level: 1,
456 upstreams: urls
457 .iter()
458 .map(|u| UpstreamConfig {
459 base_url: u.to_string(),
460 auth: UpstreamAuth {
461 auth_token: Some("sk-test".to_string()),
462 auth_token_env: None,
463 api_key: None,
464 api_key_env: None,
465 },
466 tags: HashMap::new(),
467 supported_models: HashMap::new(),
468 model_mapping: HashMap::new(),
469 })
470 .collect(),
471 }
472 }
473
474 fn make_provider_endpoint_service(
475 name: &str,
476 upstreams: &[(&str, &str, &str)],
477 ) -> ServiceConfig {
478 ServiceConfig {
479 name: name.to_string(),
480 alias: None,
481 enabled: true,
482 level: 1,
483 upstreams: upstreams
484 .iter()
485 .map(|(base_url, provider_id, endpoint_id)| UpstreamConfig {
486 base_url: (*base_url).to_string(),
487 auth: UpstreamAuth {
488 auth_token: Some("sk-test".to_string()),
489 auth_token_env: None,
490 api_key: None,
491 api_key_env: None,
492 },
493 tags: HashMap::from([
494 ("provider_id".to_string(), (*provider_id).to_string()),
495 ("endpoint_id".to_string(), (*endpoint_id).to_string()),
496 ]),
497 supported_models: HashMap::new(),
498 model_mapping: HashMap::new(),
499 })
500 .collect(),
501 }
502 }
503
504 #[test]
505 fn lb_prefers_non_exhausted_upstream_when_available() {
506 let service = make_service(
507 "codex-main",
508 &["https://primary.example", "https://backup.example"],
509 );
510 let states = Arc::new(Mutex::new(HashMap::new()));
511 let lb = LoadBalancer::new(Arc::new(service), states.clone());
512
513 let first = lb.select_upstream().expect("should select an upstream");
515 assert_eq!(first.index, 0);
516
517 {
519 let mut guard = states.lock().unwrap();
520 let entry = guard
521 .entry("codex-main".to_string())
522 .or_insert_with(LbState::default);
523 entry.ensure_layout(lb.service.name.as_str(), &lb.service.upstreams);
524 entry.usage_exhausted[0] = true;
525 entry.usage_exhausted[1] = false;
526 }
527
528 let second = lb.select_upstream().expect("should select backup upstream");
530 assert_eq!(second.index, 1);
531 }
532
533 #[test]
534 fn lb_falls_back_when_all_exhausted() {
535 let service = make_service(
536 "codex-main",
537 &["https://primary.example", "https://backup.example"],
538 );
539 let states = Arc::new(Mutex::new(HashMap::new()));
540 let lb = LoadBalancer::new(Arc::new(service), states.clone());
541
542 let _ = lb.select_upstream();
544
545 {
546 let mut guard = states.lock().unwrap();
547 let entry = guard
548 .entry("codex-main".to_string())
549 .or_insert_with(LbState::default);
550 entry.ensure_layout(lb.service.name.as_str(), &lb.service.upstreams);
551 entry.usage_exhausted[0] = true;
552 entry.usage_exhausted[1] = true;
553 }
554
555 let selected = lb
557 .select_upstream()
558 .expect("should still select an upstream");
559 assert_eq!(selected.index, 0);
560 }
561
562 #[test]
563 fn lb_strict_mode_still_falls_back_when_all_usage_exhausted() {
564 let service = make_service(
565 "codex-main",
566 &["https://primary.example", "https://backup.example"],
567 );
568 let states = Arc::new(Mutex::new(HashMap::new()));
569 let lb = LoadBalancer::new(Arc::new(service), states.clone());
570
571 {
572 let mut guard = states.lock().unwrap();
573 let entry = guard
574 .entry("codex-main".to_string())
575 .or_insert_with(LbState::default);
576 entry.ensure_layout(lb.service.name.as_str(), &lb.service.upstreams);
577 entry.usage_exhausted[0] = true;
578 entry.usage_exhausted[1] = true;
579 }
580
581 let selected = lb
582 .select_upstream_avoiding_strict(&HashSet::new())
583 .expect("strict mode should still ignore usage exhaustion on fallback");
584 assert_eq!(selected.index, 0);
585 }
586
587 #[test]
588 fn lb_resets_state_when_upstream_layout_changes() {
589 let states = Arc::new(Mutex::new(HashMap::new()));
590 let initial = LoadBalancer::new(
591 Arc::new(make_service(
592 "codex-main",
593 &["https://primary.example", "https://backup.example"],
594 )),
595 states.clone(),
596 );
597 initial.record_result_with_backoff(
598 0,
599 false,
600 COOLDOWN_SECS,
601 CooldownBackoff {
602 factor: 1,
603 max_secs: 0,
604 },
605 );
606
607 {
608 let guard = states.lock().unwrap();
609 let entry = guard.get("codex-main").expect("state exists");
610 assert_eq!(entry.failure_counts, vec![1, 0]);
611 }
612
613 let reordered = LoadBalancer::new(
614 Arc::new(make_service(
615 "codex-main",
616 &["https://backup.example", "https://primary.example"],
617 )),
618 states.clone(),
619 );
620 let selected = reordered
621 .select_upstream()
622 .expect("should select an upstream");
623 assert_eq!(selected.index, 0);
624
625 let guard = states.lock().unwrap();
626 let entry = guard.get("codex-main").expect("state exists");
627 assert_eq!(entry.failure_counts, vec![0, 0]);
628 assert_eq!(entry.last_good_index, None);
629 }
630
631 #[test]
632 fn lb_migrates_state_when_provider_endpoint_order_changes() {
633 let states = Arc::new(Mutex::new(HashMap::new()));
634 let initial = LoadBalancer::new(
635 Arc::new(make_provider_endpoint_service(
636 "routing",
637 &[
638 ("https://primary.example", "primary", "default"),
639 ("https://backup.example", "backup", "default"),
640 ],
641 )),
642 states.clone(),
643 );
644
645 {
646 let mut guard = states.lock().unwrap();
647 let entry = guard
648 .entry("routing".to_string())
649 .or_insert_with(LbState::default);
650 entry.ensure_layout(initial.service.name.as_str(), &initial.service.upstreams);
651 entry.failure_counts[0] = 2;
652 entry.cooldown_until[0] =
653 Some(std::time::Instant::now() + std::time::Duration::from_secs(30));
654 entry.penalty_streak[0] = 3;
655 entry.usage_exhausted[1] = true;
656 entry.last_good_index = Some(1);
657 }
658
659 let reordered = LoadBalancer::new(
660 Arc::new(make_provider_endpoint_service(
661 "routing",
662 &[
663 ("https://backup.example", "backup", "default"),
664 ("https://primary.example", "primary", "default"),
665 ],
666 )),
667 states.clone(),
668 );
669 let selected = reordered
670 .select_upstream()
671 .expect("should select a migrated non-exhausted upstream");
672 assert_eq!(selected.index, 1);
673
674 let guard = states.lock().unwrap();
675 let entry = guard.get("routing").expect("state exists");
676 assert_eq!(entry.failure_counts, vec![0, 2]);
677 assert_eq!(entry.usage_exhausted, vec![true, false]);
678 assert_eq!(entry.penalty_streak, vec![0, 3]);
679 assert!(entry.cooldown_until[0].is_none());
680 assert!(entry.cooldown_until[1].is_some());
681 assert_eq!(entry.last_good_index, Some(0));
682 }
683
684 #[test]
685 fn lb_migrates_legacy_base_url_signature_when_endpoint_identity_is_unambiguous() {
686 let states = Arc::new(Mutex::new(HashMap::new()));
687 let primary_url = "https://primary.example";
688 let backup_url = "https://backup.example";
689
690 {
691 let mut guard = states.lock().unwrap();
692 guard.insert(
693 "routing".to_string(),
694 LbState {
695 failure_counts: vec![FAILURE_THRESHOLD, 0],
696 cooldown_until: vec![None, None],
697 usage_exhausted: vec![false, true],
698 last_good_index: Some(1),
699 penalty_streak: vec![2, 0],
700 upstream_signature: vec![primary_url.to_string(), backup_url.to_string()],
701 },
702 );
703 }
704
705 let reordered = LoadBalancer::new(
706 Arc::new(make_provider_endpoint_service(
707 "routing",
708 &[
709 (backup_url, "backup", "default"),
710 (primary_url, "primary", "default"),
711 ],
712 )),
713 states.clone(),
714 );
715 {
716 let mut guard = states.lock().unwrap();
717 let entry = guard.get_mut("routing").expect("state exists");
718 entry.ensure_layout(
719 reordered.service.name.as_str(),
720 &reordered.service.upstreams,
721 );
722 }
723
724 let guard = states.lock().unwrap();
725 let entry = guard.get("routing").expect("state exists");
726 assert_eq!(entry.failure_counts, vec![0, FAILURE_THRESHOLD]);
727 assert_eq!(entry.usage_exhausted, vec![true, false]);
728 assert_eq!(entry.penalty_streak, vec![0, 2]);
729 assert_eq!(entry.last_good_index, Some(0));
730 }
731
732 #[test]
733 fn lb_replaces_state_when_provider_endpoint_base_url_changes() {
734 let states = Arc::new(Mutex::new(HashMap::new()));
735 let initial = LoadBalancer::new(
736 Arc::new(make_provider_endpoint_service(
737 "routing",
738 &[("https://old.example", "input", "default")],
739 )),
740 states.clone(),
741 );
742
743 {
744 let mut guard = states.lock().unwrap();
745 let entry = guard
746 .entry("routing".to_string())
747 .or_insert_with(LbState::default);
748 entry.ensure_layout(initial.service.name.as_str(), &initial.service.upstreams);
749 entry.failure_counts[0] = FAILURE_THRESHOLD;
750 entry.cooldown_until[0] =
751 Some(std::time::Instant::now() + std::time::Duration::from_secs(30));
752 entry.usage_exhausted[0] = true;
753 entry.penalty_streak[0] = 2;
754 entry.last_good_index = Some(0);
755 }
756
757 let updated = LoadBalancer::new(
758 Arc::new(make_provider_endpoint_service(
759 "routing",
760 &[("https://new.example", "input", "default")],
761 )),
762 states.clone(),
763 );
764 let selected = updated
765 .select_upstream()
766 .expect("new endpoint URL should be selectable after state replacement");
767 assert_eq!(selected.index, 0);
768
769 let guard = states.lock().unwrap();
770 let entry = guard.get("routing").expect("state exists");
771 assert_eq!(entry.failure_counts, vec![0]);
772 assert_eq!(entry.cooldown_until, vec![None]);
773 assert_eq!(entry.usage_exhausted, vec![false]);
774 assert_eq!(entry.penalty_streak, vec![0]);
775 assert_eq!(entry.last_good_index, None);
776 }
777
778 #[test]
779 fn lb_avoids_upstreams_past_failure_threshold() {
780 let service = make_service(
781 "codex-main",
782 &["https://primary.example", "https://backup.example"],
783 );
784 let states = Arc::new(Mutex::new(HashMap::new()));
785 let lb = LoadBalancer::new(Arc::new(service), states.clone());
786
787 let disabled_backoff = CooldownBackoff {
788 factor: 1,
789 max_secs: 0,
790 };
791
792 for _ in 0..FAILURE_THRESHOLD {
794 lb.record_result_with_backoff(0, false, COOLDOWN_SECS, disabled_backoff);
795 }
796
797 let selected = lb
799 .select_upstream()
800 .expect("should select backup after failures");
801 assert_eq!(selected.index, 1);
802 }
803
804 #[test]
805 fn lb_cooldown_expiry_restores_upstream_selection() {
806 let service = make_service(
807 "codex-main",
808 &["https://primary.example", "https://backup.example"],
809 );
810 let states = Arc::new(Mutex::new(HashMap::new()));
811 let lb = LoadBalancer::new(Arc::new(service), states.clone());
812
813 let disabled_backoff = CooldownBackoff {
814 factor: 1,
815 max_secs: 0,
816 };
817
818 for _ in 0..FAILURE_THRESHOLD {
819 lb.record_result_with_backoff(0, false, 2, disabled_backoff);
820 }
821
822 {
823 let guard = states.lock().unwrap();
824 let entry = guard.get("codex-main").expect("lb state exists");
825 assert_eq!(entry.failure_counts[0], FAILURE_THRESHOLD);
826 assert!(entry.cooldown_until[0].is_some());
827 }
828
829 let during_cooldown = lb
830 .select_upstream()
831 .expect("should select backup while primary cools down");
832 assert_eq!(during_cooldown.index, 1);
833
834 {
835 let mut guard = states.lock().unwrap();
836 let entry = guard.get_mut("codex-main").expect("lb state exists");
837 entry.cooldown_until[0] =
838 Some(std::time::Instant::now() - std::time::Duration::from_secs(1));
839 }
840
841 let recovered = lb
842 .select_upstream()
843 .expect("should select primary after cooldown expiry");
844 assert_eq!(recovered.index, 0);
845
846 {
847 let guard = states.lock().unwrap();
848 let entry = guard.get("codex-main").expect("lb state exists");
849 assert_eq!(entry.failure_counts[0], 0);
850 assert!(entry.cooldown_until[0].is_none());
851 }
852 }
853
854 #[test]
855 fn lb_threshold_cooldown_backoff_grows_and_success_resets_streak() {
856 let service = make_service(
857 "codex-main",
858 &["https://primary.example", "https://backup.example"],
859 );
860 let states = Arc::new(Mutex::new(HashMap::new()));
861 let lb = LoadBalancer::new(Arc::new(service), states.clone());
862
863 let backoff = CooldownBackoff {
864 factor: 2,
865 max_secs: 10,
866 };
867
868 for _ in 0..FAILURE_THRESHOLD {
869 lb.record_result_with_backoff(0, false, 2, backoff);
870 }
871
872 let first_remaining_secs = {
873 let guard = states.lock().unwrap();
874 let entry = guard.get("codex-main").expect("lb state exists");
875 assert_eq!(entry.penalty_streak[0], 1);
876 entry.cooldown_until[0]
877 .map(|until| {
878 until
879 .saturating_duration_since(std::time::Instant::now())
880 .as_secs()
881 })
882 .expect("first cooldown exists")
883 };
884 assert!(first_remaining_secs <= 2);
885
886 {
887 let mut guard = states.lock().unwrap();
888 let entry = guard.get_mut("codex-main").expect("lb state exists");
889 entry.cooldown_until[0] =
890 Some(std::time::Instant::now() - std::time::Duration::from_secs(1));
891 }
892 let _ = lb.select_upstream();
893
894 for _ in 0..FAILURE_THRESHOLD {
895 lb.record_result_with_backoff(0, false, 2, backoff);
896 }
897
898 let second_remaining_secs = {
899 let guard = states.lock().unwrap();
900 let entry = guard.get("codex-main").expect("lb state exists");
901 assert_eq!(entry.penalty_streak[0], 2);
902 entry.cooldown_until[0]
903 .map(|until| {
904 until
905 .saturating_duration_since(std::time::Instant::now())
906 .as_secs()
907 })
908 .expect("second cooldown exists")
909 };
910 assert!(second_remaining_secs <= 4);
911 assert!(second_remaining_secs >= first_remaining_secs);
912
913 lb.record_result_with_backoff(0, true, 2, backoff);
914
915 {
916 let guard = states.lock().unwrap();
917 let entry = guard.get("codex-main").expect("lb state exists");
918 assert_eq!(entry.failure_counts[0], 0);
919 assert!(entry.cooldown_until[0].is_none());
920 assert_eq!(entry.penalty_streak[0], 0);
921 assert_eq!(entry.last_good_index, Some(0));
922 }
923 }
924}