1pub mod context;
31pub mod fusion;
32pub mod retriever;
33pub mod unified_adapter;
34
35use std::collections::HashMap;
36use std::sync::Arc;
37
38use crate::storage::engine::graph_store::GraphStore;
39use crate::storage::engine::graph_table_index::GraphTableIndex;
40use crate::storage::engine::unified_index::UnifiedIndex;
41use crate::storage::engine::vector_store::VectorStore;
42use crate::storage::query::unified::ExecutionError;
43use crate::storage::schema::Value;
44
45pub use context::{ChunkSource, ContextChunk, RetrievalContext};
46pub use fusion::{ContextFusion, FusionConfig, ResultReranker};
47pub use retriever::{MultiSourceRetriever, RetrievalStrategy};
48pub use unified_adapter::{
49 EdgeDirection, EdgePatternSpec, GraphQueryPattern, MatchSource, MatchedEntity, MetadataQuery,
50 MultiModalQuery, NodePattern, QueryCondition, QueryValue, UnifiedQueryResult,
51 UnifiedQueryStats, UnifiedStoreAdapter,
52};
53
54#[derive(Debug, Clone)]
56pub struct RagConfig {
57 pub max_chunks_per_source: usize,
59 pub max_total_chunks: usize,
61 pub similarity_threshold: f32,
63 pub graph_depth: u32,
65 pub expand_cross_refs: bool,
67 pub min_relevance: f32,
69}
70
71impl Default for RagConfig {
72 fn default() -> Self {
73 Self {
74 max_chunks_per_source: 10,
75 max_total_chunks: 25,
76 similarity_threshold: 0.8,
77 graph_depth: 2,
78 expand_cross_refs: true,
79 min_relevance: 0.3,
80 }
81 }
82}
83
84pub struct RagEngine {
86 config: RagConfig,
88 retriever: MultiSourceRetriever,
90 analyzer: QueryAnalyzer,
92}
93
94impl RagEngine {
95 pub fn new(
97 graph: Arc<GraphStore>,
98 index: Arc<GraphTableIndex>,
99 vector_store: Arc<VectorStore>,
100 unified_index: Arc<UnifiedIndex>,
101 ) -> Self {
102 Self {
103 config: RagConfig::default(),
104 retriever: MultiSourceRetriever::new(graph, index, vector_store, unified_index),
105 analyzer: QueryAnalyzer::new(),
106 }
107 }
108
109 pub fn with_config(mut self, config: RagConfig) -> Self {
111 self.config = config;
112 self
113 }
114
115 pub fn retrieve(&self, query: &str) -> Result<RetrievalContext, ExecutionError> {
117 let analysis = self.analyzer.analyze(query);
119
120 let context = self.retriever.retrieve(query, &analysis, &self.config)?;
122
123 Ok(context)
124 }
125
126 pub fn retrieve_with_strategy(
128 &self,
129 query: &str,
130 strategy: RetrievalStrategy,
131 ) -> Result<RetrievalContext, ExecutionError> {
132 let analysis = QueryAnalysis {
133 primary_strategy: strategy,
134 ..self.analyzer.analyze(query)
135 };
136
137 self.retriever.retrieve(query, &analysis, &self.config)
138 }
139
140 pub fn retrieve_by_vector(
142 &self,
143 vector: &[f32],
144 collection: &str,
145 k: usize,
146 ) -> Result<RetrievalContext, ExecutionError> {
147 self.retriever
148 .retrieve_by_vector(vector, collection, k, &self.config)
149 }
150
151 pub fn expand_context(
153 &self,
154 entity_id: &str,
155 entity_type: EntityType,
156 depth: u32,
157 ) -> Result<RetrievalContext, ExecutionError> {
158 self.retriever
159 .expand_context(entity_id, entity_type, depth, &self.config)
160 }
161
162 pub fn find_similar(
164 &self,
165 collection: &str,
166 entity_id: u64,
167 k: usize,
168 ) -> Result<Vec<SimilarEntity>, ExecutionError> {
169 self.retriever.find_similar(collection, entity_id, k)
170 }
171}
172
173#[derive(Debug, Clone)]
179pub struct QueryAnalysis {
180 pub primary_strategy: RetrievalStrategy,
182 pub secondary_strategies: Vec<RetrievalStrategy>,
184 pub entity_types: Vec<EntityType>,
186 pub keywords: Vec<String>,
188 pub intent: QueryIntent,
190 pub confidence: f32,
192}
193
194#[derive(Debug, Clone, PartialEq)]
196pub enum QueryIntent {
197 Similarity,
199 PathFinding,
201 Enumeration,
203 Lookup,
205 Analysis,
207 General,
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
213pub enum EntityType {
214 Host,
215 Service,
216 Port,
217 Vulnerability,
218 Credential,
219 User,
220 Certificate,
221 Domain,
222 Network,
223 Technology,
224 Endpoint,
225 Unknown,
226}
227
228impl EntityType {
229 pub fn from_str(s: &str) -> Self {
231 match s.to_lowercase().as_str() {
232 "host" | "hosts" | "ip" | "ips" | "server" | "servers" | "machine" => Self::Host,
233 "service" | "services" => Self::Service,
234 "port" | "ports" => Self::Port,
235 "vuln" | "vulnerability" | "vulnerabilities" | "cve" | "cves" => Self::Vulnerability,
236 "cred" | "credential" | "credentials" | "password" | "passwords" => Self::Credential,
237 "user" | "users" | "account" | "accounts" => Self::User,
238 "cert" | "certificate" | "certificates" | "ssl" | "tls" => Self::Certificate,
239 "domain" | "domains" | "dns" => Self::Domain,
240 "network" | "networks" | "subnet" | "subnets" => Self::Network,
241 "tech" | "technology" | "technologies" | "software" => Self::Technology,
242 "endpoint" | "endpoints" | "url" | "urls" | "api" => Self::Endpoint,
243 _ => Self::Unknown,
244 }
245 }
246
247 pub fn collection_name(&self) -> &'static str {
249 match self {
250 Self::Host => "hosts",
251 Self::Service => "services",
252 Self::Port => "ports",
253 Self::Vulnerability => "vulnerabilities",
254 Self::Credential => "credentials",
255 Self::User => "users",
256 Self::Certificate => "certificates",
257 Self::Domain => "domains",
258 Self::Network => "networks",
259 Self::Technology => "technologies",
260 Self::Endpoint => "endpoints",
261 Self::Unknown => "general",
262 }
263 }
264}
265
266pub struct QueryAnalyzer {
268 similarity_keywords: Vec<&'static str>,
270 path_keywords: Vec<&'static str>,
272 enum_keywords: Vec<&'static str>,
274}
275
276impl QueryAnalyzer {
277 pub fn new() -> Self {
278 Self {
279 similarity_keywords: vec![
280 "similar",
281 "like",
282 "related",
283 "comparable",
284 "equivalent",
285 "matching",
286 "resembling",
287 "analogous",
288 "close to",
289 ],
290 path_keywords: vec![
291 "path",
292 "route",
293 "reach",
294 "connect",
295 "between",
296 "from",
297 "to",
298 "via",
299 "through",
300 "attack path",
301 "lateral",
302 ],
303 enum_keywords: vec![
304 "all", "list", "find", "show", "get", "which", "what", "where", "filter", "having",
305 "with",
306 ],
307 }
308 }
309
310 pub fn analyze(&self, query: &str) -> QueryAnalysis {
312 let query_lower = query.to_lowercase();
313 let words: Vec<&str> = query_lower.split_whitespace().collect();
314
315 let intent = self.detect_intent(&query_lower);
317
318 let entity_types = self.detect_entity_types(&words);
320
321 let keywords = self.extract_keywords(&query_lower);
323
324 let primary_strategy = match intent {
326 QueryIntent::Similarity => RetrievalStrategy::VectorFirst,
327 QueryIntent::PathFinding => RetrievalStrategy::GraphFirst,
328 QueryIntent::Enumeration => RetrievalStrategy::Hybrid,
329 QueryIntent::Lookup => RetrievalStrategy::GraphFirst,
330 QueryIntent::Analysis => RetrievalStrategy::Hybrid,
331 QueryIntent::General => RetrievalStrategy::Hybrid,
332 };
333
334 let mut secondary_strategies = Vec::new();
336 if primary_strategy != RetrievalStrategy::VectorFirst {
337 secondary_strategies.push(RetrievalStrategy::VectorFirst);
338 }
339 if primary_strategy != RetrievalStrategy::GraphFirst {
340 secondary_strategies.push(RetrievalStrategy::GraphFirst);
341 }
342
343 let confidence = if intent != QueryIntent::General {
345 0.8
346 } else if !entity_types.is_empty() {
347 0.6
348 } else {
349 0.4
350 };
351
352 QueryAnalysis {
353 primary_strategy,
354 secondary_strategies,
355 entity_types,
356 keywords,
357 intent,
358 confidence,
359 }
360 }
361
362 fn detect_intent(&self, query: &str) -> QueryIntent {
363 if self.similarity_keywords.iter().any(|k| query.contains(k)) {
365 return QueryIntent::Similarity;
366 }
367
368 if self.path_keywords.iter().any(|k| query.contains(k)) {
370 return QueryIntent::PathFinding;
371 }
372
373 if self.enum_keywords.iter().any(|k| query.contains(k)) {
375 return QueryIntent::Enumeration;
376 }
377
378 if query.contains("cve-") || query.contains("192.") || query.contains("10.") {
380 return QueryIntent::Lookup;
381 }
382
383 if query.contains("impact") || query.contains("affect") || query.contains("analyze") {
385 return QueryIntent::Analysis;
386 }
387
388 QueryIntent::General
389 }
390
391 fn detect_entity_types(&self, words: &[&str]) -> Vec<EntityType> {
392 let mut types = Vec::new();
393 for word in words {
394 let entity_type = EntityType::from_str(word);
395 if entity_type != EntityType::Unknown && !types.contains(&entity_type) {
396 types.push(entity_type);
397 }
398 }
399 types
400 }
401
402 fn extract_keywords(&self, query: &str) -> Vec<String> {
403 let stop_words = [
405 "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
406 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
407 "can", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into",
408 "about", "i", "me", "my", "we", "our",
409 ];
410
411 query
412 .split_whitespace()
413 .filter(|w| w.len() > 2)
414 .filter(|w| !stop_words.contains(&w.to_lowercase().as_str()))
415 .map(|w| w.to_string())
416 .collect()
417 }
418}
419
420impl Default for QueryAnalyzer {
421 fn default() -> Self {
422 Self::new()
423 }
424}
425
426#[derive(Debug, Clone)]
432pub struct SimilarEntity {
433 pub id: u64,
435 pub collection: String,
437 pub similarity: f32,
439 pub label: Option<String>,
441 pub properties: HashMap<String, Value>,
443}
444
445impl SimilarEntity {
446 pub fn new(id: u64, collection: &str, similarity: f32) -> Self {
447 Self {
448 id,
449 collection: collection.to_string(),
450 similarity,
451 label: None,
452 properties: HashMap::new(),
453 }
454 }
455
456 pub fn with_label(mut self, label: impl Into<String>) -> Self {
457 self.label = Some(label.into());
458 self
459 }
460
461 pub fn with_property(mut self, key: impl Into<String>, value: Value) -> Self {
462 self.properties.insert(key.into(), value);
463 self
464 }
465}
466
467#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_query_analyzer_similarity_intent() {
477 let analyzer = QueryAnalyzer::new();
478
479 let analysis = analyzer.analyze("find similar CVEs to CVE-2024-1234");
480 assert_eq!(analysis.intent, QueryIntent::Similarity);
481 assert_eq!(analysis.primary_strategy, RetrievalStrategy::VectorFirst);
482 }
483
484 #[test]
485 fn test_query_analyzer_path_intent() {
486 let analyzer = QueryAnalyzer::new();
487
488 let analysis = analyzer.analyze("attack path from webserver to database");
489 assert_eq!(analysis.intent, QueryIntent::PathFinding);
490 assert_eq!(analysis.primary_strategy, RetrievalStrategy::GraphFirst);
491 }
492
493 #[test]
494 fn test_query_analyzer_enumeration_intent() {
495 let analyzer = QueryAnalyzer::new();
496
497 let analysis = analyzer.analyze("list all hosts with port 22 open");
498 assert_eq!(analysis.intent, QueryIntent::Enumeration);
499 assert_eq!(analysis.primary_strategy, RetrievalStrategy::Hybrid);
500 }
501
502 #[test]
503 fn test_entity_type_detection() {
504 let analyzer = QueryAnalyzer::new();
505
506 let analysis = analyzer.analyze("show vulnerabilities affecting hosts");
507 assert!(analysis.entity_types.contains(&EntityType::Vulnerability));
508 assert!(analysis.entity_types.contains(&EntityType::Host));
509 }
510
511 #[test]
512 fn test_keyword_extraction() {
513 let analyzer = QueryAnalyzer::new();
514
515 let analysis = analyzer.analyze("find critical vulnerabilities in production servers");
516 assert!(analysis.keywords.contains(&"critical".to_string()));
517 assert!(analysis.keywords.contains(&"production".to_string()));
518 assert!(!analysis.keywords.contains(&"in".to_string())); }
520}