use super::request::{Request, RequestBuilder, RequestSpec};
use crate::Config;
use crate::error::{ApiError, ApiErrorKind, OpenAIError, RequestError};
use crate::interceptor::InterceptorChain;
use crate::utils::traits::AsyncFrom;
use rand::Rng;
use reqwest::{Client, Response};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
pub struct HttpExecutor {
config: Arc<RwLock<Config>>,
reqwest_client: RwLock<Client>,
}
impl HttpExecutor {
pub fn new(config: Config) -> HttpExecutor {
let reqwest_client = config.http().build_reqwest_client();
HttpExecutor {
config: Arc::new(RwLock::new(config)),
reqwest_client: RwLock::new(reqwest_client),
}
}
pub(crate) fn config(&self) -> Arc<RwLock<Config>> {
self.config.clone()
}
pub async fn rebuild_reqwest_client(&self) {
let new_client = {
let config_guard = self.config.read().await;
config_guard.http().build_reqwest_client()
};
let mut client_guard = self.reqwest_client.write().await;
*client_guard = new_client;
}
pub async fn post<U, F>(&self, params: RequestSpec<U, F>) -> Result<Response, OpenAIError>
where
U: FnOnce(&Config) -> String,
F: FnOnce(&Config, &mut RequestBuilder),
{
let client = {
let client_guard = self.reqwest_client.read().await;
client_guard.clone()
};
let (retry_count, global_interceptors, request, module_interceptors) = {
let config_guard = self.config.read().await;
let retry_count = if params.retry_count != 0 {
params.retry_count
} else {
config_guard.retry_count()
};
let mut request_builder = RequestBuilder::new(
reqwest::Method::POST,
(params.url_fn)(&config_guard).as_str(),
);
(params.builder_fn)(&config_guard, &mut request_builder);
HttpExecutor::apply_global_http_settings(&config_guard, &mut request_builder);
let request = request_builder.build();
let global_interceptors = config_guard.global_interceptors().clone();
let module_interceptors = params.module_interceptors;
(
retry_count,
global_interceptors,
request,
module_interceptors,
)
};
HttpExecutor::send_with_retries(
request,
retry_count,
global_interceptors,
module_interceptors,
client,
)
.await
}
pub async fn get<U, F>(&self, params: RequestSpec<U, F>) -> Result<Response, OpenAIError>
where
U: FnOnce(&Config) -> String,
F: FnOnce(&Config, &mut RequestBuilder),
{
let client = {
let client_guard = self.reqwest_client.read().await;
client_guard.clone()
};
let (retry_count, global_interceptors, request, module_interceptors) = {
let config_guard = self.config.read().await;
let retry_count = if params.retry_count != 0 {
params.retry_count
} else {
config_guard.retry_count()
};
let mut request_builder = RequestBuilder::new(
reqwest::Method::GET,
(params.url_fn)(&config_guard).as_str(),
);
(params.builder_fn)(&config_guard, &mut request_builder);
HttpExecutor::apply_global_http_settings(&config_guard, &mut request_builder);
let request = request_builder.build();
let global_interceptors = config_guard.global_interceptors().clone();
let module_interceptors = params.module_interceptors;
(
retry_count,
global_interceptors,
request,
module_interceptors,
)
};
HttpExecutor::send_with_retries(
request,
retry_count,
global_interceptors,
module_interceptors,
client,
)
.await
}
}
impl HttpExecutor {
fn apply_global_http_settings(config: &Config, request_builder: &mut RequestBuilder) {
config.http().querys().iter().for_each(|(k, v)| {
if !request_builder.has_query(k) {
request_builder.query(k, v);
}
});
config.http().headers().iter().for_each(|(k, v)| {
if !request_builder.has_header(k) {
request_builder.header(k, v);
}
});
config.http().bodys().iter().for_each(|(k, v)| {
if !request_builder.has_body_field(k) {
request_builder.body_field(k, v.clone());
}
});
}
async fn apply_request_interceptors(
mut request: Request,
global_interceptors: &InterceptorChain,
module_interceptors: Option<&InterceptorChain>,
) -> Result<Request, OpenAIError> {
if !global_interceptors.is_empty() {
request = global_interceptors
.execute_request_interceptors(request)
.await?;
}
if let Some(module_chain) = module_interceptors {
request = module_chain.execute_request_interceptors(request).await?;
}
Ok(request)
}
async fn apply_response_interceptors(
mut response: Response,
module_interceptors: Option<&InterceptorChain>,
global_interceptors: &InterceptorChain,
) -> Result<Response, OpenAIError> {
if let Some(module_chain) = module_interceptors {
response = module_chain.execute_response_interceptors(response).await?;
}
if !global_interceptors.is_empty() {
response = global_interceptors
.execute_response_interceptors(response)
.await?;
}
Ok(response)
}
async fn apply_error_interceptors(
mut error: OpenAIError,
module_interceptors: Option<&InterceptorChain>,
global_interceptors: &InterceptorChain,
) -> Result<OpenAIError, OpenAIError> {
if let Some(module_chain) = module_interceptors {
error = module_chain.execute_error_interceptors(error).await?;
}
if !global_interceptors.is_empty() {
error = global_interceptors
.execute_error_interceptors(error)
.await?;
}
Ok(error)
}
async fn send_with_retries(
request: Request,
retry_count: u32,
global_interceptors: InterceptorChain,
module_interceptors: Option<InterceptorChain>,
client: reqwest::Client,
) -> Result<Response, OpenAIError> {
let mut attempts = 0;
let max_attempts = retry_count.max(1);
let processed_request = HttpExecutor::apply_request_interceptors(
request,
&global_interceptors,
module_interceptors.as_ref(),
)
.await?;
loop {
attempts += 1;
let request_builder = processed_request.to_reqwest(&client);
match request_builder.send().await {
Ok(response) => {
let retry_after = response
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs);
if response.status().is_success() {
let processed_response = HttpExecutor::apply_response_interceptors(
response,
module_interceptors.as_ref(),
&global_interceptors,
)
.await?;
return Ok(processed_response);
} else {
let api_error = ApiError::async_from(response).await;
if attempts >= max_attempts || !api_error.is_retryable() {
let error = HttpExecutor::apply_error_interceptors(
api_error.into(),
module_interceptors.as_ref(),
&global_interceptors,
)
.await?;
return Err(error);
}
tracing::debug!(
"Attempt {}/{}: Retrying after API error: {:?}",
attempts,
max_attempts,
api_error
);
tokio::time::sleep(calculate_retry_delay(
attempts,
&api_error.kind,
retry_after,
))
.await;
}
}
Err(e) => {
let request_error: RequestError = e.into();
if attempts >= max_attempts || !request_error.is_retryable() {
let error = HttpExecutor::apply_error_interceptors(
request_error.into(),
module_interceptors.as_ref(),
&global_interceptors,
)
.await?;
return Err(error);
}
tracing::debug!(
"Attempt {}/{}: Retrying after request error: {:?}",
attempts,
max_attempts,
request_error
);
tokio::time::sleep(calculate_retry_delay_for_request_error(
attempts,
&request_error,
))
.await;
}
}
}
}
}
fn calculate_retry_delay(
attempt: u32,
error_kind: &ApiErrorKind,
retry_after: Option<Duration>,
) -> Duration {
if let Some(duration) = retry_after {
let jitter = Duration::from_millis(rand::thread_rng().gen_range(0..1000));
return duration + jitter;
}
let base_delay_ms = match error_kind {
ApiErrorKind::RateLimit => 5000, ApiErrorKind::InternalServer => 1000, _ => 500, };
let delay_ms = base_delay_ms * 2u64.pow(attempt - 1);
let base_delay = Duration::from_millis(delay_ms.min(30_000));
let jitter_ms = (base_delay.as_millis() as u64 * (rand::thread_rng().gen_range(0..10))) / 100;
base_delay + Duration::from_millis(jitter_ms)
}
fn calculate_retry_delay_for_request_error(attempt: u32, error: &RequestError) -> Duration {
let base_delay = match error {
RequestError::Timeout(_) => 100, RequestError::Connection(_) => 200, _ => 100, };
let delay_ms = base_delay * 2u64.pow(attempt - 1);
let base_delay = Duration::from_millis(delay_ms.min(10_000));
let jitter_ms = (base_delay.as_millis() as u64 * (rand::random::<u64>() % 10)) / 100;
base_delay + Duration::from_millis(jitter_ms)
}