Skip to main content

reddb_server/storage/query/rag/
mod.rs

1//! RAG (Retrieval-Augmented Generation) Engine
2//!
3//! This module provides intelligent context retrieval by combining:
4//! - Vector similarity search (semantic matching)
5//! - Graph traversal (relationship-based context)
6//! - Table queries (structured data filtering)
7//!
8//! The RAG engine is designed for security intelligence use cases:
9//! - "What vulnerabilities affect this host?" → Vector + Graph
10//! - "Similar CVEs to CVE-2024-1234" → Pure vector search
11//! - "Attack paths to database servers" → Graph + Vector ranking
12//!
13//! # Architecture
14//!
15//! ```text
16//! Query → Analyzer → Strategy Selection → Parallel Retrieval → Fusion → Context
17//!                         │
18//!                    ┌────┼────┐
19//!                    ▼    ▼    ▼
20//!                 Vector Graph Table
21//!                    │    │    │
22//!                    └────┼────┘
23//!                         ▼
24//!                   Context Fusion
25//!                         │
26//!                         ▼
27//!                 Ranked Results + Explanations
28//! ```
29
30pub mod context;
31pub mod fusion;
32pub mod retriever;
33pub mod unified_adapter;
34
35use std::collections::HashMap;
36use std::sync::Arc;
37
38use crate::storage::engine::graph_store::GraphStore;
39use crate::storage::engine::graph_table_index::GraphTableIndex;
40use crate::storage::engine::unified_index::UnifiedIndex;
41use crate::storage::engine::vector_store::VectorStore;
42use crate::storage::query::unified::ExecutionError;
43use crate::storage::schema::Value;
44
45pub use context::{ChunkSource, ContextChunk, RetrievalContext};
46pub use fusion::{ContextFusion, FusionConfig, ResultReranker};
47pub use retriever::{MultiSourceRetriever, RetrievalStrategy};
48pub use unified_adapter::{
49    EdgeDirection, EdgePatternSpec, GraphQueryPattern, MatchSource, MatchedEntity, MetadataQuery,
50    MultiModalQuery, NodePattern, QueryCondition, QueryValue, UnifiedQueryResult,
51    UnifiedQueryStats, UnifiedStoreAdapter,
52};
53
54/// RAG Engine configuration
55#[derive(Debug, Clone)]
56pub struct RagConfig {
57    /// Maximum number of chunks to retrieve per source
58    pub max_chunks_per_source: usize,
59    /// Maximum total chunks to return
60    pub max_total_chunks: usize,
61    /// Default vector similarity threshold
62    pub similarity_threshold: f32,
63    /// Graph traversal depth for context expansion
64    pub graph_depth: u32,
65    /// Enable cross-reference expansion
66    pub expand_cross_refs: bool,
67    /// Minimum relevance score to include in results
68    pub min_relevance: f32,
69}
70
71impl Default for RagConfig {
72    fn default() -> Self {
73        Self {
74            max_chunks_per_source: 10,
75            max_total_chunks: 25,
76            similarity_threshold: 0.8,
77            graph_depth: 2,
78            expand_cross_refs: true,
79            min_relevance: 0.3,
80        }
81    }
82}
83
84/// The main RAG Engine that orchestrates retrieval
85pub struct RagEngine {
86    /// Configuration
87    config: RagConfig,
88    /// Multi-source retriever
89    retriever: MultiSourceRetriever,
90    /// Query analyzer for strategy selection
91    analyzer: QueryAnalyzer,
92}
93
94impl RagEngine {
95    /// Create a new RAG engine with all storage backends
96    pub fn new(
97        graph: Arc<GraphStore>,
98        index: Arc<GraphTableIndex>,
99        vector_store: Arc<VectorStore>,
100        unified_index: Arc<UnifiedIndex>,
101    ) -> Self {
102        Self {
103            config: RagConfig::default(),
104            retriever: MultiSourceRetriever::new(graph, index, vector_store, unified_index),
105            analyzer: QueryAnalyzer::new(),
106        }
107    }
108
109    /// Configure the RAG engine
110    pub fn with_config(mut self, config: RagConfig) -> Self {
111        self.config = config;
112        self
113    }
114
115    /// Retrieve context for a query
116    pub fn retrieve(&self, query: &str) -> Result<RetrievalContext, ExecutionError> {
117        // 1. Analyze the query to determine best retrieval strategy
118        let analysis = self.analyzer.analyze(query);
119
120        // 2. Execute retrieval with determined strategy
121        let context = self.retriever.retrieve(query, &analysis, &self.config)?;
122
123        Ok(context)
124    }
125
126    /// Retrieve with explicit strategy override
127    pub fn retrieve_with_strategy(
128        &self,
129        query: &str,
130        strategy: RetrievalStrategy,
131    ) -> Result<RetrievalContext, ExecutionError> {
132        let analysis = QueryAnalysis {
133            primary_strategy: strategy,
134            ..self.analyzer.analyze(query)
135        };
136
137        self.retriever.retrieve(query, &analysis, &self.config)
138    }
139
140    /// Retrieve with a query vector (for embedding-based queries)
141    pub fn retrieve_by_vector(
142        &self,
143        vector: &[f32],
144        collection: &str,
145        k: usize,
146    ) -> Result<RetrievalContext, ExecutionError> {
147        self.retriever
148            .retrieve_by_vector(vector, collection, k, &self.config)
149    }
150
151    /// Expand context around a known entity
152    pub fn expand_context(
153        &self,
154        entity_id: &str,
155        entity_type: EntityType,
156        depth: u32,
157    ) -> Result<RetrievalContext, ExecutionError> {
158        self.retriever
159            .expand_context(entity_id, entity_type, depth, &self.config)
160    }
161
162    /// Get similar entities by vector
163    pub fn find_similar(
164        &self,
165        collection: &str,
166        entity_id: u64,
167        k: usize,
168    ) -> Result<Vec<SimilarEntity>, ExecutionError> {
169        self.retriever.find_similar(collection, entity_id, k)
170    }
171}
172
173// ============================================================================
174// Query Analysis
175// ============================================================================
176
177/// Analyzed query with strategy recommendations
178#[derive(Debug, Clone)]
179pub struct QueryAnalysis {
180    /// Primary retrieval strategy
181    pub primary_strategy: RetrievalStrategy,
182    /// Secondary strategies to combine
183    pub secondary_strategies: Vec<RetrievalStrategy>,
184    /// Detected entity types of interest
185    pub entity_types: Vec<EntityType>,
186    /// Detected keywords/concepts
187    pub keywords: Vec<String>,
188    /// Query intent classification
189    pub intent: QueryIntent,
190    /// Confidence in the analysis (0.0-1.0)
191    pub confidence: f32,
192}
193
194/// Query intent classification
195#[derive(Debug, Clone, PartialEq)]
196pub enum QueryIntent {
197    /// Find similar items (e.g., "similar CVEs")
198    Similarity,
199    /// Find paths/relationships (e.g., "how to reach X")
200    PathFinding,
201    /// List/filter entities (e.g., "all hosts with port 22")
202    Enumeration,
203    /// Get details about specific entity
204    Lookup,
205    /// Analyze connections/impact
206    Analysis,
207    /// General/unknown intent
208    General,
209}
210
211/// Types of entities in the security domain
212#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
213pub enum EntityType {
214    Host,
215    Service,
216    Port,
217    Vulnerability,
218    Credential,
219    User,
220    Certificate,
221    Domain,
222    Network,
223    Technology,
224    Endpoint,
225    Unknown,
226}
227
228impl EntityType {
229    /// Convert from string
230    pub fn from_str(s: &str) -> Self {
231        match s.to_lowercase().as_str() {
232            "host" | "hosts" | "ip" | "ips" | "server" | "servers" | "machine" => Self::Host,
233            "service" | "services" => Self::Service,
234            "port" | "ports" => Self::Port,
235            "vuln" | "vulnerability" | "vulnerabilities" | "cve" | "cves" => Self::Vulnerability,
236            "cred" | "credential" | "credentials" | "password" | "passwords" => Self::Credential,
237            "user" | "users" | "account" | "accounts" => Self::User,
238            "cert" | "certificate" | "certificates" | "ssl" | "tls" => Self::Certificate,
239            "domain" | "domains" | "dns" => Self::Domain,
240            "network" | "networks" | "subnet" | "subnets" => Self::Network,
241            "tech" | "technology" | "technologies" | "software" => Self::Technology,
242            "endpoint" | "endpoints" | "url" | "urls" | "api" => Self::Endpoint,
243            _ => Self::Unknown,
244        }
245    }
246
247    /// Get the vector collection name for this entity type
248    pub fn collection_name(&self) -> &'static str {
249        match self {
250            Self::Host => "hosts",
251            Self::Service => "services",
252            Self::Port => "ports",
253            Self::Vulnerability => "vulnerabilities",
254            Self::Credential => "credentials",
255            Self::User => "users",
256            Self::Certificate => "certificates",
257            Self::Domain => "domains",
258            Self::Network => "networks",
259            Self::Technology => "technologies",
260            Self::Endpoint => "endpoints",
261            Self::Unknown => "general",
262        }
263    }
264}
265
266/// Query analyzer for strategy selection
267pub struct QueryAnalyzer {
268    /// Keywords that suggest similarity search
269    similarity_keywords: Vec<&'static str>,
270    /// Keywords that suggest path finding
271    path_keywords: Vec<&'static str>,
272    /// Keywords that suggest enumeration
273    enum_keywords: Vec<&'static str>,
274}
275
276impl QueryAnalyzer {
277    pub fn new() -> Self {
278        Self {
279            similarity_keywords: vec![
280                "similar",
281                "like",
282                "related",
283                "comparable",
284                "equivalent",
285                "matching",
286                "resembling",
287                "analogous",
288                "close to",
289            ],
290            path_keywords: vec![
291                "path",
292                "route",
293                "reach",
294                "connect",
295                "between",
296                "from",
297                "to",
298                "via",
299                "through",
300                "attack path",
301                "lateral",
302            ],
303            enum_keywords: vec![
304                "all", "list", "find", "show", "get", "which", "what", "where", "filter", "having",
305                "with",
306            ],
307        }
308    }
309
310    /// Analyze a query to determine optimal retrieval strategy
311    pub fn analyze(&self, query: &str) -> QueryAnalysis {
312        let query_lower = query.to_lowercase();
313        let words: Vec<&str> = query_lower.split_whitespace().collect();
314
315        // Detect intent
316        let intent = self.detect_intent(&query_lower);
317
318        // Detect entity types
319        let entity_types = self.detect_entity_types(&words);
320
321        // Extract keywords
322        let keywords = self.extract_keywords(&query_lower);
323
324        // Determine primary strategy based on intent
325        let primary_strategy = match intent {
326            QueryIntent::Similarity => RetrievalStrategy::VectorFirst,
327            QueryIntent::PathFinding => RetrievalStrategy::GraphFirst,
328            QueryIntent::Enumeration => RetrievalStrategy::Hybrid,
329            QueryIntent::Lookup => RetrievalStrategy::GraphFirst,
330            QueryIntent::Analysis => RetrievalStrategy::Hybrid,
331            QueryIntent::General => RetrievalStrategy::Hybrid,
332        };
333
334        // Determine secondary strategies
335        let mut secondary_strategies = Vec::new();
336        if primary_strategy != RetrievalStrategy::VectorFirst {
337            secondary_strategies.push(RetrievalStrategy::VectorFirst);
338        }
339        if primary_strategy != RetrievalStrategy::GraphFirst {
340            secondary_strategies.push(RetrievalStrategy::GraphFirst);
341        }
342
343        // Calculate confidence
344        let confidence = if intent != QueryIntent::General {
345            0.8
346        } else if !entity_types.is_empty() {
347            0.6
348        } else {
349            0.4
350        };
351
352        QueryAnalysis {
353            primary_strategy,
354            secondary_strategies,
355            entity_types,
356            keywords,
357            intent,
358            confidence,
359        }
360    }
361
362    fn detect_intent(&self, query: &str) -> QueryIntent {
363        // Check for similarity keywords
364        if self.similarity_keywords.iter().any(|k| query.contains(k)) {
365            return QueryIntent::Similarity;
366        }
367
368        // Check for path keywords
369        if self.path_keywords.iter().any(|k| query.contains(k)) {
370            return QueryIntent::PathFinding;
371        }
372
373        // Check for enumeration keywords
374        if self.enum_keywords.iter().any(|k| query.contains(k)) {
375            return QueryIntent::Enumeration;
376        }
377
378        // Check for lookup patterns (specific IDs, IPs, CVEs)
379        if query.contains("cve-") || query.contains("192.") || query.contains("10.") {
380            return QueryIntent::Lookup;
381        }
382
383        // Check for analysis keywords
384        if query.contains("impact") || query.contains("affect") || query.contains("analyze") {
385            return QueryIntent::Analysis;
386        }
387
388        QueryIntent::General
389    }
390
391    fn detect_entity_types(&self, words: &[&str]) -> Vec<EntityType> {
392        let mut types = Vec::new();
393        for word in words {
394            let entity_type = EntityType::from_str(word);
395            if entity_type != EntityType::Unknown && !types.contains(&entity_type) {
396                types.push(entity_type);
397            }
398        }
399        types
400    }
401
402    fn extract_keywords(&self, query: &str) -> Vec<String> {
403        // Simple keyword extraction - filter common words
404        let stop_words = [
405            "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
406            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
407            "can", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into",
408            "about", "i", "me", "my", "we", "our",
409        ];
410
411        query
412            .split_whitespace()
413            .filter(|w| w.len() > 2)
414            .filter(|w| !stop_words.contains(&w.to_lowercase().as_str()))
415            .map(|w| w.to_string())
416            .collect()
417    }
418}
419
420impl Default for QueryAnalyzer {
421    fn default() -> Self {
422        Self::new()
423    }
424}
425
426// ============================================================================
427// Similar Entity Result
428// ============================================================================
429
430/// A similar entity found via vector search
431#[derive(Debug, Clone)]
432pub struct SimilarEntity {
433    /// Entity ID
434    pub id: u64,
435    /// Collection/type
436    pub collection: String,
437    /// Similarity score (0-1, higher is more similar)
438    pub similarity: f32,
439    /// Entity label/name
440    pub label: Option<String>,
441    /// Additional properties
442    pub properties: HashMap<String, Value>,
443}
444
445impl SimilarEntity {
446    pub fn new(id: u64, collection: &str, similarity: f32) -> Self {
447        Self {
448            id,
449            collection: collection.to_string(),
450            similarity,
451            label: None,
452            properties: HashMap::new(),
453        }
454    }
455
456    pub fn with_label(mut self, label: impl Into<String>) -> Self {
457        self.label = Some(label.into());
458        self
459    }
460
461    pub fn with_property(mut self, key: impl Into<String>, value: Value) -> Self {
462        self.properties.insert(key.into(), value);
463        self
464    }
465}
466
467// ============================================================================
468// Tests
469// ============================================================================
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_query_analyzer_similarity_intent() {
477        let analyzer = QueryAnalyzer::new();
478
479        let analysis = analyzer.analyze("find similar CVEs to CVE-2024-1234");
480        assert_eq!(analysis.intent, QueryIntent::Similarity);
481        assert_eq!(analysis.primary_strategy, RetrievalStrategy::VectorFirst);
482    }
483
484    #[test]
485    fn test_query_analyzer_path_intent() {
486        let analyzer = QueryAnalyzer::new();
487
488        let analysis = analyzer.analyze("attack path from webserver to database");
489        assert_eq!(analysis.intent, QueryIntent::PathFinding);
490        assert_eq!(analysis.primary_strategy, RetrievalStrategy::GraphFirst);
491    }
492
493    #[test]
494    fn test_query_analyzer_enumeration_intent() {
495        let analyzer = QueryAnalyzer::new();
496
497        let analysis = analyzer.analyze("list all hosts with port 22 open");
498        assert_eq!(analysis.intent, QueryIntent::Enumeration);
499        assert_eq!(analysis.primary_strategy, RetrievalStrategy::Hybrid);
500    }
501
502    #[test]
503    fn test_entity_type_detection() {
504        let analyzer = QueryAnalyzer::new();
505
506        let analysis = analyzer.analyze("show vulnerabilities affecting hosts");
507        assert!(analysis.entity_types.contains(&EntityType::Vulnerability));
508        assert!(analysis.entity_types.contains(&EntityType::Host));
509    }
510
511    #[test]
512    fn test_keyword_extraction() {
513        let analyzer = QueryAnalyzer::new();
514
515        let analysis = analyzer.analyze("find critical vulnerabilities in production servers");
516        assert!(analysis.keywords.contains(&"critical".to_string()));
517        assert!(analysis.keywords.contains(&"production".to_string()));
518        assert!(!analysis.keywords.contains(&"in".to_string())); // stop word
519    }
520}