use std::collections::HashMap;
use std::sync::Arc;
use crate::error::Result;
use crate::traits::{EmbeddingProvider, LLMProvider};
use crate::{GeminiProvider, JinaProvider, MockProvider, ProviderFactory, ProviderType};
#[derive(Default)]
pub struct ProviderRegistry {
llm_providers: HashMap<String, Arc<dyn LLMProvider>>,
embedding_providers: HashMap<String, Arc<dyn EmbeddingProvider>>,
}
impl ProviderRegistry {
pub fn new() -> Result<Self> {
let mut registry = Self::default();
if let Ok((llm, embed)) = ProviderFactory::create(ProviderType::OpenAI) {
registry.register_llm("openai", llm);
registry.register_embedding("openai", embed);
}
if let Ok((llm, embed)) = ProviderFactory::create(ProviderType::Ollama) {
registry.register_llm("ollama", llm);
registry.register_embedding("ollama", embed);
}
if let Ok((llm, embed)) = ProviderFactory::create(ProviderType::LMStudio) {
registry.register_llm("lmstudio", llm);
registry.register_embedding("lmstudio", embed);
}
if let Ok((llm, embed)) = ProviderFactory::create(ProviderType::VsCodeCopilot) {
registry.register_llm("vscode-copilot", llm);
registry.register_embedding("vscode-copilot", embed);
}
if let Ok(provider) = GeminiProvider::from_env() {
registry.register_llm("gemini", Arc::new(provider));
}
if let Ok(provider) = JinaProvider::from_env() {
registry.register_embedding("jina", Arc::new(provider));
}
let mock = Arc::new(MockProvider::new());
registry.register_llm("mock", mock.clone());
registry.register_embedding("mock", mock);
Ok(registry)
}
pub fn register_llm(&mut self, name: impl Into<String>, provider: Arc<dyn LLMProvider>) {
self.llm_providers.insert(name.into(), provider);
}
pub fn register_embedding(
&mut self,
name: impl Into<String>,
provider: Arc<dyn EmbeddingProvider>,
) {
self.embedding_providers.insert(name.into(), provider);
}
pub fn get_llm(&self, name: &str) -> Option<Arc<dyn LLMProvider>> {
self.llm_providers.get(name).cloned()
}
pub fn get_embedding(&self, name: &str) -> Option<Arc<dyn EmbeddingProvider>> {
self.embedding_providers.get(name).cloned()
}
pub fn list_llm(&self) -> Vec<String> {
self.llm_providers.keys().cloned().collect()
}
pub fn list_embedding(&self) -> Vec<String> {
self.embedding_providers.keys().cloned().collect()
}
pub fn has_llm(&self, name: &str) -> bool {
self.llm_providers.contains_key(name)
}
pub fn has_embedding(&self, name: &str) -> bool {
self.embedding_providers.contains_key(name)
}
pub fn remove_llm(&mut self, name: &str) -> Option<Arc<dyn LLMProvider>> {
self.llm_providers.remove(name)
}
pub fn remove_embedding(&mut self, name: &str) -> Option<Arc<dyn EmbeddingProvider>> {
self.embedding_providers.remove(name)
}
pub fn llm_count(&self) -> usize {
self.llm_providers.len()
}
pub fn embedding_count(&self) -> usize {
self.embedding_providers.len()
}
pub fn clear_llm(&mut self) {
self.llm_providers.clear();
}
pub fn clear_embedding(&mut self) {
self.embedding_providers.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_registry_new() {
let registry = ProviderRegistry::new();
assert!(registry.is_ok());
let registry = registry.unwrap();
assert!(registry.has_llm("mock"));
assert!(registry.has_embedding("mock"));
}
#[tokio::test]
async fn test_get_mock_provider() {
let registry = ProviderRegistry::new().unwrap();
let mock = registry.get_llm("mock");
assert!(mock.is_some());
let embed_mock = registry.get_embedding("mock");
assert!(embed_mock.is_some());
}
#[tokio::test]
async fn test_get_nonexistent_provider() {
let registry = ProviderRegistry::new().unwrap();
let unknown = registry.get_llm("nonexistent");
assert!(unknown.is_none());
let unknown_embed = registry.get_embedding("nonexistent");
assert!(unknown_embed.is_none());
}
#[tokio::test]
async fn test_register_custom_llm_provider() {
let mut registry = ProviderRegistry::new().unwrap();
let mock = Arc::new(MockProvider::new());
registry.register_llm("custom", mock);
assert!(registry.has_llm("custom"));
assert!(registry.get_llm("custom").is_some());
}
#[tokio::test]
async fn test_register_custom_embedding_provider() {
let mut registry = ProviderRegistry::new().unwrap();
let mock = Arc::new(MockProvider::new());
registry.register_embedding("custom_embed", mock);
assert!(registry.has_embedding("custom_embed"));
assert!(registry.get_embedding("custom_embed").is_some());
}
#[tokio::test]
async fn test_list_providers() {
let registry = ProviderRegistry::new().unwrap();
let llm_names = registry.list_llm();
assert!(!llm_names.is_empty());
assert!(llm_names.contains(&"mock".to_string()));
let embed_names = registry.list_embedding();
assert!(!embed_names.is_empty());
assert!(embed_names.contains(&"mock".to_string()));
}
#[tokio::test]
async fn test_remove_provider() {
let mut registry = ProviderRegistry::new().unwrap();
let mock = Arc::new(MockProvider::new());
registry.register_llm("to_remove", mock);
assert!(registry.has_llm("to_remove"));
let removed = registry.remove_llm("to_remove");
assert!(removed.is_some());
assert!(!registry.has_llm("to_remove"));
}
#[tokio::test]
async fn test_provider_count() {
let registry = ProviderRegistry::new().unwrap();
assert!(registry.llm_count() > 0);
assert!(registry.embedding_count() > 0);
}
#[tokio::test]
async fn test_overwrite_provider() {
let mut registry = ProviderRegistry::new().unwrap();
let _original = registry.get_llm("mock").unwrap();
let new_mock = Arc::new(MockProvider::new());
registry.register_llm("mock", new_mock);
let _updated = registry.get_llm("mock").unwrap();
assert!(registry.get_llm("mock").is_some());
}
#[test]
fn test_registry_default_is_empty() {
let registry = ProviderRegistry::default();
assert_eq!(registry.llm_count(), 0);
assert_eq!(registry.embedding_count(), 0);
assert!(registry.list_llm().is_empty());
assert!(registry.list_embedding().is_empty());
}
#[test]
fn test_has_on_empty_registry() {
let registry = ProviderRegistry::default();
assert!(!registry.has_llm("anything"));
assert!(!registry.has_embedding("anything"));
}
#[test]
fn test_get_on_empty_registry() {
let registry = ProviderRegistry::default();
assert!(registry.get_llm("mock").is_none());
assert!(registry.get_embedding("mock").is_none());
}
#[test]
fn test_remove_nonexistent_llm() {
let mut registry = ProviderRegistry::default();
let removed = registry.remove_llm("nonexistent");
assert!(removed.is_none());
}
#[test]
fn test_remove_nonexistent_embedding() {
let mut registry = ProviderRegistry::default();
let removed = registry.remove_embedding("nonexistent");
assert!(removed.is_none());
}
#[test]
fn test_remove_embedding_provider() {
let mut registry = ProviderRegistry::default();
let mock = Arc::new(MockProvider::new());
registry.register_embedding("emb1", mock);
assert!(registry.has_embedding("emb1"));
let removed = registry.remove_embedding("emb1");
assert!(removed.is_some());
assert!(!registry.has_embedding("emb1"));
}
#[test]
fn test_clear_llm() {
let mut registry = ProviderRegistry::default();
let mock = Arc::new(MockProvider::new());
registry.register_llm("a", mock.clone());
registry.register_llm("b", mock);
assert_eq!(registry.llm_count(), 2);
registry.clear_llm();
assert_eq!(registry.llm_count(), 0);
assert!(registry.list_llm().is_empty());
}
#[test]
fn test_clear_embedding() {
let mut registry = ProviderRegistry::default();
let mock = Arc::new(MockProvider::new());
registry.register_embedding("x", mock.clone());
registry.register_embedding("y", mock);
assert_eq!(registry.embedding_count(), 2);
registry.clear_embedding();
assert_eq!(registry.embedding_count(), 0);
assert!(registry.list_embedding().is_empty());
}
#[test]
fn test_register_multiple_custom_providers() {
let mut registry = ProviderRegistry::default();
for i in 0..5 {
let mock = Arc::new(MockProvider::new());
registry.register_llm(format!("custom_{}", i), mock);
}
assert_eq!(registry.llm_count(), 5);
for i in 0..5 {
assert!(registry.has_llm(&format!("custom_{}", i)));
}
}
#[test]
fn test_clear_llm_does_not_affect_embedding() {
let mut registry = ProviderRegistry::default();
let mock = Arc::new(MockProvider::new());
registry.register_llm("shared", mock.clone());
registry.register_embedding("shared", mock);
assert_eq!(registry.llm_count(), 1);
assert_eq!(registry.embedding_count(), 1);
registry.clear_llm();
assert_eq!(registry.llm_count(), 0);
assert_eq!(registry.embedding_count(), 1); }
#[test]
fn test_clear_embedding_does_not_affect_llm() {
let mut registry = ProviderRegistry::default();
let mock = Arc::new(MockProvider::new());
registry.register_llm("shared", mock.clone());
registry.register_embedding("shared", mock);
registry.clear_embedding();
assert_eq!(registry.llm_count(), 1); assert_eq!(registry.embedding_count(), 0);
}
#[test]
fn test_get_returns_cloned_arc() {
let mut registry = ProviderRegistry::default();
let mock = Arc::new(MockProvider::new());
registry.register_llm("test", mock);
let p1 = registry.get_llm("test").unwrap();
let p2 = registry.get_llm("test").unwrap();
assert!(Arc::ptr_eq(&p1, &p2));
}
#[test]
fn test_llm_count_after_removal() {
let mut registry = ProviderRegistry::default();
let mock = Arc::new(MockProvider::new());
registry.register_llm("a", mock.clone());
registry.register_llm("b", mock);
assert_eq!(registry.llm_count(), 2);
registry.remove_llm("a");
assert_eq!(registry.llm_count(), 1);
}
}