1use crate::{Error, Result};
2use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6use tracing::{debug, info, warn, instrument};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CircuitState {
11 Closed,
13 Open,
15 HalfOpen,
17}
18
19impl std::fmt::Display for CircuitState {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 match self {
22 CircuitState::Closed => write!(f, "CLOSED"),
23 CircuitState::Open => write!(f, "OPEN"),
24 CircuitState::HalfOpen => write!(f, "HALF_OPEN"),
25 }
26 }
27}
28
29#[derive(Debug, Clone)]
31pub struct CircuitBreakerConfig {
32 pub failure_threshold: u32,
34 pub timeout: Duration,
36 pub success_threshold: u32,
38 pub failure_time_window: Duration,
40}
41
42impl Default for CircuitBreakerConfig {
43 fn default() -> Self {
44 Self {
45 failure_threshold: 5,
46 timeout: Duration::from_secs(30),
47 success_threshold: 3,
48 failure_time_window: Duration::from_secs(60),
49 }
50 }
51}
52
53#[derive(Debug)]
55pub struct CircuitBreaker {
56 config: CircuitBreakerConfig,
57 state: Arc<RwLock<CircuitState>>,
58 failure_count: AtomicU64,
59 success_count: AtomicU64,
60 last_failure_time: Arc<RwLock<Option<Instant>>>,
61 last_state_change: Arc<RwLock<Instant>>,
62 recent_failures: Arc<RwLock<Vec<Instant>>>,
64}
65
66impl CircuitBreaker {
67 pub fn new(config: CircuitBreakerConfig) -> Self {
68 Self {
69 config,
70 state: Arc::new(RwLock::new(CircuitState::Closed)),
71 failure_count: AtomicU64::new(0),
72 success_count: AtomicU64::new(0),
73 last_failure_time: Arc::new(RwLock::new(None)),
74 last_state_change: Arc::new(RwLock::new(Instant::now())),
75 recent_failures: Arc::new(RwLock::new(Vec::new())),
76 }
77 }
78
79 #[instrument(skip(self))]
81 pub async fn can_proceed(&self) -> bool {
82 let current_state = *self.state.read().await;
83
84 match current_state {
85 CircuitState::Closed => {
86 debug!("Circuit closed, allowing request");
87 true
88 }
89 CircuitState::Open => {
90 let should_attempt_reset = self.should_attempt_reset().await;
91 if should_attempt_reset {
92 self.transition_to_half_open().await;
93 debug!("Circuit transitioning to half-open, allowing request");
94 true
95 } else {
96 debug!("Circuit open, rejecting request");
97 false
98 }
99 }
100 CircuitState::HalfOpen => {
101 debug!("Circuit half-open, allowing limited request");
102 true
103 }
104 }
105 }
106
107 #[instrument(skip(self))]
109 pub async fn record_success(&self) {
110 let current_state = *self.state.read().await;
111 self.success_count.fetch_add(1, Ordering::Relaxed);
112
113 debug!(state = %current_state, "Recording success");
114
115 match current_state {
116 CircuitState::HalfOpen => {
117 let success_count = self.success_count.load(Ordering::Relaxed);
118 if success_count >= self.config.success_threshold as u64 {
119 self.transition_to_closed().await;
120 }
121 }
122 CircuitState::Open => {
123 warn!("Unexpected success in open circuit state");
126 }
127 CircuitState::Closed => {
128 }
130 }
131 }
132
133 #[instrument(skip(self), fields(error = %error))]
135 pub async fn record_failure<E: std::fmt::Display>(&self, error: E) {
136 let now = Instant::now();
137 let current_state = *self.state.read().await;
138
139 debug!(state = %current_state, error = %error, "Recording failure");
140
141 self.failure_count.fetch_add(1, Ordering::Relaxed);
142 *self.last_failure_time.write().await = Some(now);
143
144 {
146 let mut recent_failures = self.recent_failures.write().await;
147 recent_failures.push(now);
148
149 let cutoff = now - self.config.failure_time_window;
151 recent_failures.retain(|&failure_time| failure_time > cutoff);
152 }
153
154 let recent_failure_count = self.recent_failures.read().await.len();
156
157 if current_state != CircuitState::Open
158 && recent_failure_count >= self.config.failure_threshold as usize {
159 self.transition_to_open().await;
160 }
161 }
162
163 pub async fn stats(&self) -> CircuitBreakerStats {
165 let state = *self.state.read().await;
166 let recent_failures = self.recent_failures.read().await.len();
167 let last_state_change = *self.last_state_change.read().await;
168
169 CircuitBreakerStats {
170 state,
171 total_failures: self.failure_count.load(Ordering::Relaxed),
172 total_successes: self.success_count.load(Ordering::Relaxed),
173 recent_failures: recent_failures as u64,
174 time_in_current_state: last_state_change.elapsed(),
175 failure_rate: self.calculate_failure_rate().await,
176 }
177 }
178
179 #[instrument(skip(self))]
181 pub async fn reset(&self) {
182 info!("Manually resetting circuit breaker");
183 self.transition_to_closed().await;
184 self.failure_count.store(0, Ordering::Relaxed);
185 self.success_count.store(0, Ordering::Relaxed);
186 *self.last_failure_time.write().await = None;
187 self.recent_failures.write().await.clear();
188 }
189
190 async fn should_attempt_reset(&self) -> bool {
192 let last_state_change = *self.last_state_change.read().await;
193 last_state_change.elapsed() >= self.config.timeout
194 }
195
196 async fn transition_to_closed(&self) {
198 let mut state = self.state.write().await;
199 if *state != CircuitState::Closed {
200 info!(previous_state = %*state, "Circuit breaker transitioning to CLOSED");
201 *state = CircuitState::Closed;
202 *self.last_state_change.write().await = Instant::now();
203 self.success_count.store(0, Ordering::Relaxed);
204 }
205 }
206
207 async fn transition_to_open(&self) {
209 let mut state = self.state.write().await;
210 if *state != CircuitState::Open {
211 warn!(previous_state = %*state, "Circuit breaker transitioning to OPEN");
212 *state = CircuitState::Open;
213 *self.last_state_change.write().await = Instant::now();
214 }
215 }
216
217 async fn transition_to_half_open(&self) {
219 let mut state = self.state.write().await;
220 if *state != CircuitState::HalfOpen {
221 info!(previous_state = %*state, "Circuit breaker transitioning to HALF_OPEN");
222 *state = CircuitState::HalfOpen;
223 *self.last_state_change.write().await = Instant::now();
224 self.success_count.store(0, Ordering::Relaxed);
225 }
226 }
227
228 async fn calculate_failure_rate(&self) -> f64 {
230 let recent_failures = self.recent_failures.read().await;
231 let total_failures = self.failure_count.load(Ordering::Relaxed);
232 let total_successes = self.success_count.load(Ordering::Relaxed);
233 let total_requests = total_failures + total_successes;
234
235 if total_requests == 0 {
236 0.0
237 } else {
238 (recent_failures.len() as f64) / (total_requests as f64) * 100.0
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct CircuitBreakerStats {
246 pub state: CircuitState,
247 pub total_failures: u64,
248 pub total_successes: u64,
249 pub recent_failures: u64,
250 pub time_in_current_state: Duration,
251 pub failure_rate: f64,
252}
253
254impl CircuitBreakerStats {
255 pub fn is_healthy(&self) -> bool {
257 matches!(self.state, CircuitState::Closed) && self.failure_rate < 5.0
258 }
259
260 pub fn status_string(&self) -> String {
262 format!(
263 "Circuit: {} | Failures: {}/{} | Rate: {:.1}% | Uptime: {}s",
264 self.state,
265 self.recent_failures,
266 self.total_failures + self.total_successes,
267 self.failure_rate,
268 self.time_in_current_state.as_secs()
269 )
270 }
271}
272
273pub struct CircuitBreakerExecutor {
275 circuit_breaker: Arc<CircuitBreaker>,
276}
277
278impl CircuitBreakerExecutor {
279 pub fn new(config: CircuitBreakerConfig) -> Self {
280 Self {
281 circuit_breaker: Arc::new(CircuitBreaker::new(config)),
282 }
283 }
284
285 #[instrument(skip(self, operation))]
287 pub async fn execute<F, T, E>(&self, operation: F) -> Result<T>
288 where
289 F: std::future::Future<Output = std::result::Result<T, E>>,
290 E: std::fmt::Display + std::error::Error + Send + Sync + 'static,
291 {
292 if !self.circuit_breaker.can_proceed().await {
294 return Err(Error::Config("Circuit breaker is open, request rejected".to_string()));
295 }
296
297 match operation.await {
299 Ok(result) => {
300 self.circuit_breaker.record_success().await;
301 Ok(result)
302 }
303 Err(error) => {
304 self.circuit_breaker.record_failure(&error).await;
305 Err(Error::Config(format!("Operation failed: {}", error)))
306 }
307 }
308 }
309
310 pub async fn stats(&self) -> CircuitBreakerStats {
312 self.circuit_breaker.stats().await
313 }
314
315 pub async fn reset(&self) {
317 self.circuit_breaker.reset().await
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use tokio::time::{sleep, Duration};
325
326 #[tokio::test]
327 async fn test_circuit_breaker_states() {
328 let config = CircuitBreakerConfig {
329 failure_threshold: 2,
330 timeout: Duration::from_millis(100),
331 success_threshold: 1,
332 failure_time_window: Duration::from_secs(10),
333 };
334
335 let cb = CircuitBreaker::new(config);
336
337 assert!(cb.can_proceed().await);
339
340 cb.record_failure("test error 1").await;
342 cb.record_failure("test error 2").await;
343
344 assert!(!cb.can_proceed().await);
346
347 sleep(Duration::from_millis(150)).await;
349 assert!(cb.can_proceed().await);
350
351 cb.record_success().await;
353 assert!(cb.can_proceed().await);
354
355 let stats = cb.stats().await;
356 assert_eq!(stats.state, CircuitState::Closed);
357 }
358}