Skip to main content

claude_wrapper/
retry.rs

1use std::time::Duration;
2
3use tracing::warn;
4
5use crate::error::Error;
6
7/// Retry policy for transient CLI failures.
8///
9/// Configure max attempts, backoff strategy, and which errors to retry.
10///
11/// # Example
12///
13/// ```
14/// use claude_wrapper::RetryPolicy;
15/// use std::time::Duration;
16///
17/// let policy = RetryPolicy::new()
18///     .max_attempts(3)
19///     .initial_backoff(Duration::from_secs(1))
20///     .exponential()
21///     .retry_on_timeout(true)
22///     .retry_on_exit_codes([1, 2]);
23/// ```
24#[derive(Debug, Clone)]
25pub struct RetryPolicy {
26    pub(crate) max_attempts: u32,
27    pub(crate) initial_backoff: Duration,
28    pub(crate) max_backoff: Duration,
29    pub(crate) backoff_strategy: BackoffStrategy,
30    pub(crate) retry_on_timeout: bool,
31    pub(crate) retry_exit_codes: Vec<i32>,
32}
33
34/// Backoff strategy between retry attempts.
35#[derive(Debug, Clone, Copy)]
36pub enum BackoffStrategy {
37    /// Fixed delay between attempts.
38    Fixed,
39    /// Exponential backoff (delay doubles each attempt).
40    Exponential,
41}
42
43impl Default for RetryPolicy {
44    fn default() -> Self {
45        Self {
46            max_attempts: 3,
47            initial_backoff: Duration::from_secs(1),
48            max_backoff: Duration::from_secs(30),
49            backoff_strategy: BackoffStrategy::Fixed,
50            retry_on_timeout: true,
51            retry_exit_codes: Vec::new(),
52        }
53    }
54}
55
56impl RetryPolicy {
57    /// Create a new retry policy with default settings (3 attempts, 1s fixed backoff).
58    #[must_use]
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    /// Set the maximum number of attempts (including the initial attempt).
64    ///
65    /// A value of 1 means no retries.
66    #[must_use]
67    pub fn max_attempts(mut self, n: u32) -> Self {
68        self.max_attempts = n;
69        self
70    }
71
72    /// Set the initial delay before the first retry.
73    #[must_use]
74    pub fn initial_backoff(mut self, duration: Duration) -> Self {
75        self.initial_backoff = duration;
76        self
77    }
78
79    /// Set the maximum delay between retries (caps exponential growth).
80    #[must_use]
81    pub fn max_backoff(mut self, duration: Duration) -> Self {
82        self.max_backoff = duration;
83        self
84    }
85
86    /// Use fixed backoff (same delay between each attempt).
87    #[must_use]
88    pub fn fixed(mut self) -> Self {
89        self.backoff_strategy = BackoffStrategy::Fixed;
90        self
91    }
92
93    /// Use exponential backoff (delay doubles each attempt, capped by max_backoff).
94    #[must_use]
95    pub fn exponential(mut self) -> Self {
96        self.backoff_strategy = BackoffStrategy::Exponential;
97        self
98    }
99
100    /// Retry on timeout errors.
101    #[must_use]
102    pub fn retry_on_timeout(mut self, retry: bool) -> Self {
103        self.retry_on_timeout = retry;
104        self
105    }
106
107    /// Retry on specific non-zero exit codes.
108    #[must_use]
109    pub fn retry_on_exit_codes(mut self, codes: impl IntoIterator<Item = i32>) -> Self {
110        self.retry_exit_codes = codes.into_iter().collect();
111        self
112    }
113
114    /// Calculate the delay for a given attempt (0-indexed).
115    pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
116        let delay = match self.backoff_strategy {
117            BackoffStrategy::Fixed => self.initial_backoff,
118            BackoffStrategy::Exponential => self
119                .initial_backoff
120                .saturating_mul(2u32.saturating_pow(attempt)),
121        };
122        delay.min(self.max_backoff)
123    }
124
125    /// Check if the given error should be retried.
126    pub(crate) fn should_retry(&self, error: &Error) -> bool {
127        match error {
128            Error::Timeout { .. } => self.retry_on_timeout,
129            Error::CommandFailed { exit_code, .. } => self.retry_exit_codes.contains(exit_code),
130            _ => false,
131        }
132    }
133}
134
135/// Execute a fallible async operation with retry.
136pub(crate) async fn with_retry<F, Fut, T>(
137    policy: &RetryPolicy,
138    mut operation: F,
139) -> crate::error::Result<T>
140where
141    F: FnMut() -> Fut,
142    Fut: std::future::Future<Output = crate::error::Result<T>>,
143{
144    let mut last_error = None;
145
146    for attempt in 0..policy.max_attempts {
147        match operation().await {
148            Ok(result) => return Ok(result),
149            Err(e) => {
150                if attempt + 1 < policy.max_attempts && policy.should_retry(&e) {
151                    let delay = policy.delay_for_attempt(attempt);
152                    warn!(
153                        attempt = attempt + 1,
154                        max_attempts = policy.max_attempts,
155                        delay_ms = delay.as_millis() as u64,
156                        error = %e,
157                        "retrying after transient error"
158                    );
159                    tokio::time::sleep(delay).await;
160                    last_error = Some(e);
161                } else {
162                    return Err(e);
163                }
164            }
165        }
166    }
167
168    Err(last_error.expect("at least one attempt was made"))
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_default_policy() {
177        let policy = RetryPolicy::new();
178        assert_eq!(policy.max_attempts, 3);
179        assert_eq!(policy.initial_backoff, Duration::from_secs(1));
180        assert!(policy.retry_on_timeout);
181        assert!(policy.retry_exit_codes.is_empty());
182    }
183
184    #[test]
185    fn test_builder() {
186        let policy = RetryPolicy::new()
187            .max_attempts(5)
188            .initial_backoff(Duration::from_millis(500))
189            .exponential()
190            .retry_on_timeout(false)
191            .retry_on_exit_codes([1, 2, 3]);
192
193        assert_eq!(policy.max_attempts, 5);
194        assert_eq!(policy.initial_backoff, Duration::from_millis(500));
195        assert!(!policy.retry_on_timeout);
196        assert_eq!(policy.retry_exit_codes, vec![1, 2, 3]);
197    }
198
199    #[test]
200    fn test_fixed_delay() {
201        let policy = RetryPolicy::new()
202            .initial_backoff(Duration::from_secs(2))
203            .fixed();
204
205        assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(2));
206        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
207        assert_eq!(policy.delay_for_attempt(5), Duration::from_secs(2));
208    }
209
210    #[test]
211    fn test_exponential_delay() {
212        let policy = RetryPolicy::new()
213            .initial_backoff(Duration::from_secs(1))
214            .max_backoff(Duration::from_secs(30))
215            .exponential();
216
217        assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(1));
218        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
219        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(4));
220        assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(8));
221        // Capped at max_backoff
222        assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(30));
223    }
224
225    #[test]
226    fn test_should_retry_timeout() {
227        let policy = RetryPolicy::new().retry_on_timeout(true);
228        let error = Error::Timeout {
229            timeout_seconds: 60,
230        };
231        assert!(policy.should_retry(&error));
232
233        let policy = RetryPolicy::new().retry_on_timeout(false);
234        assert!(!policy.should_retry(&error));
235    }
236
237    #[test]
238    fn test_should_retry_exit_code() {
239        let policy = RetryPolicy::new().retry_on_exit_codes([1, 2]);
240
241        let retryable = Error::CommandFailed {
242            command: "test".into(),
243            exit_code: 1,
244            stdout: String::new(),
245            stderr: String::new(),
246            working_dir: None,
247        };
248        assert!(policy.should_retry(&retryable));
249
250        let not_retryable = Error::CommandFailed {
251            command: "test".into(),
252            exit_code: 99,
253            stdout: String::new(),
254            stderr: String::new(),
255            working_dir: None,
256        };
257        assert!(!policy.should_retry(&not_retryable));
258    }
259
260    #[test]
261    fn test_should_not_retry_other_errors() {
262        let policy = RetryPolicy::new()
263            .retry_on_timeout(true)
264            .retry_on_exit_codes([1]);
265
266        let error = Error::NotFound;
267        assert!(!policy.should_retry(&error));
268    }
269
270    #[tokio::test]
271    async fn test_with_retry_succeeds_first_try() {
272        let policy = RetryPolicy::new().max_attempts(3);
273        let result = with_retry(&policy, || async { Ok::<_, Error>(42) }).await;
274        assert_eq!(result.unwrap(), 42);
275    }
276
277    #[tokio::test]
278    async fn test_with_retry_succeeds_after_failures() {
279        let policy = RetryPolicy::new()
280            .max_attempts(3)
281            .initial_backoff(Duration::from_millis(1))
282            .retry_on_timeout(true);
283
284        let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
285        let attempt_clone = attempt.clone();
286
287        let result = with_retry(&policy, || {
288            let attempt = attempt_clone.clone();
289            async move {
290                let n = attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
291                if n < 2 {
292                    Err(Error::Timeout {
293                        timeout_seconds: 60,
294                    })
295                } else {
296                    Ok(42)
297                }
298            }
299        })
300        .await;
301
302        assert_eq!(result.unwrap(), 42);
303        assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
304    }
305
306    #[tokio::test]
307    async fn test_with_retry_exhausts_attempts() {
308        let policy = RetryPolicy::new()
309            .max_attempts(2)
310            .initial_backoff(Duration::from_millis(1))
311            .retry_on_timeout(true);
312
313        let result: crate::error::Result<()> = with_retry(&policy, || async {
314            Err(Error::Timeout {
315                timeout_seconds: 60,
316            })
317        })
318        .await;
319
320        assert!(matches!(result, Err(Error::Timeout { .. })));
321    }
322
323    #[tokio::test]
324    async fn test_with_retry_no_retry_on_non_retryable() {
325        let policy = RetryPolicy::new()
326            .max_attempts(3)
327            .initial_backoff(Duration::from_millis(1))
328            .retry_on_timeout(false);
329
330        let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
331        let attempt_clone = attempt.clone();
332
333        let result: crate::error::Result<()> = with_retry(&policy, || {
334            let attempt = attempt_clone.clone();
335            async move {
336                attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
337                Err(Error::Timeout {
338                    timeout_seconds: 60,
339                })
340            }
341        })
342        .await;
343
344        assert!(result.is_err());
345        // Should only attempt once since timeout is not retryable
346        assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 1);
347    }
348}