use std::collections::HashMap;
use std::sync::Arc;
use crate::config::Config;
use crate::error::LlmConnectorError;
use crate::providers::Provider;
use crate::types::{ChatRequest, ChatResponse};
#[cfg(feature = "streaming")]
use crate::types::ChatStream;
pub struct Client {
providers: HashMap<String, Arc<dyn Provider>>,
config: Config,
}
impl Client {
pub fn with_config(config: Config) -> Self {
let mut client = Self {
providers: HashMap::new(),
config: config.clone(),
};
#[cfg(feature = "reqwest")]
client.initialize_providers();
client
}
pub fn from_env() -> Self {
Self::with_config(Config::from_env())
}
#[cfg(feature = "reqwest")]
fn initialize_providers(&mut self) {
use crate::protocols::GenericProvider;
if let Some(openai_config) = &self.config.openai {
if let Ok(provider) = GenericProvider::new(
openai_config.clone(),
crate::protocols::openai::openai(),
) {
self.providers
.insert("openai".to_string(), Arc::new(provider));
}
}
if let Some(deepseek_config) = &self.config.deepseek {
if let Ok(provider) = GenericProvider::new(
deepseek_config.clone(),
crate::protocols::openai::deepseek(),
) {
self.providers
.insert("deepseek".to_string(), Arc::new(provider));
}
}
if let Some(aliyun_config) = &self.config.aliyun {
if let Ok(provider) =
GenericProvider::new(aliyun_config.clone(), crate::protocols::aliyun::aliyun())
{
self.providers
.insert("aliyun".to_string(), Arc::new(provider));
}
}
if let Some(zhipu_config) = &self.config.zhipu {
if let Ok(provider) =
GenericProvider::new(zhipu_config.clone(), crate::protocols::openai::zhipu())
{
self.providers
.insert("zhipu".to_string(), Arc::new(provider));
}
}
if let Some(moonshot_config) = &self.config.moonshot {
if let Ok(provider) =
GenericProvider::new(moonshot_config.clone(), crate::protocols::openai::moonshot())
{
self.providers
.insert("moonshot".to_string(), Arc::new(provider));
}
}
if let Some(volcengine_config) = &self.config.volcengine {
if let Ok(provider) =
GenericProvider::new(volcengine_config.clone(), crate::protocols::openai::volcengine())
{
self.providers
.insert("volcengine".to_string(), Arc::new(provider));
}
}
if let Some(longcat_config) = &self.config.longcat {
if let Ok(provider) =
GenericProvider::new(longcat_config.clone(), crate::protocols::openai::longcat())
{
self.providers
.insert("longcat".to_string(), Arc::new(provider));
}
}
}
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
validate_chat_request(&request)?;
let provider = self.get_provider_for_model(&request.model)?;
let mut cleaned_request = request;
cleaned_request.model = clean_model_name(&cleaned_request.model).to_string();
provider.chat(&cleaned_request).await
}
#[cfg(feature = "streaming")]
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream, LlmConnectorError> {
validate_chat_request(&request)?;
let provider = self.get_provider_for_model(&request.model)?;
let mut cleaned_request = request;
cleaned_request.model = clean_model_name(&cleaned_request.model).to_string();
cleaned_request.stream = Some(true);
provider.chat_stream(&cleaned_request).await
}
fn get_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>, LlmConnectorError> {
let provider_name = detect_provider_from_model(model)
.ok_or_else(|| LlmConnectorError::UnsupportedModel(model.to_string()))?;
self.providers.get(provider_name).cloned().ok_or_else(|| {
LlmConnectorError::ConfigError(format!("Provider '{}' not configured", provider_name))
})
}
pub fn list_models(&self) -> Vec<String> {
let mut models = Vec::new();
for (provider_name, provider) in &self.providers {
for model in provider.supported_models() {
models.push(format!("{}/{}", provider_name, model));
models.push(model);
}
}
models.sort();
models.dedup();
models
}
pub fn list_providers(&self) -> Vec<String> {
self.config.list_providers()
}
pub fn supports_model(&self, model: &str) -> bool {
if let Ok(provider) = self.get_provider_for_model(model) {
let clean_model = clean_model_name(model);
provider.supports_model(clean_model)
} else {
false
}
}
pub fn get_provider_info(&self, model: &str) -> Option<String> {
detect_provider_from_model(model).map(|s| s.to_string())
}
}
fn validate_chat_request(request: &ChatRequest) -> Result<(), LlmConnectorError> {
if request.messages.is_empty() {
return Err(LlmConnectorError::InvalidRequest(
"Messages cannot be empty".to_string(),
));
}
Ok(())
}
fn clean_model_name(model: &str) -> &str {
if let Some(idx) = model.find('/') {
&model[idx + 1..]
} else {
model
}
}
fn detect_provider_from_model(model: &str) -> Option<&str> {
if let Some(idx) = model.find('/') {
Some(&model[..idx])
} else {
None
}
}
impl Default for Client {
fn default() -> Self {
Self::from_env()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Message;
#[test]
fn test_client_creation() {
let config = Config::default();
let client = Client::with_config(config);
assert_eq!(client.list_providers().len(), 0);
}
#[test]
fn test_model_support_detection() {
let client = Client::default();
assert!(client.get_provider_info("gpt-4").is_none());
assert!(client.get_provider_info("claude-3-haiku").is_none());
assert!(client.get_provider_info("deepseek-chat").is_none());
assert!(client.get_provider_info("unknown-model").is_none());
}
#[tokio::test]
async fn test_request_validation() {
let client = Client::default();
let request = ChatRequest {
model: "".to_string(),
messages: vec![Message::user("Hello")],
..Default::default()
};
let result = client.chat(request).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
LlmConnectorError::InvalidRequest(_) | LlmConnectorError::UnsupportedModel(_)
));
}
}