Skip to main content

claude_wrapper/
retry.rs

1use std::time::Duration;
2
3use tracing::warn;
4
5use crate::error::Error;
6
7/// Retry policy for transient CLI failures.
8///
9/// Configure max attempts, backoff strategy, and which errors to retry.
10///
11/// # Example
12///
13/// ```
14/// use claude_wrapper::RetryPolicy;
15/// use std::time::Duration;
16///
17/// let policy = RetryPolicy::new()
18///     .max_attempts(3)
19///     .initial_backoff(Duration::from_secs(1))
20///     .exponential()
21///     .retry_on_timeout(true)
22///     .retry_on_exit_codes([1, 2]);
23/// ```
24#[derive(Debug, Clone)]
25pub struct RetryPolicy {
26    pub(crate) max_attempts: u32,
27    pub(crate) initial_backoff: Duration,
28    pub(crate) max_backoff: Duration,
29    pub(crate) backoff_strategy: BackoffStrategy,
30    pub(crate) retry_on_timeout: bool,
31    pub(crate) retry_exit_codes: Vec<i32>,
32}
33
34/// Backoff strategy between retry attempts.
35#[derive(Debug, Clone, Copy)]
36pub enum BackoffStrategy {
37    /// Fixed delay between attempts.
38    Fixed,
39    /// Exponential backoff (delay doubles each attempt).
40    Exponential,
41}
42
43impl Default for RetryPolicy {
44    fn default() -> Self {
45        Self {
46            max_attempts: 3,
47            initial_backoff: Duration::from_secs(1),
48            max_backoff: Duration::from_secs(30),
49            backoff_strategy: BackoffStrategy::Fixed,
50            retry_on_timeout: true,
51            retry_exit_codes: Vec::new(),
52        }
53    }
54}
55
56impl RetryPolicy {
57    /// Create a new retry policy with default settings (3 attempts, 1s fixed backoff).
58    #[must_use]
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    /// Set the maximum number of attempts (including the initial attempt).
64    ///
65    /// A value of 1 means no retries.
66    #[must_use]
67    pub fn max_attempts(mut self, n: u32) -> Self {
68        self.max_attempts = n;
69        self
70    }
71
72    /// Set the initial delay before the first retry.
73    #[must_use]
74    pub fn initial_backoff(mut self, duration: Duration) -> Self {
75        self.initial_backoff = duration;
76        self
77    }
78
79    /// Set the maximum delay between retries (caps exponential growth).
80    #[must_use]
81    pub fn max_backoff(mut self, duration: Duration) -> Self {
82        self.max_backoff = duration;
83        self
84    }
85
86    /// Use fixed backoff (same delay between each attempt).
87    #[must_use]
88    pub fn fixed(mut self) -> Self {
89        self.backoff_strategy = BackoffStrategy::Fixed;
90        self
91    }
92
93    /// Use exponential backoff (delay doubles each attempt, capped by max_backoff).
94    #[must_use]
95    pub fn exponential(mut self) -> Self {
96        self.backoff_strategy = BackoffStrategy::Exponential;
97        self
98    }
99
100    /// Retry on timeout errors.
101    #[must_use]
102    pub fn retry_on_timeout(mut self, retry: bool) -> Self {
103        self.retry_on_timeout = retry;
104        self
105    }
106
107    /// Retry on specific non-zero exit codes.
108    #[must_use]
109    pub fn retry_on_exit_codes(mut self, codes: impl IntoIterator<Item = i32>) -> Self {
110        self.retry_exit_codes = codes.into_iter().collect();
111        self
112    }
113
114    /// Calculate the delay for a given attempt (0-indexed).
115    pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
116        let delay = match self.backoff_strategy {
117            BackoffStrategy::Fixed => self.initial_backoff,
118            BackoffStrategy::Exponential => self
119                .initial_backoff
120                .saturating_mul(2u32.saturating_pow(attempt)),
121        };
122        delay.min(self.max_backoff)
123    }
124
125    /// Check if the given error should be retried.
126    pub(crate) fn should_retry(&self, error: &Error) -> bool {
127        match error {
128            Error::Timeout { .. } => self.retry_on_timeout,
129            Error::CommandFailed { exit_code, .. } => self.retry_exit_codes.contains(exit_code),
130            _ => false,
131        }
132    }
133}
134
135/// Execute a fallible async operation with retry.
136#[cfg(feature = "async")]
137pub(crate) async fn with_retry<F, Fut, T>(
138    policy: &RetryPolicy,
139    mut operation: F,
140) -> crate::error::Result<T>
141where
142    F: FnMut() -> Fut,
143    Fut: std::future::Future<Output = crate::error::Result<T>>,
144{
145    let mut last_error = None;
146
147    for attempt in 0..policy.max_attempts {
148        match operation().await {
149            Ok(result) => return Ok(result),
150            Err(e) => {
151                if attempt + 1 < policy.max_attempts && policy.should_retry(&e) {
152                    let delay = policy.delay_for_attempt(attempt);
153                    warn!(
154                        attempt = attempt + 1,
155                        max_attempts = policy.max_attempts,
156                        delay_ms = delay.as_millis() as u64,
157                        error = %e,
158                        "retrying after transient error"
159                    );
160                    tokio::time::sleep(delay).await;
161                    last_error = Some(e);
162                } else {
163                    return Err(e);
164                }
165            }
166        }
167    }
168
169    Err(last_error.expect("at least one attempt was made"))
170}
171
172/// Execute a fallible blocking operation with retry. Sync mirror of
173/// [`with_retry`]; waits between attempts with [`std::thread::sleep`].
174#[cfg(feature = "sync")]
175pub(crate) fn with_retry_sync<F, T>(
176    policy: &RetryPolicy,
177    mut operation: F,
178) -> crate::error::Result<T>
179where
180    F: FnMut() -> crate::error::Result<T>,
181{
182    let mut last_error = None;
183
184    for attempt in 0..policy.max_attempts {
185        match operation() {
186            Ok(result) => return Ok(result),
187            Err(e) => {
188                if attempt + 1 < policy.max_attempts && policy.should_retry(&e) {
189                    let delay = policy.delay_for_attempt(attempt);
190                    warn!(
191                        attempt = attempt + 1,
192                        max_attempts = policy.max_attempts,
193                        delay_ms = delay.as_millis() as u64,
194                        error = %e,
195                        "retrying after transient error"
196                    );
197                    std::thread::sleep(delay);
198                    last_error = Some(e);
199                } else {
200                    return Err(e);
201                }
202            }
203        }
204    }
205
206    Err(last_error.expect("at least one attempt was made"))
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_default_policy() {
215        let policy = RetryPolicy::new();
216        assert_eq!(policy.max_attempts, 3);
217        assert_eq!(policy.initial_backoff, Duration::from_secs(1));
218        assert!(policy.retry_on_timeout);
219        assert!(policy.retry_exit_codes.is_empty());
220    }
221
222    #[test]
223    fn test_builder() {
224        let policy = RetryPolicy::new()
225            .max_attempts(5)
226            .initial_backoff(Duration::from_millis(500))
227            .exponential()
228            .retry_on_timeout(false)
229            .retry_on_exit_codes([1, 2, 3]);
230
231        assert_eq!(policy.max_attempts, 5);
232        assert_eq!(policy.initial_backoff, Duration::from_millis(500));
233        assert!(!policy.retry_on_timeout);
234        assert_eq!(policy.retry_exit_codes, vec![1, 2, 3]);
235    }
236
237    #[test]
238    fn test_fixed_delay() {
239        let policy = RetryPolicy::new()
240            .initial_backoff(Duration::from_secs(2))
241            .fixed();
242
243        assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(2));
244        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
245        assert_eq!(policy.delay_for_attempt(5), Duration::from_secs(2));
246    }
247
248    #[test]
249    fn test_exponential_delay() {
250        let policy = RetryPolicy::new()
251            .initial_backoff(Duration::from_secs(1))
252            .max_backoff(Duration::from_secs(30))
253            .exponential();
254
255        assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(1));
256        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
257        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(4));
258        assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(8));
259        // Capped at max_backoff
260        assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(30));
261    }
262
263    #[test]
264    fn test_should_retry_timeout() {
265        let policy = RetryPolicy::new().retry_on_timeout(true);
266        let error = Error::Timeout {
267            timeout_seconds: 60,
268        };
269        assert!(policy.should_retry(&error));
270
271        let policy = RetryPolicy::new().retry_on_timeout(false);
272        assert!(!policy.should_retry(&error));
273    }
274
275    #[test]
276    fn test_should_retry_exit_code() {
277        let policy = RetryPolicy::new().retry_on_exit_codes([1, 2]);
278
279        let retryable = Error::CommandFailed {
280            command: "test".into(),
281            exit_code: 1,
282            stdout: String::new(),
283            stderr: String::new(),
284            working_dir: None,
285        };
286        assert!(policy.should_retry(&retryable));
287
288        let not_retryable = Error::CommandFailed {
289            command: "test".into(),
290            exit_code: 99,
291            stdout: String::new(),
292            stderr: String::new(),
293            working_dir: None,
294        };
295        assert!(!policy.should_retry(&not_retryable));
296    }
297
298    #[test]
299    fn test_should_not_retry_other_errors() {
300        let policy = RetryPolicy::new()
301            .retry_on_timeout(true)
302            .retry_on_exit_codes([1]);
303
304        let error = Error::NotFound;
305        assert!(!policy.should_retry(&error));
306    }
307
308    #[cfg(feature = "async")]
309    #[tokio::test]
310    async fn test_with_retry_succeeds_first_try() {
311        let policy = RetryPolicy::new().max_attempts(3);
312        let result = with_retry(&policy, || async { Ok::<_, Error>(42) }).await;
313        assert_eq!(result.unwrap(), 42);
314    }
315
316    #[cfg(feature = "async")]
317    #[tokio::test]
318    async fn test_with_retry_succeeds_after_failures() {
319        let policy = RetryPolicy::new()
320            .max_attempts(3)
321            .initial_backoff(Duration::from_millis(1))
322            .retry_on_timeout(true);
323
324        let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
325        let attempt_clone = attempt.clone();
326
327        let result = with_retry(&policy, || {
328            let attempt = attempt_clone.clone();
329            async move {
330                let n = attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
331                if n < 2 {
332                    Err(Error::Timeout {
333                        timeout_seconds: 60,
334                    })
335                } else {
336                    Ok(42)
337                }
338            }
339        })
340        .await;
341
342        assert_eq!(result.unwrap(), 42);
343        assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
344    }
345
346    #[cfg(feature = "async")]
347    #[tokio::test]
348    async fn test_with_retry_exhausts_attempts() {
349        let policy = RetryPolicy::new()
350            .max_attempts(2)
351            .initial_backoff(Duration::from_millis(1))
352            .retry_on_timeout(true);
353
354        let result: crate::error::Result<()> = with_retry(&policy, || async {
355            Err(Error::Timeout {
356                timeout_seconds: 60,
357            })
358        })
359        .await;
360
361        assert!(matches!(result, Err(Error::Timeout { .. })));
362    }
363
364    #[cfg(feature = "async")]
365    #[tokio::test]
366    async fn test_with_retry_no_retry_on_non_retryable() {
367        let policy = RetryPolicy::new()
368            .max_attempts(3)
369            .initial_backoff(Duration::from_millis(1))
370            .retry_on_timeout(false);
371
372        let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
373        let attempt_clone = attempt.clone();
374
375        let result: crate::error::Result<()> = with_retry(&policy, || {
376            let attempt = attempt_clone.clone();
377            async move {
378                attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
379                Err(Error::Timeout {
380                    timeout_seconds: 60,
381                })
382            }
383        })
384        .await;
385
386        assert!(result.is_err());
387        // Should only attempt once since timeout is not retryable
388        assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 1);
389    }
390
391    #[cfg(feature = "sync")]
392    #[test]
393    fn test_with_retry_sync_succeeds_first_try() {
394        let policy = RetryPolicy::new().max_attempts(3);
395        let result = with_retry_sync(&policy, || Ok::<_, Error>(42));
396        assert_eq!(result.unwrap(), 42);
397    }
398
399    #[cfg(feature = "sync")]
400    #[test]
401    fn test_with_retry_sync_succeeds_after_failures() {
402        use std::sync::atomic::{AtomicU32, Ordering};
403
404        let policy = RetryPolicy::new()
405            .max_attempts(3)
406            .initial_backoff(Duration::from_millis(1))
407            .retry_on_timeout(true);
408
409        let attempt = AtomicU32::new(0);
410        let result = with_retry_sync(&policy, || {
411            let n = attempt.fetch_add(1, Ordering::SeqCst);
412            if n < 2 {
413                Err(Error::Timeout {
414                    timeout_seconds: 60,
415                })
416            } else {
417                Ok(42)
418            }
419        });
420
421        assert_eq!(result.unwrap(), 42);
422        assert_eq!(attempt.load(Ordering::SeqCst), 3);
423    }
424
425    #[cfg(feature = "sync")]
426    #[test]
427    fn test_with_retry_sync_exhausts_attempts() {
428        let policy = RetryPolicy::new()
429            .max_attempts(2)
430            .initial_backoff(Duration::from_millis(1))
431            .retry_on_timeout(true);
432
433        let result: crate::error::Result<()> = with_retry_sync(&policy, || {
434            Err(Error::Timeout {
435                timeout_seconds: 60,
436            })
437        });
438
439        assert!(matches!(result, Err(Error::Timeout { .. })));
440    }
441
442    #[cfg(feature = "sync")]
443    #[test]
444    fn test_with_retry_sync_no_retry_on_non_retryable() {
445        use std::sync::atomic::{AtomicU32, Ordering};
446
447        let policy = RetryPolicy::new()
448            .max_attempts(3)
449            .initial_backoff(Duration::from_millis(1))
450            .retry_on_timeout(false);
451
452        let attempt = AtomicU32::new(0);
453        let result: crate::error::Result<()> = with_retry_sync(&policy, || {
454            attempt.fetch_add(1, Ordering::SeqCst);
455            Err(Error::Timeout {
456                timeout_seconds: 60,
457            })
458        });
459
460        assert!(result.is_err());
461        assert_eq!(attempt.load(Ordering::SeqCst), 1);
462    }
463}