1use std::time::Duration;
2
3use tracing::warn;
4
5use crate::error::Error;
6
7#[derive(Debug, Clone)]
25pub struct RetryPolicy {
26 pub(crate) max_attempts: u32,
27 pub(crate) initial_backoff: Duration,
28 pub(crate) max_backoff: Duration,
29 pub(crate) backoff_strategy: BackoffStrategy,
30 pub(crate) retry_on_timeout: bool,
31 pub(crate) retry_exit_codes: Vec<i32>,
32}
33
34#[derive(Debug, Clone, Copy)]
36pub enum BackoffStrategy {
37 Fixed,
39 Exponential,
41}
42
43impl Default for RetryPolicy {
44 fn default() -> Self {
45 Self {
46 max_attempts: 3,
47 initial_backoff: Duration::from_secs(1),
48 max_backoff: Duration::from_secs(30),
49 backoff_strategy: BackoffStrategy::Fixed,
50 retry_on_timeout: true,
51 retry_exit_codes: Vec::new(),
52 }
53 }
54}
55
56impl RetryPolicy {
57 #[must_use]
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 #[must_use]
67 pub fn max_attempts(mut self, n: u32) -> Self {
68 self.max_attempts = n;
69 self
70 }
71
72 #[must_use]
74 pub fn initial_backoff(mut self, duration: Duration) -> Self {
75 self.initial_backoff = duration;
76 self
77 }
78
79 #[must_use]
81 pub fn max_backoff(mut self, duration: Duration) -> Self {
82 self.max_backoff = duration;
83 self
84 }
85
86 #[must_use]
88 pub fn fixed(mut self) -> Self {
89 self.backoff_strategy = BackoffStrategy::Fixed;
90 self
91 }
92
93 #[must_use]
95 pub fn exponential(mut self) -> Self {
96 self.backoff_strategy = BackoffStrategy::Exponential;
97 self
98 }
99
100 #[must_use]
102 pub fn retry_on_timeout(mut self, retry: bool) -> Self {
103 self.retry_on_timeout = retry;
104 self
105 }
106
107 #[must_use]
109 pub fn retry_on_exit_codes(mut self, codes: impl IntoIterator<Item = i32>) -> Self {
110 self.retry_exit_codes = codes.into_iter().collect();
111 self
112 }
113
114 pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
116 let delay = match self.backoff_strategy {
117 BackoffStrategy::Fixed => self.initial_backoff,
118 BackoffStrategy::Exponential => self
119 .initial_backoff
120 .saturating_mul(2u32.saturating_pow(attempt)),
121 };
122 delay.min(self.max_backoff)
123 }
124
125 pub(crate) fn should_retry(&self, error: &Error) -> bool {
127 match error {
128 Error::Timeout { .. } => self.retry_on_timeout,
129 Error::CommandFailed { exit_code, .. } => self.retry_exit_codes.contains(exit_code),
130 _ => false,
131 }
132 }
133}
134
135pub(crate) async fn with_retry<F, Fut, T>(
137 policy: &RetryPolicy,
138 mut operation: F,
139) -> crate::error::Result<T>
140where
141 F: FnMut() -> Fut,
142 Fut: std::future::Future<Output = crate::error::Result<T>>,
143{
144 let mut last_error = None;
145
146 for attempt in 0..policy.max_attempts {
147 match operation().await {
148 Ok(result) => return Ok(result),
149 Err(e) => {
150 if attempt + 1 < policy.max_attempts && policy.should_retry(&e) {
151 let delay = policy.delay_for_attempt(attempt);
152 warn!(
153 attempt = attempt + 1,
154 max_attempts = policy.max_attempts,
155 delay_ms = delay.as_millis() as u64,
156 error = %e,
157 "retrying after transient error"
158 );
159 tokio::time::sleep(delay).await;
160 last_error = Some(e);
161 } else {
162 return Err(e);
163 }
164 }
165 }
166 }
167
168 Err(last_error.expect("at least one attempt was made"))
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_default_policy() {
177 let policy = RetryPolicy::new();
178 assert_eq!(policy.max_attempts, 3);
179 assert_eq!(policy.initial_backoff, Duration::from_secs(1));
180 assert!(policy.retry_on_timeout);
181 assert!(policy.retry_exit_codes.is_empty());
182 }
183
184 #[test]
185 fn test_builder() {
186 let policy = RetryPolicy::new()
187 .max_attempts(5)
188 .initial_backoff(Duration::from_millis(500))
189 .exponential()
190 .retry_on_timeout(false)
191 .retry_on_exit_codes([1, 2, 3]);
192
193 assert_eq!(policy.max_attempts, 5);
194 assert_eq!(policy.initial_backoff, Duration::from_millis(500));
195 assert!(!policy.retry_on_timeout);
196 assert_eq!(policy.retry_exit_codes, vec![1, 2, 3]);
197 }
198
199 #[test]
200 fn test_fixed_delay() {
201 let policy = RetryPolicy::new()
202 .initial_backoff(Duration::from_secs(2))
203 .fixed();
204
205 assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(2));
206 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
207 assert_eq!(policy.delay_for_attempt(5), Duration::from_secs(2));
208 }
209
210 #[test]
211 fn test_exponential_delay() {
212 let policy = RetryPolicy::new()
213 .initial_backoff(Duration::from_secs(1))
214 .max_backoff(Duration::from_secs(30))
215 .exponential();
216
217 assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(1));
218 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
219 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(4));
220 assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(8));
221 assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(30));
223 }
224
225 #[test]
226 fn test_should_retry_timeout() {
227 let policy = RetryPolicy::new().retry_on_timeout(true);
228 let error = Error::Timeout {
229 timeout_seconds: 60,
230 };
231 assert!(policy.should_retry(&error));
232
233 let policy = RetryPolicy::new().retry_on_timeout(false);
234 assert!(!policy.should_retry(&error));
235 }
236
237 #[test]
238 fn test_should_retry_exit_code() {
239 let policy = RetryPolicy::new().retry_on_exit_codes([1, 2]);
240
241 let retryable = Error::CommandFailed {
242 command: "test".into(),
243 exit_code: 1,
244 stdout: String::new(),
245 stderr: String::new(),
246 working_dir: None,
247 };
248 assert!(policy.should_retry(&retryable));
249
250 let not_retryable = Error::CommandFailed {
251 command: "test".into(),
252 exit_code: 99,
253 stdout: String::new(),
254 stderr: String::new(),
255 working_dir: None,
256 };
257 assert!(!policy.should_retry(¬_retryable));
258 }
259
260 #[test]
261 fn test_should_not_retry_other_errors() {
262 let policy = RetryPolicy::new()
263 .retry_on_timeout(true)
264 .retry_on_exit_codes([1]);
265
266 let error = Error::NotFound;
267 assert!(!policy.should_retry(&error));
268 }
269
270 #[tokio::test]
271 async fn test_with_retry_succeeds_first_try() {
272 let policy = RetryPolicy::new().max_attempts(3);
273 let result = with_retry(&policy, || async { Ok::<_, Error>(42) }).await;
274 assert_eq!(result.unwrap(), 42);
275 }
276
277 #[tokio::test]
278 async fn test_with_retry_succeeds_after_failures() {
279 let policy = RetryPolicy::new()
280 .max_attempts(3)
281 .initial_backoff(Duration::from_millis(1))
282 .retry_on_timeout(true);
283
284 let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
285 let attempt_clone = attempt.clone();
286
287 let result = with_retry(&policy, || {
288 let attempt = attempt_clone.clone();
289 async move {
290 let n = attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
291 if n < 2 {
292 Err(Error::Timeout {
293 timeout_seconds: 60,
294 })
295 } else {
296 Ok(42)
297 }
298 }
299 })
300 .await;
301
302 assert_eq!(result.unwrap(), 42);
303 assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
304 }
305
306 #[tokio::test]
307 async fn test_with_retry_exhausts_attempts() {
308 let policy = RetryPolicy::new()
309 .max_attempts(2)
310 .initial_backoff(Duration::from_millis(1))
311 .retry_on_timeout(true);
312
313 let result: crate::error::Result<()> = with_retry(&policy, || async {
314 Err(Error::Timeout {
315 timeout_seconds: 60,
316 })
317 })
318 .await;
319
320 assert!(matches!(result, Err(Error::Timeout { .. })));
321 }
322
323 #[tokio::test]
324 async fn test_with_retry_no_retry_on_non_retryable() {
325 let policy = RetryPolicy::new()
326 .max_attempts(3)
327 .initial_backoff(Duration::from_millis(1))
328 .retry_on_timeout(false);
329
330 let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
331 let attempt_clone = attempt.clone();
332
333 let result: crate::error::Result<()> = with_retry(&policy, || {
334 let attempt = attempt_clone.clone();
335 async move {
336 attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
337 Err(Error::Timeout {
338 timeout_seconds: 60,
339 })
340 }
341 })
342 .await;
343
344 assert!(result.is_err());
345 assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 1);
347 }
348}