Skip to main content

sediment/
retry.rs

1//! Retry utilities with exponential backoff
2//!
3//! Provides generic retry logic for transient failures in async operations.
4
5use std::fmt::Display;
6use std::future::Future;
7use std::time::Duration;
8
9use tokio::time::sleep;
10use tracing::{debug, warn};
11
12/// Default maximum number of retry attempts
13const DEFAULT_MAX_ATTEMPTS: u32 = 3;
14
15/// Default initial delay between retries (in milliseconds)
16const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
17
18/// Default maximum delay between retries (in milliseconds)
19const DEFAULT_MAX_DELAY_MS: u64 = 2000;
20
21/// Configuration for retry behavior
22#[derive(Debug, Clone)]
23pub struct RetryConfig {
24    /// Maximum number of attempts (including the initial attempt)
25    pub max_attempts: u32,
26    /// Initial delay between retries in milliseconds
27    pub initial_delay_ms: u64,
28    /// Maximum delay between retries in milliseconds (caps exponential growth)
29    pub max_delay_ms: u64,
30}
31
32impl Default for RetryConfig {
33    fn default() -> Self {
34        Self {
35            max_attempts: DEFAULT_MAX_ATTEMPTS,
36            initial_delay_ms: DEFAULT_INITIAL_DELAY_MS,
37            max_delay_ms: DEFAULT_MAX_DELAY_MS,
38        }
39    }
40}
41
42impl RetryConfig {
43    /// Create a new retry configuration with custom values
44    pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
45        Self {
46            max_attempts,
47            initial_delay_ms,
48            max_delay_ms,
49        }
50    }
51
52    /// Calculate the delay for a given attempt number (0-indexed)
53    fn delay_for_attempt(&self, attempt: u32) -> Duration {
54        let delay_ms = self.initial_delay_ms * 2u64.pow(attempt);
55        let capped_delay_ms = delay_ms.min(self.max_delay_ms);
56        Duration::from_millis(capped_delay_ms)
57    }
58}
59
60/// Execute an async operation with exponential backoff retry.
61///
62/// The operation is retried up to `config.max_attempts` times on failure.
63/// The delay between retries grows exponentially, starting at `initial_delay_ms`
64/// and capped at `max_delay_ms`.
65///
66/// # Arguments
67///
68/// * `config` - Retry configuration
69/// * `operation` - A closure that returns a Future yielding Result<T, E>
70///
71/// # Returns
72///
73/// The result of the successful operation, or the last error if all attempts fail.
74///
75/// # Example
76///
77/// ```ignore
78/// use sediment::retry::{with_retry, RetryConfig};
79///
80/// let result = with_retry(&RetryConfig::default(), || async {
81///     // Your fallible async operation here
82///     Ok::<_, String>("success")
83/// }).await;
84/// ```
85pub async fn with_retry<T, E, F, Fut>(config: &RetryConfig, operation: F) -> Result<T, E>
86where
87    F: Fn() -> Fut,
88    Fut: Future<Output = Result<T, E>>,
89    E: Display,
90{
91    let mut last_error: Option<E> = None;
92
93    for attempt in 0..config.max_attempts {
94        match operation().await {
95            Ok(result) => {
96                if attempt > 0 {
97                    debug!("Operation succeeded on attempt {}", attempt + 1);
98                }
99                return Ok(result);
100            }
101            Err(e) => {
102                let is_last_attempt = attempt + 1 >= config.max_attempts;
103
104                if is_last_attempt {
105                    warn!(
106                        "Operation failed after {} attempts: {}",
107                        config.max_attempts, e
108                    );
109                    last_error = Some(e);
110                } else {
111                    let delay = config.delay_for_attempt(attempt);
112                    warn!(
113                        "Operation failed (attempt {}/{}): {}. Retrying in {:?}...",
114                        attempt + 1,
115                        config.max_attempts,
116                        e,
117                        delay
118                    );
119                    sleep(delay).await;
120                    last_error = Some(e);
121                }
122            }
123        }
124    }
125
126    // Return the last error
127    Err(last_error.expect("at least one attempt should have been made"))
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use std::sync::Arc;
134    use std::sync::atomic::{AtomicU32, Ordering};
135
136    #[tokio::test]
137    async fn test_retry_success_first_attempt() {
138        let config = RetryConfig::default();
139        let result: Result<&str, &str> = with_retry(&config, || async { Ok("success") }).await;
140        assert_eq!(result, Ok("success"));
141    }
142
143    #[tokio::test]
144    async fn test_retry_success_after_failures() {
145        let config = RetryConfig::new(3, 10, 100); // Short delays for testing
146        let attempt_count = Arc::new(AtomicU32::new(0));
147        let attempt_count_clone = attempt_count.clone();
148
149        let result: Result<&str, &str> = with_retry(&config, || {
150            let count = attempt_count_clone.clone();
151            async move {
152                let current = count.fetch_add(1, Ordering::SeqCst);
153                if current < 2 {
154                    Err("transient error")
155                } else {
156                    Ok("success")
157                }
158            }
159        })
160        .await;
161
162        assert_eq!(result, Ok("success"));
163        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
164    }
165
166    #[tokio::test]
167    async fn test_retry_all_failures() {
168        let config = RetryConfig::new(3, 10, 100); // Short delays for testing
169        let attempt_count = Arc::new(AtomicU32::new(0));
170        let attempt_count_clone = attempt_count.clone();
171
172        let result: Result<&str, &str> = with_retry(&config, || {
173            let count = attempt_count_clone.clone();
174            async move {
175                count.fetch_add(1, Ordering::SeqCst);
176                Err("persistent error")
177            }
178        })
179        .await;
180
181        assert_eq!(result, Err("persistent error"));
182        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
183    }
184
185    #[test]
186    fn test_delay_calculation() {
187        let config = RetryConfig::new(5, 100, 1000);
188
189        assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
190        assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
191        assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
192        assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
193        assert_eq!(config.delay_for_attempt(4), Duration::from_millis(1000)); // Capped
194    }
195}