Skip to main content

codetether_agent/provider/
registry.rs

1//! [`ProviderRegistry`] — name → provider map with resolution.
2//!
3//! Holds all initialised providers and resolves `"provider/model"` strings
4//! to the correct [`Provider`] instance.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use codetether_agent::provider::ProviderRegistry;
10//!
11//! let registry = ProviderRegistry::new();
12//! assert!(registry.list().is_empty());
13//! ```
14
15use super::parse::parse_model_string;
16use super::traits::Provider;
17use anyhow::Result;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21/// Registry of available providers.
22pub struct ProviderRegistry {
23    pub(crate) providers: HashMap<String, Arc<dyn Provider>>,
24}
25
26impl std::fmt::Debug for ProviderRegistry {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("ProviderRegistry")
29            .field("provider_count", &self.providers.len())
30            .field("providers", &self.providers.keys().collect::<Vec<_>>())
31            .finish()
32    }
33}
34
35impl ProviderRegistry {
36    /// Create an empty registry.
37    ///
38    /// # Examples
39    ///
40    /// ```rust
41    /// use codetether_agent::provider::ProviderRegistry;
42    /// let registry = ProviderRegistry::new();
43    /// assert!(registry.list().is_empty());
44    /// ```
45    pub fn new() -> Self {
46        Self {
47            providers: HashMap::new(),
48        }
49    }
50
51    /// Register a provider (automatically wrapped with metrics instrumentation).
52    ///
53    /// # Examples
54    ///
55    /// ```rust,no_run
56    /// use codetether_agent::provider::ProviderRegistry;
57    /// use std::sync::Arc;
58    /// # fn demo(registry: &mut ProviderRegistry, p: Arc<dyn codetether_agent::provider::Provider>) {
59    /// registry.register(p);
60    /// # }
61    /// ```
62    pub fn register(&mut self, provider: Arc<dyn Provider>) {
63        let name = provider.name().to_string();
64        let wrapped = super::metrics::MetricsProvider::wrap(provider);
65        self.providers.insert(name, wrapped);
66    }
67
68    /// Get a provider by name.
69    ///
70    /// # Examples
71    ///
72    /// ```rust
73    /// use codetether_agent::provider::ProviderRegistry;
74    /// let registry = ProviderRegistry::new();
75    /// assert!(registry.get("openai").is_none());
76    /// ```
77    pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
78        self.providers.get(name).cloned()
79    }
80
81    /// List all registered provider names.
82    ///
83    /// # Examples
84    ///
85    /// ```rust
86    /// use codetether_agent::provider::ProviderRegistry;
87    /// let registry = ProviderRegistry::new();
88    /// assert!(registry.list().is_empty());
89    /// ```
90    pub fn list(&self) -> Vec<&str> {
91        self.providers.keys().map(|s| s.as_str()).collect()
92    }
93
94    /// Resolve a model string to a provider and model name.
95    ///
96    /// Accepts:
97    /// - `"provider/model"` (e.g. `"openai/gpt-4o"`)
98    /// - `"model"` alone (uses first available provider)
99    ///
100    /// # Examples
101    ///
102    /// ```rust
103    /// use codetether_agent::provider::ProviderRegistry;
104    ///
105    /// let registry = ProviderRegistry::new();
106    /// // No providers ⇒ error
107    /// assert!(registry.resolve_model("gpt-4o").is_err());
108    /// ```
109    pub fn resolve_model(&self, model_str: &str) -> Result<(Arc<dyn Provider>, String)> {
110        let (provider_name, model) = parse_model_string(model_str);
111
112        if let Some(provider_name) = provider_name {
113            let normalized = match provider_name {
114                "local-cuda" | "localcuda" => "local_cuda",
115                "zhipuai" => "zai",
116                other => other,
117            };
118
119            let provider = self.providers.get(normalized).cloned().ok_or_else(|| {
120                anyhow::anyhow!(
121                    "Provider '{}' not found. Available: {:?}",
122                    normalized,
123                    self.list()
124                )
125            })?;
126            Ok((provider, model.to_string()))
127        } else {
128            let first_provider = self
129                .providers
130                .values()
131                .next()
132                .ok_or_else(|| anyhow::anyhow!("No providers available in registry"))?;
133            Ok((first_provider.clone(), model_str.to_string()))
134        }
135    }
136}
137
138impl Default for ProviderRegistry {
139    fn default() -> Self {
140        Self::new()
141    }
142}