Skip to main content

codex_helper_core/
lb.rs

1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::sync::{Arc, Mutex};
4
5use crate::config::{ServiceConfig, UpstreamConfig};
6use crate::runtime_identity::ProviderEndpointKey;
7use tracing::info;
8
9pub const FAILURE_THRESHOLD: u32 = 3;
10pub const COOLDOWN_SECS: u64 = 30;
11
12#[derive(Debug, Clone, Copy)]
13pub struct CooldownBackoff {
14    pub factor: u64,
15    pub max_secs: u64,
16}
17
18impl CooldownBackoff {
19    pub(crate) fn effective_cooldown_secs(&self, base_secs: u64, penalty_streak: u32) -> u64 {
20        if base_secs == 0 {
21            return 0;
22        }
23        if self.factor <= 1 {
24            return base_secs;
25        }
26        let cap = if self.max_secs == 0 {
27            base_secs
28        } else {
29            self.max_secs.max(base_secs)
30        };
31
32        let mut secs = base_secs;
33        for _ in 0..penalty_streak.min(64) {
34            secs = secs.saturating_mul(self.factor);
35            if secs >= cap {
36                return cap;
37            }
38        }
39        secs.min(cap)
40    }
41}
42
43#[derive(Debug, Clone, Default)]
44pub struct LbState {
45    pub failure_counts: Vec<u32>,
46    pub cooldown_until: Vec<Option<std::time::Instant>>,
47    pub usage_exhausted: Vec<bool>,
48    pub last_good_index: Option<usize>,
49    pub penalty_streak: Vec<u32>,
50    pub(crate) upstream_signature: Vec<String>,
51}
52
53impl LbState {
54    pub(crate) fn ensure_layout(&mut self, service_name: &str, upstreams: &[UpstreamConfig]) {
55        let signature = upstreams
56            .iter()
57            .enumerate()
58            .map(|(idx, upstream)| upstream_signature_key(service_name, idx, upstream))
59            .collect::<Vec<_>>();
60        let legacy_signature = upstreams
61            .iter()
62            .map(|upstream| upstream.base_url.clone())
63            .collect::<Vec<_>>();
64
65        if has_duplicate_signatures(&signature) {
66            self.reset_for_layout(signature);
67            return;
68        }
69
70        let len = upstreams.len();
71        if self.upstream_signature == signature
72            && self.failure_counts.len() == len
73            && self.cooldown_until.len() == len
74            && self.usage_exhausted.len() == len
75            && self.penalty_streak.len() == len
76        {
77            return;
78        }
79
80        self.migrate_layout(signature, legacy_signature);
81    }
82
83    fn reset_for_layout(&mut self, signature: Vec<String>) {
84        let len = signature.len();
85        self.failure_counts = vec![0; len];
86        self.cooldown_until = vec![None; len];
87        self.usage_exhausted = vec![false; len];
88        self.penalty_streak = vec![0; len];
89        // upstream 布局变化时,原来的粘性索引不再可信,直接清空。
90        self.last_good_index = None;
91        self.upstream_signature = signature;
92    }
93
94    fn migrate_layout(&mut self, signature: Vec<String>, legacy_signature: Vec<String>) {
95        if self.upstream_signature.is_empty() {
96            self.reset_for_layout(signature);
97            return;
98        }
99
100        let old_signature = std::mem::take(&mut self.upstream_signature);
101        if has_duplicate_signatures(&old_signature) {
102            self.reset_for_layout(signature);
103            return;
104        }
105
106        let old_index_by_signature = old_signature
107            .iter()
108            .enumerate()
109            .map(|(idx, key)| (key.clone(), idx))
110            .collect::<std::collections::HashMap<_, _>>();
111        let legacy_fallback_enabled = !has_duplicate_signatures(&legacy_signature);
112
113        let old_failure_counts = std::mem::take(&mut self.failure_counts);
114        let old_cooldown_until = std::mem::take(&mut self.cooldown_until);
115        let old_usage_exhausted = std::mem::take(&mut self.usage_exhausted);
116        let old_penalty_streak = std::mem::take(&mut self.penalty_streak);
117        let old_last_good_index = self.last_good_index.take();
118
119        let len = signature.len();
120        self.failure_counts = vec![0; len];
121        self.cooldown_until = vec![None; len];
122        self.usage_exhausted = vec![false; len];
123        self.penalty_streak = vec![0; len];
124
125        for (new_idx, key) in signature.iter().enumerate() {
126            let old_idx = old_index_by_signature.get(key).copied().or_else(|| {
127                legacy_fallback_enabled
128                    .then(|| legacy_signature.get(new_idx))
129                    .flatten()
130                    .and_then(|legacy_key| old_index_by_signature.get(legacy_key).copied())
131            });
132            let Some(old_idx) = old_idx else {
133                continue;
134            };
135            self.failure_counts[new_idx] = old_failure_counts.get(old_idx).copied().unwrap_or(0);
136            self.cooldown_until[new_idx] = old_cooldown_until.get(old_idx).and_then(|until| *until);
137            self.usage_exhausted[new_idx] =
138                old_usage_exhausted.get(old_idx).copied().unwrap_or(false);
139            self.penalty_streak[new_idx] = old_penalty_streak.get(old_idx).copied().unwrap_or(0);
140        }
141
142        self.last_good_index = old_last_good_index.and_then(|old_idx| {
143            old_signature.get(old_idx).and_then(|key| {
144                signature
145                    .iter()
146                    .position(|new_key| new_key == key)
147                    .or_else(|| {
148                        legacy_fallback_enabled
149                            .then(|| {
150                                legacy_signature
151                                    .iter()
152                                    .position(|legacy_key| legacy_key == key)
153                            })
154                            .flatten()
155                    })
156            })
157        });
158        self.upstream_signature = signature;
159    }
160}
161
162fn has_duplicate_signatures(values: &[String]) -> bool {
163    let mut seen = HashSet::new();
164    values.iter().any(|value| !seen.insert(value))
165}
166
167fn upstream_signature_key(
168    service_name: &str,
169    upstream_index: usize,
170    upstream: &UpstreamConfig,
171) -> String {
172    let provider_id = upstream
173        .tags
174        .get("provider_id")
175        .cloned()
176        .unwrap_or_else(|| format!("{service_name}#{upstream_index}"));
177    let endpoint_id = upstream
178        .tags
179        .get("endpoint_id")
180        .cloned()
181        .unwrap_or_else(|| upstream_index.to_string());
182    let provider_endpoint = ProviderEndpointKey::new(service_name, provider_id, endpoint_id);
183    format!("{}|{}", provider_endpoint.stable_key(), upstream.base_url)
184}
185
186/// Upstream selection result
187#[derive(Debug, Clone)]
188pub struct SelectedUpstream {
189    pub station_name: String,
190    pub index: usize,
191    pub upstream: UpstreamConfig,
192}
193
194/// 简单的负载选择器,当前仅按权重随机,未来可扩展为按 usage / 失败次数等切换。
195#[derive(Clone)]
196pub struct LoadBalancer {
197    pub service: Arc<ServiceConfig>,
198    pub states: Arc<Mutex<HashMap<String, LbState>>>,
199}
200
201impl LoadBalancer {
202    pub fn new(service: Arc<ServiceConfig>, states: Arc<Mutex<HashMap<String, LbState>>>) -> Self {
203        Self { service, states }
204    }
205
206    #[cfg(test)]
207    pub fn select_upstream(&self) -> Option<SelectedUpstream> {
208        self.select_upstream_avoiding(&HashSet::new())
209    }
210
211    pub fn select_upstream_avoiding(&self, avoid: &HashSet<usize>) -> Option<SelectedUpstream> {
212        self.select_upstream_avoiding_inner(avoid, false)
213    }
214
215    pub fn select_upstream_avoiding_strict(
216        &self,
217        avoid: &HashSet<usize>,
218    ) -> Option<SelectedUpstream> {
219        self.select_upstream_avoiding_inner(avoid, true)
220    }
221
222    fn select_upstream_avoiding_inner(
223        &self,
224        avoid: &HashSet<usize>,
225        strict: bool,
226    ) -> Option<SelectedUpstream> {
227        if self.service.upstreams.is_empty() {
228            return None;
229        }
230
231        let mut map = match self.states.lock() {
232            Ok(m) => m,
233            Err(e) => e.into_inner(),
234        };
235        let entry = map.entry(self.service.name.clone()).or_default();
236        entry.ensure_layout(self.service.name.as_str(), &self.service.upstreams);
237
238        let now = std::time::Instant::now();
239
240        // 更新冷却状态:如果冷却期已过,重置失败计数和冷却时间。
241        for idx in 0..self.service.upstreams.len() {
242            if let Some(until) = entry.cooldown_until.get(idx).and_then(|v| *v)
243                && now >= until
244            {
245                entry.failure_counts[idx] = 0;
246                if let Some(slot) = entry.cooldown_until.get_mut(idx) {
247                    *slot = None;
248                }
249            }
250        }
251
252        // 优先使用最近一次“成功”的 upstream,实现粘性路由:
253        // 一旦已经切换到可用线路,就尽量保持在该线路上,而不是每次都从头熔断。
254        if let Some(idx) = entry.last_good_index
255            && idx < self.service.upstreams.len()
256            && entry.failure_counts[idx] < FAILURE_THRESHOLD
257            && !entry.usage_exhausted.get(idx).copied().unwrap_or(false)
258            && !avoid.contains(&idx)
259        {
260            let upstream = self.service.upstreams[idx].clone();
261            return Some(SelectedUpstream {
262                station_name: self.service.name.clone(),
263                index: idx,
264                upstream,
265            });
266        }
267
268        // 第一轮:按顺序选择第一个「未熔断 + 未标记用量用尽」的 upstream。
269        if let Some(idx) = self
270            .service
271            .upstreams
272            .iter()
273            .enumerate()
274            .find_map(|(idx, _)| {
275                if avoid.contains(&idx) {
276                    return None;
277                }
278                if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
279                    return None;
280                }
281                if entry.usage_exhausted.get(idx).copied().unwrap_or(false) {
282                    return None;
283                }
284                Some(idx)
285            })
286        {
287            let upstream = self.service.upstreams[idx].clone();
288            return Some(SelectedUpstream {
289                station_name: self.service.name.clone(),
290                index: idx,
291                upstream,
292            });
293        }
294
295        // 第二轮:忽略 usage_exhausted,只看失败阈值,仍然按顺序选第一个。
296        if let Some(idx) = self
297            .service
298            .upstreams
299            .iter()
300            .enumerate()
301            .find_map(|(idx, _)| {
302                if avoid.contains(&idx) {
303                    return None;
304                }
305                if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
306                    None
307                } else {
308                    Some(idx)
309                }
310            })
311        {
312            let upstream = self.service.upstreams[idx].clone();
313            return Some(SelectedUpstream {
314                station_name: self.service.name.clone(),
315                index: idx,
316                upstream,
317            });
318        }
319
320        if strict {
321            return None;
322        }
323
324        // 兜底:所有 upstream 都已达到失败阈值时,仍然返回第一个,以保证永远有兜底。
325        // 如果 avoid 把所有都排除了,则兜底返回第一个“非 avoid”的 upstream;仍然没有则返回 0。
326        let idx = (0..self.service.upstreams.len())
327            .find(|i| !avoid.contains(i))
328            .unwrap_or(0);
329        let upstream = self.service.upstreams[idx].clone();
330        Some(SelectedUpstream {
331            station_name: self.service.name.clone(),
332            index: idx,
333            upstream,
334        })
335    }
336
337    pub fn penalize_with_backoff(
338        &self,
339        index: usize,
340        cooldown_secs: u64,
341        reason: &str,
342        backoff: CooldownBackoff,
343    ) {
344        let mut map = match self.states.lock() {
345            Ok(m) => m,
346            Err(_) => return,
347        };
348        let entry = map
349            .entry(self.service.name.clone())
350            .or_insert_with(LbState::default);
351        entry.ensure_layout(self.service.name.as_str(), &self.service.upstreams);
352        if index >= entry.failure_counts.len() {
353            return;
354        }
355
356        let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
357        let effective_secs = backoff.effective_cooldown_secs(cooldown_secs, streak);
358
359        entry.failure_counts[index] = FAILURE_THRESHOLD;
360        if let Some(slot) = entry.cooldown_until.get_mut(index) {
361            *slot =
362                Some(std::time::Instant::now() + std::time::Duration::from_secs(effective_secs));
363        }
364        if let Some(slot) = entry.penalty_streak.get_mut(index) {
365            *slot = streak.saturating_add(1);
366        }
367        if entry.last_good_index == Some(index) {
368            entry.last_good_index = None;
369        }
370        info!(
371            "lb: upstream '{}' index {} penalized for {}s (reason: {})",
372            self.service.name, index, effective_secs, reason
373        );
374    }
375
376    pub fn record_result_with_backoff(
377        &self,
378        index: usize,
379        success: bool,
380        failure_threshold_cooldown_secs: u64,
381        backoff: CooldownBackoff,
382    ) {
383        let mut map = match self.states.lock() {
384            Ok(m) => m,
385            Err(_) => return,
386        };
387        let entry = map
388            .entry(self.service.name.clone())
389            .or_insert_with(LbState::default);
390        entry.ensure_layout(self.service.name.as_str(), &self.service.upstreams);
391        if index >= entry.failure_counts.len() {
392            return;
393        }
394        if success {
395            entry.failure_counts[index] = 0;
396            if let Some(slot) = entry.cooldown_until.get_mut(index) {
397                *slot = None;
398            }
399            if let Some(slot) = entry.penalty_streak.get_mut(index) {
400                *slot = 0;
401            }
402            // 成功请求会将该 upstream 记为“最近可用线路”,后续优先继续使用。
403            entry.last_good_index = Some(index);
404        } else {
405            entry.failure_counts[index] = entry.failure_counts[index].saturating_add(1);
406            if entry.failure_counts[index] >= FAILURE_THRESHOLD
407                && let Some(slot) = entry.cooldown_until.get_mut(index)
408            {
409                let base_secs = if failure_threshold_cooldown_secs == 0 {
410                    COOLDOWN_SECS
411                } else {
412                    failure_threshold_cooldown_secs
413                };
414                let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
415                let effective_secs = backoff.effective_cooldown_secs(base_secs, streak);
416                let now = std::time::Instant::now();
417                let new_until = now + std::time::Duration::from_secs(effective_secs);
418                let should_update = match *slot {
419                    Some(existing) => new_until > existing,
420                    None => true,
421                };
422                if should_update {
423                    *slot = Some(new_until);
424                }
425                if let Some(slot) = entry.penalty_streak.get_mut(index) {
426                    *slot = streak.saturating_add(1);
427                }
428                info!(
429                    "lb: upstream '{}' index {} reached failure threshold {} (count = {}), entering cooldown for {}s",
430                    self.service.name,
431                    index,
432                    FAILURE_THRESHOLD,
433                    entry.failure_counts[index],
434                    effective_secs
435                );
436                // 触发熔断时,如当前 last_good_index 指向该线路,则清空,允许后续选择其他线路。
437                if entry.last_good_index == Some(index) {
438                    entry.last_good_index = None;
439                }
440            }
441        }
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::config::{ServiceConfig, UpstreamAuth, UpstreamConfig};
449
450    fn make_service(name: &str, urls: &[&str]) -> ServiceConfig {
451        ServiceConfig {
452            name: name.to_string(),
453            alias: None,
454            enabled: true,
455            level: 1,
456            upstreams: urls
457                .iter()
458                .map(|u| UpstreamConfig {
459                    base_url: u.to_string(),
460                    auth: UpstreamAuth {
461                        auth_token: Some("sk-test".to_string()),
462                        auth_token_env: None,
463                        api_key: None,
464                        api_key_env: None,
465                    },
466                    tags: HashMap::new(),
467                    supported_models: HashMap::new(),
468                    model_mapping: HashMap::new(),
469                })
470                .collect(),
471        }
472    }
473
474    fn make_provider_endpoint_service(
475        name: &str,
476        upstreams: &[(&str, &str, &str)],
477    ) -> ServiceConfig {
478        ServiceConfig {
479            name: name.to_string(),
480            alias: None,
481            enabled: true,
482            level: 1,
483            upstreams: upstreams
484                .iter()
485                .map(|(base_url, provider_id, endpoint_id)| UpstreamConfig {
486                    base_url: (*base_url).to_string(),
487                    auth: UpstreamAuth {
488                        auth_token: Some("sk-test".to_string()),
489                        auth_token_env: None,
490                        api_key: None,
491                        api_key_env: None,
492                    },
493                    tags: HashMap::from([
494                        ("provider_id".to_string(), (*provider_id).to_string()),
495                        ("endpoint_id".to_string(), (*endpoint_id).to_string()),
496                    ]),
497                    supported_models: HashMap::new(),
498                    model_mapping: HashMap::new(),
499                })
500                .collect(),
501        }
502    }
503
504    #[test]
505    fn lb_prefers_non_exhausted_upstream_when_available() {
506        let service = make_service(
507            "codex-main",
508            &["https://primary.example", "https://backup.example"],
509        );
510        let states = Arc::new(Mutex::new(HashMap::new()));
511        let lb = LoadBalancer::new(Arc::new(service), states.clone());
512
513        // 初次选择应选第一个 upstream(index 0)。
514        let first = lb.select_upstream().expect("should select an upstream");
515        assert_eq!(first.index, 0);
516
517        // 标记 index 0 为 usage_exhausted,index 1 为可用。
518        {
519            let mut guard = states.lock().unwrap();
520            let entry = guard
521                .entry("codex-main".to_string())
522                .or_insert_with(LbState::default);
523            entry.ensure_layout(lb.service.name.as_str(), &lb.service.upstreams);
524            entry.usage_exhausted[0] = true;
525            entry.usage_exhausted[1] = false;
526        }
527
528        // 此时应优先选择未 exhausted 的 index 1。
529        let second = lb.select_upstream().expect("should select backup upstream");
530        assert_eq!(second.index, 1);
531    }
532
533    #[test]
534    fn lb_falls_back_when_all_exhausted() {
535        let service = make_service(
536            "codex-main",
537            &["https://primary.example", "https://backup.example"],
538        );
539        let states = Arc::new(Mutex::new(HashMap::new()));
540        let lb = LoadBalancer::new(Arc::new(service), states.clone());
541
542        // 初始化状态
543        let _ = lb.select_upstream();
544
545        {
546            let mut guard = states.lock().unwrap();
547            let entry = guard
548                .entry("codex-main".to_string())
549                .or_insert_with(LbState::default);
550            entry.ensure_layout(lb.service.name.as_str(), &lb.service.upstreams);
551            entry.usage_exhausted[0] = true;
552            entry.usage_exhausted[1] = true;
553        }
554
555        // 所有 upstream 都 exhausted 时,仍然应返回 index 0 做兜底。
556        let selected = lb
557            .select_upstream()
558            .expect("should still select an upstream");
559        assert_eq!(selected.index, 0);
560    }
561
562    #[test]
563    fn lb_strict_mode_still_falls_back_when_all_usage_exhausted() {
564        let service = make_service(
565            "codex-main",
566            &["https://primary.example", "https://backup.example"],
567        );
568        let states = Arc::new(Mutex::new(HashMap::new()));
569        let lb = LoadBalancer::new(Arc::new(service), states.clone());
570
571        {
572            let mut guard = states.lock().unwrap();
573            let entry = guard
574                .entry("codex-main".to_string())
575                .or_insert_with(LbState::default);
576            entry.ensure_layout(lb.service.name.as_str(), &lb.service.upstreams);
577            entry.usage_exhausted[0] = true;
578            entry.usage_exhausted[1] = true;
579        }
580
581        let selected = lb
582            .select_upstream_avoiding_strict(&HashSet::new())
583            .expect("strict mode should still ignore usage exhaustion on fallback");
584        assert_eq!(selected.index, 0);
585    }
586
587    #[test]
588    fn lb_resets_state_when_upstream_layout_changes() {
589        let states = Arc::new(Mutex::new(HashMap::new()));
590        let initial = LoadBalancer::new(
591            Arc::new(make_service(
592                "codex-main",
593                &["https://primary.example", "https://backup.example"],
594            )),
595            states.clone(),
596        );
597        initial.record_result_with_backoff(
598            0,
599            false,
600            COOLDOWN_SECS,
601            CooldownBackoff {
602                factor: 1,
603                max_secs: 0,
604            },
605        );
606
607        {
608            let guard = states.lock().unwrap();
609            let entry = guard.get("codex-main").expect("state exists");
610            assert_eq!(entry.failure_counts, vec![1, 0]);
611        }
612
613        let reordered = LoadBalancer::new(
614            Arc::new(make_service(
615                "codex-main",
616                &["https://backup.example", "https://primary.example"],
617            )),
618            states.clone(),
619        );
620        let selected = reordered
621            .select_upstream()
622            .expect("should select an upstream");
623        assert_eq!(selected.index, 0);
624
625        let guard = states.lock().unwrap();
626        let entry = guard.get("codex-main").expect("state exists");
627        assert_eq!(entry.failure_counts, vec![0, 0]);
628        assert_eq!(entry.last_good_index, None);
629    }
630
631    #[test]
632    fn lb_migrates_state_when_provider_endpoint_order_changes() {
633        let states = Arc::new(Mutex::new(HashMap::new()));
634        let initial = LoadBalancer::new(
635            Arc::new(make_provider_endpoint_service(
636                "routing",
637                &[
638                    ("https://primary.example", "primary", "default"),
639                    ("https://backup.example", "backup", "default"),
640                ],
641            )),
642            states.clone(),
643        );
644
645        {
646            let mut guard = states.lock().unwrap();
647            let entry = guard
648                .entry("routing".to_string())
649                .or_insert_with(LbState::default);
650            entry.ensure_layout(initial.service.name.as_str(), &initial.service.upstreams);
651            entry.failure_counts[0] = 2;
652            entry.cooldown_until[0] =
653                Some(std::time::Instant::now() + std::time::Duration::from_secs(30));
654            entry.penalty_streak[0] = 3;
655            entry.usage_exhausted[1] = true;
656            entry.last_good_index = Some(1);
657        }
658
659        let reordered = LoadBalancer::new(
660            Arc::new(make_provider_endpoint_service(
661                "routing",
662                &[
663                    ("https://backup.example", "backup", "default"),
664                    ("https://primary.example", "primary", "default"),
665                ],
666            )),
667            states.clone(),
668        );
669        let selected = reordered
670            .select_upstream()
671            .expect("should select a migrated non-exhausted upstream");
672        assert_eq!(selected.index, 1);
673
674        let guard = states.lock().unwrap();
675        let entry = guard.get("routing").expect("state exists");
676        assert_eq!(entry.failure_counts, vec![0, 2]);
677        assert_eq!(entry.usage_exhausted, vec![true, false]);
678        assert_eq!(entry.penalty_streak, vec![0, 3]);
679        assert!(entry.cooldown_until[0].is_none());
680        assert!(entry.cooldown_until[1].is_some());
681        assert_eq!(entry.last_good_index, Some(0));
682    }
683
684    #[test]
685    fn lb_migrates_legacy_base_url_signature_when_endpoint_identity_is_unambiguous() {
686        let states = Arc::new(Mutex::new(HashMap::new()));
687        let primary_url = "https://primary.example";
688        let backup_url = "https://backup.example";
689
690        {
691            let mut guard = states.lock().unwrap();
692            guard.insert(
693                "routing".to_string(),
694                LbState {
695                    failure_counts: vec![FAILURE_THRESHOLD, 0],
696                    cooldown_until: vec![None, None],
697                    usage_exhausted: vec![false, true],
698                    last_good_index: Some(1),
699                    penalty_streak: vec![2, 0],
700                    upstream_signature: vec![primary_url.to_string(), backup_url.to_string()],
701                },
702            );
703        }
704
705        let reordered = LoadBalancer::new(
706            Arc::new(make_provider_endpoint_service(
707                "routing",
708                &[
709                    (backup_url, "backup", "default"),
710                    (primary_url, "primary", "default"),
711                ],
712            )),
713            states.clone(),
714        );
715        {
716            let mut guard = states.lock().unwrap();
717            let entry = guard.get_mut("routing").expect("state exists");
718            entry.ensure_layout(
719                reordered.service.name.as_str(),
720                &reordered.service.upstreams,
721            );
722        }
723
724        let guard = states.lock().unwrap();
725        let entry = guard.get("routing").expect("state exists");
726        assert_eq!(entry.failure_counts, vec![0, FAILURE_THRESHOLD]);
727        assert_eq!(entry.usage_exhausted, vec![true, false]);
728        assert_eq!(entry.penalty_streak, vec![0, 2]);
729        assert_eq!(entry.last_good_index, Some(0));
730    }
731
732    #[test]
733    fn lb_replaces_state_when_provider_endpoint_base_url_changes() {
734        let states = Arc::new(Mutex::new(HashMap::new()));
735        let initial = LoadBalancer::new(
736            Arc::new(make_provider_endpoint_service(
737                "routing",
738                &[("https://old.example", "input", "default")],
739            )),
740            states.clone(),
741        );
742
743        {
744            let mut guard = states.lock().unwrap();
745            let entry = guard
746                .entry("routing".to_string())
747                .or_insert_with(LbState::default);
748            entry.ensure_layout(initial.service.name.as_str(), &initial.service.upstreams);
749            entry.failure_counts[0] = FAILURE_THRESHOLD;
750            entry.cooldown_until[0] =
751                Some(std::time::Instant::now() + std::time::Duration::from_secs(30));
752            entry.usage_exhausted[0] = true;
753            entry.penalty_streak[0] = 2;
754            entry.last_good_index = Some(0);
755        }
756
757        let updated = LoadBalancer::new(
758            Arc::new(make_provider_endpoint_service(
759                "routing",
760                &[("https://new.example", "input", "default")],
761            )),
762            states.clone(),
763        );
764        let selected = updated
765            .select_upstream()
766            .expect("new endpoint URL should be selectable after state replacement");
767        assert_eq!(selected.index, 0);
768
769        let guard = states.lock().unwrap();
770        let entry = guard.get("routing").expect("state exists");
771        assert_eq!(entry.failure_counts, vec![0]);
772        assert_eq!(entry.cooldown_until, vec![None]);
773        assert_eq!(entry.usage_exhausted, vec![false]);
774        assert_eq!(entry.penalty_streak, vec![0]);
775        assert_eq!(entry.last_good_index, None);
776    }
777
778    #[test]
779    fn lb_avoids_upstreams_past_failure_threshold() {
780        let service = make_service(
781            "codex-main",
782            &["https://primary.example", "https://backup.example"],
783        );
784        let states = Arc::new(Mutex::new(HashMap::new()));
785        let lb = LoadBalancer::new(Arc::new(service), states.clone());
786
787        let disabled_backoff = CooldownBackoff {
788            factor: 1,
789            max_secs: 0,
790        };
791
792        // 对 primary 连续记录 FAILURE_THRESHOLD 次失败。
793        for _ in 0..FAILURE_THRESHOLD {
794            lb.record_result_with_backoff(0, false, COOLDOWN_SECS, disabled_backoff);
795        }
796
797        // 此时应选择 backup(index 1),因为 index 0 已达到失败阈值。
798        let selected = lb
799            .select_upstream()
800            .expect("should select backup after failures");
801        assert_eq!(selected.index, 1);
802    }
803
804    #[test]
805    fn lb_cooldown_expiry_restores_upstream_selection() {
806        let service = make_service(
807            "codex-main",
808            &["https://primary.example", "https://backup.example"],
809        );
810        let states = Arc::new(Mutex::new(HashMap::new()));
811        let lb = LoadBalancer::new(Arc::new(service), states.clone());
812
813        let disabled_backoff = CooldownBackoff {
814            factor: 1,
815            max_secs: 0,
816        };
817
818        for _ in 0..FAILURE_THRESHOLD {
819            lb.record_result_with_backoff(0, false, 2, disabled_backoff);
820        }
821
822        {
823            let guard = states.lock().unwrap();
824            let entry = guard.get("codex-main").expect("lb state exists");
825            assert_eq!(entry.failure_counts[0], FAILURE_THRESHOLD);
826            assert!(entry.cooldown_until[0].is_some());
827        }
828
829        let during_cooldown = lb
830            .select_upstream()
831            .expect("should select backup while primary cools down");
832        assert_eq!(during_cooldown.index, 1);
833
834        {
835            let mut guard = states.lock().unwrap();
836            let entry = guard.get_mut("codex-main").expect("lb state exists");
837            entry.cooldown_until[0] =
838                Some(std::time::Instant::now() - std::time::Duration::from_secs(1));
839        }
840
841        let recovered = lb
842            .select_upstream()
843            .expect("should select primary after cooldown expiry");
844        assert_eq!(recovered.index, 0);
845
846        {
847            let guard = states.lock().unwrap();
848            let entry = guard.get("codex-main").expect("lb state exists");
849            assert_eq!(entry.failure_counts[0], 0);
850            assert!(entry.cooldown_until[0].is_none());
851        }
852    }
853
854    #[test]
855    fn lb_threshold_cooldown_backoff_grows_and_success_resets_streak() {
856        let service = make_service(
857            "codex-main",
858            &["https://primary.example", "https://backup.example"],
859        );
860        let states = Arc::new(Mutex::new(HashMap::new()));
861        let lb = LoadBalancer::new(Arc::new(service), states.clone());
862
863        let backoff = CooldownBackoff {
864            factor: 2,
865            max_secs: 10,
866        };
867
868        for _ in 0..FAILURE_THRESHOLD {
869            lb.record_result_with_backoff(0, false, 2, backoff);
870        }
871
872        let first_remaining_secs = {
873            let guard = states.lock().unwrap();
874            let entry = guard.get("codex-main").expect("lb state exists");
875            assert_eq!(entry.penalty_streak[0], 1);
876            entry.cooldown_until[0]
877                .map(|until| {
878                    until
879                        .saturating_duration_since(std::time::Instant::now())
880                        .as_secs()
881                })
882                .expect("first cooldown exists")
883        };
884        assert!(first_remaining_secs <= 2);
885
886        {
887            let mut guard = states.lock().unwrap();
888            let entry = guard.get_mut("codex-main").expect("lb state exists");
889            entry.cooldown_until[0] =
890                Some(std::time::Instant::now() - std::time::Duration::from_secs(1));
891        }
892        let _ = lb.select_upstream();
893
894        for _ in 0..FAILURE_THRESHOLD {
895            lb.record_result_with_backoff(0, false, 2, backoff);
896        }
897
898        let second_remaining_secs = {
899            let guard = states.lock().unwrap();
900            let entry = guard.get("codex-main").expect("lb state exists");
901            assert_eq!(entry.penalty_streak[0], 2);
902            entry.cooldown_until[0]
903                .map(|until| {
904                    until
905                        .saturating_duration_since(std::time::Instant::now())
906                        .as_secs()
907                })
908                .expect("second cooldown exists")
909        };
910        assert!(second_remaining_secs <= 4);
911        assert!(second_remaining_secs >= first_remaining_secs);
912
913        lb.record_result_with_backoff(0, true, 2, backoff);
914
915        {
916            let guard = states.lock().unwrap();
917            let entry = guard.get("codex-main").expect("lb state exists");
918            assert_eq!(entry.failure_counts[0], 0);
919            assert!(entry.cooldown_until[0].is_none());
920            assert_eq!(entry.penalty_streak[0], 0);
921            assert_eq!(entry.last_good_index, Some(0));
922        }
923    }
924}