1use std::thread;
10use std::time::Duration;
11use tracing::{debug, warn};
12
13const DEFAULT_MAX_ATTEMPTS: u32 = 3;
24
25const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
29
30const DEFAULT_MAX_DELAY_SECS: u64 = 5;
34
35const BACKOFF_MULTIPLIER: f64 = 2.0;
38
39const NETWORK_MAX_ATTEMPTS: u32 = 4;
43
44const NETWORK_INITIAL_DELAY_MS: u64 = 500;
48
49const NETWORK_MAX_DELAY_SECS: u64 = 5;
53
54const CONNECTION_INITIAL_DELAY_MS: u64 = 50;
58
59const CONNECTION_MAX_DELAY_SECS: u64 = 2;
63
64#[derive(Debug, Clone)]
66pub struct RetryConfig {
67 pub max_attempts: u32,
69 pub initial_delay: Duration,
71 pub max_delay: Duration,
73 pub backoff_multiplier: f64,
75}
76
77impl Default for RetryConfig {
78 fn default() -> Self {
79 Self {
80 max_attempts: DEFAULT_MAX_ATTEMPTS,
81 initial_delay: Duration::from_millis(DEFAULT_INITIAL_DELAY_MS),
82 max_delay: Duration::from_secs(DEFAULT_MAX_DELAY_SECS),
83 backoff_multiplier: BACKOFF_MULTIPLIER,
84 }
85 }
86}
87
88impl RetryConfig {
89 pub fn for_network() -> Self {
94 Self {
95 max_attempts: NETWORK_MAX_ATTEMPTS,
96 initial_delay: Duration::from_millis(NETWORK_INITIAL_DELAY_MS),
97 max_delay: Duration::from_secs(NETWORK_MAX_DELAY_SECS),
98 backoff_multiplier: BACKOFF_MULTIPLIER,
99 }
100 }
101
102 pub fn for_connection() -> Self {
107 Self {
108 max_attempts: DEFAULT_MAX_ATTEMPTS,
109 initial_delay: Duration::from_millis(CONNECTION_INITIAL_DELAY_MS),
110 max_delay: Duration::from_secs(CONNECTION_MAX_DELAY_SECS),
111 backoff_multiplier: BACKOFF_MULTIPLIER,
112 }
113 }
114}
115
116pub fn retry_with_backoff<T, E, F, R>(
132 config: RetryConfig,
133 operation_name: &str,
134 mut operation: F,
135 should_retry: R,
136) -> Result<T, E>
137where
138 F: FnMut() -> Result<T, E>,
139 R: Fn(&E) -> bool,
140 E: std::fmt::Display,
141{
142 let mut attempt = 0;
143 let mut delay = config.initial_delay;
144
145 loop {
146 attempt += 1;
147
148 match operation() {
149 Ok(result) => {
150 if attempt > 1 {
151 debug!(
152 operation = %operation_name,
153 attempts = attempt,
154 "operation succeeded after retry"
155 );
156 }
157 return Ok(result);
158 }
159 Err(e) => {
160 if attempt >= config.max_attempts {
161 warn!(
162 operation = %operation_name,
163 attempts = attempt,
164 error = %e,
165 "operation failed after max attempts"
166 );
167 return Err(e);
168 }
169
170 if !should_retry(&e) {
171 debug!(
172 operation = %operation_name,
173 attempt = attempt,
174 error = %e,
175 "operation failed with non-retryable error"
176 );
177 return Err(e);
178 }
179
180 warn!(
181 operation = %operation_name,
182 attempt = attempt,
183 max_attempts = config.max_attempts,
184 delay_ms = delay.as_millis(),
185 error = %e,
186 "operation failed, will retry"
187 );
188
189 thread::sleep(delay);
190
191 delay = Duration::from_secs_f64(
193 (delay.as_secs_f64() * config.backoff_multiplier)
194 .min(config.max_delay.as_secs_f64()),
195 );
196 }
197 }
198 }
199}
200
201pub fn is_transient_network_error(error_msg: &str) -> bool {
206 let error_lower = error_msg.to_lowercase();
207
208 if error_lower.contains("connection refused")
210 || error_lower.contains("connection reset")
211 || error_lower.contains("connection timed out")
212 || error_lower.contains("network is unreachable")
213 || error_lower.contains("no route to host")
214 || error_lower.contains("temporary failure")
215 || error_lower.contains("try again")
216 || error_lower.contains("resource temporarily unavailable")
217 {
218 return true;
219 }
220
221 if error_lower.contains("name resolution")
223 || error_lower.contains("dns")
224 || error_lower.contains("could not resolve")
225 || error_lower.contains("no such host")
226 {
227 return true;
228 }
229
230 if error_lower.contains("502 bad gateway")
232 || error_lower.contains("503 service unavailable")
233 || error_lower.contains("504 gateway timeout")
234 || error_lower.contains("429 too many requests")
235 {
236 return true;
237 }
238
239 if error_lower.contains("toomanyrequests")
241 || error_lower.contains("rate limit")
242 || error_lower.contains("quota exceeded")
243 {
244 return true;
245 }
246
247 if error_lower.contains("broken pipe")
249 || error_lower.contains("interrupted")
250 || error_lower.contains("eagain")
251 || error_lower.contains("ewouldblock")
252 {
253 return true;
254 }
255
256 false
257}
258
259pub fn is_permanent_error(error_msg: &str) -> bool {
261 let error_lower = error_msg.to_lowercase();
262
263 if error_lower.contains("401 unauthorized")
265 || error_lower.contains("403 forbidden")
266 || error_lower.contains("authentication required")
267 || error_lower.contains("access denied")
268 {
269 return true;
270 }
271
272 if error_lower.contains("404 not found")
274 || error_lower.contains("manifest unknown")
275 || error_lower.contains("name unknown")
276 || error_lower.contains("repository does not exist")
277 {
278 return true;
279 }
280
281 if error_lower.contains("invalid reference")
283 || error_lower.contains("invalid image")
284 || error_lower.contains("malformed")
285 {
286 return true;
287 }
288
289 false
290}
291
292pub fn is_transient_io_error(error: &std::io::Error) -> bool {
294 use std::io::ErrorKind;
295
296 matches!(
297 error.kind(),
298 ErrorKind::ConnectionRefused
299 | ErrorKind::ConnectionReset
300 | ErrorKind::ConnectionAborted
301 | ErrorKind::NotConnected
302 | ErrorKind::BrokenPipe
303 | ErrorKind::TimedOut
304 | ErrorKind::Interrupted
305 | ErrorKind::WouldBlock
306 )
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use std::cell::RefCell;
313
314 #[test]
315 fn test_retry_success_first_attempt() {
316 let result: Result<i32, &str> =
317 retry_with_backoff(RetryConfig::default(), "test", || Ok(42), |_| true);
318 assert_eq!(result.unwrap(), 42);
319 }
320
321 #[test]
322 fn test_retry_success_after_failures() {
323 let attempts = RefCell::new(0);
324 let result: Result<i32, &str> = retry_with_backoff(
325 RetryConfig {
326 max_attempts: 3,
327 initial_delay: Duration::from_millis(1),
328 max_delay: Duration::from_millis(10),
329 backoff_multiplier: 2.0,
330 },
331 "test",
332 || {
333 *attempts.borrow_mut() += 1;
334 if *attempts.borrow() < 3 {
335 Err("transient error")
336 } else {
337 Ok(42)
338 }
339 },
340 |_| true,
341 );
342 assert_eq!(result.unwrap(), 42);
343 assert_eq!(*attempts.borrow(), 3);
344 }
345
346 #[test]
347 fn test_retry_exhausted() {
348 let attempts = RefCell::new(0);
349 let result: Result<i32, &str> = retry_with_backoff(
350 RetryConfig {
351 max_attempts: 3,
352 initial_delay: Duration::from_millis(1),
353 max_delay: Duration::from_millis(10),
354 backoff_multiplier: 2.0,
355 },
356 "test",
357 || {
358 *attempts.borrow_mut() += 1;
359 Err("always fails")
360 },
361 |_| true,
362 );
363 assert!(result.is_err());
364 assert_eq!(*attempts.borrow(), 3);
365 }
366
367 #[test]
368 fn test_retry_non_retryable_error() {
369 let attempts = RefCell::new(0);
370 let result: Result<i32, &str> = retry_with_backoff(
371 RetryConfig::default(),
372 "test",
373 || {
374 *attempts.borrow_mut() += 1;
375 Err("permanent error")
376 },
377 |_| false, );
379 assert!(result.is_err());
380 assert_eq!(*attempts.borrow(), 1);
381 }
382
383 #[test]
384 fn test_transient_network_errors() {
385 assert!(is_transient_network_error("connection refused"));
386 assert!(is_transient_network_error("Connection timed out"));
387 assert!(is_transient_network_error("503 Service Unavailable"));
388 assert!(is_transient_network_error("rate limit exceeded"));
389 assert!(!is_transient_network_error("404 not found"));
390 assert!(!is_transient_network_error("some random error"));
391 }
392
393 #[test]
394 fn test_permanent_errors() {
395 assert!(is_permanent_error("401 Unauthorized"));
396 assert!(is_permanent_error("404 Not Found"));
397 assert!(is_permanent_error("manifest unknown"));
398 assert!(!is_permanent_error("connection refused"));
399 assert!(!is_permanent_error("503 Service Unavailable"));
400 }
401
402 #[test]
403 fn test_config_presets() {
404 let network = RetryConfig::for_network();
405 assert_eq!(network.max_attempts, 4);
406 assert_eq!(network.initial_delay, Duration::from_millis(500));
407
408 let connection = RetryConfig::for_connection();
409 assert_eq!(connection.max_attempts, 3);
410 assert_eq!(connection.initial_delay, Duration::from_millis(50));
411 }
412}