use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::RwLock;
use std::sync::atomic::{AtomicU64, Ordering};
use async_trait::async_trait;
use crate::error::EmbeddingError;
use crate::layer1_echo::traits::EmbeddingProvider;
#[derive(Debug, Clone)]
pub struct EmbeddingCacheConfig {
pub max_entries: usize,
pub track_stats: bool,
}
impl Default for EmbeddingCacheConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
track_stats: true,
}
}
}
impl EmbeddingCacheConfig {
#[must_use]
pub fn new(max_entries: usize) -> Self {
Self {
max_entries,
track_stats: true,
}
}
#[must_use]
pub fn with_stats_tracking(mut self, track: bool) -> Self {
self.track_stats = track;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub entries: usize,
}
impl CacheStats {
#[must_use]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
{
(self.hits as f64 / total as f64) * 100.0
}
}
}
}
struct CacheEntry {
embedding: Vec<f32>,
last_access: u64,
}
pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
provider: P,
config: EmbeddingCacheConfig,
cache: RwLock<HashMap<u64, CacheEntry>>,
access_counter: AtomicU64,
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
}
impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
#[must_use]
pub fn new(provider: P, config: EmbeddingCacheConfig) -> Self {
Self {
provider,
config,
cache: RwLock::new(HashMap::new()),
access_counter: AtomicU64::new(0),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
evictions: AtomicU64::new(0),
}
}
#[must_use]
pub fn with_defaults(provider: P) -> Self {
Self::new(provider, EmbeddingCacheConfig::default())
}
#[must_use]
pub fn provider(&self) -> &P {
&self.provider
}
#[must_use]
pub fn stats(&self) -> CacheStats {
let cache = self.cache.read().expect("cache lock poisoned");
CacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
entries: cache.len(),
}
}
pub fn clear_cache(&self) {
let mut cache = self.cache.write().expect("cache lock poisoned");
cache.clear();
}
fn hash_text(text: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
fn get_cached(&self, text_hash: u64) -> Option<Vec<f32>> {
let mut cache = self.cache.write().expect("cache lock poisoned");
if let Some(entry) = cache.get_mut(&text_hash) {
entry.last_access = self.access_counter.fetch_add(1, Ordering::Relaxed);
if self.config.track_stats {
self.hits.fetch_add(1, Ordering::Relaxed);
}
Some(entry.embedding.clone())
} else {
if self.config.track_stats {
self.misses.fetch_add(1, Ordering::Relaxed);
}
None
}
}
fn insert_cached(&self, text_hash: u64, embedding: Vec<f32>) {
let mut cache = self.cache.write().expect("cache lock poisoned");
while cache.len() >= self.config.max_entries {
let lru_key = cache
.iter()
.min_by_key(|(_, entry)| entry.last_access)
.map(|(k, _)| *k);
if let Some(key) = lru_key {
cache.remove(&key);
if self.config.track_stats {
self.evictions.fetch_add(1, Ordering::Relaxed);
}
} else {
break;
}
}
let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
cache.insert(
text_hash,
CacheEntry {
embedding,
last_access: access_time,
},
);
}
}
#[async_trait]
impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let text_hash = Self::hash_text(text);
if let Some(embedding) = self.get_cached(text_hash) {
return Ok(embedding);
}
let embedding = self.provider.embed(text).await?;
self.insert_cached(text_hash, embedding.clone());
Ok(embedding)
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::EmptyInput);
}
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut uncached_indices: Vec<usize> = Vec::new();
let mut uncached_texts: Vec<&str> = Vec::new();
for (i, text) in texts.iter().enumerate() {
let text_hash = Self::hash_text(text);
if let Some(embedding) = self.get_cached(text_hash) {
results[i] = Some(embedding);
} else {
uncached_indices.push(i);
uncached_texts.push(text);
}
}
if uncached_texts.is_empty() {
return Ok(results.into_iter().flatten().collect());
}
let new_embeddings = self.provider.embed_batch(&uncached_texts).await?;
for (i, embedding) in uncached_indices.into_iter().zip(new_embeddings) {
let text_hash = Self::hash_text(texts[i]);
self.insert_cached(text_hash, embedding.clone());
results[i] = Some(embedding);
}
Ok(results.into_iter().flatten().collect())
}
fn dimension(&self) -> usize {
self.provider.dimension()
}
fn model_id(&self) -> &str {
self.provider.model_id()
}
}
#[cfg(disabled)]
mod tests {
use super::*;
use crate::layer1_echo::embedding::MockEmbeddingProvider;
#[tokio::test]
async fn test_cache_hit() {
let provider = MockEmbeddingProvider::new(64);
let cached = CachedEmbeddingProvider::with_defaults(provider);
let emb1 = cached.embed("test text").await.unwrap();
assert_eq!(cached.stats().hits, 0);
assert_eq!(cached.stats().misses, 1);
let emb2 = cached.embed("test text").await.unwrap();
assert_eq!(cached.stats().hits, 1);
assert_eq!(cached.stats().misses, 1);
assert_eq!(emb1, emb2);
}
#[tokio::test]
async fn test_cache_different_texts() {
let provider = MockEmbeddingProvider::new(64);
let cached = CachedEmbeddingProvider::with_defaults(provider);
cached.embed("text 1").await.unwrap();
cached.embed("text 2").await.unwrap();
assert_eq!(cached.stats().misses, 2);
assert_eq!(cached.stats().entries, 2);
}
#[tokio::test]
async fn test_cache_lru_eviction() {
let provider = MockEmbeddingProvider::new(32);
let config = EmbeddingCacheConfig::new(3);
let cached = CachedEmbeddingProvider::new(provider, config);
cached.embed("text 1").await.unwrap();
cached.embed("text 2").await.unwrap();
cached.embed("text 3").await.unwrap();
assert_eq!(cached.stats().entries, 3);
assert_eq!(cached.stats().evictions, 0);
cached.embed("text 1").await.unwrap();
cached.embed("text 4").await.unwrap();
assert_eq!(cached.stats().entries, 3);
assert_eq!(cached.stats().evictions, 1);
let stats_before = cached.stats();
cached.embed("text 1").await.unwrap();
assert_eq!(cached.stats().hits, stats_before.hits + 1);
}
#[tokio::test]
async fn test_cache_batch_partial_hit() {
let provider = MockEmbeddingProvider::new(32);
let cached = CachedEmbeddingProvider::with_defaults(provider);
cached.embed("cached text").await.unwrap();
let texts = vec!["cached text", "new text 1", "new text 2"];
let embeddings = cached.embed_batch(&texts).await.unwrap();
assert_eq!(embeddings.len(), 3);
assert_eq!(cached.stats().hits, 1);
assert_eq!(cached.stats().misses, 3); }
#[tokio::test]
async fn test_cache_batch_all_cached() {
let provider = MockEmbeddingProvider::new(32);
let cached = CachedEmbeddingProvider::with_defaults(provider);
cached.embed("text 1").await.unwrap();
cached.embed("text 2").await.unwrap();
let stats_before = cached.stats();
let embeddings = cached.embed_batch(&["text 1", "text 2"]).await.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(cached.stats().hits, stats_before.hits + 2);
assert_eq!(cached.stats().misses, stats_before.misses);
}
#[tokio::test]
async fn test_cache_clear() {
let provider = MockEmbeddingProvider::new(32);
let cached = CachedEmbeddingProvider::with_defaults(provider);
cached.embed("text 1").await.unwrap();
cached.embed("text 2").await.unwrap();
assert_eq!(cached.stats().entries, 2);
cached.clear_cache();
assert_eq!(cached.stats().entries, 0);
cached.embed("text 1").await.unwrap();
assert_eq!(cached.stats().misses, 3); }
#[tokio::test]
async fn test_cache_stats_hit_rate() {
let stats = CacheStats {
hits: 75,
misses: 25,
evictions: 0,
entries: 100,
};
assert!((stats.hit_rate() - 75.0).abs() < 0.001);
}
#[tokio::test]
async fn test_cache_stats_hit_rate_zero() {
let stats = CacheStats::default();
assert!((stats.hit_rate() - 0.0).abs() < 0.001);
}
#[tokio::test]
async fn test_cache_config_builder() {
let config = EmbeddingCacheConfig::new(5000).with_stats_tracking(false);
assert_eq!(config.max_entries, 5000);
assert!(!config.track_stats);
}
#[cfg(disabled)]
mod proptest_tests {
use super::*;
use proptest::prelude::*;
proptest! {
fn cache_size_never_exceeds_max(
max_entries in 1usize..50,
num_insertions in 1usize..100
) {
tokio_test::block_on(async {
let provider = MockEmbeddingProvider::new(32);
let config = EmbeddingCacheConfig::new(max_entries);
let cached = CachedEmbeddingProvider::new(provider, config);
for i in 0..num_insertions {
let text = format!("text_{}", i);
let _ = cached.embed(&text).await;
}
let stats = cached.stats();
prop_assert!(stats.entries <= max_entries,
"Cache size {} exceeds max_entries {}", stats.entries, max_entries);
Ok(())
});
}
#[test]
fn cache_hits_return_same_value(
text in "[a-z]{5,20}",
num_accesses in 2usize..10
) {
tokio_test::block_on(async {
let provider = MockEmbeddingProvider::new(64);
let cached = CachedEmbeddingProvider::with_defaults(provider);
let first = cached.embed(&text).await.ok();
prop_assert!(first.is_some(), "First embed should succeed");
for _ in 1..num_accesses {
let current = cached.embed(&text).await.ok();
prop_assert_eq!(first.clone(), current,
"Cache hit should return identical value");
}
let stats = cached.stats();
prop_assert_eq!(stats.hits as usize, num_accesses - 1,
"Should have {} cache hits, got {}", num_accesses - 1, stats.hits);
Ok(())
});
}
#[test]
fn different_texts_different_entries(
texts in prop::collection::hash_set("[a-z]{3,10}", 2..20)
) {
tokio_test::block_on(async {
let provider = MockEmbeddingProvider::new(32);
let cached = CachedEmbeddingProvider::with_defaults(provider);
let text_vec: Vec<_> = texts.iter().collect();
for text in &text_vec {
let _ = cached.embed(text).await;
}
let stats = cached.stats();
prop_assert!(stats.entries <= text_vec.len(),
"Cache entries {} should not exceed unique texts {}",
stats.entries, text_vec.len());
Ok(())
});
}
#[test]
fn lru_evicts_least_recently_used(
max_entries in 3usize..10
) {
tokio_test::block_on(async {
let provider = MockEmbeddingProvider::new(32);
let config = EmbeddingCacheConfig::new(max_entries);
let cached = CachedEmbeddingProvider::new(provider, config);
for i in 0..max_entries {
let _ = cached.embed(&format!("text_{}", i)).await;
}
let _ = cached.embed("text_0").await;
let _ = cached.embed("new_text").await;
let stats = cached.stats();
prop_assert!(stats.evictions >= 1,
"Should have at least 1 eviction, got {}", stats.evictions);
prop_assert_eq!(stats.entries, max_entries,
"Cache should maintain max_entries size");
let hits_before = cached.stats().hits;
let _ = cached.embed("text_0").await;
let hits_after = cached.stats().hits;
prop_assert!(hits_after > hits_before,
"text_0 should still be cached (recently accessed)");
Ok(())
});
}
#[test]
fn hit_rate_calculation_correct(
num_unique in 1usize..20,
repeats in 1usize..5
) {
tokio_test::block_on(async {
let provider = MockEmbeddingProvider::new(32);
let cached = CachedEmbeddingProvider::with_defaults(provider);
for i in 0..num_unique {
let _ = cached.embed(&format!("text_{}", i)).await;
}
for _ in 0..repeats {
for i in 0..num_unique {
let _ = cached.embed(&format!("text_{}", i)).await;
}
}
let stats = cached.stats();
let expected_hits = num_unique * repeats;
let expected_misses = num_unique;
prop_assert_eq!(stats.hits as usize, expected_hits,
"Expected {} hits, got {}", expected_hits, stats.hits);
prop_assert_eq!(stats.misses as usize, expected_misses,
"Expected {} misses, got {}", expected_misses, stats.misses);
#[allow(clippy::cast_precision_loss)]
let expected_rate = (expected_hits as f64 / (expected_hits + expected_misses) as f64) * 100.0;
prop_assert!((stats.hit_rate() - expected_rate).abs() < 0.1,
"Hit rate should be ~{:.2}%, got {:.2}%", expected_rate, stats.hit_rate());
Ok(())
});
}
#[test]
fn batch_embed_cache_consistency(
texts in prop::collection::vec("[a-z]{3,10}", 1..15)
) {
tokio_test::block_on(async {
let provider = MockEmbeddingProvider::new(32);
let cached = CachedEmbeddingProvider::with_defaults(provider);
let mut individual_results = Vec::new();
for text in &texts {
individual_results.push(cached.embed(text).await.ok());
}
cached.clear_cache();
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
let batch_results = cached.embed_batch(&text_refs).await.ok();
if let Some(batch) = batch_results {
prop_assert_eq!(batch.len(), texts.len(),
"Batch should return same number of embeddings");
for (i, text) in texts.iter().enumerate() {
let cached_result = cached.embed(text).await.ok();
prop_assert_eq!(Some(batch[i].clone()), cached_result,
"Cached embedding should match batch result at index {}", i);
}
}
Ok(())
});
}
}
}
#[tokio::test]
async fn test_cache_dimension_passthrough() {
let provider = MockEmbeddingProvider::new(256);
let cached = CachedEmbeddingProvider::with_defaults(provider);
assert_eq!(cached.dimension(), 256);
}
#[tokio::test]
async fn test_cache_model_id_passthrough() {
let provider = MockEmbeddingProvider::new(64).with_model_id("test-model");
let cached = CachedEmbeddingProvider::with_defaults(provider);
assert_eq!(cached.model_id(), "test-model");
}
#[tokio::test]
async fn test_cache_empty_batch() {
let provider = MockEmbeddingProvider::new(32);
let cached = CachedEmbeddingProvider::with_defaults(provider);
let result = cached.embed_batch(&[]).await;
assert!(matches!(result, Err(EmbeddingError::EmptyInput)));
}
}