1use crate::circuit_breaker::{CircuitBreakerConfig, CircuitState};
4use crate::metrics::Metrics;
5use crate::types::AiLibError;
6use futures::Future;
7use serde::{Deserialize, Serialize};
8use std::sync::atomic::{AtomicU32, AtomicU64};
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::time::timeout;
12
13#[derive(Debug, thiserror::Error)]
15pub enum CircuitBreakerError {
16 #[error("Circuit breaker is open: {0}")]
17 CircuitOpen(String),
18 #[error("Request timeout: {0}")]
19 RequestTimeout(String),
20 #[error("Underlying error: {0}")]
21 Underlying(#[from] AiLibError),
22 #[error("Circuit breaker is disabled")]
23 Disabled,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CircuitBreakerMetrics {
29 pub state: CircuitState,
30 pub total_requests: u64,
31 pub successful_requests: u64,
32 pub failed_requests: u64,
33 pub timeout_requests: u64,
34 pub circuit_open_count: u64,
35 pub circuit_close_count: u64,
36 pub current_failure_count: u32,
37 pub current_success_count: u32,
38 #[serde(skip)]
39 pub last_failure_time: Option<Instant>,
40 #[serde(skip)]
41 pub uptime: Duration,
42}
43
44pub struct CircuitBreaker {
46 state: Arc<Mutex<CircuitState>>,
47 config: CircuitBreakerConfig,
48 failure_count: Arc<AtomicU32>,
49 success_count: Arc<AtomicU32>,
50 last_failure_time: Arc<Mutex<Option<Instant>>>,
51 total_requests: Arc<AtomicU64>,
53 successful_requests: Arc<AtomicU64>,
54 failed_requests: Arc<AtomicU64>,
55 timeout_requests: Arc<AtomicU64>,
56 circuit_open_count: Arc<AtomicU64>,
57 circuit_close_count: Arc<AtomicU64>,
58 start_time: Instant,
59 metrics: Option<Arc<dyn Metrics>>,
61 enabled: bool,
63}
64
65impl CircuitBreaker {
66 pub fn new(config: CircuitBreakerConfig) -> Self {
68 Self {
69 state: Arc::new(Mutex::new(CircuitState::Closed)),
70 config,
71 failure_count: Arc::new(AtomicU32::new(0)),
72 success_count: Arc::new(AtomicU32::new(0)),
73 last_failure_time: Arc::new(Mutex::new(None)),
74 total_requests: Arc::new(AtomicU64::new(0)),
75 successful_requests: Arc::new(AtomicU64::new(0)),
76 failed_requests: Arc::new(AtomicU64::new(0)),
77 timeout_requests: Arc::new(AtomicU64::new(0)),
78 circuit_open_count: Arc::new(AtomicU64::new(0)),
79 circuit_close_count: Arc::new(AtomicU64::new(0)),
80 start_time: Instant::now(),
81 metrics: None,
82 enabled: true,
83 }
84 }
85
86 pub fn with_metrics(config: CircuitBreakerConfig, metrics: Arc<dyn Metrics>) -> Self {
88 Self {
89 state: Arc::new(Mutex::new(CircuitState::Closed)),
90 config,
91 failure_count: Arc::new(AtomicU32::new(0)),
92 success_count: Arc::new(AtomicU32::new(0)),
93 last_failure_time: Arc::new(Mutex::new(None)),
94 total_requests: Arc::new(AtomicU64::new(0)),
95 successful_requests: Arc::new(AtomicU64::new(0)),
96 failed_requests: Arc::new(AtomicU64::new(0)),
97 timeout_requests: Arc::new(AtomicU64::new(0)),
98 circuit_open_count: Arc::new(AtomicU64::new(0)),
99 circuit_close_count: Arc::new(AtomicU64::new(0)),
100 start_time: Instant::now(),
101 metrics: Some(metrics),
102 enabled: true,
103 }
104 }
105
106 pub fn set_enabled(&mut self, enabled: bool) {
108 self.enabled = enabled;
109 }
110
111 pub async fn call<F, T>(&self, f: F) -> Result<T, CircuitBreakerError>
113 where
114 F: Future<Output = Result<T, AiLibError>>,
115 {
116 if !self.enabled {
118 return f.await.map_err(CircuitBreakerError::Underlying);
119 }
120
121 self.total_requests
123 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
124
125 if !self.should_allow_request().await {
127 return Err(CircuitBreakerError::CircuitOpen(
128 "Circuit breaker is open".to_string(),
129 ));
130 }
131
132 let result = timeout(self.config.request_timeout, f).await;
134
135 match result {
136 Ok(Ok(response)) => {
137 self.on_success().await;
138 Ok(response)
139 }
140 Ok(Err(error)) => {
141 self.on_failure().await;
142 Err(CircuitBreakerError::Underlying(error))
143 }
144 Err(_) => {
145 self.on_timeout().await;
146 Err(CircuitBreakerError::RequestTimeout(
147 "Request timed out".to_string(),
148 ))
149 }
150 }
151 }
152
153 async fn should_allow_request(&self) -> bool {
155 let state = *self.state.lock().unwrap();
156
157 match state {
158 CircuitState::Closed => true,
159 CircuitState::Open => {
160 let allow_half_open = {
162 let last = self.last_failure_time.lock().unwrap();
163 last.and_then(|t| Some(t.elapsed() >= self.config.recovery_timeout))
164 .unwrap_or(false)
165 };
166 if allow_half_open {
167 self.transition_to_half_open().await;
168 true
169 } else {
170 false
171 }
172 }
173 CircuitState::HalfOpen => true,
174 }
175 }
176
177 async fn on_success(&self) {
179 self.successful_requests
180 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
181
182 let mut record_closed_metric = false;
183 {
184 let mut state = self.state.lock().unwrap();
185 match *state {
186 CircuitState::Closed => {
187 self.failure_count
189 .store(0, std::sync::atomic::Ordering::Relaxed);
190 }
191 CircuitState::HalfOpen => {
192 let success_count = self
193 .success_count
194 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
195 + 1;
196 if success_count >= self.config.success_threshold {
197 *state = CircuitState::Closed;
198 self.success_count
199 .store(0, std::sync::atomic::Ordering::Relaxed);
200 self.circuit_close_count
201 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
202 record_closed_metric = true;
203 }
204 }
205 CircuitState::Open => {
206 }
208 }
209 }
210 if record_closed_metric {
211 if let Some(metrics) = &self.metrics {
212 metrics.incr_counter("circuit_breaker.closed", 1).await;
213 }
214 }
215 }
216
217 async fn on_failure(&self) {
219 self.failed_requests
220 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
221
222 let failure_count = self
223 .failure_count
224 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
225 + 1;
226
227 *self.last_failure_time.lock().unwrap() = Some(Instant::now());
229
230 if failure_count >= self.config.failure_threshold {
232 {
233 let mut state = self.state.lock().unwrap();
234 *state = CircuitState::Open;
235 }
236 self.circuit_open_count
237 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
238
239 if let Some(metrics) = &self.metrics {
241 let m = metrics.clone();
242 m.incr_counter("circuit_breaker.opened", 1).await;
243 }
244 }
245 }
246
247 async fn on_timeout(&self) {
249 self.timeout_requests
250 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
251 self.failed_requests
252 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
253
254 let failure_count = self
255 .failure_count
256 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
257 + 1;
258
259 *self.last_failure_time.lock().unwrap() = Some(Instant::now());
261
262 if failure_count >= self.config.failure_threshold {
264 {
265 let mut state = self.state.lock().unwrap();
266 *state = CircuitState::Open;
267 }
268 self.circuit_open_count
269 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
270
271 if let Some(metrics) = &self.metrics {
273 let m = metrics.clone();
274 m.incr_counter("circuit_breaker.opened", 1).await;
275 }
276 }
277 }
278
279 async fn transition_to_half_open(&self) {
281 let mut state = self.state.lock().unwrap();
282 *state = CircuitState::HalfOpen;
283 self.success_count
284 .store(0, std::sync::atomic::Ordering::Relaxed);
285 }
286
287 pub fn state(&self) -> CircuitState {
289 *self.state.lock().unwrap()
290 }
291
292 pub fn failure_count(&self) -> u32 {
294 self.failure_count
295 .load(std::sync::atomic::Ordering::Relaxed)
296 }
297
298 pub fn success_count(&self) -> u32 {
300 self.success_count
301 .load(std::sync::atomic::Ordering::Relaxed)
302 }
303
304 pub fn get_metrics(&self) -> CircuitBreakerMetrics {
306 CircuitBreakerMetrics {
307 state: self.state(),
308 total_requests: self
309 .total_requests
310 .load(std::sync::atomic::Ordering::Relaxed),
311 successful_requests: self
312 .successful_requests
313 .load(std::sync::atomic::Ordering::Relaxed),
314 failed_requests: self
315 .failed_requests
316 .load(std::sync::atomic::Ordering::Relaxed),
317 timeout_requests: self
318 .timeout_requests
319 .load(std::sync::atomic::Ordering::Relaxed),
320 circuit_open_count: self
321 .circuit_open_count
322 .load(std::sync::atomic::Ordering::Relaxed),
323 circuit_close_count: self
324 .circuit_close_count
325 .load(std::sync::atomic::Ordering::Relaxed),
326 current_failure_count: self.failure_count(),
327 current_success_count: self.success_count(),
328 last_failure_time: *self.last_failure_time.lock().unwrap(),
329 uptime: self.start_time.elapsed(),
330 }
331 }
332
333 pub fn success_rate(&self) -> f64 {
335 let total = self
336 .total_requests
337 .load(std::sync::atomic::Ordering::Relaxed);
338 if total == 0 {
339 return 100.0;
340 }
341 let successful = self
342 .successful_requests
343 .load(std::sync::atomic::Ordering::Relaxed);
344 (successful as f64 / total as f64) * 100.0
345 }
346
347 pub fn failure_rate(&self) -> f64 {
349 let total = self
350 .total_requests
351 .load(std::sync::atomic::Ordering::Relaxed);
352 if total == 0 {
353 return 0.0;
354 }
355 let failed = self
356 .failed_requests
357 .load(std::sync::atomic::Ordering::Relaxed);
358 (failed as f64 / total as f64) * 100.0
359 }
360
361 pub fn is_healthy(&self) -> bool {
363 self.state() == CircuitState::Closed && self.failure_rate() < 50.0
364 }
365
366 pub fn reset(&self) {
368 self.failure_count
369 .store(0, std::sync::atomic::Ordering::Relaxed);
370 self.success_count
371 .store(0, std::sync::atomic::Ordering::Relaxed);
372 self.total_requests
373 .store(0, std::sync::atomic::Ordering::Relaxed);
374 self.successful_requests
375 .store(0, std::sync::atomic::Ordering::Relaxed);
376 self.failed_requests
377 .store(0, std::sync::atomic::Ordering::Relaxed);
378 self.timeout_requests
379 .store(0, std::sync::atomic::Ordering::Relaxed);
380 self.circuit_open_count
381 .store(0, std::sync::atomic::Ordering::Relaxed);
382 self.circuit_close_count
383 .store(0, std::sync::atomic::Ordering::Relaxed);
384
385 let mut state = self.state.lock().unwrap();
386 *state = CircuitState::Closed;
387
388 let mut last_failure = self.last_failure_time.lock().unwrap();
389 *last_failure = None;
390 }
391
392 pub fn force_open(&self) {
394 let mut state = self.state.lock().unwrap();
395 *state = CircuitState::Open;
396 self.circuit_open_count
397 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
398 }
399
400 pub fn force_close(&self) {
402 let mut state = self.state.lock().unwrap();
403 *state = CircuitState::Closed;
404 self.failure_count
405 .store(0, std::sync::atomic::Ordering::Relaxed);
406 self.success_count
407 .store(0, std::sync::atomic::Ordering::Relaxed);
408 self.circuit_close_count
409 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
410 }
411}