heliosdb_proxy/distribcache/ai/
rag.rs1use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::Instant;
10
11pub type ChunkId = u64;
13
14pub type EmbeddingHash = u64;
16
17#[derive(Debug, Clone)]
19pub struct Chunk {
20 pub id: ChunkId,
22 pub document_id: String,
24 pub content: String,
26 pub embedding: Option<Vec<f32>>,
28 pub position: usize,
30 pub metadata: Option<serde_json::Value>,
32 pub created_at: Instant,
34}
35
36impl Chunk {
37 pub fn new(id: ChunkId, document_id: impl Into<String>, content: impl Into<String>) -> Self {
39 Self {
40 id,
41 document_id: document_id.into(),
42 content: content.into(),
43 embedding: None,
44 position: 0,
45 metadata: None,
46 created_at: Instant::now(),
47 }
48 }
49
50 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
52 self.embedding = Some(embedding);
53 self
54 }
55
56 pub fn with_position(mut self, position: usize) -> Self {
58 self.position = position;
59 self
60 }
61
62 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
64 self.metadata = Some(metadata);
65 self
66 }
67
68 pub fn size(&self) -> usize {
70 self.content.len() +
71 self.document_id.len() +
72 self.embedding.as_ref().map(|e| e.len() * 4).unwrap_or(0) +
73 64
74 }
75}
76
77pub fn hash_embedding(embedding: &[f32]) -> EmbeddingHash {
79 use std::hash::{Hash, Hasher};
80 use std::collections::hash_map::DefaultHasher;
81
82 let mut hasher = DefaultHasher::new();
83
84 for val in embedding {
86 let quantized = (val * 1000.0) as i32;
87 quantized.hash(&mut hasher);
88 }
89
90 hasher.finish()
91}
92
93pub struct RagChunkCache {
95 chunks: DashMap<ChunkId, Chunk>,
97
98 embedding_to_chunks: DashMap<EmbeddingHash, Vec<ChunkId>>,
100
101 document_to_chunks: DashMap<String, HashSet<ChunkId>>,
103
104 max_size_mb: usize,
106
107 current_size: AtomicU64,
109
110 stats: RagCacheStats,
112}
113
114#[derive(Debug, Default)]
116struct RagCacheStats {
117 hits: AtomicU64,
118 misses: AtomicU64,
119 embedding_lookups: AtomicU64,
120 embedding_cache_hits: AtomicU64,
121}
122
123impl RagChunkCache {
124 pub fn new(max_size_mb: usize) -> Self {
126 Self {
127 chunks: DashMap::new(),
128 embedding_to_chunks: DashMap::new(),
129 document_to_chunks: DashMap::new(),
130 max_size_mb,
131 current_size: AtomicU64::new(0),
132 stats: RagCacheStats::default(),
133 }
134 }
135
136 pub fn get_chunk(&self, id: ChunkId) -> Option<Chunk> {
138 if let Some(chunk) = self.chunks.get(&id) {
139 self.stats.hits.fetch_add(1, Ordering::Relaxed);
140 Some(chunk.clone())
141 } else {
142 self.stats.misses.fetch_add(1, Ordering::Relaxed);
143 None
144 }
145 }
146
147 pub fn get_chunks_by_embedding(&self, embedding: &[f32], k: usize) -> Vec<Chunk> {
149 self.stats.embedding_lookups.fetch_add(1, Ordering::Relaxed);
150
151 let hash = hash_embedding(embedding);
152
153 if let Some(chunk_ids) = self.embedding_to_chunks.get(&hash) {
154 self.stats.embedding_cache_hits.fetch_add(1, Ordering::Relaxed);
155
156 let chunks: Vec<_> = chunk_ids.iter()
157 .filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
158 .take(k)
159 .collect();
160
161 return chunks;
162 }
163
164 Vec::new()
165 }
166
167 pub fn get_document_chunks(&self, document_id: &str) -> Vec<Chunk> {
169 if let Some(ids) = self.document_to_chunks.get(document_id) {
170 ids.iter()
171 .filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
172 .collect()
173 } else {
174 Vec::new()
175 }
176 }
177
178 pub fn insert_chunk(&self, chunk: Chunk) {
180 let size = chunk.size() as u64;
181 let max_bytes = (self.max_size_mb * 1024 * 1024) as u64;
182
183 while self.current_size.load(Ordering::Relaxed) + size > max_bytes {
185 if !self.evict_one() {
186 break;
187 }
188 }
189
190 self.document_to_chunks
192 .entry(chunk.document_id.clone())
193 .or_default()
194 .insert(chunk.id);
195
196 if let Some(ref embedding) = chunk.embedding {
198 let hash = hash_embedding(embedding);
199 self.embedding_to_chunks
200 .entry(hash)
201 .or_default()
202 .push(chunk.id);
203 }
204
205 self.chunks.insert(chunk.id, chunk);
207 self.current_size.fetch_add(size, Ordering::Relaxed);
208 }
209
210 pub fn insert_chunks(&self, chunks: Vec<Chunk>) {
212 for chunk in chunks {
213 self.insert_chunk(chunk);
214 }
215 }
216
217 pub fn cache_embedding_result(&self, embedding: &[f32], chunk_ids: Vec<ChunkId>) {
219 let hash = hash_embedding(embedding);
220 self.embedding_to_chunks.insert(hash, chunk_ids);
221 }
222
223 pub fn remove_chunk(&self, id: ChunkId) {
225 if let Some((_, chunk)) = self.chunks.remove(&id) {
226 self.current_size.fetch_sub(chunk.size() as u64, Ordering::Relaxed);
227
228 if let Some(mut ids) = self.document_to_chunks.get_mut(&chunk.document_id) {
230 ids.remove(&id);
231 }
232 }
233 }
234
235 pub fn remove_document(&self, document_id: &str) {
237 if let Some((_, ids)) = self.document_to_chunks.remove(document_id) {
238 for id in ids {
239 self.remove_chunk(id);
240 }
241 }
242 }
243
244 fn evict_one(&self) -> bool {
246 let mut oldest_id = None;
247 let mut oldest_time = Instant::now();
248
249 for entry in self.chunks.iter() {
250 if entry.created_at < oldest_time {
251 oldest_time = entry.created_at;
252 oldest_id = Some(*entry.key());
253 }
254 }
255
256 if let Some(id) = oldest_id {
257 self.remove_chunk(id);
258 return true;
259 }
260
261 false
262 }
263
264 pub fn stats(&self) -> RagCacheStatsSnapshot {
266 RagCacheStatsSnapshot {
267 chunk_count: self.chunks.len(),
268 document_count: self.document_to_chunks.len(),
269 size_bytes: self.current_size.load(Ordering::Relaxed),
270 max_size_bytes: (self.max_size_mb * 1024 * 1024) as u64,
271 hits: self.stats.hits.load(Ordering::Relaxed),
272 misses: self.stats.misses.load(Ordering::Relaxed),
273 embedding_lookups: self.stats.embedding_lookups.load(Ordering::Relaxed),
274 embedding_cache_hit_rate: {
275 let lookups = self.stats.embedding_lookups.load(Ordering::Relaxed);
276 let hits = self.stats.embedding_cache_hits.load(Ordering::Relaxed);
277 if lookups > 0 { hits as f64 / lookups as f64 } else { 0.0 }
278 },
279 }
280 }
281
282 pub fn clear(&self) {
284 self.chunks.clear();
285 self.embedding_to_chunks.clear();
286 self.document_to_chunks.clear();
287 self.current_size.store(0, Ordering::Relaxed);
288 }
289}
290
291#[derive(Debug, Clone)]
293pub struct RagCacheStatsSnapshot {
294 pub chunk_count: usize,
295 pub document_count: usize,
296 pub size_bytes: u64,
297 pub max_size_bytes: u64,
298 pub hits: u64,
299 pub misses: u64,
300 pub embedding_lookups: u64,
301 pub embedding_cache_hit_rate: f64,
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_chunk_creation() {
310 let chunk = Chunk::new(1, "doc-1", "This is a test chunk")
311 .with_position(0);
312
313 assert_eq!(chunk.id, 1);
314 assert_eq!(chunk.document_id, "doc-1");
315 assert_eq!(chunk.position, 0);
316 }
317
318 #[test]
319 fn test_insert_and_get() {
320 let cache = RagChunkCache::new(10);
321
322 let chunk = Chunk::new(1, "doc-1", "Test content");
323 cache.insert_chunk(chunk);
324
325 let retrieved = cache.get_chunk(1);
326 assert!(retrieved.is_some());
327 assert_eq!(retrieved.unwrap().content, "Test content");
328 }
329
330 #[test]
331 fn test_document_chunks() {
332 let cache = RagChunkCache::new(10);
333
334 cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1").with_position(0));
335 cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2").with_position(1));
336 cache.insert_chunk(Chunk::new(3, "doc-2", "Chunk 3").with_position(0));
337
338 let doc1_chunks = cache.get_document_chunks("doc-1");
339 assert_eq!(doc1_chunks.len(), 2);
340
341 let doc2_chunks = cache.get_document_chunks("doc-2");
342 assert_eq!(doc2_chunks.len(), 1);
343 }
344
345 #[test]
346 fn test_embedding_lookup() {
347 let cache = RagChunkCache::new(10);
348
349 let embedding = vec![0.1, 0.2, 0.3];
350 let chunk = Chunk::new(1, "doc-1", "Embedded content")
351 .with_embedding(embedding.clone());
352
353 cache.insert_chunk(chunk);
354
355 cache.cache_embedding_result(&embedding, vec![1]);
357
358 let results = cache.get_chunks_by_embedding(&embedding, 10);
360 assert_eq!(results.len(), 1);
361 assert_eq!(results[0].id, 1);
362 }
363
364 #[test]
365 fn test_remove_document() {
366 let cache = RagChunkCache::new(10);
367
368 cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1"));
369 cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2"));
370
371 cache.remove_document("doc-1");
372
373 assert!(cache.get_chunk(1).is_none());
374 assert!(cache.get_chunk(2).is_none());
375 }
376
377 #[test]
378 fn test_stats() {
379 let cache = RagChunkCache::new(10);
380
381 cache.insert_chunk(Chunk::new(1, "doc-1", "Content"));
382 cache.get_chunk(1); cache.get_chunk(2); let stats = cache.stats();
386 assert_eq!(stats.chunk_count, 1);
387 assert_eq!(stats.hits, 1);
388 assert_eq!(stats.misses, 1);
389 }
390}