Skip to main content

reddb_server/storage/query/rag/
context.rs

1//! Retrieval Context
2//!
3//! Represents the contextual information retrieved for a query,
4//! including chunks from various sources with relevance scoring.
5
6use std::collections::HashMap;
7
8use super::EntityType;
9use crate::storage::schema::Value;
10use std::cmp::Ordering;
11
12/// Complete context retrieved for a query
13#[derive(Debug, Clone)]
14pub struct RetrievalContext {
15    /// Original query
16    pub query: String,
17    /// Retrieved context chunks
18    pub chunks: Vec<ContextChunk>,
19    /// Overall relevance score
20    pub overall_relevance: f32,
21    /// Sources used in retrieval
22    pub sources_used: Vec<ChunkSource>,
23    /// Total retrieval time in microseconds
24    pub retrieval_time_us: u64,
25    /// Explanation of retrieval strategy
26    pub explanation: Option<String>,
27}
28
29impl RetrievalContext {
30    /// Create a new retrieval context
31    pub fn new(query: impl Into<String>) -> Self {
32        Self {
33            query: query.into(),
34            chunks: Vec::new(),
35            overall_relevance: 0.0,
36            sources_used: Vec::new(),
37            retrieval_time_us: 0,
38            explanation: None,
39        }
40    }
41
42    /// Add a chunk to the context
43    pub fn add_chunk(&mut self, chunk: ContextChunk) {
44        if !self.sources_used.contains(&chunk.source) {
45            self.sources_used.push(chunk.source.clone());
46        }
47        self.chunks.push(chunk);
48    }
49
50    /// Sort chunks by relevance (descending)
51    pub fn sort_by_relevance(&mut self) {
52        self.chunks.sort_by(|a, b| {
53            b.relevance
54                .partial_cmp(&a.relevance)
55                .unwrap_or(Ordering::Equal)
56                .then_with(|| {
57                    let a_entity = a.entity_id.as_deref().unwrap_or("");
58                    let b_entity = b.entity_id.as_deref().unwrap_or("");
59                    a_entity.cmp(b_entity)
60                })
61                .then_with(|| a.source.name().cmp(b.source.name()))
62                .then_with(|| a.content.cmp(&b.content))
63        });
64    }
65
66    /// Calculate overall relevance from chunks
67    pub fn calculate_overall_relevance(&mut self) {
68        if self.chunks.is_empty() {
69            self.overall_relevance = 0.0;
70            return;
71        }
72
73        // Use weighted average, with top results weighted higher
74        let total_weight: f32 = (1..=self.chunks.len()).map(|i| 1.0 / i as f32).sum();
75
76        let weighted_sum: f32 = self
77            .chunks
78            .iter()
79            .enumerate()
80            .map(|(i, c)| c.relevance * (1.0 / (i + 1) as f32))
81            .sum();
82
83        self.overall_relevance = weighted_sum / total_weight;
84    }
85
86    /// Limit to top N chunks
87    pub fn limit(&mut self, n: usize) {
88        self.sort_by_relevance();
89        self.chunks.truncate(n);
90    }
91
92    /// Get chunks for a specific entity type
93    pub fn chunks_for_type(&self, entity_type: EntityType) -> Vec<&ContextChunk> {
94        self.chunks
95            .iter()
96            .filter(|c| c.entity_type == Some(entity_type))
97            .collect()
98    }
99
100    /// Get chunks from a specific source
101    pub fn chunks_from_source(&self, source: &ChunkSource) -> Vec<&ContextChunk> {
102        self.chunks.iter().filter(|c| &c.source == source).collect()
103    }
104
105    /// Check if context is empty
106    pub fn is_empty(&self) -> bool {
107        self.chunks.is_empty()
108    }
109
110    /// Number of chunks
111    pub fn len(&self) -> usize {
112        self.chunks.len()
113    }
114
115    /// Get the top chunk
116    pub fn top_chunk(&self) -> Option<&ContextChunk> {
117        self.chunks.first()
118    }
119
120    /// Convert to a text representation for LLM context
121    pub fn to_context_string(&self) -> String {
122        let mut s = String::new();
123
124        for (i, chunk) in self.chunks.iter().enumerate() {
125            s.push_str(&format!("[{}] ", i + 1));
126            s.push_str(&chunk.to_text());
127            s.push('\n');
128        }
129
130        s
131    }
132
133    /// Get entity IDs mentioned in the context
134    pub fn entity_ids(&self) -> Vec<&str> {
135        self.chunks
136            .iter()
137            .filter_map(|c| c.entity_id.as_deref())
138            .collect()
139    }
140
141    /// Merge another context into this one
142    pub fn merge(&mut self, other: RetrievalContext) {
143        for chunk in other.chunks {
144            self.add_chunk(chunk);
145        }
146        self.retrieval_time_us += other.retrieval_time_us;
147    }
148
149    /// Set explanation
150    pub fn with_explanation(mut self, explanation: impl Into<String>) -> Self {
151        self.explanation = Some(explanation.into());
152        self
153    }
154}
155
156/// A single chunk of context
157#[derive(Debug, Clone)]
158pub struct ContextChunk {
159    /// The content of this chunk
160    pub content: String,
161    /// Source of this chunk
162    pub source: ChunkSource,
163    /// Relevance score (0.0-1.0)
164    pub relevance: f32,
165    /// Entity type if applicable
166    pub entity_type: Option<EntityType>,
167    /// Entity ID if applicable
168    pub entity_id: Option<String>,
169    /// Additional metadata
170    pub metadata: HashMap<String, Value>,
171    /// Distance/similarity score from vector search (if applicable)
172    pub vector_distance: Option<f32>,
173    /// Graph depth from query entity (if applicable)
174    pub graph_depth: Option<u32>,
175}
176
177impl ContextChunk {
178    /// Create a new chunk
179    pub fn new(content: impl Into<String>, source: ChunkSource, relevance: f32) -> Self {
180        Self {
181            content: content.into(),
182            source,
183            relevance,
184            entity_type: None,
185            entity_id: None,
186            metadata: HashMap::new(),
187            vector_distance: None,
188            graph_depth: None,
189        }
190    }
191
192    /// Create from vector search result
193    pub fn from_vector(
194        content: impl Into<String>,
195        collection: impl Into<String>,
196        distance: f32,
197        id: u64,
198    ) -> Self {
199        let relevance = 1.0 / (1.0 + distance); // Convert distance to relevance
200        let mut chunk = Self::new(content, ChunkSource::Vector(collection.into()), relevance);
201        chunk.vector_distance = Some(distance);
202        chunk.entity_id = Some(id.to_string());
203        chunk
204    }
205
206    /// Create from graph traversal
207    pub fn from_graph(
208        content: impl Into<String>,
209        depth: u32,
210        entity_type: EntityType,
211        entity_id: impl Into<String>,
212    ) -> Self {
213        // Relevance decreases with depth
214        let relevance = 1.0 / (1.0 + depth as f32);
215        let mut chunk = Self::new(content, ChunkSource::Graph, relevance);
216        chunk.graph_depth = Some(depth);
217        chunk.entity_type = Some(entity_type);
218        chunk.entity_id = Some(entity_id.into());
219        chunk
220    }
221
222    /// Create from table query
223    pub fn from_table(
224        content: impl Into<String>,
225        table: impl Into<String>,
226        row_id: u64,
227        relevance: f32,
228    ) -> Self {
229        let mut chunk = Self::new(content, ChunkSource::Table(table.into()), relevance);
230        chunk.entity_id = Some(row_id.to_string());
231        chunk
232    }
233
234    /// Set entity type
235    pub fn with_entity_type(mut self, entity_type: EntityType) -> Self {
236        self.entity_type = Some(entity_type);
237        self
238    }
239
240    /// Set entity ID
241    pub fn with_entity_id(mut self, id: impl Into<String>) -> Self {
242        self.entity_id = Some(id.into());
243        self
244    }
245
246    /// Add metadata
247    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
248        self.metadata.insert(key.into(), value);
249        self
250    }
251
252    /// Convert to text representation
253    pub fn to_text(&self) -> String {
254        let mut parts = Vec::new();
255
256        // Source info
257        parts.push(format!("[{}]", self.source.name()));
258
259        // Entity info
260        if let Some(ref id) = self.entity_id {
261            if let Some(entity_type) = self.entity_type {
262                parts.push(format!("{:?}:{}", entity_type, id));
263            } else {
264                parts.push(format!("id:{}", id));
265            }
266        }
267
268        // Score info
269        parts.push(format!("relevance:{:.2}", self.relevance));
270
271        // Content
272        format!("{}: {}", parts.join(" "), self.content)
273    }
274}
275
276/// Source of a context chunk
277#[derive(Debug, Clone, PartialEq, Eq, Hash)]
278pub enum ChunkSource {
279    /// From vector similarity search
280    Vector(String), // collection name
281    /// From graph traversal
282    Graph,
283    /// From table/structured query
284    Table(String), // table name
285    /// From cross-reference expansion
286    CrossRef,
287    /// From intelligence layer
288    Intelligence,
289    /// Cached/previously retrieved
290    Cache,
291}
292
293impl ChunkSource {
294    /// Get a display name for the source
295    pub fn name(&self) -> &str {
296        match self {
297            Self::Vector(_) => "vector",
298            Self::Graph => "graph",
299            Self::Table(_) => "table",
300            Self::CrossRef => "cross-ref",
301            Self::Intelligence => "intel",
302            Self::Cache => "cache",
303        }
304    }
305
306    /// Get the collection/table name if applicable
307    pub fn collection(&self) -> Option<&str> {
308        match self {
309            Self::Vector(c) | Self::Table(c) => Some(c),
310            _ => None,
311        }
312    }
313}
314
315// ============================================================================
316// Context Builder
317// ============================================================================
318
319/// Builder for creating retrieval contexts
320pub struct ContextBuilder {
321    context: RetrievalContext,
322}
323
324impl ContextBuilder {
325    /// Start building a new context
326    pub fn new(query: impl Into<String>) -> Self {
327        Self {
328            context: RetrievalContext::new(query),
329        }
330    }
331
332    /// Add a chunk
333    pub fn chunk(mut self, chunk: ContextChunk) -> Self {
334        self.context.add_chunk(chunk);
335        self
336    }
337
338    /// Add a vector result
339    pub fn vector_result(
340        mut self,
341        content: impl Into<String>,
342        collection: impl Into<String>,
343        distance: f32,
344        id: u64,
345    ) -> Self {
346        self.context
347            .add_chunk(ContextChunk::from_vector(content, collection, distance, id));
348        self
349    }
350
351    /// Add a graph result
352    pub fn graph_result(
353        mut self,
354        content: impl Into<String>,
355        depth: u32,
356        entity_type: EntityType,
357        entity_id: impl Into<String>,
358    ) -> Self {
359        self.context.add_chunk(ContextChunk::from_graph(
360            content,
361            depth,
362            entity_type,
363            entity_id,
364        ));
365        self
366    }
367
368    /// Add a table result
369    pub fn table_result(
370        mut self,
371        content: impl Into<String>,
372        table: impl Into<String>,
373        row_id: u64,
374        relevance: f32,
375    ) -> Self {
376        self.context
377            .add_chunk(ContextChunk::from_table(content, table, row_id, relevance));
378        self
379    }
380
381    /// Set retrieval time
382    pub fn time_us(mut self, time: u64) -> Self {
383        self.context.retrieval_time_us = time;
384        self
385    }
386
387    /// Set explanation
388    pub fn explanation(mut self, explanation: impl Into<String>) -> Self {
389        self.context.explanation = Some(explanation.into());
390        self
391    }
392
393    /// Build the context
394    pub fn build(mut self) -> RetrievalContext {
395        self.context.sort_by_relevance();
396        self.context.calculate_overall_relevance();
397        self.context
398    }
399}
400
401// ============================================================================
402// Tests
403// ============================================================================
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_context_builder() {
411        let context = ContextBuilder::new("test query")
412            .vector_result(
413                "CVE-2024-1234: SQL injection vulnerability",
414                "vulns",
415                0.1,
416                1,
417            )
418            .vector_result("CVE-2024-5678: XSS vulnerability", "vulns", 0.3, 2)
419            .graph_result("Host 192.168.1.1 runs nginx", 1, EntityType::Host, "h1")
420            .time_us(1000)
421            .build();
422
423        assert_eq!(context.len(), 3);
424        assert!(context.overall_relevance > 0.0);
425
426        // Top chunk should be the closest vector result
427        let top = context.top_chunk().unwrap();
428        assert!(matches!(top.source, ChunkSource::Vector(_)));
429    }
430
431    #[test]
432    fn test_relevance_calculation() {
433        let mut context = RetrievalContext::new("test");
434        context.add_chunk(ContextChunk::new("A", ChunkSource::Graph, 1.0));
435        context.add_chunk(ContextChunk::new("B", ChunkSource::Graph, 0.5));
436        context.add_chunk(ContextChunk::new("C", ChunkSource::Graph, 0.25));
437
438        context.calculate_overall_relevance();
439
440        // Weighted average should be between min and max
441        assert!(context.overall_relevance > 0.25);
442        assert!(context.overall_relevance < 1.0);
443    }
444
445    #[test]
446    fn test_context_filtering() {
447        let mut context = RetrievalContext::new("test");
448        context.add_chunk(
449            ContextChunk::new("Host info", ChunkSource::Graph, 0.9)
450                .with_entity_type(EntityType::Host),
451        );
452        context.add_chunk(
453            ContextChunk::new("Vuln info", ChunkSource::Graph, 0.8)
454                .with_entity_type(EntityType::Vulnerability),
455        );
456
457        let hosts = context.chunks_for_type(EntityType::Host);
458        assert_eq!(hosts.len(), 1);
459        assert!(hosts[0].content.contains("Host"));
460    }
461
462    #[test]
463    fn test_context_merge() {
464        let mut context1 = RetrievalContext::new("test");
465        context1.add_chunk(ContextChunk::new("A", ChunkSource::Graph, 0.9));
466        context1.retrieval_time_us = 100;
467
468        let mut context2 = RetrievalContext::new("test");
469        context2.add_chunk(ContextChunk::new(
470            "B",
471            ChunkSource::Vector("v".to_string()),
472            0.8,
473        ));
474        context2.retrieval_time_us = 200;
475
476        context1.merge(context2);
477
478        assert_eq!(context1.len(), 2);
479        assert_eq!(context1.retrieval_time_us, 300);
480    }
481
482    #[test]
483    fn test_to_context_string() {
484        let context = ContextBuilder::new("test")
485            .vector_result("Important finding", "vulns", 0.1, 1)
486            .build();
487
488        let text = context.to_context_string();
489        assert!(text.contains("[1]"));
490        assert!(text.contains("vector"));
491        assert!(text.contains("Important finding"));
492    }
493}