use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use lru::LruCache;
use sha2::{Digest, Sha256};
use crate::workspace::embeddings::{EmbeddingError, EmbeddingProvider};
#[derive(Debug, Clone)]
pub struct EmbeddingCacheConfig {
pub max_entries: usize,
}
impl Default for EmbeddingCacheConfig {
fn default() -> Self {
Self {
max_entries: crate::config::DEFAULT_EMBEDDING_CACHE_SIZE,
}
}
}
pub struct CachedEmbeddingProvider {
inner: Arc<dyn EmbeddingProvider>,
cache: Mutex<LruCache<[u8; 32], Vec<f32>>>,
}
impl CachedEmbeddingProvider {
pub fn new(inner: Arc<dyn EmbeddingProvider>, config: EmbeddingCacheConfig) -> Self {
let max_entries = config.max_entries.max(1);
if max_entries > 100_000 {
tracing::warn!(
max_entries,
"Embedding cache size exceeds 100,000 entries; memory usage may be significant"
);
}
let cap = NonZeroUsize::new(max_entries).expect("clamped to >= 1"); Self {
inner,
cache: Mutex::new(LruCache::new(cap)),
}
}
pub fn len(&self) -> usize {
self.cache.lock().unwrap_or_else(|e| e.into_inner()).len()
}
pub fn is_empty(&self) -> bool {
self.cache
.lock()
.unwrap_or_else(|e| e.into_inner())
.is_empty()
}
pub fn clear(&self) {
self.cache.lock().unwrap_or_else(|e| e.into_inner()).clear();
}
fn cache_key(&self, text: &str) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(self.inner.model_name().as_bytes());
hasher.update(b"\0");
hasher.update(text.as_bytes());
hasher.finalize().into()
}
}
#[async_trait]
impl EmbeddingProvider for CachedEmbeddingProvider {
fn dimension(&self) -> usize {
self.inner.dimension()
}
fn model_name(&self) -> &str {
self.inner.model_name()
}
fn max_input_length(&self) -> usize {
self.inner.max_input_length()
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let key = self.cache_key(text);
{
let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
if let Some(embedding) = guard.get(&key) {
tracing::trace!("embedding cache hit");
return Ok(embedding.clone());
}
}
let embedding = self.inner.embed(text).await?;
{
let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
if guard.get(&key).is_some() {
tracing::trace!("embedding cache: concurrent insert, skipping clone");
} else {
guard.push(key, embedding.clone());
}
}
tracing::trace!("embedding cache miss");
Ok(embedding)
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let keys: Vec<[u8; 32]> = texts.iter().map(|t| self.cache_key(t)).collect();
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut miss_indices: Vec<usize> = Vec::new();
{
let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
for (i, key) in keys.iter().enumerate() {
if let Some(embedding) = guard.get(key) {
results[i] = Some(embedding.clone());
} else {
miss_indices.push(i);
}
}
}
if miss_indices.is_empty() {
tracing::trace!(count = texts.len(), "embedding batch: all cache hits");
return results
.into_iter()
.enumerate()
.map(|(i, slot)| {
slot.ok_or_else(|| {
EmbeddingError::InvalidResponse(format!(
"embedding slot {i} was not populated"
))
})
})
.collect::<Result<Vec<_>, _>>();
}
let miss_texts: Vec<String> = miss_indices.iter().map(|&i| texts[i].clone()).collect();
let new_embeddings = self.inner.embed_batch(&miss_texts).await?;
if new_embeddings.len() != miss_indices.len() {
return Err(EmbeddingError::InvalidResponse(format!(
"embed_batch returned {} embeddings, expected {}",
new_embeddings.len(),
miss_indices.len()
)));
}
tracing::trace!(
hits = texts.len() - miss_indices.len(),
misses = miss_indices.len(),
"embedding batch: partial cache"
);
{
let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
let cap = guard.cap().get();
let skip = miss_indices.len().saturating_sub(cap);
for (&orig_idx, emb) in miss_indices[skip..].iter().zip(&new_embeddings[skip..]) {
guard.push(keys[orig_idx], emb.clone());
}
}
for (orig_idx, emb) in miss_indices.iter().copied().zip(new_embeddings) {
results[orig_idx] = Some(emb);
}
results
.into_iter()
.enumerate()
.map(|(i, slot)| {
slot.ok_or_else(|| {
EmbeddingError::InvalidResponse(format!("embedding slot {i} was not populated"))
})
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
struct CountingMock {
dimension: usize,
model: String,
embed_calls: AtomicU32,
batch_calls: AtomicU32,
}
impl CountingMock {
fn new(dimension: usize, model: &str) -> Self {
Self {
dimension,
model: model.to_string(),
embed_calls: AtomicU32::new(0),
batch_calls: AtomicU32::new(0),
}
}
fn embed_calls(&self) -> u32 {
self.embed_calls.load(Ordering::SeqCst)
}
fn batch_calls(&self) -> u32 {
self.batch_calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl EmbeddingProvider for CountingMock {
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
&self.model
}
fn max_input_length(&self) -> usize {
10_000
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
self.embed_calls.fetch_add(1, Ordering::SeqCst);
let val = text.len() as f32 / 100.0;
Ok(vec![val; self.dimension])
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
self.batch_calls.fetch_add(1, Ordering::SeqCst);
texts
.iter()
.map(|t| {
let val = t.len() as f32 / 100.0;
Ok(vec![val; self.dimension])
})
.collect()
}
}
#[tokio::test]
async fn cache_hit_avoids_inner_call() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 });
let r1 = cached.embed("hello").await.unwrap();
assert_eq!(inner.embed_calls(), 1);
let r2 = cached.embed("hello").await.unwrap();
assert_eq!(inner.embed_calls(), 1); assert_eq!(r1, r2);
assert_eq!(cached.len(), 1);
}
#[tokio::test]
async fn cache_miss_calls_inner() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 });
cached.embed("hello").await.unwrap();
cached.embed("world").await.unwrap();
assert_eq!(inner.embed_calls(), 2);
assert_eq!(cached.len(), 2);
}
#[tokio::test]
async fn cache_key_includes_model() {
let inner_a = Arc::new(CountingMock::new(4, "model-a"));
let inner_b = Arc::new(CountingMock::new(4, "model-b"));
let cached_a = CachedEmbeddingProvider::new(
inner_a.clone(),
EmbeddingCacheConfig { max_entries: 100 },
);
let cached_b = CachedEmbeddingProvider::new(
inner_b.clone(),
EmbeddingCacheConfig { max_entries: 100 },
);
let key_a = cached_a.cache_key("hello");
let key_b = cached_b.cache_key("hello");
assert_ne!(key_a, key_b);
}
#[tokio::test]
async fn lru_eviction() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 2 });
cached.embed("first").await.unwrap();
cached.embed("second").await.unwrap();
assert_eq!(cached.len(), 2);
cached.embed("third").await.unwrap();
assert_eq!(cached.len(), 2);
assert_eq!(inner.embed_calls(), 3);
cached.embed("first").await.unwrap();
assert_eq!(inner.embed_calls(), 4);
}
#[tokio::test]
async fn embed_batch_partial_hits() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 });
cached.embed("cached").await.unwrap();
assert_eq!(inner.embed_calls(), 1);
let texts = vec![
"cached".to_string(),
"new_one".to_string(),
"new_two".to_string(),
];
let results = cached.embed_batch(&texts).await.unwrap();
assert_eq!(inner.batch_calls(), 1);
assert_eq!(results.len(), 3);
assert_eq!(cached.len(), 3);
}
#[tokio::test]
async fn batch_preserves_order() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 });
cached.embed("bb").await.unwrap();
let texts = vec!["a".to_string(), "bb".to_string(), "ccc".to_string()];
let results = cached.embed_batch(&texts).await.unwrap();
assert_eq!(results.len(), 3);
let expected_a = vec![1.0_f32 / 100.0; 4];
let expected_bb = vec![2.0_f32 / 100.0; 4];
let expected_ccc = vec![3.0_f32 / 100.0; 4];
assert_eq!(results[0], expected_a);
assert_eq!(results[1], expected_bb);
assert_eq!(results[2], expected_ccc);
}
#[tokio::test]
async fn batch_exceeding_capacity_respects_max_entries() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 3 });
let texts: Vec<String> = (0..5).map(|i| format!("text_{i}")).collect();
let results = cached.embed_batch(&texts).await.unwrap();
assert_eq!(results.len(), 5);
let len = cached.len();
assert!(len <= 3, "cache len {len} exceeds max 3");
}
struct FailThenSucceedMock {
dimension: usize,
model: String,
remaining_failures: AtomicU32,
}
impl FailThenSucceedMock {
fn new(dimension: usize, fail_count: u32) -> Self {
Self {
dimension,
model: "fail-mock".to_string(),
remaining_failures: AtomicU32::new(fail_count),
}
}
}
#[async_trait]
impl EmbeddingProvider for FailThenSucceedMock {
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
&self.model
}
fn max_input_length(&self) -> usize {
10_000
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let prev =
self.remaining_failures
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
if v > 0 { Some(v - 1) } else { None }
});
if prev.is_ok() {
return Err(EmbeddingError::HttpError("simulated failure".to_string()));
}
let val = text.len() as f32 / 100.0;
Ok(vec![val; self.dimension])
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let prev =
self.remaining_failures
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
if v > 0 { Some(v - 1) } else { None }
});
if prev.is_ok() {
return Err(EmbeddingError::HttpError("simulated failure".to_string()));
}
texts
.iter()
.map(|t| {
let val = t.len() as f32 / 100.0;
Ok(vec![val; self.dimension])
})
.collect()
}
}
#[tokio::test]
async fn error_does_not_pollute_cache() {
let inner = Arc::new(FailThenSucceedMock::new(4, 1));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 });
let err = cached.embed("hello").await;
assert!(err.is_err());
assert!(cached.is_empty(), "cache should be empty after error");
let result = cached.embed("hello").await;
assert!(result.is_ok());
assert_eq!(cached.len(), 1);
}
#[tokio::test]
async fn embed_batch_empty_input() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 });
let results = cached.embed_batch(&[]).await.unwrap();
assert!(results.is_empty());
assert_eq!(inner.batch_calls(), 0);
}
#[tokio::test]
async fn embed_batch_all_misses() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 100 });
let texts: Vec<String> = vec!["alpha".into(), "beta".into(), "gamma".into()];
let results = cached.embed_batch(&texts).await.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(inner.batch_calls(), 1, "inner called once for misses");
assert_eq!(cached.len(), 3, "all results should be cached");
let results2 = cached.embed_batch(&texts).await.unwrap();
assert_eq!(results2.len(), 3);
assert_eq!(inner.batch_calls(), 1, "no new inner calls");
}
#[tokio::test]
async fn zero_max_entries_clamped_to_one() {
let inner = Arc::new(CountingMock::new(4, "test-model"));
let cached =
CachedEmbeddingProvider::new(inner.clone(), EmbeddingCacheConfig { max_entries: 0 });
cached.embed("hello").await.unwrap();
assert_eq!(cached.len(), 1);
cached.embed("world").await.unwrap();
assert_eq!(cached.len(), 1);
assert_eq!(inner.embed_calls(), 2);
}
}