use crate::error::LlmConnectorError;
use reqwest::Client;
use serde::Serialize;
use std::collections::HashMap;
use std::time::Duration;
#[derive(Clone)]
pub struct HttpClient {
client: Client,
base_url: String,
headers: HashMap<String, String>,
}
impl HttpClient {
pub fn new(base_url: &str) -> Result<Self, LlmConnectorError> {
let client = Client::builder()
.timeout(Duration::from_secs(60)) .no_proxy() .build()
.map_err(|e| LlmConnectorError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
client,
base_url: base_url.trim_end_matches('/').to_string(),
headers: HashMap::new(),
})
}
pub fn with_config(
base_url: &str,
timeout_secs: Option<u64>,
proxy: Option<&str>,
) -> Result<Self, LlmConnectorError> {
let mut builder = Client::builder();
if let Some(timeout) = timeout_secs {
builder = builder.timeout(Duration::from_secs(timeout));
} else {
builder = builder.timeout(Duration::from_secs(60)); }
if let Some(proxy_url) = proxy {
let proxy = reqwest::Proxy::all(proxy_url)
.map_err(|e| LlmConnectorError::ConfigError(format!("Invalid proxy URL: {}", e)))?;
builder = builder.proxy(proxy);
} else {
builder = builder.no_proxy();
}
let client = builder.build()
.map_err(|e| LlmConnectorError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
client,
base_url: base_url.trim_end_matches('/').to_string(),
headers: HashMap::new(),
})
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers.extend(headers);
self
}
pub fn with_header(mut self, key: String, value: String) -> Self {
self.headers.insert(key, value);
self
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub async fn get(&self, url: &str) -> Result<reqwest::Response, LlmConnectorError> {
let mut request = self.client.get(url);
for (key, value) in &self.headers {
request = request.header(key, value);
}
request.send().await
.map_err(|e| {
if e.is_timeout() {
LlmConnectorError::TimeoutError(format!("GET request timeout: {}", e))
} else if e.is_connect() {
LlmConnectorError::ConnectionError(format!("GET connection failed: {}", e))
} else {
LlmConnectorError::NetworkError(format!("GET request failed: {}", e))
}
})
}
pub async fn post<T: Serialize>(
&self,
url: &str,
body: &T
) -> Result<reqwest::Response, LlmConnectorError> {
let mut request = self.client.post(url).json(body);
for (key, value) in &self.headers {
request = request.header(key, value);
}
request.send().await
.map_err(|e| {
if e.is_timeout() {
LlmConnectorError::TimeoutError(format!("POST request timeout: {}", e))
} else if e.is_connect() {
LlmConnectorError::ConnectionError(format!("POST connection failed: {}", e))
} else {
LlmConnectorError::NetworkError(format!("POST request failed: {}", e))
}
})
}
#[cfg(feature = "streaming")]
pub async fn stream<T: Serialize>(
&self,
url: &str,
body: &T,
) -> Result<reqwest::Response, LlmConnectorError> {
let mut request = self.client.post(url).json(body);
request = request.header("Accept", "text/event-stream");
request = request.header("Cache-Control", "no-cache");
request = request.header("Connection", "keep-alive");
for (key, value) in &self.headers {
request = request.header(key, value);
}
request.send().await
.map_err(|e| {
if e.is_timeout() {
LlmConnectorError::TimeoutError(format!("Stream request timeout: {}. Consider increasing timeout for long-running streams.", e))
} else if e.is_connect() {
LlmConnectorError::ConnectionError(format!("Stream connection failed: {}", e))
} else {
LlmConnectorError::NetworkError(format!("Stream request failed: {}", e))
}
})
}
pub async fn post_with_custom_headers<T: Serialize>(
&self,
url: &str,
body: &T,
custom_headers: &HashMap<String, String>,
) -> Result<reqwest::Response, LlmConnectorError> {
let mut request = self.client.post(url).json(body);
for (key, value) in custom_headers {
request = request.header(key, value);
}
for (key, value) in &self.headers {
request = request.header(key, value);
}
request.send().await
.map_err(|e| {
if e.is_timeout() {
LlmConnectorError::TimeoutError(format!("POST request timeout: {}", e))
} else if e.is_connect() {
LlmConnectorError::ConnectionError(format!("POST connection failed: {}", e))
} else {
LlmConnectorError::NetworkError(format!("POST request failed: {}", e))
}
})
}
}
impl std::fmt::Debug for HttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpClient")
.field("base_url", &self.base_url)
.field("headers_count", &self.headers.len())
.finish()
}
}