use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
}
impl RetryConfig {
pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
Self {
max_attempts: max_attempts.max(1),
initial_delay,
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
pub fn network() -> Self {
Self {
max_attempts: 5,
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(60),
backoff_multiplier: 2.0,
}
}
pub fn storage() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
backoff_multiplier: 1.5,
}
}
}
pub async fn retry_with_backoff<F, T, E>(config: &RetryConfig, mut operation: F) -> Result<T, E>
where
F: FnMut() -> Result<T, E>,
E: std::fmt::Display,
{
let max_attempts = config.max_attempts.max(1);
let mut delay = config.initial_delay;
let mut last_error = None;
for attempt in 0..max_attempts {
match operation() {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if attempt < max_attempts - 1 {
tracing::debug!(
"Operation failed (attempt {}/{}): {}. Retrying in {:?}...",
attempt + 1,
max_attempts,
last_error.as_ref().unwrap(),
delay
);
sleep(delay).await;
delay = std::cmp::min(
Duration::from_secs_f64(delay.as_secs_f64() * config.backoff_multiplier),
config.max_delay,
);
}
}
}
}
Err(last_error.expect("internal error: expected last error after exhausted retries"))
}
pub async fn retry_async_with_backoff<F, Fut, T, E>(
config: &RetryConfig,
mut operation: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let max_attempts = config.max_attempts.max(1);
let mut delay = config.initial_delay;
let mut last_error = None;
for attempt in 0..max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if attempt < max_attempts - 1 {
tracing::debug!(
"Async operation failed (attempt {}/{}): {}. Retrying in {:?}...",
attempt + 1,
max_attempts,
last_error.as_ref().unwrap(),
delay
);
sleep(delay).await;
delay = std::cmp::min(
Duration::from_secs_f64(delay.as_secs_f64() * config.backoff_multiplier),
config.max_delay,
);
}
}
}
}
Err(last_error.expect("internal error: expected last error after exhausted retries"))
}
pub trait IsRetryable {
fn is_retryable(&self) -> bool;
}
pub async fn retry_if_retryable<F, Fut, T, E>(
config: &RetryConfig,
mut operation: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: IsRetryable + std::fmt::Display,
{
let max_attempts = config.max_attempts.max(1);
let mut delay = config.initial_delay;
let mut last_error = None;
for attempt in 0..max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
if !e.is_retryable() {
return Err(e);
}
last_error = Some(e);
if attempt < max_attempts - 1 {
tracing::debug!(
"Retryable error (attempt {}/{}): {}. Retrying in {:?}...",
attempt + 1,
max_attempts,
last_error.as_ref().unwrap(),
delay
);
sleep(delay).await;
delay = std::cmp::min(
Duration::from_secs_f64(delay.as_secs_f64() * config.backoff_multiplier),
config.max_delay,
);
}
}
}
}
Err(last_error.expect("internal error: expected last error after exhausted retries"))
}