use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetryAction {
ShouldRetry(Duration),
NoRetry,
PermanentFailure,
}
pub trait RetryPolicy: Send + Sync {
fn should_retry(&self, attempt: u32) -> Option<Duration>;
fn max_retries(&self) -> u32 {
u32::MAX
}
}
#[derive(Debug, Clone)]
pub struct ExponentialBackoff {
max_retries: u32,
initial_delay: Duration,
max_delay: Duration,
jitter: bool,
}
impl ExponentialBackoff {
pub fn new(max_retries: u32, initial_delay: Duration) -> Self {
Self {
max_retries,
initial_delay,
max_delay: Duration::from_secs(30),
jitter: false,
}
}
pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
self.max_delay = max_delay;
self
}
pub fn with_jitter(mut self) -> Self {
self.jitter = true;
self
}
fn calculate_delay(&self, attempt: u32) -> f64 {
let base_delay_ms = self.initial_delay.as_millis() as f64;
let max_delay_ms = self.max_delay.as_millis() as f64;
let delay = base_delay_ms * 2f64.powi(attempt as i32);
let delay = if delay > max_delay_ms {
max_delay_ms
} else {
delay
};
if self.jitter {
use std::time::Instant;
let now = Instant::now();
let nanos = now.elapsed().as_nanos() as f64;
let jitter_range = delay * 0.2;
let jitter = nanos as f64 % jitter_range;
delay - jitter_range / 2.0 + jitter
} else {
delay
}
}
}
impl RetryPolicy for ExponentialBackoff {
fn should_retry(&self, attempt: u32) -> Option<Duration> {
if attempt >= self.max_retries {
return None;
}
let delay_ms = self.calculate_delay(attempt);
Some(Duration::from_millis(delay_ms as u64))
}
fn max_retries(&self) -> u32 {
self.max_retries
}
}
#[derive(Debug, Clone)]
pub struct FixedDelay {
max_retries: u32,
delay: Duration,
}
impl FixedDelay {
pub fn new(max_retries: u32, delay: Duration) -> Self {
Self { max_retries, delay }
}
}
impl RetryPolicy for FixedDelay {
fn should_retry(&self, attempt: u32) -> Option<Duration> {
if attempt >= self.max_retries {
return None;
}
Some(self.delay)
}
fn max_retries(&self) -> u32 {
self.max_retries
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoRetry;
impl RetryPolicy for NoRetry {
fn should_retry(&self, _attempt: u32) -> Option<Duration> {
None
}
fn max_retries(&self) -> u32 {
0
}
}
pub struct TransientFilter<P> {
inner: P,
predicate: Box<dyn Fn(&str) -> bool + Send + Sync>,
}
impl<P: RetryPolicy> TransientFilter<P> {
pub fn new(policy: P, predicate: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
Self {
inner: policy,
predicate: Box::new(predicate),
}
}
}
impl<P: RetryPolicy> RetryPolicy for TransientFilter<P> {
fn should_retry(&self, attempt: u32) -> Option<Duration> {
self.inner.should_retry(attempt)
}
fn max_retries(&self) -> u32 {
self.inner.max_retries()
}
}
pub trait RetryPolicyExt: RetryPolicy + Sized {
fn delays(&self) -> DelayIterator<'_, Self> {
DelayIterator {
policy: self,
attempt: 0,
}
}
}
impl<T: RetryPolicy + Sized> RetryPolicyExt for T {}
#[derive(Debug)]
pub struct DelayIterator<'a, P: RetryPolicy> {
policy: &'a P,
attempt: u32,
}
impl<P: RetryPolicy> Iterator for DelayIterator<'_, P> {
type Item = Duration;
fn next(&mut self) -> Option<Self::Item> {
let delay = self.policy.should_retry(self.attempt);
self.attempt += 1;
delay
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exponential_backoff() {
let policy = ExponentialBackoff::new(3, Duration::from_millis(100));
assert_eq!(policy.should_retry(0), Some(Duration::from_millis(100)));
assert_eq!(policy.should_retry(1), Some(Duration::from_millis(200)));
assert_eq!(policy.should_retry(2), Some(Duration::from_millis(400)));
assert_eq!(policy.should_retry(3), None);
}
#[test]
fn test_exponential_backoff_with_max_delay() {
let policy = ExponentialBackoff::new(10, Duration::from_millis(100))
.with_max_delay(Duration::from_millis(500));
assert_eq!(policy.should_retry(0), Some(Duration::from_millis(100)));
assert_eq!(policy.should_retry(1), Some(Duration::from_millis(200)));
assert_eq!(policy.should_retry(2), Some(Duration::from_millis(400)));
assert_eq!(policy.should_retry(3), Some(Duration::from_millis(500)));
assert_eq!(policy.should_retry(4), Some(Duration::from_millis(500)));
}
#[test]
fn test_fixed_delay() {
let policy = FixedDelay::new(3, Duration::from_secs(1));
assert_eq!(policy.should_retry(0), Some(Duration::from_secs(1)));
assert_eq!(policy.should_retry(1), Some(Duration::from_secs(1)));
assert_eq!(policy.should_retry(2), Some(Duration::from_secs(1)));
assert_eq!(policy.should_retry(3), None);
}
#[test]
fn test_no_retry() {
let policy = NoRetry;
assert_eq!(policy.should_retry(0), None);
assert_eq!(policy.should_retry(1), None);
}
#[test]
fn test_delay_iterator() {
let policy = ExponentialBackoff::new(3, Duration::from_millis(100));
let delays: Vec<_> = policy.delays().collect();
assert_eq!(delays.len(), 3);
assert_eq!(delays[0], Duration::from_millis(100));
assert_eq!(delays[1], Duration::from_millis(200));
assert_eq!(delays[2], Duration::from_millis(400));
}
}