Skip to main content

llm_stack/
registry.rs

1//! Dynamic provider registry for configuration-driven provider instantiation.
2//!
3//! The registry allows providers to be registered by name and instantiated from
4//! configuration at runtime. This enables:
5//!
6//! - Config-file driven provider selection
7//! - Third-party provider registration
8//! - Dynamic provider switching without code changes
9//!
10//! # Example
11//!
12//! ```rust,no_run
13//! use llm_stack::registry::{ProviderRegistry, ProviderConfig};
14//!
15//! // Get the global registry (providers register themselves on startup)
16//! let registry = ProviderRegistry::global();
17//!
18//! // Build a provider from config
19//! let config = ProviderConfig {
20//!     provider: "anthropic".into(),
21//!     api_key: Some("sk-...".into()),
22//!     model: "claude-sonnet-4-20250514".into(),
23//!     ..Default::default()
24//! };
25//!
26//! let provider = registry.build(&config).expect("provider registered");
27//! ```
28//!
29//! # Registering providers
30//!
31//! Provider crates register their factory on initialization:
32//!
33//! ```rust,ignore
34//! use llm_stack::registry::{ProviderRegistry, ProviderFactory, ProviderConfig};
35//!
36//! struct MyProviderFactory;
37//!
38//! impl ProviderFactory for MyProviderFactory {
39//!     fn name(&self) -> &str { "my-provider" }
40//!
41//!     fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
42//!         // Build and return provider
43//!     }
44//! }
45//!
46//! // Register on crate initialization
47//! ProviderRegistry::global().register(Box::new(MyProviderFactory));
48//! ```
49
50use std::collections::HashMap;
51use std::sync::{Arc, OnceLock, RwLock};
52use std::time::Duration;
53
54use crate::error::LlmError;
55use crate::provider::DynProvider;
56
57/// Configuration for building a provider from the registry.
58///
59/// This struct contains common configuration fields that work across
60/// all providers. Provider-specific options go in the `extra` map.
61#[derive(Debug, Clone, Default)]
62pub struct ProviderConfig {
63    /// Provider name (e.g., "anthropic", "openai", "ollama").
64    pub provider: String,
65
66    /// API key for authenticated providers.
67    pub api_key: Option<String>,
68
69    /// Model identifier (e.g., "claude-sonnet-4-20250514", "gpt-4o").
70    pub model: String,
71
72    /// Custom base URL for the API endpoint.
73    pub base_url: Option<String>,
74
75    /// Request timeout.
76    pub timeout: Option<Duration>,
77
78    /// Provider-specific configuration options.
79    ///
80    /// Use this for options that don't fit the common fields above.
81    /// Each provider documents which keys it recognizes.
82    pub extra: HashMap<String, serde_json::Value>,
83}
84
85impl ProviderConfig {
86    /// Creates a new config with the given provider and model.
87    pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
88        Self {
89            provider: provider.into(),
90            model: model.into(),
91            ..Default::default()
92        }
93    }
94
95    /// Sets the API key.
96    #[must_use]
97    pub fn api_key(mut self, key: impl Into<String>) -> Self {
98        self.api_key = Some(key.into());
99        self
100    }
101
102    /// Sets the base URL.
103    #[must_use]
104    pub fn base_url(mut self, url: impl Into<String>) -> Self {
105        self.base_url = Some(url.into());
106        self
107    }
108
109    /// Sets the timeout.
110    #[must_use]
111    pub fn timeout(mut self, timeout: Duration) -> Self {
112        self.timeout = Some(timeout);
113        self
114    }
115
116    /// Adds a provider-specific extra option.
117    #[must_use]
118    pub fn extra(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
119        self.extra.insert(key.into(), value.into());
120        self
121    }
122
123    /// Gets a string value from extra options.
124    pub fn get_extra_str(&self, key: &str) -> Option<&str> {
125        self.extra.get(key).and_then(|v| v.as_str())
126    }
127
128    /// Gets a bool value from extra options.
129    pub fn get_extra_bool(&self, key: &str) -> Option<bool> {
130        self.extra.get(key).and_then(serde_json::Value::as_bool)
131    }
132
133    /// Gets an integer value from extra options.
134    pub fn get_extra_i64(&self, key: &str) -> Option<i64> {
135        self.extra.get(key).and_then(serde_json::Value::as_i64)
136    }
137}
138
139/// Factory trait for creating providers from configuration.
140///
141/// Implement this trait to register a provider with the registry.
142pub trait ProviderFactory: Send + Sync {
143    /// Returns the provider name used for registration and lookup.
144    ///
145    /// This should be a lowercase identifier (e.g., "anthropic", "openai").
146    fn name(&self) -> &str;
147
148    /// Creates a provider instance from the given configuration.
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if the configuration is invalid or missing
153    /// required fields for this provider.
154    fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError>;
155}
156
157/// A registry of provider factories for dynamic provider instantiation.
158///
159/// The registry maintains a map of provider names to their factories,
160/// allowing providers to be created from configuration at runtime.
161///
162/// # Thread Safety
163///
164/// The registry is thread-safe and can be accessed concurrently.
165/// Registration and lookup use interior mutability via `RwLock`.
166///
167/// # Global vs Local Registries
168///
169/// Use [`ProviderRegistry::global()`] for the shared global registry,
170/// or create local registries with [`ProviderRegistry::new()`] for
171/// testing or isolated contexts.
172pub struct ProviderRegistry {
173    factories: RwLock<HashMap<String, Arc<dyn ProviderFactory>>>,
174}
175
176impl std::fmt::Debug for ProviderRegistry {
177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178        let factories = self
179            .factories
180            .read()
181            .expect("provider registry lock poisoned");
182        let names: Vec<_> = factories.keys().collect();
183        f.debug_struct("ProviderRegistry")
184            .field("providers", &names)
185            .finish()
186    }
187}
188
189impl Default for ProviderRegistry {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195impl ProviderRegistry {
196    /// Creates a new empty registry.
197    pub fn new() -> Self {
198        Self {
199            factories: RwLock::new(HashMap::new()),
200        }
201    }
202
203    /// Returns the global shared registry.
204    ///
205    /// Provider crates should register their factories here on initialization.
206    /// Application code can then build providers from configuration without
207    /// knowing which providers are available at compile time.
208    pub fn global() -> &'static Self {
209        static GLOBAL: OnceLock<ProviderRegistry> = OnceLock::new();
210        GLOBAL.get_or_init(ProviderRegistry::new)
211    }
212
213    /// Registers a provider factory.
214    ///
215    /// If a factory with the same name already exists, it is replaced.
216    ///
217    /// # Example
218    ///
219    /// ```rust,ignore
220    /// use llm_stack::registry::{ProviderRegistry, ProviderFactory};
221    ///
222    /// ProviderRegistry::global().register(Box::new(MyProviderFactory));
223    /// ```
224    pub fn register(&self, factory: Box<dyn ProviderFactory>) -> &Self {
225        let name = factory.name().to_lowercase();
226        let mut factories = self
227            .factories
228            .write()
229            .expect("provider registry lock poisoned");
230        factories.insert(name, Arc::from(factory));
231        self
232    }
233
234    /// Registers a provider factory (chainable Arc version).
235    ///
236    /// Use this when you want to share the factory instance.
237    pub fn register_shared(&self, factory: Arc<dyn ProviderFactory>) -> &Self {
238        let name = factory.name().to_lowercase();
239        let mut factories = self
240            .factories
241            .write()
242            .expect("provider registry lock poisoned");
243        factories.insert(name, factory);
244        self
245    }
246
247    /// Unregisters a provider by name.
248    ///
249    /// Returns `true` if the provider was registered and removed.
250    pub fn unregister(&self, name: &str) -> bool {
251        let mut factories = self
252            .factories
253            .write()
254            .expect("provider registry lock poisoned");
255        factories.remove(&name.to_lowercase()).is_some()
256    }
257
258    /// Checks if a provider is registered.
259    pub fn contains(&self, name: &str) -> bool {
260        let factories = self
261            .factories
262            .read()
263            .expect("provider registry lock poisoned");
264        factories.contains_key(&name.to_lowercase())
265    }
266
267    /// Returns the names of all registered providers.
268    pub fn providers(&self) -> Vec<String> {
269        let factories = self
270            .factories
271            .read()
272            .expect("provider registry lock poisoned");
273        factories.keys().cloned().collect()
274    }
275
276    /// Builds a provider from configuration.
277    ///
278    /// Looks up the factory by `config.provider` and delegates to it.
279    ///
280    /// # Errors
281    ///
282    /// Returns [`LlmError::InvalidRequest`] if no factory is registered
283    /// for the requested provider name.
284    pub fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
285        let name = config.provider.to_lowercase();
286        let factories = self
287            .factories
288            .read()
289            .expect("provider registry lock poisoned");
290
291        let factory = factories.get(&name).ok_or_else(|| {
292            let available: Vec<_> = factories.keys().cloned().collect();
293            LlmError::InvalidRequest(format!(
294                "unknown provider '{}'. Available: {:?}",
295                config.provider, available
296            ))
297        })?;
298
299        factory.build(config)
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::chat::{ChatResponse, ContentBlock, StopReason};
307    use crate::provider::{ChatParams, Provider, ProviderMetadata};
308    use crate::stream::ChatStream;
309    use crate::usage::Usage;
310    use std::collections::{HashMap, HashSet};
311
312    struct TestProvider {
313        model: String,
314    }
315
316    impl Provider for TestProvider {
317        async fn generate(&self, _params: &ChatParams) -> Result<ChatResponse, LlmError> {
318            Ok(ChatResponse {
319                content: vec![ContentBlock::Text("test".into())],
320                usage: Usage::default(),
321                stop_reason: StopReason::EndTurn,
322                model: self.model.clone(),
323                metadata: HashMap::default(),
324            })
325        }
326
327        async fn stream(&self, _params: &ChatParams) -> Result<ChatStream, LlmError> {
328            Err(LlmError::InvalidRequest("not implemented".into()))
329        }
330
331        fn metadata(&self) -> ProviderMetadata {
332            ProviderMetadata {
333                name: "test".into(),
334                model: self.model.clone(),
335                context_window: 4096,
336                capabilities: HashSet::new(),
337            }
338        }
339    }
340
341    struct TestFactory;
342
343    impl ProviderFactory for TestFactory {
344        fn name(&self) -> &'static str {
345            "test"
346        }
347
348        fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
349            Ok(Box::new(TestProvider {
350                model: config.model.clone(),
351            }))
352        }
353    }
354
355    #[test]
356    fn test_registry_register_and_build() {
357        let registry = ProviderRegistry::new();
358        registry.register(Box::new(TestFactory));
359
360        assert!(registry.contains("test"));
361        assert!(registry.contains("TEST")); // case insensitive
362
363        let config = ProviderConfig::new("test", "test-model");
364        let provider = registry.build(&config).unwrap();
365
366        assert_eq!(provider.metadata().model, "test-model");
367    }
368
369    #[test]
370    fn test_registry_unknown_provider() {
371        let registry = ProviderRegistry::new();
372
373        let config = ProviderConfig::new("unknown", "model");
374        let result = registry.build(&config);
375
376        assert!(result.is_err());
377        let err = result.err().unwrap();
378        assert!(matches!(err, LlmError::InvalidRequest(_)));
379    }
380
381    #[test]
382    fn test_registry_unregister() {
383        let registry = ProviderRegistry::new();
384        registry.register(Box::new(TestFactory));
385
386        assert!(registry.contains("test"));
387        assert!(registry.unregister("test"));
388        assert!(!registry.contains("test"));
389        assert!(!registry.unregister("test")); // already removed
390    }
391
392    #[test]
393    fn test_registry_providers_list() {
394        let registry = ProviderRegistry::new();
395        registry.register(Box::new(TestFactory));
396
397        let providers = registry.providers();
398        assert_eq!(providers, vec!["test"]);
399    }
400
401    #[test]
402    fn test_provider_config_builder() {
403        let config = ProviderConfig::new("anthropic", "claude-3")
404            .api_key("sk-123")
405            .base_url("https://custom.api")
406            .timeout(Duration::from_secs(60))
407            .extra("organization", "org-123");
408
409        assert_eq!(config.provider, "anthropic");
410        assert_eq!(config.model, "claude-3");
411        assert_eq!(config.api_key, Some("sk-123".into()));
412        assert_eq!(config.base_url, Some("https://custom.api".into()));
413        assert_eq!(config.timeout, Some(Duration::from_secs(60)));
414        assert_eq!(config.get_extra_str("organization"), Some("org-123"));
415    }
416
417    #[test]
418    fn test_provider_config_extra_types() {
419        let config = ProviderConfig::new("test", "model")
420            .extra("flag", true)
421            .extra("count", 42i64)
422            .extra("name", "value");
423
424        assert_eq!(config.get_extra_bool("flag"), Some(true));
425        assert_eq!(config.get_extra_i64("count"), Some(42));
426        assert_eq!(config.get_extra_str("name"), Some("value"));
427        assert_eq!(config.get_extra_str("missing"), None);
428    }
429
430    #[tokio::test]
431    async fn test_built_provider_works() {
432        let registry = ProviderRegistry::new();
433        registry.register(Box::new(TestFactory));
434
435        let config = ProviderConfig::new("test", "my-model");
436        let provider = registry.build(&config).unwrap();
437
438        let response = provider
439            .generate_boxed(&ChatParams::default())
440            .await
441            .unwrap();
442        assert_eq!(response.model, "my-model");
443    }
444
445    #[test]
446    fn test_registry_replace_factory() {
447        struct AltFactory;
448        impl ProviderFactory for AltFactory {
449            fn name(&self) -> &'static str {
450                "test"
451            }
452            fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
453                Ok(Box::new(TestProvider {
454                    model: format!("alt-{}", config.model),
455                }))
456            }
457        }
458
459        let registry = ProviderRegistry::new();
460        registry.register(Box::new(TestFactory));
461        registry.register(Box::new(AltFactory)); // replaces
462
463        let config = ProviderConfig::new("test", "model");
464        let provider = registry.build(&config).unwrap();
465
466        assert_eq!(provider.metadata().model, "alt-model");
467    }
468}