1use std::collections::HashMap;
51use std::sync::{Arc, OnceLock, RwLock};
52use std::time::Duration;
53
54use crate::error::LlmError;
55use crate::provider::DynProvider;
56
57#[derive(Debug, Clone, Default)]
62pub struct ProviderConfig {
63 pub provider: String,
65
66 pub api_key: Option<String>,
68
69 pub model: String,
71
72 pub base_url: Option<String>,
74
75 pub timeout: Option<Duration>,
77
78 pub extra: HashMap<String, serde_json::Value>,
83}
84
85impl ProviderConfig {
86 pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
88 Self {
89 provider: provider.into(),
90 model: model.into(),
91 ..Default::default()
92 }
93 }
94
95 #[must_use]
97 pub fn api_key(mut self, key: impl Into<String>) -> Self {
98 self.api_key = Some(key.into());
99 self
100 }
101
102 #[must_use]
104 pub fn base_url(mut self, url: impl Into<String>) -> Self {
105 self.base_url = Some(url.into());
106 self
107 }
108
109 #[must_use]
111 pub fn timeout(mut self, timeout: Duration) -> Self {
112 self.timeout = Some(timeout);
113 self
114 }
115
116 #[must_use]
118 pub fn extra(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
119 self.extra.insert(key.into(), value.into());
120 self
121 }
122
123 pub fn get_extra_str(&self, key: &str) -> Option<&str> {
125 self.extra.get(key).and_then(|v| v.as_str())
126 }
127
128 pub fn get_extra_bool(&self, key: &str) -> Option<bool> {
130 self.extra.get(key).and_then(serde_json::Value::as_bool)
131 }
132
133 pub fn get_extra_i64(&self, key: &str) -> Option<i64> {
135 self.extra.get(key).and_then(serde_json::Value::as_i64)
136 }
137}
138
139pub trait ProviderFactory: Send + Sync {
143 fn name(&self) -> &str;
147
148 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError>;
155}
156
157pub struct ProviderRegistry {
173 factories: RwLock<HashMap<String, Arc<dyn ProviderFactory>>>,
174}
175
176impl std::fmt::Debug for ProviderRegistry {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 let factories = self
179 .factories
180 .read()
181 .expect("provider registry lock poisoned");
182 let names: Vec<_> = factories.keys().collect();
183 f.debug_struct("ProviderRegistry")
184 .field("providers", &names)
185 .finish()
186 }
187}
188
189impl Default for ProviderRegistry {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195impl ProviderRegistry {
196 pub fn new() -> Self {
198 Self {
199 factories: RwLock::new(HashMap::new()),
200 }
201 }
202
203 pub fn global() -> &'static Self {
209 static GLOBAL: OnceLock<ProviderRegistry> = OnceLock::new();
210 GLOBAL.get_or_init(ProviderRegistry::new)
211 }
212
213 pub fn register(&self, factory: Box<dyn ProviderFactory>) -> &Self {
225 let name = factory.name().to_lowercase();
226 let mut factories = self
227 .factories
228 .write()
229 .expect("provider registry lock poisoned");
230 factories.insert(name, Arc::from(factory));
231 self
232 }
233
234 pub fn register_shared(&self, factory: Arc<dyn ProviderFactory>) -> &Self {
238 let name = factory.name().to_lowercase();
239 let mut factories = self
240 .factories
241 .write()
242 .expect("provider registry lock poisoned");
243 factories.insert(name, factory);
244 self
245 }
246
247 pub fn unregister(&self, name: &str) -> bool {
251 let mut factories = self
252 .factories
253 .write()
254 .expect("provider registry lock poisoned");
255 factories.remove(&name.to_lowercase()).is_some()
256 }
257
258 pub fn contains(&self, name: &str) -> bool {
260 let factories = self
261 .factories
262 .read()
263 .expect("provider registry lock poisoned");
264 factories.contains_key(&name.to_lowercase())
265 }
266
267 pub fn providers(&self) -> Vec<String> {
269 let factories = self
270 .factories
271 .read()
272 .expect("provider registry lock poisoned");
273 factories.keys().cloned().collect()
274 }
275
276 pub fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
285 let name = config.provider.to_lowercase();
286 let factories = self
287 .factories
288 .read()
289 .expect("provider registry lock poisoned");
290
291 let factory = factories.get(&name).ok_or_else(|| {
292 let available: Vec<_> = factories.keys().cloned().collect();
293 LlmError::InvalidRequest(format!(
294 "unknown provider '{}'. Available: {:?}",
295 config.provider, available
296 ))
297 })?;
298
299 factory.build(config)
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use crate::chat::{ChatResponse, ContentBlock, StopReason};
307 use crate::provider::{ChatParams, Provider, ProviderMetadata};
308 use crate::stream::ChatStream;
309 use crate::usage::Usage;
310 use std::collections::{HashMap, HashSet};
311
312 struct TestProvider {
313 model: String,
314 }
315
316 impl Provider for TestProvider {
317 async fn generate(&self, _params: &ChatParams) -> Result<ChatResponse, LlmError> {
318 Ok(ChatResponse {
319 content: vec![ContentBlock::Text("test".into())],
320 usage: Usage::default(),
321 stop_reason: StopReason::EndTurn,
322 model: self.model.clone(),
323 metadata: HashMap::default(),
324 })
325 }
326
327 async fn stream(&self, _params: &ChatParams) -> Result<ChatStream, LlmError> {
328 Err(LlmError::InvalidRequest("not implemented".into()))
329 }
330
331 fn metadata(&self) -> ProviderMetadata {
332 ProviderMetadata {
333 name: "test".into(),
334 model: self.model.clone(),
335 context_window: 4096,
336 capabilities: HashSet::new(),
337 }
338 }
339 }
340
341 struct TestFactory;
342
343 impl ProviderFactory for TestFactory {
344 fn name(&self) -> &'static str {
345 "test"
346 }
347
348 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
349 Ok(Box::new(TestProvider {
350 model: config.model.clone(),
351 }))
352 }
353 }
354
355 #[test]
356 fn test_registry_register_and_build() {
357 let registry = ProviderRegistry::new();
358 registry.register(Box::new(TestFactory));
359
360 assert!(registry.contains("test"));
361 assert!(registry.contains("TEST")); let config = ProviderConfig::new("test", "test-model");
364 let provider = registry.build(&config).unwrap();
365
366 assert_eq!(provider.metadata().model, "test-model");
367 }
368
369 #[test]
370 fn test_registry_unknown_provider() {
371 let registry = ProviderRegistry::new();
372
373 let config = ProviderConfig::new("unknown", "model");
374 let result = registry.build(&config);
375
376 assert!(result.is_err());
377 let err = result.err().unwrap();
378 assert!(matches!(err, LlmError::InvalidRequest(_)));
379 }
380
381 #[test]
382 fn test_registry_unregister() {
383 let registry = ProviderRegistry::new();
384 registry.register(Box::new(TestFactory));
385
386 assert!(registry.contains("test"));
387 assert!(registry.unregister("test"));
388 assert!(!registry.contains("test"));
389 assert!(!registry.unregister("test")); }
391
392 #[test]
393 fn test_registry_providers_list() {
394 let registry = ProviderRegistry::new();
395 registry.register(Box::new(TestFactory));
396
397 let providers = registry.providers();
398 assert_eq!(providers, vec!["test"]);
399 }
400
401 #[test]
402 fn test_provider_config_builder() {
403 let config = ProviderConfig::new("anthropic", "claude-3")
404 .api_key("sk-123")
405 .base_url("https://custom.api")
406 .timeout(Duration::from_secs(60))
407 .extra("organization", "org-123");
408
409 assert_eq!(config.provider, "anthropic");
410 assert_eq!(config.model, "claude-3");
411 assert_eq!(config.api_key, Some("sk-123".into()));
412 assert_eq!(config.base_url, Some("https://custom.api".into()));
413 assert_eq!(config.timeout, Some(Duration::from_secs(60)));
414 assert_eq!(config.get_extra_str("organization"), Some("org-123"));
415 }
416
417 #[test]
418 fn test_provider_config_extra_types() {
419 let config = ProviderConfig::new("test", "model")
420 .extra("flag", true)
421 .extra("count", 42i64)
422 .extra("name", "value");
423
424 assert_eq!(config.get_extra_bool("flag"), Some(true));
425 assert_eq!(config.get_extra_i64("count"), Some(42));
426 assert_eq!(config.get_extra_str("name"), Some("value"));
427 assert_eq!(config.get_extra_str("missing"), None);
428 }
429
430 #[tokio::test]
431 async fn test_built_provider_works() {
432 let registry = ProviderRegistry::new();
433 registry.register(Box::new(TestFactory));
434
435 let config = ProviderConfig::new("test", "my-model");
436 let provider = registry.build(&config).unwrap();
437
438 let response = provider
439 .generate_boxed(&ChatParams::default())
440 .await
441 .unwrap();
442 assert_eq!(response.model, "my-model");
443 }
444
445 #[test]
446 fn test_registry_replace_factory() {
447 struct AltFactory;
448 impl ProviderFactory for AltFactory {
449 fn name(&self) -> &'static str {
450 "test"
451 }
452 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
453 Ok(Box::new(TestProvider {
454 model: format!("alt-{}", config.model),
455 }))
456 }
457 }
458
459 let registry = ProviderRegistry::new();
460 registry.register(Box::new(TestFactory));
461 registry.register(Box::new(AltFactory)); let config = ProviderConfig::new("test", "model");
464 let provider = registry.build(&config).unwrap();
465
466 assert_eq!(provider.metadata().model, "alt-model");
467 }
468}