use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
use crate::error::Result;
use crate::model::EmbeddingModel;
use async_trait::async_trait;
use std::sync::Arc;
use tracing::debug;
pub struct CachedEmbeddingService<S> {
inner: Arc<S>,
cache: crate::cache::EmbeddingCache,
}
impl<S: EmbeddingService> CachedEmbeddingService<S> {
pub fn new(inner: Arc<S>, cache_capacity: usize) -> Self {
Self {
inner,
cache: crate::cache::EmbeddingCache::new(cache_capacity),
}
}
pub fn with_default_cache(inner: Arc<S>) -> Self {
Self {
inner,
cache: crate::cache::EmbeddingCache::with_default_capacity(),
}
}
pub fn cache_stats(&self) -> crate::cache::CacheStats {
self.cache.stats()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
}
#[async_trait]
impl<S: EmbeddingService + 'static> EmbeddingService for CachedEmbeddingService<S> {
async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
use crate::error::EmbedError;
if texts.is_empty() {
return Err(EmbedError::InvalidInput("no texts provided".into()));
}
if texts.len() > DEFAULT_MAX_BATCH_SIZE {
return Err(EmbedError::InvalidInput(format!(
"batch size {} exceeds maximum {}",
texts.len(),
DEFAULT_MAX_BATCH_SIZE
)));
}
for text in texts {
if text.len() > MAX_TEXT_CHARS {
return Err(EmbedError::TextTooLong {
length: text.len(),
max: MAX_TEXT_CHARS,
});
}
}
if !self.cache.is_enabled() {
return self.inner.embed(texts, model).await;
}
let model_config = self.inner.model_config(model);
let keys: Vec<_> = texts
.iter()
.map(|t| self.cache.compute_key(t, model_config))
.collect();
let cached = self.cache.get_many(&keys);
let mut to_embed: Vec<(usize, &String)> = Vec::new();
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
for (i, (text, cached_emb)) in texts.iter().zip(cached.into_iter()).enumerate() {
if let Some(arc) = cached_emb {
results[i] = Some(arc.to_vec());
} else {
to_embed.push((i, text));
}
}
if to_embed.is_empty() {
debug!("all {} texts found in cache", texts.len());
return Ok(results.into_iter().flatten().collect());
}
debug!(
"{} texts cached, {} need embedding",
texts.len() - to_embed.len(),
to_embed.len()
);
let texts_to_embed: Vec<String> = to_embed.iter().map(|(_, t)| (*t).clone()).collect();
let new_embeddings = self.inner.embed(&texts_to_embed, model).await?;
if new_embeddings.len() != to_embed.len() {
return Err(EmbedError::InferenceFailed(format!(
"embedding service returned {} vectors for {} inputs",
new_embeddings.len(),
to_embed.len()
)));
}
let mut cache_entries = Vec::with_capacity(to_embed.len());
for ((i, _), embedding) in to_embed.into_iter().zip(new_embeddings.into_iter()) {
cache_entries.push((keys[i], embedding.clone()));
results[i] = Some(embedding);
}
self.cache.put_many(cache_entries);
Ok(results.into_iter().flatten().collect())
}
fn supports_model(&self, model: EmbeddingModel) -> bool {
self.inner.supports_model(model)
}
fn name(&self) -> &'static str {
"cached-embedding"
}
}
const _: () = {
let _ = DEFAULT_MAX_BATCH_SIZE;
let _ = MAX_TEXT_CHARS;
};