eventcore_postgres/
circuit_breaker.rs

1//! Circuit breaker pattern implementation for `PostgreSQL` operations
2//!
3//! This module provides resilient circuit breaker functionality to prevent
4//! cascading failures when database operations are failing consistently.
5
6#![allow(clippy::wildcard_in_or_patterns)]
7#![allow(clippy::significant_drop_tightening)]
8#![allow(clippy::cast_precision_loss)]
9#![allow(clippy::match_same_arms)]
10#![allow(clippy::significant_drop_in_scrutinee)]
11#![allow(clippy::option_if_let_else)]
12#![allow(clippy::float_cmp)]
13
14use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
15use std::time::{Duration, Instant};
16
17use serde::{Deserialize, Serialize};
18use thiserror::Error;
19use tokio::sync::RwLock;
20use tracing::{debug, error, instrument, warn};
21
22use crate::PostgresError;
23
24/// Circuit breaker states
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[repr(u8)]
27pub enum CircuitState {
28    /// Circuit is closed, operations flow normally
29    Closed = 0,
30    /// Circuit is open, operations are rejected immediately
31    Open = 1,
32    /// Circuit is half-open, testing if service has recovered
33    HalfOpen = 2,
34}
35
36impl From<u8> for CircuitState {
37    fn from(value: u8) -> Self {
38        match value {
39            0 => Self::Closed,
40            2 => Self::HalfOpen,
41            1 | _ => Self::Open, // Default to safest state
42        }
43    }
44}
45
46/// Circuit breaker configuration
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct CircuitBreakerConfig {
49    /// Number of failures required to open the circuit
50    pub failure_threshold: u64,
51    /// Number of requests in half-open state before deciding to close circuit
52    pub success_threshold: u64,
53    /// Duration to wait before transitioning from open to half-open
54    pub timeout_duration: Duration,
55    /// Rolling window duration for failure counting
56    pub rolling_window: Duration,
57    /// Minimum number of requests in window before considering failure rate
58    pub minimum_requests: u64,
59}
60
61impl Default for CircuitBreakerConfig {
62    fn default() -> Self {
63        Self {
64            failure_threshold: 5,
65            success_threshold: 3,
66            timeout_duration: Duration::from_secs(30),
67            rolling_window: Duration::from_secs(60),
68            minimum_requests: 10,
69        }
70    }
71}
72
73impl CircuitBreakerConfig {
74    /// Create a conservative configuration for critical operations
75    pub const fn conservative() -> Self {
76        Self {
77            failure_threshold: 3,
78            success_threshold: 5,
79            timeout_duration: Duration::from_secs(60),
80            rolling_window: Duration::from_secs(120),
81            minimum_requests: 5,
82        }
83    }
84
85    /// Create an aggressive configuration for non-critical operations
86    pub const fn aggressive() -> Self {
87        Self {
88            failure_threshold: 10,
89            success_threshold: 2,
90            timeout_duration: Duration::from_secs(10),
91            rolling_window: Duration::from_secs(30),
92            minimum_requests: 20,
93        }
94    }
95}
96
97/// Errors related to circuit breaker operation
98#[derive(Debug, Error)]
99pub enum CircuitBreakerError {
100    /// Circuit breaker is open, operation rejected
101    #[error("Circuit breaker is open, operation rejected. Last failure: {last_failure:?}")]
102    Open {
103        /// Reason for the last failure that caused the circuit to open
104        last_failure: Option<String>,
105    },
106
107    /// Operation failed and circuit breaker recorded the failure
108    #[error("Operation failed: {source}")]
109    OperationFailed {
110        /// The underlying error that caused the operation to fail
111        #[source]
112        source: PostgresError,
113    },
114}
115
116impl From<CircuitBreakerError> for PostgresError {
117    fn from(error: CircuitBreakerError) -> Self {
118        match error {
119            CircuitBreakerError::Open { .. } => Self::Connection(sqlx::Error::PoolClosed),
120            CircuitBreakerError::OperationFailed { source } => source,
121        }
122    }
123}
124
125/// Sliding window for tracking request metrics
126#[derive(Debug)]
127struct SlidingWindow {
128    requests: RwLock<Vec<(Instant, bool)>>, // (timestamp, was_success)
129    window_duration: Duration,
130}
131
132impl SlidingWindow {
133    fn new(window_duration: Duration) -> Self {
134        Self {
135            requests: RwLock::new(Vec::new()),
136            window_duration,
137        }
138    }
139
140    async fn record_request(&self, success: bool) {
141        let now = Instant::now();
142        let mut requests = self.requests.write().await;
143
144        // Add new request
145        requests.push((now, success));
146
147        // Clean old requests outside the window
148        let cutoff = now.checked_sub(self.window_duration).unwrap();
149        requests.retain(|(timestamp, _)| *timestamp > cutoff);
150    }
151
152    async fn get_metrics(&self) -> (u64, u64) {
153        let now = Instant::now();
154        let cutoff = now.checked_sub(self.window_duration).unwrap();
155        let requests = self.requests.read().await;
156
157        let recent_requests: Vec<_> = requests
158            .iter()
159            .filter(|(timestamp, _)| *timestamp > cutoff)
160            .collect();
161
162        let total = recent_requests.len() as u64;
163        let failures = recent_requests
164            .iter()
165            .filter(|(_, success)| !*success)
166            .count() as u64;
167
168        (total, failures)
169    }
170}
171
172/// Circuit breaker implementation
173#[derive(Debug)]
174pub struct CircuitBreaker {
175    config: CircuitBreakerConfig,
176    state: AtomicU8, // Uses CircuitState representation
177    failure_count: AtomicU64,
178    success_count: AtomicU64,
179    last_failure_time: RwLock<Option<Instant>>,
180    last_failure_reason: RwLock<Option<String>>,
181    sliding_window: SlidingWindow,
182}
183
184impl CircuitBreaker {
185    /// Create a new circuit breaker with the given configuration
186    pub fn new(config: CircuitBreakerConfig) -> Self {
187        Self {
188            sliding_window: SlidingWindow::new(config.rolling_window),
189            config,
190            state: AtomicU8::new(CircuitState::Closed as u8),
191            failure_count: AtomicU64::new(0),
192            success_count: AtomicU64::new(0),
193            last_failure_time: RwLock::new(None),
194            last_failure_reason: RwLock::new(None),
195        }
196    }
197
198    /// Get current circuit breaker state
199    pub fn state(&self) -> CircuitState {
200        CircuitState::from(self.state.load(Ordering::Acquire))
201    }
202
203    /// Get current metrics
204    pub async fn metrics(&self) -> CircuitBreakerMetrics {
205        let (total_requests, total_failures) = self.sliding_window.get_metrics().await;
206        let failure_rate = if total_requests > 0 {
207            total_failures as f64 / total_requests as f64
208        } else {
209            0.0
210        };
211
212        let last_failure_time = *self.last_failure_time.read().await;
213        let last_failure_reason = self.last_failure_reason.read().await.clone();
214
215        // Convert Instant to seconds since epoch for serialization
216        let last_failure_timestamp = last_failure_time.map(|instant| instant.elapsed().as_secs());
217
218        CircuitBreakerMetrics {
219            state: self.state(),
220            failure_count: self.failure_count.load(Ordering::Relaxed),
221            success_count: self.success_count.load(Ordering::Relaxed),
222            total_requests,
223            total_failures,
224            failure_rate,
225            last_failure_time: last_failure_timestamp,
226            last_failure_reason,
227        }
228    }
229
230    /// Execute an operation through the circuit breaker
231    #[instrument(skip(self, operation))]
232    pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, CircuitBreakerError>
233    where
234        F: FnOnce() -> Fut,
235        Fut: std::future::Future<Output = Result<T, PostgresError>>,
236    {
237        // Check if circuit should allow the request
238        if !self.should_allow_request().await {
239            let last_failure = self.last_failure_reason.read().await.clone();
240            return Err(CircuitBreakerError::Open { last_failure });
241        }
242
243        // Execute the operation
244        match operation().await {
245            Ok(result) => {
246                self.record_success().await;
247                Ok(result)
248            }
249            Err(error) => {
250                let error_msg = error.to_string();
251                self.record_failure(error_msg).await;
252                Err(CircuitBreakerError::OperationFailed { source: error })
253            }
254        }
255    }
256
257    /// Check if the circuit should allow a request
258    async fn should_allow_request(&self) -> bool {
259        match self.state() {
260            CircuitState::Closed => true,
261            CircuitState::Open => {
262                // Check if timeout has elapsed to transition to half-open
263                if let Some(last_failure) = *self.last_failure_time.read().await {
264                    if last_failure.elapsed() >= self.config.timeout_duration {
265                        debug!("Circuit breaker transitioning from Open to HalfOpen");
266                        self.transition_to_half_open();
267                        true
268                    } else {
269                        false
270                    }
271                } else {
272                    false
273                }
274            }
275            CircuitState::HalfOpen => true,
276        }
277    }
278
279    /// Record a successful operation
280    async fn record_success(&self) {
281        self.sliding_window.record_request(true).await;
282
283        let current_state = self.state();
284        match current_state {
285            CircuitState::Closed => {
286                // Reset failure count on success
287                self.failure_count.store(0, Ordering::Relaxed);
288            }
289            CircuitState::HalfOpen => {
290                let success_count = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
291                debug!("Circuit breaker half-open success count: {}", success_count);
292
293                if success_count >= self.config.success_threshold {
294                    debug!("Circuit breaker transitioning from HalfOpen to Closed");
295                    self.transition_to_closed();
296                }
297            }
298            CircuitState::Open => {
299                // Shouldn't happen, but handle gracefully
300                warn!("Recorded success while circuit was open");
301            }
302        }
303    }
304
305    /// Record a failed operation
306    async fn record_failure(&self, error_msg: String) {
307        self.sliding_window.record_request(false).await;
308
309        // Update failure tracking
310        *self.last_failure_time.write().await = Some(Instant::now());
311        *self.last_failure_reason.write().await = Some(error_msg);
312
313        let current_state = self.state();
314        match current_state {
315            CircuitState::Closed => {
316                let failure_count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
317                debug!("Circuit breaker failure count: {}", failure_count);
318
319                // Check if we should open the circuit based on recent failure rate
320                let (total_requests, total_failures) = self.sliding_window.get_metrics().await;
321
322                if total_requests >= self.config.minimum_requests
323                    && total_failures >= self.config.failure_threshold
324                {
325                    warn!(
326                        "Circuit breaker opening due to failure threshold. Failures: {}/{}",
327                        total_failures, total_requests
328                    );
329                    self.transition_to_open();
330                }
331            }
332            CircuitState::HalfOpen => {
333                debug!("Circuit breaker transitioning from HalfOpen to Open due to failure");
334                self.transition_to_open();
335            }
336            CircuitState::Open => {
337                // Already open, just update counters
338                self.failure_count.fetch_add(1, Ordering::Relaxed);
339            }
340        }
341    }
342
343    /// Transition to closed state
344    fn transition_to_closed(&self) {
345        self.state
346            .store(CircuitState::Closed as u8, Ordering::Release);
347        self.failure_count.store(0, Ordering::Relaxed);
348        self.success_count.store(0, Ordering::Relaxed);
349        debug!("Circuit breaker state changed to Closed");
350    }
351
352    /// Transition to open state
353    fn transition_to_open(&self) {
354        self.state
355            .store(CircuitState::Open as u8, Ordering::Release);
356        self.success_count.store(0, Ordering::Relaxed);
357        error!("Circuit breaker state changed to Open");
358    }
359
360    /// Transition to half-open state
361    fn transition_to_half_open(&self) {
362        self.state
363            .store(CircuitState::HalfOpen as u8, Ordering::Release);
364        self.success_count.store(0, Ordering::Relaxed);
365        debug!("Circuit breaker state changed to HalfOpen");
366    }
367
368    /// Manually reset the circuit breaker to closed state
369    pub async fn reset(&self) {
370        debug!("Manually resetting circuit breaker");
371        self.transition_to_closed();
372        *self.last_failure_time.write().await = None;
373        *self.last_failure_reason.write().await = None;
374    }
375
376    /// Force the circuit breaker to open (for testing or manual intervention)
377    pub async fn force_open(&self) {
378        warn!("Manually forcing circuit breaker to open state");
379        self.transition_to_open();
380        *self.last_failure_time.write().await = Some(Instant::now());
381        *self.last_failure_reason.write().await = Some("Manually forced open".to_string());
382    }
383}
384
385/// Circuit breaker metrics for monitoring
386#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct CircuitBreakerMetrics {
388    /// Current state of the circuit breaker
389    pub state: CircuitState,
390    /// Total number of failures
391    pub failure_count: u64,
392    /// Total number of successes in current state
393    pub success_count: u64,
394    /// Total requests in the rolling window
395    pub total_requests: u64,
396    /// Total failures in the rolling window
397    pub total_failures: u64,
398    /// Current failure rate (0.0 to 1.0)
399    pub failure_rate: f64,
400    /// Timestamp of last failure (as seconds since epoch)
401    #[serde(skip_serializing_if = "Option::is_none")]
402    pub last_failure_time: Option<u64>,
403    /// Reason for last failure
404    pub last_failure_reason: Option<String>,
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[tokio::test]
412    async fn test_circuit_breaker_closed_to_open() {
413        let config = CircuitBreakerConfig {
414            failure_threshold: 3,
415            minimum_requests: 3,
416            ..CircuitBreakerConfig::default()
417        };
418
419        let breaker = CircuitBreaker::new(config);
420        assert_eq!(breaker.state(), CircuitState::Closed);
421
422        // Simulate failures
423        for i in 0..3 {
424            let result = breaker
425                .execute(|| async {
426                    Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
427                })
428                .await;
429
430            assert!(result.is_err());
431
432            if i < 2 {
433                assert_eq!(breaker.state(), CircuitState::Closed);
434            } else {
435                assert_eq!(breaker.state(), CircuitState::Open);
436            }
437        }
438    }
439
440    #[tokio::test]
441    async fn test_circuit_breaker_open_to_half_open() {
442        let config = CircuitBreakerConfig {
443            failure_threshold: 1,
444            minimum_requests: 1,
445            timeout_duration: Duration::from_millis(50),
446            ..CircuitBreakerConfig::default()
447        };
448
449        let breaker = CircuitBreaker::new(config);
450
451        // Cause failure to open circuit
452        let _ = breaker
453            .execute(|| async {
454                Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
455            })
456            .await;
457
458        assert_eq!(breaker.state(), CircuitState::Open);
459
460        // Wait for timeout
461        tokio::time::sleep(Duration::from_millis(60)).await;
462
463        // Next request should transition to half-open
464        let result = breaker
465            .execute(|| async { Ok::<(), PostgresError>(()) })
466            .await;
467
468        assert!(result.is_ok());
469        // After first success in half-open, should still be half-open (need success_threshold successes)
470        assert_eq!(breaker.state(), CircuitState::HalfOpen);
471    }
472
473    #[tokio::test]
474    async fn test_circuit_breaker_half_open_to_closed() {
475        let config = CircuitBreakerConfig {
476            failure_threshold: 1,
477            minimum_requests: 1,
478            success_threshold: 2,
479            timeout_duration: Duration::from_millis(50),
480            ..CircuitBreakerConfig::default()
481        };
482
483        let breaker = CircuitBreaker::new(config);
484
485        // Open the circuit
486        let _ = breaker
487            .execute(|| async {
488                Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
489            })
490            .await;
491
492        // Wait and execute successful operations
493        tokio::time::sleep(Duration::from_millis(60)).await;
494
495        // First success (should be half-open)
496        let _ = breaker
497            .execute(|| async { Ok::<(), PostgresError>(()) })
498            .await;
499
500        // Second success (should close circuit)
501        let _ = breaker
502            .execute(|| async { Ok::<(), PostgresError>(()) })
503            .await;
504
505        assert_eq!(breaker.state(), CircuitState::Closed);
506    }
507
508    #[tokio::test]
509    async fn test_circuit_breaker_metrics() {
510        let breaker = CircuitBreaker::new(CircuitBreakerConfig::default());
511
512        // Execute some operations
513        let _ = breaker
514            .execute(|| async { Ok::<(), PostgresError>(()) })
515            .await;
516
517        let _ = breaker
518            .execute(|| async {
519                Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
520            })
521            .await;
522
523        let metrics = breaker.metrics().await;
524        assert_eq!(metrics.state, CircuitState::Closed);
525        assert_eq!(metrics.total_requests, 2);
526        assert_eq!(metrics.total_failures, 1);
527        assert_eq!(metrics.failure_rate, 0.5);
528    }
529
530    #[tokio::test]
531    async fn test_circuit_breaker_reset() {
532        let config = CircuitBreakerConfig {
533            failure_threshold: 1,
534            minimum_requests: 1,
535            ..CircuitBreakerConfig::default()
536        };
537
538        let breaker = CircuitBreaker::new(config);
539
540        // Open the circuit
541        let _ = breaker
542            .execute(|| async {
543                Err::<(), _>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
544            })
545            .await;
546
547        assert_eq!(breaker.state(), CircuitState::Open);
548
549        // Reset the circuit
550        breaker.reset().await;
551        assert_eq!(breaker.state(), CircuitState::Closed);
552
553        let metrics = breaker.metrics().await;
554        assert!(metrics.last_failure_time.is_none());
555        assert!(metrics.last_failure_reason.is_none());
556    }
557
558    #[tokio::test]
559    async fn test_sliding_window() {
560        let window = SlidingWindow::new(Duration::from_millis(100));
561
562        // Record some requests
563        window.record_request(true).await;
564        window.record_request(false).await;
565        window.record_request(true).await;
566
567        let (total, failures) = window.get_metrics().await;
568        assert_eq!(total, 3);
569        assert_eq!(failures, 1);
570
571        // Wait for window to expire
572        tokio::time::sleep(Duration::from_millis(150)).await;
573
574        let (total, failures) = window.get_metrics().await;
575        assert_eq!(total, 0);
576        assert_eq!(failures, 0);
577    }
578}