use crate::error::Error;
use reqwest::header::HeaderMap;
use std::time::{Duration, Instant, SystemTime};
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: usize,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay_ms: 100,
max_delay_ms: 5000,
backoff_multiplier: 2.0,
jitter: true,
}
}
}
#[derive(Debug, Clone)]
pub struct TimeoutConfig {
pub connect_timeout_ms: u64,
pub request_timeout_ms: u64,
}
impl Default for TimeoutConfig {
fn default() -> Self {
Self {
connect_timeout_ms: 10_000, request_timeout_ms: 30_000, }
}
}
#[derive(Debug, Clone)]
pub struct RetryInfo {
pub attempt: u32,
pub status_code: Option<u16>,
pub delay_ms: u64,
pub reason: String,
}
impl RetryInfo {
#[must_use]
pub fn new(
attempt: u32,
status_code: Option<u16>,
delay_ms: u64,
reason: impl Into<String>,
) -> Self {
Self {
attempt,
status_code,
delay_ms,
reason: reason.into(),
}
}
}
#[derive(Debug)]
pub struct RetryResult<T> {
pub result: Result<T, Error>,
pub retry_history: Vec<RetryInfo>,
pub total_attempts: u32,
}
#[must_use]
pub fn parse_retry_after_header(headers: &HeaderMap) -> Option<Duration> {
let retry_after = headers.get("retry-after")?;
let value = retry_after.to_str().ok()?;
parse_retry_after_value(value)
}
#[must_use]
pub fn parse_retry_after_value(value: &str) -> Option<Duration> {
if let Ok(seconds) = value.parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
httpdate::parse_http_date(value)
.ok()
.and_then(|date| date.duration_since(SystemTime::now()).ok())
}
#[must_use]
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_possible_wrap
)]
pub fn calculate_retry_delay_with_header(
config: &RetryConfig,
attempt: usize,
retry_after: Option<Duration>,
) -> Duration {
let calculated_delay = calculate_retry_delay(config, attempt);
retry_after.map_or(calculated_delay, |server_delay| {
let delay = calculated_delay.max(server_delay);
let max_delay = Duration::from_millis(config.max_delay_ms);
delay.min(max_delay)
})
}
#[must_use]
pub fn is_retryable_error(error: &reqwest::Error) -> bool {
if error.is_connect() {
return true;
}
if error.is_timeout() {
return true;
}
error
.status()
.is_none_or(|status| is_retryable_status(status.as_u16()))
}
#[must_use]
pub const fn is_retryable_status(status: u16) -> bool {
match status {
408 | 429 => true,
500..=599 => !matches!(status, 501 | 505),
_ => false, }
}
#[must_use]
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_possible_wrap
)]
pub fn calculate_retry_delay(config: &RetryConfig, attempt: usize) -> Duration {
let base_delay = config.initial_delay_ms as f64;
let attempt_i32 = attempt.min(30) as i32; let delay_ms =
(base_delay * config.backoff_multiplier.powi(attempt_i32)).min(config.max_delay_ms as f64);
let final_delay_ms = if config.jitter {
let jitter_factor = fastrand::f64().mul_add(0.25, 1.0);
delay_ms * jitter_factor
} else {
delay_ms
} as u64;
Duration::from_millis(final_delay_ms)
}
pub async fn execute_with_retry<F, Fut, T>(
config: &RetryConfig,
_operation_name: &str,
mut operation: F,
) -> Result<T, Error>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, reqwest::Error>>,
{
let _start_time = Instant::now();
let mut last_error = None;
for attempt in 0..config.max_attempts {
match operation().await {
Ok(result) => {
return Ok(result);
}
Err(error) => {
let is_last_attempt = attempt + 1 >= config.max_attempts;
let is_retryable = is_retryable_error(&error);
if !is_retryable {
let error_message = error.to_string();
return Err(Error::transient_network_error(error_message, false));
}
if is_last_attempt {
let error_message = error.to_string();
last_error = Some(error_message);
break;
}
let delay = calculate_retry_delay(config, attempt);
sleep(delay).await;
last_error = Some(error.to_string());
}
}
}
Err(Error::retry_limit_exceeded(
config.max_attempts.try_into().unwrap_or(u32::MAX),
last_error.unwrap_or_else(|| "Unknown error".to_string()),
))
}
#[allow(clippy::cast_possible_truncation)]
pub async fn execute_with_retry_tracking<F, Fut, T>(
config: &RetryConfig,
operation_name: &str,
mut operation: F,
) -> RetryResult<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, reqwest::Error>>,
{
let mut retry_history = Vec::new();
let mut last_error = None;
for attempt in 0..config.max_attempts {
match operation().await {
Ok(result) => {
return RetryResult {
result: Ok(result),
retry_history,
total_attempts: (attempt + 1) as u32,
};
}
Err(error) => {
let is_last_attempt = attempt + 1 >= config.max_attempts;
let is_retryable = is_retryable_error(&error);
let status_code = error.status().map(|s| s.as_u16());
let error_message = error.to_string();
if !is_retryable {
return RetryResult {
result: Err(Error::transient_network_error(error_message, false)),
retry_history,
total_attempts: (attempt + 1) as u32,
};
}
if is_last_attempt {
last_error = Some(error_message);
break;
}
let delay = calculate_retry_delay(config, attempt);
let delay_ms = delay.as_millis() as u64;
retry_history.push(RetryInfo::new(
(attempt + 1) as u32,
status_code,
delay_ms,
format!("{operation_name}: {error_message}"),
));
sleep(delay).await;
last_error = Some(error_message);
}
}
}
RetryResult {
result: Err(Error::retry_limit_exceeded(
config.max_attempts.try_into().unwrap_or(u32::MAX),
last_error.unwrap_or_else(|| "Unknown error".to_string()),
)),
retry_history,
total_attempts: config.max_attempts as u32,
}
}
pub fn create_resilient_client(timeout_config: &TimeoutConfig) -> Result<reqwest::Client, Error> {
reqwest::Client::builder()
.connect_timeout(Duration::from_millis(timeout_config.connect_timeout_ms))
.timeout(Duration::from_millis(timeout_config.request_timeout_ms))
.build()
.map_err(|e| {
Error::network_request_failed(format!("Failed to create resilient HTTP client: {e}"))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_retry_delay() {
let config = RetryConfig {
max_attempts: 5,
initial_delay_ms: 100,
max_delay_ms: 1000,
backoff_multiplier: 2.0,
jitter: false,
};
let delay1 = calculate_retry_delay(&config, 0);
let delay2 = calculate_retry_delay(&config, 1);
let delay3 = calculate_retry_delay(&config, 2);
assert_eq!(delay1.as_millis(), 100);
assert_eq!(delay2.as_millis(), 200);
assert_eq!(delay3.as_millis(), 400);
let delay_max = calculate_retry_delay(&config, 10);
assert_eq!(delay_max.as_millis(), 1000);
}
#[test]
fn test_calculate_retry_delay_with_jitter() {
let config = RetryConfig {
max_attempts: 3,
initial_delay_ms: 100,
max_delay_ms: 1000,
backoff_multiplier: 2.0,
jitter: true,
};
let delay1 = calculate_retry_delay(&config, 0);
let delay2 = calculate_retry_delay(&config, 0);
assert!(delay1.as_millis() >= 100 && delay1.as_millis() <= 125);
assert!(delay2.as_millis() >= 100 && delay2.as_millis() <= 125);
}
#[test]
fn test_default_configs() {
let retry_config = RetryConfig::default();
assert_eq!(retry_config.max_attempts, 3);
assert_eq!(retry_config.initial_delay_ms, 100);
let timeout_config = TimeoutConfig::default();
assert_eq!(timeout_config.connect_timeout_ms, 10_000);
assert_eq!(timeout_config.request_timeout_ms, 30_000);
}
#[test]
fn test_parse_retry_after_header_seconds() {
let mut headers = HeaderMap::new();
headers.insert("retry-after", "120".parse().unwrap());
let duration = parse_retry_after_header(&headers);
assert_eq!(duration, Some(Duration::from_secs(120)));
}
#[test]
fn test_parse_retry_after_header_zero() {
let mut headers = HeaderMap::new();
headers.insert("retry-after", "0".parse().unwrap());
let duration = parse_retry_after_header(&headers);
assert_eq!(duration, Some(Duration::from_secs(0)));
}
#[test]
fn test_parse_retry_after_header_missing() {
let headers = HeaderMap::new();
let duration = parse_retry_after_header(&headers);
assert_eq!(duration, None);
}
#[test]
fn test_parse_retry_after_header_invalid() {
let mut headers = HeaderMap::new();
headers.insert("retry-after", "not-a-number".parse().unwrap());
let duration = parse_retry_after_header(&headers);
assert_eq!(duration, None);
}
#[test]
fn test_calculate_retry_delay_with_header_none() {
let config = RetryConfig {
max_attempts: 3,
initial_delay_ms: 100,
max_delay_ms: 5000,
backoff_multiplier: 2.0,
jitter: false,
};
let delay = calculate_retry_delay_with_header(&config, 0, None);
assert_eq!(delay.as_millis(), 100);
}
#[test]
fn test_calculate_retry_delay_with_header_uses_server_delay_when_larger() {
let config = RetryConfig {
max_attempts: 3,
initial_delay_ms: 100,
max_delay_ms: 5000,
backoff_multiplier: 2.0,
jitter: false,
};
let retry_after = Some(Duration::from_secs(3));
let delay = calculate_retry_delay_with_header(&config, 0, retry_after);
assert_eq!(delay.as_secs(), 3);
}
#[test]
fn test_calculate_retry_delay_with_header_uses_calculated_when_larger() {
let config = RetryConfig {
max_attempts: 3,
initial_delay_ms: 5000,
max_delay_ms: 30_000,
backoff_multiplier: 2.0,
jitter: false,
};
let retry_after = Some(Duration::from_secs(1));
let delay = calculate_retry_delay_with_header(&config, 0, retry_after);
assert_eq!(delay.as_millis(), 5000);
}
#[test]
fn test_calculate_retry_delay_with_header_caps_at_max() {
let config = RetryConfig {
max_attempts: 3,
initial_delay_ms: 100,
max_delay_ms: 5000,
backoff_multiplier: 2.0,
jitter: false,
};
let retry_after = Some(Duration::from_secs(60));
let delay = calculate_retry_delay_with_header(&config, 0, retry_after);
assert_eq!(delay.as_millis(), 5000);
}
#[test]
fn test_retry_info_new() {
let info = RetryInfo::new(1, Some(429), 500, "Rate limited");
assert_eq!(info.attempt, 1);
assert_eq!(info.status_code, Some(429));
assert_eq!(info.delay_ms, 500);
assert_eq!(info.reason, "Rate limited");
}
#[test]
fn test_retry_info_without_status_code() {
let info = RetryInfo::new(2, None, 1000, "Connection refused");
assert_eq!(info.attempt, 2);
assert_eq!(info.status_code, None);
assert_eq!(info.delay_ms, 1000);
assert_eq!(info.reason, "Connection refused");
}
#[test]
fn test_retry_result_success_no_retries() {
let result: RetryResult<i32> = RetryResult {
result: Ok(42),
retry_history: vec![],
total_attempts: 1,
};
assert!(result.result.is_ok());
assert!(result.retry_history.is_empty());
assert_eq!(result.total_attempts, 1);
}
#[test]
fn test_retry_result_success_after_retries() {
let result: RetryResult<i32> = RetryResult {
result: Ok(42),
retry_history: vec![RetryInfo::new(1, Some(503), 100, "Service unavailable")],
total_attempts: 2,
};
assert!(result.result.is_ok());
assert_eq!(result.retry_history.len(), 1);
assert_eq!(result.total_attempts, 2);
}
#[test]
fn test_is_retryable_status_408_request_timeout() {
assert!(is_retryable_status(408));
}
#[test]
fn test_is_retryable_status_429_too_many_requests() {
assert!(is_retryable_status(429));
}
#[test]
fn test_is_retryable_status_500_internal_server_error() {
assert!(is_retryable_status(500));
}
#[test]
fn test_is_retryable_status_502_bad_gateway() {
assert!(is_retryable_status(502));
}
#[test]
fn test_is_retryable_status_503_service_unavailable() {
assert!(is_retryable_status(503));
}
#[test]
fn test_is_retryable_status_504_gateway_timeout() {
assert!(is_retryable_status(504));
}
#[test]
fn test_is_retryable_status_501_not_implemented_not_retryable() {
assert!(!is_retryable_status(501));
}
#[test]
fn test_is_retryable_status_505_http_version_not_supported_not_retryable() {
assert!(!is_retryable_status(505));
}
#[test]
fn test_is_retryable_status_4xx_not_retryable() {
assert!(!is_retryable_status(400)); assert!(!is_retryable_status(401)); assert!(!is_retryable_status(403)); assert!(!is_retryable_status(404)); assert!(!is_retryable_status(405)); assert!(!is_retryable_status(422)); }
#[test]
fn test_is_retryable_status_2xx_not_retryable() {
assert!(!is_retryable_status(200));
assert!(!is_retryable_status(201));
assert!(!is_retryable_status(204));
}
#[test]
fn test_is_retryable_status_3xx_not_retryable() {
assert!(!is_retryable_status(301));
assert!(!is_retryable_status(302));
assert!(!is_retryable_status(304));
}
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
struct FailThenSucceed {
fail_for: usize,
count: Arc<AtomicUsize>,
}
impl wiremock::Respond for FailThenSucceed {
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
let n = self.count.fetch_add(1, Ordering::SeqCst);
if n < self.fail_for {
ResponseTemplate::new(503)
} else {
ResponseTemplate::new(200).set_body_string("done")
}
}
}
fn no_jitter_config(max_attempts: usize) -> RetryConfig {
RetryConfig {
max_attempts,
initial_delay_ms: 1,
max_delay_ms: 10,
backoff_multiplier: 2.0,
jitter: false,
}
}
async fn make_request(client: &reqwest::Client, url: &str) -> Result<String, reqwest::Error> {
let resp = client.get(url).send().await?;
let resp = resp.error_for_status()?;
let body = resp.text().await?;
Ok(body)
}
#[tokio::test]
async fn test_execute_with_retry_immediate_success() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/ok"))
.respond_with(ResponseTemplate::new(200).set_body_string("ok"))
.expect(1)
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/ok", server.uri());
let config = no_jitter_config(3);
let result = execute_with_retry(&config, "test", || {
let client = client.clone();
let url = url.clone();
async move { make_request(&client, &url).await }
})
.await;
assert!(result.is_ok(), "should succeed on first attempt");
assert_eq!(result.unwrap(), "ok");
}
#[tokio::test]
async fn test_execute_with_retry_non_retryable_short_circuits() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/bad"))
.respond_with(ResponseTemplate::new(400))
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/bad", server.uri());
let config = no_jitter_config(3);
let call_count = Arc::new(AtomicUsize::new(0));
let cc = call_count.clone();
let result = execute_with_retry(&config, "test", || {
let client = client.clone();
let url = url.clone();
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
make_request(&client, &url).await
}
})
.await;
assert!(result.is_err(), "non-retryable error must propagate");
assert_eq!(call_count.load(Ordering::SeqCst), 1, "must not retry a 400");
}
#[tokio::test]
async fn test_execute_with_retry_succeeds_after_transient_failures() {
let server = MockServer::start().await;
let count = Arc::new(AtomicUsize::new(0));
Mock::given(method("GET"))
.and(path("/flaky"))
.respond_with(FailThenSucceed {
fail_for: 2,
count: count.clone(),
})
.expect(3)
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/flaky", server.uri());
let config = no_jitter_config(3);
let result = execute_with_retry(&config, "test", || {
let client = client.clone();
let url = url.clone();
async move { make_request(&client, &url).await }
})
.await;
assert!(result.is_ok(), "should succeed after two transient 503s");
assert_eq!(count.load(Ordering::SeqCst), 3, "must have made 3 calls");
}
#[tokio::test]
async fn test_execute_with_retry_exhaustion_returns_error() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/always-fail"))
.respond_with(ResponseTemplate::new(503))
.expect(3)
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/always-fail", server.uri());
let config = no_jitter_config(3);
let result = execute_with_retry(&config, "test", || {
let client = client.clone();
let url = url.clone();
async move { make_request(&client, &url).await }
})
.await;
assert!(result.is_err(), "all attempts exhausted must return error");
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Retry limit exceeded"),
"error must mention retry exhaustion, got: {msg}"
);
}
#[tokio::test]
async fn test_execute_with_retry_tracking_immediate_success() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/ok"))
.respond_with(ResponseTemplate::new(200).set_body_string("ok"))
.expect(1)
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/ok", server.uri());
let config = no_jitter_config(3);
let ret = execute_with_retry_tracking(&config, "test", || {
let client = client.clone();
let url = url.clone();
async move { make_request(&client, &url).await }
})
.await;
assert!(ret.result.is_ok());
assert_eq!(ret.total_attempts, 1);
assert!(
ret.retry_history.is_empty(),
"no retries on immediate success"
);
}
#[tokio::test]
async fn test_execute_with_retry_tracking_non_retryable_short_circuits() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/bad"))
.respond_with(ResponseTemplate::new(400))
.expect(1)
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/bad", server.uri());
let config = no_jitter_config(3);
let ret = execute_with_retry_tracking(&config, "test", || {
let client = client.clone();
let url = url.clone();
async move { make_request(&client, &url).await }
})
.await;
assert!(ret.result.is_err());
assert_eq!(ret.total_attempts, 1, "must stop after the first attempt");
assert!(
ret.retry_history.is_empty(),
"non-retryable error must not populate retry history"
);
}
#[tokio::test]
async fn test_execute_with_retry_tracking_records_history() {
let server = MockServer::start().await;
let count = Arc::new(AtomicUsize::new(0));
Mock::given(method("GET"))
.and(path("/flaky"))
.respond_with(FailThenSucceed {
fail_for: 2,
count: count.clone(),
})
.expect(3)
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/flaky", server.uri());
let config = no_jitter_config(3);
let ret = execute_with_retry_tracking(&config, "test-op", || {
let client = client.clone();
let url = url.clone();
async move { make_request(&client, &url).await }
})
.await;
assert!(ret.result.is_ok(), "should eventually succeed");
assert_eq!(ret.total_attempts, 3);
assert_eq!(ret.retry_history.len(), 2);
assert_eq!(ret.retry_history[0].attempt, 1);
assert_eq!(ret.retry_history[1].attempt, 2);
}
#[tokio::test]
async fn test_execute_with_retry_tracking_exhaustion_total_attempts() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/always-fail"))
.respond_with(ResponseTemplate::new(503))
.expect(3)
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/always-fail", server.uri());
let config = no_jitter_config(3);
let ret = execute_with_retry_tracking(&config, "test-op", || {
let client = client.clone();
let url = url.clone();
async move { make_request(&client, &url).await }
})
.await;
assert!(ret.result.is_err(), "all attempts exhausted");
assert_eq!(ret.total_attempts, 3);
assert_eq!(ret.retry_history.len(), 2);
}
}