use super::types::{LoadBalancer, LoadBalancingStrategy, ProviderStats};
use crate::sdk::{config::ClientConfig, config::SdkProviderConfig, errors::*};
use crate::utils::net::ClientUtils;
use crate::utils::net::http::{create_custom_client, create_streaming_client};
use reqwest;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::info;
#[derive(Debug)]
pub struct LLMClient {
pub(crate) config: ClientConfig,
pub(crate) http_client: reqwest::Client,
pub(crate) stream_http_client: reqwest::Client,
pub(crate) provider_stats: Arc<RwLock<HashMap<String, ProviderStats>>>,
pub(crate) load_balancer: Arc<LoadBalancer>,
}
impl LLMClient {
pub fn new(config: ClientConfig) -> Result<Self> {
if config.providers.is_empty() {
return Err(SDKError::ConfigError("No providers configured".to_string()));
}
let http_client = create_custom_client(Duration::from_secs(config.settings.timeout))
.map_err(|e| SDKError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
let stream_http_client = create_streaming_client().map_err(|e| {
SDKError::ConfigError(format!("Failed to create streaming HTTP client: {}", e))
})?;
let provider_stats = Arc::new(RwLock::new(HashMap::new()));
let load_balancer = Arc::new(LoadBalancer::new(LoadBalancingStrategy::WeightedRandom));
info!(
"LLMClient created with {} providers",
config.providers.len()
);
Ok(Self {
config,
http_client,
stream_http_client,
provider_stats,
load_balancer,
})
}
pub async fn new_async(config: ClientConfig) -> Result<Self> {
let client = Self::new(config)?;
client.initialize_providers().await?;
Ok(client)
}
pub(crate) async fn initialize_providers(&self) -> Result<()> {
use tracing::debug;
let mut stats = self.provider_stats.write().await;
for provider in &self.config.providers {
let provider_stats = ProviderStats {
health_score: 1.0, ..Default::default()
};
stats.insert(provider.id.clone(), provider_stats);
debug!("Initialized provider: {}", provider.id);
}
Ok(())
}
pub fn list_providers(&self) -> Vec<String> {
self.config.providers.iter().map(|p| p.id.clone()).collect()
}
pub fn config(&self) -> &ClientConfig {
&self.config
}
pub(crate) fn provider_config(&self, provider_id: &str) -> Result<&SdkProviderConfig> {
self.config
.providers
.iter()
.find(|provider| provider.id == provider_id)
.ok_or_else(|| SDKError::ProviderNotFound(provider_id.to_string()))
}
pub(crate) fn default_enabled_provider(&self) -> Option<&SdkProviderConfig> {
self.config
.default_provider
.as_ref()
.and_then(|provider_id| self.provider_config(provider_id).ok())
.filter(|provider| provider.enabled)
}
pub(crate) fn first_enabled_provider(&self) -> Result<&SdkProviderConfig> {
self.config
.providers
.iter()
.find(|provider| provider.enabled)
.ok_or(SDKError::NoDefaultProvider)
}
pub(crate) fn provider_for_model(&self, model: &str) -> Result<&SdkProviderConfig> {
self.config
.providers
.iter()
.find(|provider| {
provider.enabled && provider.models.iter().any(|candidate| candidate == model)
})
.ok_or_else(|| {
SDKError::ModelNotFound(format!("Model '{}' not supported by any provider", model))
})
}
pub(crate) fn provider_default_model<'a>(
&self,
provider: &'a SdkProviderConfig,
fallback: &'a str,
) -> &'a str {
provider
.models
.first()
.map(String::as_str)
.unwrap_or(fallback)
}
pub(crate) fn provider_base_url<'a>(
&self,
provider: &'a SdkProviderConfig,
fallback: &'a str,
) -> &'a str {
provider.base_url.as_deref().unwrap_or(fallback)
}
pub(crate) fn provider_endpoint(
&self,
provider: &SdkProviderConfig,
fallback_base: &str,
endpoint: &str,
) -> String {
ClientUtils::add_path_to_api_base(self.provider_base_url(provider, fallback_base), endpoint)
}
pub(crate) fn anthropic_messages_endpoint(&self, provider: &SdkProviderConfig) -> String {
let base_url = self.provider_base_url(provider, "https://api.anthropic.com");
let endpoint = if base_url.contains("/v1") {
"messages"
} else {
"v1/messages"
};
ClientUtils::add_path_to_api_base(base_url, endpoint)
}
pub async fn health_check(&self) -> Result<HashMap<String, bool>> {
let mut health_status = HashMap::new();
for provider in &self.config.providers {
let is_healthy = self.check_provider_health(&provider.id).await.is_ok();
health_status.insert(provider.id.clone(), is_healthy);
}
Ok(health_status)
}
pub(crate) async fn check_provider_health(&self, provider_id: &str) -> Result<()> {
use crate::sdk::types::{ChatOptions, Content, Message, Role, SdkChatRequest};
let simple_request = SdkChatRequest {
model: String::new(),
messages: vec![Message {
role: Role::User,
content: Some(Content::Text("Hi".to_string())),
name: None,
tool_calls: None,
}],
options: ChatOptions {
max_tokens: Some(1),
..Default::default()
},
};
self.execute_chat_request(provider_id, simple_request)
.await?;
Ok(())
}
}