use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use std::time::Duration;
use crate::error::LlmError;
use crate::provider::DynProvider;
#[derive(Debug, Clone, Default)]
pub struct ProviderConfig {
pub provider: String,
pub api_key: Option<String>,
pub model: String,
pub base_url: Option<String>,
pub timeout: Option<Duration>,
pub client: Option<reqwest::Client>,
pub extra: HashMap<String, serde_json::Value>,
}
impl ProviderConfig {
pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: provider.into(),
model: model.into(),
..Default::default()
}
}
#[must_use]
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
#[must_use]
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn client(mut self, client: reqwest::Client) -> Self {
self.client = Some(client);
self
}
#[must_use]
pub fn extra(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
self.extra.insert(key.into(), value.into());
self
}
pub fn get_extra_str(&self, key: &str) -> Option<&str> {
self.extra.get(key).and_then(|v| v.as_str())
}
pub fn get_extra_bool(&self, key: &str) -> Option<bool> {
self.extra.get(key).and_then(serde_json::Value::as_bool)
}
pub fn get_extra_i64(&self, key: &str) -> Option<i64> {
self.extra.get(key).and_then(serde_json::Value::as_i64)
}
}
pub trait ProviderFactory: Send + Sync {
fn name(&self) -> &str;
fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError>;
}
pub struct ProviderRegistry {
factories: RwLock<HashMap<String, Arc<dyn ProviderFactory>>>,
}
impl std::fmt::Debug for ProviderRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let factories = self
.factories
.read()
.expect("provider registry lock poisoned");
let names: Vec<_> = factories.keys().collect();
f.debug_struct("ProviderRegistry")
.field("providers", &names)
.finish()
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
factories: RwLock::new(HashMap::new()),
}
}
pub fn global() -> &'static Self {
static GLOBAL: OnceLock<ProviderRegistry> = OnceLock::new();
GLOBAL.get_or_init(ProviderRegistry::new)
}
pub fn register(&self, factory: Box<dyn ProviderFactory>) -> &Self {
let name = factory.name().to_lowercase();
let mut factories = self
.factories
.write()
.expect("provider registry lock poisoned");
factories.insert(name, Arc::from(factory));
self
}
pub fn register_shared(&self, factory: Arc<dyn ProviderFactory>) -> &Self {
let name = factory.name().to_lowercase();
let mut factories = self
.factories
.write()
.expect("provider registry lock poisoned");
factories.insert(name, factory);
self
}
pub fn unregister(&self, name: &str) -> bool {
let mut factories = self
.factories
.write()
.expect("provider registry lock poisoned");
factories.remove(&name.to_lowercase()).is_some()
}
pub fn contains(&self, name: &str) -> bool {
let factories = self
.factories
.read()
.expect("provider registry lock poisoned");
factories.contains_key(&name.to_lowercase())
}
pub fn providers(&self) -> Vec<String> {
let factories = self
.factories
.read()
.expect("provider registry lock poisoned");
factories.keys().cloned().collect()
}
pub fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
let name = config.provider.to_lowercase();
let factories = self
.factories
.read()
.expect("provider registry lock poisoned");
let factory = factories.get(&name).ok_or_else(|| {
let available: Vec<_> = factories.keys().cloned().collect();
LlmError::InvalidRequest(format!(
"unknown provider '{}'. Available: {:?}",
config.provider, available
))
})?;
factory.build(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chat::{ChatResponse, ContentBlock, StopReason};
use crate::provider::{ChatParams, Provider, ProviderMetadata};
use crate::stream::ChatStream;
use crate::usage::Usage;
use std::collections::{HashMap, HashSet};
struct TestProvider {
model: String,
}
impl Provider for TestProvider {
async fn generate(&self, _params: &ChatParams) -> Result<ChatResponse, LlmError> {
Ok(ChatResponse {
content: vec![ContentBlock::Text("test".into())],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: self.model.clone(),
metadata: HashMap::default(),
})
}
async fn stream(&self, _params: &ChatParams) -> Result<ChatStream, LlmError> {
Err(LlmError::InvalidRequest("not implemented".into()))
}
fn metadata(&self) -> ProviderMetadata {
ProviderMetadata {
name: "test".into(),
model: self.model.clone(),
context_window: 4096,
capabilities: HashSet::new(),
}
}
}
struct TestFactory;
impl ProviderFactory for TestFactory {
fn name(&self) -> &'static str {
"test"
}
fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
Ok(Box::new(TestProvider {
model: config.model.clone(),
}))
}
}
#[test]
fn test_registry_register_and_build() {
let registry = ProviderRegistry::new();
registry.register(Box::new(TestFactory));
assert!(registry.contains("test"));
assert!(registry.contains("TEST"));
let config = ProviderConfig::new("test", "test-model");
let provider = registry.build(&config).unwrap();
assert_eq!(provider.metadata().model, "test-model");
}
#[test]
fn test_registry_unknown_provider() {
let registry = ProviderRegistry::new();
let config = ProviderConfig::new("unknown", "model");
let result = registry.build(&config);
assert!(result.is_err());
let err = result.err().unwrap();
assert!(matches!(err, LlmError::InvalidRequest(_)));
}
#[test]
fn test_registry_unregister() {
let registry = ProviderRegistry::new();
registry.register(Box::new(TestFactory));
assert!(registry.contains("test"));
assert!(registry.unregister("test"));
assert!(!registry.contains("test"));
assert!(!registry.unregister("test")); }
#[test]
fn test_registry_providers_list() {
let registry = ProviderRegistry::new();
registry.register(Box::new(TestFactory));
let providers = registry.providers();
assert_eq!(providers, vec!["test"]);
}
#[test]
fn test_provider_config_builder() {
let config = ProviderConfig::new("anthropic", "claude-3")
.api_key("sk-123")
.base_url("https://custom.api")
.timeout(Duration::from_secs(60))
.extra("organization", "org-123");
assert_eq!(config.provider, "anthropic");
assert_eq!(config.model, "claude-3");
assert_eq!(config.api_key, Some("sk-123".into()));
assert_eq!(config.base_url, Some("https://custom.api".into()));
assert_eq!(config.timeout, Some(Duration::from_secs(60)));
assert_eq!(config.get_extra_str("organization"), Some("org-123"));
}
#[test]
fn test_provider_config_extra_types() {
let config = ProviderConfig::new("test", "model")
.extra("flag", true)
.extra("count", 42i64)
.extra("name", "value");
assert_eq!(config.get_extra_bool("flag"), Some(true));
assert_eq!(config.get_extra_i64("count"), Some(42));
assert_eq!(config.get_extra_str("name"), Some("value"));
assert_eq!(config.get_extra_str("missing"), None);
}
#[tokio::test]
async fn test_built_provider_works() {
let registry = ProviderRegistry::new();
registry.register(Box::new(TestFactory));
let config = ProviderConfig::new("test", "my-model");
let provider = registry.build(&config).unwrap();
let response = provider
.generate_boxed(&ChatParams::default())
.await
.unwrap();
assert_eq!(response.model, "my-model");
}
#[test]
fn test_registry_replace_factory() {
struct AltFactory;
impl ProviderFactory for AltFactory {
fn name(&self) -> &'static str {
"test"
}
fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
Ok(Box::new(TestProvider {
model: format!("alt-{}", config.model),
}))
}
}
let registry = ProviderRegistry::new();
registry.register(Box::new(TestFactory));
registry.register(Box::new(AltFactory));
let config = ProviderConfig::new("test", "model");
let provider = registry.build(&config).unwrap();
assert_eq!(provider.metadata().model, "alt-model");
}
}