use std::collections::HashMap;
use std::time::Duration;
use crate::error::{AiError, AiResult};
use crate::model::Model;
use crate::provider::{Provider, ProviderConfig};
use crate::types::{ApiErrorResponse, ChatCompletionRequest, ChatCompletionResponse, Message};
#[derive(Debug)]
pub struct AiClient {
providers: HashMap<Provider, ProviderConfig>,
default_temperature: Option<f32>,
default_max_tokens: Option<u32>,
}
impl Default for AiClient {
fn default() -> Self {
Self::new()
}
}
impl AiClient {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
default_temperature: None,
default_max_tokens: None,
}
}
pub fn from_env() -> Self {
let mut client = Self::new();
for provider in [Provider::Groq, Provider::OpenRouter, Provider::SambaNova] {
if let Some(config) = ProviderConfig::from_env(provider) {
client.providers.insert(provider, config);
}
}
client
}
pub fn with_provider(mut self, config: ProviderConfig) -> Self {
self.providers.insert(config.provider, config);
self
}
pub fn with_default_temperature(mut self, temperature: f32) -> Self {
self.default_temperature = Some(temperature);
self
}
pub fn with_default_max_tokens(mut self, max_tokens: u32) -> Self {
self.default_max_tokens = Some(max_tokens);
self
}
pub fn has_providers(&self) -> bool {
!self.providers.is_empty()
}
pub fn providers(&self) -> Vec<Provider> {
self.providers.keys().copied().collect()
}
pub fn chat(&self, model: Model, messages: Vec<Message>) -> AiResult<ChatCompletionResponse> {
let model_info = model.info();
let mut errors = Vec::new();
for mapping in &model_info.providers {
if let Some(config) = self.providers.get(&mapping.provider) {
let mut request = ChatCompletionRequest::new(mapping.model_id, messages.clone());
if let Some(temp) = self.default_temperature {
request = request.with_temperature(temp);
}
if let Some(max) = self.default_max_tokens {
request = request.with_max_tokens(max);
}
match self.send_request(config, request) {
Ok(response) => return Ok(response),
Err(e) => {
errors.push((mapping.provider, e.to_string()));
}
}
}
}
if errors.is_empty() {
Err(AiError::ModelNotAvailable(model.name().to_string()))
} else {
Err(AiError::AllProvidersFailed(errors))
}
}
pub fn chat_with_options(
&self,
model: Model,
messages: Vec<Message>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> AiResult<ChatCompletionResponse> {
let model_info = model.info();
let mut errors = Vec::new();
for mapping in &model_info.providers {
if let Some(config) = self.providers.get(&mapping.provider) {
let mut request = ChatCompletionRequest::new(mapping.model_id, messages.clone());
if let Some(temp) = temperature.or(self.default_temperature) {
request = request.with_temperature(temp);
}
if let Some(max) = max_tokens.or(self.default_max_tokens) {
request = request.with_max_tokens(max);
}
match self.send_request(config, request) {
Ok(response) => return Ok(response),
Err(e) => {
errors.push((mapping.provider, e.to_string()));
}
}
}
}
if errors.is_empty() {
Err(AiError::ModelNotAvailable(model.name().to_string()))
} else {
Err(AiError::AllProvidersFailed(errors))
}
}
pub fn chat_raw(
&self,
provider: Provider,
request: ChatCompletionRequest,
) -> AiResult<ChatCompletionResponse> {
let config = self
.providers
.get(&provider)
.ok_or(AiError::ApiKeyNotFound(provider))?;
self.send_request(config, request)
}
fn send_request(
&self,
config: &ProviderConfig,
request: ChatCompletionRequest,
) -> AiResult<ChatCompletionResponse> {
let url = config.chat_completions_url();
let agent = ureq::Agent::new_with_config(
ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(config.timeout_secs)))
.build(),
);
let response = agent
.post(&url)
.header("Authorization", &format!("Bearer {}", config.api_key))
.header("Content-Type", "application/json")
.send_json(&request)
.map_err(|e| self.handle_http_error(config.provider, e))?;
let body = response
.into_body()
.read_to_string()
.map_err(|e| AiError::ParseError(e.to_string()))?;
if let Ok(completion) = serde_json::from_str::<ChatCompletionResponse>(&body) {
return Ok(completion);
}
if let Ok(error) = serde_json::from_str::<ApiErrorResponse>(&body) {
return Err(AiError::ApiError {
provider: config.provider,
message: error.error.message,
status_code: None,
});
}
Err(AiError::ParseError(format!(
"Failed to parse response: {}",
body
)))
}
fn handle_http_error(&self, provider: Provider, error: ureq::Error) -> AiError {
match error {
ureq::Error::Timeout(_) => AiError::Timeout(120),
ureq::Error::StatusCode(status) => {
if status == 429 {
AiError::RateLimitExceeded(provider)
} else {
AiError::ApiError {
provider,
message: format!("HTTP {}", status),
status_code: Some(status),
}
}
}
_ => AiError::HttpError(error.to_string()),
}
}
}
pub fn chat_simple(prompt: &str) -> AiResult<String> {
let client = AiClient::from_env();
if !client.has_providers() {
return Err(AiError::NoApiKey);
}
let messages = vec![Message::user(prompt)];
let response = client.chat(Model::default_general(), messages)?;
response
.content()
.map(|s| s.to_string())
.ok_or_else(|| AiError::ParseError("No content in response".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = AiClient::new();
assert!(!client.has_providers());
}
#[test]
fn test_client_with_provider() {
let client = AiClient::new()
.with_provider(ProviderConfig::new(Provider::Groq, "test-key"))
.with_default_temperature(0.7);
assert!(client.has_providers());
assert!(client.providers().contains(&Provider::Groq));
}
#[test]
fn test_model_not_available() {
let client = AiClient::new();
let result = client.chat(Model::Llama3_3_70B, vec![Message::user("Hello")]);
assert!(matches!(result, Err(AiError::ModelNotAvailable(_))));
}
}