Skip to main content

llm_stack/
registry.rs

1//! Dynamic provider registry for configuration-driven provider instantiation.
2//!
3//! The registry allows providers to be registered by name and instantiated from
4//! configuration at runtime. This enables:
5//!
6//! - Config-file driven provider selection
7//! - Third-party provider registration
8//! - Dynamic provider switching without code changes
9//!
10//! # Example
11//!
12//! ```rust,no_run
13//! use llm_stack::registry::{ProviderRegistry, ProviderConfig};
14//!
15//! // Get the global registry (providers register themselves on startup)
16//! let registry = ProviderRegistry::global();
17//!
18//! // Build a provider from config
19//! let config = ProviderConfig {
20//!     provider: "anthropic".into(),
21//!     api_key: Some("sk-...".into()),
22//!     model: "claude-sonnet-4-20250514".into(),
23//!     ..Default::default()
24//! };
25//!
26//! let provider = registry.build(&config).expect("provider registered");
27//! ```
28//!
29//! # Registering providers
30//!
31//! Provider crates register their factory on initialization:
32//!
33//! ```rust,ignore
34//! use llm_stack::registry::{ProviderRegistry, ProviderFactory, ProviderConfig};
35//!
36//! struct MyProviderFactory;
37//!
38//! impl ProviderFactory for MyProviderFactory {
39//!     fn name(&self) -> &str { "my-provider" }
40//!
41//!     fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
42//!         // Build and return provider
43//!     }
44//! }
45//!
46//! // Register on crate initialization
47//! ProviderRegistry::global().register(Box::new(MyProviderFactory));
48//! ```
49
50use std::collections::HashMap;
51use std::sync::{Arc, OnceLock, RwLock};
52use std::time::Duration;
53
54use crate::error::LlmError;
55use crate::provider::DynProvider;
56
57/// Configuration for building a provider from the registry.
58///
59/// This struct contains common configuration fields that work across
60/// all providers. Provider-specific options go in the `extra` map.
61#[derive(Debug, Clone, Default)]
62pub struct ProviderConfig {
63    /// Provider name (e.g., "anthropic", "openai", "ollama").
64    pub provider: String,
65
66    /// API key for authenticated providers.
67    pub api_key: Option<String>,
68
69    /// Model identifier (e.g., "claude-sonnet-4-20250514", "gpt-4o").
70    pub model: String,
71
72    /// Custom base URL for the API endpoint.
73    pub base_url: Option<String>,
74
75    /// Request timeout.
76    pub timeout: Option<Duration>,
77
78    /// Shared HTTP client for connection pooling.
79    ///
80    /// When set, the provider will use this client instead of creating
81    /// its own. Useful when multiple providers should share a connection
82    /// pool (e.g., in multi-agent systems).
83    pub client: Option<reqwest::Client>,
84
85    /// Provider-specific configuration options.
86    ///
87    /// Use this for options that don't fit the common fields above.
88    /// Each provider documents which keys it recognizes.
89    pub extra: HashMap<String, serde_json::Value>,
90}
91
92impl ProviderConfig {
93    /// Creates a new config with the given provider and model.
94    pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
95        Self {
96            provider: provider.into(),
97            model: model.into(),
98            ..Default::default()
99        }
100    }
101
102    /// Sets the API key.
103    #[must_use]
104    pub fn api_key(mut self, key: impl Into<String>) -> Self {
105        self.api_key = Some(key.into());
106        self
107    }
108
109    /// Sets the base URL.
110    #[must_use]
111    pub fn base_url(mut self, url: impl Into<String>) -> Self {
112        self.base_url = Some(url.into());
113        self
114    }
115
116    /// Sets the timeout.
117    #[must_use]
118    pub fn timeout(mut self, timeout: Duration) -> Self {
119        self.timeout = Some(timeout);
120        self
121    }
122
123    /// Sets a shared HTTP client for connection pooling.
124    #[must_use]
125    pub fn client(mut self, client: reqwest::Client) -> Self {
126        self.client = Some(client);
127        self
128    }
129
130    /// Adds a provider-specific extra option.
131    #[must_use]
132    pub fn extra(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
133        self.extra.insert(key.into(), value.into());
134        self
135    }
136
137    /// Gets a string value from extra options.
138    pub fn get_extra_str(&self, key: &str) -> Option<&str> {
139        self.extra.get(key).and_then(|v| v.as_str())
140    }
141
142    /// Gets a bool value from extra options.
143    pub fn get_extra_bool(&self, key: &str) -> Option<bool> {
144        self.extra.get(key).and_then(serde_json::Value::as_bool)
145    }
146
147    /// Gets an integer value from extra options.
148    pub fn get_extra_i64(&self, key: &str) -> Option<i64> {
149        self.extra.get(key).and_then(serde_json::Value::as_i64)
150    }
151}
152
153/// Factory trait for creating providers from configuration.
154///
155/// Implement this trait to register a provider with the registry.
156pub trait ProviderFactory: Send + Sync {
157    /// Returns the provider name used for registration and lookup.
158    ///
159    /// This should be a lowercase identifier (e.g., "anthropic", "openai").
160    fn name(&self) -> &str;
161
162    /// Creates a provider instance from the given configuration.
163    ///
164    /// # Errors
165    ///
166    /// Returns an error if the configuration is invalid or missing
167    /// required fields for this provider.
168    fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError>;
169}
170
171/// A registry of provider factories for dynamic provider instantiation.
172///
173/// The registry maintains a map of provider names to their factories,
174/// allowing providers to be created from configuration at runtime.
175///
176/// # Thread Safety
177///
178/// The registry is thread-safe and can be accessed concurrently.
179/// Registration and lookup use interior mutability via `RwLock`.
180///
181/// # Global vs Local Registries
182///
183/// Use [`ProviderRegistry::global()`] for the shared global registry,
184/// or create local registries with [`ProviderRegistry::new()`] for
185/// testing or isolated contexts.
186pub struct ProviderRegistry {
187    factories: RwLock<HashMap<String, Arc<dyn ProviderFactory>>>,
188}
189
190impl std::fmt::Debug for ProviderRegistry {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        let factories = self
193            .factories
194            .read()
195            .expect("provider registry lock poisoned");
196        let names: Vec<_> = factories.keys().collect();
197        f.debug_struct("ProviderRegistry")
198            .field("providers", &names)
199            .finish()
200    }
201}
202
203impl Default for ProviderRegistry {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209impl ProviderRegistry {
210    /// Creates a new empty registry.
211    pub fn new() -> Self {
212        Self {
213            factories: RwLock::new(HashMap::new()),
214        }
215    }
216
217    /// Returns the global shared registry.
218    ///
219    /// Provider crates should register their factories here on initialization.
220    /// Application code can then build providers from configuration without
221    /// knowing which providers are available at compile time.
222    pub fn global() -> &'static Self {
223        static GLOBAL: OnceLock<ProviderRegistry> = OnceLock::new();
224        GLOBAL.get_or_init(ProviderRegistry::new)
225    }
226
227    /// Registers a provider factory.
228    ///
229    /// If a factory with the same name already exists, it is replaced.
230    ///
231    /// # Example
232    ///
233    /// ```rust,ignore
234    /// use llm_stack::registry::{ProviderRegistry, ProviderFactory};
235    ///
236    /// ProviderRegistry::global().register(Box::new(MyProviderFactory));
237    /// ```
238    pub fn register(&self, factory: Box<dyn ProviderFactory>) -> &Self {
239        let name = factory.name().to_lowercase();
240        let mut factories = self
241            .factories
242            .write()
243            .expect("provider registry lock poisoned");
244        factories.insert(name, Arc::from(factory));
245        self
246    }
247
248    /// Registers a provider factory (chainable Arc version).
249    ///
250    /// Use this when you want to share the factory instance.
251    pub fn register_shared(&self, factory: Arc<dyn ProviderFactory>) -> &Self {
252        let name = factory.name().to_lowercase();
253        let mut factories = self
254            .factories
255            .write()
256            .expect("provider registry lock poisoned");
257        factories.insert(name, factory);
258        self
259    }
260
261    /// Unregisters a provider by name.
262    ///
263    /// Returns `true` if the provider was registered and removed.
264    pub fn unregister(&self, name: &str) -> bool {
265        let mut factories = self
266            .factories
267            .write()
268            .expect("provider registry lock poisoned");
269        factories.remove(&name.to_lowercase()).is_some()
270    }
271
272    /// Checks if a provider is registered.
273    pub fn contains(&self, name: &str) -> bool {
274        let factories = self
275            .factories
276            .read()
277            .expect("provider registry lock poisoned");
278        factories.contains_key(&name.to_lowercase())
279    }
280
281    /// Returns the names of all registered providers.
282    pub fn providers(&self) -> Vec<String> {
283        let factories = self
284            .factories
285            .read()
286            .expect("provider registry lock poisoned");
287        factories.keys().cloned().collect()
288    }
289
290    /// Builds a provider from configuration.
291    ///
292    /// Looks up the factory by `config.provider` and delegates to it.
293    ///
294    /// # Errors
295    ///
296    /// Returns [`LlmError::InvalidRequest`] if no factory is registered
297    /// for the requested provider name.
298    pub fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
299        let name = config.provider.to_lowercase();
300        let factories = self
301            .factories
302            .read()
303            .expect("provider registry lock poisoned");
304
305        let factory = factories.get(&name).ok_or_else(|| {
306            let available: Vec<_> = factories.keys().cloned().collect();
307            LlmError::InvalidRequest(format!(
308                "unknown provider '{}'. Available: {:?}",
309                config.provider, available
310            ))
311        })?;
312
313        factory.build(config)
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::chat::{ChatResponse, ContentBlock, StopReason};
321    use crate::provider::{ChatParams, Provider, ProviderMetadata};
322    use crate::stream::ChatStream;
323    use crate::usage::Usage;
324    use std::collections::{HashMap, HashSet};
325
326    struct TestProvider {
327        model: String,
328    }
329
330    impl Provider for TestProvider {
331        async fn generate(&self, _params: &ChatParams) -> Result<ChatResponse, LlmError> {
332            Ok(ChatResponse {
333                content: vec![ContentBlock::Text("test".into())],
334                usage: Usage::default(),
335                stop_reason: StopReason::EndTurn,
336                model: self.model.clone(),
337                metadata: HashMap::default(),
338            })
339        }
340
341        async fn stream(&self, _params: &ChatParams) -> Result<ChatStream, LlmError> {
342            Err(LlmError::InvalidRequest("not implemented".into()))
343        }
344
345        fn metadata(&self) -> ProviderMetadata {
346            ProviderMetadata {
347                name: "test".into(),
348                model: self.model.clone(),
349                context_window: 4096,
350                capabilities: HashSet::new(),
351            }
352        }
353    }
354
355    struct TestFactory;
356
357    impl ProviderFactory for TestFactory {
358        fn name(&self) -> &'static str {
359            "test"
360        }
361
362        fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
363            Ok(Box::new(TestProvider {
364                model: config.model.clone(),
365            }))
366        }
367    }
368
369    #[test]
370    fn test_registry_register_and_build() {
371        let registry = ProviderRegistry::new();
372        registry.register(Box::new(TestFactory));
373
374        assert!(registry.contains("test"));
375        assert!(registry.contains("TEST")); // case insensitive
376
377        let config = ProviderConfig::new("test", "test-model");
378        let provider = registry.build(&config).unwrap();
379
380        assert_eq!(provider.metadata().model, "test-model");
381    }
382
383    #[test]
384    fn test_registry_unknown_provider() {
385        let registry = ProviderRegistry::new();
386
387        let config = ProviderConfig::new("unknown", "model");
388        let result = registry.build(&config);
389
390        assert!(result.is_err());
391        let err = result.err().unwrap();
392        assert!(matches!(err, LlmError::InvalidRequest(_)));
393    }
394
395    #[test]
396    fn test_registry_unregister() {
397        let registry = ProviderRegistry::new();
398        registry.register(Box::new(TestFactory));
399
400        assert!(registry.contains("test"));
401        assert!(registry.unregister("test"));
402        assert!(!registry.contains("test"));
403        assert!(!registry.unregister("test")); // already removed
404    }
405
406    #[test]
407    fn test_registry_providers_list() {
408        let registry = ProviderRegistry::new();
409        registry.register(Box::new(TestFactory));
410
411        let providers = registry.providers();
412        assert_eq!(providers, vec!["test"]);
413    }
414
415    #[test]
416    fn test_provider_config_builder() {
417        let config = ProviderConfig::new("anthropic", "claude-3")
418            .api_key("sk-123")
419            .base_url("https://custom.api")
420            .timeout(Duration::from_secs(60))
421            .extra("organization", "org-123");
422
423        assert_eq!(config.provider, "anthropic");
424        assert_eq!(config.model, "claude-3");
425        assert_eq!(config.api_key, Some("sk-123".into()));
426        assert_eq!(config.base_url, Some("https://custom.api".into()));
427        assert_eq!(config.timeout, Some(Duration::from_secs(60)));
428        assert_eq!(config.get_extra_str("organization"), Some("org-123"));
429    }
430
431    #[test]
432    fn test_provider_config_extra_types() {
433        let config = ProviderConfig::new("test", "model")
434            .extra("flag", true)
435            .extra("count", 42i64)
436            .extra("name", "value");
437
438        assert_eq!(config.get_extra_bool("flag"), Some(true));
439        assert_eq!(config.get_extra_i64("count"), Some(42));
440        assert_eq!(config.get_extra_str("name"), Some("value"));
441        assert_eq!(config.get_extra_str("missing"), None);
442    }
443
444    #[tokio::test]
445    async fn test_built_provider_works() {
446        let registry = ProviderRegistry::new();
447        registry.register(Box::new(TestFactory));
448
449        let config = ProviderConfig::new("test", "my-model");
450        let provider = registry.build(&config).unwrap();
451
452        let response = provider
453            .generate_boxed(&ChatParams::default())
454            .await
455            .unwrap();
456        assert_eq!(response.model, "my-model");
457    }
458
459    #[test]
460    fn test_registry_replace_factory() {
461        struct AltFactory;
462        impl ProviderFactory for AltFactory {
463            fn name(&self) -> &'static str {
464                "test"
465            }
466            fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
467                Ok(Box::new(TestProvider {
468                    model: format!("alt-{}", config.model),
469                }))
470            }
471        }
472
473        let registry = ProviderRegistry::new();
474        registry.register(Box::new(TestFactory));
475        registry.register(Box::new(AltFactory)); // replaces
476
477        let config = ProviderConfig::new("test", "model");
478        let provider = registry.build(&config).unwrap();
479
480        assert_eq!(provider.metadata().model, "alt-model");
481    }
482}