1use std::collections::HashMap;
26use std::sync::RwLock;
27use std::time::{Duration, Instant};
28
29#[derive(Debug, Clone)]
31pub struct CircuitBreakerConfig {
32 pub failure_threshold: u32,
34 pub success_threshold: u32,
36 pub open_duration: Duration,
38 pub failure_window: Duration,
40 pub timeout_as_failure: bool,
42 pub server_error_as_failure: bool,
44 pub rate_limit_as_failure: bool,
46}
47
48impl Default for CircuitBreakerConfig {
49 fn default() -> Self {
50 Self {
51 failure_threshold: 5,
52 success_threshold: 2,
53 open_duration: Duration::from_secs(30),
54 failure_window: Duration::from_secs(60),
55 timeout_as_failure: true,
56 server_error_as_failure: true,
57 rate_limit_as_failure: false, }
59 }
60}
61
62impl CircuitBreakerConfig {
63 pub fn production() -> Self {
65 Self {
66 failure_threshold: 10,
67 success_threshold: 3,
68 open_duration: Duration::from_secs(60),
69 failure_window: Duration::from_secs(120),
70 timeout_as_failure: true,
71 server_error_as_failure: true,
72 rate_limit_as_failure: false,
73 }
74 }
75
76 pub fn aggressive() -> Self {
78 Self {
79 failure_threshold: 3,
80 success_threshold: 1,
81 open_duration: Duration::from_secs(15),
82 failure_window: Duration::from_secs(30),
83 timeout_as_failure: true,
84 server_error_as_failure: true,
85 rate_limit_as_failure: true,
86 }
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum CircuitState {
93 Closed,
95 Open,
97 HalfOpen,
99}
100
101#[derive(Debug)]
103struct DomainCircuit {
104 state: CircuitState,
105 failures: Vec<Instant>,
106 successes_in_half_open: u32,
107 opened_at: Option<Instant>,
108 last_failure: Option<Instant>,
109}
110
111impl DomainCircuit {
112 fn new() -> Self {
113 Self {
114 state: CircuitState::Closed,
115 failures: Vec::new(),
116 successes_in_half_open: 0,
117 opened_at: None,
118 last_failure: None,
119 }
120 }
121
122 fn recent_failures(&self, window: Duration) -> u32 {
124 let cutoff = Instant::now() - window;
125 self.failures.iter().filter(|&&t| t > cutoff).count() as u32
126 }
127
128 fn cleanup_old_failures(&mut self, window: Duration) {
130 let cutoff = Instant::now() - window;
131 self.failures.retain(|&t| t > cutoff);
132 }
133}
134
135pub struct CircuitBreaker {
137 config: CircuitBreakerConfig,
138 circuits: RwLock<HashMap<String, DomainCircuit>>,
139}
140
141impl CircuitBreaker {
142 pub fn new(config: CircuitBreakerConfig) -> Self {
144 Self {
145 config,
146 circuits: RwLock::new(HashMap::new()),
147 }
148 }
149
150 pub fn default_config() -> Self {
152 Self::new(CircuitBreakerConfig::default())
153 }
154
155 pub fn allow_request(&self, domain: &str) -> bool {
157 let mut circuits = self.circuits.write().unwrap();
158 let circuit = circuits.entry(domain.to_string()).or_insert_with(DomainCircuit::new);
159
160 match circuit.state {
161 CircuitState::Closed => true,
162 CircuitState::Open => {
163 if let Some(opened_at) = circuit.opened_at {
165 if opened_at.elapsed() >= self.config.open_duration {
166 circuit.state = CircuitState::HalfOpen;
167 circuit.successes_in_half_open = 0;
168 true
169 } else {
170 false
171 }
172 } else {
173 false
174 }
175 }
176 CircuitState::HalfOpen => true,
177 }
178 }
179
180 pub fn record_success(&self, domain: &str) {
182 let mut circuits = self.circuits.write().unwrap();
183 if let Some(circuit) = circuits.get_mut(domain) {
184 match circuit.state {
185 CircuitState::HalfOpen => {
186 circuit.successes_in_half_open += 1;
187 if circuit.successes_in_half_open >= self.config.success_threshold {
188 circuit.state = CircuitState::Closed;
190 circuit.failures.clear();
191 circuit.opened_at = None;
192 circuit.successes_in_half_open = 0;
193 }
194 }
195 CircuitState::Closed => {
196 circuit.cleanup_old_failures(self.config.failure_window);
198 }
199 CircuitState::Open => {
200 }
202 }
203 }
204 }
205
206 pub fn record_failure(&self, domain: &str) {
208 let mut circuits = self.circuits.write().unwrap();
209 let circuit = circuits.entry(domain.to_string()).or_insert_with(DomainCircuit::new);
210
211 circuit.failures.push(Instant::now());
212 circuit.last_failure = Some(Instant::now());
213 circuit.cleanup_old_failures(self.config.failure_window);
214
215 match circuit.state {
216 CircuitState::Closed => {
217 if circuit.recent_failures(self.config.failure_window) >= self.config.failure_threshold {
218 circuit.state = CircuitState::Open;
220 circuit.opened_at = Some(Instant::now());
221 }
222 }
223 CircuitState::HalfOpen => {
224 circuit.state = CircuitState::Open;
226 circuit.opened_at = Some(Instant::now());
227 circuit.successes_in_half_open = 0;
228 }
229 CircuitState::Open => {
230 circuit.opened_at = Some(Instant::now());
232 }
233 }
234 }
235
236 pub fn record_timeout(&self, domain: &str) {
238 if self.config.timeout_as_failure {
239 self.record_failure(domain);
240 }
241 }
242
243 pub fn record_server_error(&self, domain: &str) {
245 if self.config.server_error_as_failure {
246 self.record_failure(domain);
247 }
248 }
249
250 pub fn record_rate_limit(&self, domain: &str) {
252 if self.config.rate_limit_as_failure {
253 self.record_failure(domain);
254 }
255 }
256
257 pub fn get_state(&self, domain: &str) -> CircuitState {
259 let circuits = self.circuits.read().unwrap();
260 circuits.get(domain).map(|c| c.state).unwrap_or(CircuitState::Closed)
261 }
262
263 pub fn get_open_circuits(&self) -> Vec<String> {
265 let circuits = self.circuits.read().unwrap();
266 circuits
267 .iter()
268 .filter(|(_, c)| c.state == CircuitState::Open)
269 .map(|(domain, _)| domain.clone())
270 .collect()
271 }
272
273 pub fn reset(&self, domain: &str) {
275 let mut circuits = self.circuits.write().unwrap();
276 circuits.remove(domain);
277 }
278
279 pub fn reset_all(&self) {
281 let mut circuits = self.circuits.write().unwrap();
282 circuits.clear();
283 }
284
285 pub fn stats(&self) -> CircuitBreakerStats {
287 let circuits = self.circuits.read().unwrap();
288 let total = circuits.len();
289 let open = circuits.values().filter(|c| c.state == CircuitState::Open).count();
290 let half_open = circuits.values().filter(|c| c.state == CircuitState::HalfOpen).count();
291 let closed = circuits.values().filter(|c| c.state == CircuitState::Closed).count();
292
293 CircuitBreakerStats {
294 total_domains: total,
295 open_circuits: open,
296 half_open_circuits: half_open,
297 closed_circuits: closed,
298 }
299 }
300}
301
302#[derive(Debug, Clone)]
304pub struct CircuitBreakerStats {
305 pub total_domains: usize,
307 pub open_circuits: usize,
309 pub half_open_circuits: usize,
311 pub closed_circuits: usize,
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_circuit_starts_closed() {
321 let breaker = CircuitBreaker::default_config();
322 assert!(breaker.allow_request("example.com"));
323 assert_eq!(breaker.get_state("example.com"), CircuitState::Closed);
324 }
325
326 #[test]
327 fn test_circuit_opens_after_failures() {
328 let config = CircuitBreakerConfig {
329 failure_threshold: 3,
330 ..Default::default()
331 };
332 let breaker = CircuitBreaker::new(config);
333
334 for _ in 0..3 {
336 breaker.record_failure("example.com");
337 }
338
339 assert_eq!(breaker.get_state("example.com"), CircuitState::Open);
340 assert!(!breaker.allow_request("example.com"));
341 }
342
343 #[test]
344 fn test_circuit_transitions_to_half_open() {
345 let config = CircuitBreakerConfig {
346 failure_threshold: 2,
347 open_duration: Duration::from_millis(10),
348 ..Default::default()
349 };
350 let breaker = CircuitBreaker::new(config);
351
352 breaker.record_failure("example.com");
354 breaker.record_failure("example.com");
355 assert_eq!(breaker.get_state("example.com"), CircuitState::Open);
356
357 std::thread::sleep(Duration::from_millis(15));
359
360 assert!(breaker.allow_request("example.com"));
362 assert_eq!(breaker.get_state("example.com"), CircuitState::HalfOpen);
363 }
364
365 #[test]
366 fn test_circuit_closes_after_successes() {
367 let config = CircuitBreakerConfig {
368 failure_threshold: 2,
369 success_threshold: 2,
370 open_duration: Duration::from_millis(10),
371 ..Default::default()
372 };
373 let breaker = CircuitBreaker::new(config);
374
375 breaker.record_failure("example.com");
377 breaker.record_failure("example.com");
378
379 std::thread::sleep(Duration::from_millis(15));
381 breaker.allow_request("example.com");
382
383 breaker.record_success("example.com");
385 breaker.record_success("example.com");
386
387 assert_eq!(breaker.get_state("example.com"), CircuitState::Closed);
388 }
389
390 #[test]
391 fn test_stats() {
392 let config = CircuitBreakerConfig {
393 failure_threshold: 2,
394 ..Default::default()
395 };
396 let breaker = CircuitBreaker::new(config);
397
398 breaker.allow_request("good.com");
400 breaker.record_failure("bad.com");
401 breaker.record_failure("bad.com");
402
403 let stats = breaker.stats();
404 assert_eq!(stats.total_domains, 2);
405 assert_eq!(stats.open_circuits, 1);
406 assert_eq!(stats.closed_circuits, 1);
407 }
408}