pforge_runtime/
recovery.rs

1use crate::{Error, Middleware, Result};
2use serde_json::Value;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7
8/// Circuit breaker states
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CircuitState {
11    Closed,  // Normal operation
12    Open,    // Failing, reject requests
13    HalfOpen, // Testing if service recovered
14}
15
16/// Circuit breaker configuration
17#[derive(Debug, Clone)]
18pub struct CircuitBreakerConfig {
19    /// Number of failures before opening circuit
20    pub failure_threshold: usize,
21    /// Time to wait before attempting recovery
22    pub timeout: Duration,
23    /// Number of successes needed to close circuit
24    pub success_threshold: usize,
25}
26
27impl Default for CircuitBreakerConfig {
28    fn default() -> Self {
29        Self {
30            failure_threshold: 5,
31            timeout: Duration::from_secs(60),
32            success_threshold: 2,
33        }
34    }
35}
36
37/// Circuit breaker for fault tolerance
38pub struct CircuitBreaker {
39    config: CircuitBreakerConfig,
40    state: Arc<RwLock<CircuitState>>,
41    failure_count: Arc<AtomicUsize>,
42    success_count: Arc<AtomicUsize>,
43    last_failure_time: Arc<RwLock<Option<Instant>>>,
44}
45
46impl CircuitBreaker {
47    pub fn new(config: CircuitBreakerConfig) -> Self {
48        Self {
49            config,
50            state: Arc::new(RwLock::new(CircuitState::Closed)),
51            failure_count: Arc::new(AtomicUsize::new(0)),
52            success_count: Arc::new(AtomicUsize::new(0)),
53            last_failure_time: Arc::new(RwLock::new(None)),
54        }
55    }
56
57    pub async fn get_state(&self) -> CircuitState {
58        *self.state.read().await
59    }
60
61    pub async fn call<F, Fut, T>(&self, operation: F) -> Result<T>
62    where
63        F: FnOnce() -> Fut,
64        Fut: std::future::Future<Output = Result<T>>,
65    {
66        // Check if we should attempt the operation
67        let current_state = self.get_state().await;
68
69        match current_state {
70            CircuitState::Open => {
71                // Check if timeout has elapsed
72                if let Some(last_failure) = *self.last_failure_time.read().await {
73                    if last_failure.elapsed() >= self.config.timeout {
74                        // Transition to half-open
75                        *self.state.write().await = CircuitState::HalfOpen;
76                        self.success_count.store(0, Ordering::SeqCst);
77                    } else {
78                        return Err(Error::Handler("Circuit breaker is OPEN".to_string()));
79                    }
80                }
81            }
82            _ => {}
83        }
84
85        // Attempt the operation
86        match operation().await {
87            Ok(result) => {
88                self.on_success().await;
89                Ok(result)
90            }
91            Err(error) => {
92                self.on_failure().await;
93                Err(error)
94            }
95        }
96    }
97
98    async fn on_success(&self) {
99        let state = self.get_state().await;
100
101        match state {
102            CircuitState::HalfOpen => {
103                let successes = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
104                if successes >= self.config.success_threshold {
105                    *self.state.write().await = CircuitState::Closed;
106                    self.failure_count.store(0, Ordering::SeqCst);
107                    self.success_count.store(0, Ordering::SeqCst);
108                }
109            }
110            CircuitState::Closed => {
111                // Reset failure count on success
112                self.failure_count.store(0, Ordering::SeqCst);
113            }
114            _ => {}
115        }
116    }
117
118    async fn on_failure(&self) {
119        let state = self.get_state().await;
120
121        match state {
122            CircuitState::Closed => {
123                let failures = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
124                if failures >= self.config.failure_threshold {
125                    *self.state.write().await = CircuitState::Open;
126                    *self.last_failure_time.write().await = Some(Instant::now());
127                }
128            }
129            CircuitState::HalfOpen => {
130                // Any failure in half-open state immediately opens circuit
131                *self.state.write().await = CircuitState::Open;
132                *self.last_failure_time.write().await = Some(Instant::now());
133                self.failure_count.store(self.config.failure_threshold, Ordering::SeqCst);
134            }
135            _ => {}
136        }
137    }
138
139    pub fn get_stats(&self) -> CircuitBreakerStats {
140        CircuitBreakerStats {
141            failure_count: self.failure_count.load(Ordering::SeqCst),
142            success_count: self.success_count.load(Ordering::SeqCst),
143        }
144    }
145}
146
147#[derive(Debug, Clone)]
148pub struct CircuitBreakerStats {
149    pub failure_count: usize,
150    pub success_count: usize,
151}
152
153/// Fallback handler for error recovery
154pub struct FallbackHandler<F, Fut>
155where
156    F: Fn(Error) -> Fut + Send + Sync,
157    Fut: std::future::Future<Output = Result<Value>> + Send,
158{
159    fallback_fn: F,
160    _phantom: std::marker::PhantomData<Fut>,
161}
162
163impl<F, Fut> FallbackHandler<F, Fut>
164where
165    F: Fn(Error) -> Fut + Send + Sync,
166    Fut: std::future::Future<Output = Result<Value>> + Send,
167{
168    pub fn new(fallback_fn: F) -> Self {
169        Self {
170            fallback_fn,
171            _phantom: std::marker::PhantomData,
172        }
173    }
174
175    pub async fn handle_error(&self, error: Error) -> Result<Value> {
176        (self.fallback_fn)(error).await
177    }
178}
179
180/// Error tracking for monitoring and debugging
181pub struct ErrorTracker {
182    total_errors: Arc<AtomicU64>,
183    errors_by_type: Arc<RwLock<std::collections::HashMap<String, u64>>>,
184}
185
186impl ErrorTracker {
187    pub fn new() -> Self {
188        Self {
189            total_errors: Arc::new(AtomicU64::new(0)),
190            errors_by_type: Arc::new(RwLock::new(std::collections::HashMap::new())),
191        }
192    }
193
194    pub async fn track_error(&self, error: &Error) {
195        self.total_errors.fetch_add(1, Ordering::SeqCst);
196
197        let error_type = self.classify_error(error);
198        let mut errors = self.errors_by_type.write().await;
199        *errors.entry(error_type).or_insert(0) += 1;
200    }
201
202    fn classify_error(&self, error: &Error) -> String {
203        match error {
204            Error::Handler(msg) => {
205                if msg.contains("timeout") || msg.contains("timed out") {
206                    "timeout".to_string()
207                } else if msg.contains("connection") {
208                    "connection".to_string()
209                } else {
210                    "handler_error".to_string()
211                }
212            }
213            _ => "unknown".to_string(),
214        }
215    }
216
217    pub fn total_errors(&self) -> u64 {
218        self.total_errors.load(Ordering::SeqCst)
219    }
220
221    pub async fn errors_by_type(&self) -> std::collections::HashMap<String, u64> {
222        self.errors_by_type.read().await.clone()
223    }
224
225    pub async fn reset(&self) {
226        self.total_errors.store(0, Ordering::SeqCst);
227        self.errors_by_type.write().await.clear();
228    }
229}
230
231impl Default for ErrorTracker {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237/// Recovery middleware - integrates circuit breaker and fallback
238pub struct RecoveryMiddleware {
239    circuit_breaker: Option<Arc<CircuitBreaker>>,
240    error_tracker: Arc<ErrorTracker>,
241}
242
243impl RecoveryMiddleware {
244    pub fn new() -> Self {
245        Self {
246            circuit_breaker: None,
247            error_tracker: Arc::new(ErrorTracker::new()),
248        }
249    }
250
251    pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
252        self.circuit_breaker = Some(Arc::new(CircuitBreaker::new(config)));
253        self
254    }
255
256    pub fn error_tracker(&self) -> Arc<ErrorTracker> {
257        self.error_tracker.clone()
258    }
259}
260
261impl Default for RecoveryMiddleware {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267#[async_trait::async_trait]
268impl Middleware for RecoveryMiddleware {
269    async fn before(&self, request: Value) -> Result<Value> {
270        // Check circuit breaker before processing
271        if let Some(cb) = &self.circuit_breaker {
272            let state = cb.get_state().await;
273            if state == CircuitState::Open {
274                return Err(Error::Handler("Circuit breaker is OPEN - service unavailable".to_string()));
275            }
276        }
277        Ok(request)
278    }
279
280    async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
281        // Track the error
282        self.error_tracker.track_error(&error).await;
283
284        // Record failure in circuit breaker
285        if let Some(cb) = &self.circuit_breaker {
286            cb.on_failure().await;
287        }
288
289        Err(error)
290    }
291
292    async fn after(&self, _request: Value, response: Value) -> Result<Value> {
293        // Record success in circuit breaker
294        if let Some(cb) = &self.circuit_breaker {
295            cb.on_success().await;
296        }
297
298        Ok(response)
299    }
300}
301
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[tokio::test]
308    async fn test_circuit_breaker_closed_to_open() {
309        let config = CircuitBreakerConfig {
310            failure_threshold: 3,
311            timeout: Duration::from_secs(1),
312            success_threshold: 2,
313        };
314
315        let cb = CircuitBreaker::new(config);
316
317        // Initially closed
318        assert_eq!(cb.get_state().await, CircuitState::Closed);
319
320        // 3 failures should open circuit
321        for _ in 0..3 {
322            let _ = cb
323                .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
324                .await;
325        }
326
327        assert_eq!(cb.get_state().await, CircuitState::Open);
328    }
329
330    #[tokio::test]
331    async fn test_circuit_breaker_half_open_recovery() {
332        let config = CircuitBreakerConfig {
333            failure_threshold: 2,
334            timeout: Duration::from_millis(100),
335            success_threshold: 2,
336        };
337
338        let cb = CircuitBreaker::new(config);
339
340        // Open the circuit
341        for _ in 0..2 {
342            let _ = cb
343                .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
344                .await;
345        }
346
347        assert_eq!(cb.get_state().await, CircuitState::Open);
348
349        // Wait for timeout
350        tokio::time::sleep(Duration::from_millis(150)).await;
351
352        // Next call should transition to half-open
353        let _ = cb.call(|| async { Ok::<_, Error>(42) }).await;
354        assert_eq!(cb.get_state().await, CircuitState::HalfOpen);
355
356        // One more success should close circuit
357        let _ = cb.call(|| async { Ok::<_, Error>(42) }).await;
358        assert_eq!(cb.get_state().await, CircuitState::Closed);
359    }
360
361    #[tokio::test]
362    async fn test_circuit_breaker_rejects_when_open() {
363        let config = CircuitBreakerConfig {
364            failure_threshold: 1,
365            timeout: Duration::from_secs(60),
366            success_threshold: 2,
367        };
368
369        let cb = CircuitBreaker::new(config);
370
371        // Open the circuit
372        let _ = cb
373            .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
374            .await;
375
376        // Should reject immediately
377        let result = cb.call(|| async { Ok::<_, Error>(42) }).await;
378        assert!(result.is_err());
379        assert!(result.unwrap_err().to_string().contains("Circuit breaker is OPEN"));
380    }
381
382    #[tokio::test]
383    async fn test_error_tracker() {
384        let tracker = ErrorTracker::new();
385
386        // Track different errors
387        tracker.track_error(&Error::Handler("timeout error".to_string())).await;
388        tracker.track_error(&Error::Handler("timeout error".to_string())).await;
389        tracker.track_error(&Error::Handler("connection error".to_string())).await;
390        tracker.track_error(&Error::Handler("other error".to_string())).await;
391
392        assert_eq!(tracker.total_errors(), 4);
393
394        let by_type = tracker.errors_by_type().await;
395        assert_eq!(by_type.get("timeout"), Some(&2));
396        assert_eq!(by_type.get("connection"), Some(&1));
397        assert_eq!(by_type.get("handler_error"), Some(&1));
398    }
399
400    #[tokio::test]
401    async fn test_fallback_handler() {
402        let fallback = FallbackHandler::new(|error: Error| async move {
403            // Return default value on error
404            let _ = error;
405            Ok(serde_json::json!({"fallback": true}))
406        });
407
408        let result = fallback
409            .handle_error(Error::Handler("test".to_string()))
410            .await
411            .unwrap();
412
413        assert_eq!(result["fallback"], true);
414    }
415
416    #[tokio::test]
417    async fn test_recovery_middleware_integration() {
418        let config = CircuitBreakerConfig {
419            failure_threshold: 2,
420            timeout: Duration::from_secs(60),
421            success_threshold: 2,
422        };
423
424        let middleware = RecoveryMiddleware::new().with_circuit_breaker(config);
425        let tracker = middleware.error_tracker();
426
427        // Simulate failures
428        let _ = middleware
429            .on_error(
430                serde_json::json!({}),
431                Error::Handler("test error".to_string()),
432            )
433            .await;
434
435        let _ = middleware
436            .on_error(
437                serde_json::json!({}),
438                Error::Handler("test error".to_string()),
439            )
440            .await;
441
442        // Check error tracking
443        assert_eq!(tracker.total_errors(), 2);
444
445        // Circuit should be open, before hook should fail
446        let result = middleware.before(serde_json::json!({})).await;
447        assert!(result.is_err());
448    }
449}