code_mesh_core/llm/
registry.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4use anyhow;
5
6use super::{
7 ProviderRegistry, ModelConfig, ProviderConfig, ProviderSource,
8 AnthropicProvider, OpenAIProvider, GitHubCopilotProvider,
9 AnthropicModelWithProvider, OpenAIModelWithProvider, GitHubCopilotModelWithProvider,
10 LanguageModel,
11};
12use crate::auth::{AuthStorage, AnthropicAuth, GitHubCopilotAuth};
13
14pub struct LLMRegistry {
16 provider_registry: ProviderRegistry,
17 model_cache: Arc<RwLock<HashMap<String, Arc<dyn LanguageModel>>>>,
18}
19
20impl LLMRegistry {
21 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
23 Self {
24 provider_registry: ProviderRegistry::new(storage),
25 model_cache: Arc::new(RwLock::new(HashMap::new())),
26 }
27 }
28
29 pub async fn initialize(&mut self) -> crate::Result<()> {
31 self.load_default_configs().await?;
33
34 self.provider_registry.discover_from_env().await?;
36 self.provider_registry.discover_from_storage().await?;
37
38 self.provider_registry.initialize_all().await?;
40
41 Ok(())
42 }
43
44 pub async fn load_models_dev_configs(&mut self) -> crate::Result<()> {
46 self.provider_registry.load_models_dev().await
47 }
48
49 pub async fn load_config_file(&mut self, path: &str) -> crate::Result<()> {
51 self.provider_registry.load_configs(path).await
52 }
53
54 pub async fn get_model(&self, provider_id: &str, model_id: &str) -> crate::Result<Arc<dyn LanguageModel>> {
56 let cache_key = format!("{}:{}", provider_id, model_id);
57
58 {
60 let cache = self.model_cache.read().await;
61 if let Some(model) = cache.get(&cache_key) {
62 return Ok(model.clone());
63 }
64 }
65
66 let provider = self.provider_registry.get(provider_id).await
68 .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("Provider not found: {}", provider_id)))?;
69
70 let model = provider.get_model(model_id).await?;
71
72 return Err(crate::Error::Other(anyhow::anyhow!(
75 "Model trait and LanguageModel trait are incompatible - cannot cast between them"
76 )));
77 }
78
79 pub async fn get_model_from_string(&self, model_str: &str) -> crate::Result<Arc<dyn LanguageModel>> {
81 let (provider_id, model_id) = ProviderRegistry::parse_model(model_str);
82 self.get_model(&provider_id, &model_id).await
83 }
84
85 pub async fn get_default_model(&self, provider_id: &str) -> crate::Result<Arc<dyn LanguageModel>> {
87 let model = self.provider_registry.get_default_model(provider_id).await?;
88 return Err(crate::Error::Other(anyhow::anyhow!(
91 "Model trait and LanguageModel trait are incompatible - cannot cast between them"
92 )));
93 }
94
95 pub async fn get_best_model(&self) -> crate::Result<Arc<dyn LanguageModel>> {
97 let available_providers = self.provider_registry.available().await;
98
99 if available_providers.is_empty() {
100 return Err(crate::Error::Other(anyhow::anyhow!("No providers available")));
101 }
102
103 let provider_priority = ["anthropic", "openai", "github-copilot"];
105
106 for provider_id in provider_priority {
107 if available_providers.contains(&provider_id.to_string()) {
108 if let Ok(model) = self.get_default_model(provider_id).await {
109 return Ok(model);
110 }
111 }
112 }
113
114 self.get_default_model(&available_providers[0]).await
116 }
117
118 pub async fn list_providers(&self) -> Vec<String> {
120 self.provider_registry.list().await
121 }
122
123 pub async fn list_available_providers(&self) -> Vec<String> {
125 self.provider_registry.available().await
126 }
127
128 pub async fn list_models(&self, provider_id: &str) -> crate::Result<Vec<ModelConfig>> {
130 let provider = self.provider_registry.get(provider_id).await
131 .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("Provider not found: {}", provider_id)))?;
132
133 let model_infos = provider.list_models().await?;
134
135 Ok(model_infos.into_iter().map(|info| ModelConfig {
137 model_id: info.id,
138 ..Default::default()
139 }).collect())
140 }
141
142 pub async fn clear_cache(&self) {
144 let mut cache = self.model_cache.write().await;
145 cache.clear();
146 }
147
148 pub async fn cache_stats(&self) -> HashMap<String, usize> {
150 let cache = self.model_cache.read().await;
151 let mut stats = HashMap::new();
152 stats.insert("cached_models".to_string(), cache.len());
153 stats
154 }
155
156 async fn load_default_configs(&mut self) -> crate::Result<()> {
158 Ok(())
161 }
162
163 pub async fn register_provider(&mut self, provider: Arc<dyn super::Provider>) {
165 self.provider_registry.register(provider).await;
166 }
167}
168
169pub async fn create_default_registry() -> crate::Result<LLMRegistry> {
171 let storage = Arc::new(crate::auth::FileAuthStorage::default_with_result()?) as Arc<dyn AuthStorage>;
172 let mut registry = LLMRegistry::new(storage);
173 registry.initialize().await?;
174 Ok(registry)
175}
176
177pub async fn create_registry_with_models_dev() -> crate::Result<LLMRegistry> {
179 let mut registry = create_default_registry().await?;
180 registry.load_models_dev_configs().await?;
181 Ok(registry)
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::auth::storage::FileAuthStorage;
188 use tempfile::tempdir;
189
190 #[tokio::test]
191 async fn test_registry_creation() {
192 let temp_dir = tempdir().unwrap();
193 let auth_path = temp_dir.path().join("auth.json");
194 let storage = Arc::new(FileAuthStorage::new(auth_path));
195
196 let registry = LLMRegistry::new(storage);
197 let providers = registry.list_providers().await;
198
199 assert_eq!(providers, Vec::<String>::new());
201 }
202
203 #[tokio::test]
204 async fn test_cache_operations() {
205 let temp_dir = tempdir().unwrap();
206 let auth_path = temp_dir.path().join("auth.json");
207 let storage = Arc::new(FileAuthStorage::new(auth_path));
208
209 let registry = LLMRegistry::new(storage);
210
211 let stats = registry.cache_stats().await;
213 assert_eq!(stats.get("cached_models"), Some(&0));
214
215 registry.clear_cache().await;
217 let stats = registry.cache_stats().await;
218 assert_eq!(stats.get("cached_models"), Some(&0));
219 }
220}