use std::collections::HashMap;
use std::sync::Arc;
use crate::config::Config;
use crate::error::LlmConnectorError;
use crate::providers::{Provider, utils};
use crate::types::{ChatRequest, ChatResponse};
#[cfg(feature = "streaming")]
use crate::types::ChatStream;
pub struct Client {
providers: HashMap<String, Arc<dyn Provider>>,
config: Config,
}
impl Client {
pub fn with_config(config: Config) -> Self {
let mut client = Self {
providers: HashMap::new(),
config: config.clone(),
};
#[cfg(feature = "reqwest")]
client.initialize_providers();
client
}
pub fn from_env() -> Self {
Self::with_config(Config::from_env())
}
#[cfg(feature = "reqwest")]
fn initialize_providers(&mut self) {
}
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
utils::validate_chat_request(&request)?;
let provider = self.get_provider_for_model(&request.model)?;
let mut cleaned_request = request;
cleaned_request.model = utils::clean_model_name(&cleaned_request.model).to_string();
provider.chat(&cleaned_request).await
}
#[cfg(feature = "streaming")]
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream, LlmConnectorError> {
utils::validate_chat_request(&request)?;
let provider = self.get_provider_for_model(&request.model)?;
let mut cleaned_request = request;
cleaned_request.model = utils::clean_model_name(&cleaned_request.model).to_string();
cleaned_request.stream = Some(true);
provider.chat_stream(&cleaned_request).await
}
fn get_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>, LlmConnectorError> {
let provider_name = utils::detect_provider_from_model(model)
.ok_or_else(|| LlmConnectorError::UnsupportedModel(model.to_string()))?;
self.providers
.get(provider_name)
.cloned()
.ok_or_else(|| {
LlmConnectorError::ConfigError(
format!("Provider '{}' not configured", provider_name)
)
})
}
pub fn list_models(&self) -> Vec<String> {
let mut models = Vec::new();
for (provider_name, provider) in &self.providers {
for model in provider.supported_models() {
models.push(format!("{}/{}", provider_name, model));
models.push(model);
}
}
models.sort();
models.dedup();
models
}
pub fn list_providers(&self) -> Vec<String> {
self.config.list_providers()
}
pub fn supports_model(&self, model: &str) -> bool {
if let Ok(provider) = self.get_provider_for_model(model) {
let clean_model = utils::clean_model_name(model);
provider.supports_model(clean_model)
} else {
false
}
}
pub fn get_provider_info(&self, model: &str) -> Option<String> {
utils::detect_provider_from_model(model).map(|s| s.to_string())
}
}
impl Default for Client {
fn default() -> Self {
Self::from_env()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Message;
#[test]
fn test_client_creation() {
let config = Config::default();
let client = Client::with_config(config);
assert_eq!(client.list_providers().len(), 0);
}
#[test]
fn test_model_support_detection() {
let client = Client::default();
assert!(client.get_provider_info("gpt-4").is_some());
assert!(client.get_provider_info("claude-3-haiku").is_some());
assert!(client.get_provider_info("deepseek-chat").is_some());
assert!(client.get_provider_info("unknown-model").is_none());
}
#[tokio::test]
async fn test_request_validation() {
let client = Client::default();
let request = ChatRequest {
model: "".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "Hello".to_string(),
..Default::default()
}],
..Default::default()
};
let result = client.chat(request).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), LlmConnectorError::InvalidRequest(_)));
}
}