1use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use crate::storage::engine::distance::DistanceMetric;
10use crate::storage::engine::graph_store::{GraphStore, StoredNode};
11use crate::storage::engine::graph_table_index::GraphTableIndex;
12use crate::storage::engine::unified_index::UnifiedIndex;
13use crate::storage::engine::vector_store::VectorStore;
14use crate::storage::query::unified::ExecutionError;
15
16use super::context::{ChunkSource, ContextChunk, RetrievalContext};
17use super::{EntityType, QueryAnalysis, RagConfig, SimilarEntity};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum RetrievalStrategy {
22 VectorFirst,
24 GraphFirst,
26 Hybrid,
28 VectorOnly,
30 GraphOnly,
32 TableOnly,
34}
35
36pub struct MultiSourceRetriever {
38 graph: Arc<GraphStore>,
40 index: Arc<GraphTableIndex>,
42 vector_store: Arc<VectorStore>,
44 unified_index: Arc<UnifiedIndex>,
46}
47
48impl MultiSourceRetriever {
49 pub fn new(
51 graph: Arc<GraphStore>,
52 index: Arc<GraphTableIndex>,
53 vector_store: Arc<VectorStore>,
54 unified_index: Arc<UnifiedIndex>,
55 ) -> Self {
56 Self {
57 graph,
58 index,
59 vector_store,
60 unified_index,
61 }
62 }
63
64 pub fn retrieve(
66 &self,
67 query: &str,
68 analysis: &QueryAnalysis,
69 config: &RagConfig,
70 ) -> Result<RetrievalContext, ExecutionError> {
71 let start = std::time::Instant::now();
72 let mut context = RetrievalContext::new(query);
73
74 match analysis.primary_strategy {
76 RetrievalStrategy::VectorFirst | RetrievalStrategy::VectorOnly => {
77 self.retrieve_vector(query, analysis, config, &mut context)?;
78
79 if analysis.primary_strategy != RetrievalStrategy::VectorOnly {
81 self.expand_with_graph(&mut context, config)?;
82 }
83 }
84 RetrievalStrategy::GraphFirst | RetrievalStrategy::GraphOnly => {
85 self.retrieve_graph(query, analysis, config, &mut context)?;
86
87 if analysis.primary_strategy != RetrievalStrategy::GraphOnly {
89 self.expand_with_vectors(&mut context, config)?;
90 }
91 }
92 RetrievalStrategy::Hybrid => {
93 self.retrieve_vector(query, analysis, config, &mut context)?;
95 self.retrieve_graph(query, analysis, config, &mut context)?;
96 }
97 RetrievalStrategy::TableOnly => {
98 self.retrieve_table(query, analysis, config, &mut context)?;
99 }
100 }
101
102 if config.expand_cross_refs {
104 self.expand_cross_refs(&mut context, config)?;
105 }
106
107 context.sort_by_relevance();
109 context.limit(config.max_total_chunks);
110 context.calculate_overall_relevance();
111 context.retrieval_time_us = start.elapsed().as_micros() as u64;
112
113 let explanation = format!(
115 "Retrieved {} chunks using {} strategy. Sources: {:?}",
116 context.len(),
117 match analysis.primary_strategy {
118 RetrievalStrategy::VectorFirst => "vector-first",
119 RetrievalStrategy::GraphFirst => "graph-first",
120 RetrievalStrategy::Hybrid => "hybrid",
121 RetrievalStrategy::VectorOnly => "vector-only",
122 RetrievalStrategy::GraphOnly => "graph-only",
123 RetrievalStrategy::TableOnly => "table-only",
124 },
125 context.sources_used
126 );
127 context.explanation = Some(explanation);
128
129 Ok(context)
130 }
131
132 fn retrieve_vector(
134 &self,
135 query: &str,
136 analysis: &QueryAnalysis,
137 config: &RagConfig,
138 context: &mut RetrievalContext,
139 ) -> Result<(), ExecutionError> {
140 let collections: Vec<&str> = if analysis.entity_types.is_empty() {
142 vec!["vulnerabilities", "hosts", "services"]
144 } else {
145 analysis
146 .entity_types
147 .iter()
148 .map(|t| t.collection_name())
149 .collect()
150 };
151
152 for collection in collections {
154 if let Some(coll) = self.vector_store.get(collection) {
156 let results = self.search_collection_by_keywords(
162 collection,
163 &analysis.keywords,
164 config.max_chunks_per_source,
165 );
166
167 for (id, content, relevance) in results {
168 let chunk = ContextChunk::from_vector(
169 content,
170 collection,
171 1.0 - relevance, id,
173 )
174 .with_entity_type(EntityType::from_str(collection));
175
176 context.add_chunk(chunk);
177 }
178 }
179 }
180
181 Ok(())
182 }
183
184 fn search_collection_by_keywords(
186 &self,
187 collection: &str,
188 keywords: &[String],
189 limit: usize,
190 ) -> Vec<(u64, String, f32)> {
191 Vec::new()
199 }
200
201 fn retrieve_graph(
203 &self,
204 query: &str,
205 analysis: &QueryAnalysis,
206 config: &RagConfig,
207 context: &mut RetrievalContext,
208 ) -> Result<(), ExecutionError> {
209 let start_nodes = self.find_graph_start_nodes(analysis, config);
211
212 for (node_id, node_type) in start_nodes {
214 self.traverse_and_collect(
215 &node_id,
216 node_type,
217 config.graph_depth,
218 context,
219 &mut HashSet::new(),
220 )?;
221 }
222
223 Ok(())
224 }
225
226 fn find_graph_start_nodes(
228 &self,
229 analysis: &QueryAnalysis,
230 config: &RagConfig,
231 ) -> Vec<(String, EntityType)> {
232 let mut nodes = Vec::new();
233
234 for keyword in &analysis.keywords {
236 if keyword.to_uppercase().starts_with("CVE-") {
238 if let Some(node) = self.graph.get_node(&keyword.to_uppercase()) {
239 nodes.push((node.id.clone(), EntityType::Vulnerability));
240 }
241 }
242
243 if keyword.contains('.') && keyword.chars().all(|c| c.is_ascii_digit() || c == '.') {
245 if let Some(node) = self.graph.get_node(keyword) {
246 nodes.push((node.id.clone(), EntityType::Host));
247 }
248 }
249 }
250
251 nodes.truncate(config.max_chunks_per_source);
253 nodes
254 }
255
256 fn traverse_and_collect(
258 &self,
259 node_id: &str,
260 entity_type: EntityType,
261 max_depth: u32,
262 context: &mut RetrievalContext,
263 visited: &mut HashSet<String>,
264 ) -> Result<(), ExecutionError> {
265 if max_depth == 0 || visited.contains(node_id) {
266 return Ok(());
267 }
268
269 visited.insert(node_id.to_string());
270
271 if let Some(node) = self.graph.get_node(node_id) {
273 let content = self.node_to_content(&node);
275
276 let chunk = ContextChunk::from_graph(
277 content,
278 max_depth - 1, entity_type,
280 node_id,
281 );
282
283 context.add_chunk(chunk);
284
285 let edges = self.graph.outgoing_edges(node_id);
287 for (edge_type, target_id, _weight) in edges {
288 if !visited.contains(&target_id) {
289 let target_type = self.infer_entity_type_from_edge(edge_type.as_str());
291
292 self.traverse_and_collect(
293 &target_id,
294 target_type,
295 max_depth - 1,
296 context,
297 visited,
298 )?;
299 }
300 }
301 }
302
303 Ok(())
304 }
305
306 fn node_to_content(&self, node: &StoredNode) -> String {
308 format!(
311 "{}: {} (label: {})",
312 node.node_type.as_str(),
313 node.id,
314 node.label
315 )
316 }
317
318 fn infer_entity_type_from_edge(&self, edge_type: &str) -> EntityType {
320 match edge_type.to_lowercase().as_str() {
321 "runs" | "hosts" => EntityType::Service,
322 "has_vuln" | "affects" => EntityType::Vulnerability,
323 "uses" | "depends_on" => EntityType::Technology,
324 "owns" | "created_by" => EntityType::User,
325 "connects_to" | "routes_to" => EntityType::Network,
326 "has_cert" | "secured_by" => EntityType::Certificate,
327 "resolves_to" | "has_domain" => EntityType::Domain,
328 _ => EntityType::Unknown,
329 }
330 }
331
332 fn retrieve_table(
334 &self,
335 _query: &str,
336 _analysis: &QueryAnalysis,
337 _config: &RagConfig,
338 _context: &mut RetrievalContext,
339 ) -> Result<(), ExecutionError> {
340 Ok(())
343 }
344
345 fn expand_with_vectors(
347 &self,
348 context: &mut RetrievalContext,
349 _config: &RagConfig,
350 ) -> Result<(), ExecutionError> {
351 let entity_ids: Vec<(String, EntityType)> = context
353 .chunks
354 .iter()
355 .filter(|c| matches!(c.source, ChunkSource::Graph))
356 .filter_map(|c| {
357 c.entity_id
358 .as_ref()
359 .map(|id| (id.clone(), c.entity_type.unwrap_or(EntityType::Unknown)))
360 })
361 .collect();
362
363 for (entity_id, _entity_type) in entity_ids {
364 let vec_refs = self.unified_index.get_node_vectors(&entity_id);
366 for vec_ref in vec_refs {
367 if let Some(_coll) = self.vector_store.get(&vec_ref.collection) {
369 }
372 }
373 }
374
375 Ok(())
376 }
377
378 fn expand_with_graph(
380 &self,
381 context: &mut RetrievalContext,
382 _config: &RagConfig,
383 ) -> Result<(), ExecutionError> {
384 let vector_entities: Vec<(u64, String)> = context
386 .chunks
387 .iter()
388 .filter(|c| matches!(c.source, ChunkSource::Vector(_)))
389 .filter_map(|c| {
390 c.entity_id
391 .as_ref()
392 .and_then(|id| id.parse().ok())
393 .map(|id| (id, c.source.collection().unwrap_or("unknown").to_string()))
394 })
395 .collect();
396
397 for (vector_id, collection) in vector_entities {
398 if let Some(node_id) = self.unified_index.get_vector_node(&collection, vector_id) {
400 let _entity_type = EntityType::from_str(&collection);
401
402 let edges = self.graph.outgoing_edges(&node_id);
404 for (edge_type, target_id, _weight) in edges.into_iter().take(3) {
405 if let Some(target_node) = self.graph.get_node(&target_id) {
406 let content = self.node_to_content(&target_node);
407 let target_type = self.infer_entity_type_from_edge(edge_type.as_str());
408
409 let chunk = ContextChunk::from_graph(
410 format!("{} -> {}: {}", edge_type.as_str(), target_node.id, content),
411 1,
412 target_type,
413 &target_node.id,
414 );
415
416 context.add_chunk(chunk);
417 }
418 }
419 }
420 }
421
422 Ok(())
423 }
424
425 fn expand_cross_refs(
427 &self,
428 context: &mut RetrievalContext,
429 _config: &RagConfig,
430 ) -> Result<(), ExecutionError> {
431 let existing_ids: Vec<(String, ChunkSource)> = context
433 .chunks
434 .iter()
435 .filter_map(|c| {
436 c.entity_id
437 .as_ref()
438 .map(|id| (id.clone(), c.source.clone()))
439 })
440 .collect();
441
442 for (id, source) in existing_ids {
443 match source {
444 ChunkSource::Vector(collection) => {
445 if let Ok(id_num) = id.parse::<u64>() {
447 if let Some(row_key) =
448 self.unified_index.get_vector_row(&collection, id_num)
449 {
450 let chunk = ContextChunk::new(
451 format!("Linked row: {}:{}", row_key.table, row_key.row_id),
452 ChunkSource::CrossRef,
453 0.5,
454 );
455 context.add_chunk(chunk);
456 }
457 }
458 }
459 ChunkSource::Graph => {
460 let vec_refs = self.unified_index.get_node_vectors(&id);
462 if let Some(vec_ref) = vec_refs.first() {
463 let chunk = ContextChunk::new(
464 format!("Has embedding in collection: {}", vec_ref.collection),
465 ChunkSource::CrossRef,
466 0.5,
467 );
468 context.add_chunk(chunk);
469 }
470 }
471 _ => {}
472 }
473 }
474
475 Ok(())
476 }
477
478 pub fn retrieve_by_vector(
480 &self,
481 vector: &[f32],
482 collection: &str,
483 k: usize,
484 config: &RagConfig,
485 ) -> Result<RetrievalContext, ExecutionError> {
486 let start = std::time::Instant::now();
487 let mut context = RetrievalContext::new(format!("vector search in {}", collection));
488
489 if let Some(coll) = self.vector_store.get(collection) {
491 let results = coll.search_with_filter(vector, k, None);
492
493 for result in results {
494 let relevance = 1.0 / (1.0 + result.distance);
496 if relevance < config.min_relevance {
497 continue;
498 }
499
500 let content = result
502 .metadata
503 .as_ref()
504 .and_then(|m| m.strings.get("content").cloned())
505 .unwrap_or_else(|| format!("Vector {} in {}", result.id, collection));
506
507 let chunk =
508 ContextChunk::from_vector(content, collection, result.distance, result.id)
509 .with_entity_type(EntityType::from_str(collection));
510
511 context.add_chunk(chunk);
512 }
513 }
514
515 if config.expand_cross_refs {
517 self.expand_with_graph(&mut context, config)?;
518 }
519
520 context.sort_by_relevance();
521 context.calculate_overall_relevance();
522 context.retrieval_time_us = start.elapsed().as_micros() as u64;
523
524 Ok(context)
525 }
526
527 pub fn expand_context(
529 &self,
530 entity_id: &str,
531 entity_type: EntityType,
532 depth: u32,
533 config: &RagConfig,
534 ) -> Result<RetrievalContext, ExecutionError> {
535 let start = std::time::Instant::now();
536 let mut context = RetrievalContext::new(format!(
537 "expand {}:{}",
538 entity_type.collection_name(),
539 entity_id
540 ));
541
542 self.traverse_and_collect(
544 entity_id,
545 entity_type,
546 depth,
547 &mut context,
548 &mut HashSet::new(),
549 )?;
550
551 let vec_refs = self.unified_index.get_node_vectors(entity_id);
553 if !vec_refs.is_empty() {
554 }
557
558 context.sort_by_relevance();
559 context.calculate_overall_relevance();
560 context.retrieval_time_us = start.elapsed().as_micros() as u64;
561
562 Ok(context)
563 }
564
565 pub fn find_similar(
567 &self,
568 collection: &str,
569 entity_id: u64,
570 k: usize,
571 ) -> Result<Vec<SimilarEntity>, ExecutionError> {
572 let coll = self
574 .vector_store
575 .get(collection)
576 .ok_or_else(|| ExecutionError::new(format!("Collection not found: {}", collection)))?;
577
578 Ok(Vec::new())
581 }
582}
583
584pub struct InMemoryRetriever {
590 chunks: Vec<StoredChunk>,
592 vectors: HashMap<String, Vec<(u64, Vec<f32>, String)>>,
594}
595
596struct StoredChunk {
597 content: String,
598 source: ChunkSource,
599 entity_type: Option<EntityType>,
600 entity_id: Option<String>,
601 keywords: Vec<String>,
602}
603
604impl InMemoryRetriever {
605 pub fn new() -> Self {
606 Self {
607 chunks: Vec::new(),
608 vectors: HashMap::new(),
609 }
610 }
611
612 pub fn add_chunk(
614 &mut self,
615 content: &str,
616 source: ChunkSource,
617 entity_type: Option<EntityType>,
618 keywords: Vec<String>,
619 ) {
620 self.chunks.push(StoredChunk {
621 content: content.to_string(),
622 source,
623 entity_type,
624 entity_id: None,
625 keywords,
626 });
627 }
628
629 pub fn add_vector(&mut self, collection: &str, id: u64, vector: Vec<f32>, content: &str) {
631 self.vectors
632 .entry(collection.to_string())
633 .or_default()
634 .push((id, vector, content.to_string()));
635 }
636
637 pub fn search_keywords(&self, keywords: &[String], limit: usize) -> RetrievalContext {
639 let mut context = RetrievalContext::new(keywords.join(" "));
640
641 for chunk in &self.chunks {
642 let matches: usize = keywords
643 .iter()
644 .filter(|kw| {
645 chunk.keywords.contains(kw)
646 || chunk.content.to_lowercase().contains(&kw.to_lowercase())
647 })
648 .count();
649
650 if matches > 0 {
651 let relevance = matches as f32 / keywords.len().max(1) as f32;
652 let ctx_chunk = ContextChunk::new(&chunk.content, chunk.source.clone(), relevance)
653 .with_entity_type(chunk.entity_type.unwrap_or(EntityType::Unknown));
654
655 context.add_chunk(ctx_chunk);
656 }
657 }
658
659 context.sort_by_relevance();
660 context.limit(limit);
661 context.calculate_overall_relevance();
662 context
663 }
664
665 pub fn search_vector(&self, collection: &str, query: &[f32], k: usize) -> RetrievalContext {
667 let mut context = RetrievalContext::new(format!("vector search {}", collection));
668
669 if let Some(vectors) = self.vectors.get(collection) {
670 let mut distances: Vec<(u64, f32, &str)> = vectors
671 .iter()
672 .map(|(id, vec, content)| {
673 let dist =
674 crate::storage::engine::distance::distance(query, vec, DistanceMetric::L2);
675 (*id, dist, content.as_str())
676 })
677 .collect();
678
679 distances.sort_by(|a, b| {
680 a.1.partial_cmp(&b.1)
681 .unwrap_or(std::cmp::Ordering::Equal)
682 .then_with(|| a.0.cmp(&b.0))
683 });
684
685 for (id, dist, content) in distances.into_iter().take(k) {
686 let chunk = ContextChunk::from_vector(content, collection, dist, id);
687 context.add_chunk(chunk);
688 }
689 }
690
691 context.calculate_overall_relevance();
692 context
693 }
694}
695
696impl Default for InMemoryRetriever {
697 fn default() -> Self {
698 Self::new()
699 }
700}
701
702#[cfg(test)]
707mod tests {
708 use super::*;
709
710 #[test]
711 fn test_in_memory_keyword_search() {
712 let mut retriever = InMemoryRetriever::new();
713
714 retriever.add_chunk(
715 "CVE-2024-1234 is a critical SQL injection vulnerability in nginx",
716 ChunkSource::Intelligence,
717 Some(EntityType::Vulnerability),
718 vec!["cve".to_string(), "sql".to_string(), "nginx".to_string()],
719 );
720
721 retriever.add_chunk(
722 "Host 192.168.1.1 runs nginx web server",
723 ChunkSource::Graph,
724 Some(EntityType::Host),
725 vec!["host".to_string(), "nginx".to_string()],
726 );
727
728 let context = retriever.search_keywords(&["nginx".to_string()], 10);
729 assert_eq!(context.len(), 2);
730
731 let context = retriever.search_keywords(&["cve".to_string(), "sql".to_string()], 10);
732 assert_eq!(context.len(), 1);
733 }
734
735 #[test]
736 fn test_in_memory_vector_search() {
737 let mut retriever = InMemoryRetriever::new();
738
739 retriever.add_vector("vulns", 1, vec![1.0, 0.0, 0.0], "CVE-2024-1234");
740 retriever.add_vector("vulns", 2, vec![0.9, 0.1, 0.0], "CVE-2024-5678");
741 retriever.add_vector("vulns", 3, vec![0.0, 1.0, 0.0], "CVE-2024-9999");
742
743 let context = retriever.search_vector("vulns", &[1.0, 0.0, 0.0], 2);
744 assert_eq!(context.len(), 2);
745
746 let top = context.top_chunk().unwrap();
748 assert!(top.content.contains("1234"));
749 }
750
751 #[test]
752 fn test_retrieval_strategy() {
753 assert_eq!(
754 RetrievalStrategy::VectorFirst,
755 RetrievalStrategy::VectorFirst
756 );
757 assert_ne!(
758 RetrievalStrategy::VectorFirst,
759 RetrievalStrategy::GraphFirst
760 );
761 }
762}