Skip to main content

ares/llm/
provider_registry.rs

1//! Provider Registry for managing multiple LLM providers
2//!
3//! This module provides a registry for managing named LLM providers
4//! that can be configured via TOML configuration.
5//!
6//! # Model Capabilities (DIR-43)
7//!
8//! The registry now supports capability-based model selection:
9//!
10//! ```rust,ignore
11//! use ares::llm::{ProviderRegistry, CapabilityRequirements};
12//!
13//! let requirements = CapabilityRequirements::builder()
14//!     .requires_tools()
15//!     .requires_vision()
16//!     .min_context_window(100_000)
17//!     .build();
18//!
19//! let model = registry.find_model(&requirements)?;
20//! let client = registry.create_client_for_model(&model.name).await?;
21//! ```
22
23use crate::llm::capabilities::{CapabilityRequirements, ModelCapabilities, ModelWithCapabilities};
24use crate::llm::client::{LLMClient, Provider};
25use crate::types::{AppError, Result};
26use crate::utils::toml_config::{AresConfig, ModelConfig, ProviderConfig};
27use std::collections::HashMap;
28use std::sync::Arc;
29
30/// Registry for managing multiple named LLM providers
31///
32/// The ProviderRegistry holds references to provider configurations and allows
33/// creating LLM clients for specific models or providers by name.
34pub struct ProviderRegistry {
35    /// Provider configurations keyed by name
36    providers: HashMap<String, ProviderConfig>,
37    /// Model configurations keyed by name
38    models: HashMap<String, ModelConfig>,
39    /// Default model name to use when none specified
40    default_model: Option<String>,
41}
42
43impl ProviderRegistry {
44    /// Create a new empty provider registry
45    pub fn new() -> Self {
46        Self {
47            providers: HashMap::new(),
48            models: HashMap::new(),
49            default_model: None,
50        }
51    }
52
53    /// Create a provider registry from TOML configuration
54    pub fn from_config(config: &AresConfig) -> Self {
55        Self {
56            providers: config.providers.clone(),
57            models: config.models.clone(),
58            default_model: config.models.keys().next().cloned(),
59        }
60    }
61
62    /// Set the default model name
63    pub fn set_default_model(&mut self, model_name: &str) {
64        self.default_model = Some(model_name.to_string());
65    }
66
67    /// Register a provider configuration
68    pub fn register_provider(&mut self, name: &str, config: ProviderConfig) {
69        self.providers.insert(name.to_string(), config);
70    }
71
72    /// Register a model configuration
73    pub fn register_model(&mut self, name: &str, config: ModelConfig) {
74        self.models.insert(name.to_string(), config);
75    }
76
77    /// Get a provider configuration by name
78    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
79        self.providers.get(name)
80    }
81
82    /// Get a model configuration by name
83    pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
84        self.models.get(name)
85    }
86
87    /// Get all provider names
88    pub fn provider_names(&self) -> Vec<&str> {
89        self.providers.keys().map(|s| s.as_str()).collect()
90    }
91
92    /// Get all model names
93    pub fn model_names(&self) -> Vec<&str> {
94        self.models.keys().map(|s| s.as_str()).collect()
95    }
96
97    /// Create an LLM client for a specific model by name
98    ///
99    /// This resolves the model -> provider chain and creates the appropriate client.
100    pub async fn create_client_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
101        let model_config = self.get_model(model_name).ok_or_else(|| {
102            AppError::Configuration(format!("Model '{}' not found in configuration", model_name))
103        })?;
104
105        let provider_config = self.get_provider(&model_config.provider).ok_or_else(|| {
106            AppError::Configuration(format!(
107                "Provider '{}' referenced by model '{}' not found",
108                model_config.provider, model_name
109            ))
110        })?;
111
112        let provider = Provider::from_model_config(model_config, provider_config)?;
113        provider.create_client().await
114    }
115
116    /// Create an LLM client for a specific provider by name
117    ///
118    /// Uses the provider's default model.
119    pub async fn create_client_for_provider(
120        &self,
121        provider_name: &str,
122    ) -> Result<Box<dyn LLMClient>> {
123        let provider_config = self.get_provider(provider_name).ok_or_else(|| {
124            AppError::Configuration(format!(
125                "Provider '{}' not found in configuration",
126                provider_name
127            ))
128        })?;
129
130        let provider = Provider::from_config(provider_config, None)?;
131        provider.create_client().await
132    }
133
134    /// Create an LLM client using the default model
135    pub async fn create_default_client(&self) -> Result<Box<dyn LLMClient>> {
136        let model_name = self
137            .default_model
138            .as_ref()
139            .ok_or_else(|| AppError::Configuration("No default model configured".into()))?;
140
141        self.create_client_for_model(model_name).await
142    }
143
144    /// Check if a model exists in the registry
145    pub fn has_model(&self, name: &str) -> bool {
146        self.models.contains_key(name)
147    }
148
149    /// Check if a provider exists in the registry
150    pub fn has_provider(&self, name: &str) -> bool {
151        self.providers.contains_key(name)
152    }
153
154    // ================== Capability-Based Model Selection (DIR-43) ==================
155
156    /// Get capabilities for a registered model.
157    ///
158    /// Attempts to auto-detect capabilities based on the model name,
159    /// or returns default capabilities if unknown.
160    pub fn get_model_capabilities(&self, model_name: &str) -> Option<ModelCapabilities> {
161        let model_config = self.get_model(model_name)?;
162        let provider_config = self.get_provider(&model_config.provider)?;
163
164        // Start with auto-detected capabilities based on model ID
165        let mut caps = ModelCapabilities::for_model(&model_config.model);
166
167        // Override with provider-specific info
168        match provider_config {
169            ProviderConfig::Ollama { .. } => {
170                caps.is_local = true;
171                caps.cost_tier = "free".to_string();
172            }
173            ProviderConfig::LlamaCpp { .. } => {
174                caps.is_local = true;
175                caps.cost_tier = "free".to_string();
176            }
177            ProviderConfig::OpenAI { .. } => {
178                caps.is_local = false;
179            }
180            ProviderConfig::Anthropic { .. } => {
181                caps.is_local = false;
182            }
183        }
184
185        Some(caps)
186    }
187
188    /// Get all models with their capabilities.
189    pub fn models_with_capabilities(&self) -> Vec<ModelWithCapabilities> {
190        self.models
191            .iter()
192            .filter_map(|(name, config)| {
193                let caps = self.get_model_capabilities(name)?;
194                Some(ModelWithCapabilities {
195                    name: name.clone(),
196                    provider: config.provider.clone(),
197                    model_id: config.model.clone(),
198                    capabilities: caps,
199                })
200            })
201            .collect()
202    }
203
204    /// Find models that satisfy the given capability requirements.
205    ///
206    /// Returns matching models sorted by score (best match first).
207    pub fn find_models(&self, requirements: &CapabilityRequirements) -> Vec<ModelWithCapabilities> {
208        let mut matches: Vec<_> = self
209            .models_with_capabilities()
210            .into_iter()
211            .filter(|m| m.capabilities.satisfies(requirements))
212            .collect();
213
214        // Sort by score (highest first)
215        matches.sort_by(|a, b| {
216            let score_a = a.capabilities.score(requirements);
217            let score_b = b.capabilities.score(requirements);
218            score_b.cmp(&score_a)
219        });
220
221        matches
222    }
223
224    /// Find the best model for the given requirements.
225    ///
226    /// Returns the highest-scoring model that satisfies all requirements,
227    /// or None if no model matches.
228    pub fn find_best_model(
229        &self,
230        requirements: &CapabilityRequirements,
231    ) -> Option<ModelWithCapabilities> {
232        self.find_models(requirements).into_iter().next()
233    }
234
235    /// Create an LLM client for the best model matching requirements.
236    ///
237    /// # Example
238    ///
239    /// ```rust,ignore
240    /// let requirements = CapabilityRequirements::builder()
241    ///     .requires_tools()
242    ///     .requires_vision()
243    ///     .build();
244    ///
245    /// let client = registry.create_client_for_requirements(&requirements).await?;
246    /// ```
247    pub async fn create_client_for_requirements(
248        &self,
249        requirements: &CapabilityRequirements,
250    ) -> Result<Box<dyn LLMClient>> {
251        let model = self.find_best_model(requirements).ok_or_else(|| {
252            AppError::Configuration(format!(
253                "No model found matching requirements: {:?}",
254                requirements
255            ))
256        })?;
257
258        self.create_client_for_model(&model.name).await
259    }
260
261    /// Find models suitable for agent tasks (tool calling required).
262    pub fn find_agent_models(&self) -> Vec<ModelWithCapabilities> {
263        self.find_models(&CapabilityRequirements::for_agent())
264    }
265
266    /// Find models suitable for vision tasks.
267    pub fn find_vision_models(&self) -> Vec<ModelWithCapabilities> {
268        self.find_models(&CapabilityRequirements::for_vision())
269    }
270
271    /// Find models suitable for coding tasks.
272    pub fn find_coding_models(&self) -> Vec<ModelWithCapabilities> {
273        self.find_models(&CapabilityRequirements::for_coding())
274    }
275
276    /// Find local-only models.
277    pub fn find_local_models(&self) -> Vec<ModelWithCapabilities> {
278        self.find_models(&CapabilityRequirements::for_local())
279    }
280
281    /// List all registered models with their provider info.
282    pub fn list_models(&self) -> Vec<ModelInfo> {
283        self.models
284            .iter()
285            .map(|(name, config)| ModelInfo {
286                name: name.clone(),
287                provider: config.provider.clone(),
288                model: config.model.clone(),
289            })
290            .collect()
291    }
292}
293
294/// Model info for listing available models via API.
295#[derive(Debug, Clone, serde::Serialize)]
296pub struct ModelInfo {
297    pub name: String,
298    pub provider: String,
299    pub model: String,
300}
301
302impl Default for ProviderRegistry {
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308/// Configuration-based LLM client factory using the provider registry
309///
310/// This is the new factory that uses TOML configuration instead of environment variables.
311pub struct ConfigBasedLLMFactory {
312    registry: Arc<ProviderRegistry>,
313    default_model: String,
314}
315
316impl ConfigBasedLLMFactory {
317    /// Create a new factory from a provider registry
318    pub fn new(registry: Arc<ProviderRegistry>, default_model: &str) -> Self {
319        Self {
320            registry,
321            default_model: default_model.to_string(),
322        }
323    }
324
325    /// Create a factory from TOML configuration
326    pub fn from_config(config: &AresConfig) -> Result<Self> {
327        let registry = ProviderRegistry::from_config(config);
328
329        // Get the first model as default, or error if no models defined
330        let default_model =
331            config.models.keys().next().cloned().ok_or_else(|| {
332                AppError::Configuration("No models defined in configuration".into())
333            })?;
334
335        Ok(Self {
336            registry: Arc::new(registry),
337            default_model,
338        })
339    }
340
341    /// Get the provider registry
342    pub fn registry(&self) -> &Arc<ProviderRegistry> {
343        &self.registry
344    }
345
346    /// Create an LLM client for a specific model
347    pub async fn create_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
348        self.registry.create_client_for_model(model_name).await
349    }
350
351    /// Create an LLM client using the default model
352    pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
353        self.registry
354            .create_client_for_model(&self.default_model)
355            .await
356    }
357
358    /// Get the default model name
359    pub fn default_model(&self) -> &str {
360        &self.default_model
361    }
362
363    /// Set the default model name
364    pub fn set_default_model(&mut self, model_name: &str) {
365        self.default_model = model_name.to_string();
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::llm::capabilities::CapabilityRequirements;
373
374    #[test]
375    fn test_empty_registry() {
376        let registry = ProviderRegistry::new();
377        assert!(registry.provider_names().is_empty());
378        assert!(registry.model_names().is_empty());
379    }
380
381    #[test]
382    fn test_register_provider() {
383        let mut registry = ProviderRegistry::new();
384        registry.register_provider(
385            "ollama-local",
386            ProviderConfig::Ollama {
387                base_url: "http://localhost:11434".to_string(),
388                default_model: "ministral-3:3b".to_string(),
389            },
390        );
391
392        assert!(registry.has_provider("ollama-local"));
393        assert!(!registry.has_provider("nonexistent"));
394    }
395
396    #[test]
397    fn test_register_model() {
398        let mut registry = ProviderRegistry::new();
399        registry.register_provider(
400            "ollama-local",
401            ProviderConfig::Ollama {
402                base_url: "http://localhost:11434".to_string(),
403                default_model: "ministral-3:3b".to_string(),
404            },
405        );
406        registry.register_model(
407            "fast",
408            ModelConfig {
409                provider: "ollama-local".to_string(),
410                model: "ministral-3:3b".to_string(),
411                temperature: 0.7,
412                max_tokens: 256,
413                top_p: None,
414                frequency_penalty: None,
415                presence_penalty: None,
416            },
417        );
418
419        assert!(registry.has_model("fast"));
420        assert!(!registry.has_model("nonexistent"));
421    }
422
423    // ================== DIR-43: Capability Tests ==================
424
425    fn create_test_registry() -> ProviderRegistry {
426        let mut registry = ProviderRegistry::new();
427
428        // Register providers
429        registry.register_provider(
430            "ollama",
431            ProviderConfig::Ollama {
432                base_url: "http://localhost:11434".to_string(),
433                default_model: "llama-3.3-70b-instruct".to_string(),
434            },
435        );
436
437        registry.register_provider(
438            "anthropic",
439            ProviderConfig::Anthropic {
440                api_key_env: "ANTHROPIC_API_KEY".to_string(),
441                default_model: "claude-3-5-sonnet-20241022".to_string(),
442            },
443        );
444
445        registry.register_provider(
446            "openai",
447            ProviderConfig::OpenAI {
448                api_key_env: "OPENAI_API_KEY".to_string(),
449                api_base: "https://api.openai.com/v1".to_string(),
450                default_model: "gpt-4o".to_string(),
451            },
452        );
453
454        // Register models
455        registry.register_model(
456            "fast-local",
457            ModelConfig {
458                provider: "ollama".to_string(),
459                model: "ministral-3:3b".to_string(),
460                temperature: 0.7,
461                max_tokens: 512,
462                top_p: None,
463                frequency_penalty: None,
464                presence_penalty: None,
465            },
466        );
467
468        registry.register_model(
469            "powerful-local",
470            ModelConfig {
471                provider: "ollama".to_string(),
472                model: "llama-3.3-70b-instruct".to_string(),
473                temperature: 0.7,
474                max_tokens: 2048,
475                top_p: None,
476                frequency_penalty: None,
477                presence_penalty: None,
478            },
479        );
480
481        registry.register_model(
482            "claude-sonnet",
483            ModelConfig {
484                provider: "anthropic".to_string(),
485                model: "claude-3-5-sonnet-20241022".to_string(),
486                temperature: 0.7,
487                max_tokens: 4096,
488                top_p: None,
489                frequency_penalty: None,
490                presence_penalty: None,
491            },
492        );
493
494        registry.register_model(
495            "gpt4o",
496            ModelConfig {
497                provider: "openai".to_string(),
498                model: "gpt-4o-2024-08-06".to_string(),
499                temperature: 0.7,
500                max_tokens: 4096,
501                top_p: None,
502                frequency_penalty: None,
503                presence_penalty: None,
504            },
505        );
506
507        registry
508    }
509
510    #[test]
511    fn test_get_model_capabilities() {
512        let registry = create_test_registry();
513
514        // Test local model capabilities
515        let fast_caps = registry.get_model_capabilities("fast-local").unwrap();
516        assert!(fast_caps.is_local);
517        assert_eq!(fast_caps.cost_tier, "free");
518        assert!(fast_caps.supports_tools);
519
520        // Test cloud model capabilities
521        let claude_caps = registry.get_model_capabilities("claude-sonnet").unwrap();
522        assert!(!claude_caps.is_local);
523        assert!(claude_caps.supports_tools);
524        assert!(claude_caps.supports_vision);
525        assert_eq!(claude_caps.context_window, 200_000);
526    }
527
528    #[test]
529    fn test_models_with_capabilities() {
530        let registry = create_test_registry();
531        let models = registry.models_with_capabilities();
532
533        assert_eq!(models.len(), 4);
534
535        // Verify all models have capabilities
536        for model in &models {
537            assert!(!model.name.is_empty());
538            assert!(!model.provider.is_empty());
539            // All these models should support tools
540            assert!(model.capabilities.supports_tools);
541        }
542    }
543
544    #[test]
545    fn test_find_local_models() {
546        let registry = create_test_registry();
547        let local_models = registry.find_local_models();
548
549        // Should find the two Ollama models
550        assert_eq!(local_models.len(), 2);
551        for model in &local_models {
552            assert!(model.capabilities.is_local);
553            assert_eq!(model.capabilities.cost_tier, "free");
554        }
555    }
556
557    #[test]
558    fn test_find_vision_models() {
559        let registry = create_test_registry();
560        let vision_models = registry.find_vision_models();
561
562        // Claude and GPT-4o support vision
563        assert_eq!(vision_models.len(), 2);
564        for model in &vision_models {
565            assert!(model.capabilities.supports_vision);
566        }
567    }
568
569    #[test]
570    fn test_find_best_model_for_agent() {
571        let registry = create_test_registry();
572
573        let requirements = CapabilityRequirements::for_agent();
574        let best = registry.find_best_model(&requirements);
575
576        assert!(best.is_some());
577        let best = best.unwrap();
578        assert!(best.capabilities.supports_tools);
579        assert!(best.capabilities.production_ready);
580    }
581
582    #[test]
583    fn test_find_best_model_with_context_window() {
584        let registry = create_test_registry();
585
586        // Require large context window
587        let requirements = CapabilityRequirements::builder()
588            .min_context_window(100_000)
589            .build();
590
591        let matches = registry.find_models(&requirements);
592
593        // Should match Claude (200k), GPT-4o (128k), and Llama (128k)
594        assert!(matches.len() >= 2);
595        for model in &matches {
596            assert!(model.capabilities.context_window >= 100_000);
597        }
598    }
599
600    #[test]
601    fn test_find_best_model_prefers_cheaper() {
602        let registry = create_test_registry();
603
604        // Basic requirements that all models satisfy
605        let requirements = CapabilityRequirements::builder().requires_tools().build();
606
607        let best = registry.find_best_model(&requirements).unwrap();
608
609        // Should prefer local/free models when all else is equal
610        // (scoring penalizes cost)
611        assert!(
612            best.capabilities.is_local || best.capabilities.cost_tier == "free",
613            "Expected best model to be local/free, got: {} (cost: {})",
614            best.name,
615            best.capabilities.cost_tier
616        );
617    }
618
619    #[test]
620    fn test_no_model_matches_impossible_requirements() {
621        let registry = create_test_registry();
622
623        // Impossible requirements: local + vision (no local vision models in test registry)
624        let requirements = CapabilityRequirements::builder()
625            .requires_local()
626            .requires_vision()
627            .build();
628
629        let matches = registry.find_models(&requirements);
630        assert!(matches.is_empty());
631    }
632
633    #[test]
634    fn test_find_coding_models() {
635        let registry = create_test_registry();
636        let coding_models = registry.find_coding_models();
637
638        // Should find models that support tools + reasoning + large context
639        for model in &coding_models {
640            assert!(model.capabilities.supports_tools);
641            assert!(model.capabilities.supports_reasoning);
642            assert!(model.capabilities.context_window >= 32_000);
643        }
644    }
645}