use crate::client::RetryConfig;
use crate::error::{Error, Result};
use fastrand::Rng;
use reqwest::{
header::{HeaderMap, RETRY_AFTER},
RequestBuilder, Response,
};
use std::time::{Duration, Instant, SystemTime};
use tokio::time::{sleep, timeout};
pub mod operations {
pub const TEXT_COMPLETION: &str = "text_completion";
pub const WEB_SEARCH: &str = "web_search";
pub const LIST_MODELS: &str = "list_models";
pub const GET_BALANCE: &str = "get_balance";
pub const GET_ACTIVITY: &str = "get_activity";
pub const GET_PROVIDERS: &str = "get_providers";
pub const GET_GENERATION: &str = "get_generation";
pub const STRUCTURED_GENERATE: &str = "structured_generate";
pub const CHAT_COMPLETION: &str = "chat_completion";
pub const GET_KEY_INFO: &str = "get_key_info";
pub const GET_EMBEDDINGS: &str = "get_embeddings";
}
pub async fn execute_with_retry_builder<F>(
config: &RetryConfig,
operation_name: &str,
mut request_builder: F,
) -> Result<Response>
where
F: FnMut() -> RequestBuilder,
{
let mut retry_count = 0usize;
let mut backoff_ms = config.initial_backoff_ms;
let mut rng = Rng::new();
let start_time = Instant::now();
loop {
let remaining = config.total_timeout.saturating_sub(start_time.elapsed());
if remaining.is_zero() {
return Err(Error::TimeoutError(format!(
"Retry timeout exceeded for {}: {}ms limit",
operation_name,
config.total_timeout.as_millis()
)));
}
let send_fut = request_builder().send();
match timeout(remaining, send_fut).await {
Err(_) => {
if retry_count < config.max_retries as usize {
retry_count += 1;
let sleep_ms =
jittered_backoff_ms(backoff_ms, config.max_backoff_ms, &mut rng, remaining);
sleep(Duration::from_millis(sleep_ms)).await;
backoff_ms = next_backoff(backoff_ms, config.max_backoff_ms);
continue;
} else {
return Err(Error::TimeoutError(format!(
"Request timeout for {} after {:?}",
operation_name, config.total_timeout
)));
}
}
Ok(Err(e)) => {
if is_retryable_reqwest_error(&e) && retry_count < config.max_retries as usize {
retry_count += 1;
let sleep_ms =
jittered_backoff_ms(backoff_ms, config.max_backoff_ms, &mut rng, remaining);
sleep(Duration::from_millis(sleep_ms)).await;
backoff_ms = next_backoff(backoff_ms, config.max_backoff_ms);
continue;
}
return Err(e.into());
}
Ok(Ok(response)) => {
let status = response.status();
let status_code = status.as_u16();
if config.retry_on_status_codes.contains(&status_code)
&& retry_count < config.max_retries as usize
{
retry_count += 1;
let retry_after_ms = parse_retry_after_ms(response.headers());
if let Err(e) = response.bytes().await {
#[cfg(feature = "tracing")]
tracing::warn!(
operation = operation_name,
error = %e,
"Failed to consume response body during retry"
);
#[cfg(not(feature = "tracing"))]
eprintln!(
"Warning: Failed to consume response body during retry for {}: {}",
operation_name, e
);
}
let base_ms = retry_after_ms.unwrap_or(backoff_ms);
let sleep_ms =
jittered_backoff_ms(base_ms, config.max_backoff_ms, &mut rng, remaining);
sleep(Duration::from_millis(sleep_ms)).await;
if retry_after_ms.is_none() {
backoff_ms = next_backoff(backoff_ms, config.max_backoff_ms);
}
continue;
}
return Ok(response);
}
}
}
}
fn parse_retry_after_ms(headers: &HeaderMap) -> Option<u64> {
const MAX_SECONDS: u64 = 3600;
let value = headers.get(RETRY_AFTER)?;
let s = value.to_str().ok()?.trim();
if let Ok(seconds) = s.parse::<u64>() {
return Some(seconds.min(MAX_SECONDS) * 1000);
}
if let Ok(http_date) = httpdate::parse_http_date(s) {
let now = SystemTime::now();
let dur = match http_date.duration_since(now) {
Ok(d) => d,
Err(_) => Duration::ZERO, };
return Some(dur.min(Duration::from_secs(MAX_SECONDS)).as_millis() as u64);
}
None
}
fn is_retryable_reqwest_error(e: &reqwest::Error) -> bool {
e.is_timeout() || e.is_connect()
}
fn jittered_backoff_ms(
base_ms: u64,
max_backoff_ms: u64,
rng: &mut Rng,
remaining_overall: Duration,
) -> u64 {
let safe_base = base_ms.min(300_000);
let capped = safe_base.min(max_backoff_ms);
let jitter = rng.f64() * 0.5 + 0.75;
let jittered = (capped as f64 * jitter) as u64;
let remaining_ms = remaining_overall.as_millis().saturating_sub(25) as u64;
jittered.min(remaining_ms)
}
fn next_backoff(current_ms: u64, max_backoff_ms: u64) -> u64 {
let doubled = current_ms.saturating_mul(2);
doubled.min(max_backoff_ms).min(300_000) }
pub async fn handle_response_text(response: Response, operation_name: &str) -> Result<String> {
let status = response.status();
let status_code = status.as_u16();
let body = response.text().await?;
if !status.is_success() {
let err = Error::from_response_text(status_code, &body);
return Err(err);
}
if body.trim().is_empty() {
return Err(Error::ApiError {
code: status_code,
message: format!("Empty response body for {}", operation_name),
metadata: None,
});
}
Ok(body)
}
pub async fn handle_response_json<T: serde::de::DeserializeOwned>(
response: Response,
operation_name: &str,
) -> Result<T> {
let status = response.status();
let status_code = status.as_u16();
let body = response.text().await?;
if !status.is_success() {
let err = Error::from_response_text(status_code, &body);
return Err(err);
}
if body.trim().is_empty() {
return Err(Error::ApiError {
code: status_code,
message: format!("Empty response body for {}", operation_name),
metadata: None,
});
}
serde_json::from_str::<T>(&body).map_err(|e| Error::DeserializationError {
status_code,
message: crate::utils::security::create_safe_error_message(
&format!(
"Failed to decode JSON response for {}: {}. Body (elided) was: {}",
operation_name,
e,
elide(&body, 2_000)
),
&format!("{} JSON parsing error", operation_name),
),
})
}
fn elide(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
let mut end = max;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
format!("{}… ({} bytes total)", &s[..end], s.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::RetryConfig;
use reqwest::header::HeaderValue;
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_backoff_ms, 500);
assert_eq!(config.max_backoff_ms, 10000);
assert!(config.retry_on_status_codes.contains(&429));
assert!(config.retry_on_status_codes.contains(&500));
}
#[test]
fn test_exponential_backoff_calculation() {
let config = RetryConfig::default();
let mut backoff_ms = config.initial_backoff_ms;
assert_eq!(backoff_ms, 500);
backoff_ms = next_backoff(backoff_ms, config.max_backoff_ms);
assert_eq!(backoff_ms, 1000);
backoff_ms = next_backoff(backoff_ms, config.max_backoff_ms);
assert_eq!(backoff_ms, 2000);
for _ in 0..10 {
backoff_ms = next_backoff(backoff_ms, config.max_backoff_ms);
}
assert_eq!(backoff_ms, config.max_backoff_ms.min(300_000));
}
#[test]
fn test_parse_retry_after_delta_seconds() {
let mut h = HeaderMap::new();
h.insert(RETRY_AFTER, HeaderValue::from_static("120"));
assert_eq!(parse_retry_after_ms(&h), Some(120_000));
}
#[test]
fn test_parse_retry_after_http_date_future() {
let mut h = HeaderMap::new();
let future = SystemTime::now() + Duration::from_secs(5);
let s = httpdate::fmt_http_date(future);
h.insert(RETRY_AFTER, HeaderValue::from_str(&s).unwrap());
let ms = parse_retry_after_ms(&h).unwrap();
assert!(ms <= 5000 && ms > 0);
}
#[test]
fn test_parse_retry_after_http_date_past() {
let mut h = HeaderMap::new();
let past = SystemTime::now() - Duration::from_secs(5);
let s = httpdate::fmt_http_date(past);
h.insert(RETRY_AFTER, HeaderValue::from_str(&s).unwrap());
assert_eq!(parse_retry_after_ms(&h), Some(0));
}
#[tokio::test]
async fn test_retry_config_status_codes() {
let config = RetryConfig::default();
assert!(config.retry_on_status_codes.contains(&429)); assert!(config.retry_on_status_codes.contains(&500)); assert!(config.retry_on_status_codes.contains(&502)); assert!(config.retry_on_status_codes.contains(&503)); assert!(config.retry_on_status_codes.contains(&504));
assert!(!config.retry_on_status_codes.contains(&200));
assert!(!config.retry_on_status_codes.contains(&201));
assert!(!config.retry_on_status_codes.contains(&400));
assert!(!config.retry_on_status_codes.contains(&401));
assert!(!config.retry_on_status_codes.contains(&404));
}
#[tokio::test]
async fn test_retry_config_new_fields() {
use std::time::Duration;
let config = RetryConfig::default();
assert_eq!(config.total_timeout, Duration::from_secs(120));
let custom_config = config.with_total_timeout(Duration::from_secs(300));
assert_eq!(custom_config.total_timeout, Duration::from_secs(300));
}
#[tokio::test]
async fn test_total_timeout_enforcement() {
let config = RetryConfig {
max_retries: 3,
initial_backoff_ms: 50,
max_backoff_ms: 100,
retry_on_status_codes: vec![429, 500, 502, 503, 504],
total_timeout: Duration::from_millis(200), max_retry_interval: Duration::from_secs(30),
};
let client = reqwest::Client::new();
let result = execute_with_retry_builder(&config, "test_timeout", || {
client.get("http://192.0.2.1:99999") })
.await;
assert!(result.is_err());
let error = result.unwrap_err();
match &error {
Error::TimeoutError(msg) => {
assert!(
msg.contains("timeout") || msg.contains("Timeout"),
"Expected timeout message, got: {}",
msg
);
}
Error::HttpError(_) => {
}
_ => panic!("Expected timeout or network error, got: {:?}", error),
}
}
#[tokio::test]
async fn test_individual_retry_capping() {
use reqwest::StatusCode;
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(StatusCode::INTERNAL_SERVER_ERROR))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 3,
initial_backoff_ms: 100,
max_backoff_ms: 5000, retry_on_status_codes: vec![500],
total_timeout: Duration::from_secs(10), max_retry_interval: Duration::from_secs(30),
};
let start_time = std::time::Instant::now();
let client = reqwest::Client::new();
let result =
execute_with_retry_builder(&config, "test_capping", || client.get(mock_server.uri()))
.await;
let elapsed = start_time.elapsed();
match result {
Ok(response) => {
assert_eq!(response.status().as_u16(), 500);
}
Err(error) => {
match &error {
Error::ApiError {
code: _,
message: _,
metadata: _,
} => {} Error::TimeoutError(_) => {
}
_ => panic!("Expected API error or timeout, got: {:?}", error),
}
}
}
assert!(
elapsed < Duration::from_millis(1500),
"Took too long: {:?}",
elapsed
);
}
#[tokio::test]
async fn test_concurrent_retry_limits() {
use reqwest::StatusCode;
use std::sync::Arc;
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(StatusCode::INTERNAL_SERVER_ERROR))
.up_to_n_times(1)
.mount(&mock_server)
.await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(StatusCode::OK))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 2,
initial_backoff_ms: 50,
max_backoff_ms: 200,
retry_on_status_codes: vec![500],
total_timeout: Duration::from_secs(5),
max_retry_interval: Duration::from_secs(30),
};
let config = Arc::new(config);
let server_url = mock_server.uri();
let handles: Vec<_> = (0..5)
.map(|_| {
let config = config.clone();
let url = server_url.clone();
tokio::spawn(async move {
let client = reqwest::Client::new();
execute_with_retry_builder(&config, "concurrent_test", || client.get(&url))
.await
})
})
.collect();
let results: Vec<_> = futures::future::join_all(handles).await;
let mut successes = 0;
let mut failures = 0;
for result in results {
match result {
Ok(Ok(_)) => successes += 1,
Ok(Err(_)) => failures += 1,
Err(_) => failures += 1, }
}
assert!(
successes >= 3,
"Expected at least 3 successes, got {}",
successes
);
assert_eq!(failures, 0, "Expected no failures, got {}", failures);
}
#[tokio::test]
async fn test_retry_performance_impact() {
use reqwest::StatusCode;
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(StatusCode::OK))
.mount(&mock_server)
.await;
let config = RetryConfig::default();
let start_time = std::time::Instant::now();
let client = reqwest::Client::new();
let result = execute_with_retry_builder(&config, "performance_test", || {
client.get(mock_server.uri())
})
.await;
let elapsed = start_time.elapsed();
assert!(result.is_ok());
assert!(
elapsed < Duration::from_millis(100),
"Took too long: {:?}",
elapsed
);
}
#[tokio::test]
async fn test_backoff_jitter_variation() {
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
Mock::given(matchers::method("GET"))
.respond_with(ResponseTemplate::new(429))
.mount(&mock_server)
.await;
let config = RetryConfig {
max_retries: 3,
initial_backoff_ms: 100,
max_backoff_ms: 200,
retry_on_status_codes: vec![429],
total_timeout: Duration::from_secs(5),
max_retry_interval: Duration::from_secs(30),
};
let start_time = std::time::Instant::now();
let client = reqwest::Client::new();
let result =
execute_with_retry_builder(&config, "jitter_test", || client.get(mock_server.uri()))
.await;
let elapsed = start_time.elapsed();
match result {
Ok(response) => {
assert_eq!(response.status().as_u16(), 429);
}
Err(error) => {
match &error {
Error::ApiError {
code: _,
message: _,
metadata: _,
} => {} Error::TimeoutError(_) => {
}
_ => panic!("Expected API error or timeout, got: {:?}", error),
}
}
}
assert!(
elapsed > Duration::from_millis(300),
"Too fast, likely no retries: {:?}",
elapsed
);
assert!(
elapsed < Duration::from_millis(1000),
"Too slow, possible issue: {:?}",
elapsed
);
}
}