Skip to main content

ares/llm/
provider_registry.rs

1//! Provider Registry for managing multiple LLM providers
2//!
3//! This module provides a registry for managing named LLM providers
4//! that can be configured via TOML configuration.
5//!
6//! # Model Capabilities (DIR-43)
7//!
8//! The registry now supports capability-based model selection:
9//!
10//! ```rust,ignore
11//! use ares::llm::{ProviderRegistry, CapabilityRequirements};
12//!
13//! let requirements = CapabilityRequirements::builder()
14//!     .requires_tools()
15//!     .requires_vision()
16//!     .min_context_window(100_000)
17//!     .build();
18//!
19//! let model = registry.find_model(&requirements)?;
20//! let client = registry.create_client_for_model(&model.name).await?;
21//! ```
22
23use crate::llm::capabilities::{CapabilityRequirements, ModelCapabilities, ModelWithCapabilities};
24use crate::llm::client::{LLMClient, Provider};
25use crate::types::{AppError, Result};
26use crate::utils::toml_config::{AresConfig, ModelConfig, ProviderConfig};
27use std::collections::HashMap;
28use std::sync::Arc;
29
30/// Registry for managing multiple named LLM providers
31///
32/// The ProviderRegistry holds references to provider configurations and allows
33/// creating LLM clients for specific models or providers by name.
34pub struct ProviderRegistry {
35    /// Provider configurations keyed by name
36    providers: HashMap<String, ProviderConfig>,
37    /// Model configurations keyed by name
38    models: HashMap<String, ModelConfig>,
39    /// Default model name to use when none specified
40    default_model: Option<String>,
41}
42
43impl ProviderRegistry {
44    /// Create a new empty provider registry
45    pub fn new() -> Self {
46        Self {
47            providers: HashMap::new(),
48            models: HashMap::new(),
49            default_model: None,
50        }
51    }
52
53    /// Create a provider registry from TOML configuration
54    pub fn from_config(config: &AresConfig) -> Self {
55        Self {
56            providers: config.providers.clone(),
57            models: config.models.clone(),
58            default_model: config.models.keys().next().cloned(),
59        }
60    }
61
62    /// Set the default model name
63    pub fn set_default_model(&mut self, model_name: &str) {
64        self.default_model = Some(model_name.to_string());
65    }
66
67    /// Register a provider configuration
68    pub fn register_provider(&mut self, name: &str, config: ProviderConfig) {
69        self.providers.insert(name.to_string(), config);
70    }
71
72    /// Register a model configuration
73    pub fn register_model(&mut self, name: &str, config: ModelConfig) {
74        self.models.insert(name.to_string(), config);
75    }
76
77    /// Get a provider configuration by name
78    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
79        self.providers.get(name)
80    }
81
82    /// Get a model configuration by name
83    pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
84        self.models.get(name)
85    }
86
87    /// Get all provider names
88    pub fn provider_names(&self) -> Vec<&str> {
89        self.providers.keys().map(|s| s.as_str()).collect()
90    }
91
92    /// Get all model names
93    pub fn model_names(&self) -> Vec<&str> {
94        self.models.keys().map(|s| s.as_str()).collect()
95    }
96
97    /// Create an LLM client for a specific model by name
98    ///
99    /// This resolves the model -> provider chain and creates the appropriate client.
100    pub async fn create_client_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
101        let model_config = self.get_model(model_name).ok_or_else(|| {
102            AppError::Configuration(format!("Model '{}' not found in configuration", model_name))
103        })?;
104
105        let provider_config = self.get_provider(&model_config.provider).ok_or_else(|| {
106            AppError::Configuration(format!(
107                "Provider '{}' referenced by model '{}' not found",
108                model_config.provider, model_name
109            ))
110        })?;
111
112        let provider = Provider::from_model_config(model_config, provider_config)?;
113        provider.create_client().await
114    }
115
116    /// Create an LLM client for a specific provider by name
117    ///
118    /// Uses the provider's default model.
119    pub async fn create_client_for_provider(
120        &self,
121        provider_name: &str,
122    ) -> Result<Box<dyn LLMClient>> {
123        let provider_config = self.get_provider(provider_name).ok_or_else(|| {
124            AppError::Configuration(format!(
125                "Provider '{}' not found in configuration",
126                provider_name
127            ))
128        })?;
129
130        let provider = Provider::from_config(provider_config, None)?;
131        provider.create_client().await
132    }
133
134    /// Create an LLM client using the default model
135    pub async fn create_default_client(&self) -> Result<Box<dyn LLMClient>> {
136        let model_name = self
137            .default_model
138            .as_ref()
139            .ok_or_else(|| AppError::Configuration("No default model configured".into()))?;
140
141        self.create_client_for_model(model_name).await
142    }
143
144    /// Check if a model exists in the registry
145    pub fn has_model(&self, name: &str) -> bool {
146        self.models.contains_key(name)
147    }
148
149    /// Check if a provider exists in the registry
150    pub fn has_provider(&self, name: &str) -> bool {
151        self.providers.contains_key(name)
152    }
153
154    // ================== Capability-Based Model Selection (DIR-43) ==================
155
156    /// Get capabilities for a registered model.
157    ///
158    /// Attempts to auto-detect capabilities based on the model name,
159    /// or returns default capabilities if unknown.
160    pub fn get_model_capabilities(&self, model_name: &str) -> Option<ModelCapabilities> {
161        let model_config = self.get_model(model_name)?;
162        let provider_config = self.get_provider(&model_config.provider)?;
163
164        // Start with auto-detected capabilities based on model ID
165        let mut caps = ModelCapabilities::for_model(&model_config.model);
166
167        // Override with provider-specific info
168        match provider_config {
169            ProviderConfig::Ollama { .. } => {
170                caps.is_local = true;
171                caps.cost_tier = "free".to_string();
172            }
173            ProviderConfig::LlamaCpp { .. } => {
174                caps.is_local = true;
175                caps.cost_tier = "free".to_string();
176            }
177            ProviderConfig::OpenAI { .. } => {
178                caps.is_local = false;
179            }
180            ProviderConfig::Anthropic { .. } => {
181                caps.is_local = false;
182            }
183        }
184
185        Some(caps)
186    }
187
188    /// Get all models with their capabilities.
189    pub fn models_with_capabilities(&self) -> Vec<ModelWithCapabilities> {
190        self.models
191            .iter()
192            .filter_map(|(name, config)| {
193                let caps = self.get_model_capabilities(name)?;
194                Some(ModelWithCapabilities {
195                    name: name.clone(),
196                    provider: config.provider.clone(),
197                    model_id: config.model.clone(),
198                    capabilities: caps,
199                })
200            })
201            .collect()
202    }
203
204    /// Find models that satisfy the given capability requirements.
205    ///
206    /// Returns matching models sorted by score (best match first).
207    pub fn find_models(&self, requirements: &CapabilityRequirements) -> Vec<ModelWithCapabilities> {
208        let mut matches: Vec<_> = self
209            .models_with_capabilities()
210            .into_iter()
211            .filter(|m| m.capabilities.satisfies(requirements))
212            .collect();
213
214        // Sort by score (highest first)
215        matches.sort_by(|a, b| {
216            let score_a = a.capabilities.score(requirements);
217            let score_b = b.capabilities.score(requirements);
218            score_b.cmp(&score_a)
219        });
220
221        matches
222    }
223
224    /// Find the best model for the given requirements.
225    ///
226    /// Returns the highest-scoring model that satisfies all requirements,
227    /// or None if no model matches.
228    pub fn find_best_model(
229        &self,
230        requirements: &CapabilityRequirements,
231    ) -> Option<ModelWithCapabilities> {
232        self.find_models(requirements).into_iter().next()
233    }
234
235    /// Create an LLM client for the best model matching requirements.
236    ///
237    /// # Example
238    ///
239    /// ```rust,ignore
240    /// let requirements = CapabilityRequirements::builder()
241    ///     .requires_tools()
242    ///     .requires_vision()
243    ///     .build();
244    ///
245    /// let client = registry.create_client_for_requirements(&requirements).await?;
246    /// ```
247    pub async fn create_client_for_requirements(
248        &self,
249        requirements: &CapabilityRequirements,
250    ) -> Result<Box<dyn LLMClient>> {
251        let model = self.find_best_model(requirements).ok_or_else(|| {
252            AppError::Configuration(format!(
253                "No model found matching requirements: {:?}",
254                requirements
255            ))
256        })?;
257
258        self.create_client_for_model(&model.name).await
259    }
260
261    /// Find models suitable for agent tasks (tool calling required).
262    pub fn find_agent_models(&self) -> Vec<ModelWithCapabilities> {
263        self.find_models(&CapabilityRequirements::for_agent())
264    }
265
266    /// Find models suitable for vision tasks.
267    pub fn find_vision_models(&self) -> Vec<ModelWithCapabilities> {
268        self.find_models(&CapabilityRequirements::for_vision())
269    }
270
271    /// Find models suitable for coding tasks.
272    pub fn find_coding_models(&self) -> Vec<ModelWithCapabilities> {
273        self.find_models(&CapabilityRequirements::for_coding())
274    }
275
276    /// Find local-only models.
277    pub fn find_local_models(&self) -> Vec<ModelWithCapabilities> {
278        self.find_models(&CapabilityRequirements::for_local())
279    }
280
281    /// List all registered models with their provider info.
282    pub fn list_models(&self) -> Vec<ModelInfo> {
283        self.models.iter().map(|(name, config)| ModelInfo {
284            name: name.clone(),
285            provider: config.provider.clone(),
286            model: config.model.clone(),
287        }).collect()
288    }
289}
290
291/// Model info for listing available models via API.
292#[derive(Debug, Clone, serde::Serialize)]
293pub struct ModelInfo {
294    pub name: String,
295    pub provider: String,
296    pub model: String,
297}
298
299impl Default for ProviderRegistry {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305/// Configuration-based LLM client factory using the provider registry
306///
307/// This is the new factory that uses TOML configuration instead of environment variables.
308pub struct ConfigBasedLLMFactory {
309    registry: Arc<ProviderRegistry>,
310    default_model: String,
311}
312
313impl ConfigBasedLLMFactory {
314    /// Create a new factory from a provider registry
315    pub fn new(registry: Arc<ProviderRegistry>, default_model: &str) -> Self {
316        Self {
317            registry,
318            default_model: default_model.to_string(),
319        }
320    }
321
322    /// Create a factory from TOML configuration
323    pub fn from_config(config: &AresConfig) -> Result<Self> {
324        let registry = ProviderRegistry::from_config(config);
325
326        // Get the first model as default, or error if no models defined
327        let default_model =
328            config.models.keys().next().cloned().ok_or_else(|| {
329                AppError::Configuration("No models defined in configuration".into())
330            })?;
331
332        Ok(Self {
333            registry: Arc::new(registry),
334            default_model,
335        })
336    }
337
338    /// Get the provider registry
339    pub fn registry(&self) -> &Arc<ProviderRegistry> {
340        &self.registry
341    }
342
343    /// Create an LLM client for a specific model
344    pub async fn create_for_model(&self, model_name: &str) -> Result<Box<dyn LLMClient>> {
345        self.registry.create_client_for_model(model_name).await
346    }
347
348    /// Create an LLM client using the default model
349    pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
350        self.registry
351            .create_client_for_model(&self.default_model)
352            .await
353    }
354
355    /// Get the default model name
356    pub fn default_model(&self) -> &str {
357        &self.default_model
358    }
359
360    /// Set the default model name
361    pub fn set_default_model(&mut self, model_name: &str) {
362        self.default_model = model_name.to_string();
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::llm::capabilities::CapabilityRequirements;
370
371    #[test]
372    fn test_empty_registry() {
373        let registry = ProviderRegistry::new();
374        assert!(registry.provider_names().is_empty());
375        assert!(registry.model_names().is_empty());
376    }
377
378    #[test]
379    fn test_register_provider() {
380        let mut registry = ProviderRegistry::new();
381        registry.register_provider(
382            "ollama-local",
383            ProviderConfig::Ollama {
384                base_url: "http://localhost:11434".to_string(),
385                default_model: "ministral-3:3b".to_string(),
386            },
387        );
388
389        assert!(registry.has_provider("ollama-local"));
390        assert!(!registry.has_provider("nonexistent"));
391    }
392
393    #[test]
394    fn test_register_model() {
395        let mut registry = ProviderRegistry::new();
396        registry.register_provider(
397            "ollama-local",
398            ProviderConfig::Ollama {
399                base_url: "http://localhost:11434".to_string(),
400                default_model: "ministral-3:3b".to_string(),
401            },
402        );
403        registry.register_model(
404            "fast",
405            ModelConfig {
406                provider: "ollama-local".to_string(),
407                model: "ministral-3:3b".to_string(),
408                temperature: 0.7,
409                max_tokens: 256,
410                top_p: None,
411                frequency_penalty: None,
412                presence_penalty: None,
413            },
414        );
415
416        assert!(registry.has_model("fast"));
417        assert!(!registry.has_model("nonexistent"));
418    }
419
420    // ================== DIR-43: Capability Tests ==================
421
422    fn create_test_registry() -> ProviderRegistry {
423        let mut registry = ProviderRegistry::new();
424
425        // Register providers
426        registry.register_provider(
427            "ollama",
428            ProviderConfig::Ollama {
429                base_url: "http://localhost:11434".to_string(),
430                default_model: "llama-3.3-70b-instruct".to_string(),
431            },
432        );
433
434        registry.register_provider(
435            "anthropic",
436            ProviderConfig::Anthropic {
437                api_key_env: "ANTHROPIC_API_KEY".to_string(),
438                default_model: "claude-3-5-sonnet-20241022".to_string(),
439            },
440        );
441
442        registry.register_provider(
443            "openai",
444            ProviderConfig::OpenAI {
445                api_key_env: "OPENAI_API_KEY".to_string(),
446                api_base: "https://api.openai.com/v1".to_string(),
447                default_model: "gpt-4o".to_string(),
448            },
449        );
450
451        // Register models
452        registry.register_model(
453            "fast-local",
454            ModelConfig {
455                provider: "ollama".to_string(),
456                model: "ministral-3:3b".to_string(),
457                temperature: 0.7,
458                max_tokens: 512,
459                top_p: None,
460                frequency_penalty: None,
461                presence_penalty: None,
462            },
463        );
464
465        registry.register_model(
466            "powerful-local",
467            ModelConfig {
468                provider: "ollama".to_string(),
469                model: "llama-3.3-70b-instruct".to_string(),
470                temperature: 0.7,
471                max_tokens: 2048,
472                top_p: None,
473                frequency_penalty: None,
474                presence_penalty: None,
475            },
476        );
477
478        registry.register_model(
479            "claude-sonnet",
480            ModelConfig {
481                provider: "anthropic".to_string(),
482                model: "claude-3-5-sonnet-20241022".to_string(),
483                temperature: 0.7,
484                max_tokens: 4096,
485                top_p: None,
486                frequency_penalty: None,
487                presence_penalty: None,
488            },
489        );
490
491        registry.register_model(
492            "gpt4o",
493            ModelConfig {
494                provider: "openai".to_string(),
495                model: "gpt-4o-2024-08-06".to_string(),
496                temperature: 0.7,
497                max_tokens: 4096,
498                top_p: None,
499                frequency_penalty: None,
500                presence_penalty: None,
501            },
502        );
503
504        registry
505    }
506
507    #[test]
508    fn test_get_model_capabilities() {
509        let registry = create_test_registry();
510
511        // Test local model capabilities
512        let fast_caps = registry.get_model_capabilities("fast-local").unwrap();
513        assert!(fast_caps.is_local);
514        assert_eq!(fast_caps.cost_tier, "free");
515        assert!(fast_caps.supports_tools);
516
517        // Test cloud model capabilities
518        let claude_caps = registry.get_model_capabilities("claude-sonnet").unwrap();
519        assert!(!claude_caps.is_local);
520        assert!(claude_caps.supports_tools);
521        assert!(claude_caps.supports_vision);
522        assert_eq!(claude_caps.context_window, 200_000);
523    }
524
525    #[test]
526    fn test_models_with_capabilities() {
527        let registry = create_test_registry();
528        let models = registry.models_with_capabilities();
529
530        assert_eq!(models.len(), 4);
531
532        // Verify all models have capabilities
533        for model in &models {
534            assert!(!model.name.is_empty());
535            assert!(!model.provider.is_empty());
536            // All these models should support tools
537            assert!(model.capabilities.supports_tools);
538        }
539    }
540
541    #[test]
542    fn test_find_local_models() {
543        let registry = create_test_registry();
544        let local_models = registry.find_local_models();
545
546        // Should find the two Ollama models
547        assert_eq!(local_models.len(), 2);
548        for model in &local_models {
549            assert!(model.capabilities.is_local);
550            assert_eq!(model.capabilities.cost_tier, "free");
551        }
552    }
553
554    #[test]
555    fn test_find_vision_models() {
556        let registry = create_test_registry();
557        let vision_models = registry.find_vision_models();
558
559        // Claude and GPT-4o support vision
560        assert_eq!(vision_models.len(), 2);
561        for model in &vision_models {
562            assert!(model.capabilities.supports_vision);
563        }
564    }
565
566    #[test]
567    fn test_find_best_model_for_agent() {
568        let registry = create_test_registry();
569
570        let requirements = CapabilityRequirements::for_agent();
571        let best = registry.find_best_model(&requirements);
572
573        assert!(best.is_some());
574        let best = best.unwrap();
575        assert!(best.capabilities.supports_tools);
576        assert!(best.capabilities.production_ready);
577    }
578
579    #[test]
580    fn test_find_best_model_with_context_window() {
581        let registry = create_test_registry();
582
583        // Require large context window
584        let requirements = CapabilityRequirements::builder()
585            .min_context_window(100_000)
586            .build();
587
588        let matches = registry.find_models(&requirements);
589
590        // Should match Claude (200k), GPT-4o (128k), and Llama (128k)
591        assert!(matches.len() >= 2);
592        for model in &matches {
593            assert!(model.capabilities.context_window >= 100_000);
594        }
595    }
596
597    #[test]
598    fn test_find_best_model_prefers_cheaper() {
599        let registry = create_test_registry();
600
601        // Basic requirements that all models satisfy
602        let requirements = CapabilityRequirements::builder().requires_tools().build();
603
604        let best = registry.find_best_model(&requirements).unwrap();
605
606        // Should prefer local/free models when all else is equal
607        // (scoring penalizes cost)
608        assert!(
609            best.capabilities.is_local || best.capabilities.cost_tier == "free",
610            "Expected best model to be local/free, got: {} (cost: {})",
611            best.name,
612            best.capabilities.cost_tier
613        );
614    }
615
616    #[test]
617    fn test_no_model_matches_impossible_requirements() {
618        let registry = create_test_registry();
619
620        // Impossible requirements: local + vision (no local vision models in test registry)
621        let requirements = CapabilityRequirements::builder()
622            .requires_local()
623            .requires_vision()
624            .build();
625
626        let matches = registry.find_models(&requirements);
627        assert!(matches.is_empty());
628    }
629
630    #[test]
631    fn test_find_coding_models() {
632        let registry = create_test_registry();
633        let coding_models = registry.find_coding_models();
634
635        // Should find models that support tools + reasoning + large context
636        for model in &coding_models {
637            assert!(model.capabilities.supports_tools);
638            assert!(model.capabilities.supports_reasoning);
639            assert!(model.capabilities.context_window >= 32_000);
640        }
641    }
642}