1#![allow(clippy::wildcard_in_or_patterns)]
7#![allow(clippy::significant_drop_tightening)]
8#![allow(clippy::cast_precision_loss)]
9#![allow(clippy::match_same_arms)]
10#![allow(clippy::significant_drop_in_scrutinee)]
11#![allow(clippy::option_if_let_else)]
12#![allow(clippy::float_cmp)]
13
14use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
15use std::time::{Duration, Instant};
16
17use serde::{Deserialize, Serialize};
18use thiserror::Error;
19use tokio::sync::RwLock;
20use tracing::{debug, error, instrument, warn};
21
22use crate::PostgresError;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[repr(u8)]
27pub enum CircuitState {
28 Closed = 0,
30 Open = 1,
32 HalfOpen = 2,
34}
35
36impl From<u8> for CircuitState {
37 fn from(value: u8) -> Self {
38 match value {
39 0 => Self::Closed,
40 2 => Self::HalfOpen,
41 1 | _ => Self::Open, }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct CircuitBreakerConfig {
49 pub failure_threshold: u64,
51 pub success_threshold: u64,
53 pub timeout_duration: Duration,
55 pub rolling_window: Duration,
57 pub minimum_requests: u64,
59}
60
61impl Default for CircuitBreakerConfig {
62 fn default() -> Self {
63 Self {
64 failure_threshold: 5,
65 success_threshold: 3,
66 timeout_duration: Duration::from_secs(30),
67 rolling_window: Duration::from_secs(60),
68 minimum_requests: 10,
69 }
70 }
71}
72
73impl CircuitBreakerConfig {
74 pub const fn conservative() -> Self {
76 Self {
77 failure_threshold: 3,
78 success_threshold: 5,
79 timeout_duration: Duration::from_secs(60),
80 rolling_window: Duration::from_secs(120),
81 minimum_requests: 5,
82 }
83 }
84
85 pub const fn aggressive() -> Self {
87 Self {
88 failure_threshold: 10,
89 success_threshold: 2,
90 timeout_duration: Duration::from_secs(10),
91 rolling_window: Duration::from_secs(30),
92 minimum_requests: 20,
93 }
94 }
95}
96
97#[derive(Debug, Error)]
99pub enum CircuitBreakerError {
100 #[error("Circuit breaker is open, operation rejected. Last failure: {last_failure:?}")]
102 Open {
103 last_failure: Option<String>,
105 },
106
107 #[error("Operation failed: {source}")]
109 OperationFailed {
110 #[source]
112 source: PostgresError,
113 },
114}
115
116impl From<CircuitBreakerError> for PostgresError {
117 fn from(error: CircuitBreakerError) -> Self {
118 match error {
119 CircuitBreakerError::Open { .. } => Self::Connection(sqlx::Error::PoolClosed),
120 CircuitBreakerError::OperationFailed { source } => source,
121 }
122 }
123}
124
125#[derive(Debug)]
127struct SlidingWindow {
128 requests: RwLock<Vec<(Instant, bool)>>, window_duration: Duration,
130}
131
132impl SlidingWindow {
133 fn new(window_duration: Duration) -> Self {
134 Self {
135 requests: RwLock::new(Vec::new()),
136 window_duration,
137 }
138 }
139
140 async fn record_request(&self, success: bool) {
141 let now = Instant::now();
142 let mut requests = self.requests.write().await;
143
144 requests.push((now, success));
146
147 let cutoff = now.checked_sub(self.window_duration).unwrap();
149 requests.retain(|(timestamp, _)| *timestamp > cutoff);
150 }
151
152 async fn get_metrics(&self) -> (u64, u64) {
153 let now = Instant::now();
154 let cutoff = now.checked_sub(self.window_duration).unwrap();
155 let requests = self.requests.read().await;
156
157 let recent_requests: Vec<_> = requests
158 .iter()
159 .filter(|(timestamp, _)| *timestamp > cutoff)
160 .collect();
161
162 let total = recent_requests.len() as u64;
163 let failures = recent_requests
164 .iter()
165 .filter(|(_, success)| !*success)
166 .count() as u64;
167
168 (total, failures)
169 }
170}
171
172#[derive(Debug)]
174pub struct CircuitBreaker {
175 config: CircuitBreakerConfig,
176 state: AtomicU8, failure_count: AtomicU64,
178 success_count: AtomicU64,
179 last_failure_time: RwLock<Option<Instant>>,
180 last_failure_reason: RwLock<Option<String>>,
181 sliding_window: SlidingWindow,
182}
183
184impl CircuitBreaker {
185 pub fn new(config: CircuitBreakerConfig) -> Self {
187 Self {
188 sliding_window: SlidingWindow::new(config.rolling_window),
189 config,
190 state: AtomicU8::new(CircuitState::Closed as u8),
191 failure_count: AtomicU64::new(0),
192 success_count: AtomicU64::new(0),
193 last_failure_time: RwLock::new(None),
194 last_failure_reason: RwLock::new(None),
195 }
196 }
197
198 pub fn state(&self) -> CircuitState {
200 CircuitState::from(self.state.load(Ordering::Acquire))
201 }
202
203 pub async fn metrics(&self) -> CircuitBreakerMetrics {
205 let (total_requests, total_failures) = self.sliding_window.get_metrics().await;
206 let failure_rate = if total_requests > 0 {
207 total_failures as f64 / total_requests as f64
208 } else {
209 0.0
210 };
211
212 let last_failure_time = *self.last_failure_time.read().await;
213 let last_failure_reason = self.last_failure_reason.read().await.clone();
214
215 let last_failure_timestamp = last_failure_time.map(|instant| instant.elapsed().as_secs());
217
218 CircuitBreakerMetrics {
219 state: self.state(),
220 failure_count: self.failure_count.load(Ordering::Relaxed),
221 success_count: self.success_count.load(Ordering::Relaxed),
222 total_requests,
223 total_failures,
224 failure_rate,
225 last_failure_time: last_failure_timestamp,
226 last_failure_reason,
227 }
228 }
229
230 #[instrument(skip(self, operation))]
232 pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, CircuitBreakerError>
233 where
234 F: FnOnce() -> Fut,
235 Fut: std::future::Future<Output = Result<T, PostgresError>>,
236 {
237 if !self.should_allow_request().await {
239 let last_failure = self.last_failure_reason.read().await.clone();
240 return Err(CircuitBreakerError::Open { last_failure });
241 }
242
243 match operation().await {
245 Ok(result) => {
246 self.record_success().await;
247 Ok(result)
248 }
249 Err(error) => {
250 let error_msg = error.to_string();
251 self.record_failure(error_msg).await;
252 Err(CircuitBreakerError::OperationFailed { source: error })
253 }
254 }
255 }
256
257 async fn should_allow_request(&self) -> bool {
259 match self.state() {
260 CircuitState::Closed => true,
261 CircuitState::Open => {
262 if let Some(last_failure) = *self.last_failure_time.read().await {
264 if last_failure.elapsed() >= self.config.timeout_duration {
265 debug!("Circuit breaker transitioning from Open to HalfOpen");
266 self.transition_to_half_open();
267 true
268 } else {
269 false
270 }
271 } else {
272 false
273 }
274 }
275 CircuitState::HalfOpen => true,
276 }
277 }
278
279 async fn record_success(&self) {
281 self.sliding_window.record_request(true).await;
282
283 let current_state = self.state();
284 match current_state {
285 CircuitState::Closed => {
286 self.failure_count.store(0, Ordering::Relaxed);
288 }
289 CircuitState::HalfOpen => {
290 let success_count = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
291 debug!("Circuit breaker half-open success count: {}", success_count);
292
293 if success_count >= self.config.success_threshold {
294 debug!("Circuit breaker transitioning from HalfOpen to Closed");
295 self.transition_to_closed();
296 }
297 }
298 CircuitState::Open => {
299 warn!("Recorded success while circuit was open");
301 }
302 }
303 }
304
305 async fn record_failure(&self, error_msg: String) {
307 self.sliding_window.record_request(false).await;
308
309 *self.last_failure_time.write().await = Some(Instant::now());
311 *self.last_failure_reason.write().await = Some(error_msg);
312
313 let current_state = self.state();
314 match current_state {
315 CircuitState::Closed => {
316 let failure_count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
317 debug!("Circuit breaker failure count: {}", failure_count);
318
319 let (total_requests, total_failures) = self.sliding_window.get_metrics().await;
321
322 if total_requests >= self.config.minimum_requests
323 && total_failures >= self.config.failure_threshold
324 {
325 warn!(
326 "Circuit breaker opening due to failure threshold. Failures: {}/{}",
327 total_failures, total_requests
328 );
329 self.transition_to_open();
330 }
331 }
332 CircuitState::HalfOpen => {
333 debug!("Circuit breaker transitioning from HalfOpen to Open due to failure");
334 self.transition_to_open();
335 }
336 CircuitState::Open => {
337 self.failure_count.fetch_add(1, Ordering::Relaxed);
339 }
340 }
341 }
342
343 fn transition_to_closed(&self) {
345 self.state
346 .store(CircuitState::Closed as u8, Ordering::Release);
347 self.failure_count.store(0, Ordering::Relaxed);
348 self.success_count.store(0, Ordering::Relaxed);
349 debug!("Circuit breaker state changed to Closed");
350 }
351
352 fn transition_to_open(&self) {
354 self.state
355 .store(CircuitState::Open as u8, Ordering::Release);
356 self.success_count.store(0, Ordering::Relaxed);
357 error!("Circuit breaker state changed to Open");
358 }
359
360 fn transition_to_half_open(&self) {
362 self.state
363 .store(CircuitState::HalfOpen as u8, Ordering::Release);
364 self.success_count.store(0, Ordering::Relaxed);
365 debug!("Circuit breaker state changed to HalfOpen");
366 }
367
368 pub async fn reset(&self) {
370 debug!("Manually resetting circuit breaker");
371 self.transition_to_closed();
372 *self.last_failure_time.write().await = None;
373 *self.last_failure_reason.write().await = None;
374 }
375
376 pub async fn force_open(&self) {
378 warn!("Manually forcing circuit breaker to open state");
379 self.transition_to_open();
380 *self.last_failure_time.write().await = Some(Instant::now());
381 *self.last_failure_reason.write().await = Some("Manually forced open".to_string());
382 }
383}
384
385#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct CircuitBreakerMetrics {
388 pub state: CircuitState,
390 pub failure_count: u64,
392 pub success_count: u64,
394 pub total_requests: u64,
396 pub total_failures: u64,
398 pub failure_rate: f64,
400 #[serde(skip_serializing_if = "Option::is_none")]
402 pub last_failure_time: Option<u64>,
403 pub last_failure_reason: Option<String>,
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[tokio::test]
412 async fn test_circuit_breaker_closed_to_open() {
413 let config = CircuitBreakerConfig {
414 failure_threshold: 3,
415 minimum_requests: 3,
416 ..CircuitBreakerConfig::default()
417 };
418
419 let breaker = CircuitBreaker::new(config);
420 assert_eq!(breaker.state(), CircuitState::Closed);
421
422 for i in 0..3 {
424 let result = breaker
425 .execute(|| async {
426 Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
427 })
428 .await;
429
430 assert!(result.is_err());
431
432 if i < 2 {
433 assert_eq!(breaker.state(), CircuitState::Closed);
434 } else {
435 assert_eq!(breaker.state(), CircuitState::Open);
436 }
437 }
438 }
439
440 #[tokio::test]
441 async fn test_circuit_breaker_open_to_half_open() {
442 let config = CircuitBreakerConfig {
443 failure_threshold: 1,
444 minimum_requests: 1,
445 timeout_duration: Duration::from_millis(50),
446 ..CircuitBreakerConfig::default()
447 };
448
449 let breaker = CircuitBreaker::new(config);
450
451 let _ = breaker
453 .execute(|| async {
454 Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
455 })
456 .await;
457
458 assert_eq!(breaker.state(), CircuitState::Open);
459
460 tokio::time::sleep(Duration::from_millis(60)).await;
462
463 let result = breaker
465 .execute(|| async { Ok::<(), PostgresError>(()) })
466 .await;
467
468 assert!(result.is_ok());
469 assert_eq!(breaker.state(), CircuitState::HalfOpen);
471 }
472
473 #[tokio::test]
474 async fn test_circuit_breaker_half_open_to_closed() {
475 let config = CircuitBreakerConfig {
476 failure_threshold: 1,
477 minimum_requests: 1,
478 success_threshold: 2,
479 timeout_duration: Duration::from_millis(50),
480 ..CircuitBreakerConfig::default()
481 };
482
483 let breaker = CircuitBreaker::new(config);
484
485 let _ = breaker
487 .execute(|| async {
488 Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
489 })
490 .await;
491
492 tokio::time::sleep(Duration::from_millis(60)).await;
494
495 let _ = breaker
497 .execute(|| async { Ok::<(), PostgresError>(()) })
498 .await;
499
500 let _ = breaker
502 .execute(|| async { Ok::<(), PostgresError>(()) })
503 .await;
504
505 assert_eq!(breaker.state(), CircuitState::Closed);
506 }
507
508 #[tokio::test]
509 async fn test_circuit_breaker_metrics() {
510 let breaker = CircuitBreaker::new(CircuitBreakerConfig::default());
511
512 let _ = breaker
514 .execute(|| async { Ok::<(), PostgresError>(()) })
515 .await;
516
517 let _ = breaker
518 .execute(|| async {
519 Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
520 })
521 .await;
522
523 let metrics = breaker.metrics().await;
524 assert_eq!(metrics.state, CircuitState::Closed);
525 assert_eq!(metrics.total_requests, 2);
526 assert_eq!(metrics.total_failures, 1);
527 assert_eq!(metrics.failure_rate, 0.5);
528 }
529
530 #[tokio::test]
531 async fn test_circuit_breaker_reset() {
532 let config = CircuitBreakerConfig {
533 failure_threshold: 1,
534 minimum_requests: 1,
535 ..CircuitBreakerConfig::default()
536 };
537
538 let breaker = CircuitBreaker::new(config);
539
540 let _ = breaker
542 .execute(|| async {
543 Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
544 })
545 .await;
546
547 assert_eq!(breaker.state(), CircuitState::Open);
548
549 breaker.reset().await;
551 assert_eq!(breaker.state(), CircuitState::Closed);
552
553 let metrics = breaker.metrics().await;
554 assert!(metrics.last_failure_time.is_none());
555 assert!(metrics.last_failure_reason.is_none());
556 }
557
558 #[tokio::test]
559 async fn test_sliding_window() {
560 let window = SlidingWindow::new(Duration::from_millis(100));
561
562 window.record_request(true).await;
564 window.record_request(false).await;
565 window.record_request(true).await;
566
567 let (total, failures) = window.get_metrics().await;
568 assert_eq!(total, 3);
569 assert_eq!(failures, 1);
570
571 tokio::time::sleep(Duration::from_millis(150)).await;
573
574 let (total, failures) = window.get_metrics().await;
575 assert_eq!(total, 0);
576 assert_eq!(failures, 0);
577 }
578}