use crate::auth::AuthMethod;
use crate::network::{Http2Config, TlsConfig};
use crate::stats::{RequestResult, StatsCollector};
use crate::user_agent::UserAgentManager;
use anyhow::Result;
use chrono::Utc;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::{Client, Method};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub struct HttpClient {
client: Client,
url: String,
method: Method,
headers: HeaderMap,
payload: Option<String>,
user_agent_manager: Arc<UserAgentManager>,
stats_collector: Arc<StatsCollector>,
auth_method: Arc<AuthMethod>,
http2_config: Arc<Http2Config>,
tls_config: Arc<TlsConfig>,
retry_count: u32,
expect_body: Option<String>,
}
impl HttpClient {
#[allow(clippy::too_many_arguments)]
pub fn new(
url: String,
method: Method,
headers: Vec<String>,
payload: Option<String>,
user_agent_manager: Arc<UserAgentManager>,
stats_collector: Arc<StatsCollector>,
timeout: Option<Duration>,
auth_method: Arc<AuthMethod>,
http2_config: Arc<Http2Config>,
tls_config: Arc<TlsConfig>,
follow_redirects: bool,
retry_count: u32,
expect_body: Option<String>,
) -> Result<Self> {
let mut client_builder = Client::builder();
if let Some(timeout) = timeout {
client_builder = client_builder.timeout(timeout);
}
if !follow_redirects {
client_builder = client_builder.redirect(reqwest::redirect::Policy::none());
}
client_builder = http2_config.apply_to_client_builder(client_builder);
client_builder = tls_config.apply_to_client_builder(client_builder)?;
let client = client_builder.build()?;
let mut header_map = HeaderMap::new();
for header in headers {
if let Some((key, value)) = header.split_once(':') {
let key = key.trim();
let value = value.trim();
if let (Ok(name), Ok(val)) = (
HeaderName::from_bytes(key.as_bytes()),
HeaderValue::from_str(value),
) {
header_map.insert(name, val);
}
}
}
Ok(Self {
client,
url,
method,
headers: header_map,
payload,
user_agent_manager,
stats_collector,
auth_method,
http2_config,
tls_config,
retry_count,
expect_body,
})
}
pub async fn send_request(&self) -> Result<()> {
let timestamp = Utc::now();
let user_agent = self.user_agent_manager.get_user_agent().to_string();
let mut url = self.url.clone();
if let Some((key, value)) = self.auth_method.get_query_params().await {
let separator = if url.contains('?') { "&" } else { "?" };
url = format!("{url}{separator}{key}={value}");
}
let max_attempts = self.retry_count + 1;
for attempt in 0..max_attempts {
let mut builder = self.client.request(self.method.clone(), &url);
builder = builder.header("User-Agent", &user_agent);
if let Some((key, value)) = self.auth_method.get_auth_header().await {
builder = builder.header(key, value);
}
for (key, value) in &self.headers {
builder = builder.header(key, value);
}
if let Some(payload) = &self.payload {
if self.method == Method::POST
|| self.method == Method::PUT
|| self.method == Method::PATCH
{
builder = builder.body(payload.clone());
if !self.headers.contains_key("content-type")
&& (payload.trim_start().starts_with('{')
|| payload.trim_start().starts_with('['))
{
builder = builder.header("Content-Type", "application/json");
}
}
}
let start = Instant::now();
let request = builder.build()?;
match self.client.execute(request).await {
Ok(response) => {
let status_code = response.status().as_u16();
let duration = start.elapsed().as_millis() as u64;
let mut error = if status_code >= 400 {
Some(format!(
"HTTP {}: {}",
status_code,
response.status().canonical_reason().unwrap_or("Unknown")
))
} else {
None
};
let body_bytes = response.bytes().await.unwrap_or_default();
let bytes_received = body_bytes.len() as u64;
if error.is_none() {
if let Some(expected) = &self.expect_body {
let body_str = String::from_utf8_lossy(&body_bytes);
if !body_str.contains(expected.as_str()) {
error = Some(format!(
"body validation failed: expected {:?} in response",
expected
));
}
}
}
let result = RequestResult {
timestamp,
duration_ms: duration,
status_code: Some(status_code),
error,
user_agent: Some(user_agent.clone()),
bytes_received,
};
self.stats_collector.record_request(result).await;
return Ok(());
}
Err(e) => {
if attempt + 1 < max_attempts {
continue;
}
let duration = start.elapsed().as_millis() as u64;
let err_str = e.to_string();
let result = RequestResult {
timestamp,
duration_ms: duration,
status_code: None,
error: Some(err_str),
user_agent: Some(user_agent.clone()),
bytes_received: 0,
};
self.stats_collector.record_request(result).await;
return Err(e.into());
}
}
}
unreachable!()
}
pub fn get_http2_config(&self) -> &Http2Config {
&self.http2_config
}
pub fn get_tls_config(&self) -> &TlsConfig {
&self.tls_config
}
pub fn get_protocol_info(&self) -> crate::network::ProtocolInfo {
self.http2_config.get_protocol_info()
}
pub fn get_tls_info(&self) -> crate::network::TlsInfo {
self.tls_config.get_summary()
}
}