use std::time::Duration;
use super::CircuitBreaker;
const DEFAULT_MAX_ATTEMPTS: u32 = 5;
const DEFAULT_INITIAL_DELAY_MS: u32 = 1000;
const DEFAULT_MAX_DELAY_MS: u32 = 30000;
const DEFAULT_JITTER_FACTOR: f32 = 0.5;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_delay_ms: u32,
pub max_delay_ms: u32,
pub jitter_factor: f32,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: DEFAULT_MAX_ATTEMPTS,
initial_delay_ms: DEFAULT_INITIAL_DELAY_MS,
max_delay_ms: DEFAULT_MAX_DELAY_MS,
jitter_factor: DEFAULT_JITTER_FACTOR,
}
}
}
impl RetryPolicy {
pub fn no_retry() -> Self {
Self {
max_attempts: 1,
initial_delay_ms: 0,
max_delay_ms: 0,
jitter_factor: 0.0,
}
}
pub fn aggressive() -> Self {
Self {
max_attempts: 7,
initial_delay_ms: 500,
max_delay_ms: 10000,
jitter_factor: 0.3,
}
}
pub fn conservative() -> Self {
Self {
max_attempts: 3,
initial_delay_ms: 2000,
max_delay_ms: 60000,
jitter_factor: 0.5,
}
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
if attempt == 0 {
return Duration::ZERO;
}
let base_delay_ms = self
.initial_delay_ms
.saturating_mul(1 << (attempt - 1).min(10));
let capped_delay_ms = base_delay_ms.min(self.max_delay_ms);
let final_delay_ms = if self.jitter_factor > 0.0 {
let jitter_range = (capped_delay_ms as f32 * self.jitter_factor) as u32;
let jitter = random_u32() % (jitter_range * 2 + 1);
capped_delay_ms
.saturating_sub(jitter_range)
.saturating_add(jitter)
} else {
capped_delay_ms
};
Duration::from_millis(final_delay_ms as u64)
}
}
#[derive(Debug)]
pub enum RetryResult<T, E> {
Success(T),
Failure(E),
Exhausted {
last_error: E,
attempts: u32,
},
}
impl<T, E> RetryResult<T, E> {
pub fn into_result(self) -> Result<T, E> {
match self {
RetryResult::Success(v) => Ok(v),
RetryResult::Failure(e) => Err(e),
RetryResult::Exhausted { last_error, .. } => Err(last_error),
}
}
pub fn is_success(&self) -> bool {
matches!(self, RetryResult::Success(_))
}
}
pub trait RetryableError {
fn is_retryable(&self) -> bool;
fn retry_after(&self) -> Option<Duration>;
}
pub fn with_retry<T, E, F>(
policy: &RetryPolicy,
circuit: Option<&CircuitBreaker>,
mut operation: F,
) -> RetryResult<T, E>
where
F: FnMut() -> Result<T, E>,
E: RetryableError,
{
let mut last_error: Option<E> = None;
for attempt in 0..policy.max_attempts {
if let Some(cb) = circuit {
if !cb.can_execute() {
std::thread::sleep(Duration::from_millis(100));
continue;
}
}
let delay = if let Some(ref err) = last_error {
err.retry_after()
.unwrap_or_else(|| policy.delay_for_attempt(attempt))
} else {
policy.delay_for_attempt(attempt)
};
if !delay.is_zero() {
std::thread::sleep(delay);
}
match operation() {
Ok(result) => {
if let Some(cb) = circuit {
cb.record_success();
}
return RetryResult::Success(result);
}
Err(err) => {
if let Some(cb) = circuit {
cb.record_failure();
}
if !err.is_retryable() {
return RetryResult::Failure(err);
}
last_error = Some(err);
}
}
}
RetryResult::Exhausted {
last_error: last_error.expect("at least one attempt should have been made"),
attempts: policy.max_attempts,
}
}
fn random_u32() -> u32 {
use std::cell::Cell;
use std::time::SystemTime;
thread_local! {
static STATE: Cell<u32> = Cell::new(
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_nanos() as u32)
.unwrap_or(12345)
);
}
STATE.with(|state| {
let mut x = state.get();
x ^= x << 13;
x ^= x >> 17;
x ^= x << 5;
state.set(x);
x
})
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct TestError {
retryable: bool,
retry_after: Option<Duration>,
}
impl RetryableError for TestError {
fn is_retryable(&self) -> bool {
self.retryable
}
fn retry_after(&self) -> Option<Duration> {
self.retry_after
}
}
#[test]
fn test_default_policy() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_attempts, 5);
assert_eq!(policy.initial_delay_ms, 1000);
assert_eq!(policy.max_delay_ms, 30000);
assert_eq!(policy.jitter_factor, 0.5);
}
#[test]
fn test_no_retry_policy() {
let policy = RetryPolicy::no_retry();
assert_eq!(policy.max_attempts, 1);
}
#[test]
fn test_delay_calculation() {
let policy = RetryPolicy {
max_attempts: 5,
initial_delay_ms: 1000,
max_delay_ms: 30000,
jitter_factor: 0.0, };
assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(1000));
assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(2000));
assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(4000));
assert_eq!(policy.delay_for_attempt(4), Duration::from_millis(8000));
}
#[test]
fn test_delay_capped_at_max() {
let policy = RetryPolicy {
max_attempts: 10,
initial_delay_ms: 1000,
max_delay_ms: 5000,
jitter_factor: 0.0,
};
assert_eq!(policy.delay_for_attempt(5), Duration::from_millis(5000));
assert_eq!(policy.delay_for_attempt(6), Duration::from_millis(5000));
}
#[test]
fn test_retry_success_first_attempt() {
let policy = RetryPolicy::no_retry();
let mut call_count = 0;
let result: RetryResult<&str, TestError> = with_retry(&policy, None, || {
call_count += 1;
Ok("success")
});
assert!(result.is_success());
assert_eq!(call_count, 1);
}
#[test]
fn test_retry_success_after_failures() {
let policy = RetryPolicy {
max_attempts: 5,
initial_delay_ms: 1, max_delay_ms: 10,
jitter_factor: 0.0,
};
let mut call_count = 0;
let result: RetryResult<&str, TestError> = with_retry(&policy, None, || {
call_count += 1;
if call_count < 3 {
Err(TestError {
retryable: true,
retry_after: None,
})
} else {
Ok("success")
}
});
assert!(result.is_success());
assert_eq!(call_count, 3);
}
#[test]
fn test_retry_non_retryable_error() {
let policy = RetryPolicy::default();
let mut call_count = 0;
let result: RetryResult<&str, TestError> = with_retry(&policy, None, || {
call_count += 1;
Err(TestError {
retryable: false, retry_after: None,
})
});
assert!(matches!(result, RetryResult::Failure(_)));
assert_eq!(call_count, 1); }
#[test]
fn test_retry_exhausted() {
let policy = RetryPolicy {
max_attempts: 3,
initial_delay_ms: 1,
max_delay_ms: 10,
jitter_factor: 0.0,
};
let mut call_count = 0;
let result: RetryResult<&str, TestError> = with_retry(&policy, None, || {
call_count += 1;
Err(TestError {
retryable: true,
retry_after: None,
})
});
assert!(matches!(result, RetryResult::Exhausted { attempts: 3, .. }));
assert_eq!(call_count, 3);
}
#[test]
fn test_jitter_produces_variation() {
let policy = RetryPolicy {
max_attempts: 5,
initial_delay_ms: 1000,
max_delay_ms: 30000,
jitter_factor: 0.5,
};
let delays: Vec<_> = (0..10).map(|_| policy.delay_for_attempt(2)).collect();
let unique_delays: std::collections::HashSet<_> = delays.iter().collect();
assert!(
unique_delays.len() > 1,
"Expected jitter to produce variation"
);
}
}