pub mod context;
pub mod fusion;
pub mod retriever;
pub mod unified_adapter;
use std::collections::HashMap;
use std::sync::Arc;
use crate::storage::engine::graph_store::GraphStore;
use crate::storage::engine::graph_table_index::GraphTableIndex;
use crate::storage::engine::unified_index::UnifiedIndex;
use crate::storage::engine::vector_store::VectorStore;
use crate::storage::query::unified::ExecutionError;
use crate::storage::schema::Value;
pub use context::{ChunkSource, ContextChunk, RetrievalContext};
pub use fusion::{ContextFusion, FusionConfig, ResultReranker};
pub use retriever::{MultiSourceRetriever, RetrievalStrategy};
pub use unified_adapter::{
EdgeDirection, EdgePatternSpec, GraphQueryPattern, MatchSource, MatchedEntity, MetadataQuery,
MultiModalQuery, NodePattern, QueryCondition, QueryValue, UnifiedQueryResult,
UnifiedQueryStats, UnifiedStoreAdapter,
};
#[derive(Debug, Clone)]
pub struct RagConfig {
pub max_chunks_per_source: usize,
pub max_total_chunks: usize,
pub similarity_threshold: f32,
pub graph_depth: u32,
pub expand_cross_refs: bool,
pub min_relevance: f32,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
max_chunks_per_source: 10,
max_total_chunks: 25,
similarity_threshold: 0.8,
graph_depth: 2,
expand_cross_refs: true,
min_relevance: 0.3,
}
}
}
pub struct RagEngine {
config: RagConfig,
retriever: MultiSourceRetriever,
analyzer: QueryAnalyzer,
}
impl RagEngine {
pub fn new(
graph: Arc<GraphStore>,
index: Arc<GraphTableIndex>,
vector_store: Arc<VectorStore>,
unified_index: Arc<UnifiedIndex>,
) -> Self {
Self {
config: RagConfig::default(),
retriever: MultiSourceRetriever::new(graph, index, vector_store, unified_index),
analyzer: QueryAnalyzer::new(),
}
}
pub fn with_config(mut self, config: RagConfig) -> Self {
self.config = config;
self
}
pub fn retrieve(&self, query: &str) -> Result<RetrievalContext, ExecutionError> {
let analysis = self.analyzer.analyze(query);
let context = self.retriever.retrieve(query, &analysis, &self.config)?;
Ok(context)
}
pub fn retrieve_with_strategy(
&self,
query: &str,
strategy: RetrievalStrategy,
) -> Result<RetrievalContext, ExecutionError> {
let analysis = QueryAnalysis {
primary_strategy: strategy,
..self.analyzer.analyze(query)
};
self.retriever.retrieve(query, &analysis, &self.config)
}
pub fn retrieve_by_vector(
&self,
vector: &[f32],
collection: &str,
k: usize,
) -> Result<RetrievalContext, ExecutionError> {
self.retriever
.retrieve_by_vector(vector, collection, k, &self.config)
}
pub fn expand_context(
&self,
entity_id: &str,
entity_type: EntityType,
depth: u32,
) -> Result<RetrievalContext, ExecutionError> {
self.retriever
.expand_context(entity_id, entity_type, depth, &self.config)
}
pub fn find_similar(
&self,
collection: &str,
entity_id: u64,
k: usize,
) -> Result<Vec<SimilarEntity>, ExecutionError> {
self.retriever.find_similar(collection, entity_id, k)
}
}
#[derive(Debug, Clone)]
pub struct QueryAnalysis {
pub primary_strategy: RetrievalStrategy,
pub secondary_strategies: Vec<RetrievalStrategy>,
pub entity_types: Vec<EntityType>,
pub keywords: Vec<String>,
pub intent: QueryIntent,
pub confidence: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum QueryIntent {
Similarity,
PathFinding,
Enumeration,
Lookup,
Analysis,
General,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EntityType {
Host,
Service,
Port,
Vulnerability,
Credential,
User,
Certificate,
Domain,
Network,
Technology,
Endpoint,
Unknown,
}
impl EntityType {
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"host" | "hosts" | "ip" | "ips" | "server" | "servers" | "machine" => Self::Host,
"service" | "services" => Self::Service,
"port" | "ports" => Self::Port,
"vuln" | "vulnerability" | "vulnerabilities" | "cve" | "cves" => Self::Vulnerability,
"cred" | "credential" | "credentials" | "password" | "passwords" => Self::Credential,
"user" | "users" | "account" | "accounts" => Self::User,
"cert" | "certificate" | "certificates" | "ssl" | "tls" => Self::Certificate,
"domain" | "domains" | "dns" => Self::Domain,
"network" | "networks" | "subnet" | "subnets" => Self::Network,
"tech" | "technology" | "technologies" | "software" => Self::Technology,
"endpoint" | "endpoints" | "url" | "urls" | "api" => Self::Endpoint,
_ => Self::Unknown,
}
}
pub fn collection_name(&self) -> &'static str {
match self {
Self::Host => "hosts",
Self::Service => "services",
Self::Port => "ports",
Self::Vulnerability => "vulnerabilities",
Self::Credential => "credentials",
Self::User => "users",
Self::Certificate => "certificates",
Self::Domain => "domains",
Self::Network => "networks",
Self::Technology => "technologies",
Self::Endpoint => "endpoints",
Self::Unknown => "general",
}
}
}
pub struct QueryAnalyzer {
similarity_keywords: Vec<&'static str>,
path_keywords: Vec<&'static str>,
enum_keywords: Vec<&'static str>,
}
impl QueryAnalyzer {
pub fn new() -> Self {
Self {
similarity_keywords: vec![
"similar",
"like",
"related",
"comparable",
"equivalent",
"matching",
"resembling",
"analogous",
"close to",
],
path_keywords: vec![
"path",
"route",
"reach",
"connect",
"between",
"from",
"to",
"via",
"through",
"attack path",
"lateral",
],
enum_keywords: vec![
"all", "list", "find", "show", "get", "which", "what", "where", "filter", "having",
"with",
],
}
}
pub fn analyze(&self, query: &str) -> QueryAnalysis {
let query_lower = query.to_lowercase();
let words: Vec<&str> = query_lower.split_whitespace().collect();
let intent = self.detect_intent(&query_lower);
let entity_types = self.detect_entity_types(&words);
let keywords = self.extract_keywords(&query_lower);
let primary_strategy = match intent {
QueryIntent::Similarity => RetrievalStrategy::VectorFirst,
QueryIntent::PathFinding => RetrievalStrategy::GraphFirst,
QueryIntent::Enumeration => RetrievalStrategy::Hybrid,
QueryIntent::Lookup => RetrievalStrategy::GraphFirst,
QueryIntent::Analysis => RetrievalStrategy::Hybrid,
QueryIntent::General => RetrievalStrategy::Hybrid,
};
let mut secondary_strategies = Vec::new();
if primary_strategy != RetrievalStrategy::VectorFirst {
secondary_strategies.push(RetrievalStrategy::VectorFirst);
}
if primary_strategy != RetrievalStrategy::GraphFirst {
secondary_strategies.push(RetrievalStrategy::GraphFirst);
}
let confidence = if intent != QueryIntent::General {
0.8
} else if !entity_types.is_empty() {
0.6
} else {
0.4
};
QueryAnalysis {
primary_strategy,
secondary_strategies,
entity_types,
keywords,
intent,
confidence,
}
}
fn detect_intent(&self, query: &str) -> QueryIntent {
if self.similarity_keywords.iter().any(|k| query.contains(k)) {
return QueryIntent::Similarity;
}
if self.path_keywords.iter().any(|k| query.contains(k)) {
return QueryIntent::PathFinding;
}
if self.enum_keywords.iter().any(|k| query.contains(k)) {
return QueryIntent::Enumeration;
}
if query.contains("cve-") || query.contains("192.") || query.contains("10.") {
return QueryIntent::Lookup;
}
if query.contains("impact") || query.contains("affect") || query.contains("analyze") {
return QueryIntent::Analysis;
}
QueryIntent::General
}
fn detect_entity_types(&self, words: &[&str]) -> Vec<EntityType> {
let mut types = Vec::new();
for word in words {
let entity_type = EntityType::from_str(word);
if entity_type != EntityType::Unknown && !types.contains(&entity_type) {
types.push(entity_type);
}
}
types
}
fn extract_keywords(&self, query: &str) -> Vec<String> {
let stop_words = [
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
"had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
"can", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into",
"about", "i", "me", "my", "we", "our",
];
query
.split_whitespace()
.filter(|w| w.len() > 2)
.filter(|w| !stop_words.contains(&w.to_lowercase().as_str()))
.map(|w| w.to_string())
.collect()
}
}
impl Default for QueryAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SimilarEntity {
pub id: u64,
pub collection: String,
pub similarity: f32,
pub label: Option<String>,
pub properties: HashMap<String, Value>,
}
impl SimilarEntity {
pub fn new(id: u64, collection: &str, similarity: f32) -> Self {
Self {
id,
collection: collection.to_string(),
similarity,
label: None,
properties: HashMap::new(),
}
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn with_property(mut self, key: impl Into<String>, value: Value) -> Self {
self.properties.insert(key.into(), value);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_analyzer_similarity_intent() {
let analyzer = QueryAnalyzer::new();
let analysis = analyzer.analyze("find similar CVEs to CVE-2024-1234");
assert_eq!(analysis.intent, QueryIntent::Similarity);
assert_eq!(analysis.primary_strategy, RetrievalStrategy::VectorFirst);
}
#[test]
fn test_query_analyzer_path_intent() {
let analyzer = QueryAnalyzer::new();
let analysis = analyzer.analyze("attack path from webserver to database");
assert_eq!(analysis.intent, QueryIntent::PathFinding);
assert_eq!(analysis.primary_strategy, RetrievalStrategy::GraphFirst);
}
#[test]
fn test_query_analyzer_enumeration_intent() {
let analyzer = QueryAnalyzer::new();
let analysis = analyzer.analyze("list all hosts with port 22 open");
assert_eq!(analysis.intent, QueryIntent::Enumeration);
assert_eq!(analysis.primary_strategy, RetrievalStrategy::Hybrid);
}
#[test]
fn test_entity_type_detection() {
let analyzer = QueryAnalyzer::new();
let analysis = analyzer.analyze("show vulnerabilities affecting hosts");
assert!(analysis.entity_types.contains(&EntityType::Vulnerability));
assert!(analysis.entity_types.contains(&EntityType::Host));
}
#[test]
fn test_keyword_extraction() {
let analyzer = QueryAnalyzer::new();
let analysis = analyzer.analyze("find critical vulnerabilities in production servers");
assert!(analysis.keywords.contains(&"critical".to_string()));
assert!(analysis.keywords.contains(&"production".to_string()));
assert!(!analysis.keywords.contains(&"in".to_string())); }
}