1use super::{Cache, CacheEntryMetadata, CacheStats, EmbeddingCacheConfig, EmbeddingCacheEntry};
6use crate::RragResult;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::hash::{DefaultHasher, Hash, Hasher};
10
11pub struct EmbeddingCache {
13 config: EmbeddingCacheConfig,
15
16 storage: HashMap<String, EmbeddingCacheEntry>,
18
19 deduplication_map: HashMap<String, String>,
21
22 compressed_storage: Option<HashMap<String, CompressedEmbedding>>,
24
25 stats: CacheStats,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct CompressedEmbedding {
32 pub quantized_values: Vec<u8>,
34
35 pub scale: f32,
37 pub offset: f32,
38
39 pub dimension: usize,
41
42 pub compression_ratio: f32,
44}
45
46#[derive(Debug, Clone, Copy)]
48pub enum CompressionMethod {
49 None,
51
52 Quantization8Bit,
54
55 PCA,
57
58 ProductQuantization,
60
61 BinaryQuantization,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct DeduplicationStats {
68 pub deduplicated_count: usize,
70
71 pub memory_saved: usize,
73
74 pub deduplication_ratio: f32,
76}
77
78impl EmbeddingCache {
79 pub fn new(config: EmbeddingCacheConfig) -> RragResult<Self> {
81 let compressed_storage = if config.compression_enabled {
82 Some(HashMap::new())
83 } else {
84 None
85 };
86
87 Ok(Self {
88 config,
89 storage: HashMap::new(),
90 deduplication_map: HashMap::new(),
91 compressed_storage,
92 stats: CacheStats::default(),
93 })
94 }
95
96 pub fn get_embedding(&self, text: &str, model: &str) -> Option<Vec<f32>> {
98 let key = self.make_key(text, model);
99
100 if let Some(entry) = self.storage.get(&key) {
102 return Some(entry.embedding.clone());
103 }
104
105 let text_hash = self.hash_text(text);
107 if let Some(canonical_key) = self.deduplication_map.get(&text_hash) {
108 if let Some(entry) = self.storage.get(canonical_key) {
109 return Some(entry.embedding.clone());
110 }
111 }
112
113 if let Some(compressed_storage) = &self.compressed_storage {
115 if let Some(compressed) = compressed_storage.get(&key) {
116 return Some(self.decompress_embedding(compressed));
117 }
118 }
119
120 None
121 }
122
123 pub fn cache_embedding(
125 &mut self,
126 text: String,
127 model: String,
128 embedding: Vec<f32>,
129 ) -> RragResult<()> {
130 let key = self.make_key(&text, &model);
131 let text_hash = self.hash_text(&text);
132
133 if let Some(existing_key) = self.deduplication_map.get(&text_hash) {
135 if !self.storage.contains_key(&key) {
137 if let Some(existing_entry) = self.storage.get(existing_key).cloned() {
138 self.storage.insert(key, existing_entry);
139 }
140 }
141 return Ok(());
142 }
143
144 if self.storage.len() >= self.config.max_size {
146 self.evict_entry()?;
147 }
148
149 let entry = EmbeddingCacheEntry {
151 text: text.clone(),
152 text_hash: text_hash.clone(),
153 embedding: embedding.clone(),
154 model: model.clone(),
155 metadata: CacheEntryMetadata::new(),
156 };
157
158 if self.config.compression_enabled {
160 let compressed = self.compress_embedding(&embedding);
161 if let Some(compressed_storage) = &mut self.compressed_storage {
162 compressed_storage.insert(key.clone(), compressed);
163 }
164
165 let mut metadata_entry = entry;
167 metadata_entry.embedding = Vec::new(); self.storage.insert(key.clone(), metadata_entry);
169 } else {
170 self.storage.insert(key.clone(), entry);
171 }
172
173 self.deduplication_map.insert(text_hash, key);
175
176 Ok(())
177 }
178
179 fn compress_embedding(&self, embedding: &[f32]) -> CompressedEmbedding {
181 let (min_val, max_val) = embedding
183 .iter()
184 .fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), &val| {
185 (min.min(val), max.max(val))
186 });
187
188 let range = max_val - min_val;
189 let scale = range / 255.0;
190 let offset = min_val;
191
192 let quantized_values: Vec<u8> = embedding
193 .iter()
194 .map(|&val| {
195 let normalized = (val - offset) / scale;
196 normalized.round().clamp(0.0, 255.0) as u8
197 })
198 .collect();
199
200 let original_size = embedding.len() * std::mem::size_of::<f32>();
201 let compressed_size =
202 quantized_values.len() * std::mem::size_of::<u8>() + std::mem::size_of::<f32>() * 2; CompressedEmbedding {
205 quantized_values,
206 scale,
207 offset,
208 dimension: embedding.len(),
209 compression_ratio: original_size as f32 / compressed_size as f32,
210 }
211 }
212
213 fn decompress_embedding(&self, compressed: &CompressedEmbedding) -> Vec<f32> {
215 compressed
216 .quantized_values
217 .iter()
218 .map(|&val| (val as f32) * compressed.scale + compressed.offset)
219 .collect()
220 }
221
222 fn make_key(&self, text: &str, model: &str) -> String {
224 format!("{}:{}", model, text)
225 }
226
227 fn hash_text(&self, text: &str) -> String {
229 let mut hasher = DefaultHasher::new();
230 text.hash(&mut hasher);
231 format!("{:x}", hasher.finish())
232 }
233
234 fn evict_entry(&mut self) -> RragResult<()> {
236 if self.storage.is_empty() {
237 return Ok(());
238 }
239
240 let mut candidate_key: Option<String> = None;
242 let mut min_access_count = u64::MAX;
243 let mut oldest_time = std::time::SystemTime::now();
244
245 for (key, entry) in &self.storage {
246 if entry.metadata.access_count < min_access_count
247 || (entry.metadata.access_count == min_access_count
248 && entry.metadata.last_accessed < oldest_time)
249 {
250 min_access_count = entry.metadata.access_count;
251 oldest_time = entry.metadata.last_accessed;
252 candidate_key = Some(key.clone());
253 }
254 }
255
256 if let Some(key) = candidate_key {
257 if let Some(entry) = self.storage.remove(&key) {
258 self.deduplication_map.remove(&entry.text_hash);
260
261 if let Some(compressed_storage) = &mut self.compressed_storage {
263 compressed_storage.remove(&key);
264 }
265
266 self.stats.evictions += 1;
267 }
268 }
269
270 Ok(())
271 }
272
273 pub fn get_deduplication_stats(&self) -> DeduplicationStats {
275 let total_entries = self.storage.len();
276 let unique_texts = self.deduplication_map.len();
277 let deduplicated_count = if total_entries > unique_texts {
278 total_entries - unique_texts
279 } else {
280 0
281 };
282
283 let embedding_size = 1536 * std::mem::size_of::<f32>(); let memory_saved = deduplicated_count * embedding_size;
285
286 let deduplication_ratio = if total_entries > 0 {
287 deduplicated_count as f32 / total_entries as f32
288 } else {
289 0.0
290 };
291
292 DeduplicationStats {
293 deduplicated_count,
294 memory_saved,
295 deduplication_ratio,
296 }
297 }
298
299 pub fn get_compression_stats(&self) -> Option<CompressionStats> {
301 if !self.config.compression_enabled {
302 return None;
303 }
304
305 let compressed_storage = self.compressed_storage.as_ref()?;
306
307 let mut total_original_size = 0;
308 let mut total_compressed_size = 0;
309 let mut compression_ratios = Vec::new();
310
311 for compressed in compressed_storage.values() {
312 let original_size = compressed.dimension * std::mem::size_of::<f32>();
313 let compressed_size =
314 compressed.quantized_values.len() + std::mem::size_of::<f32>() * 2; total_original_size += original_size;
317 total_compressed_size += compressed_size;
318 compression_ratios.push(compressed.compression_ratio);
319 }
320
321 let overall_ratio = if total_compressed_size > 0 {
322 total_original_size as f32 / total_compressed_size as f32
323 } else {
324 1.0
325 };
326
327 let avg_ratio = if !compression_ratios.is_empty() {
328 compression_ratios.iter().sum::<f32>() / compression_ratios.len() as f32
329 } else {
330 1.0
331 };
332
333 Some(CompressionStats {
334 total_entries: compressed_storage.len(),
335 total_original_size,
336 total_compressed_size,
337 overall_compression_ratio: overall_ratio,
338 average_compression_ratio: avg_ratio,
339 memory_saved: total_original_size - total_compressed_size,
340 })
341 }
342}
343
344impl Cache<String, EmbeddingCacheEntry> for EmbeddingCache {
345 fn get(&self, key: &String) -> Option<EmbeddingCacheEntry> {
346 self.storage.get(key).cloned()
347 }
348
349 fn put(&mut self, key: String, value: EmbeddingCacheEntry) -> RragResult<()> {
350 if self.storage.len() >= self.config.max_size {
352 self.evict_entry()?;
353 }
354
355 self.storage.insert(key, value);
356 Ok(())
357 }
358
359 fn remove(&mut self, key: &String) -> Option<EmbeddingCacheEntry> {
360 let entry = self.storage.remove(key);
361
362 if let Some(ref entry_val) = entry {
363 self.deduplication_map.remove(&entry_val.text_hash);
364
365 if let Some(compressed_storage) = &mut self.compressed_storage {
366 compressed_storage.remove(key);
367 }
368 }
369
370 entry
371 }
372
373 fn contains(&self, key: &String) -> bool {
374 self.storage.contains_key(key)
375 || (self
376 .compressed_storage
377 .as_ref()
378 .map_or(false, |cs| cs.contains_key(key)))
379 }
380
381 fn clear(&mut self) {
382 self.storage.clear();
383 self.deduplication_map.clear();
384 if let Some(compressed_storage) = &mut self.compressed_storage {
385 compressed_storage.clear();
386 }
387 self.stats = CacheStats::default();
388 }
389
390 fn size(&self) -> usize {
391 self.storage.len()
392 }
393
394 fn stats(&self) -> CacheStats {
395 self.stats.clone()
396 }
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct CompressionStats {
402 pub total_entries: usize,
404
405 pub total_original_size: usize,
407
408 pub total_compressed_size: usize,
410
411 pub overall_compression_ratio: f32,
413
414 pub average_compression_ratio: f32,
416
417 pub memory_saved: usize,
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 fn create_test_config() -> EmbeddingCacheConfig {
426 EmbeddingCacheConfig {
427 enabled: true,
428 max_size: 100,
429 ttl: std::time::Duration::from_secs(3600),
430 eviction_policy: super::super::EvictionPolicy::LFU,
431 compression_enabled: true,
432 }
433 }
434
435 #[test]
436 fn test_embedding_cache_creation() {
437 let config = create_test_config();
438 let cache = EmbeddingCache::new(config).unwrap();
439
440 assert_eq!(cache.size(), 0);
441 assert!(cache.compressed_storage.is_some());
442 }
443
444 #[test]
445 fn test_basic_operations() {
446 let config = create_test_config();
447 let mut cache = EmbeddingCache::new(config).unwrap();
448
449 let text = "test text".to_string();
450 let model = "test-model".to_string();
451 let embedding = vec![1.0, 2.0, 3.0];
452
453 cache
455 .cache_embedding(text.clone(), model.clone(), embedding.clone())
456 .unwrap();
457 assert_eq!(cache.size(), 1);
458
459 let retrieved = cache.get_embedding(&text, &model);
461 assert!(retrieved.is_some());
462
463 let retrieved_embedding = retrieved.unwrap();
465 assert_eq!(retrieved_embedding.len(), embedding.len());
466 }
467
468 #[test]
469 fn test_compression() {
470 let config = create_test_config();
471 let cache = EmbeddingCache::new(config).unwrap();
472
473 let embedding = vec![1.0, 2.0, 3.0, 4.0, 5.0];
474 let compressed = cache.compress_embedding(&embedding);
475
476 assert_eq!(compressed.dimension, 5);
477 assert_eq!(compressed.quantized_values.len(), 5);
478 assert!(compressed.compression_ratio > 1.0);
479
480 let decompressed = cache.decompress_embedding(&compressed);
481 assert_eq!(decompressed.len(), embedding.len());
482 }
483
484 #[test]
485 fn test_deduplication() {
486 let config = create_test_config();
487 let mut cache = EmbeddingCache::new(config).unwrap();
488
489 let text = "same text".to_string();
490 let embedding = vec![1.0, 2.0, 3.0];
491
492 cache
494 .cache_embedding(text.clone(), "model1".to_string(), embedding.clone())
495 .unwrap();
496 cache
497 .cache_embedding(text.clone(), "model2".to_string(), embedding.clone())
498 .unwrap();
499
500 let stats = cache.get_deduplication_stats();
501 assert_eq!(stats.deduplicated_count, 1);
502 assert!(stats.deduplication_ratio > 0.0);
503 }
504
505 #[test]
506 fn test_hash_text() {
507 let config = create_test_config();
508 let cache = EmbeddingCache::new(config).unwrap();
509
510 let text1 = "hello world";
511 let text2 = "hello world";
512 let text3 = "goodbye world";
513
514 let hash1 = cache.hash_text(text1);
515 let hash2 = cache.hash_text(text2);
516 let hash3 = cache.hash_text(text3);
517
518 assert_eq!(hash1, hash2);
519 assert_ne!(hash1, hash3);
520 }
521}