1use crate::llm::capabilities::{CapabilityRequirements, ModelCapabilities, ModelWithCapabilities};
24use crate::llm::client::{LLMClient, Provider};
25use crate::types::{AppError, Result};
26use crate::utils::toml_config::{AresConfig, ModelConfig, ProviderConfig};
27use std::collections::HashMap;
28use std::sync::Arc;
29
30pub struct ProviderRegistry {
35 providers: HashMap<String, ProviderConfig>,
37 models: HashMap<String, ModelConfig>,
39 default_model: Option<String>,
41}
42
43impl ProviderRegistry {
44 pub fn new() -> Self {
46 Self {
47 providers: HashMap::new(),
48 models: HashMap::new(),
49 default_model: None,
50 }
51 }
52
53 pub fn from_config(config: &AresConfig) -> Self {
55 Self {
56 providers: config.providers.clone(),
57 models: config.models.clone(),
58 default_model: config.models.keys().next().cloned(),
59 }
60 }
61
62 pub fn set_default_model(&mut self, model_name: &str) {
64 self.default_model = Some(model_name.to_string());
65 }
66
67 pub fn register_provider(&mut self, name: &str, config: ProviderConfig) {
69 self.providers.insert(name.to_string(), config);
70 }
71
72 pub fn register_model(&mut self, name: &str, config: ModelConfig) {
74 self.models.insert(name.to_string(), config);
75 }
76
77 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
79 self.providers.get(name)
80 }
81
82 pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
84 self.models.get(name)
85 }
86
87 pub fn provider_names(&self) -> Vec<&str> {
89 self.providers.keys().map(|s| s.as_str()).collect()
90 }
91
92 pub fn model_names(&self) -> Vec<&str> {
94 self.models.keys().map(|s| s.as_str()).collect()
95 }
96
97 pub async fn create_client_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
101 let model_config = self.get_model(model_name).ok_or_else(|| {
102 AppError::Configuration(format!("Model '{}' not found in configuration", model_name))
103 })?;
104
105 let provider_config = self.get_provider(&model_config.provider).ok_or_else(|| {
106 AppError::Configuration(format!(
107 "Provider '{}' referenced by model '{}' not found",
108 model_config.provider, model_name
109 ))
110 })?;
111
112 let provider = Provider::from_model_config(model_config, provider_config)?;
113 provider.create_client().await
114 }
115
116 pub async fn create_client_for_provider(
120 &self,
121 provider_name: &str,
122 ) -> Result<Box<dyn LLMClient>> {
123 let provider_config = self.get_provider(provider_name).ok_or_else(|| {
124 AppError::Configuration(format!(
125 "Provider '{}' not found in configuration",
126 provider_name
127 ))
128 })?;
129
130 let provider = Provider::from_config(provider_config, None)?;
131 provider.create_client().await
132 }
133
134 pub async fn create_default_client(&self) -> Result<Box<dyn LLMClient>> {
136 let model_name = self
137 .default_model
138 .as_ref()
139 .ok_or_else(|| AppError::Configuration("No default model configured".into()))?;
140
141 self.create_client_for_model(model_name).await
142 }
143
144 pub fn has_model(&self, name: &str) -> bool {
146 self.models.contains_key(name)
147 }
148
149 pub fn has_provider(&self, name: &str) -> bool {
151 self.providers.contains_key(name)
152 }
153
154 pub fn get_model_capabilities(&self, model_name: &str) -> Option<ModelCapabilities> {
161 let model_config = self.get_model(model_name)?;
162 let provider_config = self.get_provider(&model_config.provider)?;
163
164 let mut caps = ModelCapabilities::for_model(&model_config.model);
166
167 match provider_config {
169 ProviderConfig::Ollama { .. } => {
170 caps.is_local = true;
171 caps.cost_tier = "free".to_string();
172 }
173 ProviderConfig::LlamaCpp { .. } => {
174 caps.is_local = true;
175 caps.cost_tier = "free".to_string();
176 }
177 ProviderConfig::OpenAI { .. } => {
178 caps.is_local = false;
179 }
180 ProviderConfig::Anthropic { .. } => {
181 caps.is_local = false;
182 }
183 }
184
185 Some(caps)
186 }
187
188 pub fn models_with_capabilities(&self) -> Vec<ModelWithCapabilities> {
190 self.models
191 .iter()
192 .filter_map(|(name, config)| {
193 let caps = self.get_model_capabilities(name)?;
194 Some(ModelWithCapabilities {
195 name: name.clone(),
196 provider: config.provider.clone(),
197 model_id: config.model.clone(),
198 capabilities: caps,
199 })
200 })
201 .collect()
202 }
203
204 pub fn find_models(&self, requirements: &CapabilityRequirements) -> Vec<ModelWithCapabilities> {
208 let mut matches: Vec<_> = self
209 .models_with_capabilities()
210 .into_iter()
211 .filter(|m| m.capabilities.satisfies(requirements))
212 .collect();
213
214 matches.sort_by(|a, b| {
216 let score_a = a.capabilities.score(requirements);
217 let score_b = b.capabilities.score(requirements);
218 score_b.cmp(&score_a)
219 });
220
221 matches
222 }
223
224 pub fn find_best_model(
229 &self,
230 requirements: &CapabilityRequirements,
231 ) -> Option<ModelWithCapabilities> {
232 self.find_models(requirements).into_iter().next()
233 }
234
235 pub async fn create_client_for_requirements(
248 &self,
249 requirements: &CapabilityRequirements,
250 ) -> Result<Box<dyn LLMClient>> {
251 let model = self.find_best_model(requirements).ok_or_else(|| {
252 AppError::Configuration(format!(
253 "No model found matching requirements: {:?}",
254 requirements
255 ))
256 })?;
257
258 self.create_client_for_model(&model.name).await
259 }
260
261 pub fn find_agent_models(&self) -> Vec<ModelWithCapabilities> {
263 self.find_models(&CapabilityRequirements::for_agent())
264 }
265
266 pub fn find_vision_models(&self) -> Vec<ModelWithCapabilities> {
268 self.find_models(&CapabilityRequirements::for_vision())
269 }
270
271 pub fn find_coding_models(&self) -> Vec<ModelWithCapabilities> {
273 self.find_models(&CapabilityRequirements::for_coding())
274 }
275
276 pub fn find_local_models(&self) -> Vec<ModelWithCapabilities> {
278 self.find_models(&CapabilityRequirements::for_local())
279 }
280
281 pub fn list_models(&self) -> Vec<ModelInfo> {
283 self.models.iter().map(|(name, config)| ModelInfo {
284 name: name.clone(),
285 provider: config.provider.clone(),
286 model: config.model.clone(),
287 }).collect()
288 }
289}
290
291#[derive(Debug, Clone, serde::Serialize)]
293pub struct ModelInfo {
294 pub name: String,
295 pub provider: String,
296 pub model: String,
297}
298
299impl Default for ProviderRegistry {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305pub struct ConfigBasedLLMFactory {
309 registry: Arc<ProviderRegistry>,
310 default_model: String,
311}
312
313impl ConfigBasedLLMFactory {
314 pub fn new(registry: Arc<ProviderRegistry>, default_model: &str) -> Self {
316 Self {
317 registry,
318 default_model: default_model.to_string(),
319 }
320 }
321
322 pub fn from_config(config: &AresConfig) -> Result<Self> {
324 let registry = ProviderRegistry::from_config(config);
325
326 let default_model =
328 config.models.keys().next().cloned().ok_or_else(|| {
329 AppError::Configuration("No models defined in configuration".into())
330 })?;
331
332 Ok(Self {
333 registry: Arc::new(registry),
334 default_model,
335 })
336 }
337
338 pub fn registry(&self) -> &Arc<ProviderRegistry> {
340 &self.registry
341 }
342
343 pub async fn create_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
345 self.registry.create_client_for_model(model_name).await
346 }
347
348 pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
350 self.registry
351 .create_client_for_model(&self.default_model)
352 .await
353 }
354
355 pub fn default_model(&self) -> &str {
357 &self.default_model
358 }
359
360 pub fn set_default_model(&mut self, model_name: &str) {
362 self.default_model = model_name.to_string();
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::llm::capabilities::CapabilityRequirements;
370
371 #[test]
372 fn test_empty_registry() {
373 let registry = ProviderRegistry::new();
374 assert!(registry.provider_names().is_empty());
375 assert!(registry.model_names().is_empty());
376 }
377
378 #[test]
379 fn test_register_provider() {
380 let mut registry = ProviderRegistry::new();
381 registry.register_provider(
382 "ollama-local",
383 ProviderConfig::Ollama {
384 base_url: "http://localhost:11434".to_string(),
385 default_model: "ministral-3:3b".to_string(),
386 },
387 );
388
389 assert!(registry.has_provider("ollama-local"));
390 assert!(!registry.has_provider("nonexistent"));
391 }
392
393 #[test]
394 fn test_register_model() {
395 let mut registry = ProviderRegistry::new();
396 registry.register_provider(
397 "ollama-local",
398 ProviderConfig::Ollama {
399 base_url: "http://localhost:11434".to_string(),
400 default_model: "ministral-3:3b".to_string(),
401 },
402 );
403 registry.register_model(
404 "fast",
405 ModelConfig {
406 provider: "ollama-local".to_string(),
407 model: "ministral-3:3b".to_string(),
408 temperature: 0.7,
409 max_tokens: 256,
410 top_p: None,
411 frequency_penalty: None,
412 presence_penalty: None,
413 },
414 );
415
416 assert!(registry.has_model("fast"));
417 assert!(!registry.has_model("nonexistent"));
418 }
419
420 fn create_test_registry() -> ProviderRegistry {
423 let mut registry = ProviderRegistry::new();
424
425 registry.register_provider(
427 "ollama",
428 ProviderConfig::Ollama {
429 base_url: "http://localhost:11434".to_string(),
430 default_model: "llama-3.3-70b-instruct".to_string(),
431 },
432 );
433
434 registry.register_provider(
435 "anthropic",
436 ProviderConfig::Anthropic {
437 api_key_env: "ANTHROPIC_API_KEY".to_string(),
438 default_model: "claude-3-5-sonnet-20241022".to_string(),
439 },
440 );
441
442 registry.register_provider(
443 "openai",
444 ProviderConfig::OpenAI {
445 api_key_env: "OPENAI_API_KEY".to_string(),
446 api_base: "https://api.openai.com/v1".to_string(),
447 default_model: "gpt-4o".to_string(),
448 },
449 );
450
451 registry.register_model(
453 "fast-local",
454 ModelConfig {
455 provider: "ollama".to_string(),
456 model: "ministral-3:3b".to_string(),
457 temperature: 0.7,
458 max_tokens: 512,
459 top_p: None,
460 frequency_penalty: None,
461 presence_penalty: None,
462 },
463 );
464
465 registry.register_model(
466 "powerful-local",
467 ModelConfig {
468 provider: "ollama".to_string(),
469 model: "llama-3.3-70b-instruct".to_string(),
470 temperature: 0.7,
471 max_tokens: 2048,
472 top_p: None,
473 frequency_penalty: None,
474 presence_penalty: None,
475 },
476 );
477
478 registry.register_model(
479 "claude-sonnet",
480 ModelConfig {
481 provider: "anthropic".to_string(),
482 model: "claude-3-5-sonnet-20241022".to_string(),
483 temperature: 0.7,
484 max_tokens: 4096,
485 top_p: None,
486 frequency_penalty: None,
487 presence_penalty: None,
488 },
489 );
490
491 registry.register_model(
492 "gpt4o",
493 ModelConfig {
494 provider: "openai".to_string(),
495 model: "gpt-4o-2024-08-06".to_string(),
496 temperature: 0.7,
497 max_tokens: 4096,
498 top_p: None,
499 frequency_penalty: None,
500 presence_penalty: None,
501 },
502 );
503
504 registry
505 }
506
507 #[test]
508 fn test_get_model_capabilities() {
509 let registry = create_test_registry();
510
511 let fast_caps = registry.get_model_capabilities("fast-local").unwrap();
513 assert!(fast_caps.is_local);
514 assert_eq!(fast_caps.cost_tier, "free");
515 assert!(fast_caps.supports_tools);
516
517 let claude_caps = registry.get_model_capabilities("claude-sonnet").unwrap();
519 assert!(!claude_caps.is_local);
520 assert!(claude_caps.supports_tools);
521 assert!(claude_caps.supports_vision);
522 assert_eq!(claude_caps.context_window, 200_000);
523 }
524
525 #[test]
526 fn test_models_with_capabilities() {
527 let registry = create_test_registry();
528 let models = registry.models_with_capabilities();
529
530 assert_eq!(models.len(), 4);
531
532 for model in &models {
534 assert!(!model.name.is_empty());
535 assert!(!model.provider.is_empty());
536 assert!(model.capabilities.supports_tools);
538 }
539 }
540
541 #[test]
542 fn test_find_local_models() {
543 let registry = create_test_registry();
544 let local_models = registry.find_local_models();
545
546 assert_eq!(local_models.len(), 2);
548 for model in &local_models {
549 assert!(model.capabilities.is_local);
550 assert_eq!(model.capabilities.cost_tier, "free");
551 }
552 }
553
554 #[test]
555 fn test_find_vision_models() {
556 let registry = create_test_registry();
557 let vision_models = registry.find_vision_models();
558
559 assert_eq!(vision_models.len(), 2);
561 for model in &vision_models {
562 assert!(model.capabilities.supports_vision);
563 }
564 }
565
566 #[test]
567 fn test_find_best_model_for_agent() {
568 let registry = create_test_registry();
569
570 let requirements = CapabilityRequirements::for_agent();
571 let best = registry.find_best_model(&requirements);
572
573 assert!(best.is_some());
574 let best = best.unwrap();
575 assert!(best.capabilities.supports_tools);
576 assert!(best.capabilities.production_ready);
577 }
578
579 #[test]
580 fn test_find_best_model_with_context_window() {
581 let registry = create_test_registry();
582
583 let requirements = CapabilityRequirements::builder()
585 .min_context_window(100_000)
586 .build();
587
588 let matches = registry.find_models(&requirements);
589
590 assert!(matches.len() >= 2);
592 for model in &matches {
593 assert!(model.capabilities.context_window >= 100_000);
594 }
595 }
596
597 #[test]
598 fn test_find_best_model_prefers_cheaper() {
599 let registry = create_test_registry();
600
601 let requirements = CapabilityRequirements::builder().requires_tools().build();
603
604 let best = registry.find_best_model(&requirements).unwrap();
605
606 assert!(
609 best.capabilities.is_local || best.capabilities.cost_tier == "free",
610 "Expected best model to be local/free, got: {} (cost: {})",
611 best.name,
612 best.capabilities.cost_tier
613 );
614 }
615
616 #[test]
617 fn test_no_model_matches_impossible_requirements() {
618 let registry = create_test_registry();
619
620 let requirements = CapabilityRequirements::builder()
622 .requires_local()
623 .requires_vision()
624 .build();
625
626 let matches = registry.find_models(&requirements);
627 assert!(matches.is_empty());
628 }
629
630 #[test]
631 fn test_find_coding_models() {
632 let registry = create_test_registry();
633 let coding_models = registry.find_coding_models();
634
635 for model in &coding_models {
637 assert!(model.capabilities.supports_tools);
638 assert!(model.capabilities.supports_reasoning);
639 assert!(model.capabilities.context_window >= 32_000);
640 }
641 }
642}