ai_lib/circuit_breaker/
breaker.rs

1//! Circuit breaker implementation
2
3use crate::circuit_breaker::{CircuitBreakerConfig, CircuitState};
4use crate::metrics::Metrics;
5use crate::types::AiLibError;
6use futures::Future;
7use serde::{Deserialize, Serialize};
8use std::sync::atomic::{AtomicU32, AtomicU64};
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::time::timeout;
12
13/// Circuit breaker error types
14#[derive(Debug, thiserror::Error)]
15pub enum CircuitBreakerError {
16    #[error("Circuit breaker is open: {0}")]
17    CircuitOpen(String),
18    #[error("Request timeout: {0}")]
19    RequestTimeout(String),
20    #[error("Underlying error: {0}")]
21    Underlying(#[from] AiLibError),
22    #[error("Circuit breaker is disabled")]
23    Disabled,
24}
25
26/// Circuit breaker metrics for monitoring
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CircuitBreakerMetrics {
29    pub state: CircuitState,
30    pub total_requests: u64,
31    pub successful_requests: u64,
32    pub failed_requests: u64,
33    pub timeout_requests: u64,
34    pub circuit_open_count: u64,
35    pub circuit_close_count: u64,
36    pub current_failure_count: u32,
37    pub current_success_count: u32,
38    #[serde(skip)]
39    pub last_failure_time: Option<Instant>,
40    #[serde(skip)]
41    pub uptime: Duration,
42}
43
44/// Circuit breaker implementation
45pub struct CircuitBreaker {
46    state: Arc<Mutex<CircuitState>>,
47    config: CircuitBreakerConfig,
48    failure_count: Arc<AtomicU32>,
49    success_count: Arc<AtomicU32>,
50    last_failure_time: Arc<Mutex<Option<Instant>>>,
51    // Metrics
52    total_requests: Arc<AtomicU64>,
53    successful_requests: Arc<AtomicU64>,
54    failed_requests: Arc<AtomicU64>,
55    timeout_requests: Arc<AtomicU64>,
56    circuit_open_count: Arc<AtomicU64>,
57    circuit_close_count: Arc<AtomicU64>,
58    start_time: Instant,
59    // Optional metrics collector
60    metrics: Option<Arc<dyn Metrics>>,
61    // Circuit breaker enabled flag
62    enabled: bool,
63}
64
65impl CircuitBreaker {
66    /// Create a new circuit breaker with the given configuration
67    pub fn new(config: CircuitBreakerConfig) -> Self {
68        Self {
69            state: Arc::new(Mutex::new(CircuitState::Closed)),
70            config,
71            failure_count: Arc::new(AtomicU32::new(0)),
72            success_count: Arc::new(AtomicU32::new(0)),
73            last_failure_time: Arc::new(Mutex::new(None)),
74            total_requests: Arc::new(AtomicU64::new(0)),
75            successful_requests: Arc::new(AtomicU64::new(0)),
76            failed_requests: Arc::new(AtomicU64::new(0)),
77            timeout_requests: Arc::new(AtomicU64::new(0)),
78            circuit_open_count: Arc::new(AtomicU64::new(0)),
79            circuit_close_count: Arc::new(AtomicU64::new(0)),
80            start_time: Instant::now(),
81            metrics: None,
82            enabled: true,
83        }
84    }
85
86    /// Create a new circuit breaker with metrics collection
87    pub fn with_metrics(config: CircuitBreakerConfig, metrics: Arc<dyn Metrics>) -> Self {
88        Self {
89            state: Arc::new(Mutex::new(CircuitState::Closed)),
90            config,
91            failure_count: Arc::new(AtomicU32::new(0)),
92            success_count: Arc::new(AtomicU32::new(0)),
93            last_failure_time: Arc::new(Mutex::new(None)),
94            total_requests: Arc::new(AtomicU64::new(0)),
95            successful_requests: Arc::new(AtomicU64::new(0)),
96            failed_requests: Arc::new(AtomicU64::new(0)),
97            timeout_requests: Arc::new(AtomicU64::new(0)),
98            circuit_open_count: Arc::new(AtomicU64::new(0)),
99            circuit_close_count: Arc::new(AtomicU64::new(0)),
100            start_time: Instant::now(),
101            metrics: Some(metrics),
102            enabled: true,
103        }
104    }
105
106    /// Enable or disable the circuit breaker
107    pub fn set_enabled(&mut self, enabled: bool) {
108        self.enabled = enabled;
109    }
110
111    /// Execute a function with circuit breaker protection
112    pub async fn call<F, T>(&self, f: F) -> Result<T, CircuitBreakerError>
113    where
114        F: Future<Output = Result<T, AiLibError>>,
115    {
116        // Check if circuit breaker is enabled
117        if !self.enabled {
118            return f.await.map_err(CircuitBreakerError::Underlying);
119        }
120
121        // Increment total requests counter
122        self.total_requests
123            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
124
125        // Check if we should allow the request
126        if !self.should_allow_request().await {
127            return Err(CircuitBreakerError::CircuitOpen(
128                "Circuit breaker is open".to_string(),
129            ));
130        }
131
132        // Execute the request with timeout
133        let result = timeout(self.config.request_timeout, f).await;
134
135        match result {
136            Ok(Ok(response)) => {
137                self.on_success().await;
138                Ok(response)
139            }
140            Ok(Err(error)) => {
141                self.on_failure().await;
142                Err(CircuitBreakerError::Underlying(error))
143            }
144            Err(_) => {
145                self.on_timeout().await;
146                Err(CircuitBreakerError::RequestTimeout(
147                    "Request timed out".to_string(),
148                ))
149            }
150        }
151    }
152
153    /// Check if the circuit breaker should allow a request
154    async fn should_allow_request(&self) -> bool {
155        let state = *self.state.lock().unwrap();
156
157        match state {
158            CircuitState::Closed => true,
159            CircuitState::Open => {
160                // Check if enough time has passed to try half-open
161                let allow_half_open = {
162                    let last = self.last_failure_time.lock().unwrap();
163                    last.and_then(|t| Some(t.elapsed() >= self.config.recovery_timeout))
164                        .unwrap_or(false)
165                };
166                if allow_half_open {
167                    self.transition_to_half_open().await;
168                    true
169                } else {
170                    false
171                }
172            }
173            CircuitState::HalfOpen => true,
174        }
175    }
176
177    /// Handle successful request
178    async fn on_success(&self) {
179        self.successful_requests
180            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
181
182        let mut record_closed_metric = false;
183        {
184            let mut state = self.state.lock().unwrap();
185            match *state {
186                CircuitState::Closed => {
187                    // Reset failure count on success
188                    self.failure_count
189                        .store(0, std::sync::atomic::Ordering::Relaxed);
190                }
191                CircuitState::HalfOpen => {
192                    let success_count = self
193                        .success_count
194                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
195                        + 1;
196                    if success_count >= self.config.success_threshold {
197                        *state = CircuitState::Closed;
198                        self.success_count
199                            .store(0, std::sync::atomic::Ordering::Relaxed);
200                        self.circuit_close_count
201                            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
202                        record_closed_metric = true;
203                    }
204                }
205                CircuitState::Open => {
206                    // This shouldn't happen, but handle gracefully
207                }
208            }
209        }
210        if record_closed_metric {
211            if let Some(metrics) = &self.metrics {
212                metrics.incr_counter("circuit_breaker.closed", 1).await;
213            }
214        }
215    }
216
217    /// Handle failed request
218    async fn on_failure(&self) {
219        self.failed_requests
220            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
221
222        let failure_count = self
223            .failure_count
224            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
225            + 1;
226
227        // Record failure time
228        *self.last_failure_time.lock().unwrap() = Some(Instant::now());
229
230        // Check if we should open the circuit
231        if failure_count >= self.config.failure_threshold {
232            {
233                let mut state = self.state.lock().unwrap();
234                *state = CircuitState::Open;
235            }
236            self.circuit_open_count
237                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
238
239            // Record metrics
240            if let Some(metrics) = &self.metrics {
241                let m = metrics.clone();
242                m.incr_counter("circuit_breaker.opened", 1).await;
243            }
244        }
245    }
246
247    /// Handle timeout request
248    async fn on_timeout(&self) {
249        self.timeout_requests
250            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
251        self.failed_requests
252            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
253
254        let failure_count = self
255            .failure_count
256            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
257            + 1;
258
259        // Record failure time
260        *self.last_failure_time.lock().unwrap() = Some(Instant::now());
261
262        // Check if we should open the circuit
263        if failure_count >= self.config.failure_threshold {
264            {
265                let mut state = self.state.lock().unwrap();
266                *state = CircuitState::Open;
267            }
268            self.circuit_open_count
269                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
270
271            // Record metrics
272            if let Some(metrics) = &self.metrics {
273                let m = metrics.clone();
274                m.incr_counter("circuit_breaker.opened", 1).await;
275            }
276        }
277    }
278
279    /// Transition to half-open state
280    async fn transition_to_half_open(&self) {
281        let mut state = self.state.lock().unwrap();
282        *state = CircuitState::HalfOpen;
283        self.success_count
284            .store(0, std::sync::atomic::Ordering::Relaxed);
285    }
286
287    /// Get current circuit state
288    pub fn state(&self) -> CircuitState {
289        *self.state.lock().unwrap()
290    }
291
292    /// Get current failure count
293    pub fn failure_count(&self) -> u32 {
294        self.failure_count
295            .load(std::sync::atomic::Ordering::Relaxed)
296    }
297
298    /// Get current success count
299    pub fn success_count(&self) -> u32 {
300        self.success_count
301            .load(std::sync::atomic::Ordering::Relaxed)
302    }
303
304    /// Get comprehensive metrics
305    pub fn get_metrics(&self) -> CircuitBreakerMetrics {
306        CircuitBreakerMetrics {
307            state: self.state(),
308            total_requests: self
309                .total_requests
310                .load(std::sync::atomic::Ordering::Relaxed),
311            successful_requests: self
312                .successful_requests
313                .load(std::sync::atomic::Ordering::Relaxed),
314            failed_requests: self
315                .failed_requests
316                .load(std::sync::atomic::Ordering::Relaxed),
317            timeout_requests: self
318                .timeout_requests
319                .load(std::sync::atomic::Ordering::Relaxed),
320            circuit_open_count: self
321                .circuit_open_count
322                .load(std::sync::atomic::Ordering::Relaxed),
323            circuit_close_count: self
324                .circuit_close_count
325                .load(std::sync::atomic::Ordering::Relaxed),
326            current_failure_count: self.failure_count(),
327            current_success_count: self.success_count(),
328            last_failure_time: *self.last_failure_time.lock().unwrap(),
329            uptime: self.start_time.elapsed(),
330        }
331    }
332
333    /// Get success rate as a percentage
334    pub fn success_rate(&self) -> f64 {
335        let total = self
336            .total_requests
337            .load(std::sync::atomic::Ordering::Relaxed);
338        if total == 0 {
339            return 100.0;
340        }
341        let successful = self
342            .successful_requests
343            .load(std::sync::atomic::Ordering::Relaxed);
344        (successful as f64 / total as f64) * 100.0
345    }
346
347    /// Get failure rate as a percentage
348    pub fn failure_rate(&self) -> f64 {
349        let total = self
350            .total_requests
351            .load(std::sync::atomic::Ordering::Relaxed);
352        if total == 0 {
353            return 0.0;
354        }
355        let failed = self
356            .failed_requests
357            .load(std::sync::atomic::Ordering::Relaxed);
358        (failed as f64 / total as f64) * 100.0
359    }
360
361    /// Check if circuit breaker is healthy
362    pub fn is_healthy(&self) -> bool {
363        self.state() == CircuitState::Closed && self.failure_rate() < 50.0
364    }
365
366    /// Reset all counters and state
367    pub fn reset(&self) {
368        self.failure_count
369            .store(0, std::sync::atomic::Ordering::Relaxed);
370        self.success_count
371            .store(0, std::sync::atomic::Ordering::Relaxed);
372        self.total_requests
373            .store(0, std::sync::atomic::Ordering::Relaxed);
374        self.successful_requests
375            .store(0, std::sync::atomic::Ordering::Relaxed);
376        self.failed_requests
377            .store(0, std::sync::atomic::Ordering::Relaxed);
378        self.timeout_requests
379            .store(0, std::sync::atomic::Ordering::Relaxed);
380        self.circuit_open_count
381            .store(0, std::sync::atomic::Ordering::Relaxed);
382        self.circuit_close_count
383            .store(0, std::sync::atomic::Ordering::Relaxed);
384
385        let mut state = self.state.lock().unwrap();
386        *state = CircuitState::Closed;
387
388        let mut last_failure = self.last_failure_time.lock().unwrap();
389        *last_failure = None;
390    }
391
392    /// Force circuit breaker to open state
393    pub fn force_open(&self) {
394        let mut state = self.state.lock().unwrap();
395        *state = CircuitState::Open;
396        self.circuit_open_count
397            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
398    }
399
400    /// Force circuit breaker to closed state
401    pub fn force_close(&self) {
402        let mut state = self.state.lock().unwrap();
403        *state = CircuitState::Closed;
404        self.failure_count
405            .store(0, std::sync::atomic::Ordering::Relaxed);
406        self.success_count
407            .store(0, std::sync::atomic::Ordering::Relaxed);
408        self.circuit_close_count
409            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
410    }
411}