use super::config::{HttpClientConfig, HttpClientConfigBuilder};
use super::error::HttpError;
use super::request::{CancellableRequest, Method, RequestBody, RequestBuilder};
use super::response::Response;
#[derive(Debug, Clone)]
pub struct HttpClient {
inner: reqwest::Client,
config: HttpClientConfig,
}
impl HttpClient {
pub fn new() -> Result<Self, HttpError> {
Self::with_config(HttpClientConfig::default())
}
pub fn with_config(config: HttpClientConfig) -> Result<Self, HttpError> {
let mut builder = reqwest::Client::builder()
.connect_timeout(config.connect_timeout)
.timeout(config.request_timeout)
.user_agent(&config.user_agent)
.redirect(if config.follow_redirects {
reqwest::redirect::Policy::limited(config.max_redirects as usize)
} else {
reqwest::redirect::Policy::none()
});
if config.accept_invalid_certs {
builder = builder.danger_accept_invalid_certs(true);
}
if let Some(proxy) = &config.proxy {
if let Some(http) = &proxy.http {
builder = builder.proxy(
reqwest::Proxy::http(http)
.map_err(|e| HttpError::ClientBuild(format!("HTTP 代理配置错误: {}", e)))?
);
}
if let Some(https) = &proxy.https {
builder = builder.proxy(
reqwest::Proxy::https(https)
.map_err(|e| HttpError::ClientBuild(format!("HTTPS 代理配置错误: {}", e)))?
);
}
}
let inner = builder
.build()
.map_err(|e| HttpError::ClientBuild(e.to_string()))?;
Ok(Self { inner, config })
}
pub fn builder() -> HttpClientConfigBuilder {
HttpClientConfigBuilder::new()
}
pub fn config(&self) -> &HttpClientConfig {
&self.config
}
pub fn inner(&self) -> &reqwest::Client {
&self.inner
}
pub async fn get(&self, url: &str) -> Result<Response, HttpError> {
self.execute(RequestBuilder::new(Method::GET, url)).await
}
pub async fn post_json<T: serde::Serialize>(
&self,
url: &str,
body: &T,
) -> Result<Response, HttpError> {
let builder = RequestBuilder::new(Method::POST, url).json(body)?;
self.execute(builder).await
}
pub async fn post_form<T: serde::Serialize>(
&self,
url: &str,
body: &T,
) -> Result<Response, HttpError> {
let builder = RequestBuilder::new(Method::POST, url).form(body)?;
self.execute(builder).await
}
pub async fn put_json<T: serde::Serialize>(
&self,
url: &str,
body: &T,
) -> Result<Response, HttpError> {
let builder = RequestBuilder::new(Method::PUT, url).json(body)?;
self.execute(builder).await
}
pub async fn delete(&self, url: &str) -> Result<Response, HttpError> {
self.execute(RequestBuilder::new(Method::DELETE, url)).await
}
pub async fn patch_json<T: serde::Serialize>(
&self,
url: &str,
body: &T,
) -> Result<Response, HttpError> {
let builder = RequestBuilder::new(Method::PATCH, url).json(body)?;
self.execute(builder).await
}
pub async fn head(&self, url: &str) -> Result<Response, HttpError> {
self.execute(RequestBuilder::new(Method::HEAD, url)).await
}
pub fn request(&self, method: Method, url: &str) -> RequestBuilder {
RequestBuilder::new(method, url)
}
pub async fn execute(&self, builder: RequestBuilder) -> Result<Response, HttpError> {
self.execute_with_retry(builder).await
}
pub async fn execute_cancellable(
&self,
request: CancellableRequest,
) -> Result<Response, HttpError> {
tokio::select! {
result = self.execute_with_retry(request.builder) => result,
_ = request.cancel_token.cancelled() => Err(HttpError::Cancelled),
}
}
async fn execute_with_retry(&self, builder: RequestBuilder) -> Result<Response, HttpError> {
let max_retries = builder.max_retries.unwrap_or(self.config.max_retries);
let retry_strategy = builder
.retry_strategy
.clone()
.unwrap_or_else(|| self.config.retry_strategy.clone());
let mut attempts = 0;
loop {
attempts += 1;
match self.execute_once(&builder).await {
Ok(resp) => {
if resp.is_server_error() && attempts <= max_retries {
} else {
return Ok(resp);
}
}
Err(e) if e.is_retryable() && attempts <= max_retries => {
}
Err(e) => return Err(e),
}
if let Some(delay) = retry_strategy.delay_for_attempt(attempts) {
tokio::time::sleep(delay).await;
} else {
return Err(HttpError::RequestTimeout);
}
}
}
async fn execute_once(&self, builder: &RequestBuilder) -> Result<Response, HttpError> {
let url = builder.build_url()?;
let method = match builder.method {
Method::GET => reqwest::Method::GET,
Method::POST => reqwest::Method::POST,
Method::PUT => reqwest::Method::PUT,
Method::DELETE => reqwest::Method::DELETE,
Method::PATCH => reqwest::Method::PATCH,
Method::HEAD => reqwest::Method::HEAD,
Method::OPTIONS => reqwest::Method::OPTIONS,
};
let mut req = self.inner.request(method, &url);
for (key, value) in &builder.headers {
req = req.header(key.as_str(), value.as_str());
}
if let Some(timeout) = builder.timeout {
req = req.timeout(timeout);
}
match &builder.body {
RequestBody::None => {}
RequestBody::Json(bytes) => {
req = req.body(bytes.clone());
}
RequestBody::Form(encoded) => {
req = req.body(encoded.clone());
}
RequestBody::Bytes(bytes) => {
req = req.body(bytes.clone());
}
RequestBody::Text(text) => {
req = req.body(text.clone());
}
}
let resp = req.send().await?;
Response::from_reqwest(resp).await
}
}
impl Default for HttpClient {
fn default() -> Self {
Self::new().expect("创建默认 HTTP 客户端失败")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_client_config() {
let config = HttpClientConfig::default();
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert_eq!(config.max_retries, 3);
}
#[test]
fn test_builder_methods() {
let builder = HttpClient::builder()
.timeout(Duration::from_secs(60))
.max_retries(5)
.no_retry()
.build();
assert_eq!(builder.request_timeout, Duration::from_secs(60));
assert_eq!(builder.max_retries, 0);
}
}