1use std::time::Duration;
2
3use tracing::warn;
4
5use crate::error::Error;
6
7#[derive(Debug, Clone)]
25pub struct RetryPolicy {
26 pub(crate) max_attempts: u32,
27 pub(crate) initial_backoff: Duration,
28 pub(crate) max_backoff: Duration,
29 pub(crate) backoff_strategy: BackoffStrategy,
30 pub(crate) retry_on_timeout: bool,
31 pub(crate) retry_exit_codes: Vec<i32>,
32}
33
34#[derive(Debug, Clone, Copy)]
36pub enum BackoffStrategy {
37 Fixed,
39 Exponential,
41}
42
43impl Default for RetryPolicy {
44 fn default() -> Self {
45 Self {
46 max_attempts: 3,
47 initial_backoff: Duration::from_secs(1),
48 max_backoff: Duration::from_secs(30),
49 backoff_strategy: BackoffStrategy::Fixed,
50 retry_on_timeout: true,
51 retry_exit_codes: Vec::new(),
52 }
53 }
54}
55
56impl RetryPolicy {
57 #[must_use]
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 #[must_use]
67 pub fn max_attempts(mut self, n: u32) -> Self {
68 self.max_attempts = n;
69 self
70 }
71
72 #[must_use]
74 pub fn initial_backoff(mut self, duration: Duration) -> Self {
75 self.initial_backoff = duration;
76 self
77 }
78
79 #[must_use]
81 pub fn max_backoff(mut self, duration: Duration) -> Self {
82 self.max_backoff = duration;
83 self
84 }
85
86 #[must_use]
88 pub fn fixed(mut self) -> Self {
89 self.backoff_strategy = BackoffStrategy::Fixed;
90 self
91 }
92
93 #[must_use]
95 pub fn exponential(mut self) -> Self {
96 self.backoff_strategy = BackoffStrategy::Exponential;
97 self
98 }
99
100 #[must_use]
102 pub fn retry_on_timeout(mut self, retry: bool) -> Self {
103 self.retry_on_timeout = retry;
104 self
105 }
106
107 #[must_use]
109 pub fn retry_on_exit_codes(mut self, codes: impl IntoIterator<Item = i32>) -> Self {
110 self.retry_exit_codes = codes.into_iter().collect();
111 self
112 }
113
114 pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
116 let delay = match self.backoff_strategy {
117 BackoffStrategy::Fixed => self.initial_backoff,
118 BackoffStrategy::Exponential => self
119 .initial_backoff
120 .saturating_mul(2u32.saturating_pow(attempt)),
121 };
122 delay.min(self.max_backoff)
123 }
124
125 pub(crate) fn should_retry(&self, error: &Error) -> bool {
127 match error {
128 Error::Timeout { .. } => self.retry_on_timeout,
129 Error::CommandFailed { exit_code, .. } => self.retry_exit_codes.contains(exit_code),
130 _ => false,
131 }
132 }
133}
134
135pub(crate) async fn with_retry<F, Fut, T>(
137 policy: &RetryPolicy,
138 mut operation: F,
139) -> crate::error::Result<T>
140where
141 F: FnMut() -> Fut,
142 Fut: std::future::Future<Output = crate::error::Result<T>>,
143{
144 let mut last_error = None;
145
146 for attempt in 0..policy.max_attempts {
147 match operation().await {
148 Ok(result) => return Ok(result),
149 Err(e) => {
150 if attempt + 1 < policy.max_attempts && policy.should_retry(&e) {
151 let delay = policy.delay_for_attempt(attempt);
152 warn!(
153 attempt = attempt + 1,
154 max_attempts = policy.max_attempts,
155 delay_ms = delay.as_millis() as u64,
156 error = %e,
157 "retrying after transient error"
158 );
159 tokio::time::sleep(delay).await;
160 last_error = Some(e);
161 } else {
162 return Err(e);
163 }
164 }
165 }
166 }
167
168 Err(last_error.expect("at least one attempt was made"))
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_default_policy() {
177 let policy = RetryPolicy::new();
178 assert_eq!(policy.max_attempts, 3);
179 assert_eq!(policy.initial_backoff, Duration::from_secs(1));
180 assert!(policy.retry_on_timeout);
181 assert!(policy.retry_exit_codes.is_empty());
182 }
183
184 #[test]
185 fn test_builder() {
186 let policy = RetryPolicy::new()
187 .max_attempts(5)
188 .initial_backoff(Duration::from_millis(500))
189 .exponential()
190 .retry_on_timeout(false)
191 .retry_on_exit_codes([1, 2, 3]);
192
193 assert_eq!(policy.max_attempts, 5);
194 assert_eq!(policy.initial_backoff, Duration::from_millis(500));
195 assert!(!policy.retry_on_timeout);
196 assert_eq!(policy.retry_exit_codes, vec![1, 2, 3]);
197 }
198
199 #[test]
200 fn test_fixed_delay() {
201 let policy = RetryPolicy::new()
202 .initial_backoff(Duration::from_secs(2))
203 .fixed();
204
205 assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(2));
206 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
207 assert_eq!(policy.delay_for_attempt(5), Duration::from_secs(2));
208 }
209
210 #[test]
211 fn test_exponential_delay() {
212 let policy = RetryPolicy::new()
213 .initial_backoff(Duration::from_secs(1))
214 .max_backoff(Duration::from_secs(30))
215 .exponential();
216
217 assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(1));
218 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
219 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(4));
220 assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(8));
221 assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(30));
223 }
224
225 #[test]
226 fn test_should_retry_timeout() {
227 let policy = RetryPolicy::new().retry_on_timeout(true);
228 let error = Error::Timeout {
229 timeout_seconds: 60,
230 };
231 assert!(policy.should_retry(&error));
232
233 let policy = RetryPolicy::new().retry_on_timeout(false);
234 assert!(!policy.should_retry(&error));
235 }
236
237 #[test]
238 fn test_should_retry_exit_code() {
239 let policy = RetryPolicy::new().retry_on_exit_codes([1, 2]);
240
241 let retryable = Error::CommandFailed {
242 command: "test".into(),
243 exit_code: 1,
244 stdout: String::new(),
245 stderr: String::new(),
246 };
247 assert!(policy.should_retry(&retryable));
248
249 let not_retryable = Error::CommandFailed {
250 command: "test".into(),
251 exit_code: 99,
252 stdout: String::new(),
253 stderr: String::new(),
254 };
255 assert!(!policy.should_retry(¬_retryable));
256 }
257
258 #[test]
259 fn test_should_not_retry_other_errors() {
260 let policy = RetryPolicy::new()
261 .retry_on_timeout(true)
262 .retry_on_exit_codes([1]);
263
264 let error = Error::NotFound;
265 assert!(!policy.should_retry(&error));
266 }
267
268 #[tokio::test]
269 async fn test_with_retry_succeeds_first_try() {
270 let policy = RetryPolicy::new().max_attempts(3);
271 let result = with_retry(&policy, || async { Ok::<_, Error>(42) }).await;
272 assert_eq!(result.unwrap(), 42);
273 }
274
275 #[tokio::test]
276 async fn test_with_retry_succeeds_after_failures() {
277 let policy = RetryPolicy::new()
278 .max_attempts(3)
279 .initial_backoff(Duration::from_millis(1))
280 .retry_on_timeout(true);
281
282 let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
283 let attempt_clone = attempt.clone();
284
285 let result = with_retry(&policy, || {
286 let attempt = attempt_clone.clone();
287 async move {
288 let n = attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
289 if n < 2 {
290 Err(Error::Timeout {
291 timeout_seconds: 60,
292 })
293 } else {
294 Ok(42)
295 }
296 }
297 })
298 .await;
299
300 assert_eq!(result.unwrap(), 42);
301 assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 3);
302 }
303
304 #[tokio::test]
305 async fn test_with_retry_exhausts_attempts() {
306 let policy = RetryPolicy::new()
307 .max_attempts(2)
308 .initial_backoff(Duration::from_millis(1))
309 .retry_on_timeout(true);
310
311 let result: crate::error::Result<()> = with_retry(&policy, || async {
312 Err(Error::Timeout {
313 timeout_seconds: 60,
314 })
315 })
316 .await;
317
318 assert!(matches!(result, Err(Error::Timeout { .. })));
319 }
320
321 #[tokio::test]
322 async fn test_with_retry_no_retry_on_non_retryable() {
323 let policy = RetryPolicy::new()
324 .max_attempts(3)
325 .initial_backoff(Duration::from_millis(1))
326 .retry_on_timeout(false);
327
328 let attempt = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
329 let attempt_clone = attempt.clone();
330
331 let result: crate::error::Result<()> = with_retry(&policy, || {
332 let attempt = attempt_clone.clone();
333 async move {
334 attempt.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
335 Err(Error::Timeout {
336 timeout_seconds: 60,
337 })
338 }
339 })
340 .await;
341
342 assert!(result.is_err());
343 assert_eq!(attempt.load(std::sync::atomic::Ordering::SeqCst), 1);
345 }
346}