use std::future::Future;
use std::time::Duration;
use rand::Rng;
use tracing::debug;
use crate::error::{Result, SeerError};
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
pub jitter: bool,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
multiplier: 2.0,
jitter: true,
}
}
}
impl RetryPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_attempts(mut self, attempts: usize) -> Self {
self.max_attempts = attempts.max(1);
self
}
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn with_multiplier(mut self, multiplier: f64) -> Self {
self.multiplier = multiplier.max(1.0);
self
}
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
pub fn no_retry() -> Self {
Self {
max_attempts: 1,
..Self::default()
}
}
pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
if attempt == 0 {
return self.initial_delay;
}
let safe_attempt = attempt.min(20) as i32;
let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
let final_delay = if self.jitter {
let mut rng = rand::thread_rng();
let jitter_factor = rng.gen_range(0.0..1.0);
capped_delay * jitter_factor
} else {
capped_delay
};
Duration::from_millis(final_delay as u64)
}
}
pub trait RetryClassifier: Send + Sync {
fn is_retryable(&self, error: &SeerError) -> bool;
}
#[derive(Debug, Clone, Default)]
pub struct NetworkRetryClassifier;
impl NetworkRetryClassifier {
pub fn new() -> Self {
Self
}
}
impl RetryClassifier for NetworkRetryClassifier {
fn is_retryable(&self, error: &SeerError) -> bool {
match error {
SeerError::Timeout(_) => true,
SeerError::WhoisConnectionFailed(_) => true,
SeerError::RateLimited(_) => true,
SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
SeerError::WhoisError(msg) => {
let lower = msg.to_lowercase();
lower.contains("connection")
|| lower.contains("timeout")
|| lower.contains("refused")
|| lower.contains("reset")
}
SeerError::RdapError(msg) => {
let lower = msg.to_lowercase();
lower.contains("status 5")
|| lower.contains("status 429")
|| lower.contains("timeout")
}
SeerError::RdapBootstrapError(msg) => {
let lower = msg.to_lowercase();
lower.contains("timeout") || lower.contains("connection")
}
SeerError::DnsError(msg) => {
let lower = msg.to_lowercase();
lower.contains("timeout") || lower.contains("temporary")
}
SeerError::HttpError(msg) => {
let lower = msg.to_lowercase();
lower.contains("timeout")
|| lower.contains("connection")
|| lower.contains("status 5")
|| lower.contains("status 429")
}
SeerError::InvalidDomain(_) => false,
SeerError::DomainNotAllowed { .. } => false,
SeerError::InvalidIpAddress(_) => false,
SeerError::InvalidRecordType(_) => false,
SeerError::WhoisServerNotFound(_) => false,
SeerError::JsonError(_) => false,
SeerError::CertificateError(_) => false,
SeerError::SslError(_) => false,
SeerError::DnsResolverError(_) => false,
SeerError::BulkOperationError { .. } => false,
SeerError::LookupFailed { .. } => false,
SeerError::ConfigError(_) => false,
SeerError::InvalidInput(_) => false,
SeerError::RetryExhausted { last_error, .. } => self.is_retryable(last_error),
SeerError::Other(_) => false,
}
}
}
fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
if error.is_connect() {
return true;
}
if error.is_timeout() {
return true;
}
if let Some(status) = error.status() {
if status.as_u16() == 429 {
return true;
}
if status.is_server_error() {
return true;
}
return false;
}
if error.is_request() || error.is_body() {
return false;
}
true
}
#[derive(Debug, Clone)]
pub struct RetryExecutor<C: RetryClassifier> {
policy: RetryPolicy,
classifier: C,
}
impl RetryExecutor<NetworkRetryClassifier> {
pub fn new(policy: RetryPolicy) -> Self {
Self {
policy,
classifier: NetworkRetryClassifier::new(),
}
}
}
impl<C: RetryClassifier> RetryExecutor<C> {
pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
Self { policy, classifier }
}
pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut last_error: Option<SeerError> = None;
let mut attempt = 0;
while attempt < self.policy.max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
let is_retryable = self.classifier.is_retryable(&e);
let attempts_remaining = self.policy.max_attempts - attempt - 1;
if !is_retryable || attempts_remaining == 0 {
if attempt > 0 {
debug!(
attempt = attempt + 1,
max_attempts = self.policy.max_attempts,
error = %e,
"Operation failed after retries"
);
}
return Err(if attempt > 0 {
SeerError::RetryExhausted {
attempts: attempt + 1,
last_error: Box::new(e),
}
} else {
e
});
}
let delay = self.policy.delay_for_attempt(attempt);
debug!(
attempt = attempt + 1,
max_attempts = self.policy.max_attempts,
delay_ms = delay.as_millis(),
error = %e,
"Retrying after transient error"
);
last_error = Some(e);
tokio::time::sleep(delay).await;
attempt += 1;
}
}
}
Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
}
pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T>>,
{
operation().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[test]
fn test_retry_policy_defaults() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_attempts, 3);
assert_eq!(policy.initial_delay, Duration::from_millis(100));
assert_eq!(policy.max_delay, Duration::from_secs(5));
assert_eq!(policy.multiplier, 2.0);
assert!(policy.jitter);
}
#[test]
fn test_retry_policy_builder() {
let policy = RetryPolicy::new()
.with_max_attempts(5)
.with_initial_delay(Duration::from_millis(200))
.with_max_delay(Duration::from_secs(10))
.with_multiplier(3.0)
.with_jitter(false);
assert_eq!(policy.max_attempts, 5);
assert_eq!(policy.initial_delay, Duration::from_millis(200));
assert_eq!(policy.max_delay, Duration::from_secs(10));
assert_eq!(policy.multiplier, 3.0);
assert!(!policy.jitter);
}
#[test]
fn test_delay_calculation_no_jitter() {
let policy = RetryPolicy::new()
.with_initial_delay(Duration::from_millis(100))
.with_multiplier(2.0)
.with_max_delay(Duration::from_secs(10))
.with_jitter(false);
assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
}
#[test]
fn test_delay_capped_at_max() {
let policy = RetryPolicy::new()
.with_initial_delay(Duration::from_secs(1))
.with_multiplier(10.0)
.with_max_delay(Duration::from_secs(5))
.with_jitter(false);
assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
}
#[test]
fn test_classifier_timeout_is_retryable() {
let classifier = NetworkRetryClassifier::new();
assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
}
#[test]
fn test_classifier_invalid_domain_not_retryable() {
let classifier = NetworkRetryClassifier::new();
assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
}
#[test]
fn test_classifier_server_not_found_not_retryable() {
let classifier = NetworkRetryClassifier::new();
assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
}
#[test]
fn test_classifier_rate_limited_is_retryable() {
let classifier = NetworkRetryClassifier::new();
assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
}
#[tokio::test]
async fn test_executor_success_on_first_try() {
let policy = RetryPolicy::new().with_max_attempts(3);
let executor = RetryExecutor::new(policy);
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result: Result<&str> = executor
.execute(|| {
let a = attempts_clone.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
Ok("success")
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_executor_retries_on_transient_error() {
let policy = RetryPolicy::new()
.with_max_attempts(3)
.with_initial_delay(Duration::from_millis(1))
.with_jitter(false);
let executor = RetryExecutor::new(policy);
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result: Result<&str> = executor
.execute(|| {
let a = attempts_clone.clone();
async move {
let count = a.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(SeerError::Timeout("test timeout".to_string()))
} else {
Ok("success after retries")
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success after retries");
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_executor_no_retry_on_non_retryable_error() {
let policy = RetryPolicy::new()
.with_max_attempts(3)
.with_initial_delay(Duration::from_millis(1));
let executor = RetryExecutor::new(policy);
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result: Result<&str> = executor
.execute(|| {
let a = attempts_clone.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
Err(SeerError::InvalidDomain("bad.".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_executor_exhausts_retries() {
let policy = RetryPolicy::new()
.with_max_attempts(3)
.with_initial_delay(Duration::from_millis(1))
.with_jitter(false);
let executor = RetryExecutor::new(policy);
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result: Result<&str> = executor
.execute(|| {
let a = attempts_clone.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
Err(SeerError::Timeout("always fails".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 3);
match result.unwrap_err() {
SeerError::RetryExhausted { attempts, .. } => {
assert_eq!(attempts, 3);
}
other => panic!("Expected RetryExhausted, got {:?}", other),
}
}
#[test]
fn test_no_retry_policy() {
let policy = RetryPolicy::no_retry();
assert_eq!(policy.max_attempts, 1);
}
#[test]
fn test_delay_overflow_protection() {
let policy = RetryPolicy::new()
.with_initial_delay(Duration::from_millis(100))
.with_multiplier(2.0)
.with_max_delay(Duration::from_secs(5))
.with_jitter(false);
let delay_50 = policy.delay_for_attempt(50);
let delay_100 = policy.delay_for_attempt(100);
let delay_1000 = policy.delay_for_attempt(1000);
assert!(delay_50 <= Duration::from_secs(5));
assert!(delay_100 <= Duration::from_secs(5));
assert!(delay_1000 <= Duration::from_secs(5));
}
#[test]
fn retry_exhausted_is_retryable_if_inner_is() {
let classifier = NetworkRetryClassifier::new();
let retryable_inner = SeerError::Timeout("inner timed out".to_string());
let wrapped_retryable = SeerError::RetryExhausted {
attempts: 3,
last_error: Box::new(retryable_inner),
};
assert!(
classifier.is_retryable(&wrapped_retryable),
"RetryExhausted wrapping a retryable Timeout should be retryable",
);
let non_retryable_inner = SeerError::InvalidDomain("bad.".to_string());
let wrapped_non_retryable = SeerError::RetryExhausted {
attempts: 3,
last_error: Box::new(non_retryable_inner),
};
assert!(
!classifier.is_retryable(&wrapped_non_retryable),
"RetryExhausted wrapping a non-retryable InvalidDomain must not be retryable",
);
}
}