use anyhow::{Context, Result};
use async_trait::async_trait;
use futures::StreamExt;
use std::env;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
pub struct HttpResponse {
pub status: u16,
pub body: String,
}
pub struct StreamingHttpResponse {
pub status: u16,
pub retry_after: Option<String>,
pub byte_stream: Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes>> + Send>>,
pub error_body: String,
}
#[derive(Debug, Clone)]
pub struct HttpMetricsRecord {
pub url: String,
pub method: String,
pub status: u16,
pub duration_ms: f64,
pub request_bytes: u64,
pub response_bytes: u64,
pub streaming: bool,
}
pub type HttpMetricsCallback = Arc<dyn Fn(HttpMetricsRecord) + Send + Sync>;
static HTTP_METRICS_CALLBACK: std::sync::RwLock<Option<HttpMetricsCallback>> =
std::sync::RwLock::new(None);
pub fn set_http_metrics_callback(callback: HttpMetricsCallback) {
*HTTP_METRICS_CALLBACK.write().unwrap() = Some(callback);
}
pub fn clear_http_metrics_callback() {
*HTTP_METRICS_CALLBACK.write().unwrap() = None;
}
fn maybe_record_metrics(record: HttpMetricsRecord) {
if let Some(callback) = HTTP_METRICS_CALLBACK.read().unwrap().as_ref() {
callback(record);
}
}
#[async_trait]
pub trait HttpClient: Send + Sync {
async fn post(
&self,
url: &str,
headers: Vec<(&str, &str)>,
body: &serde_json::Value,
cancel_token: CancellationToken,
) -> Result<HttpResponse>;
async fn post_streaming(
&self,
url: &str,
headers: Vec<(&str, &str)>,
body: &serde_json::Value,
cancel_token: CancellationToken,
) -> Result<StreamingHttpResponse>;
}
pub struct ReqwestHttpClient {
client: reqwest::Client,
}
impl ReqwestHttpClient {
pub fn new() -> Self {
Self {
client: build_reqwest_client(None, None).expect("failed to build default HTTP client"),
}
}
}
impl Default for ReqwestHttpClient {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl HttpClient for ReqwestHttpClient {
async fn post(
&self,
url: &str,
headers: Vec<(&str, &str)>,
body: &serde_json::Value,
cancel_token: CancellationToken,
) -> Result<HttpResponse> {
let start = std::time::Instant::now();
let request_body = serde_json::to_string(body).unwrap_or_default();
let request_bytes = request_body.len() as u64;
tracing::debug!(
"HTTP POST to {}: {}",
url,
serde_json::to_string_pretty(body)?
);
let mut request = self.client.post(url);
for (key, value) in headers {
request = request.header(key, value);
}
request = request.json(body);
let response = tokio::select! {
_ = cancel_token.cancelled() => {
anyhow::bail!("HTTP request cancelled");
}
result = request.send() => {
result.context(format!("Failed to send request to {}", url))?
}
};
let status = response.status().as_u16();
let response_body = response.text().await?;
let response_bytes = response_body.len() as u64;
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
maybe_record_metrics(HttpMetricsRecord {
url: url.to_string(),
method: "POST".to_string(),
status,
duration_ms,
request_bytes,
response_bytes,
streaming: false,
});
Ok(HttpResponse {
status,
body: response_body,
})
}
async fn post_streaming(
&self,
url: &str,
headers: Vec<(&str, &str)>,
body: &serde_json::Value,
cancel_token: CancellationToken,
) -> Result<StreamingHttpResponse> {
let start = std::time::Instant::now();
let request_body = serde_json::to_string(body).unwrap_or_default();
let request_bytes = request_body.len() as u64;
let mut request = self.client.post(url);
for (key, value) in headers {
request = request.header(key, value);
}
request = request.json(body);
let response = tokio::select! {
_ = cancel_token.cancelled() => {
anyhow::bail!("HTTP streaming request cancelled");
}
result = request.send() => {
result.context(format!("Failed to send streaming request to {}", url))?
}
};
let status = response.status().as_u16();
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(String::from);
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
maybe_record_metrics(HttpMetricsRecord {
url: url.to_string(),
method: "POST".to_string(),
status,
duration_ms,
request_bytes,
response_bytes: 0, streaming: true,
});
if (200..300).contains(&status) {
let byte_stream = response
.bytes_stream()
.map(|r| r.map_err(|e| anyhow::anyhow!("Stream error: {}", e)));
Ok(StreamingHttpResponse {
status,
retry_after,
byte_stream: Box::pin(byte_stream),
error_body: String::new(),
})
} else {
let error_body = response.text().await.unwrap_or_default();
let empty: futures::stream::Empty<Result<bytes::Bytes>> = futures::stream::empty();
Ok(StreamingHttpResponse {
status,
retry_after,
byte_stream: Box::pin(empty),
error_body,
})
}
}
}
pub fn default_http_client() -> Arc<dyn HttpClient> {
Arc::new(ReqwestHttpClient::new())
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
struct ExplicitProxyConfig {
http: Option<String>,
https: Option<String>,
}
pub(crate) fn build_reqwest_client(
timeout: Option<Duration>,
default_headers: Option<reqwest::header::HeaderMap>,
) -> Result<reqwest::Client> {
let mut builder = reqwest::Client::builder().no_proxy();
if let Some(timeout) = timeout {
builder = builder.timeout(timeout);
}
if let Some(default_headers) = default_headers {
builder = builder.default_headers(default_headers);
}
let proxy_config = explicit_proxy_config_from_env();
if let Some(http_proxy) = proxy_config.http.as_deref() {
builder = builder.proxy(
reqwest::Proxy::http(http_proxy)
.with_context(|| format!("Invalid HTTP proxy URL: {http_proxy}"))?,
);
}
if let Some(https_proxy) = proxy_config.https.as_deref() {
builder = builder.proxy(
reqwest::Proxy::https(https_proxy)
.with_context(|| format!("Invalid HTTPS proxy URL: {https_proxy}"))?,
);
}
builder.build().context("Failed to build reqwest client")
}
fn explicit_proxy_config_from_env() -> ExplicitProxyConfig {
let http = first_non_empty_env(&["http_proxy", "HTTP_PROXY"]);
let https = first_non_empty_env(&["https_proxy", "HTTPS_PROXY"]).or_else(|| http.clone());
ExplicitProxyConfig { http, https }
}
fn first_non_empty_env(keys: &[&str]) -> Option<String> {
keys.iter().find_map(|key| {
env::var(key)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
})
}
pub(crate) fn normalize_base_url(base_url: &str) -> String {
base_url
.trim_end_matches('/')
.trim_end_matches("/v1")
.trim_end_matches('/')
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn proxy_env_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
fn clear_proxy_env() {
for key in ["http_proxy", "HTTP_PROXY", "https_proxy", "HTTPS_PROXY"] {
unsafe { env::remove_var(key) };
}
}
#[test]
fn test_normalize_base_url() {
assert_eq!(
normalize_base_url("https://api.example.com"),
"https://api.example.com"
);
assert_eq!(
normalize_base_url("https://api.example.com/"),
"https://api.example.com"
);
assert_eq!(
normalize_base_url("https://api.example.com/v1"),
"https://api.example.com"
);
assert_eq!(
normalize_base_url("https://api.example.com/v1/"),
"https://api.example.com"
);
}
#[test]
fn test_normalize_base_url_edge_cases() {
assert_eq!(
normalize_base_url("http://localhost:8080/v1"),
"http://localhost:8080"
);
assert_eq!(
normalize_base_url("http://localhost:8080"),
"http://localhost:8080"
);
assert_eq!(
normalize_base_url("https://api.example.com/v1/"),
"https://api.example.com"
);
}
#[test]
fn test_normalize_base_url_multiple_trailing_slashes() {
assert_eq!(
normalize_base_url("https://api.example.com//"),
"https://api.example.com"
);
}
#[test]
fn test_normalize_base_url_with_port() {
assert_eq!(
normalize_base_url("http://localhost:11434/v1/"),
"http://localhost:11434"
);
}
#[test]
fn test_normalize_base_url_already_normalized() {
assert_eq!(
normalize_base_url("https://api.openai.com"),
"https://api.openai.com"
);
}
#[test]
fn test_normalize_base_url_empty_string() {
assert_eq!(normalize_base_url(""), "");
}
#[test]
fn test_default_http_client_creation() {
let _client = default_http_client();
}
#[test]
fn test_explicit_proxy_config_from_env_prefers_lowercase_vars() {
let _guard = proxy_env_lock().lock().unwrap();
clear_proxy_env();
unsafe {
env::set_var("http_proxy", "http://lower-http:3128");
env::set_var("HTTP_PROXY", "http://upper-http:3128");
env::set_var("https_proxy", "http://lower-https:3128");
env::set_var("HTTPS_PROXY", "http://upper-https:3128");
}
let proxy_config = explicit_proxy_config_from_env();
assert_eq!(
proxy_config,
ExplicitProxyConfig {
http: Some("http://lower-http:3128".to_string()),
https: Some("http://lower-https:3128".to_string()),
}
);
clear_proxy_env();
}
#[test]
fn test_explicit_proxy_config_from_env_falls_back_to_http_for_https() {
let _guard = proxy_env_lock().lock().unwrap();
clear_proxy_env();
unsafe {
env::set_var("HTTP_PROXY", "http://proxy.example:3128");
}
let proxy_config = explicit_proxy_config_from_env();
assert_eq!(
proxy_config,
ExplicitProxyConfig {
http: Some("http://proxy.example:3128".to_string()),
https: Some("http://proxy.example:3128".to_string()),
}
);
clear_proxy_env();
}
#[test]
fn test_build_reqwest_client_accepts_proxy_env_urls() {
let _guard = proxy_env_lock().lock().unwrap();
clear_proxy_env();
unsafe {
env::set_var("http_proxy", "http://127.0.0.1:3128");
env::set_var("https_proxy", "http://127.0.0.1:3128");
}
let client = build_reqwest_client(None, None);
assert!(client.is_ok());
clear_proxy_env();
}
}