1use std::thread;
2use std::time::Duration;
3
4use crate::error::CliError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct RetryPolicy {
8 pub retries: u8,
9 pub retry_delay_ms: u64,
10}
11
12pub fn run_with_retry<T, F>(policy: RetryPolicy, mut op: F) -> Result<(T, u8), CliError>
13where
14 F: FnMut() -> Result<T, CliError>,
15{
16 let mut attempt = 0u8;
17 loop {
18 attempt = attempt.saturating_add(1);
19 match op() {
20 Ok(value) => return Ok((value, attempt)),
21 Err(err) => {
22 if err.exit_code() != 1 || attempt > policy.retries {
23 if err.exit_code() == 1 {
24 return Err(err.with_hint(format!(
25 "operation failed after {attempt} attempt(s) with retries={} and retry_delay_ms={}",
26 policy.retries, policy.retry_delay_ms
27 )));
28 }
29 return Err(err);
30 }
31 if policy.retry_delay_ms > 0 {
32 thread::sleep(Duration::from_millis(policy.retry_delay_ms));
33 }
34 }
35 }
36 }
37}
38
39#[cfg(test)]
40mod tests {
41 use std::sync::atomic::{AtomicU8, Ordering};
42
43 use pretty_assertions::assert_eq;
44
45 use super::{RetryPolicy, run_with_retry};
46 use crate::error::CliError;
47
48 #[test]
49 fn retries_runtime_errors_until_success() {
50 static CALLS: AtomicU8 = AtomicU8::new(0);
51 CALLS.store(0, Ordering::SeqCst);
52
53 let policy = RetryPolicy {
54 retries: 2,
55 retry_delay_ms: 0,
56 };
57 let (value, attempts) = run_with_retry(policy, || {
58 let n = CALLS.fetch_add(1, Ordering::SeqCst);
59 if n < 2 {
60 Err(CliError::runtime("transient"))
61 } else {
62 Ok("ok")
63 }
64 })
65 .expect("retry should eventually succeed");
66
67 assert_eq!(value, "ok");
68 assert_eq!(attempts, 3);
69 }
70
71 #[test]
72 fn does_not_retry_usage_errors() {
73 let policy = RetryPolicy {
74 retries: 3,
75 retry_delay_ms: 0,
76 };
77
78 let err = run_with_retry::<(), _>(policy, || Err(CliError::usage("bad args")))
79 .expect_err("usage errors must not be retried");
80
81 assert_eq!(err.exit_code(), 2);
82 }
83}