use std::collections::HashSet;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_retries: u32,
pub backoff_base: Duration,
pub backoff_max: Duration,
pub retryable_status_codes: HashSet<u16>,
pub retry_on_connection_error: bool,
pub retry_on_timeout: bool,
}
impl RetryPolicy {
pub fn new(max_retries: u32) -> Self {
let mut retryable = HashSet::new();
retryable.insert(429); retryable.insert(500); retryable.insert(502); retryable.insert(503); retryable.insert(504);
Self {
max_retries,
backoff_base: Duration::from_millis(100),
backoff_max: Duration::from_secs(30),
retryable_status_codes: retryable,
retry_on_connection_error: true,
retry_on_timeout: true,
}
}
pub fn with_backoff_base(mut self, base: Duration) -> Self {
self.backoff_base = base;
self
}
pub fn with_backoff_max(mut self, max: Duration) -> Self {
self.backoff_max = max;
self
}
pub fn with_retryable_status_codes(mut self, codes: HashSet<u16>) -> Self {
self.retryable_status_codes = codes;
self
}
pub fn add_retryable_status(mut self, code: u16) -> Self {
self.retryable_status_codes.insert(code);
self
}
pub fn with_retry_on_connection_error(mut self, retry: bool) -> Self {
self.retry_on_connection_error = retry;
self
}
pub fn with_retry_on_timeout(mut self, retry: bool) -> Self {
self.retry_on_timeout = retry;
self
}
pub fn should_retry_status(&self, status_code: u16) -> bool {
self.retryable_status_codes.contains(&status_code)
}
pub fn backoff_delay(&self, attempt: u32) -> Duration {
let delay = self
.backoff_base
.saturating_mul(2u32.saturating_pow(attempt));
std::cmp::min(delay, self.backoff_max)
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::new(3)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_retry_policy() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_retries, 3);
assert!(policy.should_retry_status(503));
assert!(policy.should_retry_status(429));
assert!(!policy.should_retry_status(404));
}
#[test]
fn test_backoff_delay() {
let policy = RetryPolicy::new(5).with_backoff_base(Duration::from_millis(100));
assert_eq!(policy.backoff_delay(0), Duration::from_millis(100));
assert_eq!(policy.backoff_delay(1), Duration::from_millis(200));
assert_eq!(policy.backoff_delay(2), Duration::from_millis(400));
assert_eq!(policy.backoff_delay(3), Duration::from_millis(800));
}
#[test]
fn test_backoff_delay_capped() {
let policy = RetryPolicy::new(5)
.with_backoff_base(Duration::from_secs(1))
.with_backoff_max(Duration::from_secs(5));
assert_eq!(policy.backoff_delay(0), Duration::from_secs(1));
assert_eq!(policy.backoff_delay(1), Duration::from_secs(2));
assert_eq!(policy.backoff_delay(2), Duration::from_secs(4));
assert_eq!(policy.backoff_delay(3), Duration::from_secs(5)); }
#[test]
fn test_custom_retryable_status() {
let policy = RetryPolicy::new(1).add_retryable_status(418);
assert!(policy.should_retry_status(418));
assert!(policy.should_retry_status(503)); }
}