Skip to main content

chat_router/
circuit_breaker.rs

1use std::time::{Duration, Instant};
2
3/// Configuration for the circuit breaker.
4pub struct CircuitBreakerConfig {
5    /// Number of consecutive failures before the circuit opens (skips provider).
6    pub failure_threshold: u32,
7    /// How long an open circuit waits before allowing a probe request.
8    pub recovery_timeout: Duration,
9}
10
11impl Default for CircuitBreakerConfig {
12    fn default() -> Self {
13        Self {
14            failure_threshold: 3,
15            recovery_timeout: Duration::from_secs(30),
16        }
17    }
18}
19
20struct ProviderCircuit {
21    consecutive_failures: u32,
22    last_failure_at: Option<Instant>,
23}
24
25impl ProviderCircuit {
26    fn new() -> Self {
27        Self {
28            consecutive_failures: 0,
29            last_failure_at: None,
30        }
31    }
32}
33
34pub(crate) struct CircuitBreaker {
35    config: CircuitBreakerConfig,
36    circuits: Vec<ProviderCircuit>,
37}
38
39impl CircuitBreaker {
40    pub fn new(config: CircuitBreakerConfig, provider_count: usize) -> Self {
41        let circuits = (0..provider_count).map(|_| ProviderCircuit::new()).collect();
42        Self { config, circuits }
43    }
44
45    /// Returns true if the provider should be tried (closed or half-open).
46    pub fn is_available(&self, idx: usize) -> bool {
47        let circuit = &self.circuits[idx];
48        if circuit.consecutive_failures < self.config.failure_threshold {
49            return true;
50        }
51        // Circuit is open — check if recovery timeout has elapsed (half-open)
52        match circuit.last_failure_at {
53            Some(at) => at.elapsed() >= self.config.recovery_timeout,
54            None => true,
55        }
56    }
57
58    /// Returns the index of the provider whose circuit has been open the longest,
59    /// i.e. the one most likely to have recovered. Used as a last resort when all
60    /// circuits are open.
61    pub fn longest_open(&self) -> Option<usize> {
62        self.circuits
63            .iter()
64            .enumerate()
65            .filter(|(_, c)| c.consecutive_failures >= self.config.failure_threshold)
66            .min_by_key(|(_, c)| c.last_failure_at)
67            .map(|(idx, _)| idx)
68    }
69
70    pub fn record_success(&mut self, idx: usize) {
71        let circuit = &mut self.circuits[idx];
72        circuit.consecutive_failures = 0;
73        circuit.last_failure_at = None;
74    }
75
76    pub fn record_failure(&mut self, idx: usize) {
77        let circuit = &mut self.circuits[idx];
78        circuit.consecutive_failures += 1;
79        circuit.last_failure_at = Some(Instant::now());
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn new_circuit_is_available() {
89        let cb = CircuitBreaker::new(
90            CircuitBreakerConfig {
91                failure_threshold: 3,
92                recovery_timeout: Duration::from_secs(30),
93            },
94            2,
95        );
96        assert!(cb.is_available(0));
97        assert!(cb.is_available(1));
98    }
99
100    #[test]
101    fn opens_after_threshold() {
102        let mut cb = CircuitBreaker::new(
103            CircuitBreakerConfig {
104                failure_threshold: 2,
105                recovery_timeout: Duration::from_secs(30),
106            },
107            1,
108        );
109        cb.record_failure(0);
110        assert!(cb.is_available(0)); // 1 < 2
111        cb.record_failure(0);
112        assert!(!cb.is_available(0)); // 2 >= 2, just failed
113    }
114
115    #[test]
116    fn success_resets() {
117        let mut cb = CircuitBreaker::new(
118            CircuitBreakerConfig {
119                failure_threshold: 2,
120                recovery_timeout: Duration::from_secs(30),
121            },
122            1,
123        );
124        cb.record_failure(0);
125        cb.record_failure(0);
126        assert!(!cb.is_available(0));
127        cb.record_success(0);
128        assert!(cb.is_available(0));
129    }
130
131    #[test]
132    fn half_open_after_recovery_timeout() {
133        let mut cb = CircuitBreaker::new(
134            CircuitBreakerConfig {
135                failure_threshold: 1,
136                recovery_timeout: Duration::from_millis(0),
137            },
138            1,
139        );
140        cb.record_failure(0);
141        // With 0ms timeout, should immediately be half-open
142        assert!(cb.is_available(0));
143    }
144
145    #[test]
146    fn longest_open_picks_oldest_failure() {
147        let mut cb = CircuitBreaker::new(
148            CircuitBreakerConfig {
149                failure_threshold: 1,
150                recovery_timeout: Duration::from_secs(999),
151            },
152            3,
153        );
154        cb.record_failure(1);
155        cb.record_failure(2);
156        cb.record_failure(0);
157        // Provider 1 failed first → oldest
158        assert_eq!(cb.longest_open(), Some(1));
159    }
160}