use super::batch::EmbeddedChunk;
use crate::chunker::Chunk;
use anyhow::Result;
use moka::sync::Cache;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
pub struct EmbeddingCache {
cache: Cache<String, Arc<Vec<f32>>>,
hits: AtomicU64,
misses: AtomicU64,
#[allow(dead_code)] max_memory_mb: usize,
}
impl EmbeddingCache {
pub fn new() -> Self {
Self::with_memory_limit_mb(crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB)
}
pub fn with_memory_limit_mb(max_memory_mb: usize) -> Self {
let max_weight = (max_memory_mb * 1024 * 1024) as u64;
let cache = Cache::builder()
.max_capacity(max_weight)
.weigher(|_key: &String, value: &Arc<Vec<f32>>| {
(value.len() * std::mem::size_of::<f32>()) as u32
})
.build();
Self {
cache,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
max_memory_mb,
}
}
pub fn get(&self, chunk: &Chunk) -> Option<Vec<f32>> {
if let Some(embedding) = self.cache.get(&chunk.hash) {
self.hits.fetch_add(1, Ordering::Relaxed);
Some(embedding.as_ref().clone())
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
#[allow(dead_code)] pub fn put(&self, chunk: &Chunk, embedding: Vec<f32>) {
self.cache.insert(chunk.hash.clone(), Arc::new(embedding));
}
pub fn put_embedded(&self, embedded: &EmbeddedChunk) {
self.cache.insert(
embedded.chunk.hash.clone(),
Arc::new(embedded.embedding.clone()),
);
}
#[allow(dead_code)] pub fn contains(&self, chunk: &Chunk) -> bool {
self.cache.contains_key(&chunk.hash)
}
#[allow(dead_code)] pub fn stats(&self) -> CacheStats {
CacheStats {
size: self.cache.entry_count() as usize,
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
max_memory_mb: self.max_memory_mb,
max_entries: (self.max_memory_mb * 1024 * 1024) / (384 * std::mem::size_of::<f32>()),
}
}
#[allow(dead_code)] pub fn clear(&self) {
self.cache.invalidate_all();
self.cache.run_pending_tasks();
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
}
#[allow(dead_code)] pub fn len(&self) -> usize {
self.cache.run_pending_tasks();
self.cache.entry_count() as usize
}
#[allow(dead_code)] pub fn is_empty(&self) -> bool {
self.cache.run_pending_tasks();
self.cache.entry_count() == 0
}
#[allow(dead_code)] pub fn memory_usage_bytes(&self) -> usize {
self.cache.run_pending_tasks();
self.cache.weighted_size() as usize
}
#[allow(dead_code)] pub fn memory_usage_mb(&self) -> f64 {
self.memory_usage_bytes() as f64 / (1024.0 * 1024.0)
}
}
impl Default for EmbeddingCache {
fn default() -> Self {
Self::new()
}
}
pub struct QueryCache {
cache: Cache<String, Arc<Vec<f32>>>,
hits: AtomicU64,
misses: AtomicU64,
}
impl QueryCache {
pub fn new() -> Self {
Self::with_memory_limit_mb(50)
}
pub fn with_memory_limit_mb(max_memory_mb: usize) -> Self {
let max_weight = (max_memory_mb * 1024 * 1024) as u64;
let cache = Cache::builder()
.max_capacity(max_weight)
.weigher(|_key: &String, value: &Arc<Vec<f32>>| {
(value.len() * std::mem::size_of::<f32>()) as u32
})
.build();
Self {
cache,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
pub fn get(&self, query: &str) -> Option<Vec<f32>> {
if let Some(embedding) = self.cache.get(query) {
self.hits.fetch_add(1, Ordering::Relaxed);
Some(embedding.as_ref().clone())
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub fn put(&self, query: &str, embedding: Vec<f32>) {
self.cache.insert(query.to_string(), Arc::new(embedding));
}
#[allow(dead_code)]
pub fn contains(&self, query: &str) -> bool {
self.cache.contains_key(query)
}
pub fn stats(&self) -> QueryCacheStats {
QueryCacheStats {
size: self.cache.entry_count() as usize,
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
}
}
#[allow(dead_code)]
pub fn clear(&self) {
self.cache.invalidate_all();
self.cache.run_pending_tasks();
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
}
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.cache.run_pending_tasks();
self.cache.entry_count() as usize
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.cache.run_pending_tasks();
self.cache.entry_count() == 0
}
#[allow(dead_code)]
pub fn memory_usage_bytes(&self) -> usize {
self.cache.run_pending_tasks();
self.cache.weighted_size() as usize
}
#[allow(dead_code)]
pub fn memory_usage_mb(&self) -> f64 {
self.memory_usage_bytes() as f64 / (1024.0 * 1024.0)
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct QueryCacheStats {
pub size: usize,
pub hits: u64,
pub misses: u64,
}
impl QueryCacheStats {
#[allow(dead_code)] pub fn hit_rate(&self) -> f32 {
let total = self.hits + self.misses;
if total == 0 {
return 0.0;
}
self.hits as f32 / total as f32
}
#[allow(dead_code)] pub fn total_requests(&self) -> u64 {
self.hits + self.misses
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct CacheStats {
#[allow(dead_code)] pub size: usize,
pub hits: u64,
pub misses: u64,
#[allow(dead_code)] pub max_memory_mb: usize,
#[allow(dead_code)] pub max_entries: usize,
}
impl CacheStats {
#[allow(dead_code)] pub fn hit_rate(&self) -> f32 {
let total = self.hits + self.misses;
if total == 0 {
return 0.0;
}
self.hits as f32 / total as f32
}
#[allow(dead_code)] pub fn total_requests(&self) -> u64 {
self.hits + self.misses
}
}
pub struct CachedBatchEmbedder {
pub batch_embedder: super::batch::BatchEmbedder,
#[allow(dead_code)] cache: EmbeddingCache,
}
impl CachedBatchEmbedder {
#[allow(dead_code)] pub fn new(batch_embedder: super::batch::BatchEmbedder) -> Self {
Self {
batch_embedder,
cache: EmbeddingCache::new(),
}
}
pub fn with_memory_limit(
batch_embedder: super::batch::BatchEmbedder,
max_memory_mb: usize,
) -> Self {
Self {
batch_embedder,
cache: EmbeddingCache::with_memory_limit_mb(max_memory_mb),
}
}
pub fn embed_chunks(&mut self, chunks: Vec<Chunk>) -> Result<Vec<EmbeddedChunk>> {
if chunks.is_empty() {
return Ok(Vec::new());
}
let total = chunks.len();
let mut embedded_chunks = Vec::with_capacity(total);
let mut chunks_to_embed = Vec::new();
let mut cache_indices = Vec::new();
for (idx, chunk) in chunks.iter().enumerate() {
if let Some(embedding) = self.cache.get(chunk) {
embedded_chunks.push(EmbeddedChunk::new(chunk.clone(), embedding));
} else {
chunks_to_embed.push(chunk.clone());
cache_indices.push(idx);
}
}
if !chunks_to_embed.is_empty() {
let newly_embedded = self.batch_embedder.embed_chunks(chunks_to_embed)?;
for embedded in &newly_embedded {
self.cache.put_embedded(embedded);
}
embedded_chunks.extend(newly_embedded);
}
Ok(embedded_chunks)
}
#[allow(dead_code)] pub fn embed_chunk(&mut self, chunk: Chunk) -> Result<EmbeddedChunk> {
if let Some(embedding) = self.cache.get(&chunk) {
return Ok(EmbeddedChunk::new(chunk, embedding));
}
let embedded = self.batch_embedder.embed_chunk(chunk)?;
self.cache.put_embedded(&embedded);
Ok(embedded)
}
#[allow(dead_code)] pub fn cache_stats(&self) -> CacheStats {
self.cache.stats()
}
#[allow(dead_code)] pub fn clear_cache(&self) {
self.cache.clear();
}
pub fn dimensions(&self) -> usize {
self.batch_embedder.dimensions()
}
#[allow(dead_code)] pub fn cache(&self) -> &EmbeddingCache {
&self.cache
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chunker::ChunkKind;
#[test]
fn test_cache_creation() {
let cache = EmbeddingCache::new();
assert_eq!(
cache.max_memory_mb,
crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB
);
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_cache_with_memory_limit() {
let cache = EmbeddingCache::with_memory_limit_mb(100);
assert_eq!(cache.max_memory_mb, 100);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cache_put_get() {
let cache = EmbeddingCache::new();
let chunk = Chunk::new(
"fn test() {}".to_string(),
0,
1,
ChunkKind::Function,
"test.rs".to_string(),
);
let embedding = vec![1.0, 2.0, 3.0];
assert!(cache.get(&chunk).is_none());
cache.put(&chunk, embedding.clone());
assert!(cache.contains(&chunk));
let retrieved = cache.get(&chunk).unwrap();
assert_eq!(retrieved, embedding);
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cache_stats() {
let cache = EmbeddingCache::new();
let chunk1 = Chunk::new(
"fn test1() {}".to_string(),
0,
1,
ChunkKind::Function,
"test.rs".to_string(),
);
let chunk2 = Chunk::new(
"fn test2() {}".to_string(),
2,
3,
ChunkKind::Function,
"test.rs".to_string(),
);
cache.put(&chunk1, vec![1.0, 2.0, 3.0]);
cache.get(&chunk1);
cache.get(&chunk2);
cache.get(&chunk1);
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert_eq!(stats.total_requests(), 3);
assert!((stats.hit_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_cache_clear() {
let cache = EmbeddingCache::new();
let chunk = Chunk::new(
"fn test() {}".to_string(),
0,
1,
ChunkKind::Function,
"test.rs".to_string(),
);
cache.put(&chunk, vec![1.0, 2.0, 3.0]);
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_embedded_chunk_put() {
let cache = EmbeddingCache::new();
let chunk = Chunk::new(
"fn test() {}".to_string(),
0,
1,
ChunkKind::Function,
"test.rs".to_string(),
);
let embedded = EmbeddedChunk::new(chunk.clone(), vec![1.0, 2.0, 3.0]);
cache.put_embedded(&embedded);
assert!(cache.contains(&chunk));
let retrieved = cache.get(&chunk).unwrap();
assert_eq!(retrieved, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_cache_deduplication() {
let cache = EmbeddingCache::new();
let chunk1 = Chunk::new(
"fn test() {}".to_string(),
0,
1,
ChunkKind::Function,
"test.rs".to_string(),
);
let chunk2 = Chunk::new(
"fn test() {}".to_string(),
10,
11,
ChunkKind::Function,
"other.rs".to_string(),
);
assert_eq!(chunk1.hash, chunk2.hash);
cache.put(&chunk1, vec![1.0, 2.0, 3.0]);
assert!(cache.contains(&chunk2));
let retrieved = cache.get(&chunk2).unwrap();
assert_eq!(retrieved, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_memory_usage_tracking() {
let cache = EmbeddingCache::new();
let chunk = Chunk::new(
"fn test() {}".to_string(),
0,
1,
ChunkKind::Function,
"test.rs".to_string(),
);
cache.put(&chunk, vec![1.0, 2.0, 3.0]);
let bytes = cache.memory_usage_bytes();
assert!(bytes > 0);
let mb = cache.memory_usage_mb();
assert!(mb > 0.0 && mb < 1.0); }
#[test]
fn test_cache_with_memory_limit_eviction() {
let cache = EmbeddingCache::with_memory_limit_mb(1);
for i in 0..10 {
let chunk = Chunk::new(
format!("fn test{}() {{}}", i),
0,
1,
ChunkKind::Function,
"test.rs".to_string(),
);
let embedding: Vec<f32> = (0..384).map(|x| x as f32).collect();
cache.put(&chunk, embedding);
}
let stats = cache.stats();
assert!(stats.size < 10, "Cache should have evicted entries");
}
}