use rand::Rng;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, PartialEq)]
pub enum JitterType {
Additive(Duration),
Multiplicative(f64),
}
impl JitterType {
pub fn apply(&self, delay: Duration) -> Duration {
let mut rng = rand::thread_rng();
match self {
JitterType::Additive(jitter_amount) => {
let jitter_millis = rng.gen_range(0..=jitter_amount.as_millis() as u64);
let jitter = Duration::from_millis(jitter_millis);
if rng.gen_bool(0.5) {
delay + jitter
} else {
delay.saturating_sub(jitter)
}
}
JitterType::Multiplicative(factor) => {
let jitter_factor = rng.gen_range((1.0 - factor)..=(1.0 + factor));
let jittered_millis = (delay.as_millis() as f64 * jitter_factor) as u64;
Duration::from_millis(jittered_millis)
}
}
}
}
impl Serialize for JitterType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
match self {
JitterType::Additive(duration) => {
let mut state = serializer.serialize_struct("JitterType", 2)?;
state.serialize_field("type", "Additive")?;
state.serialize_field("duration_ms", &(duration.as_millis() as u64))?;
state.end()
}
JitterType::Multiplicative(factor) => {
let mut state = serializer.serialize_struct("JitterType", 2)?;
state.serialize_field("type", "Multiplicative")?;
state.serialize_field("factor", factor)?;
state.end()
}
}
}
}
impl<'de> Deserialize<'de> for JitterType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, MapAccess, Visitor};
use std::fmt;
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Type,
DurationMs,
Factor,
}
struct JitterTypeVisitor;
impl<'de> Visitor<'de> for JitterTypeVisitor {
type Value = JitterType;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a jitter type")
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut jitter_type: Option<String> = None;
let mut duration_ms: Option<u64> = None;
let mut factor: Option<f64> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"type" => {
if jitter_type.is_some() {
return Err(de::Error::duplicate_field("type"));
}
jitter_type = Some(map.next_value()?);
}
"duration_ms" => {
if duration_ms.is_some() {
return Err(de::Error::duplicate_field("duration_ms"));
}
duration_ms = Some(map.next_value()?);
}
"factor" => {
if factor.is_some() {
return Err(de::Error::duplicate_field("factor"));
}
factor = Some(map.next_value()?);
}
_ => {
let _: serde::de::IgnoredAny = map.next_value()?;
}
}
}
let jitter_type = jitter_type.ok_or_else(|| de::Error::missing_field("type"))?;
match jitter_type.as_str() {
"Additive" => {
let duration_ms =
duration_ms.ok_or_else(|| de::Error::missing_field("duration_ms"))?;
Ok(JitterType::Additive(Duration::from_millis(duration_ms)))
}
"Multiplicative" => {
let factor = factor.ok_or_else(|| de::Error::missing_field("factor"))?;
Ok(JitterType::Multiplicative(factor))
}
_ => Err(de::Error::unknown_variant(
&jitter_type,
&["Additive", "Multiplicative"],
)),
}
}
}
deserializer.deserialize_struct(
"JitterType",
&["type", "duration_ms", "factor"],
JitterTypeVisitor,
)
}
}
pub enum RetryStrategy {
Fixed(Duration),
Linear {
base: Duration,
increment: Duration,
max_delay: Option<Duration>,
},
Exponential {
base: Duration,
multiplier: f64,
max_delay: Option<Duration>,
jitter: Option<JitterType>,
},
Fibonacci {
base: Duration,
max_delay: Option<Duration>,
},
Custom(Box<dyn Fn(u32) -> Duration + Send + Sync>),
}
impl RetryStrategy {
pub fn calculate_delay(&self, attempt: u32) -> Duration {
let base_delay = match self {
RetryStrategy::Fixed(delay) => *delay,
RetryStrategy::Linear {
base,
increment,
max_delay,
} => {
let delay = *base + increment.mul_f64(attempt as f64);
if let Some(max) = max_delay {
delay.min(*max)
} else {
delay
}
}
RetryStrategy::Exponential {
base,
multiplier,
max_delay,
jitter,
} => {
let delay_multiplier = multiplier.powi((attempt.saturating_sub(1)) as i32);
let delay = base.mul_f64(delay_multiplier);
let capped_delay = if let Some(max) = max_delay {
delay.min(*max)
} else {
delay
};
if let Some(jitter_type) = jitter {
return jitter_type.apply(capped_delay);
}
capped_delay
}
RetryStrategy::Fibonacci { base, max_delay } => {
let fib_number = fibonacci(attempt);
let delay = base.mul_f64(fib_number as f64);
if let Some(max) = max_delay {
delay.min(*max)
} else {
delay
}
}
RetryStrategy::Custom(func) => func(attempt),
};
base_delay.max(Duration::from_millis(1))
}
pub fn fixed(delay: Duration) -> Self {
RetryStrategy::Fixed(delay)
}
pub fn linear(base: Duration, increment: Duration, max_delay: Option<Duration>) -> Self {
RetryStrategy::Linear {
base,
increment,
max_delay,
}
}
pub fn exponential(base: Duration, multiplier: f64, max_delay: Option<Duration>) -> Self {
RetryStrategy::Exponential {
base,
multiplier,
max_delay,
jitter: None,
}
}
pub fn exponential_with_jitter(
base: Duration,
multiplier: f64,
max_delay: Option<Duration>,
jitter: JitterType,
) -> Self {
RetryStrategy::Exponential {
base,
multiplier,
max_delay,
jitter: Some(jitter),
}
}
pub fn fibonacci(base: Duration, max_delay: Option<Duration>) -> Self {
RetryStrategy::Fibonacci { base, max_delay }
}
pub fn custom<F>(func: F) -> Self
where
F: Fn(u32) -> Duration + Send + Sync + 'static,
{
RetryStrategy::Custom(Box::new(func))
}
}
pub fn fibonacci(n: u32) -> u64 {
if n == 0 {
return 0;
}
if n <= 2 {
return 1;
}
let mut prev = 1u64;
let mut curr = 1u64;
for _ in 3..=n {
let next = prev.saturating_add(curr);
prev = curr;
curr = next;
}
curr
}
impl std::fmt::Debug for RetryStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RetryStrategy::Fixed(duration) => f.debug_tuple("Fixed").field(duration).finish(),
RetryStrategy::Linear {
base,
increment,
max_delay,
} => f
.debug_struct("Linear")
.field("base", base)
.field("increment", increment)
.field("max_delay", max_delay)
.finish(),
RetryStrategy::Exponential {
base,
multiplier,
max_delay,
jitter,
} => f
.debug_struct("Exponential")
.field("base", base)
.field("multiplier", multiplier)
.field("max_delay", max_delay)
.field("jitter", jitter)
.finish(),
RetryStrategy::Fibonacci { base, max_delay } => f
.debug_struct("Fibonacci")
.field("base", base)
.field("max_delay", max_delay)
.finish(),
RetryStrategy::Custom(_) => f.write_str("Custom(<function>)"),
}
}
}
impl Clone for RetryStrategy {
fn clone(&self) -> Self {
match self {
RetryStrategy::Fixed(duration) => RetryStrategy::Fixed(*duration),
RetryStrategy::Linear {
base,
increment,
max_delay,
} => RetryStrategy::Linear {
base: *base,
increment: *increment,
max_delay: *max_delay,
},
RetryStrategy::Exponential {
base,
multiplier,
max_delay,
jitter,
} => RetryStrategy::Exponential {
base: *base,
multiplier: *multiplier,
max_delay: *max_delay,
jitter: jitter.clone(),
},
RetryStrategy::Fibonacci { base, max_delay } => RetryStrategy::Fibonacci {
base: *base,
max_delay: *max_delay,
},
RetryStrategy::Custom(_) => {
panic!("Cannot clone custom retry strategy functions")
}
}
}
}
impl PartialEq for RetryStrategy {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(RetryStrategy::Fixed(a), RetryStrategy::Fixed(b)) => a == b,
(
RetryStrategy::Linear {
base: a_base,
increment: a_inc,
max_delay: a_max,
},
RetryStrategy::Linear {
base: b_base,
increment: b_inc,
max_delay: b_max,
},
) => a_base == b_base && a_inc == b_inc && a_max == b_max,
(
RetryStrategy::Exponential {
base: a_base,
multiplier: a_mult,
max_delay: a_max,
jitter: a_jitter,
},
RetryStrategy::Exponential {
base: b_base,
multiplier: b_mult,
max_delay: b_max,
jitter: b_jitter,
},
) => a_base == b_base && a_mult == b_mult && a_max == b_max && a_jitter == b_jitter,
(
RetryStrategy::Fibonacci {
base: a_base,
max_delay: a_max,
},
RetryStrategy::Fibonacci {
base: b_base,
max_delay: b_max,
},
) => a_base == b_base && a_max == b_max,
(RetryStrategy::Custom(_), RetryStrategy::Custom(_)) => false, _ => false,
}
}
}
impl Serialize for RetryStrategy {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
match self {
RetryStrategy::Fixed(duration) => {
let mut state = serializer.serialize_struct("RetryStrategy", 2)?;
state.serialize_field("type", "Fixed")?;
state.serialize_field("duration_ms", &(duration.as_millis() as u64))?;
state.end()
}
RetryStrategy::Linear {
base,
increment,
max_delay,
} => {
let mut state = serializer.serialize_struct("RetryStrategy", 4)?;
state.serialize_field("type", "Linear")?;
state.serialize_field("base_ms", &(base.as_millis() as u64))?;
state.serialize_field("increment_ms", &(increment.as_millis() as u64))?;
state.serialize_field("max_delay_ms", &max_delay.map(|d| d.as_millis() as u64))?;
state.end()
}
RetryStrategy::Exponential {
base,
multiplier,
max_delay,
jitter,
} => {
let mut state = serializer.serialize_struct("RetryStrategy", 5)?;
state.serialize_field("type", "Exponential")?;
state.serialize_field("base_ms", &(base.as_millis() as u64))?;
state.serialize_field("multiplier", multiplier)?;
state.serialize_field("max_delay_ms", &max_delay.map(|d| d.as_millis() as u64))?;
state.serialize_field("jitter", jitter)?;
state.end()
}
RetryStrategy::Fibonacci { base, max_delay } => {
let mut state = serializer.serialize_struct("RetryStrategy", 3)?;
state.serialize_field("type", "Fibonacci")?;
state.serialize_field("base_ms", &(base.as_millis() as u64))?;
state.serialize_field("max_delay_ms", &max_delay.map(|d| d.as_millis() as u64))?;
state.end()
}
RetryStrategy::Custom(_) => Err(serde::ser::Error::custom(
"Cannot serialize custom retry strategy functions",
)),
}
}
}
impl<'de> Deserialize<'de> for RetryStrategy {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, MapAccess, Visitor};
use std::fmt;
struct RetryStrategyVisitor;
impl<'de> Visitor<'de> for RetryStrategyVisitor {
type Value = RetryStrategy;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a retry strategy")
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut strategy_type: Option<String> = None;
let mut duration_ms: Option<u64> = None;
let mut base_ms: Option<u64> = None;
let mut increment_ms: Option<u64> = None;
let mut max_delay_ms: Option<Option<u64>> = None;
let mut multiplier: Option<f64> = None;
let mut jitter: Option<Option<JitterType>> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"type" => {
if strategy_type.is_some() {
return Err(de::Error::duplicate_field("type"));
}
strategy_type = Some(map.next_value()?);
}
"duration_ms" => {
if duration_ms.is_some() {
return Err(de::Error::duplicate_field("duration_ms"));
}
duration_ms = Some(map.next_value()?);
}
"base_ms" => {
if base_ms.is_some() {
return Err(de::Error::duplicate_field("base_ms"));
}
base_ms = Some(map.next_value()?);
}
"increment_ms" => {
if increment_ms.is_some() {
return Err(de::Error::duplicate_field("increment_ms"));
}
increment_ms = Some(map.next_value()?);
}
"max_delay_ms" => {
if max_delay_ms.is_some() {
return Err(de::Error::duplicate_field("max_delay_ms"));
}
max_delay_ms = Some(map.next_value()?);
}
"multiplier" => {
if multiplier.is_some() {
return Err(de::Error::duplicate_field("multiplier"));
}
multiplier = Some(map.next_value()?);
}
"jitter" => {
if jitter.is_some() {
return Err(de::Error::duplicate_field("jitter"));
}
jitter = Some(map.next_value()?);
}
_ => {
let _: serde::de::IgnoredAny = map.next_value()?;
}
}
}
let strategy_type =
strategy_type.ok_or_else(|| de::Error::missing_field("type"))?;
match strategy_type.as_str() {
"Fixed" => {
let duration_ms =
duration_ms.ok_or_else(|| de::Error::missing_field("duration_ms"))?;
Ok(RetryStrategy::Fixed(Duration::from_millis(duration_ms)))
}
"Linear" => {
let base_ms = base_ms.ok_or_else(|| de::Error::missing_field("base_ms"))?;
let increment_ms =
increment_ms.ok_or_else(|| de::Error::missing_field("increment_ms"))?;
let max_delay_ms =
max_delay_ms.ok_or_else(|| de::Error::missing_field("max_delay_ms"))?;
Ok(RetryStrategy::Linear {
base: Duration::from_millis(base_ms),
increment: Duration::from_millis(increment_ms),
max_delay: max_delay_ms.map(Duration::from_millis),
})
}
"Exponential" => {
let base_ms = base_ms.ok_or_else(|| de::Error::missing_field("base_ms"))?;
let multiplier =
multiplier.ok_or_else(|| de::Error::missing_field("multiplier"))?;
let max_delay_ms =
max_delay_ms.ok_or_else(|| de::Error::missing_field("max_delay_ms"))?;
let jitter = jitter.unwrap_or(None);
Ok(RetryStrategy::Exponential {
base: Duration::from_millis(base_ms),
multiplier,
max_delay: max_delay_ms.map(Duration::from_millis),
jitter,
})
}
"Fibonacci" => {
let base_ms = base_ms.ok_or_else(|| de::Error::missing_field("base_ms"))?;
let max_delay_ms =
max_delay_ms.ok_or_else(|| de::Error::missing_field("max_delay_ms"))?;
Ok(RetryStrategy::Fibonacci {
base: Duration::from_millis(base_ms),
max_delay: max_delay_ms.map(Duration::from_millis),
})
}
_ => Err(de::Error::unknown_variant(
&strategy_type,
&["Fixed", "Linear", "Exponential", "Fibonacci"],
)),
}
}
}
deserializer.deserialize_struct(
"RetryStrategy",
&[
"type",
"duration_ms",
"base_ms",
"increment_ms",
"max_delay_ms",
"multiplier",
"jitter",
],
RetryStrategyVisitor,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_fibonacci_sequence() {
assert_eq!(fibonacci(0), 0);
assert_eq!(fibonacci(1), 1);
assert_eq!(fibonacci(2), 1);
assert_eq!(fibonacci(3), 2);
assert_eq!(fibonacci(4), 3);
assert_eq!(fibonacci(5), 5);
assert_eq!(fibonacci(6), 8);
assert_eq!(fibonacci(7), 13);
assert_eq!(fibonacci(8), 21);
assert_eq!(fibonacci(9), 34);
assert_eq!(fibonacci(10), 55);
}
#[test]
fn test_fixed_retry_strategy() {
let strategy = RetryStrategy::Fixed(Duration::from_secs(30));
assert_eq!(strategy.calculate_delay(1), Duration::from_secs(30));
assert_eq!(strategy.calculate_delay(5), Duration::from_secs(30));
assert_eq!(strategy.calculate_delay(10), Duration::from_secs(30));
}
#[test]
fn test_linear_retry_strategy() {
let strategy = RetryStrategy::Linear {
base: Duration::from_secs(10),
increment: Duration::from_secs(5),
max_delay: Some(Duration::from_secs(40)),
};
assert_eq!(strategy.calculate_delay(1), Duration::from_secs(15)); assert_eq!(strategy.calculate_delay(2), Duration::from_secs(20)); assert_eq!(strategy.calculate_delay(3), Duration::from_secs(25)); assert_eq!(strategy.calculate_delay(6), Duration::from_secs(40)); assert_eq!(strategy.calculate_delay(10), Duration::from_secs(40)); }
#[test]
fn test_exponential_retry_strategy() {
let strategy = RetryStrategy::Exponential {
base: Duration::from_secs(1),
multiplier: 2.0,
max_delay: Some(Duration::from_secs(60)),
jitter: None,
};
assert_eq!(strategy.calculate_delay(1), Duration::from_secs(1)); assert_eq!(strategy.calculate_delay(2), Duration::from_secs(2)); assert_eq!(strategy.calculate_delay(3), Duration::from_secs(4)); assert_eq!(strategy.calculate_delay(4), Duration::from_secs(8)); assert_eq!(strategy.calculate_delay(5), Duration::from_secs(16)); assert_eq!(strategy.calculate_delay(6), Duration::from_secs(32)); assert_eq!(strategy.calculate_delay(7), Duration::from_secs(60)); assert_eq!(strategy.calculate_delay(10), Duration::from_secs(60)); }
#[test]
fn test_fibonacci_retry_strategy() {
let strategy = RetryStrategy::Fibonacci {
base: Duration::from_secs(2),
max_delay: Some(Duration::from_secs(100)),
};
assert_eq!(strategy.calculate_delay(1), Duration::from_secs(2)); assert_eq!(strategy.calculate_delay(2), Duration::from_secs(2)); assert_eq!(strategy.calculate_delay(3), Duration::from_secs(4)); assert_eq!(strategy.calculate_delay(4), Duration::from_secs(6)); assert_eq!(strategy.calculate_delay(5), Duration::from_secs(10)); assert_eq!(strategy.calculate_delay(6), Duration::from_secs(16)); assert_eq!(strategy.calculate_delay(7), Duration::from_secs(26)); }
#[test]
fn test_custom_retry_strategy() {
let strategy = RetryStrategy::Custom(Box::new(|attempt| match attempt {
1..=3 => Duration::from_secs(5),
4..=6 => Duration::from_secs(30),
_ => Duration::from_secs(300),
}));
assert_eq!(strategy.calculate_delay(1), Duration::from_secs(5));
assert_eq!(strategy.calculate_delay(3), Duration::from_secs(5));
assert_eq!(strategy.calculate_delay(4), Duration::from_secs(30));
assert_eq!(strategy.calculate_delay(6), Duration::from_secs(30));
assert_eq!(strategy.calculate_delay(7), Duration::from_secs(300));
assert_eq!(strategy.calculate_delay(100), Duration::from_secs(300));
}
#[test]
fn test_additive_jitter() {
let jitter = JitterType::Additive(Duration::from_secs(10));
let base_delay = Duration::from_secs(60);
for _ in 0..100 {
let jittered = jitter.apply(base_delay);
assert!(jittered >= Duration::from_secs(50)); assert!(jittered <= Duration::from_secs(70)); }
}
#[test]
fn test_multiplicative_jitter() {
let jitter = JitterType::Multiplicative(0.2); let base_delay = Duration::from_secs(100);
for _ in 0..100 {
let jittered = jitter.apply(base_delay);
assert!(jittered >= Duration::from_secs(80)); assert!(jittered <= Duration::from_secs(120)); }
}
#[test]
fn test_exponential_with_jitter() {
let strategy = RetryStrategy::Exponential {
base: Duration::from_secs(1),
multiplier: 2.0,
max_delay: None,
jitter: Some(JitterType::Multiplicative(0.1)), };
let delay = strategy.calculate_delay(1);
assert!(delay >= Duration::from_millis(900)); assert!(delay <= Duration::from_millis(1100));
let delay = strategy.calculate_delay(3);
assert!(delay >= Duration::from_millis(3600)); assert!(delay <= Duration::from_millis(4400)); }
#[test]
fn test_strategy_builder_methods() {
let fixed = RetryStrategy::fixed(Duration::from_secs(30));
assert_eq!(fixed.calculate_delay(1), Duration::from_secs(30));
let linear = RetryStrategy::linear(
Duration::from_secs(5),
Duration::from_secs(10),
Some(Duration::from_secs(120)),
);
assert_eq!(linear.calculate_delay(1), Duration::from_secs(15));
let exponential =
RetryStrategy::exponential(Duration::from_secs(1), 2.0, Some(Duration::from_secs(600)));
assert_eq!(exponential.calculate_delay(1), Duration::from_secs(1));
let fibonacci =
RetryStrategy::fibonacci(Duration::from_secs(2), Some(Duration::from_secs(300)));
assert_eq!(fibonacci.calculate_delay(1), Duration::from_secs(2));
}
#[test]
fn test_minimum_delay_enforcement() {
let strategy = RetryStrategy::Custom(Box::new(|_| Duration::from_millis(0)));
assert_eq!(strategy.calculate_delay(1), Duration::from_millis(1));
}
#[test]
fn test_serialization() {
let strategies = vec![
RetryStrategy::Fixed(Duration::from_secs(30)),
RetryStrategy::Linear {
base: Duration::from_secs(10),
increment: Duration::from_secs(5),
max_delay: Some(Duration::from_secs(60)),
},
RetryStrategy::Exponential {
base: Duration::from_secs(1),
multiplier: 2.0,
max_delay: Some(Duration::from_secs(600)),
jitter: None, },
RetryStrategy::Fibonacci {
base: Duration::from_secs(2),
max_delay: Some(Duration::from_secs(300)),
},
];
for strategy in strategies {
let serialized = serde_json::to_string(&strategy).unwrap();
let deserialized: RetryStrategy = serde_json::from_str(&serialized).unwrap();
if !matches!(strategy, RetryStrategy::Custom(_)) {
assert_eq!(strategy.calculate_delay(1), deserialized.calculate_delay(1));
assert_eq!(strategy.calculate_delay(3), deserialized.calculate_delay(3));
}
}
}
}