ipfrs_storage/
retry.rs

1//! Retry logic with exponential backoff and jitter
2//!
3//! Provides sophisticated retry strategies for handling transient failures:
4//! - Exponential backoff to prevent overwhelming failing services
5//! - Jitter to prevent thundering herd problems
6//! - Configurable max attempts and timeouts
7//! - Retry condition predicates
8//!
9//! ## Example
10//! ```no_run
11//! use ipfrs_storage::RetryPolicy;
12//! use std::time::Duration;
13//!
14//! async fn flaky_operation() -> Result<String, std::io::Error> {
15//!     // Your operation that might fail transiently
16//!     Ok("success".to_string())
17//! }
18//!
19//! #[tokio::main]
20//! async fn main() {
21//!     let policy = RetryPolicy::exponential(
22//!         Duration::from_millis(100),
23//!         3
24//!     );
25//!
26//!     let result = policy.retry(|| flaky_operation()).await;
27//!     println!("Result: {:?}", result);
28//! }
29//! ```
30
31use anyhow::{anyhow, Result};
32use serde::{Deserialize, Serialize};
33use std::future::Future;
34use std::time::Duration;
35use tokio::time::sleep;
36
37/// Backoff strategy for retries
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum BackoffStrategy {
40    /// Fixed delay between retries
41    Fixed,
42    /// Exponential backoff (delay doubles each retry)
43    Exponential,
44    /// Linear backoff (delay increases linearly)
45    Linear,
46}
47
48/// Jitter type to add randomness to backoff
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50pub enum JitterType {
51    /// No jitter
52    None,
53    /// Full jitter (0 to computed delay)
54    Full,
55    /// Equal jitter (half computed delay + random half)
56    Equal,
57    /// Decorrelated jitter (AWS recommended)
58    Decorrelated,
59}
60
61/// Retry policy configuration
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct RetryPolicy {
64    /// Maximum number of attempts (including initial attempt)
65    pub max_attempts: u32,
66    /// Base delay for backoff
67    pub base_delay: Duration,
68    /// Maximum delay between retries
69    pub max_delay: Duration,
70    /// Backoff strategy
71    pub strategy: BackoffStrategy,
72    /// Jitter type
73    pub jitter: JitterType,
74    /// Multiplier for exponential backoff (default: 2.0)
75    pub backoff_multiplier: f64,
76    /// Overall timeout for all retry attempts
77    pub total_timeout: Option<Duration>,
78}
79
80impl RetryPolicy {
81    /// Create a new retry policy with exponential backoff
82    ///
83    /// # Arguments
84    /// * `base_delay` - Initial delay between retries
85    /// * `max_attempts` - Maximum number of attempts
86    pub fn exponential(base_delay: Duration, max_attempts: u32) -> Self {
87        Self {
88            max_attempts,
89            base_delay,
90            max_delay: Duration::from_secs(60),
91            strategy: BackoffStrategy::Exponential,
92            jitter: JitterType::Equal,
93            backoff_multiplier: 2.0,
94            total_timeout: None,
95        }
96    }
97
98    /// Create a retry policy with fixed delays
99    pub fn fixed(delay: Duration, max_attempts: u32) -> Self {
100        Self {
101            max_attempts,
102            base_delay: delay,
103            max_delay: delay,
104            strategy: BackoffStrategy::Fixed,
105            jitter: JitterType::None,
106            backoff_multiplier: 1.0,
107            total_timeout: None,
108        }
109    }
110
111    /// Create a retry policy with linear backoff
112    pub fn linear(base_delay: Duration, max_attempts: u32) -> Self {
113        Self {
114            max_attempts,
115            base_delay,
116            max_delay: Duration::from_secs(60),
117            strategy: BackoffStrategy::Linear,
118            jitter: JitterType::Equal,
119            backoff_multiplier: 1.0,
120            total_timeout: None,
121        }
122    }
123
124    /// Set maximum delay
125    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
126        self.max_delay = max_delay;
127        self
128    }
129
130    /// Set jitter type
131    pub fn with_jitter(mut self, jitter: JitterType) -> Self {
132        self.jitter = jitter;
133        self
134    }
135
136    /// Set total timeout
137    pub fn with_timeout(mut self, timeout: Duration) -> Self {
138        self.total_timeout = Some(timeout);
139        self
140    }
141
142    /// Set backoff multiplier
143    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
144        self.backoff_multiplier = multiplier;
145        self
146    }
147
148    /// Calculate delay for a given attempt number
149    fn calculate_delay(&self, attempt: u32) -> Duration {
150        if attempt == 0 {
151            return Duration::from_secs(0);
152        }
153
154        let base_ms = self.base_delay.as_millis() as f64;
155
156        let computed_delay_ms = match self.strategy {
157            BackoffStrategy::Fixed => base_ms,
158            BackoffStrategy::Exponential => {
159                base_ms * self.backoff_multiplier.powi(attempt as i32 - 1)
160            }
161            BackoffStrategy::Linear => base_ms * attempt as f64,
162        };
163
164        // Cap at max delay
165        let capped_ms = computed_delay_ms.min(self.max_delay.as_millis() as f64);
166
167        // Apply jitter
168        let final_ms = match self.jitter {
169            JitterType::None => capped_ms,
170            JitterType::Full => {
171                // Random value between 0 and computed delay
172                fastrand::f64() * capped_ms
173            }
174            JitterType::Equal => {
175                // Half of computed delay + random half
176                capped_ms / 2.0 + (fastrand::f64() * capped_ms / 2.0)
177            }
178            JitterType::Decorrelated => {
179                // AWS recommended: min(max_delay, random(base, last_delay * 3))
180                let last_delay = if attempt > 1 {
181                    self.calculate_delay(attempt - 1).as_millis() as f64
182                } else {
183                    base_ms
184                };
185                let random_delay = base_ms + (fastrand::f64() * (last_delay * 3.0 - base_ms));
186                random_delay.min(self.max_delay.as_millis() as f64)
187            }
188        };
189
190        Duration::from_millis(final_ms as u64)
191    }
192
193    /// Execute a function with retry logic
194    ///
195    /// # Arguments
196    /// * `f` - Function to retry
197    ///
198    /// # Returns
199    /// Result of the function or last error
200    pub async fn retry<F, Fut, T, E>(&self, mut f: F) -> Result<T>
201    where
202        F: FnMut() -> Fut,
203        Fut: Future<Output = Result<T, E>>,
204        E: std::error::Error + Send + Sync + 'static,
205    {
206        let start_time = std::time::Instant::now();
207        let mut last_error = None;
208
209        for attempt in 0..self.max_attempts {
210            // Check total timeout
211            if let Some(timeout) = self.total_timeout {
212                if start_time.elapsed() >= timeout {
213                    return Err(anyhow!("Retry timeout exceeded after {attempt} attempts"));
214                }
215            }
216
217            // Try the operation
218            match f().await {
219                Ok(result) => return Ok(result),
220                Err(e) => {
221                    last_error = Some(e);
222
223                    // Don't sleep after the last attempt
224                    if attempt + 1 < self.max_attempts {
225                        let delay = self.calculate_delay(attempt + 1);
226                        sleep(delay).await;
227                    }
228                }
229            }
230        }
231
232        // All attempts failed
233        if let Some(e) = last_error {
234            Err(anyhow!(
235                "Operation failed after {} attempts: {}",
236                self.max_attempts,
237                e
238            ))
239        } else {
240            Err(anyhow!(
241                "Operation failed after {} attempts",
242                self.max_attempts
243            ))
244        }
245    }
246}
247
248impl Default for RetryPolicy {
249    fn default() -> Self {
250        Self::exponential(Duration::from_millis(100), 3)
251    }
252}
253
254/// Trait for retryable operations
255pub trait Retryable<T, E> {
256    /// Execute with retry policy
257    fn with_retry(self, policy: RetryPolicy) -> impl Future<Output = Result<T>>;
258}
259
260/// Retry statistics
261#[derive(Debug, Clone, Default, Serialize, Deserialize)]
262pub struct RetryStats {
263    /// Total retry attempts made
264    pub total_attempts: u64,
265    /// Successful operations
266    pub successful_ops: u64,
267    /// Failed operations (after all retries)
268    pub failed_ops: u64,
269    /// Total delay time spent in retries
270    pub total_delay_ms: u64,
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use std::sync::atomic::{AtomicU32, Ordering};
277    use std::sync::Arc;
278
279    #[tokio::test]
280    async fn test_retry_success_first_attempt() {
281        let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
282
283        let result = policy
284            .retry(|| async { Ok::<_, std::io::Error>("success") })
285            .await;
286
287        assert!(result.is_ok());
288        assert_eq!(result.unwrap(), "success");
289    }
290
291    #[tokio::test]
292    async fn test_retry_success_after_failures() {
293        let counter = Arc::new(AtomicU32::new(0));
294        let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
295
296        let counter_clone = counter.clone();
297        let result = policy
298            .retry(|| {
299                let c = counter_clone.clone();
300                async move {
301                    let count = c.fetch_add(1, Ordering::SeqCst);
302                    if count < 2 {
303                        Err(std::io::Error::new(
304                            std::io::ErrorKind::Other,
305                            "Transient failure",
306                        ))
307                    } else {
308                        Ok("success")
309                    }
310                }
311            })
312            .await;
313
314        assert!(result.is_ok());
315        assert_eq!(result.unwrap(), "success");
316        assert_eq!(counter.load(Ordering::SeqCst), 3);
317    }
318
319    #[tokio::test]
320    async fn test_retry_all_attempts_fail() {
321        let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
322
323        let result = policy
324            .retry(|| async {
325                Err::<&str, std::io::Error>(std::io::Error::new(
326                    std::io::ErrorKind::Other,
327                    "Always fails",
328                ))
329            })
330            .await;
331
332        assert!(result.is_err());
333    }
334
335    #[tokio::test]
336    async fn test_fixed_backoff() {
337        let policy = RetryPolicy::fixed(Duration::from_millis(50), 3);
338
339        for i in 1..=3 {
340            let delay = policy.calculate_delay(i);
341            assert_eq!(delay.as_millis(), 50);
342        }
343    }
344
345    #[tokio::test]
346    async fn test_exponential_backoff() {
347        let policy = RetryPolicy::exponential(Duration::from_millis(100), 4);
348
349        // Exponential growth (without jitter for this test)
350        let policy_no_jitter = policy.with_jitter(JitterType::None);
351        let d1 = policy_no_jitter.calculate_delay(1).as_millis();
352        let d2 = policy_no_jitter.calculate_delay(2).as_millis();
353        let d3 = policy_no_jitter.calculate_delay(3).as_millis();
354
355        assert_eq!(d1, 100);
356        assert_eq!(d2, 200);
357        assert_eq!(d3, 400);
358    }
359
360    #[tokio::test]
361    async fn test_linear_backoff() {
362        let policy =
363            RetryPolicy::linear(Duration::from_millis(100), 4).with_jitter(JitterType::None);
364
365        let d1 = policy.calculate_delay(1).as_millis();
366        let d2 = policy.calculate_delay(2).as_millis();
367        let d3 = policy.calculate_delay(3).as_millis();
368
369        assert_eq!(d1, 100);
370        assert_eq!(d2, 200);
371        assert_eq!(d3, 300);
372    }
373
374    #[tokio::test]
375    async fn test_max_delay_cap() {
376        let policy = RetryPolicy::exponential(Duration::from_millis(100), 10)
377            .with_max_delay(Duration::from_millis(500))
378            .with_jitter(JitterType::None);
379
380        let delay = policy.calculate_delay(5);
381        assert!(delay.as_millis() <= 500);
382    }
383
384    #[tokio::test]
385    async fn test_jitter_full() {
386        let policy =
387            RetryPolicy::exponential(Duration::from_millis(100), 3).with_jitter(JitterType::Full);
388
389        // With full jitter, delay should be between 0 and computed delay
390        for _ in 0..10 {
391            let delay = policy.calculate_delay(1);
392            assert!(delay.as_millis() <= 100);
393        }
394    }
395
396    #[tokio::test]
397    async fn test_jitter_equal() {
398        let policy =
399            RetryPolicy::exponential(Duration::from_millis(100), 3).with_jitter(JitterType::Equal);
400
401        // With equal jitter, delay should be between 50 and 100
402        for _ in 0..10 {
403            let delay = policy.calculate_delay(1);
404            let ms = delay.as_millis();
405            assert!(ms >= 50 && ms <= 100);
406        }
407    }
408
409    #[tokio::test]
410    async fn test_timeout() {
411        let policy = RetryPolicy::exponential(Duration::from_millis(50), 10)
412            .with_timeout(Duration::from_millis(150));
413
414        let start = std::time::Instant::now();
415        let result = policy
416            .retry(|| async {
417                Err::<&str, std::io::Error>(std::io::Error::new(
418                    std::io::ErrorKind::Other,
419                    "Always fails",
420                ))
421            })
422            .await;
423
424        let elapsed = start.elapsed();
425        assert!(result.is_err());
426        // Should timeout before all retries complete, but allow some margin
427        assert!(elapsed < Duration::from_millis(500));
428        assert!(elapsed >= Duration::from_millis(150));
429    }
430}