riglr_core/
retry.rs

1//! Generic retry utilities for robust async operations
2//!
3//! This module provides a centralized retry mechanism with exponential backoff,
4//! jitter, and error classification for any async operation.
5
6use backoff::{backoff::Backoff, ExponentialBackoff, ExponentialBackoffBuilder};
7use std::future::Future;
8use std::time::Duration;
9use tracing::{debug, warn};
10
11/// Configuration for retry behavior
12#[derive(Debug, Clone)]
13pub struct RetryConfig {
14    /// Maximum number of retry attempts
15    pub max_retries: u32,
16    /// Initial retry delay in milliseconds
17    pub base_delay_ms: u64,
18    /// Maximum retry delay in milliseconds
19    pub max_delay_ms: u64,
20    /// Multiplier for exponential backoff
21    pub backoff_multiplier: f64,
22    /// Whether to use jitter to avoid thundering herd
23    pub use_jitter: bool,
24}
25
26impl Default for RetryConfig {
27    fn default() -> Self {
28        Self {
29            max_retries: 3,
30            base_delay_ms: 1000,     // Start with 1 second
31            max_delay_ms: 30_000,    // Cap at 30 seconds
32            backoff_multiplier: 2.0, // Double each time
33            use_jitter: true,
34        }
35    }
36}
37
38impl RetryConfig {
39    /// Create a config for fast retries (e.g., RPC calls)
40    pub fn fast() -> Self {
41        Self {
42            max_retries: 5,
43            base_delay_ms: 100,      // Start with 100ms
44            max_delay_ms: 5_000,     // Cap at 5 seconds
45            backoff_multiplier: 1.5, // Gentler increase
46            use_jitter: true,
47        }
48    }
49
50    /// Create a config for slow retries (e.g., rate-limited APIs)
51    pub fn slow() -> Self {
52        Self {
53            max_retries: 3,
54            base_delay_ms: 5000,     // Start with 5 seconds
55            max_delay_ms: 60_000,    // Cap at 1 minute
56            backoff_multiplier: 2.0, // Double each time
57            use_jitter: true,
58        }
59    }
60}
61
62/// Error classification for retry logic
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum ErrorClass {
65    /// Error is permanent and should not be retried
66    Permanent,
67    /// Error is temporary and should be retried
68    Retryable,
69    /// Error indicates rate limiting, use longer backoff
70    RateLimited,
71}
72
73/// Create a backoff instance from our config
74fn create_backoff(config: &RetryConfig) -> ExponentialBackoff {
75    let mut builder = ExponentialBackoffBuilder::new();
76    builder.with_initial_interval(Duration::from_millis(config.base_delay_ms));
77    builder.with_max_interval(Duration::from_millis(config.max_delay_ms));
78    builder.with_multiplier(config.backoff_multiplier);
79    builder.with_max_elapsed_time(None); // We handle max retries manually
80
81    if !config.use_jitter {
82        builder.with_randomization_factor(0.0);
83    } else {
84        // Default randomization factor is 0.5 (±50%)
85        builder.with_randomization_factor(0.25);
86    }
87
88    builder.build()
89}
90
91/// Execute an async operation with retry logic
92///
93/// This function provides a generic retry mechanism for any async operation.
94/// It automatically applies exponential backoff with optional jitter.
95///
96/// # Arguments
97///
98/// * `operation` - Async closure that performs the operation
99/// * `classifier` - Function to classify errors as permanent or retryable
100/// * `config` - Retry configuration
101/// * `operation_name` - Human-readable name for logging
102///
103/// # Type Parameters
104///
105/// * `F` - The async operation closure type
106/// * `Fut` - The future type returned by the operation
107/// * `T` - The success type
108/// * `E` - The error type
109/// * `C` - The error classifier function type
110///
111/// # Returns
112///
113/// Returns the successful result or the last error after all retries are exhausted
114///
115/// # Examples
116///
117/// ```rust,ignore
118/// use riglr_core::retry::{retry_async, RetryConfig, ErrorClass};
119///
120/// async fn example() -> Result<String, MyError> {
121///     retry_async(
122///         || async {
123///             // Your async operation here
124///             fetch_data().await
125///         },
126///         |error| {
127///             // Classify error
128///             match error {
129///                 MyError::NetworkTimeout => ErrorClass::Retryable,
130///                 MyError::InvalidInput => ErrorClass::Permanent,
131///                 MyError::RateLimited => ErrorClass::RateLimited,
132///             }
133///         },
134///         &RetryConfig::default(),
135///         "fetch_data"
136///     ).await
137/// }
138/// ```
139pub async fn retry_async<F, Fut, T, E, C>(
140    mut operation: F,
141    classifier: C,
142    config: &RetryConfig,
143    operation_name: &str,
144) -> Result<T, E>
145where
146    F: FnMut() -> Fut,
147    Fut: Future<Output = Result<T, E>>,
148    E: std::fmt::Display + Clone,
149    C: Fn(&E) -> ErrorClass,
150{
151    debug!(
152        "Starting operation '{}' with retry config: max_retries={}, base_delay={}ms",
153        operation_name, config.max_retries, config.base_delay_ms
154    );
155
156    let mut backoff = create_backoff(config);
157    let mut attempts = 0u32;
158
159    loop {
160        attempts += 1;
161        debug!("Attempt {} for '{}'", attempts, operation_name);
162
163        match operation().await {
164            Ok(result) => {
165                if attempts > 1 {
166                    debug!(
167                        "Operation '{}' succeeded after {} attempts",
168                        operation_name, attempts
169                    );
170                }
171                return Ok(result);
172            }
173            Err(error) => {
174                let error_class = classifier(&error);
175
176                warn!(
177                    "Operation '{}' failed (attempt {}): {} (class: {:?})",
178                    operation_name, attempts, error, error_class
179                );
180
181                // Check if we should retry
182                match error_class {
183                    ErrorClass::Permanent => {
184                        debug!("Error is permanent, not retrying");
185                        return Err(error);
186                    }
187                    ErrorClass::Retryable | ErrorClass::RateLimited => {
188                        if attempts > config.max_retries {
189                            warn!(
190                                "Operation '{}' failed after {} attempts",
191                                operation_name, attempts
192                            );
193                            return Err(error);
194                        }
195
196                        // Get next backoff duration
197                        let delay = if let Some(duration) = backoff.next_backoff() {
198                            // For rate-limited errors, double the delay
199                            if error_class == ErrorClass::RateLimited {
200                                duration * 2
201                            } else {
202                                duration
203                            }
204                        } else {
205                            // Backoff exhausted (shouldn't happen with our config)
206                            warn!("Backoff exhausted for '{}'", operation_name);
207                            return Err(error);
208                        };
209
210                        debug!("Retrying '{}' after {:?}", operation_name, delay);
211                        tokio::time::sleep(delay).await;
212                    }
213                }
214            }
215        }
216    }
217}
218
219/// Simplified retry for operations that return std::result::Result
220pub async fn retry_with_backoff<F, Fut, T>(
221    operation: F,
222    config: &RetryConfig,
223    operation_name: &str,
224) -> Result<T, String>
225where
226    F: FnMut() -> Fut,
227    Fut: Future<Output = Result<T, String>>,
228{
229    retry_async(
230        operation,
231        |_| ErrorClass::Retryable, // Treat all errors as retryable by default
232        config,
233        operation_name,
234    )
235    .await
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use std::sync::atomic::{AtomicU32, Ordering};
242    use std::sync::Arc;
243
244    #[tokio::test]
245    async fn test_retry_succeeds_first_attempt() {
246        let config = RetryConfig::fast();
247        let result = retry_async(
248            || async { Ok::<_, String>("success") },
249            |_| ErrorClass::Retryable,
250            &config,
251            "test_op",
252        )
253        .await;
254
255        assert_eq!(result.unwrap(), "success");
256    }
257
258    #[tokio::test]
259    async fn test_retry_succeeds_after_failures() {
260        let attempts = Arc::new(AtomicU32::new(0));
261        let attempts_clone = attempts.clone();
262
263        let config = RetryConfig::fast();
264        let result = retry_async(
265            || {
266                let attempts = attempts_clone.clone();
267                async move {
268                    let count = attempts.fetch_add(1, Ordering::SeqCst);
269                    if count < 2 {
270                        Err("temporary failure".to_string())
271                    } else {
272                        Ok("success")
273                    }
274                }
275            },
276            |_| ErrorClass::Retryable,
277            &config,
278            "test_op",
279        )
280        .await;
281
282        assert_eq!(result.unwrap(), "success");
283        assert_eq!(attempts.load(Ordering::SeqCst), 3);
284    }
285
286    #[tokio::test]
287    async fn test_retry_permanent_error_no_retry() {
288        let attempts = Arc::new(AtomicU32::new(0));
289        let attempts_clone = attempts.clone();
290
291        let config = RetryConfig::fast();
292        let result = retry_async(
293            || {
294                let attempts = attempts_clone.clone();
295                async move {
296                    attempts.fetch_add(1, Ordering::SeqCst);
297                    Err::<String, _>("permanent error".to_string())
298                }
299            },
300            |_| ErrorClass::Permanent,
301            &config,
302            "test_op",
303        )
304        .await;
305
306        assert!(result.is_err());
307        assert_eq!(attempts.load(Ordering::SeqCst), 1); // Only one attempt
308    }
309
310    #[tokio::test]
311    async fn test_retry_exhausts_all_attempts() {
312        let attempts = Arc::new(AtomicU32::new(0));
313        let attempts_clone = attempts.clone();
314
315        let config = RetryConfig {
316            max_retries: 2,
317            base_delay_ms: 10,
318            max_delay_ms: 100,
319            backoff_multiplier: 2.0,
320            use_jitter: false,
321        };
322
323        let result = retry_async(
324            || {
325                let attempts = attempts_clone.clone();
326                async move {
327                    attempts.fetch_add(1, Ordering::SeqCst);
328                    Err::<String, _>("always fails".to_string())
329                }
330            },
331            |_| ErrorClass::Retryable,
332            &config,
333            "test_op",
334        )
335        .await;
336
337        assert!(result.is_err());
338        assert_eq!(attempts.load(Ordering::SeqCst), 3); // Initial + 2 retries
339    }
340
341    #[test]
342    fn test_create_backoff_config() {
343        // Test that create_backoff produces correct ExponentialBackoff from our config
344        let config = RetryConfig {
345            max_retries: 5,
346            base_delay_ms: 100,
347            max_delay_ms: 10_000,
348            backoff_multiplier: 2.0,
349            use_jitter: false,
350        };
351
352        let backoff = create_backoff(&config);
353        assert_eq!(backoff.initial_interval, Duration::from_millis(100));
354        assert_eq!(backoff.max_interval, Duration::from_millis(10_000));
355        assert_eq!(backoff.multiplier, 2.0);
356        assert_eq!(backoff.randomization_factor, 0.0); // No jitter
357    }
358
359    #[test]
360    fn test_create_backoff_with_jitter() {
361        let config = RetryConfig {
362            max_retries: 5,
363            base_delay_ms: 100,
364            max_delay_ms: 10_000,
365            backoff_multiplier: 2.0,
366            use_jitter: true,
367        };
368
369        let backoff = create_backoff(&config);
370        assert_eq!(backoff.randomization_factor, 0.25); // 25% jitter
371    }
372
373    #[test]
374    fn test_retry_config_presets() {
375        let fast = RetryConfig::fast();
376        assert_eq!(fast.base_delay_ms, 100);
377        assert_eq!(fast.max_retries, 5);
378
379        let slow = RetryConfig::slow();
380        assert_eq!(slow.base_delay_ms, 5000);
381        assert_eq!(slow.max_retries, 3);
382    }
383}