Skip to main content

rc_core/
retry.rs

1//! Retry mechanism with exponential backoff and jitter
2//!
3//! Implements retry logic for transient failures like network errors and 503 responses.
4
5use std::time::Duration;
6
7use crate::alias::RetryConfig;
8use crate::error::{Error, Result};
9
10/// Retry a fallible async operation with exponential backoff
11///
12/// # Arguments
13/// * `config` - Retry configuration
14/// * `operation` - Async closure that returns `Result<T>`
15/// * `is_retryable` - Closure that determines if an error should trigger retry
16///
17/// # Example
18/// ```ignore
19/// let result = retry_with_backoff(
20///     &config,
21///     || async { client.get_object(path).await },
22///     |e| e.is_retryable(),
23/// ).await;
24/// ```
25pub async fn retry_with_backoff<T, F, Fut, R>(
26    config: &RetryConfig,
27    mut operation: F,
28    is_retryable: R,
29) -> Result<T>
30where
31    F: FnMut() -> Fut,
32    Fut: std::future::Future<Output = Result<T>>,
33    R: Fn(&Error) -> bool,
34{
35    let mut attempt = 0;
36
37    loop {
38        attempt += 1;
39
40        match operation().await {
41            Ok(result) => return Ok(result),
42            Err(e) => {
43                if attempt >= config.max_attempts || !is_retryable(&e) {
44                    return Err(e);
45                }
46
47                let backoff = calculate_backoff(config, attempt);
48                tracing::debug!(
49                    attempt = attempt,
50                    backoff_ms = backoff.as_millis(),
51                    error = %e,
52                    "Retrying after transient error"
53                );
54
55                tokio::time::sleep(backoff).await;
56            }
57        }
58    }
59}
60
61/// Calculate backoff duration with jitter
62fn calculate_backoff(config: &RetryConfig, attempt: u32) -> Duration {
63    // Exponential backoff: initial * 2^(attempt-1)
64    let base_ms = config.initial_backoff_ms * (1u64 << (attempt - 1).min(10));
65    let capped_ms = base_ms.min(config.max_backoff_ms);
66
67    // Add jitter: random value between 0 and backoff
68    let jitter_ms = rand_jitter(capped_ms);
69    Duration::from_millis(capped_ms + jitter_ms)
70}
71
72/// Generate pseudo-random jitter without external RNG dependency
73fn rand_jitter(max: u64) -> u64 {
74    use std::time::SystemTime;
75    let nanos = SystemTime::now()
76        .duration_since(SystemTime::UNIX_EPOCH)
77        .unwrap_or_default()
78        .subsec_nanos() as u64;
79    nanos % max.max(1)
80}
81
82/// Check if an error is retryable (transient)
83pub fn is_retryable_error(error: &Error) -> bool {
84    match error {
85        Error::Network(msg) => {
86            // Retryable network errors
87            let msg_lower = msg.to_lowercase();
88            msg_lower.contains("timeout")
89                || msg_lower.contains("connection reset")
90                || msg_lower.contains("connection refused")
91                || msg_lower.contains("503")
92                || msg_lower.contains("service unavailable")
93                || msg_lower.contains("too many requests")
94                || msg_lower.contains("429")
95                || msg_lower.contains("request rate")
96                || msg_lower.contains("slow down")
97        }
98        Error::Io(e) => {
99            // Retryable I/O errors
100            matches!(
101                e.kind(),
102                std::io::ErrorKind::ConnectionReset
103                    | std::io::ErrorKind::ConnectionRefused
104                    | std::io::ErrorKind::TimedOut
105                    | std::io::ErrorKind::Interrupted
106            )
107        }
108        // Non-retryable errors
109        Error::Auth(_)
110        | Error::NotFound(_)
111        | Error::AliasNotFound(_)
112        | Error::Conflict(_)
113        | Error::InvalidPath(_)
114        | Error::Config(_)
115        | Error::UnsupportedFeature(_) => false,
116        // General errors might be retryable
117        Error::General(msg) => {
118            let msg_lower = msg.to_lowercase();
119            msg_lower.contains("timeout") || msg_lower.contains("temporary")
120        }
121        _ => false,
122    }
123}
124
125/// Retry configuration builder for easy customization
126#[derive(Debug, Clone)]
127pub struct RetryBuilder {
128    max_attempts: u32,
129    initial_backoff_ms: u64,
130    max_backoff_ms: u64,
131}
132
133impl RetryBuilder {
134    pub fn new() -> Self {
135        Self {
136            max_attempts: 3,
137            initial_backoff_ms: 100,
138            max_backoff_ms: 10000,
139        }
140    }
141
142    pub fn max_attempts(mut self, n: u32) -> Self {
143        self.max_attempts = n;
144        self
145    }
146
147    pub fn initial_backoff_ms(mut self, ms: u64) -> Self {
148        self.initial_backoff_ms = ms;
149        self
150    }
151
152    pub fn max_backoff_ms(mut self, ms: u64) -> Self {
153        self.max_backoff_ms = ms;
154        self
155    }
156
157    pub fn build(self) -> RetryConfig {
158        RetryConfig {
159            max_attempts: self.max_attempts,
160            initial_backoff_ms: self.initial_backoff_ms,
161            max_backoff_ms: self.max_backoff_ms,
162        }
163    }
164}
165
166impl Default for RetryBuilder {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_calculate_backoff() {
178        let config = RetryConfig {
179            max_attempts: 3,
180            initial_backoff_ms: 100,
181            max_backoff_ms: 10000,
182        };
183
184        // First attempt should have base backoff
185        let b1 = calculate_backoff(&config, 1);
186        assert!(b1.as_millis() >= 100 && b1.as_millis() < 200);
187
188        // Second attempt doubles
189        let b2 = calculate_backoff(&config, 2);
190        assert!(b2.as_millis() >= 200 && b2.as_millis() < 400);
191
192        // Third attempt quadruples
193        let b3 = calculate_backoff(&config, 3);
194        assert!(b3.as_millis() >= 400 && b3.as_millis() < 800);
195    }
196
197    #[test]
198    fn test_backoff_cap() {
199        let config = RetryConfig {
200            max_attempts: 10,
201            initial_backoff_ms: 1000,
202            max_backoff_ms: 5000,
203        };
204
205        // Even with many attempts, should not exceed max
206        let b = calculate_backoff(&config, 10);
207        assert!(b.as_millis() <= 10000); // max + jitter
208    }
209
210    #[test]
211    fn test_is_retryable_error() {
212        // Network errors are retryable
213        assert!(is_retryable_error(&Error::Network(
214            "connection timeout".to_string()
215        )));
216        assert!(is_retryable_error(&Error::Network(
217            "503 Service Unavailable".to_string()
218        )));
219        assert!(is_retryable_error(&Error::Network(
220            "429 Too Many Requests".to_string()
221        )));
222
223        // Auth errors are not retryable
224        assert!(!is_retryable_error(&Error::Auth(
225            "access denied".to_string()
226        )));
227
228        // Not found is not retryable
229        assert!(!is_retryable_error(&Error::NotFound(
230            "object not found".to_string()
231        )));
232    }
233
234    #[test]
235    fn test_retry_builder() {
236        let config = RetryBuilder::new()
237            .max_attempts(5)
238            .initial_backoff_ms(200)
239            .max_backoff_ms(20000)
240            .build();
241
242        assert_eq!(config.max_attempts, 5);
243        assert_eq!(config.initial_backoff_ms, 200);
244        assert_eq!(config.max_backoff_ms, 20000);
245    }
246
247    #[tokio::test]
248    async fn test_retry_success_first_attempt() {
249        let config = RetryConfig::default();
250        let mut calls = 0;
251
252        let result = retry_with_backoff(
253            &config,
254            || {
255                calls += 1;
256                async { Ok::<_, Error>(42) }
257            },
258            |_| true,
259        )
260        .await;
261
262        assert_eq!(result.unwrap(), 42);
263        assert_eq!(calls, 1);
264    }
265
266    #[tokio::test]
267    async fn test_retry_success_after_failure() {
268        let config = RetryConfig {
269            max_attempts: 3,
270            initial_backoff_ms: 1, // Fast for tests
271            max_backoff_ms: 10,
272        };
273        let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
274        let call_count_clone = call_count.clone();
275
276        let result = retry_with_backoff(
277            &config,
278            || {
279                let cc = call_count_clone.clone();
280                async move {
281                    let count = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
282                    if count < 2 {
283                        Err(Error::Network("timeout".to_string()))
284                    } else {
285                        Ok(42)
286                    }
287                }
288            },
289            is_retryable_error,
290        )
291        .await;
292
293        assert_eq!(result.unwrap(), 42);
294        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
295    }
296
297    #[tokio::test]
298    async fn test_retry_exhausted() {
299        let config = RetryConfig {
300            max_attempts: 2,
301            initial_backoff_ms: 1,
302            max_backoff_ms: 10,
303        };
304        let mut calls = 0;
305
306        let result: Result<()> = retry_with_backoff(
307            &config,
308            || {
309                calls += 1;
310                async { Err(Error::Network("always fails".to_string())) }
311            },
312            |_| true,
313        )
314        .await;
315
316        assert!(result.is_err());
317        assert_eq!(calls, 2);
318    }
319
320    #[tokio::test]
321    async fn test_retry_non_retryable() {
322        let config = RetryConfig {
323            max_attempts: 3,
324            initial_backoff_ms: 1,
325            max_backoff_ms: 10,
326        };
327        let mut calls = 0;
328
329        let result: Result<()> = retry_with_backoff(
330            &config,
331            || {
332                calls += 1;
333                async { Err(Error::NotFound("not found".to_string())) }
334            },
335            is_retryable_error,
336        )
337        .await;
338
339        assert!(result.is_err());
340        assert_eq!(calls, 1); // Should not retry
341    }
342}