halldyll_core/fetch/
circuit_breaker.rs

1//! Circuit Breaker - Prevents cascade failures per domain
2//!
3//! Implements the circuit breaker pattern to protect against failing domains:
4//! - **Closed**: Normal operation, requests pass through
5//! - **Open**: Domain is failing, requests are rejected immediately
6//! - **Half-Open**: Testing if domain recovered
7//!
8//! ## Usage
9//!
10//! ```rust,ignore
11//! let breaker = CircuitBreaker::new(CircuitBreakerConfig::default());
12//! 
13//! // Before making a request
14//! if !breaker.allow_request("example.com") {
15//!     return Err(Error::CircuitOpen);
16//! }
17//!
18//! // After request
19//! match result {
20//!     Ok(_) => breaker.record_success("example.com"),
21//!     Err(e) => breaker.record_failure("example.com"),
22//! }
23//! ```
24
25use std::collections::HashMap;
26use std::sync::RwLock;
27use std::time::{Duration, Instant};
28
29/// Circuit breaker configuration
30#[derive(Debug, Clone)]
31pub struct CircuitBreakerConfig {
32    /// Failure threshold before opening circuit
33    pub failure_threshold: u32,
34    /// Success threshold to close circuit from half-open
35    pub success_threshold: u32,
36    /// Duration the circuit stays open before going to half-open
37    pub open_duration: Duration,
38    /// Time window to count failures
39    pub failure_window: Duration,
40    /// Timeout considered as failure
41    pub timeout_as_failure: bool,
42    /// 5xx errors considered as failure
43    pub server_error_as_failure: bool,
44    /// 429 rate limit considered as failure
45    pub rate_limit_as_failure: bool,
46}
47
48impl Default for CircuitBreakerConfig {
49    fn default() -> Self {
50        Self {
51            failure_threshold: 5,
52            success_threshold: 2,
53            open_duration: Duration::from_secs(30),
54            failure_window: Duration::from_secs(60),
55            timeout_as_failure: true,
56            server_error_as_failure: true,
57            rate_limit_as_failure: false, // Rate limits are expected, not failures
58        }
59    }
60}
61
62impl CircuitBreakerConfig {
63    /// Production preset - more tolerant, longer recovery
64    pub fn production() -> Self {
65        Self {
66            failure_threshold: 10,
67            success_threshold: 3,
68            open_duration: Duration::from_secs(60),
69            failure_window: Duration::from_secs(120),
70            timeout_as_failure: true,
71            server_error_as_failure: true,
72            rate_limit_as_failure: false,
73        }
74    }
75
76    /// Aggressive preset - quick to open, quick to recover
77    pub fn aggressive() -> Self {
78        Self {
79            failure_threshold: 3,
80            success_threshold: 1,
81            open_duration: Duration::from_secs(15),
82            failure_window: Duration::from_secs(30),
83            timeout_as_failure: true,
84            server_error_as_failure: true,
85            rate_limit_as_failure: true,
86        }
87    }
88}
89
90/// Circuit state
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum CircuitState {
93    /// Circuit is closed, requests pass through
94    Closed,
95    /// Circuit is open, requests are rejected
96    Open,
97    /// Testing if domain has recovered
98    HalfOpen,
99}
100
101/// Per-domain circuit state
102#[derive(Debug)]
103struct DomainCircuit {
104    state: CircuitState,
105    failures: Vec<Instant>,
106    successes_in_half_open: u32,
107    opened_at: Option<Instant>,
108    last_failure: Option<Instant>,
109}
110
111impl DomainCircuit {
112    fn new() -> Self {
113        Self {
114            state: CircuitState::Closed,
115            failures: Vec::new(),
116            successes_in_half_open: 0,
117            opened_at: None,
118            last_failure: None,
119        }
120    }
121
122    /// Count recent failures within the window
123    fn recent_failures(&self, window: Duration) -> u32 {
124        let cutoff = Instant::now() - window;
125        self.failures.iter().filter(|&&t| t > cutoff).count() as u32
126    }
127
128    /// Clean up old failures
129    fn cleanup_old_failures(&mut self, window: Duration) {
130        let cutoff = Instant::now() - window;
131        self.failures.retain(|&t| t > cutoff);
132    }
133}
134
135/// Circuit breaker for multiple domains
136pub struct CircuitBreaker {
137    config: CircuitBreakerConfig,
138    circuits: RwLock<HashMap<String, DomainCircuit>>,
139}
140
141impl CircuitBreaker {
142    /// Create new circuit breaker
143    pub fn new(config: CircuitBreakerConfig) -> Self {
144        Self {
145            config,
146            circuits: RwLock::new(HashMap::new()),
147        }
148    }
149
150    /// Create with default config
151    pub fn default_config() -> Self {
152        Self::new(CircuitBreakerConfig::default())
153    }
154
155    /// Check if a request to this domain is allowed
156    pub fn allow_request(&self, domain: &str) -> bool {
157        let mut circuits = self.circuits.write().unwrap();
158        let circuit = circuits.entry(domain.to_string()).or_insert_with(DomainCircuit::new);
159
160        match circuit.state {
161            CircuitState::Closed => true,
162            CircuitState::Open => {
163                // Check if we should transition to half-open
164                if let Some(opened_at) = circuit.opened_at {
165                    if opened_at.elapsed() >= self.config.open_duration {
166                        circuit.state = CircuitState::HalfOpen;
167                        circuit.successes_in_half_open = 0;
168                        true
169                    } else {
170                        false
171                    }
172                } else {
173                    false
174                }
175            }
176            CircuitState::HalfOpen => true,
177        }
178    }
179
180    /// Record a successful request
181    pub fn record_success(&self, domain: &str) {
182        let mut circuits = self.circuits.write().unwrap();
183        if let Some(circuit) = circuits.get_mut(domain) {
184            match circuit.state {
185                CircuitState::HalfOpen => {
186                    circuit.successes_in_half_open += 1;
187                    if circuit.successes_in_half_open >= self.config.success_threshold {
188                        // Transition back to closed
189                        circuit.state = CircuitState::Closed;
190                        circuit.failures.clear();
191                        circuit.opened_at = None;
192                        circuit.successes_in_half_open = 0;
193                    }
194                }
195                CircuitState::Closed => {
196                    // Nothing special, just cleanup old failures
197                    circuit.cleanup_old_failures(self.config.failure_window);
198                }
199                CircuitState::Open => {
200                    // Shouldn't happen, but handle gracefully
201                }
202            }
203        }
204    }
205
206    /// Record a failed request
207    pub fn record_failure(&self, domain: &str) {
208        let mut circuits = self.circuits.write().unwrap();
209        let circuit = circuits.entry(domain.to_string()).or_insert_with(DomainCircuit::new);
210
211        circuit.failures.push(Instant::now());
212        circuit.last_failure = Some(Instant::now());
213        circuit.cleanup_old_failures(self.config.failure_window);
214
215        match circuit.state {
216            CircuitState::Closed => {
217                if circuit.recent_failures(self.config.failure_window) >= self.config.failure_threshold {
218                    // Open the circuit
219                    circuit.state = CircuitState::Open;
220                    circuit.opened_at = Some(Instant::now());
221                }
222            }
223            CircuitState::HalfOpen => {
224                // Any failure in half-open reopens the circuit
225                circuit.state = CircuitState::Open;
226                circuit.opened_at = Some(Instant::now());
227                circuit.successes_in_half_open = 0;
228            }
229            CircuitState::Open => {
230                // Already open, refresh the timer
231                circuit.opened_at = Some(Instant::now());
232            }
233        }
234    }
235
236    /// Record a timeout (may or may not count as failure based on config)
237    pub fn record_timeout(&self, domain: &str) {
238        if self.config.timeout_as_failure {
239            self.record_failure(domain);
240        }
241    }
242
243    /// Record a server error (5xx)
244    pub fn record_server_error(&self, domain: &str) {
245        if self.config.server_error_as_failure {
246            self.record_failure(domain);
247        }
248    }
249
250    /// Record a rate limit (429)
251    pub fn record_rate_limit(&self, domain: &str) {
252        if self.config.rate_limit_as_failure {
253            self.record_failure(domain);
254        }
255    }
256
257    /// Get circuit state for a domain
258    pub fn get_state(&self, domain: &str) -> CircuitState {
259        let circuits = self.circuits.read().unwrap();
260        circuits.get(domain).map(|c| c.state).unwrap_or(CircuitState::Closed)
261    }
262
263    /// Get all open circuits (for monitoring)
264    pub fn get_open_circuits(&self) -> Vec<String> {
265        let circuits = self.circuits.read().unwrap();
266        circuits
267            .iter()
268            .filter(|(_, c)| c.state == CircuitState::Open)
269            .map(|(domain, _)| domain.clone())
270            .collect()
271    }
272
273    /// Reset circuit for a domain
274    pub fn reset(&self, domain: &str) {
275        let mut circuits = self.circuits.write().unwrap();
276        circuits.remove(domain);
277    }
278
279    /// Reset all circuits
280    pub fn reset_all(&self) {
281        let mut circuits = self.circuits.write().unwrap();
282        circuits.clear();
283    }
284
285    /// Get circuit statistics
286    pub fn stats(&self) -> CircuitBreakerStats {
287        let circuits = self.circuits.read().unwrap();
288        let total = circuits.len();
289        let open = circuits.values().filter(|c| c.state == CircuitState::Open).count();
290        let half_open = circuits.values().filter(|c| c.state == CircuitState::HalfOpen).count();
291        let closed = circuits.values().filter(|c| c.state == CircuitState::Closed).count();
292
293        CircuitBreakerStats {
294            total_domains: total,
295            open_circuits: open,
296            half_open_circuits: half_open,
297            closed_circuits: closed,
298        }
299    }
300}
301
302/// Circuit breaker statistics
303#[derive(Debug, Clone)]
304pub struct CircuitBreakerStats {
305    /// Total tracked domains
306    pub total_domains: usize,
307    /// Number of open circuits
308    pub open_circuits: usize,
309    /// Number of half-open circuits
310    pub half_open_circuits: usize,
311    /// Number of closed circuits
312    pub closed_circuits: usize,
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_circuit_starts_closed() {
321        let breaker = CircuitBreaker::default_config();
322        assert!(breaker.allow_request("example.com"));
323        assert_eq!(breaker.get_state("example.com"), CircuitState::Closed);
324    }
325
326    #[test]
327    fn test_circuit_opens_after_failures() {
328        let config = CircuitBreakerConfig {
329            failure_threshold: 3,
330            ..Default::default()
331        };
332        let breaker = CircuitBreaker::new(config);
333
334        // Record failures
335        for _ in 0..3 {
336            breaker.record_failure("example.com");
337        }
338
339        assert_eq!(breaker.get_state("example.com"), CircuitState::Open);
340        assert!(!breaker.allow_request("example.com"));
341    }
342
343    #[test]
344    fn test_circuit_transitions_to_half_open() {
345        let config = CircuitBreakerConfig {
346            failure_threshold: 2,
347            open_duration: Duration::from_millis(10),
348            ..Default::default()
349        };
350        let breaker = CircuitBreaker::new(config);
351
352        // Open the circuit
353        breaker.record_failure("example.com");
354        breaker.record_failure("example.com");
355        assert_eq!(breaker.get_state("example.com"), CircuitState::Open);
356
357        // Wait for open duration
358        std::thread::sleep(Duration::from_millis(15));
359
360        // Should transition to half-open on next request
361        assert!(breaker.allow_request("example.com"));
362        assert_eq!(breaker.get_state("example.com"), CircuitState::HalfOpen);
363    }
364
365    #[test]
366    fn test_circuit_closes_after_successes() {
367        let config = CircuitBreakerConfig {
368            failure_threshold: 2,
369            success_threshold: 2,
370            open_duration: Duration::from_millis(10),
371            ..Default::default()
372        };
373        let breaker = CircuitBreaker::new(config);
374
375        // Open the circuit
376        breaker.record_failure("example.com");
377        breaker.record_failure("example.com");
378
379        // Wait and transition to half-open
380        std::thread::sleep(Duration::from_millis(15));
381        breaker.allow_request("example.com");
382
383        // Record successes
384        breaker.record_success("example.com");
385        breaker.record_success("example.com");
386
387        assert_eq!(breaker.get_state("example.com"), CircuitState::Closed);
388    }
389
390    #[test]
391    fn test_stats() {
392        let config = CircuitBreakerConfig {
393            failure_threshold: 2,
394            ..Default::default()
395        };
396        let breaker = CircuitBreaker::new(config);
397
398        // Create some circuits
399        breaker.allow_request("good.com");
400        breaker.record_failure("bad.com");
401        breaker.record_failure("bad.com");
402
403        let stats = breaker.stats();
404        assert_eq!(stats.total_domains, 2);
405        assert_eq!(stats.open_circuits, 1);
406        assert_eq!(stats.closed_circuits, 1);
407    }
408}