1use std::collections::HashMap;
51use std::sync::{Arc, OnceLock, RwLock};
52use std::time::Duration;
53
54use crate::error::LlmError;
55use crate::provider::DynProvider;
56
57#[derive(Debug, Clone, Default)]
62pub struct ProviderConfig {
63 pub provider: String,
65
66 pub api_key: Option<String>,
68
69 pub model: String,
71
72 pub base_url: Option<String>,
74
75 pub timeout: Option<Duration>,
77
78 pub client: Option<reqwest::Client>,
84
85 pub extra: HashMap<String, serde_json::Value>,
90}
91
92impl ProviderConfig {
93 pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
95 Self {
96 provider: provider.into(),
97 model: model.into(),
98 ..Default::default()
99 }
100 }
101
102 #[must_use]
104 pub fn api_key(mut self, key: impl Into<String>) -> Self {
105 self.api_key = Some(key.into());
106 self
107 }
108
109 #[must_use]
111 pub fn base_url(mut self, url: impl Into<String>) -> Self {
112 self.base_url = Some(url.into());
113 self
114 }
115
116 #[must_use]
118 pub fn timeout(mut self, timeout: Duration) -> Self {
119 self.timeout = Some(timeout);
120 self
121 }
122
123 #[must_use]
125 pub fn client(mut self, client: reqwest::Client) -> Self {
126 self.client = Some(client);
127 self
128 }
129
130 #[must_use]
132 pub fn extra(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
133 self.extra.insert(key.into(), value.into());
134 self
135 }
136
137 pub fn get_extra_str(&self, key: &str) -> Option<&str> {
139 self.extra.get(key).and_then(|v| v.as_str())
140 }
141
142 pub fn get_extra_bool(&self, key: &str) -> Option<bool> {
144 self.extra.get(key).and_then(serde_json::Value::as_bool)
145 }
146
147 pub fn get_extra_i64(&self, key: &str) -> Option<i64> {
149 self.extra.get(key).and_then(serde_json::Value::as_i64)
150 }
151}
152
153pub trait ProviderFactory: Send + Sync {
157 fn name(&self) -> &str;
161
162 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError>;
169}
170
171pub struct ProviderRegistry {
187 factories: RwLock<HashMap<String, Arc<dyn ProviderFactory>>>,
188}
189
190impl std::fmt::Debug for ProviderRegistry {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 let factories = self
193 .factories
194 .read()
195 .expect("provider registry lock poisoned");
196 let names: Vec<_> = factories.keys().collect();
197 f.debug_struct("ProviderRegistry")
198 .field("providers", &names)
199 .finish()
200 }
201}
202
203impl Default for ProviderRegistry {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209impl ProviderRegistry {
210 pub fn new() -> Self {
212 Self {
213 factories: RwLock::new(HashMap::new()),
214 }
215 }
216
217 pub fn global() -> &'static Self {
223 static GLOBAL: OnceLock<ProviderRegistry> = OnceLock::new();
224 GLOBAL.get_or_init(ProviderRegistry::new)
225 }
226
227 pub fn register(&self, factory: Box<dyn ProviderFactory>) -> &Self {
239 let name = factory.name().to_lowercase();
240 let mut factories = self
241 .factories
242 .write()
243 .expect("provider registry lock poisoned");
244 factories.insert(name, Arc::from(factory));
245 self
246 }
247
248 pub fn register_shared(&self, factory: Arc<dyn ProviderFactory>) -> &Self {
252 let name = factory.name().to_lowercase();
253 let mut factories = self
254 .factories
255 .write()
256 .expect("provider registry lock poisoned");
257 factories.insert(name, factory);
258 self
259 }
260
261 pub fn unregister(&self, name: &str) -> bool {
265 let mut factories = self
266 .factories
267 .write()
268 .expect("provider registry lock poisoned");
269 factories.remove(&name.to_lowercase()).is_some()
270 }
271
272 pub fn contains(&self, name: &str) -> bool {
274 let factories = self
275 .factories
276 .read()
277 .expect("provider registry lock poisoned");
278 factories.contains_key(&name.to_lowercase())
279 }
280
281 pub fn providers(&self) -> Vec<String> {
283 let factories = self
284 .factories
285 .read()
286 .expect("provider registry lock poisoned");
287 factories.keys().cloned().collect()
288 }
289
290 pub fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
299 let name = config.provider.to_lowercase();
300 let factories = self
301 .factories
302 .read()
303 .expect("provider registry lock poisoned");
304
305 let factory = factories.get(&name).ok_or_else(|| {
306 let available: Vec<_> = factories.keys().cloned().collect();
307 LlmError::InvalidRequest(format!(
308 "unknown provider '{}'. Available: {:?}",
309 config.provider, available
310 ))
311 })?;
312
313 factory.build(config)
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::chat::{ChatResponse, ContentBlock, StopReason};
321 use crate::provider::{ChatParams, Provider, ProviderMetadata};
322 use crate::stream::ChatStream;
323 use crate::usage::Usage;
324 use std::collections::{HashMap, HashSet};
325
326 struct TestProvider {
327 model: String,
328 }
329
330 impl Provider for TestProvider {
331 async fn generate(&self, _params: &ChatParams) -> Result<ChatResponse, LlmError> {
332 Ok(ChatResponse {
333 content: vec![ContentBlock::Text("test".into())],
334 usage: Usage::default(),
335 stop_reason: StopReason::EndTurn,
336 model: self.model.clone(),
337 metadata: HashMap::default(),
338 })
339 }
340
341 async fn stream(&self, _params: &ChatParams) -> Result<ChatStream, LlmError> {
342 Err(LlmError::InvalidRequest("not implemented".into()))
343 }
344
345 fn metadata(&self) -> ProviderMetadata {
346 ProviderMetadata {
347 name: "test".into(),
348 model: self.model.clone(),
349 context_window: 4096,
350 capabilities: HashSet::new(),
351 }
352 }
353 }
354
355 struct TestFactory;
356
357 impl ProviderFactory for TestFactory {
358 fn name(&self) -> &'static str {
359 "test"
360 }
361
362 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
363 Ok(Box::new(TestProvider {
364 model: config.model.clone(),
365 }))
366 }
367 }
368
369 #[test]
370 fn test_registry_register_and_build() {
371 let registry = ProviderRegistry::new();
372 registry.register(Box::new(TestFactory));
373
374 assert!(registry.contains("test"));
375 assert!(registry.contains("TEST")); let config = ProviderConfig::new("test", "test-model");
378 let provider = registry.build(&config).unwrap();
379
380 assert_eq!(provider.metadata().model, "test-model");
381 }
382
383 #[test]
384 fn test_registry_unknown_provider() {
385 let registry = ProviderRegistry::new();
386
387 let config = ProviderConfig::new("unknown", "model");
388 let result = registry.build(&config);
389
390 assert!(result.is_err());
391 let err = result.err().unwrap();
392 assert!(matches!(err, LlmError::InvalidRequest(_)));
393 }
394
395 #[test]
396 fn test_registry_unregister() {
397 let registry = ProviderRegistry::new();
398 registry.register(Box::new(TestFactory));
399
400 assert!(registry.contains("test"));
401 assert!(registry.unregister("test"));
402 assert!(!registry.contains("test"));
403 assert!(!registry.unregister("test")); }
405
406 #[test]
407 fn test_registry_providers_list() {
408 let registry = ProviderRegistry::new();
409 registry.register(Box::new(TestFactory));
410
411 let providers = registry.providers();
412 assert_eq!(providers, vec!["test"]);
413 }
414
415 #[test]
416 fn test_provider_config_builder() {
417 let config = ProviderConfig::new("anthropic", "claude-3")
418 .api_key("sk-123")
419 .base_url("https://custom.api")
420 .timeout(Duration::from_secs(60))
421 .extra("organization", "org-123");
422
423 assert_eq!(config.provider, "anthropic");
424 assert_eq!(config.model, "claude-3");
425 assert_eq!(config.api_key, Some("sk-123".into()));
426 assert_eq!(config.base_url, Some("https://custom.api".into()));
427 assert_eq!(config.timeout, Some(Duration::from_secs(60)));
428 assert_eq!(config.get_extra_str("organization"), Some("org-123"));
429 }
430
431 #[test]
432 fn test_provider_config_extra_types() {
433 let config = ProviderConfig::new("test", "model")
434 .extra("flag", true)
435 .extra("count", 42i64)
436 .extra("name", "value");
437
438 assert_eq!(config.get_extra_bool("flag"), Some(true));
439 assert_eq!(config.get_extra_i64("count"), Some(42));
440 assert_eq!(config.get_extra_str("name"), Some("value"));
441 assert_eq!(config.get_extra_str("missing"), None);
442 }
443
444 #[tokio::test]
445 async fn test_built_provider_works() {
446 let registry = ProviderRegistry::new();
447 registry.register(Box::new(TestFactory));
448
449 let config = ProviderConfig::new("test", "my-model");
450 let provider = registry.build(&config).unwrap();
451
452 let response = provider
453 .generate_boxed(&ChatParams::default())
454 .await
455 .unwrap();
456 assert_eq!(response.model, "my-model");
457 }
458
459 #[test]
460 fn test_registry_replace_factory() {
461 struct AltFactory;
462 impl ProviderFactory for AltFactory {
463 fn name(&self) -> &'static str {
464 "test"
465 }
466 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
467 Ok(Box::new(TestProvider {
468 model: format!("alt-{}", config.model),
469 }))
470 }
471 }
472
473 let registry = ProviderRegistry::new();
474 registry.register(Box::new(TestFactory));
475 registry.register(Box::new(AltFactory)); let config = ProviderConfig::new("test", "model");
478 let provider = registry.build(&config).unwrap();
479
480 assert_eq!(provider.metadata().model, "alt-model");
481 }
482}