use std::collections::HashMap;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::embedding::{EmbeddingModel, EmbeddingRequest, EmbeddingResponse};
use crate::error::{AiError, AiResult};
use crate::model::Model;
use crate::provider::{Provider, ProviderConfig};
use crate::transcription::{
TimestampGranularity, TranscriptionModel, TranscriptionOptions, TranscriptionResponse,
TranscriptionResponseFormat, VerboseTranscriptionResponse,
};
use crate::types::{ApiErrorResponse, ChatCompletionRequest, ChatCompletionResponse, Message};
#[derive(Debug)]
pub struct AiClient {
providers: HashMap<Provider, ProviderConfig>,
default_temperature: Option<f32>,
default_max_tokens: Option<u32>,
}
impl Default for AiClient {
fn default() -> Self {
Self::new()
}
}
impl AiClient {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
default_temperature: None,
default_max_tokens: None,
}
}
pub fn from_env() -> Self {
let mut client = Self::new();
for provider in [Provider::Groq, Provider::OpenRouter, Provider::SambaNova] {
if let Some(config) = ProviderConfig::from_env(provider) {
client.providers.insert(provider, config);
}
}
client
}
pub fn with_provider(mut self, config: ProviderConfig) -> Self {
self.providers.insert(config.provider, config);
self
}
pub fn with_default_temperature(mut self, temperature: f32) -> Self {
self.default_temperature = Some(temperature);
self
}
pub fn with_default_max_tokens(mut self, max_tokens: u32) -> Self {
self.default_max_tokens = Some(max_tokens);
self
}
pub fn has_providers(&self) -> bool {
!self.providers.is_empty()
}
pub fn providers(&self) -> Vec<Provider> {
self.providers.keys().copied().collect()
}
pub fn available_models(&self) -> Vec<Model> {
Model::available_for_providers(&self.providers())
}
pub fn chat(&self, model: Model, messages: Vec<Message>) -> AiResult<ChatCompletionResponse> {
let model_info = model.info();
let mut errors = Vec::new();
for mapping in &model_info.providers {
if let Some(config) = self.providers.get(&mapping.provider) {
let mut request = ChatCompletionRequest::new(mapping.model_id, messages.clone());
if let Some(temp) = self.default_temperature {
request = request.with_temperature(temp);
}
if let Some(max) = self.default_max_tokens {
request = request.with_max_tokens(max);
}
match self.send_request(config, request) {
Ok(response) => return Ok(response),
Err(e) => {
errors.push((mapping.provider, e.to_string()));
}
}
}
}
if errors.is_empty() {
Err(AiError::ModelNotAvailable(model.name().to_string()))
} else {
Err(AiError::AllProvidersFailed(errors))
}
}
pub fn chat_with_options(
&self,
model: Model,
messages: Vec<Message>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> AiResult<ChatCompletionResponse> {
let model_info = model.info();
let mut errors = Vec::new();
for mapping in &model_info.providers {
if let Some(config) = self.providers.get(&mapping.provider) {
let mut request = ChatCompletionRequest::new(mapping.model_id, messages.clone());
if let Some(temp) = temperature.or(self.default_temperature) {
request = request.with_temperature(temp);
}
if let Some(max) = max_tokens.or(self.default_max_tokens) {
request = request.with_max_tokens(max);
}
match self.send_request(config, request) {
Ok(response) => return Ok(response),
Err(e) => {
errors.push((mapping.provider, e.to_string()));
}
}
}
}
if errors.is_empty() {
Err(AiError::ModelNotAvailable(model.name().to_string()))
} else {
Err(AiError::AllProvidersFailed(errors))
}
}
pub fn chat_raw(
&self,
provider: Provider,
request: ChatCompletionRequest,
) -> AiResult<ChatCompletionResponse> {
let config = self
.providers
.get(&provider)
.ok_or(AiError::ApiKeyNotFound(provider))?;
self.send_request(config, request)
}
fn send_request(
&self,
config: &ProviderConfig,
request: ChatCompletionRequest,
) -> AiResult<ChatCompletionResponse> {
let url = config.chat_completions_url();
let agent = ureq::Agent::new_with_config(
ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(config.timeout_secs)))
.build(),
);
let response = agent
.post(&url)
.header("Authorization", &format!("Bearer {}", config.api_key))
.header("Content-Type", "application/json")
.send_json(&request)
.map_err(|e| self.handle_http_error(config.provider, e))?;
let body = response
.into_body()
.read_to_string()
.map_err(|e| AiError::ParseError(e.to_string()))?;
if let Ok(completion) = serde_json::from_str::<ChatCompletionResponse>(&body) {
if completion.content().map(|c| c.is_empty()).unwrap_or(true) {
return Err(AiError::EmptyResponse(config.provider));
}
return Ok(completion);
}
if let Ok(error) = serde_json::from_str::<ApiErrorResponse>(&body) {
return Err(AiError::ApiError {
provider: config.provider,
message: error.error.message,
status_code: None,
});
}
Err(AiError::ParseError(format!(
"Failed to parse response: {}",
body
)))
}
fn handle_http_error(&self, provider: Provider, error: ureq::Error) -> AiError {
match error {
ureq::Error::Timeout(_) => AiError::Timeout(120),
ureq::Error::StatusCode(status) => {
if status == 429 {
AiError::RateLimitExceeded(provider)
} else {
AiError::ApiError {
provider,
message: format!("HTTP {}", status),
status_code: Some(status),
}
}
}
_ => AiError::HttpError(error.to_string()),
}
}
pub fn embed(&self, model: EmbeddingModel, input: &str) -> AiResult<EmbeddingResponse> {
let model_info = model.info();
let mut errors = Vec::new();
for mapping in &model_info.providers {
if let Some(config) = self.providers.get(&mapping.provider) {
let request = EmbeddingRequest::new(mapping.model_id, input);
match self.send_embedding_request(config, request) {
Ok(response) => return Ok(response),
Err(e) => {
errors.push((mapping.provider, e.to_string()));
}
}
}
}
if errors.is_empty() {
Err(AiError::ModelNotAvailable(model.name().to_string()))
} else {
Err(AiError::AllProvidersFailed(errors))
}
}
pub fn embed_batch(
&self,
model: EmbeddingModel,
inputs: Vec<String>,
) -> AiResult<EmbeddingResponse> {
let model_info = model.info();
let mut errors = Vec::new();
for mapping in &model_info.providers {
if let Some(config) = self.providers.get(&mapping.provider) {
let request = EmbeddingRequest::new_batch(mapping.model_id, inputs.clone());
match self.send_embedding_request(config, request) {
Ok(response) => return Ok(response),
Err(e) => {
errors.push((mapping.provider, e.to_string()));
}
}
}
}
if errors.is_empty() {
Err(AiError::ModelNotAvailable(model.name().to_string()))
} else {
Err(AiError::AllProvidersFailed(errors))
}
}
fn send_embedding_request(
&self,
config: &ProviderConfig,
request: EmbeddingRequest,
) -> AiResult<EmbeddingResponse> {
let url = config.embeddings_url();
let agent = ureq::Agent::new_with_config(
ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(config.timeout_secs)))
.build(),
);
let response = agent
.post(&url)
.header("Authorization", &format!("Bearer {}", config.api_key))
.header("Content-Type", "application/json")
.send_json(&request)
.map_err(|e| self.handle_http_error(config.provider, e))?;
let body = response
.into_body()
.read_to_string()
.map_err(|e| AiError::ParseError(e.to_string()))?;
if let Ok(embedding) = serde_json::from_str::<EmbeddingResponse>(&body) {
return Ok(embedding);
}
if let Ok(error) = serde_json::from_str::<ApiErrorResponse>(&body) {
return Err(AiError::ApiError {
provider: config.provider,
message: error.error.message,
status_code: None,
});
}
Err(AiError::ParseError(format!(
"Failed to parse embedding response: {}",
body
)))
}
pub fn transcribe_file(
&self,
model: TranscriptionModel,
file_path: &std::path::Path,
) -> AiResult<TranscriptionResponse> {
self.transcribe_file_with_options(model, file_path, TranscriptionOptions::default())
}
pub fn transcribe_file_with_options(
&self,
model: TranscriptionModel,
file_path: &std::path::Path,
options: TranscriptionOptions,
) -> AiResult<TranscriptionResponse> {
let audio_data = std::fs::read(file_path)
.map_err(|e| AiError::InvalidRequest(format!("Failed to read audio file: {}", e)))?;
let file_name = file_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("audio.mp3")
.to_string();
self.transcribe_bytes(model, &audio_data, &file_name, options)
}
pub fn transcribe_bytes(
&self,
model: TranscriptionModel,
audio_data: &[u8],
file_name: &str,
options: TranscriptionOptions,
) -> AiResult<TranscriptionResponse> {
let model_info = model.info();
let mut errors = Vec::new();
for mapping in &model_info.providers {
if let Some(config) = self.providers.get(&mapping.provider) {
match self.send_transcription_request(
config,
mapping.model_id,
audio_data,
file_name,
&options,
) {
Ok(response) => return Ok(response),
Err(e) => {
errors.push((mapping.provider, e.to_string()));
}
}
}
}
if errors.is_empty() {
Err(AiError::ModelNotAvailable(model.name().to_string()))
} else {
Err(AiError::AllProvidersFailed(errors))
}
}
pub fn transcribe_bytes_verbose(
&self,
model: TranscriptionModel,
audio_data: &[u8],
file_name: &str,
options: TranscriptionOptions,
) -> AiResult<VerboseTranscriptionResponse> {
let model_info = model.info();
let mut errors = Vec::new();
let mut options = options;
options.response_format = Some(TranscriptionResponseFormat::VerboseJson);
for mapping in &model_info.providers {
if let Some(config) = self.providers.get(&mapping.provider) {
match self.send_transcription_request_verbose(
config,
mapping.model_id,
audio_data,
file_name,
&options,
) {
Ok(response) => return Ok(response),
Err(e) => {
errors.push((mapping.provider, e.to_string()));
}
}
}
}
if errors.is_empty() {
Err(AiError::ModelNotAvailable(model.name().to_string()))
} else {
Err(AiError::AllProvidersFailed(errors))
}
}
fn send_transcription_request(
&self,
config: &ProviderConfig,
model_id: &str,
audio_data: &[u8],
file_name: &str,
options: &TranscriptionOptions,
) -> AiResult<TranscriptionResponse> {
let body =
self.send_transcription_request_raw(config, model_id, audio_data, file_name, options)?;
if let Ok(transcription) = serde_json::from_str::<TranscriptionResponse>(&body) {
return Ok(transcription);
}
if let Ok(error) = serde_json::from_str::<ApiErrorResponse>(&body) {
return Err(AiError::ApiError {
provider: config.provider,
message: error.error.message,
status_code: None,
});
}
Err(AiError::ParseError(format!(
"Failed to parse transcription response: {}",
body
)))
}
fn send_transcription_request_verbose(
&self,
config: &ProviderConfig,
model_id: &str,
audio_data: &[u8],
file_name: &str,
options: &TranscriptionOptions,
) -> AiResult<VerboseTranscriptionResponse> {
let body =
self.send_transcription_request_raw(config, model_id, audio_data, file_name, options)?;
if let Ok(transcription) = serde_json::from_str::<VerboseTranscriptionResponse>(&body) {
return Ok(transcription);
}
if let Ok(error) = serde_json::from_str::<ApiErrorResponse>(&body) {
return Err(AiError::ApiError {
provider: config.provider,
message: error.error.message,
status_code: None,
});
}
Err(AiError::ParseError(format!(
"Failed to parse verbose transcription response: {}",
body
)))
}
fn send_transcription_request_raw(
&self,
config: &ProviderConfig,
model_id: &str,
audio_data: &[u8],
file_name: &str,
options: &TranscriptionOptions,
) -> AiResult<String> {
let url = config.transcriptions_url();
let agent = ureq::Agent::new_with_config(
ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(config.timeout_secs)))
.build(),
);
let boundary = format!("----WebKitFormBoundary{:x}", rand_boundary());
let content_type = format!("multipart/form-data; boundary={}", boundary);
let mut body = Vec::new();
body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes());
body.extend_from_slice(
format!(
"Content-Disposition: form-data; name=\"file\"; filename=\"{}\"\r\n",
file_name
)
.as_bytes(),
);
body.extend_from_slice(b"Content-Type: application/octet-stream\r\n\r\n");
body.extend_from_slice(audio_data);
body.extend_from_slice(b"\r\n");
body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes());
body.extend_from_slice(b"Content-Disposition: form-data; name=\"model\"\r\n\r\n");
body.extend_from_slice(model_id.as_bytes());
body.extend_from_slice(b"\r\n");
if let Some(ref lang) = options.language {
body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes());
body.extend_from_slice(b"Content-Disposition: form-data; name=\"language\"\r\n\r\n");
body.extend_from_slice(lang.as_bytes());
body.extend_from_slice(b"\r\n");
}
if let Some(ref prompt) = options.prompt {
body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes());
body.extend_from_slice(b"Content-Disposition: form-data; name=\"prompt\"\r\n\r\n");
body.extend_from_slice(prompt.as_bytes());
body.extend_from_slice(b"\r\n");
}
if let Some(ref format) = options.response_format {
let format_str = match format {
TranscriptionResponseFormat::Json => "json",
TranscriptionResponseFormat::Text => "text",
TranscriptionResponseFormat::VerboseJson => "verbose_json",
};
body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes());
body.extend_from_slice(
b"Content-Disposition: form-data; name=\"response_format\"\r\n\r\n",
);
body.extend_from_slice(format_str.as_bytes());
body.extend_from_slice(b"\r\n");
}
if let Some(temp) = options.temperature {
body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes());
body.extend_from_slice(b"Content-Disposition: form-data; name=\"temperature\"\r\n\r\n");
body.extend_from_slice(temp.to_string().as_bytes());
body.extend_from_slice(b"\r\n");
}
if let Some(ref granularities) = options.timestamp_granularities {
for granularity in granularities {
let g_str = match granularity {
TimestampGranularity::Word => "word",
TimestampGranularity::Segment => "segment",
};
body.extend_from_slice(format!("--{}\r\n", boundary).as_bytes());
body.extend_from_slice(
b"Content-Disposition: form-data; name=\"timestamp_granularities[]\"\r\n\r\n",
);
body.extend_from_slice(g_str.as_bytes());
body.extend_from_slice(b"\r\n");
}
}
body.extend_from_slice(format!("--{}--\r\n", boundary).as_bytes());
let response = agent
.post(&url)
.header("Authorization", &format!("Bearer {}", config.api_key))
.header("Content-Type", &content_type)
.send(&body[..])
.map_err(|e| self.handle_http_error(config.provider, e))?;
response
.into_body()
.read_to_string()
.map_err(|e| AiError::ParseError(e.to_string()))
}
}
fn rand_boundary() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0)
^ 0x5555555555555555
}
pub fn chat_simple(prompt: &str) -> AiResult<String> {
let client = AiClient::from_env();
if !client.has_providers() {
return Err(AiError::NoApiKey);
}
let messages = vec![Message::user(prompt)];
let response = client.chat(Model::default_general(), messages)?;
response
.content()
.map(|s| s.to_string())
.ok_or_else(|| AiError::ParseError("No content in response".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = AiClient::new();
assert!(!client.has_providers());
}
#[test]
fn test_client_with_provider() {
let client = AiClient::new()
.with_provider(ProviderConfig::new(Provider::Groq, "test-key"))
.with_default_temperature(0.7);
assert!(client.has_providers());
assert!(client.providers().contains(&Provider::Groq));
}
#[test]
fn test_model_not_available() {
let client = AiClient::new();
let result = client.chat(Model::Llama3_3_70B, vec![Message::user("Hello")]);
assert!(matches!(result, Err(AiError::ModelNotAvailable(_))));
}
}