1use super::batch::EmbeddedChunk;
2use crate::chunker::Chunk;
3use anyhow::Result;
4use moka::sync::Cache;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8pub struct EmbeddingCache {
14 cache: Cache<String, Arc<Vec<f32>>>,
15 hits: AtomicU64,
16 misses: AtomicU64,
17 #[allow(dead_code)] max_memory_mb: usize,
19}
20
21impl EmbeddingCache {
22 pub fn new() -> Self {
24 Self::with_memory_limit_mb(crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB)
25 }
26
27 pub fn with_memory_limit_mb(max_memory_mb: usize) -> Self {
29 let max_weight = (max_memory_mb * 1024 * 1024) as u64;
31
32 let cache = Cache::builder()
33 .max_capacity(max_weight)
34 .weigher(|_key: &String, value: &Arc<Vec<f32>>| {
35 (value.len() * std::mem::size_of::<f32>()) as u32
36 })
37 .build();
38
39 Self {
40 cache,
41 hits: AtomicU64::new(0),
42 misses: AtomicU64::new(0),
43 max_memory_mb,
44 }
45 }
46
47 pub fn get(&self, chunk: &Chunk) -> Option<Vec<f32>> {
49 if let Some(embedding) = self.cache.get(&chunk.hash) {
50 self.hits.fetch_add(1, Ordering::Relaxed);
51 Some(embedding.as_ref().clone())
52 } else {
53 self.misses.fetch_add(1, Ordering::Relaxed);
54 None
55 }
56 }
57
58 #[allow(dead_code)] pub fn put(&self, chunk: &Chunk, embedding: Vec<f32>) {
61 self.cache.insert(chunk.hash.clone(), Arc::new(embedding));
62 }
63
64 pub fn put_embedded(&self, embedded: &EmbeddedChunk) {
66 self.cache.insert(
67 embedded.chunk.hash.clone(),
68 Arc::new(embedded.embedding.clone()),
69 );
70 }
71
72 #[allow(dead_code)] pub fn contains(&self, chunk: &Chunk) -> bool {
75 self.cache.contains_key(&chunk.hash)
76 }
77
78 #[allow(dead_code)] pub fn stats(&self) -> CacheStats {
81 CacheStats {
82 size: self.cache.entry_count() as usize,
83 hits: self.hits.load(Ordering::Relaxed),
84 misses: self.misses.load(Ordering::Relaxed),
85 max_memory_mb: self.max_memory_mb,
86 max_entries: (self.max_memory_mb * 1024 * 1024) / (384 * std::mem::size_of::<f32>()),
87 }
88 }
89
90 #[allow(dead_code)] pub fn clear(&self) {
93 self.cache.invalidate_all();
94 self.cache.run_pending_tasks();
95 self.hits.store(0, Ordering::Relaxed);
96 self.misses.store(0, Ordering::Relaxed);
97 }
98
99 #[allow(dead_code)] pub fn len(&self) -> usize {
102 self.cache.run_pending_tasks();
103 self.cache.entry_count() as usize
104 }
105
106 #[allow(dead_code)] pub fn is_empty(&self) -> bool {
109 self.cache.run_pending_tasks();
110 self.cache.entry_count() == 0
111 }
112
113 #[allow(dead_code)] pub fn memory_usage_bytes(&self) -> usize {
116 self.cache.run_pending_tasks();
117 self.cache.weighted_size() as usize
118 }
119
120 #[allow(dead_code)] pub fn memory_usage_mb(&self) -> f64 {
123 self.memory_usage_bytes() as f64 / (1024.0 * 1024.0)
124 }
125}
126
127impl Default for EmbeddingCache {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133pub struct QueryCache {
139 cache: Cache<String, Arc<Vec<f32>>>,
140 hits: AtomicU64,
141 misses: AtomicU64,
142}
143
144impl QueryCache {
145 pub fn new() -> Self {
147 Self::with_memory_limit_mb(50)
148 }
149
150 pub fn with_memory_limit_mb(max_memory_mb: usize) -> Self {
152 let max_weight = (max_memory_mb * 1024 * 1024) as u64;
153
154 let cache = Cache::builder()
155 .max_capacity(max_weight)
156 .weigher(|_key: &String, value: &Arc<Vec<f32>>| {
157 (value.len() * std::mem::size_of::<f32>()) as u32
158 })
159 .build();
160
161 Self {
162 cache,
163 hits: AtomicU64::new(0),
164 misses: AtomicU64::new(0),
165 }
166 }
167
168 pub fn get(&self, query: &str) -> Option<Vec<f32>> {
170 if let Some(embedding) = self.cache.get(query) {
171 self.hits.fetch_add(1, Ordering::Relaxed);
172 Some(embedding.as_ref().clone())
173 } else {
174 self.misses.fetch_add(1, Ordering::Relaxed);
175 None
176 }
177 }
178
179 pub fn put(&self, query: &str, embedding: Vec<f32>) {
181 self.cache.insert(query.to_string(), Arc::new(embedding));
182 }
183
184 #[allow(dead_code)]
186 pub fn contains(&self, query: &str) -> bool {
187 self.cache.contains_key(query)
188 }
189
190 pub fn stats(&self) -> QueryCacheStats {
192 QueryCacheStats {
193 size: self.cache.entry_count() as usize,
194 hits: self.hits.load(Ordering::Relaxed),
195 misses: self.misses.load(Ordering::Relaxed),
196 }
197 }
198
199 #[allow(dead_code)]
201 pub fn clear(&self) {
202 self.cache.invalidate_all();
203 self.cache.run_pending_tasks();
204 self.hits.store(0, Ordering::Relaxed);
205 self.misses.store(0, Ordering::Relaxed);
206 }
207
208 #[allow(dead_code)]
210 pub fn len(&self) -> usize {
211 self.cache.run_pending_tasks();
212 self.cache.entry_count() as usize
213 }
214
215 #[allow(dead_code)]
217 pub fn is_empty(&self) -> bool {
218 self.cache.run_pending_tasks();
219 self.cache.entry_count() == 0
220 }
221
222 #[allow(dead_code)]
224 pub fn memory_usage_bytes(&self) -> usize {
225 self.cache.run_pending_tasks();
226 self.cache.weighted_size() as usize
227 }
228
229 #[allow(dead_code)]
231 pub fn memory_usage_mb(&self) -> f64 {
232 self.memory_usage_bytes() as f64 / (1024.0 * 1024.0)
233 }
234}
235
236impl Default for QueryCache {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242#[derive(Debug, Clone)]
244#[allow(dead_code)] pub struct QueryCacheStats {
246 pub size: usize,
247 pub hits: u64,
248 pub misses: u64,
249}
250
251impl QueryCacheStats {
252 #[allow(dead_code)] pub fn hit_rate(&self) -> f32 {
254 let total = self.hits + self.misses;
255 if total == 0 {
256 return 0.0;
257 }
258 self.hits as f32 / total as f32
259 }
260
261 #[allow(dead_code)] pub fn total_requests(&self) -> u64 {
263 self.hits + self.misses
264 }
265}
266
267#[derive(Debug, Clone)]
269#[allow(dead_code)] pub struct CacheStats {
271 #[allow(dead_code)] pub size: usize,
273 pub hits: u64,
274 pub misses: u64,
275 #[allow(dead_code)] pub max_memory_mb: usize,
277 #[allow(dead_code)] pub max_entries: usize,
279}
280
281impl CacheStats {
282 #[allow(dead_code)] pub fn hit_rate(&self) -> f32 {
284 let total = self.hits + self.misses;
285 if total == 0 {
286 return 0.0;
287 }
288 self.hits as f32 / total as f32
289 }
290
291 #[allow(dead_code)] pub fn total_requests(&self) -> u64 {
293 self.hits + self.misses
294 }
295}
296
297pub struct CachedBatchEmbedder {
299 pub batch_embedder: super::batch::BatchEmbedder,
300 #[allow(dead_code)] cache: EmbeddingCache,
302}
303
304impl CachedBatchEmbedder {
305 #[allow(dead_code)] pub fn new(batch_embedder: super::batch::BatchEmbedder) -> Self {
308 Self {
309 batch_embedder,
310 cache: EmbeddingCache::new(),
311 }
312 }
313
314 pub fn with_memory_limit(
316 batch_embedder: super::batch::BatchEmbedder,
317 max_memory_mb: usize,
318 ) -> Self {
319 Self {
320 batch_embedder,
321 cache: EmbeddingCache::with_memory_limit_mb(max_memory_mb),
322 }
323 }
324
325 pub fn embed_chunks(&mut self, chunks: Vec<Chunk>) -> Result<Vec<EmbeddedChunk>> {
327 if chunks.is_empty() {
328 return Ok(Vec::new());
329 }
330
331 let total = chunks.len();
332 let mut embedded_chunks = Vec::with_capacity(total);
333 let mut chunks_to_embed = Vec::new();
334 let mut cache_indices = Vec::new();
335
336 for (idx, chunk) in chunks.iter().enumerate() {
338 if let Some(embedding) = self.cache.get(chunk) {
339 embedded_chunks.push(EmbeddedChunk::new(chunk.clone(), embedding));
340 } else {
341 chunks_to_embed.push(chunk.clone());
342 cache_indices.push(idx);
343 }
344 }
345
346 if !chunks_to_embed.is_empty() {
348 let newly_embedded = self.batch_embedder.embed_chunks(chunks_to_embed)?;
349
350 for embedded in &newly_embedded {
352 self.cache.put_embedded(embedded);
353 }
354
355 embedded_chunks.extend(newly_embedded);
356 }
357
358 Ok(embedded_chunks)
359 }
360
361 #[allow(dead_code)] pub fn embed_chunk(&mut self, chunk: Chunk) -> Result<EmbeddedChunk> {
364 if let Some(embedding) = self.cache.get(&chunk) {
365 return Ok(EmbeddedChunk::new(chunk, embedding));
366 }
367
368 let embedded = self.batch_embedder.embed_chunk(chunk)?;
369 self.cache.put_embedded(&embedded);
370
371 Ok(embedded)
372 }
373
374 #[allow(dead_code)] pub fn cache_stats(&self) -> CacheStats {
377 self.cache.stats()
378 }
379
380 #[allow(dead_code)] pub fn clear_cache(&self) {
383 self.cache.clear();
384 }
385
386 pub fn dimensions(&self) -> usize {
388 self.batch_embedder.dimensions()
389 }
390
391 #[allow(dead_code)] pub fn cache(&self) -> &EmbeddingCache {
394 &self.cache
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::chunker::ChunkKind;
402
403 #[test]
404 fn test_cache_creation() {
405 let cache = EmbeddingCache::new();
406 assert_eq!(
407 cache.max_memory_mb,
408 crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB
409 );
410 assert_eq!(cache.len(), 0);
411 assert!(cache.is_empty());
412 }
413
414 #[test]
415 fn test_cache_with_memory_limit() {
416 let cache = EmbeddingCache::with_memory_limit_mb(100);
417 assert_eq!(cache.max_memory_mb, 100);
418 assert_eq!(cache.len(), 0);
419 }
420
421 #[test]
422 fn test_cache_put_get() {
423 let cache = EmbeddingCache::new();
424
425 let chunk = Chunk::new(
426 "fn test() {}".to_string(),
427 0,
428 1,
429 ChunkKind::Function,
430 "test.rs".to_string(),
431 );
432
433 let embedding = vec![1.0, 2.0, 3.0];
434
435 assert!(cache.get(&chunk).is_none());
437
438 cache.put(&chunk, embedding.clone());
440
441 assert!(cache.contains(&chunk));
443 let retrieved = cache.get(&chunk).unwrap();
444 assert_eq!(retrieved, embedding);
445
446 assert_eq!(cache.len(), 1);
447 }
448
449 #[test]
450 fn test_cache_stats() {
451 let cache = EmbeddingCache::new();
452
453 let chunk1 = Chunk::new(
454 "fn test1() {}".to_string(),
455 0,
456 1,
457 ChunkKind::Function,
458 "test.rs".to_string(),
459 );
460
461 let chunk2 = Chunk::new(
462 "fn test2() {}".to_string(),
463 2,
464 3,
465 ChunkKind::Function,
466 "test.rs".to_string(),
467 );
468
469 cache.put(&chunk1, vec![1.0, 2.0, 3.0]);
470
471 cache.get(&chunk1);
473
474 cache.get(&chunk2);
476
477 cache.get(&chunk1);
479
480 let stats = cache.stats();
481 assert_eq!(stats.hits, 2);
482 assert_eq!(stats.misses, 1);
483 assert_eq!(stats.total_requests(), 3);
484 assert!((stats.hit_rate() - 0.666).abs() < 0.01);
485 }
486
487 #[test]
488 fn test_cache_clear() {
489 let cache = EmbeddingCache::new();
490
491 let chunk = Chunk::new(
492 "fn test() {}".to_string(),
493 0,
494 1,
495 ChunkKind::Function,
496 "test.rs".to_string(),
497 );
498
499 cache.put(&chunk, vec![1.0, 2.0, 3.0]);
500 assert_eq!(cache.len(), 1);
501
502 cache.clear();
503 assert_eq!(cache.len(), 0);
504 assert!(cache.is_empty());
505 }
506
507 #[test]
508 fn test_embedded_chunk_put() {
509 let cache = EmbeddingCache::new();
510
511 let chunk = Chunk::new(
512 "fn test() {}".to_string(),
513 0,
514 1,
515 ChunkKind::Function,
516 "test.rs".to_string(),
517 );
518
519 let embedded = EmbeddedChunk::new(chunk.clone(), vec![1.0, 2.0, 3.0]);
520
521 cache.put_embedded(&embedded);
522
523 assert!(cache.contains(&chunk));
524 let retrieved = cache.get(&chunk).unwrap();
525 assert_eq!(retrieved, vec![1.0, 2.0, 3.0]);
526 }
527
528 #[test]
529 fn test_cache_deduplication() {
530 let cache = EmbeddingCache::new();
531
532 let chunk1 = Chunk::new(
534 "fn test() {}".to_string(),
535 0,
536 1,
537 ChunkKind::Function,
538 "test.rs".to_string(),
539 );
540
541 let chunk2 = Chunk::new(
542 "fn test() {}".to_string(),
543 10,
544 11,
545 ChunkKind::Function,
546 "other.rs".to_string(),
547 );
548
549 assert_eq!(chunk1.hash, chunk2.hash);
551
552 cache.put(&chunk1, vec![1.0, 2.0, 3.0]);
554
555 assert!(cache.contains(&chunk2));
557 let retrieved = cache.get(&chunk2).unwrap();
558 assert_eq!(retrieved, vec![1.0, 2.0, 3.0]);
559 }
560
561 #[test]
562 fn test_memory_usage_tracking() {
563 let cache = EmbeddingCache::new();
564
565 let chunk = Chunk::new(
566 "fn test() {}".to_string(),
567 0,
568 1,
569 ChunkKind::Function,
570 "test.rs".to_string(),
571 );
572
573 cache.put(&chunk, vec![1.0, 2.0, 3.0]);
575
576 let bytes = cache.memory_usage_bytes();
577 assert!(bytes > 0);
578
579 let mb = cache.memory_usage_mb();
580 assert!(mb > 0.0 && mb < 1.0); }
582
583 #[test]
584 fn test_cache_with_memory_limit_eviction() {
585 let cache = EmbeddingCache::with_memory_limit_mb(1);
587
588 for i in 0..10 {
590 let chunk = Chunk::new(
591 format!("fn test{}() {{}}", i),
592 0,
593 1,
594 ChunkKind::Function,
595 "test.rs".to_string(),
596 );
597
598 let embedding: Vec<f32> = (0..384).map(|x| x as f32).collect();
600 cache.put(&chunk, embedding);
601 }
602
603 let stats = cache.stats();
605 assert!(stats.size < 10, "Cache should have evicted entries");
606 }
607}