use anyhow::Result;
use async_trait::async_trait;
use super::types::{CompletionRequest, CompletionResponse};
#[async_trait]
pub trait ModelProvider: Send + Sync {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
fn name(&self) -> &str;
fn model_id(&self) -> &str;
fn supports_tools(&self) -> bool;
}
#[derive(Debug, Clone, PartialEq)]
pub enum ProviderType {
Anthropic,
Groq,
Google,
Ollama,
LmStudio,
LlamaCpp,
Custom,
}
impl ProviderType {
pub fn from_str_loose(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"anthropic" | "claude" => Some(Self::Anthropic),
"groq" => Some(Self::Groq),
"google" | "gemini" | "google_ai" => Some(Self::Google),
"ollama" => Some(Self::Ollama),
"lmstudio" | "lm_studio" | "lm-studio" => Some(Self::LmStudio),
"llama_cpp" | "llama-cpp" | "llamacpp" => Some(Self::LlamaCpp),
"custom" => Some(Self::Custom),
_ => None,
}
}
pub fn default_base_url(&self) -> &str {
match self {
Self::Anthropic => "https://api.anthropic.com",
Self::Groq => "https://api.groq.com/openai",
Self::Google => "https://generativelanguage.googleapis.com",
Self::Ollama => "http://localhost:11434",
Self::LmStudio => "http://localhost:1234",
Self::LlamaCpp => "http://localhost:8080",
Self::Custom => "http://localhost:8000",
}
}
pub fn default_model(&self) -> &str {
match self {
Self::Anthropic => "claude-sonnet-4-20250514",
Self::Groq => "llama-3.3-70b-versatile",
Self::Google => "gemini-2.5-flash",
Self::Ollama => "llama3",
Self::LmStudio => "local-model",
Self::LlamaCpp => "local-model",
Self::Custom => "custom-model",
}
}
}
pub fn create_provider(
config: &crate::config::AppConfig,
provider_name: &str,
model_name: Option<String>,
) -> Result<Box<dyn ModelProvider>> {
let provider_type = ProviderType::from_str_loose(provider_name)
.ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", provider_name))?;
let model = model_name
.map(|s| if s.is_empty() { provider_type.default_model().to_string() } else { s })
.unwrap_or_else(|| provider_type.default_model().to_string());
match provider_type {
ProviderType::Anthropic => {
let api_key = config.anthropic_api_key.clone().ok_or_else(|| {
anyhow::anyhow!("Anthropic API key is not set. Please set the ANTHROPIC_API_KEY environment variable.")
})?;
Ok(Box::new(super::anthropic::AnthropicProvider::new(api_key, Some(model), None)))
}
ProviderType::Groq => {
let api_key = config.groq_api_key.clone().ok_or_else(|| {
anyhow::anyhow!("Groq API key is not set. Please set the GROQ_API_KEY environment variable.")
})?;
Ok(Box::new(super::cloud::GroqProvider::new(api_key, Some(model), None)))
}
ProviderType::Google => {
let api_key = config.google_api_key.clone().ok_or_else(|| {
anyhow::anyhow!("Google/Gemini API key is not set. Please set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.")
})?;
Ok(Box::new(super::cloud::GoogleProvider::new(api_key, Some(model), None)))
}
ProviderType::Ollama => {
Ok(Box::new(super::local::LocalModelProvider::new(
config.ollama_url.clone(),
model,
super::local::LocalRuntime::Ollama,
false,
)))
}
ProviderType::LmStudio => {
Ok(Box::new(super::local::LocalModelProvider::new(
config.lm_studio_url.clone(),
model,
super::local::LocalRuntime::LmStudio,
false,
)))
}
ProviderType::LlamaCpp => {
Ok(Box::new(super::local::LocalModelProvider::new(
config.llama_cpp_url.clone(),
model,
super::local::LocalRuntime::LlamaCpp,
false,
)))
}
ProviderType::Custom => {
Ok(Box::new(super::local::LocalModelProvider::new(
config.custom_api_url.clone(),
model,
super::local::LocalRuntime::Custom,
false,
)))
}
}
}