1use crate::{Error, Middleware, Result};
2use serde_json::Value;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CircuitState {
11 Closed, Open, HalfOpen, }
15
16#[derive(Debug, Clone)]
18pub struct CircuitBreakerConfig {
19 pub failure_threshold: usize,
21 pub timeout: Duration,
23 pub success_threshold: usize,
25}
26
27impl Default for CircuitBreakerConfig {
28 fn default() -> Self {
29 Self {
30 failure_threshold: 5,
31 timeout: Duration::from_secs(60),
32 success_threshold: 2,
33 }
34 }
35}
36
37pub struct CircuitBreaker {
39 config: CircuitBreakerConfig,
40 state: Arc<RwLock<CircuitState>>,
41 failure_count: Arc<AtomicUsize>,
42 success_count: Arc<AtomicUsize>,
43 last_failure_time: Arc<RwLock<Option<Instant>>>,
44}
45
46impl CircuitBreaker {
47 pub fn new(config: CircuitBreakerConfig) -> Self {
48 Self {
49 config,
50 state: Arc::new(RwLock::new(CircuitState::Closed)),
51 failure_count: Arc::new(AtomicUsize::new(0)),
52 success_count: Arc::new(AtomicUsize::new(0)),
53 last_failure_time: Arc::new(RwLock::new(None)),
54 }
55 }
56
57 pub async fn get_state(&self) -> CircuitState {
58 *self.state.read().await
59 }
60
61 pub async fn call<F, Fut, T>(&self, operation: F) -> Result<T>
62 where
63 F: FnOnce() -> Fut,
64 Fut: std::future::Future<Output = Result<T>>,
65 {
66 let current_state = self.get_state().await;
68
69 if current_state == CircuitState::Open {
70 if let Some(last_failure) = *self.last_failure_time.read().await {
72 if last_failure.elapsed() >= self.config.timeout {
73 *self.state.write().await = CircuitState::HalfOpen;
75 self.success_count.store(0, Ordering::SeqCst);
76 } else {
77 return Err(Error::Handler("Circuit breaker is OPEN".to_string()));
78 }
79 }
80 }
81
82 match operation().await {
84 Ok(result) => {
85 self.on_success().await;
86 Ok(result)
87 }
88 Err(error) => {
89 self.on_failure().await;
90 Err(error)
91 }
92 }
93 }
94
95 async fn on_success(&self) {
96 let state = self.get_state().await;
97
98 match state {
99 CircuitState::HalfOpen => {
100 let successes = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
101 if successes >= self.config.success_threshold {
102 *self.state.write().await = CircuitState::Closed;
103 self.failure_count.store(0, Ordering::SeqCst);
104 self.success_count.store(0, Ordering::SeqCst);
105 }
106 }
107 CircuitState::Closed => {
108 self.failure_count.store(0, Ordering::SeqCst);
110 }
111 _ => {}
112 }
113 }
114
115 async fn on_failure(&self) {
116 let state = self.get_state().await;
117
118 match state {
119 CircuitState::Closed => {
120 let failures = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
121 if failures >= self.config.failure_threshold {
122 *self.state.write().await = CircuitState::Open;
123 *self.last_failure_time.write().await = Some(Instant::now());
124 }
125 }
126 CircuitState::HalfOpen => {
127 *self.state.write().await = CircuitState::Open;
129 *self.last_failure_time.write().await = Some(Instant::now());
130 self.failure_count
131 .store(self.config.failure_threshold, Ordering::SeqCst);
132 }
133 _ => {}
134 }
135 }
136
137 pub fn get_stats(&self) -> CircuitBreakerStats {
138 CircuitBreakerStats {
139 failure_count: self.failure_count.load(Ordering::SeqCst),
140 success_count: self.success_count.load(Ordering::SeqCst),
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
146pub struct CircuitBreakerStats {
147 pub failure_count: usize,
148 pub success_count: usize,
149}
150
151pub struct FallbackHandler<F, Fut>
153where
154 F: Fn(Error) -> Fut + Send + Sync,
155 Fut: std::future::Future<Output = Result<Value>> + Send,
156{
157 fallback_fn: F,
158 _phantom: std::marker::PhantomData<Fut>,
159}
160
161impl<F, Fut> FallbackHandler<F, Fut>
162where
163 F: Fn(Error) -> Fut + Send + Sync,
164 Fut: std::future::Future<Output = Result<Value>> + Send,
165{
166 pub fn new(fallback_fn: F) -> Self {
167 Self {
168 fallback_fn,
169 _phantom: std::marker::PhantomData,
170 }
171 }
172
173 pub async fn handle_error(&self, error: Error) -> Result<Value> {
174 (self.fallback_fn)(error).await
175 }
176}
177
178pub struct ErrorTracker {
180 total_errors: Arc<AtomicU64>,
181 errors_by_type: Arc<RwLock<rustc_hash::FxHashMap<String, u64>>>,
182}
183
184impl ErrorTracker {
185 pub fn new() -> Self {
186 Self {
187 total_errors: Arc::new(AtomicU64::new(0)),
188 errors_by_type: Arc::new(RwLock::new(rustc_hash::FxHashMap::default())),
189 }
190 }
191
192 pub async fn track_error(&self, error: &Error) {
193 self.total_errors.fetch_add(1, Ordering::SeqCst);
194
195 let error_type = self.classify_error(error);
196 let mut errors = self.errors_by_type.write().await;
197 *errors.entry(error_type).or_insert(0) += 1;
198 }
199
200 fn classify_error(&self, error: &Error) -> String {
201 match error {
202 Error::Handler(msg) => {
203 if msg.contains("timeout") || msg.contains("timed out") {
204 "timeout".to_string()
205 } else if msg.contains("connection") {
206 "connection".to_string()
207 } else {
208 "handler_error".to_string()
209 }
210 }
211 _ => "unknown".to_string(),
212 }
213 }
214
215 pub fn total_errors(&self) -> u64 {
216 self.total_errors.load(Ordering::SeqCst)
217 }
218
219 pub async fn errors_by_type(&self) -> rustc_hash::FxHashMap<String, u64> {
220 self.errors_by_type.read().await.clone()
221 }
222
223 pub async fn reset(&self) {
224 self.total_errors.store(0, Ordering::SeqCst);
225 self.errors_by_type.write().await.clear();
226 }
227}
228
229impl Default for ErrorTracker {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235pub struct RecoveryMiddleware {
237 circuit_breaker: Option<Arc<CircuitBreaker>>,
238 error_tracker: Arc<ErrorTracker>,
239}
240
241impl RecoveryMiddleware {
242 pub fn new() -> Self {
243 Self {
244 circuit_breaker: None,
245 error_tracker: Arc::new(ErrorTracker::new()),
246 }
247 }
248
249 pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
250 self.circuit_breaker = Some(Arc::new(CircuitBreaker::new(config)));
251 self
252 }
253
254 pub fn error_tracker(&self) -> Arc<ErrorTracker> {
255 self.error_tracker.clone()
256 }
257}
258
259impl Default for RecoveryMiddleware {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265#[async_trait::async_trait]
266impl Middleware for RecoveryMiddleware {
267 async fn before(&self, request: Value) -> Result<Value> {
268 if let Some(cb) = &self.circuit_breaker {
270 let state = cb.get_state().await;
271 if state == CircuitState::Open {
272 return Err(Error::Handler(
273 "Circuit breaker is OPEN - service unavailable".to_string(),
274 ));
275 }
276 }
277 Ok(request)
278 }
279
280 async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
281 self.error_tracker.track_error(&error).await;
283
284 if let Some(cb) = &self.circuit_breaker {
286 cb.on_failure().await;
287 }
288
289 Err(error)
290 }
291
292 async fn after(&self, _request: Value, response: Value) -> Result<Value> {
293 if let Some(cb) = &self.circuit_breaker {
295 cb.on_success().await;
296 }
297
298 Ok(response)
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[tokio::test]
307 async fn test_circuit_breaker_closed_to_open() {
308 let config = CircuitBreakerConfig {
309 failure_threshold: 3,
310 timeout: Duration::from_secs(1),
311 success_threshold: 2,
312 };
313
314 let cb = CircuitBreaker::new(config);
315
316 assert_eq!(cb.get_state().await, CircuitState::Closed);
318
319 for _ in 0..3 {
321 let _ = cb
322 .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
323 .await;
324 }
325
326 assert_eq!(cb.get_state().await, CircuitState::Open);
327 }
328
329 #[tokio::test]
330 async fn test_circuit_breaker_half_open_recovery() {
331 let config = CircuitBreakerConfig {
332 failure_threshold: 2,
333 timeout: Duration::from_millis(100),
334 success_threshold: 2,
335 };
336
337 let cb = CircuitBreaker::new(config);
338
339 for _ in 0..2 {
341 let _ = cb
342 .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
343 .await;
344 }
345
346 assert_eq!(cb.get_state().await, CircuitState::Open);
347
348 tokio::time::sleep(Duration::from_millis(150)).await;
350
351 let _ = cb.call(|| async { Ok::<_, Error>(42) }).await;
353 assert_eq!(cb.get_state().await, CircuitState::HalfOpen);
354
355 let _ = cb.call(|| async { Ok::<_, Error>(42) }).await;
357 assert_eq!(cb.get_state().await, CircuitState::Closed);
358 }
359
360 #[tokio::test]
361 async fn test_circuit_breaker_rejects_when_open() {
362 let config = CircuitBreakerConfig {
363 failure_threshold: 1,
364 timeout: Duration::from_secs(60),
365 success_threshold: 2,
366 };
367
368 let cb = CircuitBreaker::new(config);
369
370 let _ = cb
372 .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
373 .await;
374
375 let result = cb.call(|| async { Ok::<_, Error>(42) }).await;
377 assert!(result.is_err());
378 assert!(result
379 .unwrap_err()
380 .to_string()
381 .contains("Circuit breaker is OPEN"));
382 }
383
384 #[tokio::test]
385 async fn test_error_tracker() {
386 let tracker = ErrorTracker::new();
387
388 tracker
390 .track_error(&Error::Handler("timeout error".to_string()))
391 .await;
392 tracker
393 .track_error(&Error::Handler("timeout error".to_string()))
394 .await;
395 tracker
396 .track_error(&Error::Handler("connection error".to_string()))
397 .await;
398 tracker
399 .track_error(&Error::Handler("other error".to_string()))
400 .await;
401
402 assert_eq!(tracker.total_errors(), 4);
403
404 let by_type = tracker.errors_by_type().await;
405 assert_eq!(by_type.get("timeout"), Some(&2));
406 assert_eq!(by_type.get("connection"), Some(&1));
407 assert_eq!(by_type.get("handler_error"), Some(&1));
408 }
409
410 #[tokio::test]
411 async fn test_fallback_handler() {
412 let fallback = FallbackHandler::new(|error: Error| async move {
413 let _ = error;
415 Ok(serde_json::json!({"fallback": true}))
416 });
417
418 let result = fallback
419 .handle_error(Error::Handler("test".to_string()))
420 .await
421 .unwrap();
422
423 assert_eq!(result["fallback"], true);
424 }
425
426 #[tokio::test]
427 async fn test_recovery_middleware_integration() {
428 let config = CircuitBreakerConfig {
429 failure_threshold: 2,
430 timeout: Duration::from_secs(60),
431 success_threshold: 2,
432 };
433
434 let middleware = RecoveryMiddleware::new().with_circuit_breaker(config);
435 let tracker = middleware.error_tracker();
436
437 let _ = middleware
439 .on_error(
440 serde_json::json!({}),
441 Error::Handler("test error".to_string()),
442 )
443 .await;
444
445 let _ = middleware
446 .on_error(
447 serde_json::json!({}),
448 Error::Handler("test error".to_string()),
449 )
450 .await;
451
452 assert_eq!(tracker.total_errors(), 2);
454
455 let result = middleware.before(serde_json::json!({})).await;
457 assert!(result.is_err());
458 }
459}