1use crate::{Error, Middleware, Result};
2use serde_json::Value;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CircuitState {
11 Closed, Open, HalfOpen, }
15
16#[derive(Debug, Clone)]
18pub struct CircuitBreakerConfig {
19 pub failure_threshold: usize,
21 pub timeout: Duration,
23 pub success_threshold: usize,
25}
26
27impl Default for CircuitBreakerConfig {
28 fn default() -> Self {
29 Self {
30 failure_threshold: 5,
31 timeout: Duration::from_secs(60),
32 success_threshold: 2,
33 }
34 }
35}
36
37pub struct CircuitBreaker {
39 config: CircuitBreakerConfig,
40 state: Arc<RwLock<CircuitState>>,
41 failure_count: Arc<AtomicUsize>,
42 success_count: Arc<AtomicUsize>,
43 last_failure_time: Arc<RwLock<Option<Instant>>>,
44}
45
46impl CircuitBreaker {
47 pub fn new(config: CircuitBreakerConfig) -> Self {
48 Self {
49 config,
50 state: Arc::new(RwLock::new(CircuitState::Closed)),
51 failure_count: Arc::new(AtomicUsize::new(0)),
52 success_count: Arc::new(AtomicUsize::new(0)),
53 last_failure_time: Arc::new(RwLock::new(None)),
54 }
55 }
56
57 pub async fn get_state(&self) -> CircuitState {
58 *self.state.read().await
59 }
60
61 pub async fn call<F, Fut, T>(&self, operation: F) -> Result<T>
62 where
63 F: FnOnce() -> Fut,
64 Fut: std::future::Future<Output = Result<T>>,
65 {
66 let current_state = self.get_state().await;
68
69 match current_state {
70 CircuitState::Open => {
71 if let Some(last_failure) = *self.last_failure_time.read().await {
73 if last_failure.elapsed() >= self.config.timeout {
74 *self.state.write().await = CircuitState::HalfOpen;
76 self.success_count.store(0, Ordering::SeqCst);
77 } else {
78 return Err(Error::Handler("Circuit breaker is OPEN".to_string()));
79 }
80 }
81 }
82 _ => {}
83 }
84
85 match operation().await {
87 Ok(result) => {
88 self.on_success().await;
89 Ok(result)
90 }
91 Err(error) => {
92 self.on_failure().await;
93 Err(error)
94 }
95 }
96 }
97
98 async fn on_success(&self) {
99 let state = self.get_state().await;
100
101 match state {
102 CircuitState::HalfOpen => {
103 let successes = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
104 if successes >= self.config.success_threshold {
105 *self.state.write().await = CircuitState::Closed;
106 self.failure_count.store(0, Ordering::SeqCst);
107 self.success_count.store(0, Ordering::SeqCst);
108 }
109 }
110 CircuitState::Closed => {
111 self.failure_count.store(0, Ordering::SeqCst);
113 }
114 _ => {}
115 }
116 }
117
118 async fn on_failure(&self) {
119 let state = self.get_state().await;
120
121 match state {
122 CircuitState::Closed => {
123 let failures = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
124 if failures >= self.config.failure_threshold {
125 *self.state.write().await = CircuitState::Open;
126 *self.last_failure_time.write().await = Some(Instant::now());
127 }
128 }
129 CircuitState::HalfOpen => {
130 *self.state.write().await = CircuitState::Open;
132 *self.last_failure_time.write().await = Some(Instant::now());
133 self.failure_count.store(self.config.failure_threshold, Ordering::SeqCst);
134 }
135 _ => {}
136 }
137 }
138
139 pub fn get_stats(&self) -> CircuitBreakerStats {
140 CircuitBreakerStats {
141 failure_count: self.failure_count.load(Ordering::SeqCst),
142 success_count: self.success_count.load(Ordering::SeqCst),
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
148pub struct CircuitBreakerStats {
149 pub failure_count: usize,
150 pub success_count: usize,
151}
152
153pub struct FallbackHandler<F, Fut>
155where
156 F: Fn(Error) -> Fut + Send + Sync,
157 Fut: std::future::Future<Output = Result<Value>> + Send,
158{
159 fallback_fn: F,
160 _phantom: std::marker::PhantomData<Fut>,
161}
162
163impl<F, Fut> FallbackHandler<F, Fut>
164where
165 F: Fn(Error) -> Fut + Send + Sync,
166 Fut: std::future::Future<Output = Result<Value>> + Send,
167{
168 pub fn new(fallback_fn: F) -> Self {
169 Self {
170 fallback_fn,
171 _phantom: std::marker::PhantomData,
172 }
173 }
174
175 pub async fn handle_error(&self, error: Error) -> Result<Value> {
176 (self.fallback_fn)(error).await
177 }
178}
179
180pub struct ErrorTracker {
182 total_errors: Arc<AtomicU64>,
183 errors_by_type: Arc<RwLock<std::collections::HashMap<String, u64>>>,
184}
185
186impl ErrorTracker {
187 pub fn new() -> Self {
188 Self {
189 total_errors: Arc::new(AtomicU64::new(0)),
190 errors_by_type: Arc::new(RwLock::new(std::collections::HashMap::new())),
191 }
192 }
193
194 pub async fn track_error(&self, error: &Error) {
195 self.total_errors.fetch_add(1, Ordering::SeqCst);
196
197 let error_type = self.classify_error(error);
198 let mut errors = self.errors_by_type.write().await;
199 *errors.entry(error_type).or_insert(0) += 1;
200 }
201
202 fn classify_error(&self, error: &Error) -> String {
203 match error {
204 Error::Handler(msg) => {
205 if msg.contains("timeout") || msg.contains("timed out") {
206 "timeout".to_string()
207 } else if msg.contains("connection") {
208 "connection".to_string()
209 } else {
210 "handler_error".to_string()
211 }
212 }
213 _ => "unknown".to_string(),
214 }
215 }
216
217 pub fn total_errors(&self) -> u64 {
218 self.total_errors.load(Ordering::SeqCst)
219 }
220
221 pub async fn errors_by_type(&self) -> std::collections::HashMap<String, u64> {
222 self.errors_by_type.read().await.clone()
223 }
224
225 pub async fn reset(&self) {
226 self.total_errors.store(0, Ordering::SeqCst);
227 self.errors_by_type.write().await.clear();
228 }
229}
230
231impl Default for ErrorTracker {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237pub struct RecoveryMiddleware {
239 circuit_breaker: Option<Arc<CircuitBreaker>>,
240 error_tracker: Arc<ErrorTracker>,
241}
242
243impl RecoveryMiddleware {
244 pub fn new() -> Self {
245 Self {
246 circuit_breaker: None,
247 error_tracker: Arc::new(ErrorTracker::new()),
248 }
249 }
250
251 pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
252 self.circuit_breaker = Some(Arc::new(CircuitBreaker::new(config)));
253 self
254 }
255
256 pub fn error_tracker(&self) -> Arc<ErrorTracker> {
257 self.error_tracker.clone()
258 }
259}
260
261impl Default for RecoveryMiddleware {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267#[async_trait::async_trait]
268impl Middleware for RecoveryMiddleware {
269 async fn before(&self, request: Value) -> Result<Value> {
270 if let Some(cb) = &self.circuit_breaker {
272 let state = cb.get_state().await;
273 if state == CircuitState::Open {
274 return Err(Error::Handler("Circuit breaker is OPEN - service unavailable".to_string()));
275 }
276 }
277 Ok(request)
278 }
279
280 async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
281 self.error_tracker.track_error(&error).await;
283
284 if let Some(cb) = &self.circuit_breaker {
286 cb.on_failure().await;
287 }
288
289 Err(error)
290 }
291
292 async fn after(&self, _request: Value, response: Value) -> Result<Value> {
293 if let Some(cb) = &self.circuit_breaker {
295 cb.on_success().await;
296 }
297
298 Ok(response)
299 }
300}
301
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[tokio::test]
308 async fn test_circuit_breaker_closed_to_open() {
309 let config = CircuitBreakerConfig {
310 failure_threshold: 3,
311 timeout: Duration::from_secs(1),
312 success_threshold: 2,
313 };
314
315 let cb = CircuitBreaker::new(config);
316
317 assert_eq!(cb.get_state().await, CircuitState::Closed);
319
320 for _ in 0..3 {
322 let _ = cb
323 .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
324 .await;
325 }
326
327 assert_eq!(cb.get_state().await, CircuitState::Open);
328 }
329
330 #[tokio::test]
331 async fn test_circuit_breaker_half_open_recovery() {
332 let config = CircuitBreakerConfig {
333 failure_threshold: 2,
334 timeout: Duration::from_millis(100),
335 success_threshold: 2,
336 };
337
338 let cb = CircuitBreaker::new(config);
339
340 for _ in 0..2 {
342 let _ = cb
343 .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
344 .await;
345 }
346
347 assert_eq!(cb.get_state().await, CircuitState::Open);
348
349 tokio::time::sleep(Duration::from_millis(150)).await;
351
352 let _ = cb.call(|| async { Ok::<_, Error>(42) }).await;
354 assert_eq!(cb.get_state().await, CircuitState::HalfOpen);
355
356 let _ = cb.call(|| async { Ok::<_, Error>(42) }).await;
358 assert_eq!(cb.get_state().await, CircuitState::Closed);
359 }
360
361 #[tokio::test]
362 async fn test_circuit_breaker_rejects_when_open() {
363 let config = CircuitBreakerConfig {
364 failure_threshold: 1,
365 timeout: Duration::from_secs(60),
366 success_threshold: 2,
367 };
368
369 let cb = CircuitBreaker::new(config);
370
371 let _ = cb
373 .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
374 .await;
375
376 let result = cb.call(|| async { Ok::<_, Error>(42) }).await;
378 assert!(result.is_err());
379 assert!(result.unwrap_err().to_string().contains("Circuit breaker is OPEN"));
380 }
381
382 #[tokio::test]
383 async fn test_error_tracker() {
384 let tracker = ErrorTracker::new();
385
386 tracker.track_error(&Error::Handler("timeout error".to_string())).await;
388 tracker.track_error(&Error::Handler("timeout error".to_string())).await;
389 tracker.track_error(&Error::Handler("connection error".to_string())).await;
390 tracker.track_error(&Error::Handler("other error".to_string())).await;
391
392 assert_eq!(tracker.total_errors(), 4);
393
394 let by_type = tracker.errors_by_type().await;
395 assert_eq!(by_type.get("timeout"), Some(&2));
396 assert_eq!(by_type.get("connection"), Some(&1));
397 assert_eq!(by_type.get("handler_error"), Some(&1));
398 }
399
400 #[tokio::test]
401 async fn test_fallback_handler() {
402 let fallback = FallbackHandler::new(|error: Error| async move {
403 let _ = error;
405 Ok(serde_json::json!({"fallback": true}))
406 });
407
408 let result = fallback
409 .handle_error(Error::Handler("test".to_string()))
410 .await
411 .unwrap();
412
413 assert_eq!(result["fallback"], true);
414 }
415
416 #[tokio::test]
417 async fn test_recovery_middleware_integration() {
418 let config = CircuitBreakerConfig {
419 failure_threshold: 2,
420 timeout: Duration::from_secs(60),
421 success_threshold: 2,
422 };
423
424 let middleware = RecoveryMiddleware::new().with_circuit_breaker(config);
425 let tracker = middleware.error_tracker();
426
427 let _ = middleware
429 .on_error(
430 serde_json::json!({}),
431 Error::Handler("test error".to_string()),
432 )
433 .await;
434
435 let _ = middleware
436 .on_error(
437 serde_json::json!({}),
438 Error::Handler("test error".to_string()),
439 )
440 .await;
441
442 assert_eq!(tracker.total_errors(), 2);
444
445 let result = middleware.before(serde_json::json!({})).await;
447 assert!(result.is_err());
448 }
449}