use super::request::{Request, RequestBuilder, RequestSpec};
use crate::Config;
use crate::common::types::RetryCount;
use crate::error::{ApiError, ApiErrorKind, OpenAIError, RequestError};
use crate::utils::traits::AsyncFrom;
use rand::Rng;
use reqwest::{Client, Response};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
pub(crate) struct HttpExecutor {
config: Arc<RwLock<Config>>,
reqwest_client: RwLock<Client>,
}
impl HttpExecutor {
pub fn new(config: Config) -> HttpExecutor {
let reqwest_client = config.http().build_reqwest_client();
HttpExecutor {
config: Arc::new(RwLock::new(config)),
reqwest_client: RwLock::new(reqwest_client),
}
}
pub(crate) fn config(&self) -> Arc<RwLock<Config>> {
self.config.clone()
}
pub async fn rebuild_reqwest_client(&self) {
let new_client = {
let config_guard = self.config.read().await;
config_guard.http().build_reqwest_client()
};
let mut client_guard = self.reqwest_client.write().await;
*client_guard = new_client;
}
pub async fn post<U, F>(&self, params: RequestSpec<U, F>) -> Result<Response, OpenAIError>
where
U: FnOnce(&Config) -> String,
F: FnOnce(&Config, &mut RequestBuilder),
{
let client = {
let client_guard = self.reqwest_client.read().await;
client_guard.clone()
};
let (retry_count, request) = {
let config_guard = self.config.read().await;
let mut request_builder = RequestBuilder::new(
reqwest::Method::POST,
(params.url_fn)(&config_guard).as_str(),
);
(params.builder_fn)(&config_guard, &mut request_builder);
HttpExecutor::apply_global_http_settings(&config_guard, &mut request_builder);
let retry_count = match request_builder.extensions().get::<RetryCount>() {
Some(retry) => {
if retry.0 != 0 {
retry.0
} else {
config_guard.retry_count()
}
}
None => config_guard.retry_count(),
};
let request = request_builder.build();
(retry_count, request)
};
HttpExecutor::send_with_retries(request, retry_count as u32, client).await
}
pub async fn get<U, F>(&self, params: RequestSpec<U, F>) -> Result<Response, OpenAIError>
where
U: FnOnce(&Config) -> String,
F: FnOnce(&Config, &mut RequestBuilder),
{
let client = {
let client_guard = self.reqwest_client.read().await;
client_guard.clone()
};
let (retry_count, request) = {
let config_guard = self.config.read().await;
let mut request_builder = RequestBuilder::new(
reqwest::Method::GET,
(params.url_fn)(&config_guard).as_str(),
);
(params.builder_fn)(&config_guard, &mut request_builder);
HttpExecutor::apply_global_http_settings(&config_guard, &mut request_builder);
let retry_count = match request_builder.extensions().get::<RetryCount>() {
Some(retry) => {
if retry.0 != 0 {
retry.0
} else {
config_guard.retry_count()
}
}
None => config_guard.retry_count(),
};
let request = request_builder.build();
(retry_count, request)
};
HttpExecutor::send_with_retries(request, retry_count as u32, client).await
}
}
impl HttpExecutor {
fn apply_global_http_settings(config: &Config, request_builder: &mut RequestBuilder) {
config.http().headers().iter().for_each(|(k, v)| {
if !request_builder.has_header(k) {
request_builder.header(k, v.clone());
}
});
config.http().bodys().iter().for_each(|(k, v)| {
if !request_builder.has_body_field(k) {
request_builder.body_field(k, v.clone());
}
});
}
async fn send_with_retries(
request: Request,
retry_count: u32,
client: reqwest::Client,
) -> Result<Response, OpenAIError> {
let mut attempts = 0;
let max_attempts = retry_count.max(1);
loop {
attempts += 1;
let request_builder = request.to_reqwest(&client);
match request_builder.send().await {
Ok(response) => {
let retry_after = response
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs);
if response.status().is_success() {
return Ok(response);
} else {
let api_error = ApiError::async_from(response).await;
if attempts >= max_attempts || !api_error.is_retryable() {
return Err(api_error.into());
}
tracing::debug!(
"Attempt {}/{}: Retrying after API error: {:?}",
attempts,
max_attempts,
api_error
);
tokio::time::sleep(calculate_retry_delay(
attempts,
&api_error.kind,
retry_after,
))
.await;
}
}
Err(e) => {
let request_error: RequestError = e.into();
if attempts >= max_attempts || !request_error.is_retryable() {
return Err(request_error.into());
}
tracing::debug!(
"Attempt {}/{}: Retrying after request error: {:?}",
attempts,
max_attempts,
request_error
);
tokio::time::sleep(calculate_retry_delay_for_request_error(
attempts,
&request_error,
))
.await;
}
}
}
}
}
fn calculate_retry_delay(
attempt: u32,
error_kind: &ApiErrorKind,
retry_after: Option<Duration>,
) -> Duration {
if let Some(duration) = retry_after {
let jitter = Duration::from_millis(rand::thread_rng().gen_range(0..1000));
return duration + jitter;
}
let base_delay_ms = match error_kind {
ApiErrorKind::RateLimit => 5000, ApiErrorKind::InternalServer => 1000, _ => 500, };
let delay_ms = base_delay_ms * 2u64.pow(attempt - 1);
let base_delay = Duration::from_millis(delay_ms.min(30_000));
let jitter_ms = (base_delay.as_millis() as u64 * (rand::thread_rng().gen_range(0..10))) / 100;
base_delay + Duration::from_millis(jitter_ms)
}
fn calculate_retry_delay_for_request_error(attempt: u32, error: &RequestError) -> Duration {
let base_delay = match error {
RequestError::Timeout(_) => 100, RequestError::Connection(_) => 200, _ => 100, };
let delay_ms = base_delay * 2u64.pow(attempt - 1);
let base_delay = Duration::from_millis(delay_ms.min(10_000));
let jitter_ms = (base_delay.as_millis() as u64 * (rand::random::<u64>() % 10)) / 100;
base_delay + Duration::from_millis(jitter_ms)
}