guts_node/resilience/
mod.rs

1//! # Resilience Module
2//!
3//! Production-grade resilience patterns including:
4//!
5//! - **Retry Policy**: Configurable retry with exponential backoff
6//! - **Circuit Breaker**: Fail-fast for failing services
7//! - **Timeout Management**: Request and operation timeouts
8//! - **Rate Limiting**: Request rate limiting
9//!
10//! ## Usage
11//!
12//! ```rust,no_run
13//! use guts_node::resilience::{RetryPolicy, CircuitBreaker, TimeoutConfig};
14//! use std::time::Duration;
15//!
16//! let retry = RetryPolicy::default();
17//! let circuit_breaker = CircuitBreaker::new(5, 3, Duration::from_secs(30));
18//! let timeout = TimeoutConfig::default();
19//! ```
20
21use parking_lot::RwLock;
22use std::future::Future;
23use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
24use std::time::{Duration, Instant};
25
26/// Error kinds that can be retried.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum RetryableError {
29    /// Network timeout.
30    Timeout,
31    /// Connection failed.
32    ConnectionFailed,
33    /// Service temporarily unavailable.
34    ServiceUnavailable,
35    /// Rate limited.
36    RateLimited,
37}
38
39/// Retry policy configuration.
40#[derive(Debug, Clone)]
41pub struct RetryPolicy {
42    /// Maximum number of retry attempts.
43    pub max_attempts: u32,
44    /// Initial delay between retries.
45    pub initial_delay: Duration,
46    /// Maximum delay between retries.
47    pub max_delay: Duration,
48    /// Backoff multiplier.
49    pub multiplier: f64,
50    /// Whether to add jitter to delays.
51    pub jitter: bool,
52}
53
54impl Default for RetryPolicy {
55    fn default() -> Self {
56        Self {
57            max_attempts: 3,
58            initial_delay: Duration::from_millis(100),
59            max_delay: Duration::from_secs(5),
60            multiplier: 2.0,
61            jitter: true,
62        }
63    }
64}
65
66impl RetryPolicy {
67    /// Create a new retry policy.
68    pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
69        Self {
70            max_attempts,
71            initial_delay,
72            ..Default::default()
73        }
74    }
75
76    /// Calculate delay for a given attempt number.
77    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
78        if attempt == 0 {
79            return Duration::ZERO;
80        }
81
82        let base_delay_ms = self.initial_delay.as_millis() as f64;
83        let delay_ms = base_delay_ms * self.multiplier.powi(attempt as i32 - 1);
84        let capped_delay =
85            Duration::from_millis(delay_ms.min(self.max_delay.as_millis() as f64) as u64);
86
87        if self.jitter {
88            // Add up to 25% jitter
89            let jitter_factor = 1.0 + (rand::random::<f64>() * 0.25);
90            Duration::from_millis((capped_delay.as_millis() as f64 * jitter_factor) as u64)
91        } else {
92            capped_delay
93        }
94    }
95
96    /// Execute an operation with retry.
97    pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, E>
98    where
99        F: FnMut() -> Fut,
100        Fut: Future<Output = Result<T, E>>,
101        E: std::fmt::Debug,
102    {
103        let mut attempt = 0;
104
105        loop {
106            attempt += 1;
107
108            match operation().await {
109                Ok(result) => return Ok(result),
110                Err(e) => {
111                    if attempt >= self.max_attempts {
112                        tracing::warn!(
113                            attempt = attempt,
114                            max_attempts = self.max_attempts,
115                            error = ?e,
116                            "Retry exhausted"
117                        );
118                        return Err(e);
119                    }
120
121                    let delay = self.delay_for_attempt(attempt);
122                    tracing::debug!(
123                        attempt = attempt,
124                        delay_ms = delay.as_millis(),
125                        error = ?e,
126                        "Retrying after delay"
127                    );
128
129                    tokio::time::sleep(delay).await;
130                }
131            }
132        }
133    }
134}
135
136/// Circuit breaker state.
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum CircuitState {
139    /// Circuit is closed, requests flow normally.
140    Closed,
141    /// Circuit is open, requests fail immediately.
142    Open,
143    /// Circuit is testing if the service has recovered.
144    HalfOpen,
145}
146
147/// Circuit breaker for failing services.
148#[derive(Debug)]
149pub struct CircuitBreaker {
150    /// Number of failures before opening.
151    failure_threshold: u32,
152    /// Number of successes needed to close from half-open.
153    success_threshold: u32,
154    /// How long to stay open before transitioning to half-open.
155    timeout: Duration,
156    /// Current state.
157    state: RwLock<CircuitState>,
158    /// Current failure count.
159    failure_count: AtomicU32,
160    /// Current success count (in half-open state).
161    success_count: AtomicU32,
162    /// When the circuit was opened.
163    opened_at: RwLock<Option<Instant>>,
164}
165
166impl CircuitBreaker {
167    /// Create a new circuit breaker.
168    pub fn new(failure_threshold: u32, success_threshold: u32, timeout: Duration) -> Self {
169        Self {
170            failure_threshold,
171            success_threshold,
172            timeout,
173            state: RwLock::new(CircuitState::Closed),
174            failure_count: AtomicU32::new(0),
175            success_count: AtomicU32::new(0),
176            opened_at: RwLock::new(None),
177        }
178    }
179
180    /// Get the current state.
181    pub fn state(&self) -> CircuitState {
182        self.maybe_transition_from_open();
183        *self.state.read()
184    }
185
186    /// Check if requests should be allowed.
187    pub fn allow_request(&self) -> bool {
188        self.maybe_transition_from_open();
189        let state = *self.state.read();
190        matches!(state, CircuitState::Closed | CircuitState::HalfOpen)
191    }
192
193    /// Record a successful request.
194    pub fn record_success(&self) {
195        let state = *self.state.read();
196        match state {
197            CircuitState::Closed => {
198                // Reset failure count on success
199                self.failure_count.store(0, Ordering::SeqCst);
200            }
201            CircuitState::HalfOpen => {
202                let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
203                if count >= self.success_threshold {
204                    self.close();
205                }
206            }
207            CircuitState::Open => {}
208        }
209    }
210
211    /// Record a failed request.
212    pub fn record_failure(&self) {
213        let state = *self.state.read();
214        match state {
215            CircuitState::Closed => {
216                let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
217                if count >= self.failure_threshold {
218                    self.open();
219                }
220            }
221            CircuitState::HalfOpen => {
222                // Any failure in half-open state opens the circuit
223                self.open();
224            }
225            CircuitState::Open => {}
226        }
227    }
228
229    /// Open the circuit.
230    fn open(&self) {
231        tracing::warn!("Circuit breaker opened");
232        *self.state.write() = CircuitState::Open;
233        *self.opened_at.write() = Some(Instant::now());
234        self.success_count.store(0, Ordering::SeqCst);
235    }
236
237    /// Close the circuit.
238    fn close(&self) {
239        tracing::info!("Circuit breaker closed");
240        *self.state.write() = CircuitState::Closed;
241        self.failure_count.store(0, Ordering::SeqCst);
242        self.success_count.store(0, Ordering::SeqCst);
243        *self.opened_at.write() = None;
244    }
245
246    /// Check if we should transition from open to half-open.
247    fn maybe_transition_from_open(&self) {
248        let state = *self.state.read();
249        if state != CircuitState::Open {
250            return;
251        }
252
253        if let Some(opened_at) = *self.opened_at.read() {
254            if opened_at.elapsed() >= self.timeout {
255                tracing::info!("Circuit breaker transitioning to half-open");
256                *self.state.write() = CircuitState::HalfOpen;
257                self.success_count.store(0, Ordering::SeqCst);
258            }
259        }
260    }
261
262    /// Execute an operation with circuit breaker protection.
263    pub async fn execute<F, Fut, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError<E>>
264    where
265        F: FnOnce() -> Fut,
266        Fut: Future<Output = Result<T, E>>,
267    {
268        if !self.allow_request() {
269            return Err(CircuitBreakerError::Open);
270        }
271
272        match operation().await {
273            Ok(result) => {
274                self.record_success();
275                Ok(result)
276            }
277            Err(e) => {
278                self.record_failure();
279                Err(CircuitBreakerError::Inner(e))
280            }
281        }
282    }
283}
284
285/// Circuit breaker error wrapper.
286#[derive(Debug)]
287pub enum CircuitBreakerError<E> {
288    /// Circuit is open, request was not attempted.
289    Open,
290    /// The inner operation failed.
291    Inner(E),
292}
293
294impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        match self {
297            Self::Open => write!(f, "Circuit breaker is open"),
298            Self::Inner(e) => write!(f, "{}", e),
299        }
300    }
301}
302
303impl<E: std::error::Error> std::error::Error for CircuitBreakerError<E> {}
304
305/// Timeout configuration.
306#[derive(Debug, Clone)]
307pub struct TimeoutConfig {
308    /// Connection timeout.
309    pub connect: Duration,
310    /// Read timeout.
311    pub read: Duration,
312    /// Write timeout.
313    pub write: Duration,
314    /// Total operation timeout.
315    pub total: Duration,
316}
317
318impl Default for TimeoutConfig {
319    fn default() -> Self {
320        Self {
321            connect: Duration::from_secs(5),
322            read: Duration::from_secs(30),
323            write: Duration::from_secs(30),
324            total: Duration::from_secs(60),
325        }
326    }
327}
328
329impl TimeoutConfig {
330    /// Create a new timeout config.
331    pub fn new(connect: Duration, total: Duration) -> Self {
332        Self {
333            connect,
334            read: total,
335            write: total,
336            total,
337        }
338    }
339
340    /// Execute an operation with timeout.
341    pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, TimeoutError>
342    where
343        F: FnOnce() -> Fut,
344        Fut: Future<Output = T>,
345    {
346        match tokio::time::timeout(self.total, operation()).await {
347            Ok(result) => Ok(result),
348            Err(_) => Err(TimeoutError {
349                timeout: self.total,
350            }),
351        }
352    }
353}
354
355/// Timeout error.
356#[derive(Debug)]
357pub struct TimeoutError {
358    /// The timeout duration that was exceeded.
359    pub timeout: Duration,
360}
361
362impl std::fmt::Display for TimeoutError {
363    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        write!(f, "Operation timed out after {:?}", self.timeout)
365    }
366}
367
368impl std::error::Error for TimeoutError {}
369
370/// Rate limiter using token bucket algorithm.
371#[derive(Debug)]
372pub struct RateLimiter {
373    /// Maximum tokens (requests) per window.
374    max_tokens: u32,
375    /// Current tokens available.
376    tokens: AtomicU32,
377    /// Last refill time (unix timestamp in millis).
378    last_refill: AtomicU64,
379    /// Refill rate (tokens per second).
380    refill_rate: f64,
381}
382
383impl RateLimiter {
384    /// Create a new rate limiter.
385    pub fn new(requests_per_second: f64) -> Self {
386        let max_tokens = (requests_per_second.ceil() as u32).max(1);
387        Self {
388            max_tokens,
389            tokens: AtomicU32::new(max_tokens),
390            last_refill: AtomicU64::new(
391                std::time::SystemTime::now()
392                    .duration_since(std::time::UNIX_EPOCH)
393                    .unwrap_or_default()
394                    .as_millis() as u64,
395            ),
396            refill_rate: requests_per_second,
397        }
398    }
399
400    /// Try to acquire a token (permission to make a request).
401    pub fn try_acquire(&self) -> bool {
402        self.refill();
403
404        loop {
405            let current = self.tokens.load(Ordering::SeqCst);
406            if current == 0 {
407                return false;
408            }
409            if self
410                .tokens
411                .compare_exchange(current, current - 1, Ordering::SeqCst, Ordering::SeqCst)
412                .is_ok()
413            {
414                return true;
415            }
416        }
417    }
418
419    /// Refill tokens based on elapsed time.
420    fn refill(&self) {
421        let now = std::time::SystemTime::now()
422            .duration_since(std::time::UNIX_EPOCH)
423            .unwrap_or_default()
424            .as_millis() as u64;
425
426        let last = self.last_refill.load(Ordering::SeqCst);
427        let elapsed_ms = now.saturating_sub(last);
428        let elapsed_secs = elapsed_ms as f64 / 1000.0;
429        let tokens_to_add = (elapsed_secs * self.refill_rate) as u32;
430
431        if tokens_to_add > 0
432            && self
433                .last_refill
434                .compare_exchange(last, now, Ordering::SeqCst, Ordering::SeqCst)
435                .is_ok()
436        {
437            let current = self.tokens.load(Ordering::SeqCst);
438            let new_tokens = (current + tokens_to_add).min(self.max_tokens);
439            self.tokens.store(new_tokens, Ordering::SeqCst);
440        }
441    }
442
443    /// Get remaining tokens.
444    pub fn remaining(&self) -> u32 {
445        self.refill();
446        self.tokens.load(Ordering::SeqCst)
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_retry_policy_delay() {
456        let policy = RetryPolicy {
457            max_attempts: 3,
458            initial_delay: Duration::from_millis(100),
459            max_delay: Duration::from_secs(5),
460            multiplier: 2.0,
461            jitter: false,
462        };
463
464        assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
465        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(100));
466        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(200));
467        assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(400));
468    }
469
470    #[test]
471    fn test_circuit_breaker_states() {
472        let cb = CircuitBreaker::new(2, 1, Duration::from_millis(100));
473
474        // Initially closed
475        assert_eq!(cb.state(), CircuitState::Closed);
476        assert!(cb.allow_request());
477
478        // Record failures
479        cb.record_failure();
480        assert_eq!(cb.state(), CircuitState::Closed);
481
482        cb.record_failure();
483        assert_eq!(cb.state(), CircuitState::Open);
484        assert!(!cb.allow_request());
485
486        // Wait for timeout
487        std::thread::sleep(Duration::from_millis(150));
488        assert_eq!(cb.state(), CircuitState::HalfOpen);
489        assert!(cb.allow_request());
490
491        // Record success
492        cb.record_success();
493        assert_eq!(cb.state(), CircuitState::Closed);
494    }
495
496    #[test]
497    fn test_rate_limiter() {
498        let limiter = RateLimiter::new(10.0);
499
500        // Should be able to acquire tokens
501        for _ in 0..10 {
502            assert!(limiter.try_acquire());
503        }
504
505        // Should be rate limited
506        assert!(!limiter.try_acquire());
507    }
508}