use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use anyhow;
use super::{
ProviderRegistry, ModelConfig, ProviderConfig, ProviderSource,
AnthropicProvider, OpenAIProvider, GitHubCopilotProvider,
AnthropicModelWithProvider, OpenAIModelWithProvider, GitHubCopilotModelWithProvider,
LanguageModel,
};
use crate::auth::{AuthStorage, AnthropicAuth, GitHubCopilotAuth};
pub struct LLMRegistry {
provider_registry: ProviderRegistry,
model_cache: Arc<RwLock<HashMap<String, Arc<dyn LanguageModel>>>>,
}
impl LLMRegistry {
pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
Self {
provider_registry: ProviderRegistry::new(storage),
model_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn initialize(&mut self) -> crate::Result<()> {
self.load_default_configs().await?;
self.provider_registry.discover_from_env().await?;
self.provider_registry.discover_from_storage().await?;
self.provider_registry.initialize_all().await?;
Ok(())
}
pub async fn load_models_dev_configs(&mut self) -> crate::Result<()> {
self.provider_registry.load_models_dev().await
}
pub async fn load_config_file(&mut self, path: &str) -> crate::Result<()> {
self.provider_registry.load_configs(path).await
}
pub async fn get_model(&self, provider_id: &str, model_id: &str) -> crate::Result<Arc<dyn LanguageModel>> {
let cache_key = format!("{}:{}", provider_id, model_id);
{
let cache = self.model_cache.read().await;
if let Some(model) = cache.get(&cache_key) {
return Ok(model.clone());
}
}
let provider = self.provider_registry.get(provider_id).await
.ok_or_else(|| crate::Error::Other(anyhow::anyhow!("Provider not found: {}", provider_id)))?;
let model = provider.get_model(model_id).await?;
return Err(crate::Error::Other(anyhow::anyhow!(
"Model trait and LanguageModel trait are incompatible - cannot cast between them"
)));
}
pub async fn get_model_from_string(&self, model_str: &str) -> crate::Result<Arc<dyn LanguageModel>> {
let (provider_id, model_id) = ProviderRegistry::parse_model(model_str);
self.get_model(&provider_id, &model_id).await
}
pub async fn get_default_model(&self, provider_id: &str) -> crate::Result<Arc<dyn LanguageModel>> {
let model = self.provider_registry.get_default_model(provider_id).await?;
return Err(crate::Error::Other(anyhow::anyhow!(
"Model trait and LanguageModel trait are incompatible - cannot cast between them"
)));
}
pub async fn get_best_model(&self) -> crate::Result<Arc<dyn LanguageModel>> {
let available_providers = self.provider_registry.available().await;
if available_providers.is_empty() {
return Err(crate::Error::Other(anyhow::anyhow!("No providers available")));
}
let provider_priority = ["anthropic", "openai", "github-copilot"];
for provider_id in provider_priority {
if available_providers.contains(&provider_id.to_string()) {
if let Ok(model) = self.get_default_model(provider_id).await {
return Ok(model);
}
}
}
self.get_default_model(&available_providers[0]).await
}
pub async fn list_providers(&self) -> Vec<String> {
self.provider_registry.list().await
}
pub async fn list_available_providers(&self) -> Vec<String> {
self.provider_registry.available().await
}
pub async fn list_models(&self, provider_id: &str) -> crate::Result<Vec<ModelConfig>> {
let provider = self.provider_registry.get(provider_id).await
.ok_or_else(|| crate::Error::Other(anyhow::anyhow!("Provider not found: {}", provider_id)))?;
let model_infos = provider.list_models().await?;
Ok(model_infos.into_iter().map(|info| ModelConfig {
model_id: info.id,
..Default::default()
}).collect())
}
pub async fn clear_cache(&self) {
let mut cache = self.model_cache.write().await;
cache.clear();
}
pub async fn cache_stats(&self) -> HashMap<String, usize> {
let cache = self.model_cache.read().await;
let mut stats = HashMap::new();
stats.insert("cached_models".to_string(), cache.len());
stats
}
async fn load_default_configs(&mut self) -> crate::Result<()> {
Ok(())
}
pub async fn register_provider(&mut self, provider: Arc<dyn super::Provider>) {
self.provider_registry.register(provider).await;
}
}
pub async fn create_default_registry() -> crate::Result<LLMRegistry> {
let storage = Arc::new(crate::auth::FileAuthStorage::default_with_result()?) as Arc<dyn AuthStorage>;
let mut registry = LLMRegistry::new(storage);
registry.initialize().await?;
Ok(registry)
}
pub async fn create_registry_with_models_dev() -> crate::Result<LLMRegistry> {
let mut registry = create_default_registry().await?;
registry.load_models_dev_configs().await?;
Ok(registry)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::storage::FileAuthStorage;
use tempfile::tempdir;
#[tokio::test]
async fn test_registry_creation() {
let temp_dir = tempdir().unwrap();
let auth_path = temp_dir.path().join("auth.json");
let storage = Arc::new(FileAuthStorage::new(auth_path));
let registry = LLMRegistry::new(storage);
let providers = registry.list_providers().await;
assert_eq!(providers, Vec::<String>::new());
}
#[tokio::test]
async fn test_cache_operations() {
let temp_dir = tempdir().unwrap();
let auth_path = temp_dir.path().join("auth.json");
let storage = Arc::new(FileAuthStorage::new(auth_path));
let registry = LLMRegistry::new(storage);
let stats = registry.cache_stats().await;
assert_eq!(stats.get("cached_models"), Some(&0));
registry.clear_cache().await;
let stats = registry.cache_stats().await;
assert_eq!(stats.get("cached_models"), Some(&0));
}
}