use std::collections::hash_map::DefaultHasher;
use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use cognis_core::embeddings::Embeddings;
use cognis_core::error::Result;
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub hit_rate: f64,
}
pub trait EmbeddingCache: Send + Sync {
fn get(&self, key: &str) -> Option<Vec<f32>>;
fn put(&self, key: &str, embedding: Vec<f32>);
fn get_many(&self, keys: &[String]) -> Vec<Option<Vec<f32>>>;
fn put_many(&self, entries: &[(String, Vec<f32>)]);
fn clear(&self);
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct InMemoryEmbeddingCache {
store: Arc<RwLock<HashMap<String, Vec<f32>>>>,
order: Arc<RwLock<VecDeque<String>>>,
max_size: Option<usize>,
}
impl InMemoryEmbeddingCache {
pub fn new() -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
order: Arc::new(RwLock::new(VecDeque::new())),
max_size: None,
}
}
pub fn with_max_size(max_size: usize) -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
order: Arc::new(RwLock::new(VecDeque::new())),
max_size: Some(max_size),
}
}
fn evict_if_needed(&self) {
if let Some(max) = self.max_size {
let mut store = self.store.write().unwrap();
let mut order = self.order.write().unwrap();
while store.len() > max {
if let Some(oldest_key) = order.pop_front() {
store.remove(&oldest_key);
} else {
break;
}
}
}
}
}
impl Default for InMemoryEmbeddingCache {
fn default() -> Self {
Self::new()
}
}
impl EmbeddingCache for InMemoryEmbeddingCache {
fn get(&self, key: &str) -> Option<Vec<f32>> {
let store = self.store.read().unwrap();
store.get(key).cloned()
}
fn put(&self, key: &str, embedding: Vec<f32>) {
{
let mut store = self.store.write().unwrap();
let is_new = !store.contains_key(key);
store.insert(key.to_string(), embedding);
if is_new {
let mut order = self.order.write().unwrap();
order.push_back(key.to_string());
}
}
self.evict_if_needed();
}
fn get_many(&self, keys: &[String]) -> Vec<Option<Vec<f32>>> {
let store = self.store.read().unwrap();
keys.iter().map(|k| store.get(k).cloned()).collect()
}
fn put_many(&self, entries: &[(String, Vec<f32>)]) {
{
let mut store = self.store.write().unwrap();
let mut order = self.order.write().unwrap();
for (key, embedding) in entries {
let is_new = !store.contains_key(key);
store.insert(key.clone(), embedding.clone());
if is_new {
order.push_back(key.clone());
}
}
}
self.evict_if_needed();
}
fn clear(&self) {
let mut store = self.store.write().unwrap();
let mut order = self.order.write().unwrap();
store.clear();
order.clear();
}
fn len(&self) -> usize {
let store = self.store.read().unwrap();
store.len()
}
}
pub fn cache_key(text: &str) -> String {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
format!("{:016x}", hasher.finish())
}
pub struct CachedEmbeddings {
inner: Box<dyn Embeddings>,
cache: Box<dyn EmbeddingCache>,
hits: AtomicUsize,
misses: AtomicUsize,
}
impl CachedEmbeddings {
pub fn new(inner: Box<dyn Embeddings>, cache: Box<dyn EmbeddingCache>) -> Self {
Self {
inner,
cache,
hits: AtomicUsize::new(0),
misses: AtomicUsize::new(0),
}
}
pub fn cache_stats(&self) -> CacheStats {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
let hit_rate = if total == 0 {
0.0
} else {
hits as f64 / total as f64
};
CacheStats {
hits,
misses,
hit_rate,
}
}
pub fn reset_stats(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
}
pub fn clear(&self) {
self.cache.clear();
self.reset_stats();
}
pub fn cache(&self) -> &dyn EmbeddingCache {
self.cache.as_ref()
}
}
#[async_trait]
impl Embeddings for CachedEmbeddings {
async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let keys: Vec<String> = texts.iter().map(|t| cache_key(t)).collect();
let cached = self.cache.get_many(&keys);
let mut miss_indices: Vec<usize> = Vec::new();
let mut miss_texts: Vec<String> = Vec::new();
for (i, entry) in cached.iter().enumerate() {
if entry.is_none() {
miss_indices.push(i);
miss_texts.push(texts[i].clone());
}
}
let hit_count = texts.len() - miss_indices.len();
self.hits.fetch_add(hit_count, Ordering::Relaxed);
self.misses.fetch_add(miss_indices.len(), Ordering::Relaxed);
let miss_embeddings = if miss_texts.is_empty() {
Vec::new()
} else {
self.inner.embed_documents(miss_texts).await?
};
let new_entries: Vec<(String, Vec<f32>)> = miss_indices
.iter()
.zip(miss_embeddings.iter())
.map(|(&idx, emb)| (keys[idx].clone(), emb.clone()))
.collect();
if !new_entries.is_empty() {
self.cache.put_many(&new_entries);
}
let mut results: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
let mut miss_iter = miss_embeddings.into_iter();
for entry in cached {
match entry {
Some(emb) => results.push(emb),
None => results.push(miss_iter.next().unwrap()),
}
}
Ok(results)
}
async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
let key = cache_key(text);
if let Some(cached) = self.cache.get(&key) {
self.hits.fetch_add(1, Ordering::Relaxed);
return Ok(cached);
}
self.misses.fetch_add(1, Ordering::Relaxed);
let embedding = self.inner.embed_query(text).await?;
self.cache.put(&key, embedding.clone());
Ok(embedding)
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::embeddings_fake::DeterministicFakeEmbedding;
use std::sync::Arc;
fn make_cached(size: usize) -> CachedEmbeddings {
CachedEmbeddings::new(
Box::new(DeterministicFakeEmbedding::new(size)),
Box::new(InMemoryEmbeddingCache::new()),
)
}
fn make_cached_bounded(size: usize, max_cache: usize) -> CachedEmbeddings {
CachedEmbeddings::new(
Box::new(DeterministicFakeEmbedding::new(size)),
Box::new(InMemoryEmbeddingCache::with_max_size(max_cache)),
)
}
#[tokio::test]
async fn test_cache_miss_calls_inner() {
let cached = make_cached(8);
let result = cached
.embed_documents(vec!["hello".to_string()])
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 8);
let stats = cached.cache_stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
}
#[tokio::test]
async fn test_cache_hit_returns_cached_value() {
let cached = make_cached(8);
let first = cached
.embed_documents(vec!["hello".to_string()])
.await
.unwrap();
let second = cached
.embed_documents(vec!["hello".to_string()])
.await
.unwrap();
assert_eq!(first, second);
let stats = cached.cache_stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[tokio::test]
async fn test_mixed_hits_and_misses_in_batch() {
let cached = make_cached(8);
cached
.embed_documents(vec!["hello".to_string()])
.await
.unwrap();
let results = cached
.embed_documents(vec!["hello".to_string(), "world".to_string()])
.await
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 8);
assert_eq!(results[1].len(), 8);
let stats = cached.cache_stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 2);
}
#[tokio::test]
async fn test_cache_stats_tracking() {
let cached = make_cached(4);
cached
.embed_documents(vec!["a".to_string(), "b".to_string(), "c".to_string()])
.await
.unwrap();
cached
.embed_documents(vec!["a".to_string(), "b".to_string(), "c".to_string()])
.await
.unwrap();
let stats = cached.cache_stats();
assert_eq!(stats.hits, 3);
assert_eq!(stats.misses, 3);
assert!((stats.hit_rate - 0.5).abs() < f64::EPSILON);
}
#[tokio::test]
async fn test_lru_eviction_when_max_size_exceeded() {
let cached = make_cached_bounded(4, 2);
cached.embed_documents(vec!["a".to_string()]).await.unwrap();
cached.embed_documents(vec!["b".to_string()]).await.unwrap();
cached.embed_documents(vec!["c".to_string()]).await.unwrap();
assert_eq!(cached.cache().len(), 2);
let key_a = cache_key("a");
assert!(cached.cache().get(&key_a).is_none());
let key_b = cache_key("b");
let key_c = cache_key("c");
assert!(cached.cache().get(&key_b).is_some());
assert!(cached.cache().get(&key_c).is_some());
}
#[tokio::test]
async fn test_embed_query_caching() {
let cached = make_cached(8);
let first = cached.embed_query("test query").await.unwrap();
let second = cached.embed_query("test query").await.unwrap();
assert_eq!(first, second);
let stats = cached.cache_stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[tokio::test]
async fn test_clear_cache_resets_stats() {
let cached = make_cached(4);
cached.embed_query("foo").await.unwrap();
cached.embed_query("foo").await.unwrap();
assert_eq!(cached.cache().len(), 1);
assert_eq!(cached.cache_stats().hits, 1);
cached.clear();
assert_eq!(cached.cache().len(), 0);
let stats = cached.cache_stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
}
#[tokio::test]
async fn test_thread_safety_concurrent_access() {
let cached = Arc::new(make_cached(8));
let mut handles = Vec::new();
for i in 0..10 {
let cached_clone = Arc::clone(&cached);
handles.push(tokio::spawn(async move {
let text = format!("text_{}", i);
cached_clone.embed_query(&text).await.unwrap()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
assert_eq!(results.len(), 10);
assert_eq!(cached.cache().len(), 10);
let stats = cached.cache_stats();
assert_eq!(stats.misses, 10);
assert_eq!(stats.hits, 0);
}
#[tokio::test]
async fn test_empty_input_handling() {
let cached = make_cached(8);
let result = cached.embed_documents(vec![]).await.unwrap();
assert!(result.is_empty());
let stats = cached.cache_stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
}
#[tokio::test]
async fn test_cache_key_consistency() {
let key1 = cache_key("consistent text");
let key2 = cache_key("consistent text");
assert_eq!(key1, key2);
let key3 = cache_key("different text");
assert_ne!(key1, key3);
}
#[tokio::test]
async fn test_embed_query_and_documents_share_cache() {
let cached = make_cached(8);
let query_result = cached.embed_query("shared text").await.unwrap();
let doc_results = cached
.embed_documents(vec!["shared text".to_string()])
.await
.unwrap();
assert_eq!(query_result, doc_results[0]);
let stats = cached.cache_stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 1);
}
#[test]
fn test_in_memory_cache_is_empty() {
let cache = InMemoryEmbeddingCache::new();
assert!(cache.is_empty());
cache.put("key", vec![1.0, 2.0]);
assert!(!cache.is_empty());
}
#[test]
fn test_cache_stats_zero_lookups() {
let cached = make_cached(4);
let stats = cached.cache_stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.hit_rate, 0.0);
}
}