Skip to main content

heliosdb_proxy/distribcache/ai/
rag.rs

1//! RAG chunk cache
2//!
3//! Caches document chunks for RAG (Retrieval-Augmented Generation) pipelines.
4//! Optimized for embedding-based retrieval and document fetching.
5
6use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::Instant;
10
11/// Chunk identifier
12pub type ChunkId = u64;
13
14/// Embedding hash for cache lookup
15pub type EmbeddingHash = u64;
16
17/// Document chunk
18#[derive(Debug, Clone)]
19pub struct Chunk {
20    /// Chunk ID
21    pub id: ChunkId,
22    /// Parent document ID
23    pub document_id: String,
24    /// Chunk content
25    pub content: String,
26    /// Chunk embedding (optional, for similarity)
27    pub embedding: Option<Vec<f32>>,
28    /// Chunk position in document
29    pub position: usize,
30    /// Metadata
31    pub metadata: Option<serde_json::Value>,
32    /// Creation time
33    pub created_at: Instant,
34}
35
36impl Chunk {
37    /// Create a new chunk
38    pub fn new(id: ChunkId, document_id: impl Into<String>, content: impl Into<String>) -> Self {
39        Self {
40            id,
41            document_id: document_id.into(),
42            content: content.into(),
43            embedding: None,
44            position: 0,
45            metadata: None,
46            created_at: Instant::now(),
47        }
48    }
49
50    /// Add embedding
51    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
52        self.embedding = Some(embedding);
53        self
54    }
55
56    /// Set position
57    pub fn with_position(mut self, position: usize) -> Self {
58        self.position = position;
59        self
60    }
61
62    /// Add metadata
63    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
64        self.metadata = Some(metadata);
65        self
66    }
67
68    /// Approximate size in bytes
69    pub fn size(&self) -> usize {
70        self.content.len() +
71        self.document_id.len() +
72        self.embedding.as_ref().map(|e| e.len() * 4).unwrap_or(0) +
73        64
74    }
75}
76
77/// Hash an embedding vector
78pub fn hash_embedding(embedding: &[f32]) -> EmbeddingHash {
79    use std::hash::{Hash, Hasher};
80    use std::collections::hash_map::DefaultHasher;
81
82    let mut hasher = DefaultHasher::new();
83
84    // Quantize and hash
85    for val in embedding {
86        let quantized = (val * 1000.0) as i32;
87        quantized.hash(&mut hasher);
88    }
89
90    hasher.finish()
91}
92
93/// RAG chunk cache
94pub struct RagChunkCache {
95    /// Chunk storage (id -> chunk)
96    chunks: DashMap<ChunkId, Chunk>,
97
98    /// Embedding to chunk IDs mapping
99    embedding_to_chunks: DashMap<EmbeddingHash, Vec<ChunkId>>,
100
101    /// Document to chunk IDs mapping
102    document_to_chunks: DashMap<String, HashSet<ChunkId>>,
103
104    /// Maximum cache size in MB
105    max_size_mb: usize,
106
107    /// Current size in bytes
108    current_size: AtomicU64,
109
110    /// Statistics
111    stats: RagCacheStats,
112}
113
114/// RAG cache statistics
115#[derive(Debug, Default)]
116struct RagCacheStats {
117    hits: AtomicU64,
118    misses: AtomicU64,
119    embedding_lookups: AtomicU64,
120    embedding_cache_hits: AtomicU64,
121}
122
123impl RagChunkCache {
124    /// Create a new RAG chunk cache
125    pub fn new(max_size_mb: usize) -> Self {
126        Self {
127            chunks: DashMap::new(),
128            embedding_to_chunks: DashMap::new(),
129            document_to_chunks: DashMap::new(),
130            max_size_mb,
131            current_size: AtomicU64::new(0),
132            stats: RagCacheStats::default(),
133        }
134    }
135
136    /// Get a chunk by ID
137    pub fn get_chunk(&self, id: ChunkId) -> Option<Chunk> {
138        if let Some(chunk) = self.chunks.get(&id) {
139            self.stats.hits.fetch_add(1, Ordering::Relaxed);
140            Some(chunk.clone())
141        } else {
142            self.stats.misses.fetch_add(1, Ordering::Relaxed);
143            None
144        }
145    }
146
147    /// Get chunks by embedding similarity
148    pub fn get_chunks_by_embedding(&self, embedding: &[f32], k: usize) -> Vec<Chunk> {
149        self.stats.embedding_lookups.fetch_add(1, Ordering::Relaxed);
150
151        let hash = hash_embedding(embedding);
152
153        if let Some(chunk_ids) = self.embedding_to_chunks.get(&hash) {
154            self.stats.embedding_cache_hits.fetch_add(1, Ordering::Relaxed);
155
156            let chunks: Vec<_> = chunk_ids.iter()
157                .filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
158                .take(k)
159                .collect();
160
161            return chunks;
162        }
163
164        Vec::new()
165    }
166
167    /// Get all chunks for a document
168    pub fn get_document_chunks(&self, document_id: &str) -> Vec<Chunk> {
169        if let Some(ids) = self.document_to_chunks.get(document_id) {
170            ids.iter()
171                .filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
172                .collect()
173        } else {
174            Vec::new()
175        }
176    }
177
178    /// Insert a chunk
179    pub fn insert_chunk(&self, chunk: Chunk) {
180        let size = chunk.size() as u64;
181        let max_bytes = (self.max_size_mb * 1024 * 1024) as u64;
182
183        // Evict if needed
184        while self.current_size.load(Ordering::Relaxed) + size > max_bytes {
185            if !self.evict_one() {
186                break;
187            }
188        }
189
190        // Index by document
191        self.document_to_chunks
192            .entry(chunk.document_id.clone())
193            .or_default()
194            .insert(chunk.id);
195
196        // Index by embedding if available
197        if let Some(ref embedding) = chunk.embedding {
198            let hash = hash_embedding(embedding);
199            self.embedding_to_chunks
200                .entry(hash)
201                .or_default()
202                .push(chunk.id);
203        }
204
205        // Store chunk
206        self.chunks.insert(chunk.id, chunk);
207        self.current_size.fetch_add(size, Ordering::Relaxed);
208    }
209
210    /// Insert multiple chunks (batch)
211    pub fn insert_chunks(&self, chunks: Vec<Chunk>) {
212        for chunk in chunks {
213            self.insert_chunk(chunk);
214        }
215    }
216
217    /// Cache embedding to chunk ID mapping
218    pub fn cache_embedding_result(&self, embedding: &[f32], chunk_ids: Vec<ChunkId>) {
219        let hash = hash_embedding(embedding);
220        self.embedding_to_chunks.insert(hash, chunk_ids);
221    }
222
223    /// Remove a chunk
224    pub fn remove_chunk(&self, id: ChunkId) {
225        if let Some((_, chunk)) = self.chunks.remove(&id) {
226            self.current_size.fetch_sub(chunk.size() as u64, Ordering::Relaxed);
227
228            // Remove from document index
229            if let Some(mut ids) = self.document_to_chunks.get_mut(&chunk.document_id) {
230                ids.remove(&id);
231            }
232        }
233    }
234
235    /// Remove all chunks for a document
236    pub fn remove_document(&self, document_id: &str) {
237        if let Some((_, ids)) = self.document_to_chunks.remove(document_id) {
238            for id in ids {
239                self.remove_chunk(id);
240            }
241        }
242    }
243
244    /// Evict one chunk (oldest by creation time)
245    fn evict_one(&self) -> bool {
246        let mut oldest_id = None;
247        let mut oldest_time = Instant::now();
248
249        for entry in self.chunks.iter() {
250            if entry.created_at < oldest_time {
251                oldest_time = entry.created_at;
252                oldest_id = Some(*entry.key());
253            }
254        }
255
256        if let Some(id) = oldest_id {
257            self.remove_chunk(id);
258            return true;
259        }
260
261        false
262    }
263
264    /// Get cache statistics
265    pub fn stats(&self) -> RagCacheStatsSnapshot {
266        RagCacheStatsSnapshot {
267            chunk_count: self.chunks.len(),
268            document_count: self.document_to_chunks.len(),
269            size_bytes: self.current_size.load(Ordering::Relaxed),
270            max_size_bytes: (self.max_size_mb * 1024 * 1024) as u64,
271            hits: self.stats.hits.load(Ordering::Relaxed),
272            misses: self.stats.misses.load(Ordering::Relaxed),
273            embedding_lookups: self.stats.embedding_lookups.load(Ordering::Relaxed),
274            embedding_cache_hit_rate: {
275                let lookups = self.stats.embedding_lookups.load(Ordering::Relaxed);
276                let hits = self.stats.embedding_cache_hits.load(Ordering::Relaxed);
277                if lookups > 0 { hits as f64 / lookups as f64 } else { 0.0 }
278            },
279        }
280    }
281
282    /// Clear all cached chunks
283    pub fn clear(&self) {
284        self.chunks.clear();
285        self.embedding_to_chunks.clear();
286        self.document_to_chunks.clear();
287        self.current_size.store(0, Ordering::Relaxed);
288    }
289}
290
291/// RAG cache statistics snapshot
292#[derive(Debug, Clone)]
293pub struct RagCacheStatsSnapshot {
294    pub chunk_count: usize,
295    pub document_count: usize,
296    pub size_bytes: u64,
297    pub max_size_bytes: u64,
298    pub hits: u64,
299    pub misses: u64,
300    pub embedding_lookups: u64,
301    pub embedding_cache_hit_rate: f64,
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_chunk_creation() {
310        let chunk = Chunk::new(1, "doc-1", "This is a test chunk")
311            .with_position(0);
312
313        assert_eq!(chunk.id, 1);
314        assert_eq!(chunk.document_id, "doc-1");
315        assert_eq!(chunk.position, 0);
316    }
317
318    #[test]
319    fn test_insert_and_get() {
320        let cache = RagChunkCache::new(10);
321
322        let chunk = Chunk::new(1, "doc-1", "Test content");
323        cache.insert_chunk(chunk);
324
325        let retrieved = cache.get_chunk(1);
326        assert!(retrieved.is_some());
327        assert_eq!(retrieved.unwrap().content, "Test content");
328    }
329
330    #[test]
331    fn test_document_chunks() {
332        let cache = RagChunkCache::new(10);
333
334        cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1").with_position(0));
335        cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2").with_position(1));
336        cache.insert_chunk(Chunk::new(3, "doc-2", "Chunk 3").with_position(0));
337
338        let doc1_chunks = cache.get_document_chunks("doc-1");
339        assert_eq!(doc1_chunks.len(), 2);
340
341        let doc2_chunks = cache.get_document_chunks("doc-2");
342        assert_eq!(doc2_chunks.len(), 1);
343    }
344
345    #[test]
346    fn test_embedding_lookup() {
347        let cache = RagChunkCache::new(10);
348
349        let embedding = vec![0.1, 0.2, 0.3];
350        let chunk = Chunk::new(1, "doc-1", "Embedded content")
351            .with_embedding(embedding.clone());
352
353        cache.insert_chunk(chunk);
354
355        // Cache the embedding result
356        cache.cache_embedding_result(&embedding, vec![1]);
357
358        // Lookup by embedding
359        let results = cache.get_chunks_by_embedding(&embedding, 10);
360        assert_eq!(results.len(), 1);
361        assert_eq!(results[0].id, 1);
362    }
363
364    #[test]
365    fn test_remove_document() {
366        let cache = RagChunkCache::new(10);
367
368        cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1"));
369        cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2"));
370
371        cache.remove_document("doc-1");
372
373        assert!(cache.get_chunk(1).is_none());
374        assert!(cache.get_chunk(2).is_none());
375    }
376
377    #[test]
378    fn test_stats() {
379        let cache = RagChunkCache::new(10);
380
381        cache.insert_chunk(Chunk::new(1, "doc-1", "Content"));
382        cache.get_chunk(1); // Hit
383        cache.get_chunk(2); // Miss
384
385        let stats = cache.stats();
386        assert_eq!(stats.chunk_count, 1);
387        assert_eq!(stats.hits, 1);
388        assert_eq!(stats.misses, 1);
389    }
390}