ares/llm/
provider_registry.rs1use crate::llm::client::{LLMClient, Provider};
7use crate::types::{AppError, Result};
8use crate::utils::toml_config::{AresConfig, ModelConfig, ProviderConfig};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12pub struct ProviderRegistry {
17 providers: HashMap<String, ProviderConfig>,
19 models: HashMap<String, ModelConfig>,
21 default_model: Option<String>,
23}
24
25impl ProviderRegistry {
26 pub fn new() -> Self {
28 Self {
29 providers: HashMap::new(),
30 models: HashMap::new(),
31 default_model: None,
32 }
33 }
34
35 pub fn from_config(config: &AresConfig) -> Self {
37 Self {
38 providers: config.providers.clone(),
39 models: config.models.clone(),
40 default_model: config.models.keys().next().cloned(),
41 }
42 }
43
44 pub fn set_default_model(&mut self, model_name: &str) {
46 self.default_model = Some(model_name.to_string());
47 }
48
49 pub fn register_provider(&mut self, name: &str, config: ProviderConfig) {
51 self.providers.insert(name.to_string(), config);
52 }
53
54 pub fn register_model(&mut self, name: &str, config: ModelConfig) {
56 self.models.insert(name.to_string(), config);
57 }
58
59 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
61 self.providers.get(name)
62 }
63
64 pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
66 self.models.get(name)
67 }
68
69 pub fn provider_names(&self) -> Vec<&str> {
71 self.providers.keys().map(|s| s.as_str()).collect()
72 }
73
74 pub fn model_names(&self) -> Vec<&str> {
76 self.models.keys().map(|s| s.as_str()).collect()
77 }
78
79 pub async fn create_client_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
83 let model_config = self.get_model(model_name).ok_or_else(|| {
84 AppError::Configuration(format!("Model '{}' not found in configuration", model_name))
85 })?;
86
87 let provider_config = self.get_provider(&model_config.provider).ok_or_else(|| {
88 AppError::Configuration(format!(
89 "Provider '{}' referenced by model '{}' not found",
90 model_config.provider, model_name
91 ))
92 })?;
93
94 let provider = Provider::from_model_config(model_config, provider_config)?;
95 provider.create_client().await
96 }
97
98 pub async fn create_client_for_provider(
102 &self,
103 provider_name: &str,
104 ) -> Result<Box<dyn LLMClient>> {
105 let provider_config = self.get_provider(provider_name).ok_or_else(|| {
106 AppError::Configuration(format!(
107 "Provider '{}' not found in configuration",
108 provider_name
109 ))
110 })?;
111
112 let provider = Provider::from_config(provider_config, None)?;
113 provider.create_client().await
114 }
115
116 pub async fn create_default_client(&self) -> Result<Box<dyn LLMClient>> {
118 let model_name = self
119 .default_model
120 .as_ref()
121 .ok_or_else(|| AppError::Configuration("No default model configured".into()))?;
122
123 self.create_client_for_model(model_name).await
124 }
125
126 pub fn has_model(&self, name: &str) -> bool {
128 self.models.contains_key(name)
129 }
130
131 pub fn has_provider(&self, name: &str) -> bool {
133 self.providers.contains_key(name)
134 }
135}
136
137impl Default for ProviderRegistry {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143pub struct ConfigBasedLLMFactory {
147 registry: Arc<ProviderRegistry>,
148 default_model: String,
149}
150
151impl ConfigBasedLLMFactory {
152 pub fn new(registry: Arc<ProviderRegistry>, default_model: &str) -> Self {
154 Self {
155 registry,
156 default_model: default_model.to_string(),
157 }
158 }
159
160 pub fn from_config(config: &AresConfig) -> Result<Self> {
162 let registry = ProviderRegistry::from_config(config);
163
164 let default_model =
166 config.models.keys().next().cloned().ok_or_else(|| {
167 AppError::Configuration("No models defined in configuration".into())
168 })?;
169
170 Ok(Self {
171 registry: Arc::new(registry),
172 default_model,
173 })
174 }
175
176 pub fn registry(&self) -> &Arc<ProviderRegistry> {
178 &self.registry
179 }
180
181 pub async fn create_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
183 self.registry.create_client_for_model(model_name).await
184 }
185
186 pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
188 self.registry
189 .create_client_for_model(&self.default_model)
190 .await
191 }
192
193 pub fn default_model(&self) -> &str {
195 &self.default_model
196 }
197
198 pub fn set_default_model(&mut self, model_name: &str) {
200 self.default_model = model_name.to_string();
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_empty_registry() {
210 let registry = ProviderRegistry::new();
211 assert!(registry.provider_names().is_empty());
212 assert!(registry.model_names().is_empty());
213 }
214
215 #[test]
216 fn test_register_provider() {
217 let mut registry = ProviderRegistry::new();
218 registry.register_provider(
219 "ollama-local",
220 ProviderConfig::Ollama {
221 base_url: "http://localhost:11434".to_string(),
222 default_model: "ministral-3:3b".to_string(),
223 },
224 );
225
226 assert!(registry.has_provider("ollama-local"));
227 assert!(!registry.has_provider("nonexistent"));
228 }
229
230 #[test]
231 fn test_register_model() {
232 let mut registry = ProviderRegistry::new();
233 registry.register_provider(
234 "ollama-local",
235 ProviderConfig::Ollama {
236 base_url: "http://localhost:11434".to_string(),
237 default_model: "ministral-3:3b".to_string(),
238 },
239 );
240 registry.register_model(
241 "fast",
242 ModelConfig {
243 provider: "ollama-local".to_string(),
244 model: "ministral-3:3b".to_string(),
245 temperature: 0.7,
246 max_tokens: 256,
247 top_p: None,
248 frequency_penalty: None,
249 presence_penalty: None,
250 },
251 );
252
253 assert!(registry.has_model("fast"));
254 assert!(!registry.has_model("nonexistent"));
255 }
256}