use std::time::Duration;
use rand::prelude::*;
use tokio::time::sleep;
use typed_builder::TypedBuilder;
#[derive(Debug, Clone, TypedBuilder)]
pub struct RetryPolicy {
#[builder(default = 3)]
max_attempts: u32,
#[builder(default = Duration::from_millis(500))]
initial_delay: Duration,
#[builder(default = Duration::from_secs(60))]
max_delay: Duration,
#[builder(default = 2.0)]
backoff_factor: f64,
#[builder(default = true)]
jitter: bool,
}
impl RetryPolicy {
pub fn new() -> Self {
Self::default()
}
pub async fn execute<F, Fut, T, E>(&self, operation: F) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
self.execute_with_condition(operation, |_| true).await
}
pub async fn execute_with_condition<F, Fut, T, E, P>(&self, mut operation: F, is_retryable: P) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: std::fmt::Display,
P: Fn(&E) -> bool,
{
let mut last_error = None;
let max_attempts = self.max_attempts.max(1);
for attempt in 0..max_attempts {
if attempt > 0 {
tracing::debug!(attempt = attempt + 1, max = self.max_attempts, "🔄 Retry attempt");
}
match operation().await {
Ok(result) => {
if attempt > 0 {
tracing::info!(attempts = attempt, "✅ Operation succeeded after retries");
}
return Ok(result);
}
Err(e) => {
if !is_retryable(&e) {
tracing::warn!(error = %e, "Non-retryable error encountered");
return Err(e);
}
tracing::warn!(
attempt = attempt + 1,
max = self.max_attempts,
error = %e,
"🔄 Operation failed, retrying"
);
last_error = Some(e);
if attempt + 1 < max_attempts {
let delay = self.calculate_delay(attempt);
tracing::debug!(delay = ?delay, "🔄 Waiting before retry");
sleep(delay).await;
}
}
}
}
Err(last_error.unwrap_or_else(|| unreachable!("retry loop exited without recording an error")))
}
pub fn max_attempts(&self) -> u32 {
self.max_attempts
}
pub fn initial_delay(&self) -> Duration {
self.initial_delay
}
pub fn max_delay(&self) -> Duration {
self.max_delay
}
pub fn backoff_factor(&self) -> f64 {
self.backoff_factor
}
pub fn has_jitter(&self) -> bool {
self.jitter
}
fn calculate_delay(&self, attempt: u32) -> Duration {
let base_delay = self.initial_delay.as_millis() as f64 * self.backoff_factor.powi(attempt as i32);
let delay_ms = base_delay.min(self.max_delay.as_millis() as f64);
let final_delay_ms = if self.jitter {
let mut rng = rand::rng();
let jitter_factor: f64 = rng.random_range(0.5..=1.5);
delay_ms * jitter_factor
} else {
delay_ms
};
Duration::from_millis(final_delay_ms as u64)
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(60),
backoff_factor: 2.0,
jitter: true,
}
}
}
impl std::fmt::Display for RetryPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RetryPolicy(attempts={}, backoff={}x, jitter={})",
self.max_attempts, self.backoff_factor, self.jitter
)
}
}
pub fn is_http_error_retryable(error: &reqwest::Error) -> bool {
let is_timeout = error.is_timeout();
let is_connect = error.is_connect();
let is_request = error.is_request();
if is_timeout || is_connect || is_request {
return true;
}
if let Some(status) = error.status() {
let code = status.as_u16();
return matches!(
code,
408 | 429 | 500..=599 );
}
false
}