chie_core/
circuit_breaker.rs

1//! Circuit breaker pattern for resilient external service calls.
2//!
3//! The circuit breaker prevents cascading failures by stopping requests to failing services
4//! and allowing them time to recover.
5//!
6//! # States
7//!
8//! - **Closed**: Normal operation, requests flow through
9//! - **Open**: Service is failing, requests are rejected immediately
10//! - **HalfOpen**: Testing if service has recovered
11//!
12//! # Example
13//!
14//! ```
15//! use chie_core::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
16//!
17//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
18//! let config = CircuitBreakerConfig::default();
19//! let mut breaker = CircuitBreaker::new("api-service", config);
20//!
21//! // Try calling a service
22//! match breaker.call(|| async {
23//!     // Make API call here
24//!     Ok::<_, String>("success")
25//! }).await {
26//!     Ok(result) => println!("Success: {}", result),
27//!     Err(e) => eprintln!("Failed: {}", e),
28//! }
29//! # Ok(())
30//! # }
31//! ```
32
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37/// Circuit breaker state.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum CircuitState {
40    /// Circuit is closed, requests flow through normally
41    Closed,
42    /// Circuit is open, requests are rejected
43    Open,
44    /// Circuit is half-open, testing if service recovered
45    HalfOpen,
46}
47
48/// Circuit breaker configuration.
49#[derive(Debug, Clone)]
50pub struct CircuitBreakerConfig {
51    /// Number of failures before opening circuit
52    pub failure_threshold: u32,
53    /// Time to wait before attempting to close circuit
54    pub timeout: Duration,
55    /// Number of successful calls needed to close circuit from half-open
56    pub success_threshold: u32,
57}
58
59impl Default for CircuitBreakerConfig {
60    #[inline]
61    fn default() -> Self {
62        Self {
63            failure_threshold: 5,
64            timeout: Duration::from_secs(60),
65            success_threshold: 2,
66        }
67    }
68}
69
70/// Circuit breaker for resilient service calls.
71pub struct CircuitBreaker {
72    name: String,
73    config: CircuitBreakerConfig,
74    state: Arc<RwLock<CircuitBreakerState>>,
75}
76
77#[derive(Debug)]
78struct CircuitBreakerState {
79    circuit_state: CircuitState,
80    failure_count: u32,
81    success_count: u32,
82    last_failure_time: Option<Instant>,
83}
84
85impl CircuitBreaker {
86    /// Create a new circuit breaker.
87    #[must_use]
88    pub fn new(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
89        Self {
90            name: name.into(),
91            config,
92            state: Arc::new(RwLock::new(CircuitBreakerState {
93                circuit_state: CircuitState::Closed,
94                failure_count: 0,
95                success_count: 0,
96                last_failure_time: None,
97            })),
98        }
99    }
100
101    /// Get the current state of the circuit breaker.
102    #[must_use]
103    pub async fn state(&self) -> CircuitState {
104        self.state.read().await.circuit_state
105    }
106
107    /// Get the name of the circuit breaker.
108    #[must_use]
109    #[inline]
110    pub fn name(&self) -> &str {
111        &self.name
112    }
113
114    /// Execute a function with circuit breaker protection.
115    pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
116    where
117        F: FnOnce() -> Fut,
118        Fut: std::future::Future<Output = Result<T, E>>,
119    {
120        // Check if we should attempt the call
121        if let Err(CircuitBreakerError::CircuitOpen) = self.check_state::<E>().await {
122            return Err(CircuitBreakerError::CircuitOpen);
123        }
124
125        // Execute the function
126        match f().await {
127            Ok(result) => {
128                self.on_success().await;
129                Ok(result)
130            }
131            Err(e) => {
132                self.on_failure().await;
133                Err(CircuitBreakerError::CallFailed(e))
134            }
135        }
136    }
137
138    /// Check if we can make a call based on current state.
139    async fn check_state<E>(&self) -> Result<(), CircuitBreakerError<E>> {
140        let mut state = self.state.write().await;
141
142        match state.circuit_state {
143            CircuitState::Closed => Ok(()),
144            CircuitState::Open => {
145                // Check if timeout has elapsed
146                if let Some(last_failure) = state.last_failure_time {
147                    if last_failure.elapsed() >= self.config.timeout {
148                        // Transition to half-open
149                        state.circuit_state = CircuitState::HalfOpen;
150                        state.success_count = 0;
151                        tracing::info!(
152                            circuit_breaker = %self.name,
153                            "Circuit breaker transitioning to half-open"
154                        );
155                        Ok(())
156                    } else {
157                        Err(CircuitBreakerError::CircuitOpen)
158                    }
159                } else {
160                    Err(CircuitBreakerError::CircuitOpen)
161                }
162            }
163            CircuitState::HalfOpen => Ok(()),
164        }
165    }
166
167    /// Handle successful call.
168    async fn on_success(&self) {
169        let mut state = self.state.write().await;
170
171        match state.circuit_state {
172            CircuitState::Closed => {
173                // Reset failure count on success
174                state.failure_count = 0;
175            }
176            CircuitState::HalfOpen => {
177                state.success_count += 1;
178                if state.success_count >= self.config.success_threshold {
179                    // Close the circuit
180                    state.circuit_state = CircuitState::Closed;
181                    state.failure_count = 0;
182                    state.success_count = 0;
183                    tracing::info!(
184                        circuit_breaker = %self.name,
185                        "Circuit breaker closed after recovery"
186                    );
187                }
188            }
189            CircuitState::Open => {}
190        }
191    }
192
193    /// Handle failed call.
194    async fn on_failure(&self) {
195        let mut state = self.state.write().await;
196
197        match state.circuit_state {
198            CircuitState::Closed => {
199                state.failure_count += 1;
200                if state.failure_count >= self.config.failure_threshold {
201                    // Open the circuit
202                    state.circuit_state = CircuitState::Open;
203                    state.last_failure_time = Some(Instant::now());
204                    tracing::warn!(
205                        circuit_breaker = %self.name,
206                        failures = state.failure_count,
207                        "Circuit breaker opened due to failures"
208                    );
209                }
210            }
211            CircuitState::HalfOpen => {
212                // Failed during half-open, go back to open
213                state.circuit_state = CircuitState::Open;
214                state.last_failure_time = Some(Instant::now());
215                state.success_count = 0;
216                tracing::warn!(
217                    circuit_breaker = %self.name,
218                    "Circuit breaker reopened after failed recovery attempt"
219                );
220            }
221            CircuitState::Open => {
222                // Update last failure time
223                state.last_failure_time = Some(Instant::now());
224            }
225        }
226    }
227
228    /// Manually reset the circuit breaker to closed state.
229    pub async fn reset(&self) {
230        let mut state = self.state.write().await;
231        state.circuit_state = CircuitState::Closed;
232        state.failure_count = 0;
233        state.success_count = 0;
234        state.last_failure_time = None;
235        tracing::info!(circuit_breaker = %self.name, "Circuit breaker manually reset");
236    }
237}
238
239/// Circuit breaker error types.
240#[derive(Debug, thiserror::Error)]
241pub enum CircuitBreakerError<E> {
242    /// The circuit is open and not accepting requests
243    #[error("Circuit breaker is open")]
244    CircuitOpen,
245    /// The underlying call failed
246    #[error("Call failed: {0}")]
247    CallFailed(E),
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[tokio::test]
255    async fn test_circuit_breaker_closed_state() {
256        let config = CircuitBreakerConfig {
257            failure_threshold: 3,
258            timeout: Duration::from_secs(1),
259            success_threshold: 2,
260        };
261        let breaker = CircuitBreaker::new("test", config);
262
263        assert_eq!(breaker.state().await, CircuitState::Closed);
264
265        // Successful call
266        let result = breaker.call(|| async { Ok::<_, String>("success") }).await;
267        assert!(result.is_ok());
268        assert_eq!(breaker.state().await, CircuitState::Closed);
269    }
270
271    #[tokio::test]
272    async fn test_circuit_breaker_opens_after_failures() {
273        let config = CircuitBreakerConfig {
274            failure_threshold: 3,
275            timeout: Duration::from_secs(1),
276            success_threshold: 2,
277        };
278        let breaker = CircuitBreaker::new("test", config);
279
280        // Make 3 failing calls
281        for _ in 0..3 {
282            let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
283        }
284
285        assert_eq!(breaker.state().await, CircuitState::Open);
286    }
287
288    #[tokio::test]
289    async fn test_circuit_breaker_rejects_when_open() {
290        let config = CircuitBreakerConfig {
291            failure_threshold: 2,
292            timeout: Duration::from_secs(10),
293            success_threshold: 2,
294        };
295        let breaker = CircuitBreaker::new("test", config);
296
297        // Open the circuit
298        for _ in 0..2 {
299            let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
300        }
301
302        // Next call should be rejected
303        let result = breaker.call(|| async { Ok::<_, String>("success") }).await;
304        assert!(matches!(result, Err(CircuitBreakerError::CircuitOpen)));
305    }
306
307    #[tokio::test]
308    async fn test_circuit_breaker_half_open() {
309        let config = CircuitBreakerConfig {
310            failure_threshold: 2,
311            timeout: Duration::from_millis(100),
312            success_threshold: 2,
313        };
314        let breaker = CircuitBreaker::new("test", config);
315
316        // Open the circuit
317        for _ in 0..2 {
318            let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
319        }
320
321        assert_eq!(breaker.state().await, CircuitState::Open);
322
323        // Wait for timeout
324        tokio::time::sleep(Duration::from_millis(150)).await;
325
326        // Next call should transition to half-open
327        let _ = breaker.call(|| async { Ok::<_, String>("success") }).await;
328        let state = breaker.state().await;
329        assert!(state == CircuitState::HalfOpen || state == CircuitState::Closed);
330    }
331
332    #[tokio::test]
333    async fn test_circuit_breaker_recovery() {
334        let config = CircuitBreakerConfig {
335            failure_threshold: 2,
336            timeout: Duration::from_millis(100),
337            success_threshold: 2,
338        };
339        let breaker = CircuitBreaker::new("test", config);
340
341        // Open the circuit
342        for _ in 0..2 {
343            let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
344        }
345
346        // Wait for timeout
347        tokio::time::sleep(Duration::from_millis(150)).await;
348
349        // Make 2 successful calls to close circuit
350        for _ in 0..2 {
351            let _ = breaker.call(|| async { Ok::<_, String>("success") }).await;
352        }
353
354        assert_eq!(breaker.state().await, CircuitState::Closed);
355    }
356
357    #[tokio::test]
358    async fn test_circuit_breaker_reset() {
359        let config = CircuitBreakerConfig {
360            failure_threshold: 2,
361            timeout: Duration::from_secs(10),
362            success_threshold: 2,
363        };
364        let breaker = CircuitBreaker::new("test", config);
365
366        // Open the circuit
367        for _ in 0..2 {
368            let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
369        }
370
371        assert_eq!(breaker.state().await, CircuitState::Open);
372
373        // Reset manually
374        breaker.reset().await;
375        assert_eq!(breaker.state().await, CircuitState::Closed);
376    }
377}