use http::Method;
use reqwest::Request;
use std::sync::Arc;
use tracing::debug;
use crate::{
CircuitBreaker, HttpClientConfig, HttpClientError, RequestBuilder, Response, Result,
RetryStrategy,
};
#[derive(Clone)]
pub struct HttpClient {
inner: reqwest::Client,
config: Arc<HttpClientConfig>,
circuit_breaker: Option<Arc<CircuitBreaker>>,
}
impl HttpClient {
pub fn new(config: HttpClientConfig) -> Self {
let mut builder = reqwest::Client::builder()
.timeout(config.timeout)
.connect_timeout(config.connect_timeout)
.pool_idle_timeout(config.pool_idle_timeout)
.pool_max_idle_per_host(config.pool_max_idle_per_host)
.user_agent(&config.user_agent);
if config.gzip {
builder = builder.gzip(true);
}
if config.brotli {
builder = builder.brotli(true);
}
if config.follow_redirects {
builder = builder.redirect(reqwest::redirect::Policy::limited(config.max_redirects));
} else {
builder = builder.redirect(reqwest::redirect::Policy::none());
}
let inner = builder.build().expect("Failed to build HTTP client");
let circuit_breaker = config
.circuit_breaker
.as_ref()
.map(|cb_config| Arc::new(CircuitBreaker::new(cb_config.clone())));
Self {
inner,
config: Arc::new(config),
circuit_breaker,
}
}
pub fn default_client() -> Self {
Self::new(HttpClientConfig::default())
}
pub fn inner(&self) -> &reqwest::Client {
&self.inner
}
pub fn config(&self) -> &HttpClientConfig {
&self.config
}
pub fn get(&self, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, Method::GET, url.into())
}
pub fn post(&self, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, Method::POST, url.into())
}
pub fn put(&self, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, Method::PUT, url.into())
}
pub fn patch(&self, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, Method::PATCH, url.into())
}
pub fn delete(&self, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, Method::DELETE, url.into())
}
pub fn head(&self, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, Method::HEAD, url.into())
}
pub fn request(&self, method: Method, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, method, url.into())
}
pub(crate) async fn execute(&self, request: Request) -> Result<Response> {
if let Some(cb) = &self.circuit_breaker {
if !cb.is_allowed() {
return Err(HttpClientError::CircuitOpen);
}
}
if let Some(retry_config) = &self.config.retry {
self.execute_with_retry(request, retry_config).await
} else {
self.execute_once(request).await
}
}
async fn execute_with_retry(
&self,
request: Request,
retry_config: &crate::RetryConfig,
) -> Result<Response> {
let mut attempt = 0;
let mut last_error: Option<HttpClientError> = None;
let start = std::time::Instant::now();
loop {
if let Some(max_time) = retry_config.max_retry_time {
if start.elapsed() > max_time {
break;
}
}
let request_clone = clone_request(&request);
match self.execute_once(request_clone).await {
Ok(response) => {
if let Some(cb) = &self.circuit_breaker {
cb.record_success();
}
if retry_config.should_retry_status(response.status().as_u16())
&& attempt < retry_config.max_attempts - 1
{
debug!(
attempt = attempt + 1,
status = %response.status(),
"Retrying request due to status code"
);
last_error = Some(HttpClientError::Response {
status: response.status().as_u16(),
message: "Retriable status code".to_string(),
});
attempt += 1;
let delay = retry_config.delay_for_attempt(attempt);
tokio::time::sleep(delay).await;
continue;
}
return Ok(response);
}
Err(e) => {
if let Some(cb) = &self.circuit_breaker {
cb.record_failure();
}
if retry_config.should_retry(attempt, &e)
&& attempt < retry_config.max_attempts - 1
{
debug!(
attempt = attempt + 1,
error = %e,
"Retrying request due to error"
);
last_error = Some(e);
attempt += 1;
let delay = retry_config.delay_for_attempt(attempt);
tokio::time::sleep(delay).await;
continue;
}
return Err(e);
}
}
}
Err(HttpClientError::RetryExhausted {
attempts: attempt + 1,
message: last_error
.map(|e| e.to_string())
.unwrap_or_else(|| "Unknown error".to_string()),
})
}
async fn execute_once(&self, request: Request) -> Result<Response> {
let response = self.inner.execute(request).await?;
Ok(Response::from_reqwest(response).await)
}
}
fn clone_request(request: &Request) -> Request {
let mut builder = reqwest::Request::new(request.method().clone(), request.url().clone());
*builder.headers_mut() = request.headers().clone();
builder
}
impl Default for HttpClient {
fn default() -> Self {
Self::default_client()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_client_creation() {
let client = HttpClient::default();
assert!(client.config().gzip);
assert!(client.config().brotli);
}
#[test]
fn test_client_with_config() {
let config = HttpClientConfig::builder()
.timeout(Duration::from_secs(60))
.base_url("https://api.example.com")
.build();
let client = HttpClient::new(config);
assert_eq!(client.config().timeout, Duration::from_secs(60));
assert_eq!(
client.config().base_url.as_deref(),
Some("https://api.example.com")
);
}
}