codex_memory/mcp/
circuit_breaker.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3use tokio::sync::RwLock;
4use tracing::{debug, error, info, warn};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum CircuitState {
8    Closed,
9    Open,
10    HalfOpen,
11}
12
13#[derive(Clone)]
14pub struct CircuitBreakerConfig {
15    pub failure_threshold: u32,
16    pub success_threshold: u32,
17    pub timeout: Duration,
18    pub half_open_max_calls: u32,
19}
20
21impl Default for CircuitBreakerConfig {
22    fn default() -> Self {
23        Self {
24            failure_threshold: 5,
25            success_threshold: 2,
26            timeout: Duration::from_secs(60),
27            half_open_max_calls: 3,
28        }
29    }
30}
31
32pub struct CircuitBreaker {
33    config: CircuitBreakerConfig,
34    state: Arc<RwLock<CircuitState>>,
35    failure_count: Arc<RwLock<u32>>,
36    success_count: Arc<RwLock<u32>>,
37    last_failure_time: Arc<RwLock<Option<Instant>>>,
38    half_open_calls: Arc<RwLock<u32>>,
39}
40
41impl CircuitBreaker {
42    pub fn new(config: CircuitBreakerConfig) -> Self {
43        Self {
44            config,
45            state: Arc::new(RwLock::new(CircuitState::Closed)),
46            failure_count: Arc::new(RwLock::new(0)),
47            success_count: Arc::new(RwLock::new(0)),
48            last_failure_time: Arc::new(RwLock::new(None)),
49            half_open_calls: Arc::new(RwLock::new(0)),
50        }
51    }
52
53    pub async fn call<F, T, E>(&self, f: F) -> Result<T, E>
54    where
55        F: FnOnce() -> Result<T, E>,
56        E: std::fmt::Display,
57    {
58        let state = self.get_state().await;
59
60        match state {
61            CircuitState::Open => {
62                if self.should_attempt_reset().await {
63                    self.transition_to_half_open().await;
64                } else {
65                    error!("Circuit breaker is open, rejecting call");
66                    return Err(self.create_circuit_open_error());
67                }
68            }
69            CircuitState::HalfOpen => {
70                let calls = *self.half_open_calls.read().await;
71                if calls >= self.config.half_open_max_calls {
72                    warn!("Circuit breaker half-open limit reached");
73                    return Err(self.create_circuit_open_error());
74                }
75                *self.half_open_calls.write().await += 1;
76            }
77            CircuitState::Closed => {}
78        }
79
80        match f() {
81            Ok(result) => {
82                self.on_success().await;
83                Ok(result)
84            }
85            Err(error) => {
86                self.on_failure().await;
87                error!("Circuit breaker call failed: {}", error);
88                Err(error)
89            }
90        }
91    }
92
93    async fn get_state(&self) -> CircuitState {
94        *self.state.read().await
95    }
96
97    async fn on_success(&self) {
98        let mut state = self.state.write().await;
99        let mut success_count = self.success_count.write().await;
100        let mut failure_count = self.failure_count.write().await;
101
102        match *state {
103            CircuitState::HalfOpen => {
104                *success_count += 1;
105                if *success_count >= self.config.success_threshold {
106                    *state = CircuitState::Closed;
107                    *failure_count = 0;
108                    *success_count = 0;
109                    *self.half_open_calls.write().await = 0;
110                    info!("Circuit breaker closed after successful recovery");
111                }
112            }
113            CircuitState::Closed => {
114                *failure_count = 0;
115            }
116            _ => {}
117        }
118    }
119
120    async fn on_failure(&self) {
121        let mut state = self.state.write().await;
122        let mut failure_count = self.failure_count.write().await;
123        let mut last_failure_time = self.last_failure_time.write().await;
124
125        *failure_count += 1;
126        *last_failure_time = Some(Instant::now());
127
128        match *state {
129            CircuitState::Closed => {
130                if *failure_count >= self.config.failure_threshold {
131                    *state = CircuitState::Open;
132                    warn!("Circuit breaker opened after {} failures", failure_count);
133                }
134            }
135            CircuitState::HalfOpen => {
136                *state = CircuitState::Open;
137                *self.success_count.write().await = 0;
138                *self.half_open_calls.write().await = 0;
139                warn!("Circuit breaker reopened from half-open state");
140            }
141            _ => {}
142        }
143    }
144
145    async fn should_attempt_reset(&self) -> bool {
146        if let Some(last_failure) = *self.last_failure_time.read().await {
147            last_failure.elapsed() >= self.config.timeout
148        } else {
149            false
150        }
151    }
152
153    async fn transition_to_half_open(&self) {
154        let mut state = self.state.write().await;
155        *state = CircuitState::HalfOpen;
156        *self.half_open_calls.write().await = 0;
157        info!("Circuit breaker transitioned to half-open");
158    }
159
160    fn create_circuit_open_error<E>(&self) -> E
161    where
162        E: std::fmt::Display,
163    {
164        // This is a placeholder - in real implementation, you'd create proper error type
165        panic!("Circuit breaker is open")
166    }
167
168    pub async fn get_stats(&self) -> CircuitBreakerStats {
169        CircuitBreakerStats {
170            state: *self.state.read().await,
171            failure_count: *self.failure_count.read().await,
172            success_count: *self.success_count.read().await,
173            half_open_calls: *self.half_open_calls.read().await,
174        }
175    }
176
177    pub async fn reset(&self) {
178        *self.state.write().await = CircuitState::Closed;
179        *self.failure_count.write().await = 0;
180        *self.success_count.write().await = 0;
181        *self.last_failure_time.write().await = None;
182        *self.half_open_calls.write().await = 0;
183        debug!("Circuit breaker manually reset");
184    }
185}
186
187#[derive(Debug, Clone)]
188pub struct CircuitBreakerStats {
189    pub state: CircuitState,
190    pub failure_count: u32,
191    pub success_count: u32,
192    pub half_open_calls: u32,
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[tokio::test]
200    async fn test_circuit_breaker_transitions() {
201        let config = CircuitBreakerConfig {
202            failure_threshold: 2,
203            success_threshold: 2,
204            timeout: Duration::from_millis(100),
205            half_open_max_calls: 3,
206        };
207
208        let cb = CircuitBreaker::new(config);
209
210        // Initially closed
211        assert_eq!(cb.get_state().await, CircuitState::Closed);
212
213        // Simulate failures to open the circuit
214        for _ in 0..2 {
215            cb.on_failure().await;
216        }
217        assert_eq!(cb.get_state().await, CircuitState::Open);
218
219        // Wait for timeout and check half-open
220        tokio::time::sleep(Duration::from_millis(150)).await;
221        assert!(cb.should_attempt_reset().await);
222    }
223
224    #[tokio::test]
225    async fn test_circuit_breaker_stats() {
226        let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
227
228        let stats = cb.get_stats().await;
229        assert_eq!(stats.state, CircuitState::Closed);
230        assert_eq!(stats.failure_count, 0);
231        assert_eq!(stats.success_count, 0);
232    }
233}