Skip to main content

heliosdb_proxy/distribcache/ai/
rag.rs

1//! RAG chunk cache
2//!
3//! Caches document chunks for RAG (Retrieval-Augmented Generation) pipelines.
4//! Optimized for embedding-based retrieval and document fetching.
5
6use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::Instant;
10
11/// Chunk identifier
12pub type ChunkId = u64;
13
14/// Embedding hash for cache lookup
15pub type EmbeddingHash = u64;
16
17/// Document chunk
18#[derive(Debug, Clone)]
19pub struct Chunk {
20    /// Chunk ID
21    pub id: ChunkId,
22    /// Parent document ID
23    pub document_id: String,
24    /// Chunk content
25    pub content: String,
26    /// Chunk embedding (optional, for similarity)
27    pub embedding: Option<Vec<f32>>,
28    /// Chunk position in document
29    pub position: usize,
30    /// Metadata
31    pub metadata: Option<serde_json::Value>,
32    /// Creation time
33    pub created_at: Instant,
34}
35
36impl Chunk {
37    /// Create a new chunk
38    pub fn new(id: ChunkId, document_id: impl Into<String>, content: impl Into<String>) -> Self {
39        Self {
40            id,
41            document_id: document_id.into(),
42            content: content.into(),
43            embedding: None,
44            position: 0,
45            metadata: None,
46            created_at: Instant::now(),
47        }
48    }
49
50    /// Add embedding
51    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
52        self.embedding = Some(embedding);
53        self
54    }
55
56    /// Set position
57    pub fn with_position(mut self, position: usize) -> Self {
58        self.position = position;
59        self
60    }
61
62    /// Add metadata
63    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
64        self.metadata = Some(metadata);
65        self
66    }
67
68    /// Approximate size in bytes
69    pub fn size(&self) -> usize {
70        self.content.len()
71            + self.document_id.len()
72            + self.embedding.as_ref().map(|e| e.len() * 4).unwrap_or(0)
73            + 64
74    }
75}
76
77/// Hash an embedding vector
78pub fn hash_embedding(embedding: &[f32]) -> EmbeddingHash {
79    use std::collections::hash_map::DefaultHasher;
80    use std::hash::{Hash, Hasher};
81
82    let mut hasher = DefaultHasher::new();
83
84    // Quantize and hash
85    for val in embedding {
86        let quantized = (val * 1000.0) as i32;
87        quantized.hash(&mut hasher);
88    }
89
90    hasher.finish()
91}
92
93/// RAG chunk cache
94pub struct RagChunkCache {
95    /// Chunk storage (id -> chunk)
96    chunks: DashMap<ChunkId, Chunk>,
97
98    /// Embedding to chunk IDs mapping
99    embedding_to_chunks: DashMap<EmbeddingHash, Vec<ChunkId>>,
100
101    /// Document to chunk IDs mapping
102    document_to_chunks: DashMap<String, HashSet<ChunkId>>,
103
104    /// Maximum cache size in MB
105    max_size_mb: usize,
106
107    /// Current size in bytes
108    current_size: AtomicU64,
109
110    /// Statistics
111    stats: RagCacheStats,
112}
113
114/// RAG cache statistics
115#[derive(Debug, Default)]
116struct RagCacheStats {
117    hits: AtomicU64,
118    misses: AtomicU64,
119    embedding_lookups: AtomicU64,
120    embedding_cache_hits: AtomicU64,
121}
122
123impl RagChunkCache {
124    /// Create a new RAG chunk cache
125    pub fn new(max_size_mb: usize) -> Self {
126        Self {
127            chunks: DashMap::new(),
128            embedding_to_chunks: DashMap::new(),
129            document_to_chunks: DashMap::new(),
130            max_size_mb,
131            current_size: AtomicU64::new(0),
132            stats: RagCacheStats::default(),
133        }
134    }
135
136    /// Get a chunk by ID
137    pub fn get_chunk(&self, id: ChunkId) -> Option<Chunk> {
138        if let Some(chunk) = self.chunks.get(&id) {
139            self.stats.hits.fetch_add(1, Ordering::Relaxed);
140            Some(chunk.clone())
141        } else {
142            self.stats.misses.fetch_add(1, Ordering::Relaxed);
143            None
144        }
145    }
146
147    /// Get chunks by embedding similarity
148    pub fn get_chunks_by_embedding(&self, embedding: &[f32], k: usize) -> Vec<Chunk> {
149        self.stats.embedding_lookups.fetch_add(1, Ordering::Relaxed);
150
151        let hash = hash_embedding(embedding);
152
153        if let Some(chunk_ids) = self.embedding_to_chunks.get(&hash) {
154            self.stats
155                .embedding_cache_hits
156                .fetch_add(1, Ordering::Relaxed);
157
158            let chunks: Vec<_> = chunk_ids
159                .iter()
160                .filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
161                .take(k)
162                .collect();
163
164            return chunks;
165        }
166
167        Vec::new()
168    }
169
170    /// Get all chunks for a document
171    pub fn get_document_chunks(&self, document_id: &str) -> Vec<Chunk> {
172        if let Some(ids) = self.document_to_chunks.get(document_id) {
173            ids.iter()
174                .filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
175                .collect()
176        } else {
177            Vec::new()
178        }
179    }
180
181    /// Insert a chunk
182    pub fn insert_chunk(&self, chunk: Chunk) {
183        let size = chunk.size() as u64;
184        let max_bytes = (self.max_size_mb * 1024 * 1024) as u64;
185
186        // Evict if needed
187        while self.current_size.load(Ordering::Relaxed) + size > max_bytes {
188            if !self.evict_one() {
189                break;
190            }
191        }
192
193        // Index by document
194        self.document_to_chunks
195            .entry(chunk.document_id.clone())
196            .or_default()
197            .insert(chunk.id);
198
199        // Index by embedding if available
200        if let Some(ref embedding) = chunk.embedding {
201            let hash = hash_embedding(embedding);
202            self.embedding_to_chunks
203                .entry(hash)
204                .or_default()
205                .push(chunk.id);
206        }
207
208        // Store chunk
209        self.chunks.insert(chunk.id, chunk);
210        self.current_size.fetch_add(size, Ordering::Relaxed);
211    }
212
213    /// Insert multiple chunks (batch)
214    pub fn insert_chunks(&self, chunks: Vec<Chunk>) {
215        for chunk in chunks {
216            self.insert_chunk(chunk);
217        }
218    }
219
220    /// Cache embedding to chunk ID mapping
221    pub fn cache_embedding_result(&self, embedding: &[f32], chunk_ids: Vec<ChunkId>) {
222        let hash = hash_embedding(embedding);
223        self.embedding_to_chunks.insert(hash, chunk_ids);
224    }
225
226    /// Remove a chunk
227    pub fn remove_chunk(&self, id: ChunkId) {
228        if let Some((_, chunk)) = self.chunks.remove(&id) {
229            self.current_size
230                .fetch_sub(chunk.size() as u64, Ordering::Relaxed);
231
232            // Remove from document index
233            if let Some(mut ids) = self.document_to_chunks.get_mut(&chunk.document_id) {
234                ids.remove(&id);
235            }
236        }
237    }
238
239    /// Remove all chunks for a document
240    pub fn remove_document(&self, document_id: &str) {
241        if let Some((_, ids)) = self.document_to_chunks.remove(document_id) {
242            for id in ids {
243                self.remove_chunk(id);
244            }
245        }
246    }
247
248    /// Evict one chunk (oldest by creation time)
249    fn evict_one(&self) -> bool {
250        let mut oldest_id = None;
251        let mut oldest_time = Instant::now();
252
253        for entry in self.chunks.iter() {
254            if entry.created_at < oldest_time {
255                oldest_time = entry.created_at;
256                oldest_id = Some(*entry.key());
257            }
258        }
259
260        if let Some(id) = oldest_id {
261            self.remove_chunk(id);
262            return true;
263        }
264
265        false
266    }
267
268    /// Get cache statistics
269    pub fn stats(&self) -> RagCacheStatsSnapshot {
270        RagCacheStatsSnapshot {
271            chunk_count: self.chunks.len(),
272            document_count: self.document_to_chunks.len(),
273            size_bytes: self.current_size.load(Ordering::Relaxed),
274            max_size_bytes: (self.max_size_mb * 1024 * 1024) as u64,
275            hits: self.stats.hits.load(Ordering::Relaxed),
276            misses: self.stats.misses.load(Ordering::Relaxed),
277            embedding_lookups: self.stats.embedding_lookups.load(Ordering::Relaxed),
278            embedding_cache_hit_rate: {
279                let lookups = self.stats.embedding_lookups.load(Ordering::Relaxed);
280                let hits = self.stats.embedding_cache_hits.load(Ordering::Relaxed);
281                if lookups > 0 {
282                    hits as f64 / lookups as f64
283                } else {
284                    0.0
285                }
286            },
287        }
288    }
289
290    /// Clear all cached chunks
291    pub fn clear(&self) {
292        self.chunks.clear();
293        self.embedding_to_chunks.clear();
294        self.document_to_chunks.clear();
295        self.current_size.store(0, Ordering::Relaxed);
296    }
297}
298
299/// RAG cache statistics snapshot
300#[derive(Debug, Clone)]
301pub struct RagCacheStatsSnapshot {
302    pub chunk_count: usize,
303    pub document_count: usize,
304    pub size_bytes: u64,
305    pub max_size_bytes: u64,
306    pub hits: u64,
307    pub misses: u64,
308    pub embedding_lookups: u64,
309    pub embedding_cache_hit_rate: f64,
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_chunk_creation() {
318        let chunk = Chunk::new(1, "doc-1", "This is a test chunk").with_position(0);
319
320        assert_eq!(chunk.id, 1);
321        assert_eq!(chunk.document_id, "doc-1");
322        assert_eq!(chunk.position, 0);
323    }
324
325    #[test]
326    fn test_insert_and_get() {
327        let cache = RagChunkCache::new(10);
328
329        let chunk = Chunk::new(1, "doc-1", "Test content");
330        cache.insert_chunk(chunk);
331
332        let retrieved = cache.get_chunk(1);
333        assert!(retrieved.is_some());
334        assert_eq!(retrieved.unwrap().content, "Test content");
335    }
336
337    #[test]
338    fn test_document_chunks() {
339        let cache = RagChunkCache::new(10);
340
341        cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1").with_position(0));
342        cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2").with_position(1));
343        cache.insert_chunk(Chunk::new(3, "doc-2", "Chunk 3").with_position(0));
344
345        let doc1_chunks = cache.get_document_chunks("doc-1");
346        assert_eq!(doc1_chunks.len(), 2);
347
348        let doc2_chunks = cache.get_document_chunks("doc-2");
349        assert_eq!(doc2_chunks.len(), 1);
350    }
351
352    #[test]
353    fn test_embedding_lookup() {
354        let cache = RagChunkCache::new(10);
355
356        let embedding = vec![0.1, 0.2, 0.3];
357        let chunk = Chunk::new(1, "doc-1", "Embedded content").with_embedding(embedding.clone());
358
359        cache.insert_chunk(chunk);
360
361        // Cache the embedding result
362        cache.cache_embedding_result(&embedding, vec![1]);
363
364        // Lookup by embedding
365        let results = cache.get_chunks_by_embedding(&embedding, 10);
366        assert_eq!(results.len(), 1);
367        assert_eq!(results[0].id, 1);
368    }
369
370    #[test]
371    fn test_remove_document() {
372        let cache = RagChunkCache::new(10);
373
374        cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1"));
375        cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2"));
376
377        cache.remove_document("doc-1");
378
379        assert!(cache.get_chunk(1).is_none());
380        assert!(cache.get_chunk(2).is_none());
381    }
382
383    #[test]
384    fn test_stats() {
385        let cache = RagChunkCache::new(10);
386
387        cache.insert_chunk(Chunk::new(1, "doc-1", "Content"));
388        cache.get_chunk(1); // Hit
389        cache.get_chunk(2); // Miss
390
391        let stats = cache.stats();
392        assert_eq!(stats.chunk_count, 1);
393        assert_eq!(stats.hits, 1);
394        assert_eq!(stats.misses, 1);
395    }
396}