use std::time::Duration;
#[derive(Debug, Clone)]
pub enum RetryStrategy {
None,
Fixed {
delay: Duration,
max_attempts: u32,
},
Exponential {
initial_delay: Duration,
max_delay: Duration,
multiplier: f64,
max_attempts: u32,
},
Linear {
initial_delay: Duration,
increment: Duration,
max_delay: Duration,
max_attempts: u32,
},
}
impl Default for RetryStrategy {
fn default() -> Self {
Self::Exponential {
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
multiplier: 2.0,
max_attempts: 5,
}
}
}
impl RetryStrategy {
#[must_use]
pub const fn none() -> Self {
Self::None
}
#[must_use]
pub const fn fixed(delay: Duration, max_attempts: u32) -> Self {
Self::Fixed {
delay,
max_attempts,
}
}
#[must_use]
pub const fn exponential(initial_delay: Duration, max_attempts: u32) -> Self {
Self::Exponential {
initial_delay,
max_delay: Duration::from_secs(60),
multiplier: 2.0,
max_attempts,
}
}
#[must_use]
pub fn delay_for_attempt(&self, attempt: u32) -> Option<Duration> {
match self {
Self::None => None,
Self::Fixed {
delay,
max_attempts,
} => {
if attempt < *max_attempts {
Some(*delay)
} else {
None
}
}
Self::Exponential {
initial_delay,
max_delay,
multiplier,
max_attempts,
} => {
if attempt < *max_attempts {
let delay = initial_delay.as_secs_f64() * multiplier.powi(attempt as i32);
let delay = Duration::from_secs_f64(delay).min(*max_delay);
Some(delay)
} else {
None
}
}
Self::Linear {
initial_delay,
increment,
max_delay,
max_attempts,
} => {
if attempt < *max_attempts {
let delay = *initial_delay + (*increment * attempt);
Some(delay.min(*max_delay))
} else {
None
}
}
}
}
#[must_use]
pub fn should_retry(&self, attempt: u32) -> bool {
self.delay_for_attempt(attempt).is_some()
}
#[must_use]
pub const fn max_attempts(&self) -> u32 {
match self {
Self::None => 1,
Self::Fixed { max_attempts, .. }
| Self::Exponential { max_attempts, .. }
| Self::Linear { max_attempts, .. } => *max_attempts,
}
}
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub connection: RetryStrategy,
pub command: RetryStrategy,
pub retry_on_timeout: bool,
pub retry_on_disconnect: bool,
pub non_retryable_errors: Vec<String>,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
connection: RetryStrategy::default(),
command: RetryStrategy::none(),
retry_on_timeout: true,
retry_on_disconnect: true,
non_retryable_errors: vec![
"authentication failed".to_string(),
"permission denied".to_string(),
],
}
}
}
impl RetryPolicy {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_connection_retries(mut self, strategy: RetryStrategy) -> Self {
self.connection = strategy;
self
}
#[must_use]
pub const fn with_command_retries(mut self, strategy: RetryStrategy) -> Self {
self.command = strategy;
self
}
#[must_use]
pub fn is_retryable(&self, error: &str) -> bool {
let error_lower = error.to_lowercase();
!self
.non_retryable_errors
.iter()
.any(|e| error_lower.contains(e))
}
}
#[derive(Debug)]
pub struct RetryState {
attempt: u32,
strategy: RetryStrategy,
total_delay: Duration,
}
impl RetryState {
#[must_use]
pub const fn new(strategy: RetryStrategy) -> Self {
Self {
attempt: 0,
strategy,
total_delay: Duration::ZERO,
}
}
#[must_use]
pub const fn attempt(&self) -> u32 {
self.attempt
}
#[must_use]
pub fn should_retry(&self) -> bool {
self.strategy.should_retry(self.attempt)
}
#[must_use]
pub fn next_delay(&self) -> Option<Duration> {
self.strategy.delay_for_attempt(self.attempt)
}
pub fn record_attempt(&mut self) {
if let Some(delay) = self.next_delay() {
self.total_delay += delay;
}
self.attempt += 1;
}
#[must_use]
pub const fn total_delay(&self) -> Duration {
self.total_delay
}
pub const fn reset(&mut self) {
self.attempt = 0;
self.total_delay = Duration::ZERO;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fixed_strategy() {
let strategy = RetryStrategy::fixed(Duration::from_millis(100), 3);
assert!(strategy.should_retry(0));
assert!(strategy.should_retry(2));
assert!(!strategy.should_retry(3));
}
#[test]
fn exponential_strategy() {
let strategy = RetryStrategy::exponential(Duration::from_millis(100), 3);
let d0 = strategy.delay_for_attempt(0).unwrap();
let d1 = strategy.delay_for_attempt(1).unwrap();
let d2 = strategy.delay_for_attempt(2).unwrap();
assert!(d1 > d0);
assert!(d2 > d1);
}
#[test]
fn retry_state() {
let strategy = RetryStrategy::fixed(Duration::from_millis(100), 2);
let mut state = RetryState::new(strategy);
assert!(state.should_retry());
state.record_attempt();
assert!(state.should_retry());
state.record_attempt();
assert!(!state.should_retry());
}
}