use anyhow::{Context, Result};
use reqwest::Client;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct HttpClientConfig {
pub max_retries: u32,
pub max_wait: u64,
pub base_delay: u64,
pub timeout: u64,
pub connect_timeout: u64,
}
impl Default for HttpClientConfig {
fn default() -> Self {
Self {
max_retries: 3,
max_wait: 60,
base_delay: 1,
timeout: 120,
connect_timeout: 30,
}
}
}
impl HttpClientConfig {
pub fn create_client(&self) -> Result<Client> {
Client::builder()
.timeout(Duration::from_secs(self.timeout))
.connect_timeout(Duration::from_secs(self.connect_timeout))
.build()
.context("Failed to create HTTP client")
}
pub fn from_cli_and_env(timeout_flag: Option<u64>) -> Self {
let timeout = timeout_flag
.or_else(|| {
std::env::var("RAPS_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
})
.unwrap_or(120);
Self {
timeout,
..Self::default()
}
}
}
pub async fn execute_with_retry<F, T>(config: &HttpClientConfig, mut request_fn: F) -> Result<T>
where
F: FnMut() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>,
{
let mut attempt = 0;
loop {
match request_fn().await {
Ok(result) => return Ok(result),
Err(err) => {
let should_retry = should_retry_error(&err, attempt, config.max_retries);
if !should_retry {
return Err(err);
}
attempt += 1;
let delay = calculate_delay(attempt, config.base_delay, config.max_wait);
crate::logging::log_verbose(&format!(
"Request failed (attempt {}/{}), retrying in {}s...",
attempt,
config.max_retries,
delay.as_secs()
));
sleep(delay).await;
}
}
}
}
fn should_retry_error(err: &anyhow::Error, attempt: u32, max_retries: u32) -> bool {
if attempt >= max_retries {
return false;
}
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
if reqwest_err.is_status()
&& let Some(status) = reqwest_err.status()
{
if status.as_u16() == 429 {
return true;
}
if status.is_server_error() {
return true;
}
if status.is_client_error() {
return false;
}
}
if reqwest_err.is_timeout() || reqwest_err.is_connect() || reqwest_err.is_request() {
return true;
}
}
let error_str = err.to_string().to_lowercase();
if error_str.contains("429") || error_str.contains("too many requests") {
return true;
}
if error_str.contains("500")
|| error_str.contains("502")
|| error_str.contains("503")
|| error_str.contains("504")
|| error_str.contains("server error")
{
return true;
}
if error_str.contains("timeout")
|| error_str.contains("connection")
|| error_str.contains("network")
{
return true;
}
false
}
fn calculate_delay(attempt: u32, base_delay: u64, max_wait: u64) -> Duration {
use rand::Rng;
let exponential_delay = base_delay * 2_u64.pow(attempt);
let capped_delay = exponential_delay.min(max_wait);
let mut rng = rand::thread_rng();
let jitter = rng.gen_range(0..=(capped_delay / 4));
Duration::from_secs(capped_delay + jitter)
}