Skip to main content

ares/llm/
provider_registry.rs

1//! Provider Registry for managing multiple LLM providers
2//!
3//! This module provides a registry for managing named LLM providers
4//! that can be configured via TOML configuration.
5
6use crate::llm::client::{LLMClient, Provider};
7use crate::types::{AppError, Result};
8use crate::utils::toml_config::{AresConfig, ModelConfig, ProviderConfig};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Registry for managing multiple named LLM providers
13///
14/// The ProviderRegistry holds references to provider configurations and allows
15/// creating LLM clients for specific models or providers by name.
16pub struct ProviderRegistry {
17    /// Provider configurations keyed by name
18    providers: HashMap<String, ProviderConfig>,
19    /// Model configurations keyed by name
20    models: HashMap<String, ModelConfig>,
21    /// Default model name to use when none specified
22    default_model: Option<String>,
23}
24
25impl ProviderRegistry {
26    /// Create a new empty provider registry
27    pub fn new() -> Self {
28        Self {
29            providers: HashMap::new(),
30            models: HashMap::new(),
31            default_model: None,
32        }
33    }
34
35    /// Create a provider registry from TOML configuration
36    pub fn from_config(config: &AresConfig) -> Self {
37        Self {
38            providers: config.providers.clone(),
39            models: config.models.clone(),
40            default_model: config.models.keys().next().cloned(),
41        }
42    }
43
44    /// Set the default model name
45    pub fn set_default_model(&mut self, model_name: &str) {
46        self.default_model = Some(model_name.to_string());
47    }
48
49    /// Register a provider configuration
50    pub fn register_provider(&mut self, name: &str, config: ProviderConfig) {
51        self.providers.insert(name.to_string(), config);
52    }
53
54    /// Register a model configuration
55    pub fn register_model(&mut self, name: &str, config: ModelConfig) {
56        self.models.insert(name.to_string(), config);
57    }
58
59    /// Get a provider configuration by name
60    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
61        self.providers.get(name)
62    }
63
64    /// Get a model configuration by name
65    pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
66        self.models.get(name)
67    }
68
69    /// Get all provider names
70    pub fn provider_names(&self) -> Vec<&str> {
71        self.providers.keys().map(|s| s.as_str()).collect()
72    }
73
74    /// Get all model names
75    pub fn model_names(&self) -> Vec<&str> {
76        self.models.keys().map(|s| s.as_str()).collect()
77    }
78
79    /// Create an LLM client for a specific model by name
80    ///
81    /// This resolves the model -> provider chain and creates the appropriate client.
82    pub async fn create_client_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
83        let model_config = self.get_model(model_name).ok_or_else(|| {
84            AppError::Configuration(format!("Model '{}' not found in configuration", model_name))
85        })?;
86
87        let provider_config = self.get_provider(&model_config.provider).ok_or_else(|| {
88            AppError::Configuration(format!(
89                "Provider '{}' referenced by model '{}' not found",
90                model_config.provider, model_name
91            ))
92        })?;
93
94        let provider = Provider::from_model_config(model_config, provider_config)?;
95        provider.create_client().await
96    }
97
98    /// Create an LLM client for a specific provider by name
99    ///
100    /// Uses the provider's default model.
101    pub async fn create_client_for_provider(
102        &self,
103        provider_name: &str,
104    ) -> Result<Box<dyn LLMClient>> {
105        let provider_config = self.get_provider(provider_name).ok_or_else(|| {
106            AppError::Configuration(format!(
107                "Provider '{}' not found in configuration",
108                provider_name
109            ))
110        })?;
111
112        let provider = Provider::from_config(provider_config, None)?;
113        provider.create_client().await
114    }
115
116    /// Create an LLM client using the default model
117    pub async fn create_default_client(&self) -> Result<Box<dyn LLMClient>> {
118        let model_name = self
119            .default_model
120            .as_ref()
121            .ok_or_else(|| AppError::Configuration("No default model configured".into()))?;
122
123        self.create_client_for_model(model_name).await
124    }
125
126    /// Check if a model exists in the registry
127    pub fn has_model(&self, name: &str) -> bool {
128        self.models.contains_key(name)
129    }
130
131    /// Check if a provider exists in the registry
132    pub fn has_provider(&self, name: &str) -> bool {
133        self.providers.contains_key(name)
134    }
135}
136
137impl Default for ProviderRegistry {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143/// Configuration-based LLM client factory using the provider registry
144///
145/// This is the new factory that uses TOML configuration instead of environment variables.
146pub struct ConfigBasedLLMFactory {
147    registry: Arc<ProviderRegistry>,
148    default_model: String,
149}
150
151impl ConfigBasedLLMFactory {
152    /// Create a new factory from a provider registry
153    pub fn new(registry: Arc<ProviderRegistry>, default_model: &str) -> Self {
154        Self {
155            registry,
156            default_model: default_model.to_string(),
157        }
158    }
159
160    /// Create a factory from TOML configuration
161    pub fn from_config(config: &AresConfig) -> Result<Self> {
162        let registry = ProviderRegistry::from_config(config);
163
164        // Get the first model as default, or error if no models defined
165        let default_model =
166            config.models.keys().next().cloned().ok_or_else(|| {
167                AppError::Configuration("No models defined in configuration".into())
168            })?;
169
170        Ok(Self {
171            registry: Arc::new(registry),
172            default_model,
173        })
174    }
175
176    /// Get the provider registry
177    pub fn registry(&self) -> &Arc<ProviderRegistry> {
178        &self.registry
179    }
180
181    /// Create an LLM client for a specific model
182    pub async fn create_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
183        self.registry.create_client_for_model(model_name).await
184    }
185
186    /// Create an LLM client using the default model
187    pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
188        self.registry
189            .create_client_for_model(&self.default_model)
190            .await
191    }
192
193    /// Get the default model name
194    pub fn default_model(&self) -> &str {
195        &self.default_model
196    }
197
198    /// Set the default model name
199    pub fn set_default_model(&mut self, model_name: &str) {
200        self.default_model = model_name.to_string();
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_empty_registry() {
210        let registry = ProviderRegistry::new();
211        assert!(registry.provider_names().is_empty());
212        assert!(registry.model_names().is_empty());
213    }
214
215    #[test]
216    fn test_register_provider() {
217        let mut registry = ProviderRegistry::new();
218        registry.register_provider(
219            "ollama-local",
220            ProviderConfig::Ollama {
221                base_url: "http://localhost:11434".to_string(),
222                default_model: "ministral-3:3b".to_string(),
223            },
224        );
225
226        assert!(registry.has_provider("ollama-local"));
227        assert!(!registry.has_provider("nonexistent"));
228    }
229
230    #[test]
231    fn test_register_model() {
232        let mut registry = ProviderRegistry::new();
233        registry.register_provider(
234            "ollama-local",
235            ProviderConfig::Ollama {
236                base_url: "http://localhost:11434".to_string(),
237                default_model: "ministral-3:3b".to_string(),
238            },
239        );
240        registry.register_model(
241            "fast",
242            ModelConfig {
243                provider: "ollama-local".to_string(),
244                model: "ministral-3:3b".to_string(),
245                temperature: 0.7,
246                max_tokens: 256,
247                top_p: None,
248                frequency_penalty: None,
249                presence_penalty: None,
250            },
251        );
252
253        assert!(registry.has_model("fast"));
254        assert!(!registry.has_model("nonexistent"));
255    }
256}