use crate::errors::MarketDataError;
use std::time::Duration;
#[derive(Debug, Clone, Copy)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
}
impl RetryPolicy {
pub fn new(max_attempts: u32, initial_backoff: Duration, max_backoff: Duration) -> Self {
Self {
max_attempts,
initial_backoff,
max_backoff,
}
}
pub fn conservative() -> Self {
Self::new(3, Duration::from_millis(100), Duration::from_secs(2))
}
pub fn aggressive() -> Self {
Self::new(5, Duration::from_millis(250), Duration::from_secs(10))
}
pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
if attempt <= 1 {
return Duration::ZERO;
}
let exp = attempt.saturating_sub(1).min(31);
let multiplier = 1u64 << exp;
let raw_nanos = self
.initial_backoff
.as_nanos()
.saturating_mul(u128::from(multiplier));
let capped = raw_nanos.min(self.max_backoff.as_nanos());
let base = Duration::from_nanos(capped.min(u128::from(u64::MAX)) as u64);
base + jitter(self.initial_backoff)
}
}
fn jitter(ceiling: Duration) -> Duration {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
use std::time::Instant;
let nanos_ceil = ceiling.as_nanos().min(u128::from(u64::MAX)) as u64;
if nanos_ceil == 0 {
return Duration::ZERO;
}
let now = Instant::now().elapsed().as_nanos() as u64;
let mut hasher = RandomState::new().build_hasher();
hasher.write_u64(now);
let pseudo_random = hasher.finish() % nanos_ceil;
Duration::from_nanos(pseudo_random)
}
pub(crate) fn run<T>(
policy: &RetryPolicy,
mut op: impl FnMut() -> Result<T, MarketDataError>,
) -> Result<T, MarketDataError> {
let mut last_err: Option<MarketDataError> = None;
for attempt in 1..=policy.max_attempts {
let delay = policy.delay_for_attempt(attempt);
if !delay.is_zero() {
std::thread::sleep(delay);
}
match op() {
Ok(value) => return Ok(value),
Err(err) => {
if !err.is_retryable() || attempt == policy.max_attempts {
return Err(err);
}
last_err = Some(err);
}
}
}
Err(last_err.unwrap_or(MarketDataError::RuntimeError {
msg: "retry loop exited without error or success".into(),
}))
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::Cell;
#[test]
fn test_conservative_preset() {
let p = RetryPolicy::conservative();
assert_eq!(p.max_attempts, 3);
assert_eq!(p.initial_backoff, Duration::from_millis(100));
assert_eq!(p.max_backoff, Duration::from_secs(2));
}
#[test]
fn test_aggressive_preset() {
let p = RetryPolicy::aggressive();
assert_eq!(p.max_attempts, 5);
}
#[test]
fn test_first_attempt_no_delay() {
let p = RetryPolicy::conservative();
assert_eq!(p.delay_for_attempt(1), Duration::ZERO);
}
#[test]
fn test_delay_capped_at_max() {
let p = RetryPolicy::new(20, Duration::from_millis(100), Duration::from_millis(500));
for attempt in 1..=10 {
let d = p.delay_for_attempt(attempt);
assert!(d <= Duration::from_millis(600), "attempt {} = {:?}", attempt, d);
}
}
#[test]
fn test_run_succeeds_first_attempt() {
let p = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10));
let result: Result<i32, MarketDataError> = run(&p, || Ok(42));
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_run_retries_retryable() {
let p = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10));
let attempts = Cell::new(0u32);
let result: Result<i32, MarketDataError> = run(&p, || {
let n = attempts.get() + 1;
attempts.set(n);
if n < 2 {
Err(MarketDataError::ApiError {
status: 503,
message: "transient".into(),
})
} else {
Ok(42)
}
});
assert_eq!(result.unwrap(), 42);
assert_eq!(attempts.get(), 2);
}
#[test]
fn test_run_does_not_retry_non_retryable() {
let p = RetryPolicy::new(5, Duration::from_millis(1), Duration::from_millis(10));
let attempts = Cell::new(0u32);
let result: Result<i32, MarketDataError> = run(&p, || {
attempts.set(attempts.get() + 1);
Err(MarketDataError::ApiError {
status: 401,
message: "unauthorized".into(),
})
});
assert!(result.is_err());
assert_eq!(attempts.get(), 1);
}
#[test]
fn test_run_exhausts_and_returns_last_error() {
let p = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10));
let attempts = Cell::new(0u32);
let result: Result<i32, MarketDataError> = run(&p, || {
attempts.set(attempts.get() + 1);
Err(MarketDataError::ApiError {
status: 503,
message: "still down".into(),
})
});
assert_eq!(attempts.get(), 3);
match result.unwrap_err() {
MarketDataError::ApiError { status, message } => {
assert_eq!(status, 503);
assert_eq!(message, "still down");
}
other => panic!("expected ApiError, got {:?}", other),
}
}
}