use std::thread;
use std::time::Duration;
use tracing::{debug, warn};
const DEFAULT_MAX_ATTEMPTS: u32 = 3;
const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
const DEFAULT_MAX_DELAY_SECS: u64 = 5;
const BACKOFF_MULTIPLIER: f64 = 2.0;
const NETWORK_MAX_ATTEMPTS: u32 = 4;
const NETWORK_INITIAL_DELAY_MS: u64 = 500;
const NETWORK_MAX_DELAY_SECS: u64 = 5;
const CONNECTION_INITIAL_DELAY_MS: u64 = 100;
const CONNECTION_MAX_DELAY_SECS: u64 = 2;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: DEFAULT_MAX_ATTEMPTS,
initial_delay: Duration::from_millis(DEFAULT_INITIAL_DELAY_MS),
max_delay: Duration::from_secs(DEFAULT_MAX_DELAY_SECS),
backoff_multiplier: BACKOFF_MULTIPLIER,
}
}
}
impl RetryConfig {
pub fn for_network() -> Self {
Self {
max_attempts: NETWORK_MAX_ATTEMPTS,
initial_delay: Duration::from_millis(NETWORK_INITIAL_DELAY_MS),
max_delay: Duration::from_secs(NETWORK_MAX_DELAY_SECS),
backoff_multiplier: BACKOFF_MULTIPLIER,
}
}
pub fn for_connection() -> Self {
Self {
max_attempts: 6,
initial_delay: Duration::from_millis(CONNECTION_INITIAL_DELAY_MS),
max_delay: Duration::from_secs(CONNECTION_MAX_DELAY_SECS),
backoff_multiplier: BACKOFF_MULTIPLIER,
}
}
}
pub fn retry_with_backoff<T, E, F, R>(
config: RetryConfig,
operation_name: &str,
mut operation: F,
should_retry: R,
) -> Result<T, E>
where
F: FnMut() -> Result<T, E>,
R: Fn(&E) -> bool,
E: std::fmt::Display,
{
let mut attempt = 0;
let mut delay = config.initial_delay;
loop {
attempt += 1;
match operation() {
Ok(result) => {
if attempt > 1 {
debug!(
operation = %operation_name,
attempts = attempt,
"operation succeeded after retry"
);
}
return Ok(result);
}
Err(e) => {
if attempt >= config.max_attempts {
warn!(
operation = %operation_name,
attempts = attempt,
error = %e,
"operation failed after max attempts"
);
return Err(e);
}
if !should_retry(&e) {
debug!(
operation = %operation_name,
attempt = attempt,
error = %e,
"operation failed with non-retryable error"
);
return Err(e);
}
warn!(
operation = %operation_name,
attempt = attempt,
max_attempts = config.max_attempts,
delay_ms = delay.as_millis(),
error = %e,
"operation failed, will retry"
);
thread::sleep(delay);
delay = Duration::from_secs_f64(
(delay.as_secs_f64() * config.backoff_multiplier)
.min(config.max_delay.as_secs_f64()),
);
}
}
}
}
pub fn is_transient_network_error(error_msg: &str) -> bool {
let error_lower = error_msg.to_lowercase();
if error_lower.contains("connection refused")
|| error_lower.contains("connection reset")
|| error_lower.contains("connection timed out")
|| error_lower.contains("network is unreachable")
|| error_lower.contains("no route to host")
|| error_lower.contains("temporary failure")
|| error_lower.contains("try again")
|| error_lower.contains("resource temporarily unavailable")
{
return true;
}
if error_lower.contains("name resolution")
|| error_lower.contains("dns")
|| error_lower.contains("could not resolve")
|| error_lower.contains("no such host")
{
return true;
}
if error_lower.contains("502 bad gateway")
|| error_lower.contains("503 service unavailable")
|| error_lower.contains("504 gateway timeout")
|| error_lower.contains("429 too many requests")
{
return true;
}
if error_lower.contains("toomanyrequests")
|| error_lower.contains("rate limit")
|| error_lower.contains("quota exceeded")
{
return true;
}
if error_lower.contains("broken pipe")
|| error_lower.contains("interrupted")
|| error_lower.contains("eagain")
|| error_lower.contains("ewouldblock")
{
return true;
}
false
}
pub fn is_permanent_error(error_msg: &str) -> bool {
let error_lower = error_msg.to_lowercase();
if error_lower.contains("401 unauthorized")
|| error_lower.contains("403 forbidden")
|| error_lower.contains("authentication required")
|| error_lower.contains("access denied")
{
return true;
}
if error_lower.contains("404 not found")
|| error_lower.contains("manifest unknown")
|| error_lower.contains("name unknown")
|| error_lower.contains("repository does not exist")
{
return true;
}
if error_lower.contains("invalid reference")
|| error_lower.contains("invalid image")
|| error_lower.contains("malformed")
{
return true;
}
false
}
pub fn is_transient_io_error(error: &std::io::Error) -> bool {
use std::io::ErrorKind;
matches!(
error.kind(),
ErrorKind::ConnectionRefused
| ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::NotConnected
| ErrorKind::BrokenPipe
| ErrorKind::TimedOut
| ErrorKind::Interrupted
| ErrorKind::WouldBlock
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::RefCell;
#[test]
fn test_retry_success_first_attempt() {
let result: Result<i32, &str> =
retry_with_backoff(RetryConfig::default(), "test", || Ok(42), |_| true);
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_retry_success_after_failures() {
let attempts = RefCell::new(0);
let result: Result<i32, &str> = retry_with_backoff(
RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
backoff_multiplier: 2.0,
},
"test",
|| {
*attempts.borrow_mut() += 1;
if *attempts.borrow() < 3 {
Err("transient error")
} else {
Ok(42)
}
},
|_| true,
);
assert_eq!(result.unwrap(), 42);
assert_eq!(*attempts.borrow(), 3);
}
#[test]
fn test_retry_exhausted() {
let attempts = RefCell::new(0);
let result: Result<i32, &str> = retry_with_backoff(
RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
backoff_multiplier: 2.0,
},
"test",
|| {
*attempts.borrow_mut() += 1;
Err("always fails")
},
|_| true,
);
assert!(result.is_err());
assert_eq!(*attempts.borrow(), 3);
}
#[test]
fn test_retry_non_retryable_error() {
let attempts = RefCell::new(0);
let result: Result<i32, &str> = retry_with_backoff(
RetryConfig::default(),
"test",
|| {
*attempts.borrow_mut() += 1;
Err("permanent error")
},
|_| false, );
assert!(result.is_err());
assert_eq!(*attempts.borrow(), 1);
}
#[test]
fn test_transient_network_errors() {
assert!(is_transient_network_error("connection refused"));
assert!(is_transient_network_error("Connection timed out"));
assert!(is_transient_network_error("503 Service Unavailable"));
assert!(is_transient_network_error("rate limit exceeded"));
assert!(!is_transient_network_error("404 not found"));
assert!(!is_transient_network_error("some random error"));
}
#[test]
fn test_permanent_errors() {
assert!(is_permanent_error("401 Unauthorized"));
assert!(is_permanent_error("404 Not Found"));
assert!(is_permanent_error("manifest unknown"));
assert!(!is_permanent_error("connection refused"));
assert!(!is_permanent_error("503 Service Unavailable"));
}
#[test]
fn test_config_presets() {
let network = RetryConfig::for_network();
assert_eq!(network.max_attempts, 4);
assert_eq!(network.initial_delay, Duration::from_millis(500));
let connection = RetryConfig::for_connection();
assert_eq!(connection.max_attempts, 6);
assert_eq!(connection.initial_delay, Duration::from_millis(100));
}
}