litellm-rs 0.4.16

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
Documentation
//! Shared HTTP client for optimal connection pooling
//!
//! This module provides a high-performance shared HTTP client with connection reuse.
//! Using a shared client avoids the overhead of creating new connection pools and
//! DNS resolution caches for each request.
//!
//! # Performance Benefits
//!
//! - **Connection Reuse**: Keeps TCP connections alive across requests
//! - **DNS Caching**: Avoids repeated DNS lookups
//! - **HTTP/2 Multiplexing**: Multiple requests over a single connection
//! - **Reduced Latency**: 20-50% improvement in request latency
//!
//! # Usage
//!
//! ```rust,no_run
//! # use litellm_rs::utils::net::http::get_shared_client;
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! let client = get_shared_client();
//! let response = client.get("https://api.openai.com").send().await?;
//! # Ok(())
//! # }
//! ```

use dashmap::DashMap;
use reqwest::{Client, ClientBuilder};
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use tracing::{debug, warn};

use crate::config::validation::is_private_or_internal_ip;

/// DNS resolver that rejects private/internal IP addresses at resolution time.
///
/// This mitigates DNS-rebinding attacks: even if a hostname resolves to a public IP
/// at config-validation time, every actual request re-validates the resolved address,
/// so a later rebind to an internal IP will be caught and rejected.
struct SsrfSafeDnsResolver;

impl reqwest::dns::Resolve for SsrfSafeDnsResolver {
    fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
        let host = name.as_str().to_owned();
        Box::pin(async move {
            let addrs: std::io::Result<Vec<SocketAddr>> = tokio::task::spawn_blocking(move || {
                (host.as_str(), 0u16)
                    .to_socket_addrs()
                    .map(|iter| iter.collect())
            })
            .await
            .map_err(std::io::Error::other)?;

            let addrs = addrs?;
            let safe: Vec<SocketAddr> = addrs
                .into_iter()
                .filter(|addr| !is_private_or_internal_ip(&addr.ip()))
                .collect();

            if safe.is_empty() {
                return Err(
                    "Host resolves to private/internal IP address (SSRF protection)"
                        .to_string()
                        .into(),
                );
            }

            Ok(Box::new(safe.into_iter()) as reqwest::dns::Addrs)
        })
    }
}

/// Configuration for the HTTP client pool
#[derive(Debug, Clone)]
pub struct HttpClientPoolConfig {
    /// Maximum idle connections per host
    pub pool_max_idle_per_host: usize,
    /// Idle connection timeout
    pub pool_idle_timeout: Duration,
    /// Connection timeout
    pub connect_timeout: Duration,
    /// TCP keepalive interval
    pub tcp_keepalive: Duration,
    /// User agent string
    pub user_agent: &'static str,
}

impl Default for HttpClientPoolConfig {
    fn default() -> Self {
        Self {
            pool_max_idle_per_host: 100, // Increased for high throughput
            pool_idle_timeout: Duration::from_secs(90),
            connect_timeout: Duration::from_secs(10),
            tcp_keepalive: Duration::from_secs(60),
            user_agent: "LiteLLM-RS/0.1.0",
        }
    }
}

/// Shared HTTP client instance with optimized settings
static SHARED_HTTP_CLIENT: OnceLock<Client> = OnceLock::new();

/// Timeout-specific client cache (keyed by milliseconds)
static TIMEOUT_CLIENT_CACHE: OnceLock<DashMap<u64, Arc<Client>>> = OnceLock::new();

/// Create a reqwest client builder with unified pool/timeout defaults.
pub fn create_client_builder_with_config(
    timeout: Duration,
    config: &HttpClientPoolConfig,
) -> ClientBuilder {
    ClientBuilder::new()
        // Connection pool settings
        .pool_max_idle_per_host(config.pool_max_idle_per_host)
        .pool_idle_timeout(config.pool_idle_timeout)
        // Request timeouts
        .timeout(timeout)
        .connect_timeout(config.connect_timeout)
        // TCP optimizations
        .tcp_keepalive(config.tcp_keepalive)
        .tcp_nodelay(true)
        // User agent
        .user_agent(config.user_agent)
}

/// Create a reqwest client builder with default pool configuration.
pub fn create_client_builder(timeout: Duration) -> ClientBuilder {
    create_client_builder_with_config(timeout, &HttpClientPoolConfig::default())
}

/// Get the shared HTTP client instance
///
/// This client uses a default timeout of 30 seconds. For custom timeouts,
/// use `get_client_with_timeout`.
pub fn get_shared_client() -> &'static Client {
    SHARED_HTTP_CLIENT.get_or_init(|| {
        debug!("Initializing shared HTTP client with optimized settings");
        create_optimized_client(Duration::from_secs(30))
    })
}

/// Get or create a client with a specific timeout
///
/// Clients are cached by timeout duration (in milliseconds) to avoid creating
/// multiple clients with the same configuration.
pub fn get_client_with_timeout(timeout: Duration) -> Arc<Client> {
    let cache = TIMEOUT_CLIENT_CACHE.get_or_init(DashMap::new);
    let timeout_millis = timeout.as_millis().min(u64::MAX as u128) as u64;

    cache
        .entry(timeout_millis)
        .or_insert_with(|| {
            debug!(timeout_millis, "Creating cached HTTP client for timeout");
            Arc::new(create_optimized_client(timeout))
        })
        .clone()
}

/// Get or create a client with a specific timeout, returning errors on failure
///
/// This is useful when caller error semantics must be preserved.
pub fn get_client_with_timeout_fallible(timeout: Duration) -> Result<Arc<Client>, reqwest::Error> {
    let cache = TIMEOUT_CLIENT_CACHE.get_or_init(DashMap::new);
    let timeout_millis = timeout.as_millis().min(u64::MAX as u128) as u64;

    if let Some(existing) = cache.get(&timeout_millis) {
        return Ok(existing.clone());
    }

    let client = Arc::new(create_custom_client(timeout)?);
    cache.insert(timeout_millis, client.clone());
    Ok(client)
}

/// Create an optimized HTTP client with the given timeout
fn create_optimized_client(timeout: Duration) -> Client {
    let config = HttpClientPoolConfig::default();

    create_client_builder_with_config(timeout, &config)
        .build()
        .unwrap_or_else(|e| {
            warn!(
                "Failed to create optimized HTTP client, falling back to default: {}",
                e
            );
            Client::new()
        })
}

/// Create a custom HTTP client with specific timeout and pool configuration.
pub fn create_custom_client_with_config(
    timeout: Duration,
    config: &HttpClientPoolConfig,
) -> Result<Client, reqwest::Error> {
    create_client_builder_with_config(timeout, config).build()
}

/// Create a custom HTTP client with specific timeout
///
/// Use this when you need a one-off client that won't be reused.
/// For reusable clients, prefer `get_client_with_timeout`.
pub fn create_custom_client(timeout: Duration) -> Result<Client, reqwest::Error> {
    create_custom_client_with_config(timeout, &HttpClientPoolConfig::default())
}

/// Get or create an HTTP client with SSRF-safe DNS resolution for the given timeout.
///
/// Unlike `get_client_with_timeout_fallible`, this client installs `SsrfSafeDnsResolver`
/// so every request re-validates the resolved IP against private/internal ranges.
/// Use this for providers whose endpoint URL is user-controlled to prevent DNS-rebinding attacks.
pub fn get_ssrf_safe_client_with_timeout_fallible(
    timeout: Duration,
) -> Result<Arc<Client>, reqwest::Error> {
    create_client_builder_with_config(timeout, &HttpClientPoolConfig::default())
        .dns_resolver(Arc::new(SsrfSafeDnsResolver))
        .build()
        .map(Arc::new)
}

/// Create a custom HTTP client with specific timeout and default headers
pub fn create_custom_client_with_headers(
    timeout: Duration,
    default_headers: reqwest::header::HeaderMap,
) -> Result<Client, reqwest::Error> {
    create_client_builder(timeout)
        .default_headers(default_headers)
        .build()
}

/// Get statistics about the client cache
pub fn get_cache_stats() -> HttpClientCacheStats {
    let cache = TIMEOUT_CLIENT_CACHE.get_or_init(DashMap::new);
    HttpClientCacheStats {
        cached_clients: cache.len(),
        timeout_configs: cache.iter().map(|e| *e.key()).collect(),
    }
}

/// Statistics about the HTTP client cache
#[derive(Debug, Clone)]
pub struct HttpClientCacheStats {
    /// Number of cached clients
    pub cached_clients: usize,
    /// List of cached timeout configurations (in milliseconds)
    pub timeout_configs: Vec<u64>,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_shared_client_creation() {
        let client = get_shared_client();
        // Just verify we can get the client without panicking
        assert!(std::ptr::addr_of!(*client) == std::ptr::addr_of!(*get_shared_client()));
    }

    #[test]
    fn test_custom_client_creation() {
        let client = create_custom_client(Duration::from_secs(15));
        assert!(client.is_ok());
    }

    #[test]
    fn test_client_with_timeout_caching() {
        let client1 = get_client_with_timeout(Duration::from_secs(60));
        let client2 = get_client_with_timeout(Duration::from_secs(60));

        // Same timeout should return the same cached client
        assert!(Arc::ptr_eq(&client1, &client2));

        // Different timeout should return different client
        let client3 = get_client_with_timeout(Duration::from_secs(120));
        assert!(!Arc::ptr_eq(&client1, &client3));
    }

    #[test]
    fn test_client_with_timeout_fallible_caching() {
        let client1 = get_client_with_timeout_fallible(Duration::from_millis(1500)).unwrap();
        let client2 = get_client_with_timeout_fallible(Duration::from_millis(1500)).unwrap();

        assert!(Arc::ptr_eq(&client1, &client2));
    }

    #[test]
    fn test_cache_stats() {
        // Ensure some clients are cached
        let _ = get_client_with_timeout(Duration::from_secs(30));
        let _ = get_client_with_timeout(Duration::from_secs(45));

        let stats = get_cache_stats();
        assert!(stats.cached_clients >= 2);
        assert!(stats.timeout_configs.contains(&30_000));
        assert!(stats.timeout_configs.contains(&45_000));
    }

    #[test]
    fn test_pool_config_defaults() {
        let config = HttpClientPoolConfig::default();
        assert_eq!(config.pool_max_idle_per_host, 100);
        assert_eq!(config.pool_idle_timeout, Duration::from_secs(90));
        assert_eq!(config.connect_timeout, Duration::from_secs(10));
        assert_eq!(config.tcp_keepalive, Duration::from_secs(60));
        assert_eq!(config.user_agent, "LiteLLM-RS/0.1.0");
    }
}