use std::collections::HashMap;
use crate::error::{Result, SaorsaAiError};
use crate::types::{CompletionRequest, CompletionResponse, StreamEvent};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ProviderKind {
Anthropic,
OpenAi,
Gemini,
Ollama,
OpenAiCompatible,
}
impl ProviderKind {
#[must_use]
pub fn default_base_url(self) -> &'static str {
match self {
Self::Anthropic => "https://api.anthropic.com",
Self::OpenAi => "https://api.openai.com",
Self::Gemini => "https://generativelanguage.googleapis.com/v1beta",
Self::Ollama => "http://localhost:11434",
Self::OpenAiCompatible => "",
}
}
#[must_use]
pub fn display_name(self) -> &'static str {
match self {
Self::Anthropic => "Anthropic",
Self::OpenAi => "OpenAI",
Self::Gemini => "Google Gemini",
Self::Ollama => "Ollama",
Self::OpenAiCompatible => "OpenAI-Compatible",
}
}
}
#[derive(Clone, Debug)]
pub struct ProviderConfig {
pub kind: ProviderKind,
pub api_key: String,
pub base_url: String,
pub model: String,
pub max_tokens: u32,
}
impl ProviderConfig {
pub fn new(kind: ProviderKind, api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
base_url: kind.default_base_url().to_string(),
kind,
api_key: api_key.into(),
model: model.into(),
max_tokens: 4096,
}
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
#[must_use]
pub fn with_max_tokens(mut self, max: u32) -> Self {
self.max_tokens = max;
self
}
}
#[async_trait::async_trait]
pub trait Provider: Send + Sync {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
}
#[async_trait::async_trait]
pub trait StreamingProvider: Provider {
async fn stream(
&self,
request: CompletionRequest,
) -> Result<tokio::sync::mpsc::Receiver<Result<StreamEvent>>>;
}
type ProviderFactory =
Box<dyn Fn(ProviderConfig) -> Result<Box<dyn StreamingProvider>> + Send + Sync>;
pub struct ProviderRegistry {
factories: HashMap<ProviderKind, ProviderFactory>,
}
impl ProviderRegistry {
#[must_use]
pub fn new() -> Self {
Self {
factories: HashMap::new(),
}
}
pub fn register<F>(&mut self, kind: ProviderKind, factory: F)
where
F: Fn(ProviderConfig) -> Result<Box<dyn StreamingProvider>> + Send + Sync + 'static,
{
self.factories.insert(kind, Box::new(factory));
}
pub fn create(&self, config: ProviderConfig) -> Result<Box<dyn StreamingProvider>> {
let factory = self
.factories
.get(&config.kind)
.ok_or_else(|| SaorsaAiError::Provider {
provider: config.kind.display_name().to_string(),
message: "no factory registered for this provider".to_string(),
})?;
factory(config)
}
#[must_use]
pub fn has_provider(&self, kind: ProviderKind) -> bool {
self.factories.contains_key(&kind)
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
let mut reg = Self::new();
reg.register(ProviderKind::Anthropic, |config| {
let provider = crate::anthropic::AnthropicProvider::new(config)?;
Ok(Box::new(provider))
});
reg.register(ProviderKind::OpenAi, |config| {
let provider = crate::openai::OpenAiProvider::new(config)?;
Ok(Box::new(provider))
});
reg.register(ProviderKind::Gemini, |config| {
let provider = crate::gemini::GeminiProvider::new(config)?;
Ok(Box::new(provider))
});
reg.register(ProviderKind::Ollama, |config| {
let provider = crate::ollama::OllamaProvider::new(config)?;
Ok(Box::new(provider))
});
reg.register(ProviderKind::OpenAiCompatible, |config| {
let provider = crate::openai_compat::OpenAiCompatProvider::new(config)?;
Ok(Box::new(provider))
});
reg
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn provider_kind_default_base_url() {
assert_eq!(
ProviderKind::Anthropic.default_base_url(),
"https://api.anthropic.com"
);
assert_eq!(
ProviderKind::OpenAi.default_base_url(),
"https://api.openai.com"
);
assert_eq!(
ProviderKind::Gemini.default_base_url(),
"https://generativelanguage.googleapis.com/v1beta"
);
assert_eq!(
ProviderKind::Ollama.default_base_url(),
"http://localhost:11434"
);
assert_eq!(ProviderKind::OpenAiCompatible.default_base_url(), "");
}
#[test]
fn provider_kind_display_name() {
assert_eq!(ProviderKind::Anthropic.display_name(), "Anthropic");
assert_eq!(ProviderKind::OpenAi.display_name(), "OpenAI");
assert_eq!(ProviderKind::Gemini.display_name(), "Google Gemini");
assert_eq!(ProviderKind::Ollama.display_name(), "Ollama");
assert_eq!(
ProviderKind::OpenAiCompatible.display_name(),
"OpenAI-Compatible"
);
}
#[test]
fn provider_config_defaults_from_kind() {
let config = ProviderConfig::new(
ProviderKind::Anthropic,
"sk-test",
"claude-sonnet-4-5-20250929",
);
assert_eq!(config.base_url, "https://api.anthropic.com");
assert_eq!(config.max_tokens, 4096);
assert_eq!(config.kind, ProviderKind::Anthropic);
let config = ProviderConfig::new(ProviderKind::OpenAi, "sk-test", "gpt-4o");
assert_eq!(config.base_url, "https://api.openai.com");
let config = ProviderConfig::new(ProviderKind::Ollama, "", "llama3");
assert_eq!(config.base_url, "http://localhost:11434");
}
#[test]
fn provider_config_custom_base_url() {
let config = ProviderConfig::new(ProviderKind::Anthropic, "key", "model")
.with_base_url("https://custom.api.com");
assert_eq!(config.base_url, "https://custom.api.com");
}
#[test]
fn provider_config_builder() {
let config = ProviderConfig::new(
ProviderKind::Anthropic,
"sk-test",
"claude-sonnet-4-5-20250929",
)
.with_base_url("https://custom.api.com")
.with_max_tokens(8192);
assert_eq!(config.api_key, "sk-test");
assert_eq!(config.model, "claude-sonnet-4-5-20250929");
assert_eq!(config.base_url, "https://custom.api.com");
assert_eq!(config.max_tokens, 8192);
}
#[test]
fn registry_has_provider() {
let reg = ProviderRegistry::default();
assert!(reg.has_provider(ProviderKind::Anthropic));
assert!(reg.has_provider(ProviderKind::OpenAi));
assert!(reg.has_provider(ProviderKind::Gemini));
assert!(reg.has_provider(ProviderKind::Ollama));
assert!(reg.has_provider(ProviderKind::OpenAiCompatible));
}
#[test]
fn registry_create_anthropic() {
let reg = ProviderRegistry::default();
let config = ProviderConfig::new(
ProviderKind::Anthropic,
"sk-test",
"claude-sonnet-4-5-20250929",
);
let result = reg.create(config);
assert!(result.is_ok());
}
#[test]
fn registry_create_openai() {
let reg = ProviderRegistry::default();
let config = ProviderConfig::new(ProviderKind::OpenAi, "sk-test", "gpt-4o");
let result = reg.create(config);
assert!(result.is_ok());
}
#[test]
fn registry_create_gemini() {
let reg = ProviderRegistry::default();
let config = ProviderConfig::new(ProviderKind::Gemini, "test-key", "gemini-2.0-flash");
let result = reg.create(config);
assert!(result.is_ok());
}
#[test]
fn registry_create_ollama() {
let reg = ProviderRegistry::default();
let config = ProviderConfig::new(ProviderKind::Ollama, "", "llama3");
let result = reg.create(config);
assert!(result.is_ok());
}
#[test]
fn registry_create_unknown_returns_error() {
let reg = ProviderRegistry::new();
let config = ProviderConfig::new(ProviderKind::Anthropic, "key", "model");
let result = reg.create(config);
assert!(result.is_err());
}
#[test]
fn registry_create_openai_compatible() {
let reg = ProviderRegistry::default();
let config = ProviderConfig::new(ProviderKind::OpenAiCompatible, "key", "model")
.with_base_url("https://api.example.com");
let result = reg.create(config);
assert!(result.is_ok());
}
#[test]
fn registry_custom_factory() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let called = Arc::new(AtomicBool::new(false));
let called_clone = called.clone();
let mut reg = ProviderRegistry::new();
reg.register(ProviderKind::Anthropic, move |config| {
called_clone.store(true, Ordering::Relaxed);
crate::anthropic::AnthropicProvider::new(config)
.map(|p| Box::new(p) as Box<dyn StreamingProvider>)
});
let config = ProviderConfig::new(
ProviderKind::Anthropic,
"sk-test",
"claude-sonnet-4-5-20250929",
);
let result = reg.create(config);
assert!(result.is_ok());
assert!(called.load(Ordering::Relaxed));
}
}