pforge_runtime/
recovery.rs

1use crate::{Error, Middleware, Result};
2use serde_json::Value;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7
8/// Circuit breaker states
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CircuitState {
11    Closed,   // Normal operation
12    Open,     // Failing, reject requests
13    HalfOpen, // Testing if service recovered
14}
15
16/// Circuit breaker configuration
17#[derive(Debug, Clone)]
18pub struct CircuitBreakerConfig {
19    /// Number of failures before opening circuit
20    pub failure_threshold: usize,
21    /// Time to wait before attempting recovery
22    pub timeout: Duration,
23    /// Number of successes needed to close circuit
24    pub success_threshold: usize,
25}
26
27impl Default for CircuitBreakerConfig {
28    fn default() -> Self {
29        Self {
30            failure_threshold: 5,
31            timeout: Duration::from_secs(60),
32            success_threshold: 2,
33        }
34    }
35}
36
37/// Circuit breaker for fault tolerance
38pub struct CircuitBreaker {
39    config: CircuitBreakerConfig,
40    state: Arc<RwLock<CircuitState>>,
41    failure_count: Arc<AtomicUsize>,
42    success_count: Arc<AtomicUsize>,
43    last_failure_time: Arc<RwLock<Option<Instant>>>,
44}
45
46impl CircuitBreaker {
47    pub fn new(config: CircuitBreakerConfig) -> Self {
48        Self {
49            config,
50            state: Arc::new(RwLock::new(CircuitState::Closed)),
51            failure_count: Arc::new(AtomicUsize::new(0)),
52            success_count: Arc::new(AtomicUsize::new(0)),
53            last_failure_time: Arc::new(RwLock::new(None)),
54        }
55    }
56
57    pub async fn get_state(&self) -> CircuitState {
58        *self.state.read().await
59    }
60
61    pub async fn call<F, Fut, T>(&self, operation: F) -> Result<T>
62    where
63        F: FnOnce() -> Fut,
64        Fut: std::future::Future<Output = Result<T>>,
65    {
66        // Check if we should attempt the operation
67        let current_state = self.get_state().await;
68
69        if current_state == CircuitState::Open {
70            // Check if timeout has elapsed
71            if let Some(last_failure) = *self.last_failure_time.read().await {
72                if last_failure.elapsed() >= self.config.timeout {
73                    // Transition to half-open
74                    *self.state.write().await = CircuitState::HalfOpen;
75                    self.success_count.store(0, Ordering::SeqCst);
76                } else {
77                    return Err(Error::Handler("Circuit breaker is OPEN".to_string()));
78                }
79            }
80        }
81
82        // Attempt the operation
83        match operation().await {
84            Ok(result) => {
85                self.on_success().await;
86                Ok(result)
87            }
88            Err(error) => {
89                self.on_failure().await;
90                Err(error)
91            }
92        }
93    }
94
95    async fn on_success(&self) {
96        let state = self.get_state().await;
97
98        match state {
99            CircuitState::HalfOpen => {
100                let successes = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
101                if successes >= self.config.success_threshold {
102                    *self.state.write().await = CircuitState::Closed;
103                    self.failure_count.store(0, Ordering::SeqCst);
104                    self.success_count.store(0, Ordering::SeqCst);
105                }
106            }
107            CircuitState::Closed => {
108                // Reset failure count on success
109                self.failure_count.store(0, Ordering::SeqCst);
110            }
111            _ => {}
112        }
113    }
114
115    async fn on_failure(&self) {
116        let state = self.get_state().await;
117
118        match state {
119            CircuitState::Closed => {
120                let failures = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
121                if failures >= self.config.failure_threshold {
122                    *self.state.write().await = CircuitState::Open;
123                    *self.last_failure_time.write().await = Some(Instant::now());
124                }
125            }
126            CircuitState::HalfOpen => {
127                // Any failure in half-open state immediately opens circuit
128                *self.state.write().await = CircuitState::Open;
129                *self.last_failure_time.write().await = Some(Instant::now());
130                self.failure_count
131                    .store(self.config.failure_threshold, Ordering::SeqCst);
132            }
133            _ => {}
134        }
135    }
136
137    pub fn get_stats(&self) -> CircuitBreakerStats {
138        CircuitBreakerStats {
139            failure_count: self.failure_count.load(Ordering::SeqCst),
140            success_count: self.success_count.load(Ordering::SeqCst),
141        }
142    }
143}
144
145#[derive(Debug, Clone)]
146pub struct CircuitBreakerStats {
147    pub failure_count: usize,
148    pub success_count: usize,
149}
150
151/// Fallback handler for error recovery
152pub struct FallbackHandler<F, Fut>
153where
154    F: Fn(Error) -> Fut + Send + Sync,
155    Fut: std::future::Future<Output = Result<Value>> + Send,
156{
157    fallback_fn: F,
158    _phantom: std::marker::PhantomData<Fut>,
159}
160
161impl<F, Fut> FallbackHandler<F, Fut>
162where
163    F: Fn(Error) -> Fut + Send + Sync,
164    Fut: std::future::Future<Output = Result<Value>> + Send,
165{
166    pub fn new(fallback_fn: F) -> Self {
167        Self {
168            fallback_fn,
169            _phantom: std::marker::PhantomData,
170        }
171    }
172
173    pub async fn handle_error(&self, error: Error) -> Result<Value> {
174        (self.fallback_fn)(error).await
175    }
176}
177
178/// Error tracking for monitoring and debugging
179pub struct ErrorTracker {
180    total_errors: Arc<AtomicU64>,
181    errors_by_type: Arc<RwLock<rustc_hash::FxHashMap<String, u64>>>,
182}
183
184impl ErrorTracker {
185    pub fn new() -> Self {
186        Self {
187            total_errors: Arc::new(AtomicU64::new(0)),
188            errors_by_type: Arc::new(RwLock::new(rustc_hash::FxHashMap::default())),
189        }
190    }
191
192    pub async fn track_error(&self, error: &Error) {
193        self.total_errors.fetch_add(1, Ordering::SeqCst);
194
195        let error_type = self.classify_error(error);
196        let mut errors = self.errors_by_type.write().await;
197        *errors.entry(error_type).or_insert(0) += 1;
198    }
199
200    fn classify_error(&self, error: &Error) -> String {
201        match error {
202            Error::Handler(msg) => {
203                if msg.contains("timeout") || msg.contains("timed out") {
204                    "timeout".to_string()
205                } else if msg.contains("connection") {
206                    "connection".to_string()
207                } else {
208                    "handler_error".to_string()
209                }
210            }
211            _ => "unknown".to_string(),
212        }
213    }
214
215    pub fn total_errors(&self) -> u64 {
216        self.total_errors.load(Ordering::SeqCst)
217    }
218
219    pub async fn errors_by_type(&self) -> rustc_hash::FxHashMap<String, u64> {
220        self.errors_by_type.read().await.clone()
221    }
222
223    pub async fn reset(&self) {
224        self.total_errors.store(0, Ordering::SeqCst);
225        self.errors_by_type.write().await.clear();
226    }
227}
228
229impl Default for ErrorTracker {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235/// Recovery middleware - integrates circuit breaker and fallback
236pub struct RecoveryMiddleware {
237    circuit_breaker: Option<Arc<CircuitBreaker>>,
238    error_tracker: Arc<ErrorTracker>,
239}
240
241impl RecoveryMiddleware {
242    pub fn new() -> Self {
243        Self {
244            circuit_breaker: None,
245            error_tracker: Arc::new(ErrorTracker::new()),
246        }
247    }
248
249    pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
250        self.circuit_breaker = Some(Arc::new(CircuitBreaker::new(config)));
251        self
252    }
253
254    pub fn error_tracker(&self) -> Arc<ErrorTracker> {
255        self.error_tracker.clone()
256    }
257}
258
259impl Default for RecoveryMiddleware {
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265#[async_trait::async_trait]
266impl Middleware for RecoveryMiddleware {
267    async fn before(&self, request: Value) -> Result<Value> {
268        // Check circuit breaker before processing
269        if let Some(cb) = &self.circuit_breaker {
270            let state = cb.get_state().await;
271            if state == CircuitState::Open {
272                return Err(Error::Handler(
273                    "Circuit breaker is OPEN - service unavailable".to_string(),
274                ));
275            }
276        }
277        Ok(request)
278    }
279
280    async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
281        // Track the error
282        self.error_tracker.track_error(&error).await;
283
284        // Record failure in circuit breaker
285        if let Some(cb) = &self.circuit_breaker {
286            cb.on_failure().await;
287        }
288
289        Err(error)
290    }
291
292    async fn after(&self, _request: Value, response: Value) -> Result<Value> {
293        // Record success in circuit breaker
294        if let Some(cb) = &self.circuit_breaker {
295            cb.on_success().await;
296        }
297
298        Ok(response)
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[tokio::test]
307    async fn test_circuit_breaker_closed_to_open() {
308        let config = CircuitBreakerConfig {
309            failure_threshold: 3,
310            timeout: Duration::from_secs(1),
311            success_threshold: 2,
312        };
313
314        let cb = CircuitBreaker::new(config);
315
316        // Initially closed
317        assert_eq!(cb.get_state().await, CircuitState::Closed);
318
319        // 3 failures should open circuit
320        for _ in 0..3 {
321            let _ = cb
322                .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
323                .await;
324        }
325
326        assert_eq!(cb.get_state().await, CircuitState::Open);
327    }
328
329    #[tokio::test]
330    async fn test_circuit_breaker_half_open_recovery() {
331        let config = CircuitBreakerConfig {
332            failure_threshold: 2,
333            timeout: Duration::from_millis(100),
334            success_threshold: 2,
335        };
336
337        let cb = CircuitBreaker::new(config);
338
339        // Open the circuit
340        for _ in 0..2 {
341            let _ = cb
342                .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
343                .await;
344        }
345
346        assert_eq!(cb.get_state().await, CircuitState::Open);
347
348        // Wait for timeout
349        tokio::time::sleep(Duration::from_millis(150)).await;
350
351        // Next call should transition to half-open
352        let _ = cb.call(|| async { Ok::<_, Error>(42) }).await;
353        assert_eq!(cb.get_state().await, CircuitState::HalfOpen);
354
355        // One more success should close circuit
356        let _ = cb.call(|| async { Ok::<_, Error>(42) }).await;
357        assert_eq!(cb.get_state().await, CircuitState::Closed);
358    }
359
360    #[tokio::test]
361    async fn test_circuit_breaker_rejects_when_open() {
362        let config = CircuitBreakerConfig {
363            failure_threshold: 1,
364            timeout: Duration::from_secs(60),
365            success_threshold: 2,
366        };
367
368        let cb = CircuitBreaker::new(config);
369
370        // Open the circuit
371        let _ = cb
372            .call(|| async { Err::<(), _>(Error::Handler("test error".to_string())) })
373            .await;
374
375        // Should reject immediately
376        let result = cb.call(|| async { Ok::<_, Error>(42) }).await;
377        assert!(result.is_err());
378        assert!(result
379            .unwrap_err()
380            .to_string()
381            .contains("Circuit breaker is OPEN"));
382    }
383
384    #[tokio::test]
385    async fn test_error_tracker() {
386        let tracker = ErrorTracker::new();
387
388        // Track different errors
389        tracker
390            .track_error(&Error::Handler("timeout error".to_string()))
391            .await;
392        tracker
393            .track_error(&Error::Handler("timeout error".to_string()))
394            .await;
395        tracker
396            .track_error(&Error::Handler("connection error".to_string()))
397            .await;
398        tracker
399            .track_error(&Error::Handler("other error".to_string()))
400            .await;
401
402        assert_eq!(tracker.total_errors(), 4);
403
404        let by_type = tracker.errors_by_type().await;
405        assert_eq!(by_type.get("timeout"), Some(&2));
406        assert_eq!(by_type.get("connection"), Some(&1));
407        assert_eq!(by_type.get("handler_error"), Some(&1));
408    }
409
410    #[tokio::test]
411    async fn test_fallback_handler() {
412        let fallback = FallbackHandler::new(|error: Error| async move {
413            // Return default value on error
414            let _ = error;
415            Ok(serde_json::json!({"fallback": true}))
416        });
417
418        let result = fallback
419            .handle_error(Error::Handler("test".to_string()))
420            .await
421            .unwrap();
422
423        assert_eq!(result["fallback"], true);
424    }
425
426    #[tokio::test]
427    async fn test_recovery_middleware_integration() {
428        let config = CircuitBreakerConfig {
429            failure_threshold: 2,
430            timeout: Duration::from_secs(60),
431            success_threshold: 2,
432        };
433
434        let middleware = RecoveryMiddleware::new().with_circuit_breaker(config);
435        let tracker = middleware.error_tracker();
436
437        // Simulate failures
438        let _ = middleware
439            .on_error(
440                serde_json::json!({}),
441                Error::Handler("test error".to_string()),
442            )
443            .await;
444
445        let _ = middleware
446            .on_error(
447                serde_json::json!({}),
448                Error::Handler("test error".to_string()),
449            )
450            .await;
451
452        // Check error tracking
453        assert_eq!(tracker.total_errors(), 2);
454
455        // Circuit should be open, before hook should fail
456        let result = middleware.before(serde_json::json!({})).await;
457        assert!(result.is_err());
458    }
459}