Skip to main content

claude_wrapper/
retry.rs

1use std::time::Duration;
2
3use tracing::warn;
4
5use crate::error::Error;
6
7/// Retry policy for transient CLI failures.
8///
9/// Configure max attempts, backoff strategy, and which errors to retry.
10///
11/// # Example
12///
13/// ```
14/// use claude_wrapper::RetryPolicy;
15/// use std::time::Duration;
16///
17/// let policy = RetryPolicy::new()
18///     .max_attempts(3)
19///     .initial_backoff(Duration::from_secs(1))
20///     .exponential()
21///     .retry_on_timeout(true)
22///     .retry_on_exit_codes([1, 2]);
23/// ```
24#[derive(Debug, Clone)]
25pub struct RetryPolicy {
26    pub(crate) max_attempts: u32,
27    pub(crate) initial_backoff: Duration,
28    pub(crate) max_backoff: Duration,
29    pub(crate) backoff_strategy: BackoffStrategy,
30    pub(crate) retry_on_timeout: bool,
31    pub(crate) retry_exit_codes: Vec<i32>,
32}
33
34/// Backoff strategy between retry attempts.
35#[derive(Debug, Clone, Copy)]
36pub enum BackoffStrategy {
37    /// Fixed delay between attempts.
38    Fixed,
39    /// Exponential backoff (delay doubles each attempt).
40    Exponential,
41}
42
43impl Default for RetryPolicy {
44    fn default() -> Self {
45        Self {
46            max_attempts: 3,
47            initial_backoff: Duration::from_secs(1),
48            max_backoff: Duration::from_secs(30),
49            backoff_strategy: BackoffStrategy::Fixed,
50            retry_on_timeout: true,
51            retry_exit_codes: Vec::new(),
52        }
53    }
54}
55
56impl RetryPolicy {
57    /// Create a new retry policy with default settings (3 attempts, 1s fixed backoff).
58    #[must_use]
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    /// Set the maximum number of attempts (including the initial attempt).
64    ///
65    /// A value of 1 means no retries.
66    #[must_use]
67    pub fn max_attempts(mut self, n: u32) -> Self {
68        self.max_attempts = n;
69        self
70    }
71
72    /// Set the initial delay before the first retry.
73    #[must_use]
74    pub fn initial_backoff(mut self, duration: Duration) -> Self {
75        self.initial_backoff = duration;
76        self
77    }
78
79    /// Set the maximum delay between retries (caps exponential growth).
80    #[must_use]
81    pub fn max_backoff(mut self, duration: Duration) -> Self {
82        self.max_backoff = duration;
83        self
84    }
85
86    /// Use fixed backoff (same delay between each attempt).
87    #[must_use]
88    pub fn fixed(mut self) -> Self {
89        self.backoff_strategy = BackoffStrategy::Fixed;
90        self
91    }
92
93    /// Use exponential backoff (delay doubles each attempt, capped by max_backoff).
94    #[must_use]
95    pub fn exponential(mut self) -> Self {
96        self.backoff_strategy = BackoffStrategy::Exponential;
97        self
98    }
99
100    /// Retry on timeout errors.
101    #[must_use]
102    pub fn retry_on_timeout(mut self, retry: bool) -> Self {
103        self.retry_on_timeout = retry;
104        self
105    }
106
107    /// Retry on specific non-zero exit codes.
108    #[must_use]
109    pub fn retry_on_exit_codes(mut self, codes: impl IntoIterator<Item = i32>) -> Self {
110        self.retry_exit_codes = codes.into_iter().collect();
111        self
112    }
113
114    /// Calculate the delay for a given attempt (0-indexed).
115    pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
116        let delay = match self.backoff_strategy {
117            BackoffStrategy::Fixed => self.initial_backoff,
118            BackoffStrategy::Exponential => self
119                .initial_backoff
120                .saturating_mul(2u32.saturating_pow(attempt)),
121        };
122        delay.min(self.max_backoff)
123    }
124
125    /// Check if the given error should be retried.
126    pub(crate) fn should_retry(&self, error: &Error) -> bool {
127        match error {
128            Error::Timeout { .. } => self.retry_on_timeout,
129            Error::CommandFailed { exit_code, .. } => self.retry_exit_codes.contains(exit_code),
130            _ => false,
131        }
132    }
133}
134
135/// Execute a fallible async operation with retry.
136pub(crate) async fn with_retry<F, Fut, T>(
137    policy: &RetryPolicy,
138    mut operation: F,
139) -> crate::error::Result<T>
140where
141    F: FnMut() -> Fut,
142    Fut: std::future::Future<Output = crate::error::Result<T>>,
143{
144    let mut last_error = None;
145
146    for attempt in 0..policy.max_attempts {
147        match operation().await {
148            Ok(result) => return Ok(result),
149            Err(e) => {
150                if attempt + 1 < policy.max_attempts && policy.should_retry(&e) {
151                    let delay = policy.delay_for_attempt(attempt);
152                    warn!(
153                        attempt = attempt + 1,
154                        max_attempts = policy.max_attempts,
155                        delay_ms = delay.as_millis() as u64,
156                        error = %e,
157                        "retrying after transient error"
158                    );
159                    tokio::time::sleep(delay).await;
160                    last_error = Some(e);
161                } else {
162                    return Err(e);
163                }
164            }
165        }
166    }
167
168    Err(last_error.expect("at least one attempt was made"))
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_default_policy() {
177        let policy = RetryPolicy::new();
178        assert_eq!(policy.max_attempts, 3);
179        assert_eq!(policy.initial_backoff, Duration::from_secs(1));
180        assert!(policy.retry_on_timeout);
181        assert!(policy.retry_exit_codes.is_empty());
182    }
183
184    #[test]
185    fn test_builder() {
186        let policy = RetryPolicy::new()
187            .max_attempts(5)
188            .initial_backoff(Duration::from_millis(500))
189            .exponential()
190            .retry_on_timeout(false)
191            .retry_on_exit_codes([1, 2, 3]);
192
193        assert_eq!(policy.max_attempts, 5);
194        assert_eq!(policy.initial_backoff, Duration::from_millis(500));
195        assert!(!policy.retry_on_timeout);
196        assert_eq!(policy.retry_exit_codes, vec![1, 2, 3]);
197    }
198
199    #[test]
200    fn test_fixed_delay() {
201        let policy = RetryPolicy::new()
202            .initial_backoff(Duration::from_secs(2))
203            .fixed();
204
205        assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(2));
206        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
207        assert_eq!(policy.delay_for_attempt(5), Duration::from_secs(2));
208    }
209
210    #[test]
211    fn test_exponential_delay() {
212        let policy = RetryPolicy::new()
213            .initial_backoff(Duration::from_secs(1))
214            .max_backoff(Duration::from_secs(30))
215            .exponential();
216
217        assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(1));
218        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
219        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(4));
220        assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(8));
221        // Capped at max_backoff
222        assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(30));
223    }
224
225    #[test]
226    fn test_should_retry_timeout() {
227        let policy = RetryPolicy::new().retry_on_timeout(true);
228        let error = Error::Timeout {
229            timeout_seconds: 60,
230        };
231        assert!(policy.should_retry(&error));
232
233        let policy = RetryPolicy::new().retry_on_timeout(false);
234        assert!(!policy.should_retry(&error));
235    }
236
237    #[test]
238    fn test_should_retry_exit_code() {
239        let policy = RetryPolicy::new().retry_on_exit_codes([1, 2]);
240
241        let retryable = Error::CommandFailed {
242            command: "test".into(),
243            exit_code: 1,
244            stdout: String::new(),
245            stderr: String::new(),
246        };
247        assert!(policy.should_retry(&retryable));
248
249        let not_retryable = Error::CommandFailed {
250            command: "test".into(),
251            exit_code: 99,
252            stdout: String::new(),
253            stderr: String::new(),
254        };
255        assert!(!policy.should_retry(&not_retryable));
256    }
257
258    #[test]
259    fn test_should_not_retry_other_errors() {
260        let policy = RetryPolicy::new()
261            .retry_on_timeout(true)
262            .retry_on_exit_codes([1]);
263
264        let error = Error::NotFound;
265        assert!(!policy.should_retry(&error));
266    }
267
268    #[tokio::test]
269    async fn test_with_retry_succeeds_first_try() {
270        let policy = RetryPolicy::new().max_attempts(3);
271        let result = with_retry(&policy, || async { Ok::<_, Error>(42) }).await;
272        assert_eq!(result.unwrap(), 42);
273    }
274
275    #[tokio::test]
276    async fn test_with_retry_succeeds_after_failures() {
277        let policy = RetryPolicy::new()
278            .max_attempts(3)
279            .initial_backoff(Duration::from_millis(1))
280            .retry_on_timeout(true);
281
282        let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
283        let attempt_clone = attempt.clone();
284
285        let result = with_retry(&policy, || {
286            let attempt = attempt_clone.clone();
287            async move {
288                let n = attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
289                if n < 2 {
290                    Err(Error::Timeout {
291                        timeout_seconds: 60,
292                    })
293                } else {
294                    Ok(42)
295                }
296            }
297        })
298        .await;
299
300        assert_eq!(result.unwrap(), 42);
301        assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
302    }
303
304    #[tokio::test]
305    async fn test_with_retry_exhausts_attempts() {
306        let policy = RetryPolicy::new()
307            .max_attempts(2)
308            .initial_backoff(Duration::from_millis(1))
309            .retry_on_timeout(true);
310
311        let result: crate::error::Result<()> = with_retry(&policy, || async {
312            Err(Error::Timeout {
313                timeout_seconds: 60,
314            })
315        })
316        .await;
317
318        assert!(matches!(result, Err(Error::Timeout { .. })));
319    }
320
321    #[tokio::test]
322    async fn test_with_retry_no_retry_on_non_retryable() {
323        let policy = RetryPolicy::new()
324            .max_attempts(3)
325            .initial_backoff(Duration::from_millis(1))
326            .retry_on_timeout(false);
327
328        let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
329        let attempt_clone = attempt.clone();
330
331        let result: crate::error::Result<()> = with_retry(&policy, || {
332            let attempt = attempt_clone.clone();
333            async move {
334                attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
335                Err(Error::Timeout {
336                    timeout_seconds: 60,
337                })
338            }
339        })
340        .await;
341
342        assert!(result.is_err());
343        // Should only attempt once since timeout is not retryable
344        assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 1);
345    }
346}