use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
mod openai;
mod anthropic;
mod azure;
mod ollama;
mod google;
pub use openai::OpenAiProvider;
pub use anthropic::AnthropicProvider;
pub use azure::AzureOpenAiProvider;
pub use ollama::OllamaProvider;
pub use google::GoogleProvider;
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("API error: {0}")]
Api(String),
#[error("Rate limit exceeded")]
RateLimit,
#[error("Invalid configuration: {0}")]
Config(String),
#[error("Network error: {0}")]
Network(String),
#[error("Unsupported model: {0}")]
UnsupportedModel(String),
#[error("Token limit exceeded: {0}")]
TokenLimit(String),
}
pub type ProviderResult<T> = Result<T, ProviderError>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Function,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDef {
pub name: String,
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDef {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDef,
}
#[derive(Debug, Clone, Default)]
pub struct LlmRequest {
pub messages: Vec<ChatMessage>,
pub model: Option<String>,
pub max_tokens: Option<usize>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
pub tools: Option<Vec<ToolDef>>,
pub tool_choice: Option<serde_json::Value>,
pub stream: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub id: String,
pub model: String,
pub message: ChatMessage,
pub finish_reason: Option<String>,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamChunk {
pub id: String,
pub delta: ChatDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatDelta {
pub role: Option<MessageRole>,
pub content: Option<String>,
pub function_call: Option<FunctionCallDelta>,
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: usize,
pub id: Option<String>,
#[serde(rename = "type")]
pub call_type: Option<String>,
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmProviderConfig {
pub provider: String,
pub api_key: Option<String>,
pub endpoint: Option<String>,
pub model: Option<String>,
pub organization: Option<String>,
pub deployment: Option<String>,
pub api_version: Option<String>,
pub timeout_ms: Option<u64>,
pub max_retries: Option<usize>,
pub headers: Option<HashMap<String, String>>,
}
#[async_trait]
pub trait LlmProvider: Send + Sync {
fn name(&self) -> &str;
async fn list_models(&self) -> ProviderResult<Vec<ModelInfo>>;
async fn chat(&self, request: LlmRequest) -> ProviderResult<LlmResponse>;
async fn chat_stream(
&self,
request: LlmRequest,
) -> ProviderResult<Box<dyn futures::Stream<Item = ProviderResult<StreamChunk>> + Send + Unpin>>;
fn count_tokens(&self, text: &str, model: &str) -> ProviderResult<usize>;
fn supports_model(&self, model: &str) -> bool;
fn model_info(&self, model: &str) -> Option<ModelInfo>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub provider: String,
pub context_length: usize,
pub supports_functions: bool,
pub supports_vision: bool,
pub input_cost_per_1k: Option<f64>,
pub output_cost_per_1k: Option<f64>,
}
pub struct ProviderRegistry {
providers: HashMap<String, Arc<dyn LlmProvider>>,
default_provider: Option<String>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
default_provider: None,
}
}
pub fn register(&mut self, name: &str, provider: Arc<dyn LlmProvider>) {
self.providers.insert(name.to_string(), provider);
}
pub fn set_default(&mut self, name: &str) -> bool {
if self.providers.contains_key(name) {
self.default_provider = Some(name.to_string());
true
} else {
false
}
}
pub fn get(&self, name: &str) -> Option<Arc<dyn LlmProvider>> {
self.providers.get(name).cloned()
}
pub fn get_default(&self) -> Option<Arc<dyn LlmProvider>> {
self.default_provider
.as_ref()
.and_then(|name| self.providers.get(name).cloned())
}
pub fn list(&self) -> Vec<String> {
self.providers.keys().cloned().collect()
}
pub fn from_config(config: &LlmProviderConfig) -> ProviderResult<Arc<dyn LlmProvider>> {
match config.provider.as_str() {
"openai" => Ok(Arc::new(OpenAiProvider::new(config)?)),
"anthropic" => Ok(Arc::new(AnthropicProvider::new(config)?)),
"azure" => Ok(Arc::new(AzureOpenAiProvider::new(config)?)),
"ollama" => Ok(Arc::new(OllamaProvider::new(config)?)),
"google" | "gemini" => Ok(Arc::new(GoogleProvider::new(config)?)),
_ => Err(ProviderError::Config(format!(
"Unknown provider: {}",
config.provider
))),
}
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}