reddb_server/storage/query/rag/
context.rs1use std::collections::HashMap;
7
8use super::EntityType;
9use crate::storage::schema::Value;
10use std::cmp::Ordering;
11
12#[derive(Debug, Clone)]
14pub struct RetrievalContext {
15 pub query: String,
17 pub chunks: Vec<ContextChunk>,
19 pub overall_relevance: f32,
21 pub sources_used: Vec<ChunkSource>,
23 pub retrieval_time_us: u64,
25 pub explanation: Option<String>,
27}
28
29impl RetrievalContext {
30 pub fn new(query: impl Into<String>) -> Self {
32 Self {
33 query: query.into(),
34 chunks: Vec::new(),
35 overall_relevance: 0.0,
36 sources_used: Vec::new(),
37 retrieval_time_us: 0,
38 explanation: None,
39 }
40 }
41
42 pub fn add_chunk(&mut self, chunk: ContextChunk) {
44 if !self.sources_used.contains(&chunk.source) {
45 self.sources_used.push(chunk.source.clone());
46 }
47 self.chunks.push(chunk);
48 }
49
50 pub fn sort_by_relevance(&mut self) {
52 self.chunks.sort_by(|a, b| {
53 b.relevance
54 .partial_cmp(&a.relevance)
55 .unwrap_or(Ordering::Equal)
56 .then_with(|| {
57 let a_entity = a.entity_id.as_deref().unwrap_or("");
58 let b_entity = b.entity_id.as_deref().unwrap_or("");
59 a_entity.cmp(b_entity)
60 })
61 .then_with(|| a.source.name().cmp(b.source.name()))
62 .then_with(|| a.content.cmp(&b.content))
63 });
64 }
65
66 pub fn calculate_overall_relevance(&mut self) {
68 if self.chunks.is_empty() {
69 self.overall_relevance = 0.0;
70 return;
71 }
72
73 let total_weight: f32 = (1..=self.chunks.len()).map(|i| 1.0 / i as f32).sum();
75
76 let weighted_sum: f32 = self
77 .chunks
78 .iter()
79 .enumerate()
80 .map(|(i, c)| c.relevance * (1.0 / (i + 1) as f32))
81 .sum();
82
83 self.overall_relevance = weighted_sum / total_weight;
84 }
85
86 pub fn limit(&mut self, n: usize) {
88 self.sort_by_relevance();
89 self.chunks.truncate(n);
90 }
91
92 pub fn chunks_for_type(&self, entity_type: EntityType) -> Vec<&ContextChunk> {
94 self.chunks
95 .iter()
96 .filter(|c| c.entity_type == Some(entity_type))
97 .collect()
98 }
99
100 pub fn chunks_from_source(&self, source: &ChunkSource) -> Vec<&ContextChunk> {
102 self.chunks.iter().filter(|c| &c.source == source).collect()
103 }
104
105 pub fn is_empty(&self) -> bool {
107 self.chunks.is_empty()
108 }
109
110 pub fn len(&self) -> usize {
112 self.chunks.len()
113 }
114
115 pub fn top_chunk(&self) -> Option<&ContextChunk> {
117 self.chunks.first()
118 }
119
120 pub fn to_context_string(&self) -> String {
122 let mut s = String::new();
123
124 for (i, chunk) in self.chunks.iter().enumerate() {
125 s.push_str(&format!("[{}] ", i + 1));
126 s.push_str(&chunk.to_text());
127 s.push('\n');
128 }
129
130 s
131 }
132
133 pub fn entity_ids(&self) -> Vec<&str> {
135 self.chunks
136 .iter()
137 .filter_map(|c| c.entity_id.as_deref())
138 .collect()
139 }
140
141 pub fn merge(&mut self, other: RetrievalContext) {
143 for chunk in other.chunks {
144 self.add_chunk(chunk);
145 }
146 self.retrieval_time_us += other.retrieval_time_us;
147 }
148
149 pub fn with_explanation(mut self, explanation: impl Into<String>) -> Self {
151 self.explanation = Some(explanation.into());
152 self
153 }
154}
155
156#[derive(Debug, Clone)]
158pub struct ContextChunk {
159 pub content: String,
161 pub source: ChunkSource,
163 pub relevance: f32,
165 pub entity_type: Option<EntityType>,
167 pub entity_id: Option<String>,
169 pub metadata: HashMap<String, Value>,
171 pub vector_distance: Option<f32>,
173 pub graph_depth: Option<u32>,
175}
176
177impl ContextChunk {
178 pub fn new(content: impl Into<String>, source: ChunkSource, relevance: f32) -> Self {
180 Self {
181 content: content.into(),
182 source,
183 relevance,
184 entity_type: None,
185 entity_id: None,
186 metadata: HashMap::new(),
187 vector_distance: None,
188 graph_depth: None,
189 }
190 }
191
192 pub fn from_vector(
194 content: impl Into<String>,
195 collection: impl Into<String>,
196 distance: f32,
197 id: u64,
198 ) -> Self {
199 let relevance = 1.0 / (1.0 + distance); let mut chunk = Self::new(content, ChunkSource::Vector(collection.into()), relevance);
201 chunk.vector_distance = Some(distance);
202 chunk.entity_id = Some(id.to_string());
203 chunk
204 }
205
206 pub fn from_graph(
208 content: impl Into<String>,
209 depth: u32,
210 entity_type: EntityType,
211 entity_id: impl Into<String>,
212 ) -> Self {
213 let relevance = 1.0 / (1.0 + depth as f32);
215 let mut chunk = Self::new(content, ChunkSource::Graph, relevance);
216 chunk.graph_depth = Some(depth);
217 chunk.entity_type = Some(entity_type);
218 chunk.entity_id = Some(entity_id.into());
219 chunk
220 }
221
222 pub fn from_table(
224 content: impl Into<String>,
225 table: impl Into<String>,
226 row_id: u64,
227 relevance: f32,
228 ) -> Self {
229 let mut chunk = Self::new(content, ChunkSource::Table(table.into()), relevance);
230 chunk.entity_id = Some(row_id.to_string());
231 chunk
232 }
233
234 pub fn with_entity_type(mut self, entity_type: EntityType) -> Self {
236 self.entity_type = Some(entity_type);
237 self
238 }
239
240 pub fn with_entity_id(mut self, id: impl Into<String>) -> Self {
242 self.entity_id = Some(id.into());
243 self
244 }
245
246 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
248 self.metadata.insert(key.into(), value);
249 self
250 }
251
252 pub fn to_text(&self) -> String {
254 let mut parts = Vec::new();
255
256 parts.push(format!("[{}]", self.source.name()));
258
259 if let Some(ref id) = self.entity_id {
261 if let Some(entity_type) = self.entity_type {
262 parts.push(format!("{:?}:{}", entity_type, id));
263 } else {
264 parts.push(format!("id:{}", id));
265 }
266 }
267
268 parts.push(format!("relevance:{:.2}", self.relevance));
270
271 format!("{}: {}", parts.join(" "), self.content)
273 }
274}
275
276#[derive(Debug, Clone, PartialEq, Eq, Hash)]
278pub enum ChunkSource {
279 Vector(String), Graph,
283 Table(String), CrossRef,
287 Intelligence,
289 Cache,
291}
292
293impl ChunkSource {
294 pub fn name(&self) -> &str {
296 match self {
297 Self::Vector(_) => "vector",
298 Self::Graph => "graph",
299 Self::Table(_) => "table",
300 Self::CrossRef => "cross-ref",
301 Self::Intelligence => "intel",
302 Self::Cache => "cache",
303 }
304 }
305
306 pub fn collection(&self) -> Option<&str> {
308 match self {
309 Self::Vector(c) | Self::Table(c) => Some(c),
310 _ => None,
311 }
312 }
313}
314
315pub struct ContextBuilder {
321 context: RetrievalContext,
322}
323
324impl ContextBuilder {
325 pub fn new(query: impl Into<String>) -> Self {
327 Self {
328 context: RetrievalContext::new(query),
329 }
330 }
331
332 pub fn chunk(mut self, chunk: ContextChunk) -> Self {
334 self.context.add_chunk(chunk);
335 self
336 }
337
338 pub fn vector_result(
340 mut self,
341 content: impl Into<String>,
342 collection: impl Into<String>,
343 distance: f32,
344 id: u64,
345 ) -> Self {
346 self.context
347 .add_chunk(ContextChunk::from_vector(content, collection, distance, id));
348 self
349 }
350
351 pub fn graph_result(
353 mut self,
354 content: impl Into<String>,
355 depth: u32,
356 entity_type: EntityType,
357 entity_id: impl Into<String>,
358 ) -> Self {
359 self.context.add_chunk(ContextChunk::from_graph(
360 content,
361 depth,
362 entity_type,
363 entity_id,
364 ));
365 self
366 }
367
368 pub fn table_result(
370 mut self,
371 content: impl Into<String>,
372 table: impl Into<String>,
373 row_id: u64,
374 relevance: f32,
375 ) -> Self {
376 self.context
377 .add_chunk(ContextChunk::from_table(content, table, row_id, relevance));
378 self
379 }
380
381 pub fn time_us(mut self, time: u64) -> Self {
383 self.context.retrieval_time_us = time;
384 self
385 }
386
387 pub fn explanation(mut self, explanation: impl Into<String>) -> Self {
389 self.context.explanation = Some(explanation.into());
390 self
391 }
392
393 pub fn build(mut self) -> RetrievalContext {
395 self.context.sort_by_relevance();
396 self.context.calculate_overall_relevance();
397 self.context
398 }
399}
400
401#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_context_builder() {
411 let context = ContextBuilder::new("test query")
412 .vector_result(
413 "CVE-2024-1234: SQL injection vulnerability",
414 "vulns",
415 0.1,
416 1,
417 )
418 .vector_result("CVE-2024-5678: XSS vulnerability", "vulns", 0.3, 2)
419 .graph_result("Host 192.168.1.1 runs nginx", 1, EntityType::Host, "h1")
420 .time_us(1000)
421 .build();
422
423 assert_eq!(context.len(), 3);
424 assert!(context.overall_relevance > 0.0);
425
426 let top = context.top_chunk().unwrap();
428 assert!(matches!(top.source, ChunkSource::Vector(_)));
429 }
430
431 #[test]
432 fn test_relevance_calculation() {
433 let mut context = RetrievalContext::new("test");
434 context.add_chunk(ContextChunk::new("A", ChunkSource::Graph, 1.0));
435 context.add_chunk(ContextChunk::new("B", ChunkSource::Graph, 0.5));
436 context.add_chunk(ContextChunk::new("C", ChunkSource::Graph, 0.25));
437
438 context.calculate_overall_relevance();
439
440 assert!(context.overall_relevance > 0.25);
442 assert!(context.overall_relevance < 1.0);
443 }
444
445 #[test]
446 fn test_context_filtering() {
447 let mut context = RetrievalContext::new("test");
448 context.add_chunk(
449 ContextChunk::new("Host info", ChunkSource::Graph, 0.9)
450 .with_entity_type(EntityType::Host),
451 );
452 context.add_chunk(
453 ContextChunk::new("Vuln info", ChunkSource::Graph, 0.8)
454 .with_entity_type(EntityType::Vulnerability),
455 );
456
457 let hosts = context.chunks_for_type(EntityType::Host);
458 assert_eq!(hosts.len(), 1);
459 assert!(hosts[0].content.contains("Host"));
460 }
461
462 #[test]
463 fn test_context_merge() {
464 let mut context1 = RetrievalContext::new("test");
465 context1.add_chunk(ContextChunk::new("A", ChunkSource::Graph, 0.9));
466 context1.retrieval_time_us = 100;
467
468 let mut context2 = RetrievalContext::new("test");
469 context2.add_chunk(ContextChunk::new(
470 "B",
471 ChunkSource::Vector("v".to_string()),
472 0.8,
473 ));
474 context2.retrieval_time_us = 200;
475
476 context1.merge(context2);
477
478 assert_eq!(context1.len(), 2);
479 assert_eq!(context1.retrieval_time_us, 300);
480 }
481
482 #[test]
483 fn test_to_context_string() {
484 let context = ContextBuilder::new("test")
485 .vector_result("Important finding", "vulns", 0.1, 1)
486 .build();
487
488 let text = context.to_context_string();
489 assert!(text.contains("[1]"));
490 assert!(text.contains("vector"));
491 assert!(text.contains("Important finding"));
492 }
493}