Skip to main content

goldrush_sdk/
circuit_breaker.rs

1use crate::{Error, Result};
2use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6use tracing::{debug, info, warn, instrument};
7
8/// Circuit breaker states.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CircuitState {
11    /// Circuit is closed, requests flow normally.
12    Closed,
13    /// Circuit is open, requests are rejected immediately.
14    Open,
15    /// Circuit is half-open, allowing limited requests to test recovery.
16    HalfOpen,
17}
18
19impl std::fmt::Display for CircuitState {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            CircuitState::Closed => write!(f, "CLOSED"),
23            CircuitState::Open => write!(f, "OPEN"),
24            CircuitState::HalfOpen => write!(f, "HALF_OPEN"),
25        }
26    }
27}
28
29/// Configuration for the circuit breaker.
30#[derive(Debug, Clone)]
31pub struct CircuitBreakerConfig {
32    /// Number of consecutive failures before opening the circuit.
33    pub failure_threshold: u32,
34    /// Duration to keep circuit open before attempting recovery.
35    pub timeout: Duration,
36    /// Number of successful requests needed to close circuit from half-open.
37    pub success_threshold: u32,
38    /// Time window for counting failures.
39    pub failure_time_window: Duration,
40}
41
42impl Default for CircuitBreakerConfig {
43    fn default() -> Self {
44        Self {
45            failure_threshold: 5,
46            timeout: Duration::from_secs(30),
47            success_threshold: 3,
48            failure_time_window: Duration::from_secs(60),
49        }
50    }
51}
52
53/// Circuit breaker for preventing cascading failures.
54#[derive(Debug)]
55pub struct CircuitBreaker {
56    config: CircuitBreakerConfig,
57    state: Arc<RwLock<CircuitState>>,
58    failure_count: AtomicU64,
59    success_count: AtomicU64,
60    last_failure_time: Arc<RwLock<Option<Instant>>>,
61    last_state_change: Arc<RwLock<Instant>>,
62    /// Recent failures within the time window
63    recent_failures: Arc<RwLock<Vec<Instant>>>,
64}
65
66impl CircuitBreaker {
67    pub fn new(config: CircuitBreakerConfig) -> Self {
68        Self {
69            config,
70            state: Arc::new(RwLock::new(CircuitState::Closed)),
71            failure_count: AtomicU64::new(0),
72            success_count: AtomicU64::new(0),
73            last_failure_time: Arc::new(RwLock::new(None)),
74            last_state_change: Arc::new(RwLock::new(Instant::now())),
75            recent_failures: Arc::new(RwLock::new(Vec::new())),
76        }
77    }
78    
79    /// Check if a request can proceed.
80    #[instrument(skip(self))]
81    pub async fn can_proceed(&self) -> bool {
82        let current_state = *self.state.read().await;
83        
84        match current_state {
85            CircuitState::Closed => {
86                debug!("Circuit closed, allowing request");
87                true
88            }
89            CircuitState::Open => {
90                let should_attempt_reset = self.should_attempt_reset().await;
91                if should_attempt_reset {
92                    self.transition_to_half_open().await;
93                    debug!("Circuit transitioning to half-open, allowing request");
94                    true
95                } else {
96                    debug!("Circuit open, rejecting request");
97                    false
98                }
99            }
100            CircuitState::HalfOpen => {
101                debug!("Circuit half-open, allowing limited request");
102                true
103            }
104        }
105    }
106    
107    /// Record a successful operation.
108    #[instrument(skip(self))]
109    pub async fn record_success(&self) {
110        let current_state = *self.state.read().await;
111        self.success_count.fetch_add(1, Ordering::Relaxed);
112        
113        debug!(state = %current_state, "Recording success");
114        
115        match current_state {
116            CircuitState::HalfOpen => {
117                let success_count = self.success_count.load(Ordering::Relaxed);
118                if success_count >= self.config.success_threshold as u64 {
119                    self.transition_to_closed().await;
120                }
121            }
122            CircuitState::Open => {
123                // Success in open state should not happen, but if it does,
124                // it might indicate the service is recovering
125                warn!("Unexpected success in open circuit state");
126            }
127            CircuitState::Closed => {
128                // Normal operation, no state change needed
129            }
130        }
131    }
132    
133    /// Record a failed operation.
134    #[instrument(skip(self), fields(error = %error))]
135    pub async fn record_failure<E: std::fmt::Display>(&self, error: E) {
136        let now = Instant::now();
137        let current_state = *self.state.read().await;
138        
139        debug!(state = %current_state, error = %error, "Recording failure");
140        
141        self.failure_count.fetch_add(1, Ordering::Relaxed);
142        *self.last_failure_time.write().await = Some(now);
143        
144        // Add to recent failures
145        {
146            let mut recent_failures = self.recent_failures.write().await;
147            recent_failures.push(now);
148            
149            // Clean up old failures outside the time window
150            let cutoff = now - self.config.failure_time_window;
151            recent_failures.retain(|&failure_time| failure_time > cutoff);
152        }
153        
154        // Check if we should open the circuit
155        let recent_failure_count = self.recent_failures.read().await.len();
156        
157        if current_state != CircuitState::Open 
158            && recent_failure_count >= self.config.failure_threshold as usize {
159            self.transition_to_open().await;
160        }
161    }
162    
163    /// Get current circuit breaker statistics.
164    pub async fn stats(&self) -> CircuitBreakerStats {
165        let state = *self.state.read().await;
166        let recent_failures = self.recent_failures.read().await.len();
167        let last_state_change = *self.last_state_change.read().await;
168        
169        CircuitBreakerStats {
170            state,
171            total_failures: self.failure_count.load(Ordering::Relaxed),
172            total_successes: self.success_count.load(Ordering::Relaxed),
173            recent_failures: recent_failures as u64,
174            time_in_current_state: last_state_change.elapsed(),
175            failure_rate: self.calculate_failure_rate().await,
176        }
177    }
178    
179    /// Reset the circuit breaker to closed state.
180    #[instrument(skip(self))]
181    pub async fn reset(&self) {
182        info!("Manually resetting circuit breaker");
183        self.transition_to_closed().await;
184        self.failure_count.store(0, Ordering::Relaxed);
185        self.success_count.store(0, Ordering::Relaxed);
186        *self.last_failure_time.write().await = None;
187        self.recent_failures.write().await.clear();
188    }
189    
190    /// Check if enough time has passed to attempt reset from open state.
191    async fn should_attempt_reset(&self) -> bool {
192        let last_state_change = *self.last_state_change.read().await;
193        last_state_change.elapsed() >= self.config.timeout
194    }
195    
196    /// Transition to closed state.
197    async fn transition_to_closed(&self) {
198        let mut state = self.state.write().await;
199        if *state != CircuitState::Closed {
200            info!(previous_state = %*state, "Circuit breaker transitioning to CLOSED");
201            *state = CircuitState::Closed;
202            *self.last_state_change.write().await = Instant::now();
203            self.success_count.store(0, Ordering::Relaxed);
204        }
205    }
206    
207    /// Transition to open state.
208    async fn transition_to_open(&self) {
209        let mut state = self.state.write().await;
210        if *state != CircuitState::Open {
211            warn!(previous_state = %*state, "Circuit breaker transitioning to OPEN");
212            *state = CircuitState::Open;
213            *self.last_state_change.write().await = Instant::now();
214        }
215    }
216    
217    /// Transition to half-open state.
218    async fn transition_to_half_open(&self) {
219        let mut state = self.state.write().await;
220        if *state != CircuitState::HalfOpen {
221            info!(previous_state = %*state, "Circuit breaker transitioning to HALF_OPEN");
222            *state = CircuitState::HalfOpen;
223            *self.last_state_change.write().await = Instant::now();
224            self.success_count.store(0, Ordering::Relaxed);
225        }
226    }
227    
228    /// Calculate current failure rate.
229    async fn calculate_failure_rate(&self) -> f64 {
230        let recent_failures = self.recent_failures.read().await;
231        let total_failures = self.failure_count.load(Ordering::Relaxed);
232        let total_successes = self.success_count.load(Ordering::Relaxed);
233        let total_requests = total_failures + total_successes;
234        
235        if total_requests == 0 {
236            0.0
237        } else {
238            (recent_failures.len() as f64) / (total_requests as f64) * 100.0
239        }
240    }
241}
242
243/// Circuit breaker statistics.
244#[derive(Debug, Clone)]
245pub struct CircuitBreakerStats {
246    pub state: CircuitState,
247    pub total_failures: u64,
248    pub total_successes: u64,
249    pub recent_failures: u64,
250    pub time_in_current_state: Duration,
251    pub failure_rate: f64,
252}
253
254impl CircuitBreakerStats {
255    /// Check if the circuit breaker is healthy.
256    pub fn is_healthy(&self) -> bool {
257        matches!(self.state, CircuitState::Closed) && self.failure_rate < 5.0
258    }
259    
260    /// Get a human-readable status string.
261    pub fn status_string(&self) -> String {
262        format!(
263            "Circuit: {} | Failures: {}/{} | Rate: {:.1}% | Uptime: {}s",
264            self.state,
265            self.recent_failures,
266            self.total_failures + self.total_successes,
267            self.failure_rate,
268            self.time_in_current_state.as_secs()
269        )
270    }
271}
272
273/// Wrapper for executing operations through a circuit breaker.
274pub struct CircuitBreakerExecutor {
275    circuit_breaker: Arc<CircuitBreaker>,
276}
277
278impl CircuitBreakerExecutor {
279    pub fn new(config: CircuitBreakerConfig) -> Self {
280        Self {
281            circuit_breaker: Arc::new(CircuitBreaker::new(config)),
282        }
283    }
284    
285    /// Execute an operation through the circuit breaker.
286    #[instrument(skip(self, operation))]
287    pub async fn execute<F, T, E>(&self, operation: F) -> Result<T>
288    where
289        F: std::future::Future<Output = std::result::Result<T, E>>,
290        E: std::fmt::Display + std::error::Error + Send + Sync + 'static,
291    {
292        // Check if request can proceed
293        if !self.circuit_breaker.can_proceed().await {
294            return Err(Error::Config("Circuit breaker is open, request rejected".to_string()));
295        }
296        
297        // Execute the operation
298        match operation.await {
299            Ok(result) => {
300                self.circuit_breaker.record_success().await;
301                Ok(result)
302            }
303            Err(error) => {
304                self.circuit_breaker.record_failure(&error).await;
305                Err(Error::Config(format!("Operation failed: {}", error)))
306            }
307        }
308    }
309    
310    /// Get circuit breaker statistics.
311    pub async fn stats(&self) -> CircuitBreakerStats {
312        self.circuit_breaker.stats().await
313    }
314    
315    /// Reset the circuit breaker.
316    pub async fn reset(&self) {
317        self.circuit_breaker.reset().await
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use tokio::time::{sleep, Duration};
325    
326    #[tokio::test]
327    async fn test_circuit_breaker_states() {
328        let config = CircuitBreakerConfig {
329            failure_threshold: 2,
330            timeout: Duration::from_millis(100),
331            success_threshold: 1,
332            failure_time_window: Duration::from_secs(10),
333        };
334        
335        let cb = CircuitBreaker::new(config);
336        
337        // Initially closed
338        assert!(cb.can_proceed().await);
339        
340        // Record failures to open circuit
341        cb.record_failure("test error 1").await;
342        cb.record_failure("test error 2").await;
343        
344        // Should be open now
345        assert!(!cb.can_proceed().await);
346        
347        // Wait for timeout and transition to half-open
348        sleep(Duration::from_millis(150)).await;
349        assert!(cb.can_proceed().await);
350        
351        // Record success to close circuit
352        cb.record_success().await;
353        assert!(cb.can_proceed().await);
354        
355        let stats = cb.stats().await;
356        assert_eq!(stats.state, CircuitState::Closed);
357    }
358}