1use crate::llm::capabilities::{CapabilityRequirements, ModelCapabilities, ModelWithCapabilities};
24use crate::llm::client::{LLMClient, Provider};
25use crate::types::{AppError, Result};
26use crate::utils::toml_config::{AresConfig, ModelConfig, ProviderConfig};
27use std::collections::HashMap;
28use std::sync::Arc;
29
30pub struct ProviderRegistry {
35 providers: HashMap<String, ProviderConfig>,
37 models: HashMap<String, ModelConfig>,
39 default_model: Option<String>,
41}
42
43impl ProviderRegistry {
44 pub fn new() -> Self {
46 Self {
47 providers: HashMap::new(),
48 models: HashMap::new(),
49 default_model: None,
50 }
51 }
52
53 pub fn from_config(config: &AresConfig) -> Self {
55 Self {
56 providers: config.providers.clone(),
57 models: config.models.clone(),
58 default_model: config.models.keys().next().cloned(),
59 }
60 }
61
62 pub fn set_default_model(&mut self, model_name: &str) {
64 self.default_model = Some(model_name.to_string());
65 }
66
67 pub fn register_provider(&mut self, name: &str, config: ProviderConfig) {
69 self.providers.insert(name.to_string(), config);
70 }
71
72 pub fn register_model(&mut self, name: &str, config: ModelConfig) {
74 self.models.insert(name.to_string(), config);
75 }
76
77 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
79 self.providers.get(name)
80 }
81
82 pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
84 self.models.get(name)
85 }
86
87 pub fn provider_names(&self) -> Vec<&str> {
89 self.providers.keys().map(|s| s.as_str()).collect()
90 }
91
92 pub fn model_names(&self) -> Vec<&str> {
94 self.models.keys().map(|s| s.as_str()).collect()
95 }
96
97 pub async fn create_client_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
101 let model_config = self.get_model(model_name).ok_or_else(|| {
102 AppError::Configuration(format!("Model '{}' not found in configuration", model_name))
103 })?;
104
105 let provider_config = self.get_provider(&model_config.provider).ok_or_else(|| {
106 AppError::Configuration(format!(
107 "Provider '{}' referenced by model '{}' not found",
108 model_config.provider, model_name
109 ))
110 })?;
111
112 let provider = Provider::from_model_config(model_config, provider_config)?;
113 provider.create_client().await
114 }
115
116 pub async fn create_client_for_provider(
120 &self,
121 provider_name: &str,
122 ) -> Result<Box<dyn LLMClient>> {
123 let provider_config = self.get_provider(provider_name).ok_or_else(|| {
124 AppError::Configuration(format!(
125 "Provider '{}' not found in configuration",
126 provider_name
127 ))
128 })?;
129
130 let provider = Provider::from_config(provider_config, None)?;
131 provider.create_client().await
132 }
133
134 pub async fn create_default_client(&self) -> Result<Box<dyn LLMClient>> {
136 let model_name = self
137 .default_model
138 .as_ref()
139 .ok_or_else(|| AppError::Configuration("No default model configured".into()))?;
140
141 self.create_client_for_model(model_name).await
142 }
143
144 pub fn has_model(&self, name: &str) -> bool {
146 self.models.contains_key(name)
147 }
148
149 pub fn has_provider(&self, name: &str) -> bool {
151 self.providers.contains_key(name)
152 }
153
154 pub fn get_model_capabilities(&self, model_name: &str) -> Option<ModelCapabilities> {
161 let model_config = self.get_model(model_name)?;
162 let provider_config = self.get_provider(&model_config.provider)?;
163
164 let mut caps = ModelCapabilities::for_model(&model_config.model);
166
167 match provider_config {
169 ProviderConfig::Ollama { .. } => {
170 caps.is_local = true;
171 caps.cost_tier = "free".to_string();
172 }
173 ProviderConfig::LlamaCpp { .. } => {
174 caps.is_local = true;
175 caps.cost_tier = "free".to_string();
176 }
177 ProviderConfig::OpenAI { .. } => {
178 caps.is_local = false;
179 }
180 ProviderConfig::Anthropic { .. } => {
181 caps.is_local = false;
182 }
183 }
184
185 Some(caps)
186 }
187
188 pub fn models_with_capabilities(&self) -> Vec<ModelWithCapabilities> {
190 self.models
191 .iter()
192 .filter_map(|(name, config)| {
193 let caps = self.get_model_capabilities(name)?;
194 Some(ModelWithCapabilities {
195 name: name.clone(),
196 provider: config.provider.clone(),
197 model_id: config.model.clone(),
198 capabilities: caps,
199 })
200 })
201 .collect()
202 }
203
204 pub fn find_models(&self, requirements: &CapabilityRequirements) -> Vec<ModelWithCapabilities> {
208 let mut matches: Vec<_> = self
209 .models_with_capabilities()
210 .into_iter()
211 .filter(|m| m.capabilities.satisfies(requirements))
212 .collect();
213
214 matches.sort_by(|a, b| {
216 let score_a = a.capabilities.score(requirements);
217 let score_b = b.capabilities.score(requirements);
218 score_b.cmp(&score_a)
219 });
220
221 matches
222 }
223
224 pub fn find_best_model(
229 &self,
230 requirements: &CapabilityRequirements,
231 ) -> Option<ModelWithCapabilities> {
232 self.find_models(requirements).into_iter().next()
233 }
234
235 pub async fn create_client_for_requirements(
248 &self,
249 requirements: &CapabilityRequirements,
250 ) -> Result<Box<dyn LLMClient>> {
251 let model = self.find_best_model(requirements).ok_or_else(|| {
252 AppError::Configuration(format!(
253 "No model found matching requirements: {:?}",
254 requirements
255 ))
256 })?;
257
258 self.create_client_for_model(&model.name).await
259 }
260
261 pub fn find_agent_models(&self) -> Vec<ModelWithCapabilities> {
263 self.find_models(&CapabilityRequirements::for_agent())
264 }
265
266 pub fn find_vision_models(&self) -> Vec<ModelWithCapabilities> {
268 self.find_models(&CapabilityRequirements::for_vision())
269 }
270
271 pub fn find_coding_models(&self) -> Vec<ModelWithCapabilities> {
273 self.find_models(&CapabilityRequirements::for_coding())
274 }
275
276 pub fn find_local_models(&self) -> Vec<ModelWithCapabilities> {
278 self.find_models(&CapabilityRequirements::for_local())
279 }
280
281 pub fn list_models(&self) -> Vec<ModelInfo> {
283 self.models
284 .iter()
285 .map(|(name, config)| ModelInfo {
286 name: name.clone(),
287 provider: config.provider.clone(),
288 model: config.model.clone(),
289 })
290 .collect()
291 }
292}
293
294#[derive(Debug, Clone, serde::Serialize)]
296pub struct ModelInfo {
297 pub name: String,
298 pub provider: String,
299 pub model: String,
300}
301
302impl Default for ProviderRegistry {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308pub struct ConfigBasedLLMFactory {
312 registry: Arc<ProviderRegistry>,
313 default_model: String,
314}
315
316impl ConfigBasedLLMFactory {
317 pub fn new(registry: Arc<ProviderRegistry>, default_model: &str) -> Self {
319 Self {
320 registry,
321 default_model: default_model.to_string(),
322 }
323 }
324
325 pub fn from_config(config: &AresConfig) -> Result<Self> {
327 let registry = ProviderRegistry::from_config(config);
328
329 let default_model =
331 config.models.keys().next().cloned().ok_or_else(|| {
332 AppError::Configuration("No models defined in configuration".into())
333 })?;
334
335 Ok(Self {
336 registry: Arc::new(registry),
337 default_model,
338 })
339 }
340
341 pub fn registry(&self) -> &Arc<ProviderRegistry> {
343 &self.registry
344 }
345
346 pub async fn create_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
348 self.registry.create_client_for_model(model_name).await
349 }
350
351 pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
353 self.registry
354 .create_client_for_model(&self.default_model)
355 .await
356 }
357
358 pub fn default_model(&self) -> &str {
360 &self.default_model
361 }
362
363 pub fn set_default_model(&mut self, model_name: &str) {
365 self.default_model = model_name.to_string();
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::llm::capabilities::CapabilityRequirements;
373
374 #[test]
375 fn test_empty_registry() {
376 let registry = ProviderRegistry::new();
377 assert!(registry.provider_names().is_empty());
378 assert!(registry.model_names().is_empty());
379 }
380
381 #[test]
382 fn test_register_provider() {
383 let mut registry = ProviderRegistry::new();
384 registry.register_provider(
385 "ollama-local",
386 ProviderConfig::Ollama {
387 base_url: "http://localhost:11434".to_string(),
388 default_model: "ministral-3:3b".to_string(),
389 },
390 );
391
392 assert!(registry.has_provider("ollama-local"));
393 assert!(!registry.has_provider("nonexistent"));
394 }
395
396 #[test]
397 fn test_register_model() {
398 let mut registry = ProviderRegistry::new();
399 registry.register_provider(
400 "ollama-local",
401 ProviderConfig::Ollama {
402 base_url: "http://localhost:11434".to_string(),
403 default_model: "ministral-3:3b".to_string(),
404 },
405 );
406 registry.register_model(
407 "fast",
408 ModelConfig {
409 provider: "ollama-local".to_string(),
410 model: "ministral-3:3b".to_string(),
411 temperature: 0.7,
412 max_tokens: 256,
413 top_p: None,
414 frequency_penalty: None,
415 presence_penalty: None,
416 },
417 );
418
419 assert!(registry.has_model("fast"));
420 assert!(!registry.has_model("nonexistent"));
421 }
422
423 fn create_test_registry() -> ProviderRegistry {
426 let mut registry = ProviderRegistry::new();
427
428 registry.register_provider(
430 "ollama",
431 ProviderConfig::Ollama {
432 base_url: "http://localhost:11434".to_string(),
433 default_model: "llama-3.3-70b-instruct".to_string(),
434 },
435 );
436
437 registry.register_provider(
438 "anthropic",
439 ProviderConfig::Anthropic {
440 api_key_env: "ANTHROPIC_API_KEY".to_string(),
441 default_model: "claude-3-5-sonnet-20241022".to_string(),
442 },
443 );
444
445 registry.register_provider(
446 "openai",
447 ProviderConfig::OpenAI {
448 api_key_env: "OPENAI_API_KEY".to_string(),
449 api_base: "https://api.openai.com/v1".to_string(),
450 default_model: "gpt-4o".to_string(),
451 },
452 );
453
454 registry.register_model(
456 "fast-local",
457 ModelConfig {
458 provider: "ollama".to_string(),
459 model: "ministral-3:3b".to_string(),
460 temperature: 0.7,
461 max_tokens: 512,
462 top_p: None,
463 frequency_penalty: None,
464 presence_penalty: None,
465 },
466 );
467
468 registry.register_model(
469 "powerful-local",
470 ModelConfig {
471 provider: "ollama".to_string(),
472 model: "llama-3.3-70b-instruct".to_string(),
473 temperature: 0.7,
474 max_tokens: 2048,
475 top_p: None,
476 frequency_penalty: None,
477 presence_penalty: None,
478 },
479 );
480
481 registry.register_model(
482 "claude-sonnet",
483 ModelConfig {
484 provider: "anthropic".to_string(),
485 model: "claude-3-5-sonnet-20241022".to_string(),
486 temperature: 0.7,
487 max_tokens: 4096,
488 top_p: None,
489 frequency_penalty: None,
490 presence_penalty: None,
491 },
492 );
493
494 registry.register_model(
495 "gpt4o",
496 ModelConfig {
497 provider: "openai".to_string(),
498 model: "gpt-4o-2024-08-06".to_string(),
499 temperature: 0.7,
500 max_tokens: 4096,
501 top_p: None,
502 frequency_penalty: None,
503 presence_penalty: None,
504 },
505 );
506
507 registry
508 }
509
510 #[test]
511 fn test_get_model_capabilities() {
512 let registry = create_test_registry();
513
514 let fast_caps = registry.get_model_capabilities("fast-local").unwrap();
516 assert!(fast_caps.is_local);
517 assert_eq!(fast_caps.cost_tier, "free");
518 assert!(fast_caps.supports_tools);
519
520 let claude_caps = registry.get_model_capabilities("claude-sonnet").unwrap();
522 assert!(!claude_caps.is_local);
523 assert!(claude_caps.supports_tools);
524 assert!(claude_caps.supports_vision);
525 assert_eq!(claude_caps.context_window, 200_000);
526 }
527
528 #[test]
529 fn test_models_with_capabilities() {
530 let registry = create_test_registry();
531 let models = registry.models_with_capabilities();
532
533 assert_eq!(models.len(), 4);
534
535 for model in &models {
537 assert!(!model.name.is_empty());
538 assert!(!model.provider.is_empty());
539 assert!(model.capabilities.supports_tools);
541 }
542 }
543
544 #[test]
545 fn test_find_local_models() {
546 let registry = create_test_registry();
547 let local_models = registry.find_local_models();
548
549 assert_eq!(local_models.len(), 2);
551 for model in &local_models {
552 assert!(model.capabilities.is_local);
553 assert_eq!(model.capabilities.cost_tier, "free");
554 }
555 }
556
557 #[test]
558 fn test_find_vision_models() {
559 let registry = create_test_registry();
560 let vision_models = registry.find_vision_models();
561
562 assert_eq!(vision_models.len(), 2);
564 for model in &vision_models {
565 assert!(model.capabilities.supports_vision);
566 }
567 }
568
569 #[test]
570 fn test_find_best_model_for_agent() {
571 let registry = create_test_registry();
572
573 let requirements = CapabilityRequirements::for_agent();
574 let best = registry.find_best_model(&requirements);
575
576 assert!(best.is_some());
577 let best = best.unwrap();
578 assert!(best.capabilities.supports_tools);
579 assert!(best.capabilities.production_ready);
580 }
581
582 #[test]
583 fn test_find_best_model_with_context_window() {
584 let registry = create_test_registry();
585
586 let requirements = CapabilityRequirements::builder()
588 .min_context_window(100_000)
589 .build();
590
591 let matches = registry.find_models(&requirements);
592
593 assert!(matches.len() >= 2);
595 for model in &matches {
596 assert!(model.capabilities.context_window >= 100_000);
597 }
598 }
599
600 #[test]
601 fn test_find_best_model_prefers_cheaper() {
602 let registry = create_test_registry();
603
604 let requirements = CapabilityRequirements::builder().requires_tools().build();
606
607 let best = registry.find_best_model(&requirements).unwrap();
608
609 assert!(
612 best.capabilities.is_local || best.capabilities.cost_tier == "free",
613 "Expected best model to be local/free, got: {} (cost: {})",
614 best.name,
615 best.capabilities.cost_tier
616 );
617 }
618
619 #[test]
620 fn test_no_model_matches_impossible_requirements() {
621 let registry = create_test_registry();
622
623 let requirements = CapabilityRequirements::builder()
625 .requires_local()
626 .requires_vision()
627 .build();
628
629 let matches = registry.find_models(&requirements);
630 assert!(matches.is_empty());
631 }
632
633 #[test]
634 fn test_find_coding_models() {
635 let registry = create_test_registry();
636 let coding_models = registry.find_coding_models();
637
638 for model in &coding_models {
640 assert!(model.capabilities.supports_tools);
641 assert!(model.capabilities.supports_reasoning);
642 assert!(model.capabilities.context_window >= 32_000);
643 }
644 }
645}