1use std::time::Duration;
2
3use tracing::warn;
4
5use crate::error::Error;
6
7#[derive(Debug, Clone)]
25pub struct RetryPolicy {
26 pub(crate) max_attempts: u32,
27 pub(crate) initial_backoff: Duration,
28 pub(crate) max_backoff: Duration,
29 pub(crate) backoff_strategy: BackoffStrategy,
30 pub(crate) retry_on_timeout: bool,
31 pub(crate) retry_exit_codes: Vec<i32>,
32}
33
34#[derive(Debug, Clone, Copy)]
36pub enum BackoffStrategy {
37 Fixed,
39 Exponential,
41}
42
43impl Default for RetryPolicy {
44 fn default() -> Self {
45 Self {
46 max_attempts: 3,
47 initial_backoff: Duration::from_secs(1),
48 max_backoff: Duration::from_secs(30),
49 backoff_strategy: BackoffStrategy::Fixed,
50 retry_on_timeout: true,
51 retry_exit_codes: Vec::new(),
52 }
53 }
54}
55
56impl RetryPolicy {
57 #[must_use]
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 #[must_use]
67 pub fn max_attempts(mut self, n: u32) -> Self {
68 self.max_attempts = n;
69 self
70 }
71
72 #[must_use]
74 pub fn initial_backoff(mut self, duration: Duration) -> Self {
75 self.initial_backoff = duration;
76 self
77 }
78
79 #[must_use]
81 pub fn max_backoff(mut self, duration: Duration) -> Self {
82 self.max_backoff = duration;
83 self
84 }
85
86 #[must_use]
88 pub fn fixed(mut self) -> Self {
89 self.backoff_strategy = BackoffStrategy::Fixed;
90 self
91 }
92
93 #[must_use]
95 pub fn exponential(mut self) -> Self {
96 self.backoff_strategy = BackoffStrategy::Exponential;
97 self
98 }
99
100 #[must_use]
102 pub fn retry_on_timeout(mut self, retry: bool) -> Self {
103 self.retry_on_timeout = retry;
104 self
105 }
106
107 #[must_use]
109 pub fn retry_on_exit_codes(mut self, codes: impl IntoIterator<Item = i32>) -> Self {
110 self.retry_exit_codes = codes.into_iter().collect();
111 self
112 }
113
114 pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
116 let delay = match self.backoff_strategy {
117 BackoffStrategy::Fixed => self.initial_backoff,
118 BackoffStrategy::Exponential => self
119 .initial_backoff
120 .saturating_mul(2u32.saturating_pow(attempt)),
121 };
122 delay.min(self.max_backoff)
123 }
124
125 pub(crate) fn should_retry(&self, error: &Error) -> bool {
127 match error {
128 Error::Timeout { .. } => self.retry_on_timeout,
129 Error::CommandFailed { exit_code, .. } => self.retry_exit_codes.contains(exit_code),
130 _ => false,
131 }
132 }
133}
134
135#[cfg(feature = "async")]
137pub(crate) async fn with_retry<F, Fut, T>(
138 policy: &RetryPolicy,
139 mut operation: F,
140) -> crate::error::Result<T>
141where
142 F: FnMut() -> Fut,
143 Fut: std::future::Future<Output = crate::error::Result<T>>,
144{
145 let mut last_error = None;
146
147 for attempt in 0..policy.max_attempts {
148 match operation().await {
149 Ok(result) => return Ok(result),
150 Err(e) => {
151 if attempt + 1 < policy.max_attempts && policy.should_retry(&e) {
152 let delay = policy.delay_for_attempt(attempt);
153 warn!(
154 attempt = attempt + 1,
155 max_attempts = policy.max_attempts,
156 delay_ms = delay.as_millis() as u64,
157 error = %e,
158 "retrying after transient error"
159 );
160 tokio::time::sleep(delay).await;
161 last_error = Some(e);
162 } else {
163 return Err(e);
164 }
165 }
166 }
167 }
168
169 Err(last_error.expect("at least one attempt was made"))
170}
171
172#[cfg(feature = "sync")]
175pub(crate) fn with_retry_sync<F, T>(
176 policy: &RetryPolicy,
177 mut operation: F,
178) -> crate::error::Result<T>
179where
180 F: FnMut() -> crate::error::Result<T>,
181{
182 let mut last_error = None;
183
184 for attempt in 0..policy.max_attempts {
185 match operation() {
186 Ok(result) => return Ok(result),
187 Err(e) => {
188 if attempt + 1 < policy.max_attempts && policy.should_retry(&e) {
189 let delay = policy.delay_for_attempt(attempt);
190 warn!(
191 attempt = attempt + 1,
192 max_attempts = policy.max_attempts,
193 delay_ms = delay.as_millis() as u64,
194 error = %e,
195 "retrying after transient error"
196 );
197 std::thread::sleep(delay);
198 last_error = Some(e);
199 } else {
200 return Err(e);
201 }
202 }
203 }
204 }
205
206 Err(last_error.expect("at least one attempt was made"))
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_default_policy() {
215 let policy = RetryPolicy::new();
216 assert_eq!(policy.max_attempts, 3);
217 assert_eq!(policy.initial_backoff, Duration::from_secs(1));
218 assert!(policy.retry_on_timeout);
219 assert!(policy.retry_exit_codes.is_empty());
220 }
221
222 #[test]
223 fn test_builder() {
224 let policy = RetryPolicy::new()
225 .max_attempts(5)
226 .initial_backoff(Duration::from_millis(500))
227 .exponential()
228 .retry_on_timeout(false)
229 .retry_on_exit_codes([1, 2, 3]);
230
231 assert_eq!(policy.max_attempts, 5);
232 assert_eq!(policy.initial_backoff, Duration::from_millis(500));
233 assert!(!policy.retry_on_timeout);
234 assert_eq!(policy.retry_exit_codes, vec![1, 2, 3]);
235 }
236
237 #[test]
238 fn test_fixed_delay() {
239 let policy = RetryPolicy::new()
240 .initial_backoff(Duration::from_secs(2))
241 .fixed();
242
243 assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(2));
244 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
245 assert_eq!(policy.delay_for_attempt(5), Duration::from_secs(2));
246 }
247
248 #[test]
249 fn test_exponential_delay() {
250 let policy = RetryPolicy::new()
251 .initial_backoff(Duration::from_secs(1))
252 .max_backoff(Duration::from_secs(30))
253 .exponential();
254
255 assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(1));
256 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
257 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(4));
258 assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(8));
259 assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(30));
261 }
262
263 #[test]
264 fn test_should_retry_timeout() {
265 let policy = RetryPolicy::new().retry_on_timeout(true);
266 let error = Error::Timeout {
267 timeout_seconds: 60,
268 };
269 assert!(policy.should_retry(&error));
270
271 let policy = RetryPolicy::new().retry_on_timeout(false);
272 assert!(!policy.should_retry(&error));
273 }
274
275 #[test]
276 fn test_should_retry_exit_code() {
277 let policy = RetryPolicy::new().retry_on_exit_codes([1, 2]);
278
279 let retryable = Error::CommandFailed {
280 command: "test".into(),
281 exit_code: 1,
282 stdout: String::new(),
283 stderr: String::new(),
284 working_dir: None,
285 };
286 assert!(policy.should_retry(&retryable));
287
288 let not_retryable = Error::CommandFailed {
289 command: "test".into(),
290 exit_code: 99,
291 stdout: String::new(),
292 stderr: String::new(),
293 working_dir: None,
294 };
295 assert!(!policy.should_retry(¬_retryable));
296 }
297
298 #[test]
299 fn test_should_not_retry_other_errors() {
300 let policy = RetryPolicy::new()
301 .retry_on_timeout(true)
302 .retry_on_exit_codes([1]);
303
304 let error = Error::NotFound;
305 assert!(!policy.should_retry(&error));
306 }
307
308 #[cfg(feature = "async")]
309 #[tokio::test]
310 async fn test_with_retry_succeeds_first_try() {
311 let policy = RetryPolicy::new().max_attempts(3);
312 let result = with_retry(&policy, || async { Ok::<_, Error>(42) }).await;
313 assert_eq!(result.unwrap(), 42);
314 }
315
316 #[cfg(feature = "async")]
317 #[tokio::test]
318 async fn test_with_retry_succeeds_after_failures() {
319 let policy = RetryPolicy::new()
320 .max_attempts(3)
321 .initial_backoff(Duration::from_millis(1))
322 .retry_on_timeout(true);
323
324 let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
325 let attempt_clone = attempt.clone();
326
327 let result = with_retry(&policy, || {
328 let attempt = attempt_clone.clone();
329 async move {
330 let n = attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
331 if n < 2 {
332 Err(Error::Timeout {
333 timeout_seconds: 60,
334 })
335 } else {
336 Ok(42)
337 }
338 }
339 })
340 .await;
341
342 assert_eq!(result.unwrap(), 42);
343 assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
344 }
345
346 #[cfg(feature = "async")]
347 #[tokio::test]
348 async fn test_with_retry_exhausts_attempts() {
349 let policy = RetryPolicy::new()
350 .max_attempts(2)
351 .initial_backoff(Duration::from_millis(1))
352 .retry_on_timeout(true);
353
354 let result: crate::error::Result<()> = with_retry(&policy, || async {
355 Err(Error::Timeout {
356 timeout_seconds: 60,
357 })
358 })
359 .await;
360
361 assert!(matches!(result, Err(Error::Timeout { .. })));
362 }
363
364 #[cfg(feature = "async")]
365 #[tokio::test]
366 async fn test_with_retry_no_retry_on_non_retryable() {
367 let policy = RetryPolicy::new()
368 .max_attempts(3)
369 .initial_backoff(Duration::from_millis(1))
370 .retry_on_timeout(false);
371
372 let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
373 let attempt_clone = attempt.clone();
374
375 let result: crate::error::Result<()> = with_retry(&policy, || {
376 let attempt = attempt_clone.clone();
377 async move {
378 attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
379 Err(Error::Timeout {
380 timeout_seconds: 60,
381 })
382 }
383 })
384 .await;
385
386 assert!(result.is_err());
387 assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 1);
389 }
390
391 #[cfg(feature = "sync")]
392 #[test]
393 fn test_with_retry_sync_succeeds_first_try() {
394 let policy = RetryPolicy::new().max_attempts(3);
395 let result = with_retry_sync(&policy, || Ok::<_, Error>(42));
396 assert_eq!(result.unwrap(), 42);
397 }
398
399 #[cfg(feature = "sync")]
400 #[test]
401 fn test_with_retry_sync_succeeds_after_failures() {
402 use std::sync::atomic::{AtomicU32, Ordering};
403
404 let policy = RetryPolicy::new()
405 .max_attempts(3)
406 .initial_backoff(Duration::from_millis(1))
407 .retry_on_timeout(true);
408
409 let attempt = AtomicU32::new(0);
410 let result = with_retry_sync(&policy, || {
411 let n = attempt.fetch_add(1, Ordering::SeqCst);
412 if n < 2 {
413 Err(Error::Timeout {
414 timeout_seconds: 60,
415 })
416 } else {
417 Ok(42)
418 }
419 });
420
421 assert_eq!(result.unwrap(), 42);
422 assert_eq!(attempt.load(Ordering::SeqCst), 3);
423 }
424
425 #[cfg(feature = "sync")]
426 #[test]
427 fn test_with_retry_sync_exhausts_attempts() {
428 let policy = RetryPolicy::new()
429 .max_attempts(2)
430 .initial_backoff(Duration::from_millis(1))
431 .retry_on_timeout(true);
432
433 let result: crate::error::Result<()> = with_retry_sync(&policy, || {
434 Err(Error::Timeout {
435 timeout_seconds: 60,
436 })
437 });
438
439 assert!(matches!(result, Err(Error::Timeout { .. })));
440 }
441
442 #[cfg(feature = "sync")]
443 #[test]
444 fn test_with_retry_sync_no_retry_on_non_retryable() {
445 use std::sync::atomic::{AtomicU32, Ordering};
446
447 let policy = RetryPolicy::new()
448 .max_attempts(3)
449 .initial_backoff(Duration::from_millis(1))
450 .retry_on_timeout(false);
451
452 let attempt = AtomicU32::new(0);
453 let result: crate::error::Result<()> = with_retry_sync(&policy, || {
454 attempt.fetch_add(1, Ordering::SeqCst);
455 Err(Error::Timeout {
456 timeout_seconds: 60,
457 })
458 });
459
460 assert!(result.is_err());
461 assert_eq!(attempt.load(Ordering::SeqCst), 1);
462 }
463}