code_mesh_core/llm/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4use anyhow;
5
6use super::{
7    ProviderRegistry, ModelConfig, ProviderConfig, ProviderSource, 
8    AnthropicProvider, OpenAIProvider, GitHubCopilotProvider,
9    AnthropicModelWithProvider, OpenAIModelWithProvider, GitHubCopilotModelWithProvider,
10    LanguageModel,
11};
12use crate::auth::{AuthStorage, AnthropicAuth, GitHubCopilotAuth};
13
14/// Central registry for managing LLM providers and models
15pub struct LLMRegistry {
16    provider_registry: ProviderRegistry,
17    model_cache: Arc<RwLock<HashMap<String, Arc<dyn LanguageModel>>>>,
18}
19
20impl LLMRegistry {
21    /// Create new LLM registry with authentication storage
22    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
23        Self {
24            provider_registry: ProviderRegistry::new(storage),
25            model_cache: Arc::new(RwLock::new(HashMap::new())),
26        }
27    }
28    
29    /// Initialize the registry with default configurations
30    pub async fn initialize(&mut self) -> crate::Result<()> {
31        // Load default model configurations
32        self.load_default_configs().await?;
33        
34        // Discover providers from environment and storage
35        self.provider_registry.discover_from_env().await?;
36        self.provider_registry.discover_from_storage().await?;
37        
38        // Initialize all discovered providers
39        self.provider_registry.initialize_all().await?;
40        
41        Ok(())
42    }
43    
44    /// Load configurations from models.dev API
45    pub async fn load_models_dev_configs(&mut self) -> crate::Result<()> {
46        self.provider_registry.load_models_dev().await
47    }
48    
49    /// Load configurations from file
50    pub async fn load_config_file(&mut self, path: &str) -> crate::Result<()> {
51        self.provider_registry.load_configs(path).await
52    }
53    
54    /// Get a model by provider and model ID
55    pub async fn get_model(&self, provider_id: &str, model_id: &str) -> crate::Result<Arc<dyn LanguageModel>> {
56        let cache_key = format!("{}:{}", provider_id, model_id);
57        
58        // Check cache first
59        {
60            let cache = self.model_cache.read().await;
61            if let Some(model) = cache.get(&cache_key) {
62                return Ok(model.clone());
63            }
64        }
65        
66        // Get provider and create model
67        let provider = self.provider_registry.get(provider_id).await
68            .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("Provider not found: {}", provider_id)))?;
69            
70        let model = provider.get_model(model_id).await?;
71        
72        // We can't cast Arc<dyn Model> to Arc<dyn LanguageModel> directly
73        // For now, return an error indicating this design issue
74        return Err(crate::Error::Other(anyhow::anyhow!(
75            "Model trait and LanguageModel trait are incompatible - cannot cast between them"
76        )));
77    }
78    
79    /// Get model from string (provider/model or just model)
80    pub async fn get_model_from_string(&self, model_str: &str) -> crate::Result<Arc<dyn LanguageModel>> {
81        let (provider_id, model_id) = ProviderRegistry::parse_model(model_str);
82        self.get_model(&provider_id, &model_id).await
83    }
84    
85    /// Get default model for a provider
86    pub async fn get_default_model(&self, provider_id: &str) -> crate::Result<Arc<dyn LanguageModel>> {
87        let model = self.provider_registry.get_default_model(provider_id).await?;
88        // We can't cast Arc<dyn Model> to Arc<dyn LanguageModel> directly
89        // For now, return an error indicating this design issue
90        return Err(crate::Error::Other(anyhow::anyhow!(
91            "Model trait and LanguageModel trait are incompatible - cannot cast between them"
92        )));
93    }
94    
95    /// Get the best available model across all providers
96    pub async fn get_best_model(&self) -> crate::Result<Arc<dyn LanguageModel>> {
97        let available_providers = self.provider_registry.available().await;
98        
99        if available_providers.is_empty() {
100            return Err(crate::Error::Other(anyhow::anyhow!("No providers available")));
101        }
102        
103        // Priority order for providers
104        let provider_priority = ["anthropic", "openai", "github-copilot"];
105        
106        for provider_id in provider_priority {
107            if available_providers.contains(&provider_id.to_string()) {
108                if let Ok(model) = self.get_default_model(provider_id).await {
109                    return Ok(model);
110                }
111            }
112        }
113        
114        // Fall back to first available provider
115        self.get_default_model(&available_providers[0]).await
116    }
117    
118    /// List all available providers
119    pub async fn list_providers(&self) -> Vec<String> {
120        self.provider_registry.list().await
121    }
122    
123    /// List available providers (those that can authenticate)
124    pub async fn list_available_providers(&self) -> Vec<String> {
125        self.provider_registry.available().await
126    }
127    
128    /// List models for a provider
129    pub async fn list_models(&self, provider_id: &str) -> crate::Result<Vec<ModelConfig>> {
130        let provider = self.provider_registry.get(provider_id).await
131            .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("Provider not found: {}", provider_id)))?;
132            
133        let model_infos = provider.list_models().await?;
134        
135        // Convert ModelInfo to ModelConfig
136        Ok(model_infos.into_iter().map(|info| ModelConfig {
137            model_id: info.id,
138            ..Default::default()
139        }).collect())
140    }
141    
142    /// Clear model cache
143    pub async fn clear_cache(&self) {
144        let mut cache = self.model_cache.write().await;
145        cache.clear();
146    }
147    
148    /// Get cache statistics
149    pub async fn cache_stats(&self) -> HashMap<String, usize> {
150        let cache = self.model_cache.read().await;
151        let mut stats = HashMap::new();
152        stats.insert("cached_models".to_string(), cache.len());
153        stats
154    }
155    
156    /// Load default provider configurations
157    async fn load_default_configs(&mut self) -> crate::Result<()> {
158        // This would load built-in configurations for known providers
159        // For now, providers have their default models built-in
160        Ok(())
161    }
162    
163    /// Register a custom provider
164    pub async fn register_provider(&mut self, provider: Arc<dyn super::Provider>) {
165        self.provider_registry.register(provider).await;
166    }
167}
168
169/// Helper function to create an LLM registry with file-based auth storage
170pub async fn create_default_registry() -> crate::Result<LLMRegistry> {
171    let storage = Arc::new(crate::auth::FileAuthStorage::default_with_result()?) as Arc<dyn AuthStorage>;
172    let mut registry = LLMRegistry::new(storage);
173    registry.initialize().await?;
174    Ok(registry)
175}
176
177/// Helper function to create registry with models.dev configurations
178pub async fn create_registry_with_models_dev() -> crate::Result<LLMRegistry> {
179    let mut registry = create_default_registry().await?;
180    registry.load_models_dev_configs().await?;
181    Ok(registry)
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::auth::storage::FileAuthStorage;
188    use tempfile::tempdir;
189    
190    #[tokio::test]
191    async fn test_registry_creation() {
192        let temp_dir = tempdir().unwrap();
193        let auth_path = temp_dir.path().join("auth.json");
194        let storage = Arc::new(FileAuthStorage::new(auth_path));
195        
196        let registry = LLMRegistry::new(storage);
197        let providers = registry.list_providers().await;
198        
199        // Initially empty
200        assert_eq!(providers, Vec::<String>::new());
201    }
202    
203    #[tokio::test]
204    async fn test_cache_operations() {
205        let temp_dir = tempdir().unwrap();
206        let auth_path = temp_dir.path().join("auth.json");
207        let storage = Arc::new(FileAuthStorage::new(auth_path));
208        
209        let registry = LLMRegistry::new(storage);
210        
211        // Check empty cache
212        let stats = registry.cache_stats().await;
213        assert_eq!(stats.get("cached_models"), Some(&0));
214        
215        // Clear empty cache
216        registry.clear_cache().await;
217        let stats = registry.cache_stats().await;
218        assert_eq!(stats.get("cached_models"), Some(&0));
219    }
220}