Skip to main content

codex_helper_core/
lb.rs

1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::sync::{Arc, Mutex};
4
5use crate::config::{ServiceConfig, UpstreamConfig};
6use tracing::info;
7
8pub const FAILURE_THRESHOLD: u32 = 3;
9pub const COOLDOWN_SECS: u64 = 30;
10
11#[derive(Debug, Clone, Copy)]
12pub struct CooldownBackoff {
13    pub factor: u64,
14    pub max_secs: u64,
15}
16
17impl CooldownBackoff {
18    fn effective_cooldown_secs(&self, base_secs: u64, penalty_streak: u32) -> u64 {
19        if base_secs == 0 {
20            return 0;
21        }
22        if self.factor <= 1 {
23            return base_secs;
24        }
25        let cap = if self.max_secs == 0 {
26            base_secs
27        } else {
28            self.max_secs.max(base_secs)
29        };
30
31        let mut secs = base_secs;
32        for _ in 0..penalty_streak.min(64) {
33            secs = secs.saturating_mul(self.factor);
34            if secs >= cap {
35                return cap;
36            }
37        }
38        secs.min(cap)
39    }
40}
41
42#[derive(Debug, Default)]
43pub struct LbState {
44    pub failure_counts: Vec<u32>,
45    pub cooldown_until: Vec<Option<std::time::Instant>>,
46    pub usage_exhausted: Vec<bool>,
47    pub last_good_index: Option<usize>,
48    pub penalty_streak: Vec<u32>,
49}
50
51impl LbState {
52    fn ensure_len(&mut self, len: usize) {
53        if self.failure_counts.len() != len {
54            self.failure_counts = vec![0; len];
55            self.cooldown_until = vec![None; len];
56            self.usage_exhausted = vec![false; len];
57            self.penalty_streak = vec![0; len];
58            // 如果 upstream 数量发生变化,原来的 last_good_index 很可能已经无效,直接清空。
59            self.last_good_index = None;
60        }
61    }
62}
63
64/// Upstream selection result
65#[derive(Debug, Clone)]
66pub struct SelectedUpstream {
67    pub config_name: String,
68    pub index: usize,
69    pub upstream: UpstreamConfig,
70}
71
72/// 简单的负载选择器,当前仅按权重随机,未来可扩展为按 usage / 失败次数等切换。
73#[derive(Clone)]
74pub struct LoadBalancer {
75    pub service: Arc<ServiceConfig>,
76    pub states: Arc<Mutex<HashMap<String, LbState>>>,
77}
78
79impl LoadBalancer {
80    pub fn new(service: Arc<ServiceConfig>, states: Arc<Mutex<HashMap<String, LbState>>>) -> Self {
81        Self { service, states }
82    }
83
84    #[cfg(test)]
85    pub fn select_upstream(&self) -> Option<SelectedUpstream> {
86        self.select_upstream_avoiding(&HashSet::new())
87    }
88
89    pub fn select_upstream_avoiding(&self, avoid: &HashSet<usize>) -> Option<SelectedUpstream> {
90        self.select_upstream_avoiding_inner(avoid, false)
91    }
92
93    pub fn select_upstream_avoiding_strict(
94        &self,
95        avoid: &HashSet<usize>,
96    ) -> Option<SelectedUpstream> {
97        self.select_upstream_avoiding_inner(avoid, true)
98    }
99
100    fn select_upstream_avoiding_inner(
101        &self,
102        avoid: &HashSet<usize>,
103        strict: bool,
104    ) -> Option<SelectedUpstream> {
105        if self.service.upstreams.is_empty() {
106            return None;
107        }
108
109        let mut map = match self.states.lock() {
110            Ok(m) => m,
111            Err(e) => e.into_inner(),
112        };
113        let entry = map.entry(self.service.name.clone()).or_default();
114        entry.ensure_len(self.service.upstreams.len());
115
116        let now = std::time::Instant::now();
117
118        // 更新冷却状态:如果冷却期已过,重置失败计数和冷却时间。
119        for idx in 0..self.service.upstreams.len() {
120            if let Some(until) = entry.cooldown_until.get(idx).and_then(|v| *v)
121                && now >= until
122            {
123                entry.failure_counts[idx] = 0;
124                if let Some(slot) = entry.cooldown_until.get_mut(idx) {
125                    *slot = None;
126                }
127            }
128        }
129
130        // 优先使用最近一次“成功”的 upstream,实现粘性路由:
131        // 一旦已经切换到可用线路,就尽量保持在该线路上,而不是每次都从头熔断。
132        if let Some(idx) = entry.last_good_index
133            && idx < self.service.upstreams.len()
134            && entry.failure_counts[idx] < FAILURE_THRESHOLD
135            && !entry.usage_exhausted.get(idx).copied().unwrap_or(false)
136            && !avoid.contains(&idx)
137        {
138            let upstream = self.service.upstreams[idx].clone();
139            return Some(SelectedUpstream {
140                config_name: self.service.name.clone(),
141                index: idx,
142                upstream,
143            });
144        }
145
146        // 第一轮:按顺序选择第一个「未熔断 + 未标记用量用尽」的 upstream。
147        if let Some(idx) = self
148            .service
149            .upstreams
150            .iter()
151            .enumerate()
152            .find_map(|(idx, _)| {
153                if avoid.contains(&idx) {
154                    return None;
155                }
156                if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
157                    return None;
158                }
159                if entry.usage_exhausted.get(idx).copied().unwrap_or(false) {
160                    return None;
161                }
162                Some(idx)
163            })
164        {
165            let upstream = self.service.upstreams[idx].clone();
166            return Some(SelectedUpstream {
167                config_name: self.service.name.clone(),
168                index: idx,
169                upstream,
170            });
171        }
172
173        // 第二轮:忽略 usage_exhausted,只看失败阈值,仍然按顺序选第一个。
174        if let Some(idx) = self
175            .service
176            .upstreams
177            .iter()
178            .enumerate()
179            .find_map(|(idx, _)| {
180                if avoid.contains(&idx) {
181                    return None;
182                }
183                if entry.failure_counts[idx] >= FAILURE_THRESHOLD {
184                    None
185                } else {
186                    Some(idx)
187                }
188            })
189        {
190            let upstream = self.service.upstreams[idx].clone();
191            return Some(SelectedUpstream {
192                config_name: self.service.name.clone(),
193                index: idx,
194                upstream,
195            });
196        }
197
198        if strict {
199            return None;
200        }
201
202        // 兜底:所有 upstream 都已达到失败阈值时,仍然返回第一个,以保证永远有兜底。
203        // 如果 avoid 把所有都排除了,则兜底返回第一个“非 avoid”的 upstream;仍然没有则返回 0。
204        let idx = (0..self.service.upstreams.len())
205            .find(|i| !avoid.contains(i))
206            .unwrap_or(0);
207        let upstream = self.service.upstreams[idx].clone();
208        Some(SelectedUpstream {
209            config_name: self.service.name.clone(),
210            index: idx,
211            upstream,
212        })
213    }
214
215    pub fn penalize_with_backoff(
216        &self,
217        index: usize,
218        cooldown_secs: u64,
219        reason: &str,
220        backoff: CooldownBackoff,
221    ) {
222        let mut map = match self.states.lock() {
223            Ok(m) => m,
224            Err(_) => return,
225        };
226        let entry = map
227            .entry(self.service.name.clone())
228            .or_insert_with(LbState::default);
229        entry.ensure_len(self.service.upstreams.len());
230        if index >= entry.failure_counts.len() {
231            return;
232        }
233
234        let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
235        let effective_secs = backoff.effective_cooldown_secs(cooldown_secs, streak);
236
237        entry.failure_counts[index] = FAILURE_THRESHOLD;
238        if let Some(slot) = entry.cooldown_until.get_mut(index) {
239            *slot =
240                Some(std::time::Instant::now() + std::time::Duration::from_secs(effective_secs));
241        }
242        if let Some(slot) = entry.penalty_streak.get_mut(index) {
243            *slot = streak.saturating_add(1);
244        }
245        if entry.last_good_index == Some(index) {
246            entry.last_good_index = None;
247        }
248        info!(
249            "lb: upstream '{}' index {} penalized for {}s (reason: {})",
250            self.service.name, index, effective_secs, reason
251        );
252    }
253
254    pub fn record_result_with_backoff(
255        &self,
256        index: usize,
257        success: bool,
258        failure_threshold_cooldown_secs: u64,
259        backoff: CooldownBackoff,
260    ) {
261        let mut map = match self.states.lock() {
262            Ok(m) => m,
263            Err(_) => return,
264        };
265        let entry = map
266            .entry(self.service.name.clone())
267            .or_insert_with(LbState::default);
268        entry.ensure_len(self.service.upstreams.len());
269        if index >= entry.failure_counts.len() {
270            return;
271        }
272        if success {
273            entry.failure_counts[index] = 0;
274            if let Some(slot) = entry.cooldown_until.get_mut(index) {
275                *slot = None;
276            }
277            if let Some(slot) = entry.penalty_streak.get_mut(index) {
278                *slot = 0;
279            }
280            // 成功请求会将该 upstream 记为“最近可用线路”,后续优先继续使用。
281            entry.last_good_index = Some(index);
282        } else {
283            entry.failure_counts[index] = entry.failure_counts[index].saturating_add(1);
284            if entry.failure_counts[index] >= FAILURE_THRESHOLD
285                && let Some(slot) = entry.cooldown_until.get_mut(index)
286            {
287                let base_secs = if failure_threshold_cooldown_secs == 0 {
288                    COOLDOWN_SECS
289                } else {
290                    failure_threshold_cooldown_secs
291                };
292                let streak = entry.penalty_streak.get(index).copied().unwrap_or(0);
293                let effective_secs = backoff.effective_cooldown_secs(base_secs, streak);
294                let now = std::time::Instant::now();
295                let new_until = now + std::time::Duration::from_secs(effective_secs);
296                let should_update = match *slot {
297                    Some(existing) => new_until > existing,
298                    None => true,
299                };
300                if should_update {
301                    *slot = Some(new_until);
302                }
303                if let Some(slot) = entry.penalty_streak.get_mut(index) {
304                    *slot = streak.saturating_add(1);
305                }
306                info!(
307                    "lb: upstream '{}' index {} reached failure threshold {} (count = {}), entering cooldown for {}s",
308                    self.service.name,
309                    index,
310                    FAILURE_THRESHOLD,
311                    entry.failure_counts[index],
312                    effective_secs
313                );
314                // 触发熔断时,如当前 last_good_index 指向该线路,则清空,允许后续选择其他线路。
315                if entry.last_good_index == Some(index) {
316                    entry.last_good_index = None;
317                }
318            }
319        }
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::config::{ServiceConfig, UpstreamAuth, UpstreamConfig};
327
328    fn make_service(name: &str, urls: &[&str]) -> ServiceConfig {
329        ServiceConfig {
330            name: name.to_string(),
331            alias: None,
332            enabled: true,
333            level: 1,
334            upstreams: urls
335                .iter()
336                .map(|u| UpstreamConfig {
337                    base_url: u.to_string(),
338                    auth: UpstreamAuth {
339                        auth_token: Some("sk-test".to_string()),
340                        auth_token_env: None,
341                        api_key: None,
342                        api_key_env: None,
343                    },
344                    tags: HashMap::new(),
345                    supported_models: HashMap::new(),
346                    model_mapping: HashMap::new(),
347                })
348                .collect(),
349        }
350    }
351
352    #[test]
353    fn lb_prefers_non_exhausted_upstream_when_available() {
354        let service = make_service(
355            "codex-main",
356            &["https://primary.example", "https://backup.example"],
357        );
358        let states = Arc::new(Mutex::new(HashMap::new()));
359        let lb = LoadBalancer::new(Arc::new(service), states.clone());
360
361        // 初次选择应选第一个 upstream(index 0)。
362        let first = lb.select_upstream().expect("should select an upstream");
363        assert_eq!(first.index, 0);
364
365        // 标记 index 0 为 usage_exhausted,index 1 为可用。
366        {
367            let mut guard = states.lock().unwrap();
368            let entry = guard
369                .entry("codex-main".to_string())
370                .or_insert_with(LbState::default);
371            entry.ensure_len(2);
372            entry.usage_exhausted[0] = true;
373            entry.usage_exhausted[1] = false;
374        }
375
376        // 此时应优先选择未 exhausted 的 index 1。
377        let second = lb.select_upstream().expect("should select backup upstream");
378        assert_eq!(second.index, 1);
379    }
380
381    #[test]
382    fn lb_falls_back_when_all_exhausted() {
383        let service = make_service(
384            "codex-main",
385            &["https://primary.example", "https://backup.example"],
386        );
387        let states = Arc::new(Mutex::new(HashMap::new()));
388        let lb = LoadBalancer::new(Arc::new(service), states.clone());
389
390        // 初始化状态
391        let _ = lb.select_upstream();
392
393        {
394            let mut guard = states.lock().unwrap();
395            let entry = guard
396                .entry("codex-main".to_string())
397                .or_insert_with(LbState::default);
398            entry.ensure_len(2);
399            entry.usage_exhausted[0] = true;
400            entry.usage_exhausted[1] = true;
401        }
402
403        // 所有 upstream 都 exhausted 时,仍然应返回 index 0 做兜底。
404        let selected = lb
405            .select_upstream()
406            .expect("should still select an upstream");
407        assert_eq!(selected.index, 0);
408    }
409
410    #[test]
411    fn lb_avoids_upstreams_past_failure_threshold() {
412        let service = make_service(
413            "codex-main",
414            &["https://primary.example", "https://backup.example"],
415        );
416        let states = Arc::new(Mutex::new(HashMap::new()));
417        let lb = LoadBalancer::new(Arc::new(service), states.clone());
418
419        let disabled_backoff = CooldownBackoff {
420            factor: 1,
421            max_secs: 0,
422        };
423
424        // 对 primary 连续记录 FAILURE_THRESHOLD 次失败。
425        for _ in 0..FAILURE_THRESHOLD {
426            lb.record_result_with_backoff(0, false, COOLDOWN_SECS, disabled_backoff);
427        }
428
429        // 此时应选择 backup(index 1),因为 index 0 已达到失败阈值。
430        let selected = lb
431            .select_upstream()
432            .expect("should select backup after failures");
433        assert_eq!(selected.index, 1);
434    }
435}