Skip to main content

sediment/
retry.rs

1//! Retry utilities with exponential backoff
2//!
3//! Provides generic retry logic for transient failures in async operations.
4
5use std::fmt::Display;
6use std::future::Future;
7use std::time::Duration;
8
9use tokio::time::sleep;
10use tracing::{debug, warn};
11
12/// Default maximum number of retry attempts
13const DEFAULT_MAX_ATTEMPTS: u32 = 3;
14
15/// Default initial delay between retries (in milliseconds)
16const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
17
18/// Default maximum delay between retries (in milliseconds)
19const DEFAULT_MAX_DELAY_MS: u64 = 2000;
20
21/// Configuration for retry behavior
22#[derive(Debug, Clone)]
23pub struct RetryConfig {
24    /// Maximum number of attempts (including the initial attempt)
25    pub max_attempts: u32,
26    /// Initial delay between retries in milliseconds
27    pub initial_delay_ms: u64,
28    /// Maximum delay between retries in milliseconds (caps exponential growth)
29    pub max_delay_ms: u64,
30}
31
32impl Default for RetryConfig {
33    fn default() -> Self {
34        Self {
35            max_attempts: DEFAULT_MAX_ATTEMPTS,
36            initial_delay_ms: DEFAULT_INITIAL_DELAY_MS,
37            max_delay_ms: DEFAULT_MAX_DELAY_MS,
38        }
39    }
40}
41
42impl RetryConfig {
43    /// Create a new retry configuration with custom values.
44    /// `max_attempts` is clamped to a minimum of 1.
45    pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
46        Self {
47            max_attempts: max_attempts.max(1),
48            initial_delay_ms,
49            max_delay_ms,
50        }
51    }
52
53    /// Calculate the delay for a given attempt number (0-indexed)
54    fn delay_for_attempt(&self, attempt: u32) -> Duration {
55        let delay_ms = self
56            .initial_delay_ms
57            .saturating_mul(1u64.checked_shl(attempt).unwrap_or(u64::MAX));
58        let capped_delay_ms = delay_ms.min(self.max_delay_ms);
59        Duration::from_millis(capped_delay_ms)
60    }
61}
62
63/// Execute an async operation with exponential backoff retry.
64///
65/// The operation is retried up to `config.max_attempts` times on failure.
66/// The delay between retries grows exponentially, starting at `initial_delay_ms`
67/// and capped at `max_delay_ms`.
68///
69/// # Arguments
70///
71/// * `config` - Retry configuration
72/// * `operation` - A closure that returns a Future yielding Result<T, E>
73///
74/// # Returns
75///
76/// The result of the successful operation, or the last error if all attempts fail.
77///
78/// # Example
79///
80/// ```ignore
81/// use sediment::retry::{with_retry, RetryConfig};
82///
83/// let result = with_retry(&RetryConfig::default(), || async {
84///     // Your fallible async operation here
85///     Ok::<_, String>("success")
86/// }).await;
87/// ```
88pub async fn with_retry<T, E, F, Fut>(config: &RetryConfig, operation: F) -> Result<T, E>
89where
90    F: Fn() -> Fut,
91    Fut: Future<Output = Result<T, E>>,
92    E: Display,
93{
94    let mut last_error: Option<E> = None;
95
96    for attempt in 0..config.max_attempts {
97        match operation().await {
98            Ok(result) => {
99                if attempt > 0 {
100                    debug!("Operation succeeded on attempt {}", attempt + 1);
101                }
102                return Ok(result);
103            }
104            Err(e) => {
105                let is_last_attempt = attempt + 1 >= config.max_attempts;
106
107                if is_last_attempt {
108                    warn!(
109                        "Operation failed after {} attempts: {}",
110                        config.max_attempts, e
111                    );
112                    last_error = Some(e);
113                } else {
114                    let delay = config.delay_for_attempt(attempt);
115                    warn!(
116                        "Operation failed (attempt {}/{}): {}. Retrying in {:?}...",
117                        attempt + 1,
118                        config.max_attempts,
119                        e,
120                        delay
121                    );
122                    sleep(delay).await;
123                    last_error = Some(e);
124                }
125            }
126        }
127    }
128
129    // Return the last error
130    Err(last_error.expect("at least one attempt should have been made"))
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use std::sync::Arc;
137    use std::sync::atomic::{AtomicU32, Ordering};
138
139    #[tokio::test]
140    async fn test_retry_success_first_attempt() {
141        let config = RetryConfig::default();
142        let result: Result<&str, &str> = with_retry(&config, || async { Ok("success") }).await;
143        assert_eq!(result, Ok("success"));
144    }
145
146    #[tokio::test]
147    async fn test_retry_success_after_failures() {
148        let config = RetryConfig::new(3, 10, 100); // Short delays for testing
149        let attempt_count = Arc::new(AtomicU32::new(0));
150        let attempt_count_clone = attempt_count.clone();
151
152        let result: Result<&str, &str> = with_retry(&config, || {
153            let count = attempt_count_clone.clone();
154            async move {
155                let current = count.fetch_add(1, Ordering::SeqCst);
156                if current < 2 {
157                    Err("transient error")
158                } else {
159                    Ok("success")
160                }
161            }
162        })
163        .await;
164
165        assert_eq!(result, Ok("success"));
166        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
167    }
168
169    #[tokio::test]
170    async fn test_retry_all_failures() {
171        let config = RetryConfig::new(3, 10, 100); // Short delays for testing
172        let attempt_count = Arc::new(AtomicU32::new(0));
173        let attempt_count_clone = attempt_count.clone();
174
175        let result: Result<&str, &str> = with_retry(&config, || {
176            let count = attempt_count_clone.clone();
177            async move {
178                count.fetch_add(1, Ordering::SeqCst);
179                Err("persistent error")
180            }
181        })
182        .await;
183
184        assert_eq!(result, Err("persistent error"));
185        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
186    }
187
188    #[test]
189    fn test_delay_for_attempt_no_overflow() {
190        // Bug #10: large attempt numbers should not panic due to overflow
191        let config = RetryConfig::new(100, 100, 2000);
192        // These should not panic
193        let d64 = config.delay_for_attempt(64);
194        let d100 = config.delay_for_attempt(99);
195        // Should be capped at max_delay_ms
196        assert_eq!(d64, Duration::from_millis(2000));
197        assert_eq!(d100, Duration::from_millis(2000));
198    }
199
200    #[test]
201    fn test_delay_calculation() {
202        let config = RetryConfig::new(5, 100, 1000);
203
204        assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
205        assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
206        assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
207        assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
208        assert_eq!(config.delay_for_attempt(4), Duration::from_millis(1000)); // Capped
209    }
210}