Skip to main content

reddb_server/storage/query/rag/
retriever.rs

1//! Multi-Source Retriever
2//!
3//! Implements retrieval strategies that combine vector search,
4//! graph traversal, and table queries for comprehensive context.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use crate::storage::engine::distance::DistanceMetric;
10use crate::storage::engine::graph_store::{GraphStore, StoredNode};
11use crate::storage::engine::graph_table_index::GraphTableIndex;
12use crate::storage::engine::unified_index::UnifiedIndex;
13use crate::storage::engine::vector_store::VectorStore;
14use crate::storage::query::unified::ExecutionError;
15
16use super::context::{ChunkSource, ContextChunk, RetrievalContext};
17use super::{EntityType, QueryAnalysis, RagConfig, SimilarEntity};
18
19/// Retrieval strategy
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum RetrievalStrategy {
22    /// Use vector search as primary source
23    VectorFirst,
24    /// Use graph traversal as primary source
25    GraphFirst,
26    /// Combine vector and graph equally (hybrid)
27    Hybrid,
28    /// Only use vector search
29    VectorOnly,
30    /// Only use graph traversal
31    GraphOnly,
32    /// Table/structured query only
33    TableOnly,
34}
35
36/// Multi-source retriever that combines vector, graph, and table queries
37pub struct MultiSourceRetriever {
38    /// Graph store
39    graph: Arc<GraphStore>,
40    /// Graph-table index
41    index: Arc<GraphTableIndex>,
42    /// Vector store
43    vector_store: Arc<VectorStore>,
44    /// Cross-reference index
45    unified_index: Arc<UnifiedIndex>,
46}
47
48impl MultiSourceRetriever {
49    /// Create a new multi-source retriever
50    pub fn new(
51        graph: Arc<GraphStore>,
52        index: Arc<GraphTableIndex>,
53        vector_store: Arc<VectorStore>,
54        unified_index: Arc<UnifiedIndex>,
55    ) -> Self {
56        Self {
57            graph,
58            index,
59            vector_store,
60            unified_index,
61        }
62    }
63
64    /// Retrieve context based on query analysis
65    pub fn retrieve(
66        &self,
67        query: &str,
68        analysis: &QueryAnalysis,
69        config: &RagConfig,
70    ) -> Result<RetrievalContext, ExecutionError> {
71        let start = std::time::Instant::now();
72        let mut context = RetrievalContext::new(query);
73
74        // Execute based on primary strategy
75        match analysis.primary_strategy {
76            RetrievalStrategy::VectorFirst | RetrievalStrategy::VectorOnly => {
77                self.retrieve_vector(query, analysis, config, &mut context)?;
78
79                // Add graph context if not vector-only
80                if analysis.primary_strategy != RetrievalStrategy::VectorOnly {
81                    self.expand_with_graph(&mut context, config)?;
82                }
83            }
84            RetrievalStrategy::GraphFirst | RetrievalStrategy::GraphOnly => {
85                self.retrieve_graph(query, analysis, config, &mut context)?;
86
87                // Add vector context if not graph-only
88                if analysis.primary_strategy != RetrievalStrategy::GraphOnly {
89                    self.expand_with_vectors(&mut context, config)?;
90                }
91            }
92            RetrievalStrategy::Hybrid => {
93                // Execute both in parallel conceptually, then merge
94                self.retrieve_vector(query, analysis, config, &mut context)?;
95                self.retrieve_graph(query, analysis, config, &mut context)?;
96            }
97            RetrievalStrategy::TableOnly => {
98                self.retrieve_table(query, analysis, config, &mut context)?;
99            }
100        }
101
102        // Cross-reference expansion if enabled
103        if config.expand_cross_refs {
104            self.expand_cross_refs(&mut context, config)?;
105        }
106
107        // Finalize
108        context.sort_by_relevance();
109        context.limit(config.max_total_chunks);
110        context.calculate_overall_relevance();
111        context.retrieval_time_us = start.elapsed().as_micros() as u64;
112
113        // Add explanation
114        let explanation = format!(
115            "Retrieved {} chunks using {} strategy. Sources: {:?}",
116            context.len(),
117            match analysis.primary_strategy {
118                RetrievalStrategy::VectorFirst => "vector-first",
119                RetrievalStrategy::GraphFirst => "graph-first",
120                RetrievalStrategy::Hybrid => "hybrid",
121                RetrievalStrategy::VectorOnly => "vector-only",
122                RetrievalStrategy::GraphOnly => "graph-only",
123                RetrievalStrategy::TableOnly => "table-only",
124            },
125            context.sources_used
126        );
127        context.explanation = Some(explanation);
128
129        Ok(context)
130    }
131
132    /// Retrieve context using vector search
133    fn retrieve_vector(
134        &self,
135        query: &str,
136        analysis: &QueryAnalysis,
137        config: &RagConfig,
138        context: &mut RetrievalContext,
139    ) -> Result<(), ExecutionError> {
140        // Determine which collections to search based on entity types
141        let collections: Vec<&str> = if analysis.entity_types.is_empty() {
142            // Search all relevant collections
143            vec!["vulnerabilities", "hosts", "services"]
144        } else {
145            analysis
146                .entity_types
147                .iter()
148                .map(|t| t.collection_name())
149                .collect()
150        };
151
152        // For each collection, execute vector search
153        for collection in collections {
154            // Check if collection exists
155            if let Some(coll) = self.vector_store.get(collection) {
156                // Note: In a real implementation, we'd need to embed the query text
157                // For now, we'll look for pre-embedded entities that might match
158
159                // Get recent/relevant vectors from the collection
160                // This is a simplified approach - real RAG would embed the query
161                let results = self.search_collection_by_keywords(
162                    collection,
163                    &analysis.keywords,
164                    config.max_chunks_per_source,
165                );
166
167                for (id, content, relevance) in results {
168                    let chunk = ContextChunk::from_vector(
169                        content,
170                        collection,
171                        1.0 - relevance, // Convert relevance to distance
172                        id,
173                    )
174                    .with_entity_type(EntityType::from_str(collection));
175
176                    context.add_chunk(chunk);
177                }
178            }
179        }
180
181        Ok(())
182    }
183
184    /// Search a collection by keywords (simplified - would use embeddings in real impl)
185    fn search_collection_by_keywords(
186        &self,
187        collection: &str,
188        keywords: &[String],
189        limit: usize,
190    ) -> Vec<(u64, String, f32)> {
191        // This is a placeholder - in a real implementation:
192        // 1. Embed the keywords using an embedding model
193        // 2. Search the vector collection
194        // 3. Return results with actual content
195
196        // For now, return empty - the vector store would need
197        // a metadata-based search or we'd need embeddings
198        Vec::new()
199    }
200
201    /// Retrieve context using graph traversal
202    fn retrieve_graph(
203        &self,
204        query: &str,
205        analysis: &QueryAnalysis,
206        config: &RagConfig,
207        context: &mut RetrievalContext,
208    ) -> Result<(), ExecutionError> {
209        // Find starting nodes based on entity types and keywords
210        let start_nodes = self.find_graph_start_nodes(analysis, config);
211
212        // Traverse from each start node
213        for (node_id, node_type) in start_nodes {
214            self.traverse_and_collect(
215                &node_id,
216                node_type,
217                config.graph_depth,
218                context,
219                &mut HashSet::new(),
220            )?;
221        }
222
223        Ok(())
224    }
225
226    /// Find starting nodes for graph traversal
227    fn find_graph_start_nodes(
228        &self,
229        analysis: &QueryAnalysis,
230        config: &RagConfig,
231    ) -> Vec<(String, EntityType)> {
232        let mut nodes = Vec::new();
233
234        // Look for nodes matching keywords
235        for keyword in &analysis.keywords {
236            // Check if keyword looks like a CVE
237            if keyword.to_uppercase().starts_with("CVE-") {
238                if let Some(node) = self.graph.get_node(&keyword.to_uppercase()) {
239                    nodes.push((node.id.clone(), EntityType::Vulnerability));
240                }
241            }
242
243            // Check if keyword looks like an IP
244            if keyword.contains('.') && keyword.chars().all(|c| c.is_ascii_digit() || c == '.') {
245                if let Some(node) = self.graph.get_node(keyword) {
246                    nodes.push((node.id.clone(), EntityType::Host));
247                }
248            }
249        }
250
251        // Limit number of start nodes
252        nodes.truncate(config.max_chunks_per_source);
253        nodes
254    }
255
256    /// Traverse graph from a node and collect context
257    fn traverse_and_collect(
258        &self,
259        node_id: &str,
260        entity_type: EntityType,
261        max_depth: u32,
262        context: &mut RetrievalContext,
263        visited: &mut HashSet<String>,
264    ) -> Result<(), ExecutionError> {
265        if max_depth == 0 || visited.contains(node_id) {
266            return Ok(());
267        }
268
269        visited.insert(node_id.to_string());
270
271        // Get node information
272        if let Some(node) = self.graph.get_node(node_id) {
273            // Create content string from node
274            let content = self.node_to_content(&node);
275
276            let chunk = ContextChunk::from_graph(
277                content,
278                max_depth - 1, // Depth from start (lower = closer)
279                entity_type,
280                node_id,
281            );
282
283            context.add_chunk(chunk);
284
285            // Get outgoing edges and continue traversal
286            let edges = self.graph.outgoing_edges(node_id);
287            for (edge_type, target_id, _weight) in edges {
288                if !visited.contains(&target_id) {
289                    // Determine target entity type from edge type
290                    let target_type = self.infer_entity_type_from_edge(edge_type.as_str());
291
292                    self.traverse_and_collect(
293                        &target_id,
294                        target_type,
295                        max_depth - 1,
296                        context,
297                        visited,
298                    )?;
299                }
300            }
301        }
302
303        Ok(())
304    }
305
306    /// Convert node to content string
307    fn node_to_content(&self, node: &StoredNode) -> String {
308        // StoredNode has id, label, node_type but no properties HashMap
309        // Just use the available fields
310        format!(
311            "{}: {} (label: {})",
312            node.node_type.as_str(),
313            node.id,
314            node.label
315        )
316    }
317
318    /// Infer entity type from edge type
319    fn infer_entity_type_from_edge(&self, edge_type: &str) -> EntityType {
320        match edge_type.to_lowercase().as_str() {
321            "runs" | "hosts" => EntityType::Service,
322            "has_vuln" | "affects" => EntityType::Vulnerability,
323            "uses" | "depends_on" => EntityType::Technology,
324            "owns" | "created_by" => EntityType::User,
325            "connects_to" | "routes_to" => EntityType::Network,
326            "has_cert" | "secured_by" => EntityType::Certificate,
327            "resolves_to" | "has_domain" => EntityType::Domain,
328            _ => EntityType::Unknown,
329        }
330    }
331
332    /// Retrieve from table queries
333    fn retrieve_table(
334        &self,
335        _query: &str,
336        _analysis: &QueryAnalysis,
337        _config: &RagConfig,
338        _context: &mut RetrievalContext,
339    ) -> Result<(), ExecutionError> {
340        // Table retrieval would use the GraphTableIndex to find relevant rows
341        // This is a placeholder for the full implementation
342        Ok(())
343    }
344
345    /// Expand context with vector similarity
346    fn expand_with_vectors(
347        &self,
348        context: &mut RetrievalContext,
349        _config: &RagConfig,
350    ) -> Result<(), ExecutionError> {
351        // For entities found via graph, find similar vectors
352        let entity_ids: Vec<(String, EntityType)> = context
353            .chunks
354            .iter()
355            .filter(|c| matches!(c.source, ChunkSource::Graph))
356            .filter_map(|c| {
357                c.entity_id
358                    .as_ref()
359                    .map(|id| (id.clone(), c.entity_type.unwrap_or(EntityType::Unknown)))
360            })
361            .collect();
362
363        for (entity_id, _entity_type) in entity_ids {
364            // Check if this entity has vectors in unified index
365            let vec_refs = self.unified_index.get_node_vectors(&entity_id);
366            for vec_ref in vec_refs {
367                // Search for similar vectors
368                if let Some(_coll) = self.vector_store.get(&vec_ref.collection) {
369                    // Would search for similar vectors here
370                    // This requires the vector data which we'd get from the collection
371                }
372            }
373        }
374
375        Ok(())
376    }
377
378    /// Expand context with graph relationships
379    fn expand_with_graph(
380        &self,
381        context: &mut RetrievalContext,
382        _config: &RagConfig,
383    ) -> Result<(), ExecutionError> {
384        // For entities found via vector search, traverse graph relationships
385        let vector_entities: Vec<(u64, String)> = context
386            .chunks
387            .iter()
388            .filter(|c| matches!(c.source, ChunkSource::Vector(_)))
389            .filter_map(|c| {
390                c.entity_id
391                    .as_ref()
392                    .and_then(|id| id.parse().ok())
393                    .map(|id| (id, c.source.collection().unwrap_or("unknown").to_string()))
394            })
395            .collect();
396
397        for (vector_id, collection) in vector_entities {
398            // Check if this vector is linked to a graph node
399            if let Some(node_id) = self.unified_index.get_vector_node(&collection, vector_id) {
400                let _entity_type = EntityType::from_str(&collection);
401
402                // Get immediate neighbors via outgoing edges
403                let edges = self.graph.outgoing_edges(&node_id);
404                for (edge_type, target_id, _weight) in edges.into_iter().take(3) {
405                    if let Some(target_node) = self.graph.get_node(&target_id) {
406                        let content = self.node_to_content(&target_node);
407                        let target_type = self.infer_entity_type_from_edge(edge_type.as_str());
408
409                        let chunk = ContextChunk::from_graph(
410                            format!("{} -> {}: {}", edge_type.as_str(), target_node.id, content),
411                            1,
412                            target_type,
413                            &target_node.id,
414                        );
415
416                        context.add_chunk(chunk);
417                    }
418                }
419            }
420        }
421
422        Ok(())
423    }
424
425    /// Expand context using cross-references
426    fn expand_cross_refs(
427        &self,
428        context: &mut RetrievalContext,
429        _config: &RagConfig,
430    ) -> Result<(), ExecutionError> {
431        // Find cross-references for existing chunks
432        let existing_ids: Vec<(String, ChunkSource)> = context
433            .chunks
434            .iter()
435            .filter_map(|c| {
436                c.entity_id
437                    .as_ref()
438                    .map(|id| (id.clone(), c.source.clone()))
439            })
440            .collect();
441
442        for (id, source) in existing_ids {
443            match source {
444                ChunkSource::Vector(collection) => {
445                    // Vector -> check for linked node and row
446                    if let Ok(id_num) = id.parse::<u64>() {
447                        if let Some(row_key) =
448                            self.unified_index.get_vector_row(&collection, id_num)
449                        {
450                            let chunk = ContextChunk::new(
451                                format!("Linked row: {}:{}", row_key.table, row_key.row_id),
452                                ChunkSource::CrossRef,
453                                0.5,
454                            );
455                            context.add_chunk(chunk);
456                        }
457                    }
458                }
459                ChunkSource::Graph => {
460                    // Graph -> check for linked vectors (returns Vec)
461                    let vec_refs = self.unified_index.get_node_vectors(&id);
462                    if let Some(vec_ref) = vec_refs.first() {
463                        let chunk = ContextChunk::new(
464                            format!("Has embedding in collection: {}", vec_ref.collection),
465                            ChunkSource::CrossRef,
466                            0.5,
467                        );
468                        context.add_chunk(chunk);
469                    }
470                }
471                _ => {}
472            }
473        }
474
475        Ok(())
476    }
477
478    /// Retrieve context by vector directly
479    pub fn retrieve_by_vector(
480        &self,
481        vector: &[f32],
482        collection: &str,
483        k: usize,
484        config: &RagConfig,
485    ) -> Result<RetrievalContext, ExecutionError> {
486        let start = std::time::Instant::now();
487        let mut context = RetrievalContext::new(format!("vector search in {}", collection));
488
489        // Execute vector search
490        if let Some(coll) = self.vector_store.get(collection) {
491            let results = coll.search_with_filter(vector, k, None);
492
493            for result in results {
494                // Skip if below threshold
495                let relevance = 1.0 / (1.0 + result.distance);
496                if relevance < config.min_relevance {
497                    continue;
498                }
499
500                // Get content from metadata or generate placeholder
501                let content = result
502                    .metadata
503                    .as_ref()
504                    .and_then(|m| m.strings.get("content").cloned())
505                    .unwrap_or_else(|| format!("Vector {} in {}", result.id, collection));
506
507                let chunk =
508                    ContextChunk::from_vector(content, collection, result.distance, result.id)
509                        .with_entity_type(EntityType::from_str(collection));
510
511                context.add_chunk(chunk);
512            }
513        }
514
515        // Expand with graph context if enabled
516        if config.expand_cross_refs {
517            self.expand_with_graph(&mut context, config)?;
518        }
519
520        context.sort_by_relevance();
521        context.calculate_overall_relevance();
522        context.retrieval_time_us = start.elapsed().as_micros() as u64;
523
524        Ok(context)
525    }
526
527    /// Expand context around a known entity
528    pub fn expand_context(
529        &self,
530        entity_id: &str,
531        entity_type: EntityType,
532        depth: u32,
533        config: &RagConfig,
534    ) -> Result<RetrievalContext, ExecutionError> {
535        let start = std::time::Instant::now();
536        let mut context = RetrievalContext::new(format!(
537            "expand {}:{}",
538            entity_type.collection_name(),
539            entity_id
540        ));
541
542        // Traverse graph from entity
543        self.traverse_and_collect(
544            entity_id,
545            entity_type,
546            depth,
547            &mut context,
548            &mut HashSet::new(),
549        )?;
550
551        // Add vector similarity if entity has embedding
552        let vec_refs = self.unified_index.get_node_vectors(entity_id);
553        if !vec_refs.is_empty() {
554            // Would search for similar vectors here
555            // Requires getting the vector data first
556        }
557
558        context.sort_by_relevance();
559        context.calculate_overall_relevance();
560        context.retrieval_time_us = start.elapsed().as_micros() as u64;
561
562        Ok(context)
563    }
564
565    /// Find similar entities by vector
566    pub fn find_similar(
567        &self,
568        collection: &str,
569        entity_id: u64,
570        k: usize,
571    ) -> Result<Vec<SimilarEntity>, ExecutionError> {
572        // Get the vector for this entity
573        let coll = self
574            .vector_store
575            .get(collection)
576            .ok_or_else(|| ExecutionError::new(format!("Collection not found: {}", collection)))?;
577
578        // Would need to get vector by ID - this requires extending VectorCollection
579        // For now, return empty
580        Ok(Vec::new())
581    }
582}
583
584// ============================================================================
585// In-Memory Retriever for Testing
586// ============================================================================
587
588/// In-memory retriever for testing without full storage backends
589pub struct InMemoryRetriever {
590    /// Stored chunks
591    chunks: Vec<StoredChunk>,
592    /// Simple vector index
593    vectors: HashMap<String, Vec<(u64, Vec<f32>, String)>>,
594}
595
596struct StoredChunk {
597    content: String,
598    source: ChunkSource,
599    entity_type: Option<EntityType>,
600    entity_id: Option<String>,
601    keywords: Vec<String>,
602}
603
604impl InMemoryRetriever {
605    pub fn new() -> Self {
606        Self {
607            chunks: Vec::new(),
608            vectors: HashMap::new(),
609        }
610    }
611
612    /// Add a chunk
613    pub fn add_chunk(
614        &mut self,
615        content: &str,
616        source: ChunkSource,
617        entity_type: Option<EntityType>,
618        keywords: Vec<String>,
619    ) {
620        self.chunks.push(StoredChunk {
621            content: content.to_string(),
622            source,
623            entity_type,
624            entity_id: None,
625            keywords,
626        });
627    }
628
629    /// Add a vector
630    pub fn add_vector(&mut self, collection: &str, id: u64, vector: Vec<f32>, content: &str) {
631        self.vectors
632            .entry(collection.to_string())
633            .or_default()
634            .push((id, vector, content.to_string()));
635    }
636
637    /// Search by keywords
638    pub fn search_keywords(&self, keywords: &[String], limit: usize) -> RetrievalContext {
639        let mut context = RetrievalContext::new(keywords.join(" "));
640
641        for chunk in &self.chunks {
642            let matches: usize = keywords
643                .iter()
644                .filter(|kw| {
645                    chunk.keywords.contains(kw)
646                        || chunk.content.to_lowercase().contains(&kw.to_lowercase())
647                })
648                .count();
649
650            if matches > 0 {
651                let relevance = matches as f32 / keywords.len().max(1) as f32;
652                let ctx_chunk = ContextChunk::new(&chunk.content, chunk.source.clone(), relevance)
653                    .with_entity_type(chunk.entity_type.unwrap_or(EntityType::Unknown));
654
655                context.add_chunk(ctx_chunk);
656            }
657        }
658
659        context.sort_by_relevance();
660        context.limit(limit);
661        context.calculate_overall_relevance();
662        context
663    }
664
665    /// Vector search
666    pub fn search_vector(&self, collection: &str, query: &[f32], k: usize) -> RetrievalContext {
667        let mut context = RetrievalContext::new(format!("vector search {}", collection));
668
669        if let Some(vectors) = self.vectors.get(collection) {
670            let mut distances: Vec<(u64, f32, &str)> = vectors
671                .iter()
672                .map(|(id, vec, content)| {
673                    let dist =
674                        crate::storage::engine::distance::distance(query, vec, DistanceMetric::L2);
675                    (*id, dist, content.as_str())
676                })
677                .collect();
678
679            distances.sort_by(|a, b| {
680                a.1.partial_cmp(&b.1)
681                    .unwrap_or(std::cmp::Ordering::Equal)
682                    .then_with(|| a.0.cmp(&b.0))
683            });
684
685            for (id, dist, content) in distances.into_iter().take(k) {
686                let chunk = ContextChunk::from_vector(content, collection, dist, id);
687                context.add_chunk(chunk);
688            }
689        }
690
691        context.calculate_overall_relevance();
692        context
693    }
694}
695
696impl Default for InMemoryRetriever {
697    fn default() -> Self {
698        Self::new()
699    }
700}
701
702// ============================================================================
703// Tests
704// ============================================================================
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709
710    #[test]
711    fn test_in_memory_keyword_search() {
712        let mut retriever = InMemoryRetriever::new();
713
714        retriever.add_chunk(
715            "CVE-2024-1234 is a critical SQL injection vulnerability in nginx",
716            ChunkSource::Intelligence,
717            Some(EntityType::Vulnerability),
718            vec!["cve".to_string(), "sql".to_string(), "nginx".to_string()],
719        );
720
721        retriever.add_chunk(
722            "Host 192.168.1.1 runs nginx web server",
723            ChunkSource::Graph,
724            Some(EntityType::Host),
725            vec!["host".to_string(), "nginx".to_string()],
726        );
727
728        let context = retriever.search_keywords(&["nginx".to_string()], 10);
729        assert_eq!(context.len(), 2);
730
731        let context = retriever.search_keywords(&["cve".to_string(), "sql".to_string()], 10);
732        assert_eq!(context.len(), 1);
733    }
734
735    #[test]
736    fn test_in_memory_vector_search() {
737        let mut retriever = InMemoryRetriever::new();
738
739        retriever.add_vector("vulns", 1, vec![1.0, 0.0, 0.0], "CVE-2024-1234");
740        retriever.add_vector("vulns", 2, vec![0.9, 0.1, 0.0], "CVE-2024-5678");
741        retriever.add_vector("vulns", 3, vec![0.0, 1.0, 0.0], "CVE-2024-9999");
742
743        let context = retriever.search_vector("vulns", &[1.0, 0.0, 0.0], 2);
744        assert_eq!(context.len(), 2);
745
746        // First result should be the exact match
747        let top = context.top_chunk().unwrap();
748        assert!(top.content.contains("1234"));
749    }
750
751    #[test]
752    fn test_retrieval_strategy() {
753        assert_eq!(
754            RetrievalStrategy::VectorFirst,
755            RetrievalStrategy::VectorFirst
756        );
757        assert_ne!(
758            RetrievalStrategy::VectorFirst,
759            RetrievalStrategy::GraphFirst
760        );
761    }
762}