use std::collections::HashMap;
use std::sync::Arc;
use crate::error::{EngramError, Result};
use super::Embedder;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EmbeddingProviderInfo {
pub id: String,
pub name: String,
pub model: String,
pub dimensions: usize,
pub requires_api_key: bool,
pub is_local: bool,
}
pub trait EmbeddingProvider: Embedder {
fn provider_info(&self) -> EmbeddingProviderInfo;
}
pub struct EmbeddingRegistry {
providers: HashMap<String, Arc<dyn EmbeddingProvider>>,
order: Vec<String>,
default_id: Option<String>,
}
impl EmbeddingRegistry {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
order: Vec::new(),
default_id: None,
}
}
pub fn register(&mut self, provider: Arc<dyn EmbeddingProvider>) {
let id = provider.provider_info().id.clone();
if !self.providers.contains_key(&id) {
self.order.push(id.clone());
}
self.providers.insert(id, provider);
}
pub fn get(&self, id: &str) -> Option<Arc<dyn EmbeddingProvider>> {
self.providers.get(id).cloned()
}
pub fn list(&self) -> Vec<EmbeddingProviderInfo> {
self.order
.iter()
.filter_map(|id| self.providers.get(id))
.map(|p| p.provider_info())
.collect()
}
pub fn default_provider(&self) -> Option<Arc<dyn EmbeddingProvider>> {
if let Some(ref id) = self.default_id {
if let Some(p) = self.providers.get(id.as_str()) {
return Some(p.clone());
}
}
self.order
.first()
.and_then(|id| self.providers.get(id.as_str()))
.cloned()
}
pub fn set_default(&mut self, id: &str) -> Result<()> {
if self.providers.contains_key(id) {
self.default_id = Some(id.to_string());
Ok(())
} else {
Err(EngramError::InvalidInput(format!(
"No embedding provider registered with id '{id}'"
)))
}
}
pub fn count(&self) -> usize {
self.providers.len()
}
}
impl Default for EmbeddingRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
struct MockProvider {
info: EmbeddingProviderInfo,
}
impl MockProvider {
fn new(id: &str, dimensions: usize) -> Self {
Self {
info: EmbeddingProviderInfo {
id: id.to_string(),
name: format!("Mock ({id})"),
model: format!("mock-{id}"),
dimensions,
requires_api_key: false,
is_local: true,
},
}
}
}
impl Embedder for MockProvider {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![0.0_f32; self.info.dimensions])
}
fn dimensions(&self) -> usize {
self.info.dimensions
}
fn model_name(&self) -> &str {
&self.info.model
}
}
impl EmbeddingProvider for MockProvider {
fn provider_info(&self) -> EmbeddingProviderInfo {
self.info.clone()
}
}
fn make_provider(id: &str) -> Arc<dyn EmbeddingProvider> {
Arc::new(MockProvider::new(id, 64))
}
#[test]
fn test_register_and_get_by_id() {
let mut registry = EmbeddingRegistry::new();
registry.register(make_provider("alpha"));
let provider = registry.get("alpha");
assert!(provider.is_some(), "registered provider should be found");
assert_eq!(provider.unwrap().provider_info().id, "alpha");
}
#[test]
fn test_get_unknown_returns_none() {
let registry = EmbeddingRegistry::new();
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_list_returns_all_providers() {
let mut registry = EmbeddingRegistry::new();
registry.register(make_provider("alpha"));
registry.register(make_provider("beta"));
registry.register(make_provider("gamma"));
let list = registry.list();
assert_eq!(list.len(), 3);
let ids: Vec<&str> = list.iter().map(|i| i.id.as_str()).collect();
assert!(ids.contains(&"alpha"));
assert!(ids.contains(&"beta"));
assert!(ids.contains(&"gamma"));
}
#[test]
fn test_list_preserves_insertion_order() {
let mut registry = EmbeddingRegistry::new();
registry.register(make_provider("first"));
registry.register(make_provider("second"));
registry.register(make_provider("third"));
let ids: Vec<String> = registry.list().into_iter().map(|i| i.id).collect();
assert_eq!(ids, vec!["first", "second", "third"]);
}
#[test]
fn test_default_returns_first_registered() {
let mut registry = EmbeddingRegistry::new();
assert!(
registry.default_provider().is_none(),
"empty registry has no default"
);
registry.register(make_provider("first"));
registry.register(make_provider("second"));
let default = registry.default_provider().expect("should have a default");
assert_eq!(default.provider_info().id, "first");
}
#[test]
fn test_set_default_changes_default() {
let mut registry = EmbeddingRegistry::new();
registry.register(make_provider("alpha"));
registry.register(make_provider("beta"));
registry.set_default("beta").expect("beta is registered");
let default = registry.default_provider().expect("should have a default");
assert_eq!(default.provider_info().id, "beta");
}
#[test]
fn test_set_default_unknown_returns_error() {
let mut registry = EmbeddingRegistry::new();
let result = registry.set_default("does-not-exist");
assert!(result.is_err(), "unknown id should return an error");
}
#[test]
fn test_count() {
let mut registry = EmbeddingRegistry::new();
assert_eq!(registry.count(), 0);
registry.register(make_provider("a"));
assert_eq!(registry.count(), 1);
registry.register(make_provider("b"));
assert_eq!(registry.count(), 2);
}
#[test]
fn test_register_replaces_existing_id() {
let mut registry = EmbeddingRegistry::new();
registry.register(make_provider("a"));
registry.register(Arc::new(MockProvider {
info: EmbeddingProviderInfo {
id: "a".to_string(),
name: "Updated A".to_string(),
model: "updated-model".to_string(),
dimensions: 128,
requires_api_key: true,
is_local: false,
},
}));
assert_eq!(registry.count(), 1);
let info = registry.get("a").unwrap().provider_info();
assert_eq!(info.name, "Updated A");
assert_eq!(info.dimensions, 128);
}
#[test]
fn test_provider_info_fields() {
let info = EmbeddingProviderInfo {
id: "test".to_string(),
name: "Test Provider".to_string(),
model: "test-model-v1".to_string(),
dimensions: 256,
requires_api_key: true,
is_local: false,
};
assert_eq!(info.id, "test");
assert_eq!(info.dimensions, 256);
assert!(info.requires_api_key);
assert!(!info.is_local);
}
#[test]
fn test_embed_via_registry_provider() {
let mut registry = EmbeddingRegistry::new();
registry.register(make_provider("mock"));
let provider = registry.get("mock").expect("mock is registered");
let embedding = provider.embed("hello world").expect("embed should succeed");
assert_eq!(embedding.len(), 64);
}
}