Skip to main content

codesearch/embed/
cache.rs

1use super::batch::EmbeddedChunk;
2use crate::chunker::Chunk;
3use anyhow::Result;
4use moka::sync::Cache;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8/// Cache for embeddings keyed by chunk hash
9///
10/// Uses Moka for high-performance caching with automatic memory management.
11/// Automatically evicts entries when memory limit is reached using LRU policy.
12/// Chunks are identified by their SHA-256 content hash.
13pub struct EmbeddingCache {
14    cache: Cache<String, Arc<Vec<f32>>>,
15    hits: AtomicU64,
16    misses: AtomicU64,
17    #[allow(dead_code)] // Used in stats()
18    max_memory_mb: usize,
19}
20
21impl EmbeddingCache {
22    /// Create a new empty cache with default memory limit
23    pub fn new() -> Self {
24        Self::with_memory_limit_mb(crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB)
25    }
26
27    /// Create a new cache with specified memory limit in MB
28    pub fn with_memory_limit_mb(max_memory_mb: usize) -> Self {
29        // max_capacity is used as MAX WEIGHT when weigher is provided
30        let max_weight = (max_memory_mb * 1024 * 1024) as u64;
31
32        let cache = Cache::builder()
33            .max_capacity(max_weight)
34            .weigher(|_key: &String, value: &Arc<Vec<f32>>| {
35                (value.len() * std::mem::size_of::<f32>()) as u32
36            })
37            .build();
38
39        Self {
40            cache,
41            hits: AtomicU64::new(0),
42            misses: AtomicU64::new(0),
43            max_memory_mb,
44        }
45    }
46
47    /// Get embedding from cache if available
48    pub fn get(&self, chunk: &Chunk) -> Option<Vec<f32>> {
49        if let Some(embedding) = self.cache.get(&chunk.hash) {
50            self.hits.fetch_add(1, Ordering::Relaxed);
51            Some(embedding.as_ref().clone())
52        } else {
53            self.misses.fetch_add(1, Ordering::Relaxed);
54            None
55        }
56    }
57
58    /// Store embedding in cache (with automatic eviction if needed)
59    #[allow(dead_code)] // Reserved for direct cache access
60    pub fn put(&self, chunk: &Chunk, embedding: Vec<f32>) {
61        self.cache.insert(chunk.hash.clone(), Arc::new(embedding));
62    }
63
64    /// Store an embedded chunk (with automatic eviction if needed)
65    pub fn put_embedded(&self, embedded: &EmbeddedChunk) {
66        self.cache.insert(
67            embedded.chunk.hash.clone(),
68            Arc::new(embedded.embedding.clone()),
69        );
70    }
71
72    /// Check if cache contains embedding for chunk
73    #[allow(dead_code)] // Reserved for cache probing
74    pub fn contains(&self, chunk: &Chunk) -> bool {
75        self.cache.contains_key(&chunk.hash)
76    }
77
78    /// Get cache statistics
79    #[allow(dead_code)] // Part of public API for debugging/monitoring
80    pub fn stats(&self) -> CacheStats {
81        CacheStats {
82            size: self.cache.entry_count() as usize,
83            hits: self.hits.load(Ordering::Relaxed),
84            misses: self.misses.load(Ordering::Relaxed),
85            max_memory_mb: self.max_memory_mb,
86            max_entries: (self.max_memory_mb * 1024 * 1024) / (384 * std::mem::size_of::<f32>()),
87        }
88    }
89
90    /// Clear cache
91    #[allow(dead_code)] // Reserved for cache management
92    pub fn clear(&self) {
93        self.cache.invalidate_all();
94        self.cache.run_pending_tasks();
95        self.hits.store(0, Ordering::Relaxed);
96        self.misses.store(0, Ordering::Relaxed);
97    }
98
99    /// Get cache size (note: Moka cache is eventually consistent)
100    #[allow(dead_code)] // Reserved for cache stats
101    pub fn len(&self) -> usize {
102        self.cache.run_pending_tasks();
103        self.cache.entry_count() as usize
104    }
105
106    /// Check if cache is empty
107    #[allow(dead_code)] // Reserved for cache stats
108    pub fn is_empty(&self) -> bool {
109        self.cache.run_pending_tasks();
110        self.cache.entry_count() == 0
111    }
112
113    /// Get current memory usage estimate (in bytes)
114    #[allow(dead_code)] // Part of public API for debugging/monitoring
115    pub fn memory_usage_bytes(&self) -> usize {
116        self.cache.run_pending_tasks();
117        self.cache.weighted_size() as usize
118    }
119
120    /// Get current memory usage estimate (in MB)
121    #[allow(dead_code)] // Part of public API for debugging/monitoring
122    pub fn memory_usage_mb(&self) -> f64 {
123        self.memory_usage_bytes() as f64 / (1024.0 * 1024.0)
124    }
125}
126
127impl Default for EmbeddingCache {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133/// Query embedding cache for fast repeated searches
134///
135/// Caches query embeddings to avoid re-embedding the same queries.
136/// Query reuse is very high in interactive sessions (e.g., "authentication",
137/// "handle_file_modified"). Uses Moka LRU cache with automatic eviction.
138pub struct QueryCache {
139    cache: Cache<String, Arc<Vec<f32>>>,
140    hits: AtomicU64,
141    misses: AtomicU64,
142}
143
144impl QueryCache {
145    /// Create a new query cache with default limit (50MB)
146    pub fn new() -> Self {
147        Self::with_memory_limit_mb(50)
148    }
149
150    /// Create a query cache with specified memory limit in MB
151    pub fn with_memory_limit_mb(max_memory_mb: usize) -> Self {
152        let max_weight = (max_memory_mb * 1024 * 1024) as u64;
153
154        let cache = Cache::builder()
155            .max_capacity(max_weight)
156            .weigher(|_key: &String, value: &Arc<Vec<f32>>| {
157                (value.len() * std::mem::size_of::<f32>()) as u32
158            })
159            .build();
160
161        Self {
162            cache,
163            hits: AtomicU64::new(0),
164            misses: AtomicU64::new(0),
165        }
166    }
167
168    /// Get query embedding from cache
169    pub fn get(&self, query: &str) -> Option<Vec<f32>> {
170        if let Some(embedding) = self.cache.get(query) {
171            self.hits.fetch_add(1, Ordering::Relaxed);
172            Some(embedding.as_ref().clone())
173        } else {
174            self.misses.fetch_add(1, Ordering::Relaxed);
175            None
176        }
177    }
178
179    /// Store query embedding in cache
180    pub fn put(&self, query: &str, embedding: Vec<f32>) {
181        self.cache.insert(query.to_string(), Arc::new(embedding));
182    }
183
184    /// Check if cache contains query embedding
185    #[allow(dead_code)]
186    pub fn contains(&self, query: &str) -> bool {
187        self.cache.contains_key(query)
188    }
189
190    /// Get cache statistics
191    pub fn stats(&self) -> QueryCacheStats {
192        QueryCacheStats {
193            size: self.cache.entry_count() as usize,
194            hits: self.hits.load(Ordering::Relaxed),
195            misses: self.misses.load(Ordering::Relaxed),
196        }
197    }
198
199    /// Clear cache
200    #[allow(dead_code)]
201    pub fn clear(&self) {
202        self.cache.invalidate_all();
203        self.cache.run_pending_tasks();
204        self.hits.store(0, Ordering::Relaxed);
205        self.misses.store(0, Ordering::Relaxed);
206    }
207
208    /// Get cache size
209    #[allow(dead_code)]
210    pub fn len(&self) -> usize {
211        self.cache.run_pending_tasks();
212        self.cache.entry_count() as usize
213    }
214
215    /// Check if cache is empty
216    #[allow(dead_code)]
217    pub fn is_empty(&self) -> bool {
218        self.cache.run_pending_tasks();
219        self.cache.entry_count() == 0
220    }
221
222    /// Get memory usage in bytes
223    #[allow(dead_code)]
224    pub fn memory_usage_bytes(&self) -> usize {
225        self.cache.run_pending_tasks();
226        self.cache.weighted_size() as usize
227    }
228
229    /// Get memory usage in MB
230    #[allow(dead_code)]
231    pub fn memory_usage_mb(&self) -> f64 {
232        self.memory_usage_bytes() as f64 / (1024.0 * 1024.0)
233    }
234}
235
236impl Default for QueryCache {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242/// Query cache statistics
243#[derive(Debug, Clone)]
244#[allow(dead_code)] // Reserved for debugging/monitoring API
245pub struct QueryCacheStats {
246    pub size: usize,
247    pub hits: u64,
248    pub misses: u64,
249}
250
251impl QueryCacheStats {
252    #[allow(dead_code)] // Part of debugging/monitoring API
253    pub fn hit_rate(&self) -> f32 {
254        let total = self.hits + self.misses;
255        if total == 0 {
256            return 0.0;
257        }
258        self.hits as f32 / total as f32
259    }
260
261    #[allow(dead_code)] // Part of debugging/monitoring API
262    pub fn total_requests(&self) -> u64 {
263        self.hits + self.misses
264    }
265}
266
267/// Cache statistics
268#[derive(Debug, Clone)]
269#[allow(dead_code)] // Part of public API for debugging/monitoring
270pub struct CacheStats {
271    #[allow(dead_code)] // Part of public API for debugging/monitoring
272    pub size: usize,
273    pub hits: u64,
274    pub misses: u64,
275    #[allow(dead_code)] // Part of public API for debugging/monitoring
276    pub max_memory_mb: usize,
277    #[allow(dead_code)] // Part of public API for debugging/monitoring
278    pub max_entries: usize,
279}
280
281impl CacheStats {
282    #[allow(dead_code)] // Part of public API for debugging/monitoring
283    pub fn hit_rate(&self) -> f32 {
284        let total = self.hits + self.misses;
285        if total == 0 {
286            return 0.0;
287        }
288        self.hits as f32 / total as f32
289    }
290
291    #[allow(dead_code)] // Reserved for stats display
292    pub fn total_requests(&self) -> u64 {
293        self.hits + self.misses
294    }
295}
296
297/// Cached batch embedder that uses an embedding cache with memory limits
298pub struct CachedBatchEmbedder {
299    pub batch_embedder: super::batch::BatchEmbedder,
300    #[allow(dead_code)] // Part of public API for debugging/monitoring
301    cache: EmbeddingCache,
302}
303
304impl CachedBatchEmbedder {
305    /// Create a new cached batch embedder with default memory limit
306    #[allow(dead_code)] // Reserved for cached embedding mode
307    pub fn new(batch_embedder: super::batch::BatchEmbedder) -> Self {
308        Self {
309            batch_embedder,
310            cache: EmbeddingCache::new(),
311        }
312    }
313
314    /// Create with custom memory limit (in MB)
315    pub fn with_memory_limit(
316        batch_embedder: super::batch::BatchEmbedder,
317        max_memory_mb: usize,
318    ) -> Self {
319        Self {
320            batch_embedder,
321            cache: EmbeddingCache::with_memory_limit_mb(max_memory_mb),
322        }
323    }
324
325    /// Embed chunks using cache when possible
326    pub fn embed_chunks(&mut self, chunks: Vec<Chunk>) -> Result<Vec<EmbeddedChunk>> {
327        if chunks.is_empty() {
328            return Ok(Vec::new());
329        }
330
331        let total = chunks.len();
332        let mut embedded_chunks = Vec::with_capacity(total);
333        let mut chunks_to_embed = Vec::new();
334        let mut cache_indices = Vec::new();
335
336        // Check cache first (silent - no verbose output)
337        for (idx, chunk) in chunks.iter().enumerate() {
338            if let Some(embedding) = self.cache.get(chunk) {
339                embedded_chunks.push(EmbeddedChunk::new(chunk.clone(), embedding));
340            } else {
341                chunks_to_embed.push(chunk.clone());
342                cache_indices.push(idx);
343            }
344        }
345
346        // Embed remaining chunks
347        if !chunks_to_embed.is_empty() {
348            let newly_embedded = self.batch_embedder.embed_chunks(chunks_to_embed)?;
349
350            // Store in cache (automatic eviction if memory limit reached)
351            for embedded in &newly_embedded {
352                self.cache.put_embedded(embedded);
353            }
354
355            embedded_chunks.extend(newly_embedded);
356        }
357
358        Ok(embedded_chunks)
359    }
360
361    /// Embed a single chunk with caching
362    #[allow(dead_code)] // Reserved for single-chunk caching
363    pub fn embed_chunk(&mut self, chunk: Chunk) -> Result<EmbeddedChunk> {
364        if let Some(embedding) = self.cache.get(&chunk) {
365            return Ok(EmbeddedChunk::new(chunk, embedding));
366        }
367
368        let embedded = self.batch_embedder.embed_chunk(chunk)?;
369        self.cache.put_embedded(&embedded);
370
371        Ok(embedded)
372    }
373
374    /// Get cache statistics
375    #[allow(dead_code)] // Part of public API for debugging/monitoring
376    pub fn cache_stats(&self) -> CacheStats {
377        self.cache.stats()
378    }
379
380    /// Clear cache
381    #[allow(dead_code)] // Reserved for cache reset
382    pub fn clear_cache(&self) {
383        self.cache.clear();
384    }
385
386    /// Get embedding dimensions
387    pub fn dimensions(&self) -> usize {
388        self.batch_embedder.dimensions()
389    }
390
391    /// Get cache reference
392    #[allow(dead_code)] // Part of public API for debugging/monitoring
393    pub fn cache(&self) -> &EmbeddingCache {
394        &self.cache
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::chunker::ChunkKind;
402
403    #[test]
404    fn test_cache_creation() {
405        let cache = EmbeddingCache::new();
406        assert_eq!(
407            cache.max_memory_mb,
408            crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB
409        );
410        assert_eq!(cache.len(), 0);
411        assert!(cache.is_empty());
412    }
413
414    #[test]
415    fn test_cache_with_memory_limit() {
416        let cache = EmbeddingCache::with_memory_limit_mb(100);
417        assert_eq!(cache.max_memory_mb, 100);
418        assert_eq!(cache.len(), 0);
419    }
420
421    #[test]
422    fn test_cache_put_get() {
423        let cache = EmbeddingCache::new();
424
425        let chunk = Chunk::new(
426            "fn test() {}".to_string(),
427            0,
428            1,
429            ChunkKind::Function,
430            "test.rs".to_string(),
431        );
432
433        let embedding = vec![1.0, 2.0, 3.0];
434
435        // Initially not in cache
436        assert!(cache.get(&chunk).is_none());
437
438        // Put in cache
439        cache.put(&chunk, embedding.clone());
440
441        // Now should be in cache
442        assert!(cache.contains(&chunk));
443        let retrieved = cache.get(&chunk).unwrap();
444        assert_eq!(retrieved, embedding);
445
446        assert_eq!(cache.len(), 1);
447    }
448
449    #[test]
450    fn test_cache_stats() {
451        let cache = EmbeddingCache::new();
452
453        let chunk1 = Chunk::new(
454            "fn test1() {}".to_string(),
455            0,
456            1,
457            ChunkKind::Function,
458            "test.rs".to_string(),
459        );
460
461        let chunk2 = Chunk::new(
462            "fn test2() {}".to_string(),
463            2,
464            3,
465            ChunkKind::Function,
466            "test.rs".to_string(),
467        );
468
469        cache.put(&chunk1, vec![1.0, 2.0, 3.0]);
470
471        // Hit
472        cache.get(&chunk1);
473
474        // Miss
475        cache.get(&chunk2);
476
477        // Hit
478        cache.get(&chunk1);
479
480        let stats = cache.stats();
481        assert_eq!(stats.hits, 2);
482        assert_eq!(stats.misses, 1);
483        assert_eq!(stats.total_requests(), 3);
484        assert!((stats.hit_rate() - 0.666).abs() < 0.01);
485    }
486
487    #[test]
488    fn test_cache_clear() {
489        let cache = EmbeddingCache::new();
490
491        let chunk = Chunk::new(
492            "fn test() {}".to_string(),
493            0,
494            1,
495            ChunkKind::Function,
496            "test.rs".to_string(),
497        );
498
499        cache.put(&chunk, vec![1.0, 2.0, 3.0]);
500        assert_eq!(cache.len(), 1);
501
502        cache.clear();
503        assert_eq!(cache.len(), 0);
504        assert!(cache.is_empty());
505    }
506
507    #[test]
508    fn test_embedded_chunk_put() {
509        let cache = EmbeddingCache::new();
510
511        let chunk = Chunk::new(
512            "fn test() {}".to_string(),
513            0,
514            1,
515            ChunkKind::Function,
516            "test.rs".to_string(),
517        );
518
519        let embedded = EmbeddedChunk::new(chunk.clone(), vec![1.0, 2.0, 3.0]);
520
521        cache.put_embedded(&embedded);
522
523        assert!(cache.contains(&chunk));
524        let retrieved = cache.get(&chunk).unwrap();
525        assert_eq!(retrieved, vec![1.0, 2.0, 3.0]);
526    }
527
528    #[test]
529    fn test_cache_deduplication() {
530        let cache = EmbeddingCache::new();
531
532        // Same content = same hash
533        let chunk1 = Chunk::new(
534            "fn test() {}".to_string(),
535            0,
536            1,
537            ChunkKind::Function,
538            "test.rs".to_string(),
539        );
540
541        let chunk2 = Chunk::new(
542            "fn test() {}".to_string(),
543            10,
544            11,
545            ChunkKind::Function,
546            "other.rs".to_string(),
547        );
548
549        // Both should have same hash
550        assert_eq!(chunk1.hash, chunk2.hash);
551
552        // Put with chunk1
553        cache.put(&chunk1, vec![1.0, 2.0, 3.0]);
554
555        // Should be able to retrieve with chunk2 (same content hash)
556        assert!(cache.contains(&chunk2));
557        let retrieved = cache.get(&chunk2).unwrap();
558        assert_eq!(retrieved, vec![1.0, 2.0, 3.0]);
559    }
560
561    #[test]
562    fn test_memory_usage_tracking() {
563        let cache = EmbeddingCache::new();
564
565        let chunk = Chunk::new(
566            "fn test() {}".to_string(),
567            0,
568            1,
569            ChunkKind::Function,
570            "test.rs".to_string(),
571        );
572
573        // Add embedding with 3 floats = 12 bytes
574        cache.put(&chunk, vec![1.0, 2.0, 3.0]);
575
576        let bytes = cache.memory_usage_bytes();
577        assert!(bytes > 0);
578
579        let mb = cache.memory_usage_mb();
580        assert!(mb > 0.0 && mb < 1.0); // Should be < 1 MB
581    }
582
583    #[test]
584    fn test_cache_with_memory_limit_eviction() {
585        // Create a very small cache (1KB)
586        let cache = EmbeddingCache::with_memory_limit_mb(1);
587
588        // This can fit at most ~1-2 embeddings (each ~1536 bytes for 384-dim)
589        for i in 0..10 {
590            let chunk = Chunk::new(
591                format!("fn test{}() {{}}", i),
592                0,
593                1,
594                ChunkKind::Function,
595                "test.rs".to_string(),
596            );
597
598            // Create a 384-dim embedding
599            let embedding: Vec<f32> = (0..384).map(|x| x as f32).collect();
600            cache.put(&chunk, embedding);
601        }
602
603        // Cache should have automatically evicted old entries to stay within limit
604        let stats = cache.stats();
605        assert!(stats.size < 10, "Cache should have evicted entries");
606    }
607}