use reqwest::Client;
use serde::Serialize;
use std::sync::Arc;
use crate::config::{ProviderConfig, SharedProviderConfig};
use crate::error::LlmConnectorError;
#[derive(Clone, Debug)]
pub struct HttpTransport {
pub client: Arc<Client>,
pub config: SharedProviderConfig,
}
impl HttpTransport {
pub fn new(client: Client, config: ProviderConfig) -> Self {
Self {
client: Arc::new(client),
config: SharedProviderConfig::new(config),
}
}
pub fn from_shared(client: Arc<Client>, config: SharedProviderConfig) -> Self {
Self { client, config }
}
pub fn build_client(
proxy: &Option<String>,
timeout_ms: Option<u64>,
base_url: Option<&String>,
) -> Result<Client, LlmConnectorError> {
let mut client_builder = Client::builder();
if let Some(proxy) = proxy {
client_builder = client_builder.proxy(reqwest::Proxy::all(proxy)?);
}
if let Some(timeout) = timeout_ms {
client_builder = client_builder.timeout(std::time::Duration::from_millis(timeout));
}
if let Some(base) = base_url {
if let Ok(url) = reqwest::Url::parse(base) {
if matches!(url.host_str(), Some("localhost") | Some("127.0.0.1")) {
client_builder = client_builder.no_proxy();
}
}
}
client_builder
.build()
.map_err(|e| LlmConnectorError::ConfigError(e.to_string()))
}
pub async fn get(&self, url: &str) -> Result<reqwest::Response, LlmConnectorError> {
let mut request = self
.client
.get(url)
.header("Authorization", format!("Bearer {}", &self.config.api_key));
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
request
.send()
.await
.map_err(LlmConnectorError::from)
}
pub async fn post<T: Serialize>(
&self,
url: &str,
body: &T,
) -> Result<reqwest::Response, LlmConnectorError> {
if std::env::var("LLM_DEBUG_REQUEST_RAW").map(|v| v == "1").unwrap_or(false) {
eprintln!("[request-debug] URL: {}", url);
if let Ok(json_body) = serde_json::to_string_pretty(body) {
eprintln!("[request-debug] Body: {}", json_body);
}
}
let mut request = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", &self.config.api_key))
.header("Content-Type", "application/json");
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
if std::env::var("LLM_DEBUG_REQUEST_RAW").map(|v| v == "1").unwrap_or(false) {
eprintln!("[request-debug] Header: {}: {}", key, value);
}
}
}
let response = request
.json(body)
.send()
.await
.map_err(|e| {
if std::env::var("LLM_DEBUG_REQUEST_RAW").map(|v| v == "1").unwrap_or(false) {
eprintln!("[request-error] Network error: {}", e);
eprintln!("[request-error] URL: {}", url);
if e.is_timeout() {
eprintln!("[request-error] This is a timeout error");
}
if e.is_connect() {
eprintln!("[request-error] This is a connection error");
}
}
LlmConnectorError::from(e)
})?;
if std::env::var("LLM_DEBUG_RESPONSE_RAW").map(|v| v == "1").unwrap_or(false) {
eprintln!("[response-debug] Status: {}", response.status());
eprintln!("[response-debug] Headers: {:?}", response.headers());
}
Ok(response)
}
#[cfg(feature = "streaming")]
pub async fn stream<T: Serialize>(
&self,
url: &str,
body: &T,
) -> Result<
impl futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>>,
LlmConnectorError,
> {
let mut request = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", &self.config.api_key))
.header("Content-Type", "application/json");
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
let response = request
.json(body)
.send()
.await
.map_err(LlmConnectorError::from)?;
if !response.status().is_success() {
return Err(LlmConnectorError::ProviderError(format!(
"HTTP error: {}",
response.status()
)));
}
Ok(response.bytes_stream())
}
}