codex_memory/mcp/
circuit_breaker.rs1use std::sync::Arc;
2use std::time::{Duration, Instant};
3use tokio::sync::RwLock;
4use tracing::{debug, error, info, warn};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum CircuitState {
8 Closed,
9 Open,
10 HalfOpen,
11}
12
13#[derive(Clone)]
14pub struct CircuitBreakerConfig {
15 pub failure_threshold: u32,
16 pub success_threshold: u32,
17 pub timeout: Duration,
18 pub half_open_max_calls: u32,
19}
20
21impl Default for CircuitBreakerConfig {
22 fn default() -> Self {
23 Self {
24 failure_threshold: 5,
25 success_threshold: 2,
26 timeout: Duration::from_secs(60),
27 half_open_max_calls: 3,
28 }
29 }
30}
31
32pub struct CircuitBreaker {
33 config: CircuitBreakerConfig,
34 state: Arc<RwLock<CircuitState>>,
35 failure_count: Arc<RwLock<u32>>,
36 success_count: Arc<RwLock<u32>>,
37 last_failure_time: Arc<RwLock<Option<Instant>>>,
38 half_open_calls: Arc<RwLock<u32>>,
39}
40
41impl CircuitBreaker {
42 pub fn new(config: CircuitBreakerConfig) -> Self {
43 Self {
44 config,
45 state: Arc::new(RwLock::new(CircuitState::Closed)),
46 failure_count: Arc::new(RwLock::new(0)),
47 success_count: Arc::new(RwLock::new(0)),
48 last_failure_time: Arc::new(RwLock::new(None)),
49 half_open_calls: Arc::new(RwLock::new(0)),
50 }
51 }
52
53 pub async fn call<F, T, E>(&self, f: F) -> Result<T, E>
54 where
55 F: FnOnce() -> Result<T, E>,
56 E: std::fmt::Display,
57 {
58 let state = self.get_state().await;
59
60 match state {
61 CircuitState::Open => {
62 if self.should_attempt_reset().await {
63 self.transition_to_half_open().await;
64 } else {
65 error!("Circuit breaker is open, rejecting call");
66 return Err(self.create_circuit_open_error());
67 }
68 }
69 CircuitState::HalfOpen => {
70 let calls = *self.half_open_calls.read().await;
71 if calls >= self.config.half_open_max_calls {
72 warn!("Circuit breaker half-open limit reached");
73 return Err(self.create_circuit_open_error());
74 }
75 *self.half_open_calls.write().await += 1;
76 }
77 CircuitState::Closed => {}
78 }
79
80 match f() {
81 Ok(result) => {
82 self.on_success().await;
83 Ok(result)
84 }
85 Err(error) => {
86 self.on_failure().await;
87 error!("Circuit breaker call failed: {}", error);
88 Err(error)
89 }
90 }
91 }
92
93 async fn get_state(&self) -> CircuitState {
94 *self.state.read().await
95 }
96
97 async fn on_success(&self) {
98 let mut state = self.state.write().await;
99 let mut success_count = self.success_count.write().await;
100 let mut failure_count = self.failure_count.write().await;
101
102 match *state {
103 CircuitState::HalfOpen => {
104 *success_count += 1;
105 if *success_count >= self.config.success_threshold {
106 *state = CircuitState::Closed;
107 *failure_count = 0;
108 *success_count = 0;
109 *self.half_open_calls.write().await = 0;
110 info!("Circuit breaker closed after successful recovery");
111 }
112 }
113 CircuitState::Closed => {
114 *failure_count = 0;
115 }
116 _ => {}
117 }
118 }
119
120 async fn on_failure(&self) {
121 let mut state = self.state.write().await;
122 let mut failure_count = self.failure_count.write().await;
123 let mut last_failure_time = self.last_failure_time.write().await;
124
125 *failure_count += 1;
126 *last_failure_time = Some(Instant::now());
127
128 match *state {
129 CircuitState::Closed => {
130 if *failure_count >= self.config.failure_threshold {
131 *state = CircuitState::Open;
132 warn!("Circuit breaker opened after {} failures", failure_count);
133 }
134 }
135 CircuitState::HalfOpen => {
136 *state = CircuitState::Open;
137 *self.success_count.write().await = 0;
138 *self.half_open_calls.write().await = 0;
139 warn!("Circuit breaker reopened from half-open state");
140 }
141 _ => {}
142 }
143 }
144
145 async fn should_attempt_reset(&self) -> bool {
146 if let Some(last_failure) = *self.last_failure_time.read().await {
147 last_failure.elapsed() >= self.config.timeout
148 } else {
149 false
150 }
151 }
152
153 async fn transition_to_half_open(&self) {
154 let mut state = self.state.write().await;
155 *state = CircuitState::HalfOpen;
156 *self.half_open_calls.write().await = 0;
157 info!("Circuit breaker transitioned to half-open");
158 }
159
160 fn create_circuit_open_error<E>(&self) -> E
161 where
162 E: std::fmt::Display,
163 {
164 panic!("Circuit breaker is open")
166 }
167
168 pub async fn get_stats(&self) -> CircuitBreakerStats {
169 CircuitBreakerStats {
170 state: *self.state.read().await,
171 failure_count: *self.failure_count.read().await,
172 success_count: *self.success_count.read().await,
173 half_open_calls: *self.half_open_calls.read().await,
174 }
175 }
176
177 pub async fn reset(&self) {
178 *self.state.write().await = CircuitState::Closed;
179 *self.failure_count.write().await = 0;
180 *self.success_count.write().await = 0;
181 *self.last_failure_time.write().await = None;
182 *self.half_open_calls.write().await = 0;
183 debug!("Circuit breaker manually reset");
184 }
185}
186
187#[derive(Debug, Clone)]
188pub struct CircuitBreakerStats {
189 pub state: CircuitState,
190 pub failure_count: u32,
191 pub success_count: u32,
192 pub half_open_calls: u32,
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[tokio::test]
200 async fn test_circuit_breaker_transitions() {
201 let config = CircuitBreakerConfig {
202 failure_threshold: 2,
203 success_threshold: 2,
204 timeout: Duration::from_millis(100),
205 half_open_max_calls: 3,
206 };
207
208 let cb = CircuitBreaker::new(config);
209
210 assert_eq!(cb.get_state().await, CircuitState::Closed);
212
213 for _ in 0..2 {
215 cb.on_failure().await;
216 }
217 assert_eq!(cb.get_state().await, CircuitState::Open);
218
219 tokio::time::sleep(Duration::from_millis(150)).await;
221 assert!(cb.should_attempt_reset().await);
222 }
223
224 #[tokio::test]
225 async fn test_circuit_breaker_stats() {
226 let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
227
228 let stats = cb.get_stats().await;
229 assert_eq!(stats.state, CircuitState::Closed);
230 assert_eq!(stats.failure_count, 0);
231 assert_eq!(stats.success_count, 0);
232 }
233}