Skip to main content

model_context_protocol/
circuit_breaker.rs

1//! Circuit breaker pattern for resilient server connections.
2//!
3//! Implements a simple circuit breaker that tracks failures and prevents
4//! cascading failures by temporarily blocking requests to unhealthy servers.
5
6use parking_lot::RwLock;
7use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9
10/// Circuit breaker states
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum CircuitState {
13    /// Normal operation - requests are allowed
14    Closed,
15    /// Circuit is open - requests are blocked
16    Open,
17    /// Testing if service has recovered - limited requests allowed
18    HalfOpen,
19}
20
21/// Configuration for circuit breaker behavior
22#[derive(Debug, Clone)]
23pub struct CircuitBreakerConfig {
24    /// Number of consecutive failures before opening the circuit
25    pub failure_threshold: u32,
26    /// Duration to keep the circuit open before trying half-open
27    pub open_duration: Duration,
28    /// Number of successful requests in half-open state to close the circuit
29    pub half_open_successes: u32,
30}
31
32impl Default for CircuitBreakerConfig {
33    fn default() -> Self {
34        Self {
35            failure_threshold: 5,
36            open_duration: Duration::from_secs(30),
37            half_open_successes: 2,
38        }
39    }
40}
41
42/// Circuit breaker for managing connection health
43#[derive(Debug)]
44pub struct CircuitBreaker {
45    config: CircuitBreakerConfig,
46    state: RwLock<CircuitState>,
47    failure_count: AtomicU32,
48    success_count: AtomicU32,
49    last_failure_time: RwLock<Option<Instant>>,
50    total_requests: AtomicU64,
51    total_failures: AtomicU64,
52}
53
54impl CircuitBreaker {
55    /// Create a new circuit breaker with default configuration
56    pub fn new() -> Self {
57        Self::with_config(CircuitBreakerConfig::default())
58    }
59
60    /// Create a new circuit breaker with custom configuration
61    pub fn with_config(config: CircuitBreakerConfig) -> Self {
62        Self {
63            config,
64            state: RwLock::new(CircuitState::Closed),
65            failure_count: AtomicU32::new(0),
66            success_count: AtomicU32::new(0),
67            last_failure_time: RwLock::new(None),
68            total_requests: AtomicU64::new(0),
69            total_failures: AtomicU64::new(0),
70        }
71    }
72
73    /// Get the current circuit state
74    pub fn state(&self) -> CircuitState {
75        *self.state.read()
76    }
77
78    /// Check if a request should be allowed
79    ///
80    /// Returns `true` if the request can proceed, `false` if it should be blocked.
81    pub fn allow_request(&self) -> bool {
82        self.total_requests.fetch_add(1, Ordering::Relaxed);
83
84        let current_state = *self.state.read();
85
86        match current_state {
87            CircuitState::Closed => true,
88            CircuitState::Open => {
89                // Check if we should transition to half-open
90                if let Some(last_failure) = *self.last_failure_time.read() {
91                    if last_failure.elapsed() >= self.config.open_duration {
92                        let mut state = self.state.write();
93                        if *state == CircuitState::Open {
94                            *state = CircuitState::HalfOpen;
95                            self.success_count.store(0, Ordering::Relaxed);
96                            drop(state);
97                            return true;
98                        }
99                    }
100                }
101                false
102            }
103            CircuitState::HalfOpen => true,
104        }
105    }
106
107    /// Record a successful request
108    pub fn record_success(&self) {
109        let current_state = *self.state.read();
110
111        match current_state {
112            CircuitState::Closed => {
113                // Reset failure count on success
114                self.failure_count.store(0, Ordering::Relaxed);
115            }
116            CircuitState::HalfOpen => {
117                let successes = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
118                if successes >= self.config.half_open_successes {
119                    let mut state = self.state.write();
120                    *state = CircuitState::Closed;
121                    self.failure_count.store(0, Ordering::Relaxed);
122                    self.success_count.store(0, Ordering::Relaxed);
123                }
124            }
125            CircuitState::Open => {
126                // Shouldn't happen, but reset if it does
127            }
128        }
129    }
130
131    /// Record a failed request
132    pub fn record_failure(&self) {
133        self.total_failures.fetch_add(1, Ordering::Relaxed);
134        *self.last_failure_time.write() = Some(Instant::now());
135
136        let current_state = *self.state.read();
137
138        match current_state {
139            CircuitState::Closed => {
140                let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
141                if failures >= self.config.failure_threshold {
142                    let mut state = self.state.write();
143                    *state = CircuitState::Open;
144                }
145            }
146            CircuitState::HalfOpen => {
147                // Any failure in half-open immediately opens the circuit
148                let mut state = self.state.write();
149                *state = CircuitState::Open;
150                self.success_count.store(0, Ordering::Relaxed);
151            }
152            CircuitState::Open => {
153                // Already open, nothing to do
154            }
155        }
156    }
157
158    /// Manually reset the circuit breaker to closed state
159    pub fn reset(&self) {
160        let mut state = self.state.write();
161        *state = CircuitState::Closed;
162        self.failure_count.store(0, Ordering::Relaxed);
163        self.success_count.store(0, Ordering::Relaxed);
164    }
165
166    /// Get statistics about the circuit breaker
167    pub fn stats(&self) -> CircuitBreakerStats {
168        CircuitBreakerStats {
169            state: *self.state.read(),
170            failure_count: self.failure_count.load(Ordering::Relaxed),
171            total_requests: self.total_requests.load(Ordering::Relaxed),
172            total_failures: self.total_failures.load(Ordering::Relaxed),
173        }
174    }
175}
176
177impl Default for CircuitBreaker {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183/// Statistics about circuit breaker state
184#[derive(Debug, Clone)]
185pub struct CircuitBreakerStats {
186    pub state: CircuitState,
187    pub failure_count: u32,
188    pub total_requests: u64,
189    pub total_failures: u64,
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_circuit_breaker_starts_closed() {
198        let cb = CircuitBreaker::new();
199        assert_eq!(cb.state(), CircuitState::Closed);
200        assert!(cb.allow_request());
201    }
202
203    #[test]
204    fn test_circuit_opens_after_failures() {
205        let config = CircuitBreakerConfig {
206            failure_threshold: 3,
207            ..Default::default()
208        };
209        let cb = CircuitBreaker::with_config(config);
210
211        cb.record_failure();
212        cb.record_failure();
213        assert_eq!(cb.state(), CircuitState::Closed);
214
215        cb.record_failure();
216        assert_eq!(cb.state(), CircuitState::Open);
217        assert!(!cb.allow_request());
218    }
219
220    #[test]
221    fn test_success_resets_failure_count() {
222        let config = CircuitBreakerConfig {
223            failure_threshold: 3,
224            ..Default::default()
225        };
226        let cb = CircuitBreaker::with_config(config);
227
228        cb.record_failure();
229        cb.record_failure();
230        cb.record_success();
231        cb.record_failure();
232        cb.record_failure();
233
234        assert_eq!(cb.state(), CircuitState::Closed);
235    }
236
237    #[test]
238    fn test_manual_reset() {
239        let config = CircuitBreakerConfig {
240            failure_threshold: 1,
241            ..Default::default()
242        };
243        let cb = CircuitBreaker::with_config(config);
244
245        cb.record_failure();
246        assert_eq!(cb.state(), CircuitState::Open);
247
248        cb.reset();
249        assert_eq!(cb.state(), CircuitState::Closed);
250        assert!(cb.allow_request());
251    }
252}