use std::time::Duration;
#[derive(Debug, Clone, Copy)]
pub struct RetryStrategy {
pub max_attempts: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl RetryStrategy {
pub fn exponential(max_attempts: usize, initial_delay: Duration, max_delay: Duration) -> Self {
tracing::debug!(
max_attempts = max_attempts,
initial_delay_ms = initial_delay.as_millis(),
max_delay_ms = max_delay.as_millis(),
"⚙️ Creating exponential retry strategy"
);
Self {
max_attempts,
initial_delay,
max_delay,
backoff_multiplier: 2.0,
}
}
pub fn linear(max_attempts: usize, delay: Duration) -> Self {
tracing::debug!(
max_attempts = max_attempts,
delay_ms = delay.as_millis(),
"⚙️ Creating linear retry strategy"
);
Self {
max_attempts,
initial_delay: delay,
max_delay: delay,
backoff_multiplier: 1.0,
}
}
pub fn none() -> Self {
tracing::debug!("⚙️ Creating no-retry strategy");
Self {
max_attempts: 0,
initial_delay: Duration::from_secs(0),
max_delay: Duration::from_secs(0),
backoff_multiplier: 1.0,
}
}
pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
if attempt >= self.max_attempts {
return Duration::from_secs(0);
}
let delay_secs = self.initial_delay.as_secs_f64() * self.backoff_multiplier.powi(attempt as i32);
let result = Duration::from_secs_f64(delay_secs.min(self.max_delay.as_secs_f64()));
tracing::debug!(
attempt = attempt,
delay_ms = result.as_millis(),
"🔄 Calculated retry delay"
);
result
}
pub fn should_retry(&self, attempt: usize) -> bool {
let should_retry = attempt < self.max_attempts;
tracing::debug!(
attempt = attempt,
max_attempts = self.max_attempts,
should_retry = should_retry,
"🔄 Checking retry status"
);
should_retry
}
}
impl Default for RetryStrategy {
fn default() -> Self {
Self::exponential(3, Duration::from_secs(1), Duration::from_secs(30))
}
}
impl std::fmt::Display for RetryStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RetryStrategy(max_attempts={}, initial_delay={}ms, backoff={}x)",
self.max_attempts,
self.initial_delay.as_millis(),
self.backoff_multiplier
)
}
}