use async_trait::async_trait;
use std::collections::HashMap;
use crate::error::LlmError;
use crate::stream::ChatStream;
use crate::types::*;
#[async_trait]
pub trait ChatCapability: Send + Sync {
async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LlmError> {
self.chat_with_tools(messages, None).await
}
async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, LlmError>;
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatStream, LlmError>;
}
#[async_trait]
pub trait ChatExtensions: ChatCapability {
async fn chat_with_retry(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
options: crate::retry_api::RetryOptions,
) -> Result<ChatResponse, LlmError> {
use crate::retry_api;
let msgs = messages;
let tls = tools;
retry_api::retry_with(
|| {
let m = msgs.clone();
let t = tls.clone();
async move { self.chat_with_tools(m, t).await }
},
options,
)
.await
}
async fn memory_contents(&self) -> Result<Option<Vec<ChatMessage>>, LlmError> {
Ok(None)
}
async fn summarize_history(&self, messages: Vec<ChatMessage>) -> Result<String, LlmError> {
let prompt = format!(
"Summarize in 2-3 sentences:\n{}",
messages
.iter()
.map(|m| format!("{:?}: {}", m.role, m.content_text().unwrap_or("")))
.collect::<Vec<_>>()
.join("\n")
);
let request_messages = vec![ChatMessage::user(prompt).build()];
let response = self.chat(request_messages).await?;
response
.content_text()
.ok_or_else(|| LlmError::InternalError("No text in summary response".to_string()))
.map(std::string::ToString::to_string)
}
async fn ask(&self, prompt: String) -> Result<String, LlmError> {
let message = ChatMessage::user(prompt).build();
let response = self.chat(vec![message]).await?;
response
.content_text()
.ok_or_else(|| LlmError::InternalError("No text in response".to_string()))
.map(std::string::ToString::to_string)
}
async fn ask_with_retry(
&self,
prompt: String,
options: crate::retry_api::RetryOptions,
) -> Result<String, LlmError> {
let message = ChatMessage::user(prompt).build();
let response = self.chat_with_retry(vec![message], None, options).await?;
response
.content_text()
.ok_or_else(|| LlmError::InternalError("No text in response".to_string()))
.map(std::string::ToString::to_string)
}
async fn ask_with_system(
&self,
system_prompt: String,
user_prompt: String,
) -> Result<String, LlmError> {
let messages = vec![
ChatMessage::system(system_prompt).build(),
ChatMessage::user(user_prompt).build(),
];
let response = self.chat(messages).await?;
response
.content_text()
.ok_or_else(|| LlmError::InternalError("No text in response".to_string()))
.map(std::string::ToString::to_string)
}
async fn continue_conversation(
&self,
mut conversation: Vec<ChatMessage>,
new_message: String,
) -> Result<(String, Vec<ChatMessage>), LlmError> {
conversation.push(ChatMessage::user(new_message).build());
let response = self.chat(conversation.clone()).await?;
let response_text = response
.content_text()
.ok_or_else(|| LlmError::InternalError("No text in response".to_string()))?
.to_string();
conversation.push(ChatMessage::assistant(response_text.clone()).build());
Ok((response_text, conversation))
}
async fn translate(&self, text: String, target_language: String) -> Result<String, LlmError> {
let prompt = format!("Translate the following text to {target_language}: {text}");
self.ask(prompt).await
}
async fn explain(&self, concept: String, audience: Option<String>) -> Result<String, LlmError> {
let audience_str = audience
.map(|a| format!(" to {a}"))
.unwrap_or_else(|| " in simple terms".to_string());
let prompt = format!("Explain {concept}{audience_str}");
self.ask(prompt).await
}
async fn generate(&self, content_type: String, prompt: String) -> Result<String, LlmError> {
let system_prompt = format!(
"You are a creative writer. Generate a {content_type} based on the user's prompt."
);
self.ask_with_system(system_prompt, prompt).await
}
}
impl<T: ChatCapability> ChatExtensions for T {}
#[async_trait]
pub trait AudioCapability: Send + Sync {
fn supported_features(&self) -> &[AudioFeature];
async fn text_to_speech(&self, _request: TtsRequest) -> Result<TtsResponse, LlmError> {
Err(LlmError::UnsupportedOperation(
"Text-to-speech not supported by this provider".to_string(),
))
}
async fn text_to_speech_stream(&self, _request: TtsRequest) -> Result<AudioStream, LlmError> {
Err(LlmError::UnsupportedOperation(
"Streaming text-to-speech not supported by this provider".to_string(),
))
}
async fn speech_to_text(&self, _request: SttRequest) -> Result<SttResponse, LlmError> {
Err(LlmError::UnsupportedOperation(
"Speech-to-text not supported by this provider".to_string(),
))
}
async fn translate_audio(
&self,
_request: AudioTranslationRequest,
) -> Result<SttResponse, LlmError> {
Err(LlmError::UnsupportedOperation(
"Audio translation not supported by this provider".to_string(),
))
}
async fn get_voices(&self) -> Result<Vec<VoiceInfo>, LlmError> {
Err(LlmError::UnsupportedOperation(
"Voice listing not supported by this provider".to_string(),
))
}
async fn get_supported_languages(&self) -> Result<Vec<LanguageInfo>, LlmError> {
Err(LlmError::UnsupportedOperation(
"Language listing not supported by this provider".to_string(),
))
}
fn get_supported_audio_formats(&self) -> Vec<String> {
vec!["mp3".to_string(), "wav".to_string(), "ogg".to_string()] }
async fn speech(&self, text: String) -> Result<Vec<u8>, LlmError> {
let request = TtsRequest::new(text);
let response = self.text_to_speech(request).await?;
Ok(response.audio_data)
}
async fn transcribe(&self, audio: Vec<u8>) -> Result<String, LlmError> {
let request = SttRequest::from_audio(audio);
let response = self.speech_to_text(request).await?;
Ok(response.text)
}
async fn transcribe_file(&self, file_path: String) -> Result<String, LlmError> {
let request = SttRequest::from_file(file_path);
let response = self.speech_to_text(request).await?;
Ok(response.text)
}
async fn translate(&self, audio: Vec<u8>) -> Result<String, LlmError> {
let request = AudioTranslationRequest::from_audio(audio);
let response = self.translate_audio(request).await?;
Ok(response.text)
}
async fn translate_file(&self, file_path: String) -> Result<String, LlmError> {
let request = AudioTranslationRequest::from_file(file_path);
let response = self.translate_audio(request).await?;
Ok(response.text)
}
}
#[async_trait]
pub trait VisionCapability: Send + Sync {
async fn analyze_image(&self, request: VisionRequest) -> Result<VisionResponse, LlmError>;
async fn generate_image(&self, request: ImageGenRequest) -> Result<ImageResponse, LlmError>;
fn get_supported_input_formats(&self) -> Vec<String> {
vec!["jpeg".to_string(), "png".to_string(), "webp".to_string()]
}
fn get_supported_output_formats(&self) -> Vec<String> {
vec!["png".to_string(), "jpeg".to_string()]
}
}
#[async_trait]
pub trait EmbeddingCapability: Send + Sync {
async fn embed(&self, input: Vec<String>) -> Result<EmbeddingResponse, LlmError>;
fn embedding_dimension(&self) -> usize;
fn max_tokens_per_embedding(&self) -> usize {
8192 }
fn supported_embedding_models(&self) -> Vec<String> {
vec!["default".to_string()]
}
}
#[async_trait]
pub trait EmbeddingExtensions: EmbeddingCapability {
async fn embed_with_config(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, LlmError> {
self.embed(request.input).await
}
async fn embed_batch(
&self,
requests: BatchEmbeddingRequest,
) -> Result<BatchEmbeddingResponse, LlmError> {
let mut responses = Vec::new();
for request in requests.requests {
let result = self
.embed_with_config(request)
.await
.map_err(|e| e.to_string());
responses.push(result);
if requests.batch_options.fail_fast && responses.last().unwrap().is_err() {
break;
}
}
Ok(BatchEmbeddingResponse {
responses,
metadata: HashMap::new(),
})
}
async fn list_embedding_models(&self) -> Result<Vec<EmbeddingModelInfo>, LlmError> {
let models = self.supported_embedding_models();
let model_infos = models
.into_iter()
.map(|id| {
EmbeddingModelInfo::new(
id.clone(),
id,
self.embedding_dimension(),
self.max_tokens_per_embedding(),
)
})
.collect();
Ok(model_infos)
}
fn calculate_similarity(
&self,
embedding1: &[f32],
embedding2: &[f32],
) -> Result<f32, LlmError> {
if embedding1.len() != embedding2.len() {
return Err(LlmError::InvalidInput(
"Embedding vectors must have the same dimension".to_string(),
));
}
let dot_product: f32 = embedding1
.iter()
.zip(embedding2.iter())
.map(|(a, b)| a * b)
.sum();
let norm1: f32 = embedding1.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm2: f32 = embedding2.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm1 == 0.0 || norm2 == 0.0 {
return Err(LlmError::InvalidInput(
"Cannot calculate similarity for zero vectors".to_string(),
));
}
Ok(dot_product / (norm1 * norm2))
}
}
#[async_trait]
pub trait ImageGenerationCapability: Send + Sync {
async fn generate_images(
&self,
request: ImageGenerationRequest,
) -> Result<ImageGenerationResponse, LlmError>;
async fn edit_image(
&self,
_request: ImageEditRequest,
) -> Result<ImageGenerationResponse, LlmError> {
Err(LlmError::UnsupportedOperation(
"Image editing not supported by this provider".to_string(),
))
}
async fn create_variation(
&self,
_request: ImageVariationRequest,
) -> Result<ImageGenerationResponse, LlmError> {
Err(LlmError::UnsupportedOperation(
"Image variations not supported by this provider".to_string(),
))
}
fn get_supported_sizes(&self) -> Vec<String>;
fn get_supported_formats(&self) -> Vec<String>;
fn supports_image_editing(&self) -> bool {
false
}
fn supports_image_variations(&self) -> bool {
false
}
async fn generate_image(
&self,
prompt: String,
size: Option<String>,
count: Option<u32>,
) -> Result<Vec<String>, LlmError> {
let request = ImageGenerationRequest {
prompt,
size,
count: count.unwrap_or(1),
..Default::default()
};
let response = self.generate_images(request).await?;
Ok(response
.images
.into_iter()
.filter_map(|img| img.url)
.collect())
}
}
#[async_trait]
pub trait FileManagementCapability: Send + Sync {
async fn upload_file(&self, request: FileUploadRequest) -> Result<FileObject, LlmError>;
async fn list_files(&self, query: Option<FileListQuery>) -> Result<FileListResponse, LlmError>;
async fn retrieve_file(&self, file_id: String) -> Result<FileObject, LlmError>;
async fn delete_file(&self, file_id: String) -> Result<FileDeleteResponse, LlmError>;
async fn get_file_content(&self, file_id: String) -> Result<Vec<u8>, LlmError>;
}
#[async_trait]
pub trait ModerationCapability: Send + Sync {
async fn moderate(&self, request: ModerationRequest) -> Result<ModerationResponse, LlmError>;
fn supported_categories(&self) -> Vec<String> {
vec![
"hate".to_string(),
"hate/threatening".to_string(),
"harassment".to_string(),
"harassment/threatening".to_string(),
"self-harm".to_string(),
"self-harm/intent".to_string(),
"self-harm/instructions".to_string(),
"sexual".to_string(),
"sexual/minors".to_string(),
"violence".to_string(),
"violence/graphic".to_string(),
]
}
}
#[async_trait]
pub trait ModelListingCapability: Send + Sync {
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
async fn get_model(&self, model_id: String) -> Result<ModelInfo, LlmError>;
async fn is_model_available(&self, model_id: String) -> Result<bool, LlmError> {
match self.get_model(model_id).await {
Ok(_) => Ok(true),
Err(LlmError::NotFound(_)) => Ok(false),
Err(e) => Err(e),
}
}
}
#[async_trait]
pub trait CompletionCapability: Send + Sync {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError>;
async fn complete_stream(
&self,
_request: CompletionRequest,
) -> Result<CompletionStream, LlmError> {
Err(LlmError::UnsupportedOperation(
"Streaming completion not supported by this provider".to_string(),
))
}
}
#[async_trait]
pub trait TimeoutCapability: ChatCapability + Send + Sync {
async fn chat_with_timeout(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
timeout: std::time::Duration,
) -> Result<ChatResponse, LlmError> {
tokio::time::timeout(timeout, self.chat_with_tools(messages, tools))
.await
.map_err(|_| {
LlmError::TimeoutError(format!(
"Operation timed out after {:?} (including retries)",
timeout
))
})?
}
async fn chat_stream_with_timeout(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
timeout: std::time::Duration,
) -> Result<ChatStream, LlmError> {
tokio::time::timeout(timeout, self.chat_stream(messages, tools))
.await
.map_err(|_| {
LlmError::TimeoutError(format!(
"Stream initialization timed out after {:?}",
timeout
))
})?
}
}
impl<T> TimeoutCapability for T where T: ChatCapability + Send + Sync {}
#[async_trait]
pub trait OpenAiCapability: Send + Sync {
async fn chat_with_structured_output(
&self,
messages: Vec<ChatMessage>,
schema: JsonSchema,
) -> Result<StructuredResponse, LlmError>;
async fn create_batch(&self, requests: Vec<BatchRequest>) -> Result<BatchResponse, LlmError>;
async fn chat_with_responses_api(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<OpenAiBuiltInTool>>,
) -> Result<ChatResponse, LlmError>;
}
#[async_trait]
pub trait AnthropicCapability: Send + Sync {
async fn chat_with_cache(
&self,
request: ChatRequest,
cache_config: CacheConfig,
) -> Result<ChatResponse, LlmError>;
async fn chat_with_thinking(&self, request: ChatRequest) -> Result<ThinkingResponse, LlmError>;
}
#[async_trait]
pub trait OpenAiEmbeddingCapability: EmbeddingCapability {
async fn embed_with_dimensions(
&self,
input: Vec<String>,
dimensions: u32,
) -> Result<EmbeddingResponse, LlmError>;
async fn embed_with_format(
&self,
input: Vec<String>,
format: EmbeddingFormat,
) -> Result<EmbeddingResponse, LlmError>;
}
#[async_trait]
pub trait RerankCapability: Send + Sync {
async fn rerank(&self, request: RerankRequest) -> Result<RerankResponse, LlmError>;
fn max_documents(&self) -> Option<u32> {
None
}
fn supported_models(&self) -> Vec<String> {
vec![]
}
}
#[async_trait]
pub trait GeminiCapability: Send + Sync {
async fn chat_with_search(
&self,
request: ChatRequest,
search_config: SearchConfig,
) -> Result<ChatResponse, LlmError>;
async fn execute_code(
&self,
code: String,
language: String,
) -> Result<ExecutionResponse, LlmError>;
}
#[async_trait]
pub trait GeminiEmbeddingCapability: EmbeddingCapability {
async fn embed_with_task_type(
&self,
input: Vec<String>,
task_type: EmbeddingTaskType,
) -> Result<EmbeddingResponse, LlmError>;
async fn embed_with_title(
&self,
input: Vec<String>,
title: String,
) -> Result<EmbeddingResponse, LlmError>;
async fn embed_with_output_dimensionality(
&self,
input: Vec<String>,
output_dimensionality: u32,
) -> Result<EmbeddingResponse, LlmError>;
}
#[async_trait]
pub trait OllamaEmbeddingCapability: EmbeddingCapability {
async fn embed_with_model_options(
&self,
input: Vec<String>,
model: String,
options: HashMap<String, serde_json::Value>,
) -> Result<EmbeddingResponse, LlmError>;
async fn embed_with_truncation(
&self,
input: Vec<String>,
truncate: bool,
) -> Result<EmbeddingResponse, LlmError>;
async fn embed_with_keep_alive(
&self,
input: Vec<String>,
keep_alive: String,
) -> Result<EmbeddingResponse, LlmError>;
}
pub trait LlmProvider: Send + Sync {
fn provider_name(&self) -> &'static str;
fn supported_models(&self) -> Vec<String>;
fn capabilities(&self) -> ProviderCapabilities;
fn http_client(&self) -> &reqwest::Client;
}
#[derive(Debug, Clone, Default)]
pub struct ProviderCapabilities {
pub chat: bool,
pub audio: bool,
pub vision: bool,
pub tools: bool,
pub embedding: bool,
pub streaming: bool,
pub file_management: bool,
pub custom_features: HashMap<String, bool>,
}
impl ProviderCapabilities {
pub fn new() -> Self {
Self::default()
}
pub const fn with_chat(mut self) -> Self {
self.chat = true;
self
}
pub const fn with_audio(mut self) -> Self {
self.audio = true;
self
}
pub const fn with_vision(mut self) -> Self {
self.vision = true;
self
}
pub const fn with_tools(mut self) -> Self {
self.tools = true;
self
}
pub const fn with_embedding(mut self) -> Self {
self.embedding = true;
self
}
pub const fn with_streaming(mut self) -> Self {
self.streaming = true;
self
}
pub const fn with_file_management(mut self) -> Self {
self.file_management = true;
self
}
pub fn with_custom_feature(mut self, name: impl Into<String>, enabled: bool) -> Self {
self.custom_features.insert(name.into(), enabled);
self
}
pub fn supports(&self, feature: &str) -> bool {
match feature {
"chat" => self.chat,
"audio" => self.audio,
"vision" => self.vision,
"tools" => self.tools,
"embedding" => self.embedding,
"streaming" => self.streaming,
"file_management" => self.file_management,
_ => self.custom_features.get(feature).copied().unwrap_or(false),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_capabilities() {
let caps = ProviderCapabilities::new()
.with_chat()
.with_streaming()
.with_custom_feature("custom_feature", true);
assert!(caps.supports("chat"));
assert!(caps.supports("streaming"));
assert!(caps.supports("custom_feature"));
assert!(!caps.supports("audio"));
}
#[test]
fn test_capability_traits_are_send_sync() {
use std::sync::Arc;
fn test_arc_usage() {
let _: Option<Arc<dyn ChatCapability>> = None;
let _: Option<Arc<dyn AudioCapability>> = None;
let _: Option<Arc<dyn VisionCapability>> = None;
let _: Option<Arc<dyn EmbeddingCapability>> = None;
let _: Option<Arc<dyn ImageGenerationCapability>> = None;
let _: Option<Arc<dyn FileManagementCapability>> = None;
let _: Option<Arc<dyn ModerationCapability>> = None;
let _: Option<Arc<dyn ModelListingCapability>> = None;
let _: Option<Arc<dyn CompletionCapability>> = None;
let _: Option<Arc<dyn OpenAiCapability>> = None;
let _: Option<Arc<dyn AnthropicCapability>> = None;
let _: Option<Arc<dyn GeminiCapability>> = None;
}
test_arc_usage();
}
#[tokio::test]
async fn test_capability_traits_multithreading() {
use std::sync::Arc;
use tokio::task;
struct MockCapability;
#[async_trait::async_trait]
impl ChatCapability for MockCapability {
async fn chat_with_tools(
&self,
_messages: Vec<ChatMessage>,
_tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, crate::error::LlmError> {
Ok(ChatResponse {
id: Some("mock-id".to_string()),
content: MessageContent::Text("Mock response".to_string()),
model: Some("mock-model".to_string()),
usage: None,
finish_reason: Some(crate::types::FinishReason::Stop),
tool_calls: None,
thinking: None,
metadata: std::collections::HashMap::new(),
})
}
async fn chat_stream(
&self,
_messages: Vec<ChatMessage>,
_tools: Option<Vec<Tool>>,
) -> Result<crate::stream::ChatStream, crate::error::LlmError> {
Err(crate::error::LlmError::UnsupportedOperation(
"Mock streaming not implemented".to_string(),
))
}
}
let capability: Arc<dyn ChatCapability> = Arc::new(MockCapability);
let mut handles = Vec::new();
for i in 0..5 {
let capability_clone = capability.clone();
let handle = task::spawn(async move {
let messages = vec![ChatMessage::user("Test message").build()];
let result = capability_clone.chat_with_tools(messages, None).await;
assert!(result.is_ok());
i });
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
let result = handle.await.unwrap();
results.push(result);
}
assert_eq!(results.len(), 5);
for (i, result) in results.iter().enumerate() {
assert_eq!(*result, i);
}
}
}