Skip to main content

research_master/utils/
circuit_breaker.rs

1//! Circuit breaker pattern implementation for API resilience.
2//!
3//! The circuit breaker prevents cascading failures by temporarily disabling
4//! requests to sources that are failing. It has three states:
5//!
6//! - **Closed**: Normal operation, requests pass through
7//! - **Open**: Source is failing, requests are immediately rejected
8//! - **Half-Open**: Testing if the source has recovered
9//!
10//! # Usage
11//!
12//! ```rust
13//! use research_master::utils::{CircuitBreaker, CircuitState};
14//!
15//! let breaker = CircuitBreaker::new("semantic", 5, std::time::Duration::from_secs(60));
16//!
17//! assert_eq!(breaker.state(), CircuitState::Closed);
18//! ```
19
20use std::sync::atomic::{AtomicUsize, Ordering};
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23
24/// Circuit breaker states
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum CircuitState {
27    /// Normal operation - requests pass through
28    Closed,
29    /// Failing - requests are rejected
30    Open,
31    /// Testing recovery - limited requests allowed
32    HalfOpen,
33}
34
35/// Circuit breaker configuration
36#[derive(Debug, Clone)]
37pub struct CircuitBreakerConfig {
38    /// Number of failures before opening the circuit
39    pub failure_threshold: usize,
40
41    /// Number of successes in half-open state to close the circuit
42    pub success_threshold: usize,
43
44    /// Duration to stay open before trying half-open
45    pub open_duration: Duration,
46}
47
48/// Result of a circuit breaker operation
49#[derive(Debug, Clone)]
50pub enum CircuitResult<T> {
51    /// Operation succeeded
52    Success(T),
53    /// Operation failed but circuit is still closed
54    Failure(String),
55    /// Circuit is open, request was rejected
56    Rejected(String),
57    /// Circuit is open but request was allowed in half-open state
58    RetryAllowed(String),
59}
60
61impl<T> CircuitResult<T> {
62    /// Check if the operation was successful
63    pub fn is_success(&self) -> bool {
64        matches!(self, CircuitResult::Success(_))
65    }
66
67    /// Check if the result is a rejection due to open circuit
68    pub fn is_rejected(&self) -> bool {
69        matches!(self, CircuitResult::Rejected(_))
70    }
71
72    /// Unwrap the inner value (panics if not Success)
73    pub fn unwrap(self) -> T {
74        match self {
75            CircuitResult::Success(v) => v,
76            CircuitResult::Failure(e) => panic!("unwrap on Failure: {}", e),
77            CircuitResult::Rejected(e) => panic!("unwrap on Rejected: {}", e),
78            CircuitResult::RetryAllowed(e) => panic!("unwrap on RetryAllowed: {}", e),
79        }
80    }
81}
82
83/// Thread-safe circuit breaker implementation
84#[derive(Debug)]
85pub struct CircuitBreaker {
86    /// Circuit name (e.g., "arxiv", "semantic")
87    name: String,
88
89    /// Current state
90    state: std::sync::atomic::AtomicU8,
91
92    /// Number of consecutive failures
93    failure_count: Arc<AtomicUsize>,
94
95    /// Number of consecutive successes (in half-open state)
96    success_count: Arc<AtomicUsize>,
97
98    /// Time when circuit was opened (Instant::now().elapsed() in milliseconds)
99    open_since_ms: std::sync::atomic::AtomicU64,
100
101    /// Configuration
102    config: CircuitBreakerConfig,
103}
104
105impl CircuitBreaker {
106    /// Create a new circuit breaker
107    ///
108    /// - `name`: Identifier for this circuit (e.g., "arxiv")
109    /// - `failure_threshold`: Failures before opening (default: 5)
110    /// - `open_duration`: Time to stay open before half-open (default: 60s)
111    pub fn new(name: &str, failure_threshold: usize, open_duration: Duration) -> Self {
112        Self {
113            name: name.to_string(),
114            state: std::sync::atomic::AtomicU8::new(CircuitState::Closed as u8),
115            failure_count: Arc::new(AtomicUsize::new(0)),
116            success_count: Arc::new(AtomicUsize::new(0)),
117            open_since_ms: std::sync::atomic::AtomicU64::new(0),
118            config: CircuitBreakerConfig {
119                failure_threshold,
120                success_threshold: 3,
121                open_duration,
122            },
123        }
124    }
125
126    /// Create with default settings
127    pub fn default_for(name: &str) -> Self {
128        Self::new(name, 5, Duration::from_secs(60))
129    }
130
131    /// Get the current state
132    pub fn state(&self) -> CircuitState {
133        let state = self.state.load(Ordering::SeqCst);
134        let state = CircuitState::try_from(state).unwrap_or(CircuitState::Closed);
135
136        // Check if we should transition from open to half-open
137        if state == CircuitState::Open {
138            if let Some(since) = self.open_time() {
139                if since.elapsed() >= self.config.open_duration {
140                    return CircuitState::HalfOpen;
141                }
142            }
143        }
144
145        state
146    }
147
148    /// Get the time when the circuit was opened
149    fn open_time(&self) -> Option<Instant> {
150        let ts = self.open_since_ms.load(Ordering::SeqCst);
151        if ts == 0 {
152            None
153        } else {
154            Some(Instant::now() - Duration::from_millis(ts))
155        }
156    }
157
158    /// Record a success
159    pub fn record_success(&self) {
160        let state = self.state();
161
162        match state {
163            CircuitState::Closed => {
164                // Reset failure count on success
165                self.failure_count.store(0, Ordering::SeqCst);
166            }
167            CircuitState::HalfOpen => {
168                let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
169                if count >= self.config.success_threshold {
170                    // Transition back to closed
171                    self.state
172                        .store(CircuitState::Closed as u8, Ordering::SeqCst);
173                    self.failure_count.store(0, Ordering::SeqCst);
174                    self.success_count.store(0, Ordering::SeqCst);
175                    self.open_since_ms.store(0, Ordering::SeqCst);
176                    tracing::info!(
177                        "[circuit-breaker] {}: circuit closed (recovered)",
178                        self.name
179                    );
180                }
181            }
182            CircuitState::Open => {
183                // Shouldn't happen, but handle gracefully
184            }
185        }
186    }
187
188    /// Record a failure
189    pub fn record_failure(&self) {
190        let state = self.state();
191
192        match state {
193            CircuitState::Closed => {
194                let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
195                if count >= self.config.failure_threshold {
196                    // Transition to open and record the time
197                    self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
198                    self.open_since_ms.store(
199                        Instant::now().elapsed().as_millis().try_into().unwrap_or(0),
200                        Ordering::SeqCst,
201                    );
202                    tracing::warn!(
203                        "[circuit-breaker] {}: circuit opened ({} failures)",
204                        self.name,
205                        count
206                    );
207                }
208            }
209            CircuitState::HalfOpen => {
210                // Any failure in half-open goes back to open
211                self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
212                self.success_count.store(0, Ordering::SeqCst);
213                tracing::warn!(
214                    "[circuit-breaker] {}: circuit reopened (failure in half-open)",
215                    self.name
216                );
217            }
218            CircuitState::Open => {
219                // Already open, nothing to do
220            }
221        }
222    }
223
224    /// Check if a request should be allowed
225    pub fn can_request(&self) -> bool {
226        let state = self.state();
227        match state {
228            CircuitState::Closed | CircuitState::HalfOpen => true,
229            CircuitState::Open => false,
230        }
231    }
232
233    /// Execute an async operation with circuit breaker protection
234    ///
235    /// Returns `CircuitResult::Rejected` if the circuit is open.
236    pub async fn execute<F, T, E>(&self, operation: F) -> CircuitResult<T>
237    where
238        F: std::future::Future<Output = Result<T, E>>,
239        E: std::fmt::Display,
240    {
241        let state = self.state();
242
243        match state {
244            CircuitState::Closed => match operation.await {
245                Ok(result) => {
246                    self.record_success();
247                    CircuitResult::Success(result)
248                }
249                Err(e) => {
250                    self.record_failure();
251                    CircuitResult::Failure(e.to_string())
252                }
253            },
254            CircuitState::Open => CircuitResult::Rejected(format!(
255                "circuit is open for {} (source may be temporarily unavailable)",
256                self.name
257            )),
258            CircuitState::HalfOpen => {
259                // Allow one request to test recovery
260                match operation.await {
261                    Ok(_result) => {
262                        self.record_success();
263                        CircuitResult::RetryAllowed("half-open: success".to_string())
264                    }
265                    Err(e) => {
266                        self.record_failure();
267                        CircuitResult::Failure(e.to_string())
268                    }
269                }
270            }
271        }
272    }
273
274    /// Reset the circuit breaker to closed state
275    pub fn reset(&self) {
276        self.state
277            .store(CircuitState::Closed as u8, Ordering::SeqCst);
278        self.failure_count.store(0, Ordering::SeqCst);
279        self.success_count.store(0, Ordering::SeqCst);
280        self.open_since_ms.store(0, Ordering::SeqCst);
281    }
282}
283
284impl TryFrom<u8> for CircuitState {
285    type Error = ();
286
287    fn try_from(value: u8) -> Result<Self, Self::Error> {
288        match value {
289            0 => Ok(CircuitState::Closed),
290            1 => Ok(CircuitState::Open),
291            2 => Ok(CircuitState::HalfOpen),
292            _ => Err(()),
293        }
294    }
295}
296
297/// Manager for multiple circuit breakers (one per source)
298#[derive(Debug, Default)]
299pub struct CircuitBreakerManager {
300    breakers: Arc<std::sync::RwLock<std::collections::HashMap<String, Arc<CircuitBreaker>>>>,
301}
302
303impl CircuitBreakerManager {
304    /// Create a new manager
305    pub fn new() -> Self {
306        Self {
307            breakers: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
308        }
309    }
310
311    /// Get or create a circuit breaker for a source
312    pub fn get(&self, source_id: &str) -> Arc<CircuitBreaker> {
313        {
314            let read_guard = self.breakers.read().expect("RwLock poisoned");
315            if let Some(breaker) = read_guard.get(source_id) {
316                return Arc::clone(breaker);
317            }
318        }
319
320        {
321            let mut write_guard = self.breakers.write().expect("RwLock poisoned");
322            // Double-check after acquiring write lock
323            if let Some(breaker) = write_guard.get(source_id) {
324                return Arc::clone(breaker);
325            }
326
327            let breaker = Arc::new(CircuitBreaker::default_for(source_id));
328            write_guard.insert(source_id.to_string(), Arc::clone(&breaker));
329            breaker
330        }
331    }
332
333    /// Reset all circuit breakers
334    pub fn reset_all(&self) {
335        let guard = self.breakers.write().expect("RwLock poisoned");
336        for breaker in guard.values() {
337            breaker.reset();
338        }
339    }
340
341    /// Get status of all circuit breakers
342    pub fn status(&self) -> Vec<(String, CircuitState, bool)> {
343        let guard = self.breakers.read().expect("RwLock poisoned");
344        guard
345            .iter()
346            .map(|(name, breaker)| (name.clone(), breaker.state(), breaker.can_request()))
347            .collect()
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use std::time::Duration;
355
356    #[tokio::test]
357    async fn test_circuit_breaker_closed_by_default() {
358        let breaker = CircuitBreaker::default_for("test");
359        assert_eq!(breaker.state(), CircuitState::Closed);
360        assert!(breaker.can_request());
361    }
362
363    #[tokio::test]
364    async fn test_circuit_breaker_opens_after_failures() {
365        let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
366
367        // Record 2 failures - should still be closed
368        breaker.record_failure();
369        breaker.record_failure();
370        assert_eq!(breaker.state(), CircuitState::Closed);
371        assert!(breaker.can_request());
372
373        // Record 3rd failure - should open
374        breaker.record_failure();
375        assert_eq!(breaker.state(), CircuitState::Open);
376        assert!(!breaker.can_request());
377    }
378
379    #[tokio::test]
380    async fn test_circuit_breaker_success_resets() {
381        let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
382
383        breaker.record_failure();
384        breaker.record_failure();
385        assert_eq!(breaker.failure_count.load(Ordering::SeqCst), 2);
386
387        breaker.record_success();
388        assert_eq!(breaker.failure_count.load(Ordering::SeqCst), 0);
389    }
390
391    #[tokio::test]
392    async fn test_circuit_breaker_execute_success() {
393        let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
394
395        let result = breaker.execute(async { Ok::<i32, &str>(42) }).await;
396        assert!(result.is_success());
397        assert_eq!(result.unwrap(), 42);
398    }
399
400    #[tokio::test]
401    async fn test_circuit_breaker_execute_rejected() {
402        let breaker = Arc::new(CircuitBreaker::new("test", 1, Duration::from_secs(60)));
403
404        // Open the circuit
405        breaker.record_failure();
406        assert_eq!(breaker.state(), CircuitState::Open);
407
408        // Execute should be rejected
409        let result = breaker.execute(async { Ok::<i32, &str>(42) }).await;
410        assert!(result.is_rejected());
411    }
412
413    #[test]
414    fn test_manager() {
415        let manager = CircuitBreakerManager::new();
416
417        // Get two circuit breakers
418        let breaker1 = manager.get("source1");
419        let breaker2 = manager.get("source2");
420        let breaker1_again = manager.get("source1");
421
422        // Should be the same instance
423        assert!(Arc::ptr_eq(&breaker1, &breaker1_again));
424        // Different sources should be different
425        assert!(!Arc::ptr_eq(&breaker1, &breaker2));
426    }
427
428    #[test]
429    fn test_manager_status() {
430        let manager = CircuitBreakerManager::new();
431
432        let _ = manager.get("arxiv");
433        let _ = manager.get("semantic");
434
435        let status = manager.status();
436        assert_eq!(status.len(), 2);
437        assert!(status
438            .iter()
439            .all(|(_, state, _)| *state == CircuitState::Closed));
440    }
441}