use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use lattice_embed::{
CachedEmbeddingService, EmbeddingModel, EmbeddingService, NativeEmbeddingService,
};
use tokio::sync::OnceCell;
use crate::error::{RuntimeError, RuntimeResult};
#[async_trait]
pub trait EmbedderProvider: Send + Sync {
fn name(&self) -> &str;
fn dimensions(&self) -> usize;
async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>>;
}
pub(crate) struct EmbedderEntry {
provider: Arc<dyn EmbedderProvider>,
cell: Arc<OnceCell<Arc<dyn EmbeddingService>>>,
}
impl Clone for EmbedderEntry {
fn clone(&self) -> Self {
Self {
provider: Arc::clone(&self.provider),
cell: Arc::clone(&self.cell),
}
}
}
#[derive(Clone, Default)]
pub struct EmbedderRegistry {
entries: HashMap<String, EmbedderEntry>,
}
impl EmbedderRegistry {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn register<P: EmbedderProvider + 'static>(&mut self, provider: P) {
let name = provider.name().to_owned();
self.entries.insert(
name,
EmbedderEntry {
provider: Arc::new(provider),
cell: Arc::new(OnceCell::new()),
},
);
}
pub fn get_provider(&self, name: &str) -> Option<&dyn EmbedderProvider> {
self.entries.get(name).map(|e| e.provider.as_ref())
}
pub fn contains(&self, name: &str) -> bool {
self.entries.contains_key(name)
}
pub fn names(&self) -> Vec<String> {
self.entries.keys().cloned().collect()
}
pub(crate) fn get_entry(&self, name: &str) -> Option<EmbedderEntry> {
self.entries.get(name).cloned()
}
pub async fn get_service(&self, name: &str) -> RuntimeResult<Arc<dyn EmbeddingService>> {
let entry = self
.entries
.get(name)
.ok_or_else(|| RuntimeError::UnknownModel(name.to_string()))?
.clone();
entry.resolve().await
}
}
impl EmbedderEntry {
pub(crate) async fn resolve(self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
if let Some(svc) = self.cell.get() {
return Ok(Arc::clone(svc));
}
let svc = self.provider.build().await.map_err(|e| {
crate::error::RuntimeError::Internal(format!(
"EmbedderProvider '{}' build() failed: {e}",
self.provider.name()
))
})?;
let _ = self.cell.set(Arc::clone(&svc));
Ok(svc)
}
}
pub struct LatticeEmbedderProvider {
model: EmbeddingModel,
name: String,
}
impl LatticeEmbedderProvider {
pub fn new(model: EmbeddingModel) -> Self {
let name = model.to_string();
Self { model, name }
}
}
#[async_trait]
impl EmbedderProvider for LatticeEmbedderProvider {
fn name(&self) -> &str {
&self.name
}
fn dimensions(&self) -> usize {
self.model.dimensions()
}
async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
let native = Arc::new(NativeEmbeddingService::with_model(self.model));
let cached = CachedEmbeddingService::with_default_cache(native);
Ok(Arc::new(cached) as Arc<dyn EmbeddingService>)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct ConstVecProvider {
name: String,
dims: usize,
build_calls: Arc<AtomicUsize>,
}
impl ConstVecProvider {
fn new(name: &str, dims: usize) -> Self {
Self {
name: name.to_owned(),
dims,
build_calls: Arc::new(AtomicUsize::new(0)),
}
}
}
struct ConstVecService {
dims: usize,
}
#[async_trait]
impl EmbeddingService for ConstVecService {
async fn embed(
&self,
texts: &[String],
_model: EmbeddingModel,
) -> std::result::Result<Vec<Vec<f32>>, lattice_embed::EmbedError> {
Ok(texts.iter().map(|_| vec![1.0_f32; self.dims]).collect())
}
fn supports_model(&self, _model: EmbeddingModel) -> bool {
true
}
fn name(&self) -> &'static str {
"const-vec-service"
}
}
#[async_trait]
impl EmbedderProvider for ConstVecProvider {
fn name(&self) -> &str {
&self.name
}
fn dimensions(&self) -> usize {
self.dims
}
async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
self.build_calls.fetch_add(1, Ordering::SeqCst);
Ok(Arc::new(ConstVecService { dims: self.dims }))
}
}
#[test]
fn register_and_get_provider_round_trip() {
let mut reg = EmbedderRegistry::new();
reg.register(ConstVecProvider::new("mock-384", 384));
assert!(reg.contains("mock-384"), "registered name must be present");
let provider = reg.get_provider("mock-384").expect("provider must exist");
assert_eq!(provider.name(), "mock-384");
assert_eq!(provider.dimensions(), 384);
}
#[test]
fn duplicate_name_last_wins() {
let mut reg = EmbedderRegistry::new();
reg.register(ConstVecProvider::new("shared", 128));
reg.register(ConstVecProvider::new("shared", 256));
let provider = reg.get_provider("shared").expect("provider must exist");
assert_eq!(
provider.dimensions(),
256,
"last registration must win; expected dims=256"
);
}
#[test]
fn names_returns_all_registered() {
let mut reg = EmbedderRegistry::new();
reg.register(ConstVecProvider::new("model-a", 64));
reg.register(ConstVecProvider::new("model-b", 128));
reg.register(ConstVecProvider::new("model-c", 256));
let mut names = reg.names();
names.sort();
assert_eq!(names, vec!["model-a", "model-b", "model-c"]);
}
#[tokio::test]
async fn get_service_unknown_name_returns_error() {
let reg = EmbedderRegistry::new();
let result = reg.get_service("does-not-exist").await;
let err = result.err().expect("expected Err for unknown name, got Ok");
assert!(
matches!(err, RuntimeError::UnknownModel(ref n) if n == "does-not-exist"),
"expected UnknownModel, got {err:?}"
);
}
#[tokio::test]
async fn get_service_calls_build_once() {
let counter = Arc::new(AtomicUsize::new(0));
let provider = ConstVecProvider {
name: "cached-model".to_owned(),
dims: 32,
build_calls: Arc::clone(&counter),
};
let mut reg = EmbedderRegistry::new();
reg.register(provider);
let _ = reg.get_service("cached-model").await.unwrap();
let _ = reg.get_service("cached-model").await.unwrap();
let _ = reg.get_service("cached-model").await.unwrap();
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"build must be called exactly once regardless of get_service call count"
);
}
}