use std::time::Duration;
use async_trait::async_trait;
use rand::{RngExt, rng};
use reqwest::Client;
use tokio::time::sleep;
use crate::error::{WxPayError, WxPayResult};
#[async_trait]
pub trait HttpClient: Send + Sync {
async fn get(&self, url: &str, headers: Vec<(String, String)>) -> WxPayResult<HttpResponse>;
async fn post(
&self,
url: &str,
headers: Vec<(String, String)>,
body: &str,
) -> WxPayResult<HttpResponse>;
async fn put(
&self,
url: &str,
headers: Vec<(String, String)>,
body: &str,
) -> WxPayResult<HttpResponse>;
async fn delete(&self, url: &str, headers: Vec<(String, String)>) -> WxPayResult<HttpResponse>;
async fn patch(
&self,
url: &str,
headers: Vec<(String, String)>,
body: &str,
) -> WxPayResult<HttpResponse>;
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: String,
}
impl HttpResponse {
pub fn new(status: u16, headers: Vec<(String, String)>, body: String) -> Self {
Self {
status,
headers,
body,
}
}
pub fn get_header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
}
pub struct ReqwestHttpClient {
client: Client,
max_retries: u32,
}
impl ReqwestHttpClient {
pub fn builder() -> ReqwestHttpClientBuilder {
ReqwestHttpClientBuilder::new()
}
fn is_retriable_status(status: u16) -> bool {
status == 429 || (500..=599).contains(&status)
}
fn retry_delay_ms(retry_count: u32) -> u64 {
let base = 40_u64.saturating_mul(1_u64 << retry_count.min(8));
let jitter = rng().random_range(0..=base / 2);
base.saturating_add(jitter)
}
async fn read_response(response: reqwest::Response) -> WxPayResult<HttpResponse> {
let status = response.status().as_u16();
let response_headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let body = response
.text()
.await
.map_err(|e| WxPayError::ResponseParseError(format!("读取响应体失败:{}", e)))?;
Ok(HttpResponse::new(status, response_headers, body))
}
async fn execute_with_retries<F>(
&self,
mut request_factory: F,
retry_on_error: bool,
) -> WxPayResult<HttpResponse>
where
F: FnMut() -> reqwest::RequestBuilder,
{
for attempt in 0..=self.max_retries {
let response = request_factory().send().await;
match response {
Ok(response) => {
let status = response.status().as_u16();
let response = Self::read_response(response).await?;
if retry_on_error
&& Self::is_retriable_status(status)
&& attempt < self.max_retries
{
let delay = Duration::from_millis(Self::retry_delay_ms(attempt + 1));
sleep(delay).await;
continue;
}
return Ok(response);
}
Err(error) => {
if retry_on_error && attempt < self.max_retries {
let delay = Duration::from_millis(Self::retry_delay_ms(attempt + 1));
sleep(delay).await;
continue;
}
return Err(WxPayError::NetworkError(error));
}
}
}
Err(WxPayError::Timeout)
}
fn append_headers(
request: reqwest::RequestBuilder,
headers: &[(String, String)],
) -> reqwest::RequestBuilder {
let mut request = request;
for (name, value) in headers {
request = request.header(name, value);
}
request
}
}
#[derive(Debug, Clone)]
pub struct ReqwestHttpClientBuilder {
timeout: u64,
max_idle_connections: usize,
idle_timeout: u64,
max_retries: u32,
}
impl ReqwestHttpClientBuilder {
pub fn new() -> Self {
Self {
timeout: 30,
max_idle_connections: 100,
idle_timeout: 90,
max_retries: 3,
}
}
pub fn timeout(mut self, timeout: u64) -> Self {
self.timeout = timeout;
self
}
pub fn max_idle_connections(mut self, max_idle_connections: usize) -> Self {
self.max_idle_connections = max_idle_connections;
self
}
pub fn idle_timeout(mut self, idle_timeout: u64) -> Self {
self.idle_timeout = idle_timeout;
self
}
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn build(self) -> WxPayResult<ReqwestHttpClient> {
let client = Client::builder()
.timeout(Duration::from_secs(self.timeout))
.pool_max_idle_per_host(self.max_idle_connections)
.pool_idle_timeout(Duration::from_secs(self.idle_timeout))
.build()
.map_err(|e| WxPayError::InternalError(format!("创建 HTTP 客户端失败:{}", e)))?;
Ok(ReqwestHttpClient {
client,
max_retries: self.max_retries,
})
}
}
impl Default for ReqwestHttpClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl HttpClient for ReqwestHttpClient {
async fn get(&self, url: &str, headers: Vec<(String, String)>) -> WxPayResult<HttpResponse> {
self.execute_with_retries(
|| {
let request = self.client.get(url);
Self::append_headers(request, &headers)
},
true,
)
.await
}
async fn post(
&self,
url: &str,
headers: Vec<(String, String)>,
body: &str,
) -> WxPayResult<HttpResponse> {
let body = body.to_string();
self.execute_with_retries(
move || {
let request = self.client.post(url).body(body.clone());
Self::append_headers(request, &headers)
},
false,
)
.await
}
async fn put(
&self,
url: &str,
headers: Vec<(String, String)>,
body: &str,
) -> WxPayResult<HttpResponse> {
let body = body.to_string();
self.execute_with_retries(
move || {
let request = self.client.put(url).body(body.clone());
Self::append_headers(request, &headers)
},
false,
)
.await
}
async fn delete(&self, url: &str, headers: Vec<(String, String)>) -> WxPayResult<HttpResponse> {
self.execute_with_retries(
|| {
let request = self.client.delete(url);
Self::append_headers(request, &headers)
},
true,
)
.await
}
async fn patch(
&self,
url: &str,
headers: Vec<(String, String)>,
body: &str,
) -> WxPayResult<HttpResponse> {
let body = body.to_string();
self.execute_with_retries(
move || {
let request = self.client.patch(url).body(body.clone());
Self::append_headers(request, &headers)
},
false,
)
.await
}
}
impl std::fmt::Debug for ReqwestHttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReqwestHttpClient").finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http_response() {
let response = HttpResponse::new(
200,
vec![
("Content-Type".to_string(), "application/json".to_string()),
("X-Request-Id".to_string(), "12345".to_string()),
],
r#"{"code":"SUCCESS"}"#.to_string(),
);
assert!(response.is_success());
assert_eq!(response.status, 200);
assert_eq!(
response.get_header("Content-Type"),
Some("application/json")
);
assert_eq!(response.get_header("X-Request-Id"), Some("12345"));
assert_eq!(response.get_header("Non-Existent"), None);
}
#[test]
fn test_http_response_not_success() {
let response = HttpResponse::new(400, vec![], r#"{"code":"PARAM_ERROR"}"#.to_string());
assert!(!response.is_success());
}
#[test]
fn test_reqwest_http_client_builder() {
let builder = ReqwestHttpClientBuilder::new()
.timeout(60)
.max_idle_connections(50)
.idle_timeout(120)
.max_retries(3);
assert_eq!(builder.timeout, 60);
assert_eq!(builder.max_idle_connections, 50);
assert_eq!(builder.idle_timeout, 120);
assert_eq!(builder.max_retries, 3);
}
#[test]
fn test_reqwest_http_client_builder_default() {
let builder = ReqwestHttpClientBuilder::default();
assert_eq!(builder.timeout, 30);
assert_eq!(builder.max_idle_connections, 100);
assert_eq!(builder.idle_timeout, 90);
assert_eq!(builder.max_retries, 3);
}
#[test]
fn test_is_retriable_status() {
assert!(ReqwestHttpClient::is_retriable_status(429));
assert!(ReqwestHttpClient::is_retriable_status(500));
assert!(ReqwestHttpClient::is_retriable_status(502));
assert!(ReqwestHttpClient::is_retriable_status(599));
assert!(!ReqwestHttpClient::is_retriable_status(200));
assert!(!ReqwestHttpClient::is_retriable_status(301));
assert!(!ReqwestHttpClient::is_retriable_status(400));
assert!(!ReqwestHttpClient::is_retriable_status(401));
assert!(!ReqwestHttpClient::is_retriable_status(404));
}
#[test]
fn test_retry_delay_is_bounded_and_increasing() {
let d0 = ReqwestHttpClient::retry_delay_ms(0);
assert!((40..=60).contains(&d0));
let base1 = 40 * (1u64 << 1); let d1 = ReqwestHttpClient::retry_delay_ms(1);
assert!((base1..=base1 + base1 / 2).contains(&d1));
let base3 = 40 * (1u64 << 3); let d3 = ReqwestHttpClient::retry_delay_ms(3);
assert!((base3..=base3 + base3 / 2).contains(&d3));
let huge = ReqwestHttpClient::retry_delay_ms(u32::MAX);
assert!(huge > 0);
}
#[test]
fn test_append_headers_applies_all() {
let headers = vec![
("Authorization".to_string(), "Bearer x".to_string()),
("Accept".to_string(), "application/json".to_string()),
];
let mut all = headers.clone();
all.push(("User-Agent".to_string(), "wxpay-rs".to_string()));
assert_eq!(all.len(), 3);
assert_eq!(all[0].0, "Authorization");
assert_eq!(all[2].1, "wxpay-rs");
}
}