rexis_rag/caching/
embedding_cache.rs

1//! # Embedding Cache Implementation
2//!
3//! High-performance caching for embedding computations with compression.
4
5use super::{Cache, CacheEntryMetadata, CacheStats, EmbeddingCacheConfig, EmbeddingCacheEntry};
6use crate::RragResult;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::hash::{DefaultHasher, Hash, Hasher};
10
11/// Embedding cache with compression and deduplication
12pub struct EmbeddingCache {
13    /// Configuration
14    config: EmbeddingCacheConfig,
15
16    /// Main storage
17    storage: HashMap<String, EmbeddingCacheEntry>,
18
19    /// Text hash to full key mapping for deduplication
20    deduplication_map: HashMap<String, String>,
21
22    /// Compressed embeddings storage
23    compressed_storage: Option<HashMap<String, CompressedEmbedding>>,
24
25    /// Cache statistics
26    stats: CacheStats,
27}
28
29/// Compressed embedding representation
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct CompressedEmbedding {
32    /// Quantized embedding values
33    pub quantized_values: Vec<u8>,
34
35    /// Quantization parameters
36    pub scale: f32,
37    pub offset: f32,
38
39    /// Original dimension
40    pub dimension: usize,
41
42    /// Compression ratio achieved
43    pub compression_ratio: f32,
44}
45
46/// Embedding compression methods
47#[derive(Debug, Clone, Copy)]
48pub enum CompressionMethod {
49    /// No compression
50    None,
51
52    /// Simple quantization to 8-bit integers
53    Quantization8Bit,
54
55    /// Principal Component Analysis dimensionality reduction
56    PCA,
57
58    /// Product quantization
59    ProductQuantization,
60
61    /// Binary quantization
62    BinaryQuantization,
63}
64
65/// Embedding deduplication statistics
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct DeduplicationStats {
68    /// Number of deduplicated entries
69    pub deduplicated_count: usize,
70
71    /// Memory saved through deduplication (bytes)
72    pub memory_saved: usize,
73
74    /// Deduplication ratio (0.0 to 1.0)
75    pub deduplication_ratio: f32,
76}
77
78impl EmbeddingCache {
79    /// Create new embedding cache
80    pub fn new(config: EmbeddingCacheConfig) -> RragResult<Self> {
81        let compressed_storage = if config.compression_enabled {
82            Some(HashMap::new())
83        } else {
84            None
85        };
86
87        Ok(Self {
88            config,
89            storage: HashMap::new(),
90            deduplication_map: HashMap::new(),
91            compressed_storage,
92            stats: CacheStats::default(),
93        })
94    }
95
96    /// Get embedding with automatic decompression
97    pub fn get_embedding(&self, text: &str, model: &str) -> Option<Vec<f32>> {
98        let key = self.make_key(text, model);
99
100        // Try direct lookup
101        if let Some(entry) = self.storage.get(&key) {
102            return Some(entry.embedding.clone());
103        }
104
105        // Try deduplication lookup
106        let text_hash = self.hash_text(text);
107        if let Some(canonical_key) = self.deduplication_map.get(&text_hash) {
108            if let Some(entry) = self.storage.get(canonical_key) {
109                return Some(entry.embedding.clone());
110            }
111        }
112
113        // Try compressed storage
114        if let Some(compressed_storage) = &self.compressed_storage {
115            if let Some(compressed) = compressed_storage.get(&key) {
116                return Some(self.decompress_embedding(compressed));
117            }
118        }
119
120        None
121    }
122
123    /// Cache embedding with compression and deduplication
124    pub fn cache_embedding(
125        &mut self,
126        text: String,
127        model: String,
128        embedding: Vec<f32>,
129    ) -> RragResult<()> {
130        let key = self.make_key(&text, &model);
131        let text_hash = self.hash_text(&text);
132
133        // Check for deduplication opportunity
134        if let Some(existing_key) = self.deduplication_map.get(&text_hash) {
135            // Text already cached, just add reference
136            if !self.storage.contains_key(&key) {
137                if let Some(existing_entry) = self.storage.get(existing_key).cloned() {
138                    self.storage.insert(key, existing_entry);
139                }
140            }
141            return Ok(());
142        }
143
144        // Check capacity
145        if self.storage.len() >= self.config.max_size {
146            self.evict_entry()?;
147        }
148
149        // Create entry
150        let entry = EmbeddingCacheEntry {
151            text: text.clone(),
152            text_hash: text_hash.clone(),
153            embedding: embedding.clone(),
154            model: model.clone(),
155            metadata: CacheEntryMetadata::new(),
156        };
157
158        // Store with or without compression
159        if self.config.compression_enabled {
160            let compressed = self.compress_embedding(&embedding);
161            if let Some(compressed_storage) = &mut self.compressed_storage {
162                compressed_storage.insert(key.clone(), compressed);
163            }
164
165            // Store metadata only in main storage
166            let mut metadata_entry = entry;
167            metadata_entry.embedding = Vec::new(); // Clear to save memory
168            self.storage.insert(key.clone(), metadata_entry);
169        } else {
170            self.storage.insert(key.clone(), entry);
171        }
172
173        // Update deduplication map
174        self.deduplication_map.insert(text_hash, key);
175
176        Ok(())
177    }
178
179    /// Compress embedding using configured method
180    fn compress_embedding(&self, embedding: &[f32]) -> CompressedEmbedding {
181        // Simple 8-bit quantization for now
182        let (min_val, max_val) = embedding
183            .iter()
184            .fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), &val| {
185                (min.min(val), max.max(val))
186            });
187
188        let range = max_val - min_val;
189        let scale = range / 255.0;
190        let offset = min_val;
191
192        let quantized_values: Vec<u8> = embedding
193            .iter()
194            .map(|&val| {
195                let normalized = (val - offset) / scale;
196                normalized.round().clamp(0.0, 255.0) as u8
197            })
198            .collect();
199
200        let original_size = embedding.len() * std::mem::size_of::<f32>();
201        let compressed_size =
202            quantized_values.len() * std::mem::size_of::<u8>() + std::mem::size_of::<f32>() * 2; // scale + offset
203
204        CompressedEmbedding {
205            quantized_values,
206            scale,
207            offset,
208            dimension: embedding.len(),
209            compression_ratio: original_size as f32 / compressed_size as f32,
210        }
211    }
212
213    /// Decompress embedding
214    fn decompress_embedding(&self, compressed: &CompressedEmbedding) -> Vec<f32> {
215        compressed
216            .quantized_values
217            .iter()
218            .map(|&val| (val as f32) * compressed.scale + compressed.offset)
219            .collect()
220    }
221
222    /// Make cache key
223    fn make_key(&self, text: &str, model: &str) -> String {
224        format!("{}:{}", model, text)
225    }
226
227    /// Hash text for deduplication
228    fn hash_text(&self, text: &str) -> String {
229        let mut hasher = DefaultHasher::new();
230        text.hash(&mut hasher);
231        format!("{:x}", hasher.finish())
232    }
233
234    /// Evict least frequently used entry
235    fn evict_entry(&mut self) -> RragResult<()> {
236        if self.storage.is_empty() {
237            return Ok(());
238        }
239
240        // Find LFU entry
241        let mut candidate_key: Option<String> = None;
242        let mut min_access_count = u64::MAX;
243        let mut oldest_time = std::time::SystemTime::now();
244
245        for (key, entry) in &self.storage {
246            if entry.metadata.access_count < min_access_count
247                || (entry.metadata.access_count == min_access_count
248                    && entry.metadata.last_accessed < oldest_time)
249            {
250                min_access_count = entry.metadata.access_count;
251                oldest_time = entry.metadata.last_accessed;
252                candidate_key = Some(key.clone());
253            }
254        }
255
256        if let Some(key) = candidate_key {
257            if let Some(entry) = self.storage.remove(&key) {
258                // Remove from deduplication map
259                self.deduplication_map.remove(&entry.text_hash);
260
261                // Remove from compressed storage
262                if let Some(compressed_storage) = &mut self.compressed_storage {
263                    compressed_storage.remove(&key);
264                }
265
266                self.stats.evictions += 1;
267            }
268        }
269
270        Ok(())
271    }
272
273    /// Get deduplication statistics
274    pub fn get_deduplication_stats(&self) -> DeduplicationStats {
275        let total_entries = self.storage.len();
276        let unique_texts = self.deduplication_map.len();
277        let deduplicated_count = if total_entries > unique_texts {
278            total_entries - unique_texts
279        } else {
280            0
281        };
282
283        let embedding_size = 1536 * std::mem::size_of::<f32>(); // Assume typical size
284        let memory_saved = deduplicated_count * embedding_size;
285
286        let deduplication_ratio = if total_entries > 0 {
287            deduplicated_count as f32 / total_entries as f32
288        } else {
289            0.0
290        };
291
292        DeduplicationStats {
293            deduplicated_count,
294            memory_saved,
295            deduplication_ratio,
296        }
297    }
298
299    /// Get compression statistics
300    pub fn get_compression_stats(&self) -> Option<CompressionStats> {
301        if !self.config.compression_enabled {
302            return None;
303        }
304
305        let compressed_storage = self.compressed_storage.as_ref()?;
306
307        let mut total_original_size = 0;
308        let mut total_compressed_size = 0;
309        let mut compression_ratios = Vec::new();
310
311        for compressed in compressed_storage.values() {
312            let original_size = compressed.dimension * std::mem::size_of::<f32>();
313            let compressed_size =
314                compressed.quantized_values.len() + std::mem::size_of::<f32>() * 2; // scale + offset
315
316            total_original_size += original_size;
317            total_compressed_size += compressed_size;
318            compression_ratios.push(compressed.compression_ratio);
319        }
320
321        let overall_ratio = if total_compressed_size > 0 {
322            total_original_size as f32 / total_compressed_size as f32
323        } else {
324            1.0
325        };
326
327        let avg_ratio = if !compression_ratios.is_empty() {
328            compression_ratios.iter().sum::<f32>() / compression_ratios.len() as f32
329        } else {
330            1.0
331        };
332
333        Some(CompressionStats {
334            total_entries: compressed_storage.len(),
335            total_original_size,
336            total_compressed_size,
337            overall_compression_ratio: overall_ratio,
338            average_compression_ratio: avg_ratio,
339            memory_saved: total_original_size - total_compressed_size,
340        })
341    }
342}
343
344impl Cache<String, EmbeddingCacheEntry> for EmbeddingCache {
345    fn get(&self, key: &String) -> Option<EmbeddingCacheEntry> {
346        self.storage.get(key).cloned()
347    }
348
349    fn put(&mut self, key: String, value: EmbeddingCacheEntry) -> RragResult<()> {
350        // Check capacity
351        if self.storage.len() >= self.config.max_size {
352            self.evict_entry()?;
353        }
354
355        self.storage.insert(key, value);
356        Ok(())
357    }
358
359    fn remove(&mut self, key: &String) -> Option<EmbeddingCacheEntry> {
360        let entry = self.storage.remove(key);
361
362        if let Some(ref entry_val) = entry {
363            self.deduplication_map.remove(&entry_val.text_hash);
364
365            if let Some(compressed_storage) = &mut self.compressed_storage {
366                compressed_storage.remove(key);
367            }
368        }
369
370        entry
371    }
372
373    fn contains(&self, key: &String) -> bool {
374        self.storage.contains_key(key)
375            || (self
376                .compressed_storage
377                .as_ref()
378                .map_or(false, |cs| cs.contains_key(key)))
379    }
380
381    fn clear(&mut self) {
382        self.storage.clear();
383        self.deduplication_map.clear();
384        if let Some(compressed_storage) = &mut self.compressed_storage {
385            compressed_storage.clear();
386        }
387        self.stats = CacheStats::default();
388    }
389
390    fn size(&self) -> usize {
391        self.storage.len()
392    }
393
394    fn stats(&self) -> CacheStats {
395        self.stats.clone()
396    }
397}
398
399/// Compression statistics
400#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct CompressionStats {
402    /// Number of compressed entries
403    pub total_entries: usize,
404
405    /// Total original size in bytes
406    pub total_original_size: usize,
407
408    /// Total compressed size in bytes
409    pub total_compressed_size: usize,
410
411    /// Overall compression ratio
412    pub overall_compression_ratio: f32,
413
414    /// Average compression ratio
415    pub average_compression_ratio: f32,
416
417    /// Memory saved through compression
418    pub memory_saved: usize,
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    fn create_test_config() -> EmbeddingCacheConfig {
426        EmbeddingCacheConfig {
427            enabled: true,
428            max_size: 100,
429            ttl: std::time::Duration::from_secs(3600),
430            eviction_policy: super::super::EvictionPolicy::LFU,
431            compression_enabled: true,
432        }
433    }
434
435    #[test]
436    fn test_embedding_cache_creation() {
437        let config = create_test_config();
438        let cache = EmbeddingCache::new(config).unwrap();
439
440        assert_eq!(cache.size(), 0);
441        assert!(cache.compressed_storage.is_some());
442    }
443
444    #[test]
445    fn test_basic_operations() {
446        let config = create_test_config();
447        let mut cache = EmbeddingCache::new(config).unwrap();
448
449        let text = "test text".to_string();
450        let model = "test-model".to_string();
451        let embedding = vec![1.0, 2.0, 3.0];
452
453        // Cache embedding
454        cache
455            .cache_embedding(text.clone(), model.clone(), embedding.clone())
456            .unwrap();
457        assert_eq!(cache.size(), 1);
458
459        // Retrieve embedding
460        let retrieved = cache.get_embedding(&text, &model);
461        assert!(retrieved.is_some());
462
463        // Should be approximately equal (due to compression)
464        let retrieved_embedding = retrieved.unwrap();
465        assert_eq!(retrieved_embedding.len(), embedding.len());
466    }
467
468    #[test]
469    fn test_compression() {
470        let config = create_test_config();
471        let cache = EmbeddingCache::new(config).unwrap();
472
473        let embedding = vec![1.0, 2.0, 3.0, 4.0, 5.0];
474        let compressed = cache.compress_embedding(&embedding);
475
476        assert_eq!(compressed.dimension, 5);
477        assert_eq!(compressed.quantized_values.len(), 5);
478        assert!(compressed.compression_ratio > 1.0);
479
480        let decompressed = cache.decompress_embedding(&compressed);
481        assert_eq!(decompressed.len(), embedding.len());
482    }
483
484    #[test]
485    fn test_deduplication() {
486        let config = create_test_config();
487        let mut cache = EmbeddingCache::new(config).unwrap();
488
489        let text = "same text".to_string();
490        let embedding = vec![1.0, 2.0, 3.0];
491
492        // Cache same text with different models
493        cache
494            .cache_embedding(text.clone(), "model1".to_string(), embedding.clone())
495            .unwrap();
496        cache
497            .cache_embedding(text.clone(), "model2".to_string(), embedding.clone())
498            .unwrap();
499
500        let stats = cache.get_deduplication_stats();
501        assert_eq!(stats.deduplicated_count, 1);
502        assert!(stats.deduplication_ratio > 0.0);
503    }
504
505    #[test]
506    fn test_hash_text() {
507        let config = create_test_config();
508        let cache = EmbeddingCache::new(config).unwrap();
509
510        let text1 = "hello world";
511        let text2 = "hello world";
512        let text3 = "goodbye world";
513
514        let hash1 = cache.hash_text(text1);
515        let hash2 = cache.hash_text(text2);
516        let hash3 = cache.hash_text(text3);
517
518        assert_eq!(hash1, hash2);
519        assert_ne!(hash1, hash3);
520    }
521}