use std::time::Duration;
#[derive(Debug, Clone, PartialEq)]
pub struct RetryPolicy {
strategy: RetryStrategy,
max_retries: Option<u32>,
max_delay: Option<Duration>,
jitter: JitterStrategy,
}
#[derive(Debug, Clone, PartialEq)]
pub enum RetryStrategy {
Constant(Duration),
Linear {
base: Duration,
},
Exponential {
base: Duration,
},
Fibonacci {
base: Duration,
},
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum JitterStrategy {
#[default]
None,
Proportional(f64),
Full,
Decorrelated,
}
#[derive(Debug, Clone)]
pub struct RetryEvent<'a, E> {
pub attempt: u32,
pub error: &'a E,
pub next_delay: Option<Duration>,
pub elapsed: Duration,
}
impl RetryPolicy {
pub fn constant(delay: Duration) -> Self {
Self {
strategy: RetryStrategy::Constant(delay),
max_retries: None,
max_delay: None,
jitter: JitterStrategy::None,
}
}
pub fn linear(base: Duration) -> Self {
Self {
strategy: RetryStrategy::Linear { base },
max_retries: None,
max_delay: None,
jitter: JitterStrategy::None,
}
}
pub fn exponential(base: Duration) -> Self {
Self {
strategy: RetryStrategy::Exponential { base },
max_retries: None,
max_delay: None,
jitter: JitterStrategy::None,
}
}
pub fn fibonacci(base: Duration) -> Self {
Self {
strategy: RetryStrategy::Fibonacci { base },
max_retries: None,
max_delay: None,
jitter: JitterStrategy::None,
}
}
pub fn with_max_retries(mut self, n: u32) -> Self {
self.max_retries = Some(n);
self
}
pub fn with_max_delay(mut self, d: Duration) -> Self {
self.max_delay = Some(d);
self
}
pub fn with_jitter(mut self, factor: f64) -> Self {
self.jitter = JitterStrategy::Proportional(factor.clamp(0.0, 1.0));
self
}
pub fn with_full_jitter(mut self) -> Self {
self.jitter = JitterStrategy::Full;
self
}
pub fn with_decorrelated_jitter(mut self) -> Self {
self.jitter = JitterStrategy::Decorrelated;
self
}
pub fn max_retries(&self) -> Option<u32> {
self.max_retries
}
pub fn max_delay(&self) -> Option<Duration> {
self.max_delay
}
pub fn jitter(&self) -> &JitterStrategy {
&self.jitter
}
pub fn strategy(&self) -> &RetryStrategy {
&self.strategy
}
pub fn delay_for_attempt(&self, attempt: u32) -> Option<Duration> {
if let Some(max) = self.max_retries {
if attempt >= max {
return None;
}
}
let base_delay = match &self.strategy {
RetryStrategy::Constant(d) => *d,
RetryStrategy::Linear { base } => base.saturating_mul(attempt + 1),
RetryStrategy::Exponential { base } => {
base.saturating_mul(2u32.saturating_pow(attempt))
}
RetryStrategy::Fibonacci { base } => base.saturating_mul(fibonacci(attempt + 1)),
};
let capped = match self.max_delay {
Some(max) => base_delay.min(max),
None => base_delay,
};
Some(capped)
}
#[doc(hidden)]
pub fn delay_with_jitter(
&self,
attempt: u32,
prev_delay: Option<Duration>,
) -> Option<Duration> {
let base_delay = self.delay_for_attempt(attempt)?;
Some(self.jitter.apply(base_delay, prev_delay, self.max_delay))
}
pub fn validate(&self) -> Result<(), &'static str> {
if self.max_retries.is_none() && self.max_delay.is_none() {
Err("RetryPolicy must have at least one bound (max_retries or max_delay)")
} else {
Ok(())
}
}
}
impl JitterStrategy {
pub fn apply(
&self,
base_delay: Duration,
#[cfg_attr(not(feature = "jitter"), allow(unused_variables))] prev_delay: Option<Duration>,
max_delay: Option<Duration>,
) -> Duration {
let jittered = match self {
JitterStrategy::None => base_delay,
#[cfg(feature = "jitter")]
JitterStrategy::Proportional(factor) => {
use rand::Rng;
let mut rng = rand::rng();
let base_millis = base_delay.as_millis() as f64;
let jitter_range = base_millis * factor;
let min = (base_millis - jitter_range).max(0.0);
let max = base_millis + jitter_range;
let jittered_millis = rng.random_range(min..=max);
Duration::from_millis(jittered_millis as u64)
}
#[cfg(not(feature = "jitter"))]
JitterStrategy::Proportional(_) => base_delay,
#[cfg(feature = "jitter")]
JitterStrategy::Full => {
use rand::Rng;
let mut rng = rand::rng();
let max_millis = base_delay.as_millis() as u64;
if max_millis == 0 {
Duration::ZERO
} else {
Duration::from_millis(rng.random_range(0..=max_millis))
}
}
#[cfg(not(feature = "jitter"))]
JitterStrategy::Full => base_delay,
#[cfg(feature = "jitter")]
JitterStrategy::Decorrelated => {
use rand::Rng;
let mut rng = rand::rng();
let prev = prev_delay.unwrap_or(base_delay);
let base_millis = base_delay.as_millis() as u64;
let max_millis = prev.as_millis().saturating_mul(3) as u64;
if max_millis <= base_millis {
base_delay
} else {
Duration::from_millis(rng.random_range(base_millis..=max_millis))
}
}
#[cfg(not(feature = "jitter"))]
JitterStrategy::Decorrelated => base_delay,
};
match max_delay {
Some(max) => jittered.min(max),
None => jittered,
}
}
}
fn fibonacci(n: u32) -> u32 {
if n == 0 {
return 0;
}
let mut a = 0u32;
let mut b = 1u32;
for _ in 1..n {
let temp = a.saturating_add(b);
a = b;
b = temp;
}
b
}
#[cfg(test)]
mod policy_tests {
use super::*;
#[test]
fn test_constant_delay() {
let policy = RetryPolicy::constant(Duration::from_millis(100)).with_max_retries(3);
assert_eq!(
policy.delay_for_attempt(0),
Some(Duration::from_millis(100))
);
assert_eq!(
policy.delay_for_attempt(1),
Some(Duration::from_millis(100))
);
assert_eq!(
policy.delay_for_attempt(2),
Some(Duration::from_millis(100))
);
assert_eq!(policy.delay_for_attempt(3), None);
}
#[test]
fn test_linear_delay() {
let policy = RetryPolicy::linear(Duration::from_millis(100)).with_max_retries(5);
assert_eq!(
policy.delay_for_attempt(0),
Some(Duration::from_millis(100))
);
assert_eq!(
policy.delay_for_attempt(1),
Some(Duration::from_millis(200))
);
assert_eq!(
policy.delay_for_attempt(2),
Some(Duration::from_millis(300))
);
assert_eq!(
policy.delay_for_attempt(3),
Some(Duration::from_millis(400))
);
}
#[test]
fn test_exponential_delay() {
let policy = RetryPolicy::exponential(Duration::from_millis(100)).with_max_retries(5);
assert_eq!(
policy.delay_for_attempt(0),
Some(Duration::from_millis(100))
);
assert_eq!(
policy.delay_for_attempt(1),
Some(Duration::from_millis(200))
);
assert_eq!(
policy.delay_for_attempt(2),
Some(Duration::from_millis(400))
);
assert_eq!(
policy.delay_for_attempt(3),
Some(Duration::from_millis(800))
);
}
#[test]
fn test_fibonacci_delay() {
let policy = RetryPolicy::fibonacci(Duration::from_millis(100)).with_max_retries(6);
assert_eq!(
policy.delay_for_attempt(0),
Some(Duration::from_millis(100))
);
assert_eq!(
policy.delay_for_attempt(1),
Some(Duration::from_millis(100))
);
assert_eq!(
policy.delay_for_attempt(2),
Some(Duration::from_millis(200))
);
assert_eq!(
policy.delay_for_attempt(3),
Some(Duration::from_millis(300))
);
assert_eq!(
policy.delay_for_attempt(4),
Some(Duration::from_millis(500))
);
assert_eq!(
policy.delay_for_attempt(5),
Some(Duration::from_millis(800))
);
}
#[test]
fn test_max_delay_cap() {
let policy = RetryPolicy::exponential(Duration::from_millis(100))
.with_max_retries(10)
.with_max_delay(Duration::from_millis(500));
assert_eq!(
policy.delay_for_attempt(0),
Some(Duration::from_millis(100))
);
assert_eq!(
policy.delay_for_attempt(1),
Some(Duration::from_millis(200))
);
assert_eq!(
policy.delay_for_attempt(2),
Some(Duration::from_millis(400))
);
assert_eq!(
policy.delay_for_attempt(3),
Some(Duration::from_millis(500))
); assert_eq!(
policy.delay_for_attempt(4),
Some(Duration::from_millis(500))
); }
#[test]
fn test_max_retries_limit() {
let policy = RetryPolicy::constant(Duration::from_millis(100)).with_max_retries(2);
assert!(policy.delay_for_attempt(0).is_some());
assert!(policy.delay_for_attempt(1).is_some());
assert!(policy.delay_for_attempt(2).is_none());
}
#[test]
fn test_policy_is_clone() {
let policy = RetryPolicy::exponential(Duration::from_millis(100)).with_max_retries(3);
let cloned = policy.clone();
assert_eq!(policy, cloned);
}
#[test]
fn test_policy_is_debug() {
let policy = RetryPolicy::exponential(Duration::from_millis(100)).with_max_retries(3);
let debug = format!("{:?}", policy);
assert!(debug.contains("RetryPolicy"));
}
#[test]
fn test_fibonacci_function() {
assert_eq!(fibonacci(0), 0);
assert_eq!(fibonacci(1), 1);
assert_eq!(fibonacci(2), 1);
assert_eq!(fibonacci(3), 2);
assert_eq!(fibonacci(4), 3);
assert_eq!(fibonacci(5), 5);
assert_eq!(fibonacci(6), 8);
assert_eq!(fibonacci(7), 13);
}
#[test]
fn test_validate_with_max_retries() {
let policy = RetryPolicy::constant(Duration::from_millis(100)).with_max_retries(3);
assert!(policy.validate().is_ok());
}
#[test]
fn test_validate_with_max_delay() {
let policy = RetryPolicy::constant(Duration::from_millis(100))
.with_max_delay(Duration::from_secs(5));
assert!(policy.validate().is_ok());
}
#[test]
fn test_validate_with_both_bounds() {
let policy = RetryPolicy::constant(Duration::from_millis(100))
.with_max_retries(3)
.with_max_delay(Duration::from_secs(5));
assert!(policy.validate().is_ok());
}
#[test]
fn test_validate_no_bounds() {
let policy = RetryPolicy::constant(Duration::from_millis(100));
assert!(policy.validate().is_err());
}
#[test]
fn test_jitter_strategy_default() {
let jitter = JitterStrategy::default();
assert_eq!(jitter, JitterStrategy::None);
}
#[test]
fn test_jitter_none_returns_base_delay() {
let jitter = JitterStrategy::None;
let base = Duration::from_millis(100);
let result = jitter.apply(base, None, None);
assert_eq!(result, base);
}
#[test]
fn test_policy_getters() {
let policy = RetryPolicy::exponential(Duration::from_millis(100))
.with_max_retries(3)
.with_max_delay(Duration::from_secs(5))
.with_jitter(0.25);
assert_eq!(policy.max_retries(), Some(3));
assert_eq!(policy.max_delay(), Some(Duration::from_secs(5)));
assert!(matches!(policy.jitter(), JitterStrategy::Proportional(_)));
assert!(matches!(
policy.strategy(),
RetryStrategy::Exponential { .. }
));
}
}