use std::str::FromStr;
use std::time::Duration;
use parse_display::{
Display,
DisplayFormat,
FromStr as DeriveFromStr,
FromStrFormat,
ParseError,
};
use rand::RngExt;
use serde::{
Deserialize,
Serialize,
};
use crate::RetryDelay;
use crate::constants::DEFAULT_RETRY_JITTER;
#[derive(Debug, Clone, Copy, PartialEq, Display, DeriveFromStr, Serialize, Deserialize)]
pub enum RetryJitter {
#[display("none")]
#[from_str(regex = r"(?i)\s*none\s*")]
None,
#[display("factor:{0}")]
#[from_str(regex = r"\s*factor:\s*(?<0>\S(?:.*\S)?)\s*")]
Factor(#[display(with = RetryJitterFactorFormat)] f64),
}
struct RetryJitterFactorFormat;
impl DisplayFormat<f64> for RetryJitterFactorFormat {
fn write(&self, f: &mut std::fmt::Formatter<'_>, value: &f64) -> std::fmt::Result {
write!(f, "{value}")
}
}
impl FromStrFormat<f64> for RetryJitterFactorFormat {
type Err = ParseError;
fn parse(&self, s: &str) -> Result<f64, Self::Err> {
let value = s
.parse::<f64>()
.map_err(|_| ParseError::with_message("invalid retry jitter factor"))?;
if !(0.0..=1.0).contains(&value) {
return Err(ParseError::with_message(
"retry jitter factor must be in range [0.0, 1.0]",
));
}
Ok(value)
}
}
impl RetryJitter {
#[inline]
pub fn none() -> Self {
Self::None
}
#[inline]
pub fn factor(factor: f64) -> Self {
Self::Factor(factor)
}
pub fn apply(&self, base: Duration) -> Duration {
match self {
Self::None => base,
Self::Factor(factor) if !factor.is_finite() || *factor <= 0.0 || base.is_zero() => base,
Self::Factor(factor) => {
let base_nanos_u128 = base.as_nanos();
if base_nanos_u128 > u64::MAX as u128 {
return base;
}
let base_nanos = base_nanos_u128 as f64;
let span = base_nanos * factor;
let mut rng = rand::rng();
let jitter = rng.random_range(-span..=span);
let nanos = (base_nanos + jitter).clamp(0.0, u64::MAX as f64) as u64;
Duration::from_nanos(nanos)
}
}
}
pub fn delay_for_attempt(&self, delay_strategy: &RetryDelay, attempt: u32) -> Duration {
let base_delay = delay_strategy.base_delay(attempt);
self.apply(base_delay)
}
pub fn validate(&self) -> Result<(), String> {
match self {
Self::None => Ok(()),
Self::Factor(factor) => {
if !factor.is_finite() || *factor < 0.0 || *factor > 1.0 {
Err("jitter factor must be finite and in range [0.0, 1.0]".to_string())
} else {
Ok(())
}
}
}
}
}
impl Default for RetryJitter {
#[inline]
fn default() -> Self {
Self::from_str(DEFAULT_RETRY_JITTER)
.expect("DEFAULT_RETRY_JITTER must be a valid RetryJitter string")
}
}