model_context_protocol/
circuit_breaker.rs1use parking_lot::RwLock;
7use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum CircuitState {
13 Closed,
15 Open,
17 HalfOpen,
19}
20
21#[derive(Debug, Clone)]
23pub struct CircuitBreakerConfig {
24 pub failure_threshold: u32,
26 pub open_duration: Duration,
28 pub half_open_successes: u32,
30}
31
32impl Default for CircuitBreakerConfig {
33 fn default() -> Self {
34 Self {
35 failure_threshold: 5,
36 open_duration: Duration::from_secs(30),
37 half_open_successes: 2,
38 }
39 }
40}
41
42#[derive(Debug)]
44pub struct CircuitBreaker {
45 config: CircuitBreakerConfig,
46 state: RwLock<CircuitState>,
47 failure_count: AtomicU32,
48 success_count: AtomicU32,
49 last_failure_time: RwLock<Option<Instant>>,
50 total_requests: AtomicU64,
51 total_failures: AtomicU64,
52}
53
54impl CircuitBreaker {
55 pub fn new() -> Self {
57 Self::with_config(CircuitBreakerConfig::default())
58 }
59
60 pub fn with_config(config: CircuitBreakerConfig) -> Self {
62 Self {
63 config,
64 state: RwLock::new(CircuitState::Closed),
65 failure_count: AtomicU32::new(0),
66 success_count: AtomicU32::new(0),
67 last_failure_time: RwLock::new(None),
68 total_requests: AtomicU64::new(0),
69 total_failures: AtomicU64::new(0),
70 }
71 }
72
73 pub fn state(&self) -> CircuitState {
75 *self.state.read()
76 }
77
78 pub fn allow_request(&self) -> bool {
82 self.total_requests.fetch_add(1, Ordering::Relaxed);
83
84 let current_state = *self.state.read();
85
86 match current_state {
87 CircuitState::Closed => true,
88 CircuitState::Open => {
89 if let Some(last_failure) = *self.last_failure_time.read() {
91 if last_failure.elapsed() >= self.config.open_duration {
92 let mut state = self.state.write();
93 if *state == CircuitState::Open {
94 *state = CircuitState::HalfOpen;
95 self.success_count.store(0, Ordering::Relaxed);
96 drop(state);
97 return true;
98 }
99 }
100 }
101 false
102 }
103 CircuitState::HalfOpen => true,
104 }
105 }
106
107 pub fn record_success(&self) {
109 let current_state = *self.state.read();
110
111 match current_state {
112 CircuitState::Closed => {
113 self.failure_count.store(0, Ordering::Relaxed);
115 }
116 CircuitState::HalfOpen => {
117 let successes = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
118 if successes >= self.config.half_open_successes {
119 let mut state = self.state.write();
120 *state = CircuitState::Closed;
121 self.failure_count.store(0, Ordering::Relaxed);
122 self.success_count.store(0, Ordering::Relaxed);
123 }
124 }
125 CircuitState::Open => {
126 }
128 }
129 }
130
131 pub fn record_failure(&self) {
133 self.total_failures.fetch_add(1, Ordering::Relaxed);
134 *self.last_failure_time.write() = Some(Instant::now());
135
136 let current_state = *self.state.read();
137
138 match current_state {
139 CircuitState::Closed => {
140 let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
141 if failures >= self.config.failure_threshold {
142 let mut state = self.state.write();
143 *state = CircuitState::Open;
144 }
145 }
146 CircuitState::HalfOpen => {
147 let mut state = self.state.write();
149 *state = CircuitState::Open;
150 self.success_count.store(0, Ordering::Relaxed);
151 }
152 CircuitState::Open => {
153 }
155 }
156 }
157
158 pub fn reset(&self) {
160 let mut state = self.state.write();
161 *state = CircuitState::Closed;
162 self.failure_count.store(0, Ordering::Relaxed);
163 self.success_count.store(0, Ordering::Relaxed);
164 }
165
166 pub fn stats(&self) -> CircuitBreakerStats {
168 CircuitBreakerStats {
169 state: *self.state.read(),
170 failure_count: self.failure_count.load(Ordering::Relaxed),
171 total_requests: self.total_requests.load(Ordering::Relaxed),
172 total_failures: self.total_failures.load(Ordering::Relaxed),
173 }
174 }
175}
176
177impl Default for CircuitBreaker {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct CircuitBreakerStats {
186 pub state: CircuitState,
187 pub failure_count: u32,
188 pub total_requests: u64,
189 pub total_failures: u64,
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_circuit_breaker_starts_closed() {
198 let cb = CircuitBreaker::new();
199 assert_eq!(cb.state(), CircuitState::Closed);
200 assert!(cb.allow_request());
201 }
202
203 #[test]
204 fn test_circuit_opens_after_failures() {
205 let config = CircuitBreakerConfig {
206 failure_threshold: 3,
207 ..Default::default()
208 };
209 let cb = CircuitBreaker::with_config(config);
210
211 cb.record_failure();
212 cb.record_failure();
213 assert_eq!(cb.state(), CircuitState::Closed);
214
215 cb.record_failure();
216 assert_eq!(cb.state(), CircuitState::Open);
217 assert!(!cb.allow_request());
218 }
219
220 #[test]
221 fn test_success_resets_failure_count() {
222 let config = CircuitBreakerConfig {
223 failure_threshold: 3,
224 ..Default::default()
225 };
226 let cb = CircuitBreaker::with_config(config);
227
228 cb.record_failure();
229 cb.record_failure();
230 cb.record_success();
231 cb.record_failure();
232 cb.record_failure();
233
234 assert_eq!(cb.state(), CircuitState::Closed);
235 }
236
237 #[test]
238 fn test_manual_reset() {
239 let config = CircuitBreakerConfig {
240 failure_threshold: 1,
241 ..Default::default()
242 };
243 let cb = CircuitBreaker::with_config(config);
244
245 cb.record_failure();
246 assert_eq!(cb.state(), CircuitState::Open);
247
248 cb.reset();
249 assert_eq!(cb.state(), CircuitState::Closed);
250 assert!(cb.allow_request());
251 }
252}