use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetryStrategy {
max_retries: u32,
initial_delay: Duration,
max_delay: Duration,
multiplier: f64,
jitter: bool,
}
impl RetryStrategy {
pub fn exponential_backoff() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
multiplier: 2.0,
jitter: true,
}
}
pub fn fixed_delay(delay: Duration) -> Self {
Self {
max_retries: 3,
initial_delay: delay,
max_delay: delay,
multiplier: 1.0,
jitter: false,
}
}
pub fn no_retry() -> Self {
Self {
max_retries: 0,
initial_delay: Duration::from_secs(0),
max_delay: Duration::from_secs(0),
multiplier: 1.0,
jitter: false,
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn with_multiplier(mut self, multiplier: f64) -> Self {
assert!(
multiplier > 0.0 && multiplier.is_finite(),
"RetryStrategy multiplier must be a positive finite number, got {}",
multiplier
);
self.multiplier = multiplier;
self
}
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn calculate_delay(&self, attempt: u32) -> Duration {
if attempt == 0 || self.max_retries == 0 {
return Duration::from_secs(0);
}
let base_delay_secs =
self.initial_delay.as_secs_f64() * self.multiplier.powi((attempt - 1) as i32);
let delay_secs = base_delay_secs.min(self.max_delay.as_secs_f64());
if self.jitter {
let jitter_factor = rand::random::<f64>();
Duration::from_secs_f64(delay_secs * jitter_factor)
} else {
Duration::from_secs_f64(delay_secs)
}
}
pub fn should_retry(&self, attempt: u32) -> bool {
attempt < self.max_retries
}
pub fn initial_delay(&self) -> Duration {
self.initial_delay
}
pub fn max_delay(&self) -> Duration {
self.max_delay
}
pub fn multiplier(&self) -> f64 {
self.multiplier
}
pub fn has_jitter(&self) -> bool {
self.jitter
}
}
impl Default for RetryStrategy {
fn default() -> Self {
Self::exponential_backoff()
}
}
#[derive(Debug, Clone)]
pub struct RetryState {
strategy: RetryStrategy,
attempts: u32,
}
impl RetryState {
pub fn new(strategy: RetryStrategy) -> Self {
Self {
strategy,
attempts: 0,
}
}
pub fn attempts(&self) -> u32 {
self.attempts
}
pub fn record_attempt(&mut self) {
self.attempts += 1;
}
pub fn can_retry(&self) -> bool {
self.strategy.should_retry(self.attempts)
}
pub fn next_delay(&self) -> Duration {
self.strategy.calculate_delay(self.attempts + 1)
}
pub fn strategy(&self) -> &RetryStrategy {
&self.strategy
}
pub fn reset(&mut self) {
self.attempts = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exponential_backoff_creation() {
let strategy = RetryStrategy::exponential_backoff();
assert_eq!(strategy.max_retries(), 3);
assert_eq!(strategy.initial_delay(), Duration::from_secs(1));
assert_eq!(strategy.max_delay(), Duration::from_secs(60));
assert_eq!(strategy.multiplier(), 2.0);
assert!(strategy.has_jitter());
}
#[test]
fn test_fixed_delay_creation() {
let delay = Duration::from_secs(5);
let strategy = RetryStrategy::fixed_delay(delay);
assert_eq!(strategy.initial_delay(), delay);
assert_eq!(strategy.max_delay(), delay);
assert_eq!(strategy.multiplier(), 1.0);
assert!(!strategy.has_jitter());
}
#[test]
fn test_no_retry_creation() {
let strategy = RetryStrategy::no_retry();
assert_eq!(strategy.max_retries(), 0);
}
#[test]
fn test_strategy_builder() {
let strategy = RetryStrategy::exponential_backoff()
.with_max_retries(5)
.with_initial_delay(Duration::from_millis(500))
.with_max_delay(Duration::from_secs(120))
.with_multiplier(3.0)
.with_jitter(false);
assert_eq!(strategy.max_retries(), 5);
assert_eq!(strategy.initial_delay(), Duration::from_millis(500));
assert_eq!(strategy.max_delay(), Duration::from_secs(120));
assert_eq!(strategy.multiplier(), 3.0);
assert!(!strategy.has_jitter());
}
#[test]
fn test_calculate_delay_without_jitter() {
let strategy = RetryStrategy::exponential_backoff()
.with_initial_delay(Duration::from_secs(1))
.with_multiplier(2.0)
.with_jitter(false);
let delay1 = strategy.calculate_delay(1);
let delay2 = strategy.calculate_delay(2);
let delay3 = strategy.calculate_delay(3);
assert_eq!(delay1, Duration::from_secs(1)); assert_eq!(delay2, Duration::from_secs(2)); assert_eq!(delay3, Duration::from_secs(4)); }
#[test]
fn test_calculate_delay_with_max() {
let strategy = RetryStrategy::exponential_backoff()
.with_initial_delay(Duration::from_secs(1))
.with_max_delay(Duration::from_secs(5))
.with_multiplier(2.0)
.with_jitter(false);
let delay5 = strategy.calculate_delay(5);
assert_eq!(delay5, Duration::from_secs(5)); }
#[test]
fn test_should_retry() {
let strategy = RetryStrategy::exponential_backoff().with_max_retries(3);
assert!(strategy.should_retry(0));
assert!(strategy.should_retry(1));
assert!(strategy.should_retry(2));
assert!(!strategy.should_retry(3));
assert!(!strategy.should_retry(4));
}
#[test]
fn test_retry_state() {
let strategy = RetryStrategy::exponential_backoff().with_max_retries(2);
let mut state = RetryState::new(strategy);
assert_eq!(state.attempts(), 0);
assert!(state.can_retry());
state.record_attempt();
assert_eq!(state.attempts(), 1);
assert!(state.can_retry());
state.record_attempt();
assert_eq!(state.attempts(), 2);
assert!(!state.can_retry());
}
#[test]
fn test_retry_state_reset() {
let strategy = RetryStrategy::exponential_backoff();
let mut state = RetryState::new(strategy);
state.record_attempt();
state.record_attempt();
assert_eq!(state.attempts(), 2);
state.reset();
assert_eq!(state.attempts(), 0);
}
#[test]
fn test_next_delay() {
let strategy = RetryStrategy::exponential_backoff()
.with_initial_delay(Duration::from_secs(1))
.with_jitter(false);
let mut state = RetryState::new(strategy);
state.record_attempt();
let delay = state.next_delay();
assert!(delay >= Duration::from_secs(1));
}
}