heliosdb_proxy/distribcache/ai/
rag.rs1use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::Instant;
10
11pub type ChunkId = u64;
13
14pub type EmbeddingHash = u64;
16
17#[derive(Debug, Clone)]
19pub struct Chunk {
20 pub id: ChunkId,
22 pub document_id: String,
24 pub content: String,
26 pub embedding: Option<Vec<f32>>,
28 pub position: usize,
30 pub metadata: Option<serde_json::Value>,
32 pub created_at: Instant,
34}
35
36impl Chunk {
37 pub fn new(id: ChunkId, document_id: impl Into<String>, content: impl Into<String>) -> Self {
39 Self {
40 id,
41 document_id: document_id.into(),
42 content: content.into(),
43 embedding: None,
44 position: 0,
45 metadata: None,
46 created_at: Instant::now(),
47 }
48 }
49
50 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
52 self.embedding = Some(embedding);
53 self
54 }
55
56 pub fn with_position(mut self, position: usize) -> Self {
58 self.position = position;
59 self
60 }
61
62 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
64 self.metadata = Some(metadata);
65 self
66 }
67
68 pub fn size(&self) -> usize {
70 self.content.len()
71 + self.document_id.len()
72 + self.embedding.as_ref().map(|e| e.len() * 4).unwrap_or(0)
73 + 64
74 }
75}
76
77pub fn hash_embedding(embedding: &[f32]) -> EmbeddingHash {
79 use std::collections::hash_map::DefaultHasher;
80 use std::hash::{Hash, Hasher};
81
82 let mut hasher = DefaultHasher::new();
83
84 for val in embedding {
86 let quantized = (val * 1000.0) as i32;
87 quantized.hash(&mut hasher);
88 }
89
90 hasher.finish()
91}
92
93pub struct RagChunkCache {
95 chunks: DashMap<ChunkId, Chunk>,
97
98 embedding_to_chunks: DashMap<EmbeddingHash, Vec<ChunkId>>,
100
101 document_to_chunks: DashMap<String, HashSet<ChunkId>>,
103
104 max_size_mb: usize,
106
107 current_size: AtomicU64,
109
110 stats: RagCacheStats,
112}
113
114#[derive(Debug, Default)]
116struct RagCacheStats {
117 hits: AtomicU64,
118 misses: AtomicU64,
119 embedding_lookups: AtomicU64,
120 embedding_cache_hits: AtomicU64,
121}
122
123impl RagChunkCache {
124 pub fn new(max_size_mb: usize) -> Self {
126 Self {
127 chunks: DashMap::new(),
128 embedding_to_chunks: DashMap::new(),
129 document_to_chunks: DashMap::new(),
130 max_size_mb,
131 current_size: AtomicU64::new(0),
132 stats: RagCacheStats::default(),
133 }
134 }
135
136 pub fn get_chunk(&self, id: ChunkId) -> Option<Chunk> {
138 if let Some(chunk) = self.chunks.get(&id) {
139 self.stats.hits.fetch_add(1, Ordering::Relaxed);
140 Some(chunk.clone())
141 } else {
142 self.stats.misses.fetch_add(1, Ordering::Relaxed);
143 None
144 }
145 }
146
147 pub fn get_chunks_by_embedding(&self, embedding: &[f32], k: usize) -> Vec<Chunk> {
149 self.stats.embedding_lookups.fetch_add(1, Ordering::Relaxed);
150
151 let hash = hash_embedding(embedding);
152
153 if let Some(chunk_ids) = self.embedding_to_chunks.get(&hash) {
154 self.stats
155 .embedding_cache_hits
156 .fetch_add(1, Ordering::Relaxed);
157
158 let chunks: Vec<_> = chunk_ids
159 .iter()
160 .filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
161 .take(k)
162 .collect();
163
164 return chunks;
165 }
166
167 Vec::new()
168 }
169
170 pub fn get_document_chunks(&self, document_id: &str) -> Vec<Chunk> {
172 if let Some(ids) = self.document_to_chunks.get(document_id) {
173 ids.iter()
174 .filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
175 .collect()
176 } else {
177 Vec::new()
178 }
179 }
180
181 pub fn insert_chunk(&self, chunk: Chunk) {
183 let size = chunk.size() as u64;
184 let max_bytes = (self.max_size_mb * 1024 * 1024) as u64;
185
186 while self.current_size.load(Ordering::Relaxed) + size > max_bytes {
188 if !self.evict_one() {
189 break;
190 }
191 }
192
193 self.document_to_chunks
195 .entry(chunk.document_id.clone())
196 .or_default()
197 .insert(chunk.id);
198
199 if let Some(ref embedding) = chunk.embedding {
201 let hash = hash_embedding(embedding);
202 self.embedding_to_chunks
203 .entry(hash)
204 .or_default()
205 .push(chunk.id);
206 }
207
208 self.chunks.insert(chunk.id, chunk);
210 self.current_size.fetch_add(size, Ordering::Relaxed);
211 }
212
213 pub fn insert_chunks(&self, chunks: Vec<Chunk>) {
215 for chunk in chunks {
216 self.insert_chunk(chunk);
217 }
218 }
219
220 pub fn cache_embedding_result(&self, embedding: &[f32], chunk_ids: Vec<ChunkId>) {
222 let hash = hash_embedding(embedding);
223 self.embedding_to_chunks.insert(hash, chunk_ids);
224 }
225
226 pub fn remove_chunk(&self, id: ChunkId) {
228 if let Some((_, chunk)) = self.chunks.remove(&id) {
229 self.current_size
230 .fetch_sub(chunk.size() as u64, Ordering::Relaxed);
231
232 if let Some(mut ids) = self.document_to_chunks.get_mut(&chunk.document_id) {
234 ids.remove(&id);
235 }
236 }
237 }
238
239 pub fn remove_document(&self, document_id: &str) {
241 if let Some((_, ids)) = self.document_to_chunks.remove(document_id) {
242 for id in ids {
243 self.remove_chunk(id);
244 }
245 }
246 }
247
248 fn evict_one(&self) -> bool {
250 let mut oldest_id = None;
251 let mut oldest_time = Instant::now();
252
253 for entry in self.chunks.iter() {
254 if entry.created_at < oldest_time {
255 oldest_time = entry.created_at;
256 oldest_id = Some(*entry.key());
257 }
258 }
259
260 if let Some(id) = oldest_id {
261 self.remove_chunk(id);
262 return true;
263 }
264
265 false
266 }
267
268 pub fn stats(&self) -> RagCacheStatsSnapshot {
270 RagCacheStatsSnapshot {
271 chunk_count: self.chunks.len(),
272 document_count: self.document_to_chunks.len(),
273 size_bytes: self.current_size.load(Ordering::Relaxed),
274 max_size_bytes: (self.max_size_mb * 1024 * 1024) as u64,
275 hits: self.stats.hits.load(Ordering::Relaxed),
276 misses: self.stats.misses.load(Ordering::Relaxed),
277 embedding_lookups: self.stats.embedding_lookups.load(Ordering::Relaxed),
278 embedding_cache_hit_rate: {
279 let lookups = self.stats.embedding_lookups.load(Ordering::Relaxed);
280 let hits = self.stats.embedding_cache_hits.load(Ordering::Relaxed);
281 if lookups > 0 {
282 hits as f64 / lookups as f64
283 } else {
284 0.0
285 }
286 },
287 }
288 }
289
290 pub fn clear(&self) {
292 self.chunks.clear();
293 self.embedding_to_chunks.clear();
294 self.document_to_chunks.clear();
295 self.current_size.store(0, Ordering::Relaxed);
296 }
297}
298
299#[derive(Debug, Clone)]
301pub struct RagCacheStatsSnapshot {
302 pub chunk_count: usize,
303 pub document_count: usize,
304 pub size_bytes: u64,
305 pub max_size_bytes: u64,
306 pub hits: u64,
307 pub misses: u64,
308 pub embedding_lookups: u64,
309 pub embedding_cache_hit_rate: f64,
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_chunk_creation() {
318 let chunk = Chunk::new(1, "doc-1", "This is a test chunk").with_position(0);
319
320 assert_eq!(chunk.id, 1);
321 assert_eq!(chunk.document_id, "doc-1");
322 assert_eq!(chunk.position, 0);
323 }
324
325 #[test]
326 fn test_insert_and_get() {
327 let cache = RagChunkCache::new(10);
328
329 let chunk = Chunk::new(1, "doc-1", "Test content");
330 cache.insert_chunk(chunk);
331
332 let retrieved = cache.get_chunk(1);
333 assert!(retrieved.is_some());
334 assert_eq!(retrieved.unwrap().content, "Test content");
335 }
336
337 #[test]
338 fn test_document_chunks() {
339 let cache = RagChunkCache::new(10);
340
341 cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1").with_position(0));
342 cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2").with_position(1));
343 cache.insert_chunk(Chunk::new(3, "doc-2", "Chunk 3").with_position(0));
344
345 let doc1_chunks = cache.get_document_chunks("doc-1");
346 assert_eq!(doc1_chunks.len(), 2);
347
348 let doc2_chunks = cache.get_document_chunks("doc-2");
349 assert_eq!(doc2_chunks.len(), 1);
350 }
351
352 #[test]
353 fn test_embedding_lookup() {
354 let cache = RagChunkCache::new(10);
355
356 let embedding = vec![0.1, 0.2, 0.3];
357 let chunk = Chunk::new(1, "doc-1", "Embedded content").with_embedding(embedding.clone());
358
359 cache.insert_chunk(chunk);
360
361 cache.cache_embedding_result(&embedding, vec![1]);
363
364 let results = cache.get_chunks_by_embedding(&embedding, 10);
366 assert_eq!(results.len(), 1);
367 assert_eq!(results[0].id, 1);
368 }
369
370 #[test]
371 fn test_remove_document() {
372 let cache = RagChunkCache::new(10);
373
374 cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1"));
375 cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2"));
376
377 cache.remove_document("doc-1");
378
379 assert!(cache.get_chunk(1).is_none());
380 assert!(cache.get_chunk(2).is_none());
381 }
382
383 #[test]
384 fn test_stats() {
385 let cache = RagChunkCache::new(10);
386
387 cache.insert_chunk(Chunk::new(1, "doc-1", "Content"));
388 cache.get_chunk(1); cache.get_chunk(2); let stats = cache.stats();
392 assert_eq!(stats.chunk_count, 1);
393 assert_eq!(stats.hits, 1);
394 assert_eq!(stats.misses, 1);
395 }
396}