use dashmap::DashMap;
use reqwest::{Client, ClientBuilder};
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use tracing::{debug, warn};
use crate::config::validation::is_private_or_internal_ip;
struct SsrfSafeDnsResolver;
impl reqwest::dns::Resolve for SsrfSafeDnsResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
let host = name.as_str().to_owned();
Box::pin(async move {
let addrs: std::io::Result<Vec<SocketAddr>> = tokio::task::spawn_blocking(move || {
(host.as_str(), 0u16)
.to_socket_addrs()
.map(|iter| iter.collect())
})
.await
.map_err(std::io::Error::other)?;
let addrs = addrs?;
let safe: Vec<SocketAddr> = addrs
.into_iter()
.filter(|addr| !is_private_or_internal_ip(&addr.ip()))
.collect();
if safe.is_empty() {
return Err(
"Host resolves to private/internal IP address (SSRF protection)"
.to_string()
.into(),
);
}
Ok(Box::new(safe.into_iter()) as reqwest::dns::Addrs)
})
}
}
#[derive(Debug, Clone)]
pub struct HttpClientPoolConfig {
pub pool_max_idle_per_host: usize,
pub pool_idle_timeout: Duration,
pub connect_timeout: Duration,
pub tcp_keepalive: Duration,
pub user_agent: &'static str,
}
impl Default for HttpClientPoolConfig {
fn default() -> Self {
Self {
pool_max_idle_per_host: 100, pool_idle_timeout: Duration::from_secs(90),
connect_timeout: Duration::from_secs(10),
tcp_keepalive: Duration::from_secs(60),
user_agent: "LiteLLM-RS/0.1.0",
}
}
}
static SHARED_HTTP_CLIENT: OnceLock<Client> = OnceLock::new();
static TIMEOUT_CLIENT_CACHE: OnceLock<DashMap<u64, Arc<Client>>> = OnceLock::new();
pub fn create_client_builder_with_config(
timeout: Duration,
config: &HttpClientPoolConfig,
) -> ClientBuilder {
ClientBuilder::new()
.pool_max_idle_per_host(config.pool_max_idle_per_host)
.pool_idle_timeout(config.pool_idle_timeout)
.timeout(timeout)
.connect_timeout(config.connect_timeout)
.tcp_keepalive(config.tcp_keepalive)
.tcp_nodelay(true)
.user_agent(config.user_agent)
}
pub fn create_client_builder(timeout: Duration) -> ClientBuilder {
create_client_builder_with_config(timeout, &HttpClientPoolConfig::default())
}
pub fn get_shared_client() -> &'static Client {
SHARED_HTTP_CLIENT.get_or_init(|| {
debug!("Initializing shared HTTP client with optimized settings");
create_optimized_client(Duration::from_secs(30))
})
}
pub fn get_client_with_timeout(timeout: Duration) -> Arc<Client> {
let cache = TIMEOUT_CLIENT_CACHE.get_or_init(DashMap::new);
let timeout_millis = timeout.as_millis().min(u64::MAX as u128) as u64;
cache
.entry(timeout_millis)
.or_insert_with(|| {
debug!(timeout_millis, "Creating cached HTTP client for timeout");
Arc::new(create_optimized_client(timeout))
})
.clone()
}
pub fn get_client_with_timeout_fallible(timeout: Duration) -> Result<Arc<Client>, reqwest::Error> {
let cache = TIMEOUT_CLIENT_CACHE.get_or_init(DashMap::new);
let timeout_millis = timeout.as_millis().min(u64::MAX as u128) as u64;
if let Some(existing) = cache.get(&timeout_millis) {
return Ok(existing.clone());
}
let client = Arc::new(create_custom_client(timeout)?);
cache.insert(timeout_millis, client.clone());
Ok(client)
}
fn create_optimized_client(timeout: Duration) -> Client {
let config = HttpClientPoolConfig::default();
create_client_builder_with_config(timeout, &config)
.build()
.unwrap_or_else(|e| {
warn!(
"Failed to create optimized HTTP client, falling back to default: {}",
e
);
Client::new()
})
}
pub fn create_custom_client_with_config(
timeout: Duration,
config: &HttpClientPoolConfig,
) -> Result<Client, reqwest::Error> {
create_client_builder_with_config(timeout, config).build()
}
pub fn create_custom_client(timeout: Duration) -> Result<Client, reqwest::Error> {
create_custom_client_with_config(timeout, &HttpClientPoolConfig::default())
}
pub fn get_ssrf_safe_client_with_timeout_fallible(
timeout: Duration,
) -> Result<Arc<Client>, reqwest::Error> {
create_client_builder_with_config(timeout, &HttpClientPoolConfig::default())
.dns_resolver(Arc::new(SsrfSafeDnsResolver))
.build()
.map(Arc::new)
}
pub fn create_custom_client_with_headers(
timeout: Duration,
default_headers: reqwest::header::HeaderMap,
) -> Result<Client, reqwest::Error> {
create_client_builder(timeout)
.default_headers(default_headers)
.build()
}
pub fn get_cache_stats() -> HttpClientCacheStats {
let cache = TIMEOUT_CLIENT_CACHE.get_or_init(DashMap::new);
HttpClientCacheStats {
cached_clients: cache.len(),
timeout_configs: cache.iter().map(|e| *e.key()).collect(),
}
}
#[derive(Debug, Clone)]
pub struct HttpClientCacheStats {
pub cached_clients: usize,
pub timeout_configs: Vec<u64>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shared_client_creation() {
let client = get_shared_client();
assert!(std::ptr::addr_of!(*client) == std::ptr::addr_of!(*get_shared_client()));
}
#[test]
fn test_custom_client_creation() {
let client = create_custom_client(Duration::from_secs(15));
assert!(client.is_ok());
}
#[test]
fn test_client_with_timeout_caching() {
let client1 = get_client_with_timeout(Duration::from_secs(60));
let client2 = get_client_with_timeout(Duration::from_secs(60));
assert!(Arc::ptr_eq(&client1, &client2));
let client3 = get_client_with_timeout(Duration::from_secs(120));
assert!(!Arc::ptr_eq(&client1, &client3));
}
#[test]
fn test_client_with_timeout_fallible_caching() {
let client1 = get_client_with_timeout_fallible(Duration::from_millis(1500)).unwrap();
let client2 = get_client_with_timeout_fallible(Duration::from_millis(1500)).unwrap();
assert!(Arc::ptr_eq(&client1, &client2));
}
#[test]
fn test_cache_stats() {
let _ = get_client_with_timeout(Duration::from_secs(30));
let _ = get_client_with_timeout(Duration::from_secs(45));
let stats = get_cache_stats();
assert!(stats.cached_clients >= 2);
assert!(stats.timeout_configs.contains(&30_000));
assert!(stats.timeout_configs.contains(&45_000));
}
#[test]
fn test_pool_config_defaults() {
let config = HttpClientPoolConfig::default();
assert_eq!(config.pool_max_idle_per_host, 100);
assert_eq!(config.pool_idle_timeout, Duration::from_secs(90));
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert_eq!(config.tcp_keepalive, Duration::from_secs(60));
assert_eq!(config.user_agent, "LiteLLM-RS/0.1.0");
}
}