Skip to main content

reddb_server/storage/query/executors/
natural.rs

1//! Natural Language Query Executor
2//!
3//! Translates natural language queries to RQL and executes them,
4//! providing explanations of the translation process.
5//!
6//! # Features
7//!
8//! - Intent classification: list, find, path, count, show
9//! - Entity extraction: hosts, users, credentials, services, vulnerabilities
10//! - Query generation with confidence scoring
11//! - Execution explanation for user understanding
12
13use std::sync::Arc;
14
15use crate::storage::engine::graph_store::GraphStore;
16use crate::storage::query::modes::natural::{
17    EntityType, ExtractedEntity, NaturalParser, NaturalQuery, QueryIntent,
18};
19use crate::storage::query::unified::{
20    ExecutionError, MatchedNode, QueryStats, UnifiedRecord, UnifiedResult,
21};
22
23/// Natural language executor with translation explanation
24pub struct NaturalExecutor {
25    graph: Arc<GraphStore>,
26}
27
28impl NaturalExecutor {
29    /// Create a new natural language executor
30    pub fn new(graph: Arc<GraphStore>) -> Self {
31        Self { graph }
32    }
33
34    /// Execute a natural language query and return explanation
35    pub fn execute_with_explanation(
36        &self,
37        query: &str,
38    ) -> Result<(UnifiedResult, String), ExecutionError> {
39        // Parse natural language
40        let parsed = NaturalParser::parse(query).map_err(|e| ExecutionError::new(e.to_string()))?;
41
42        // Generate explanation
43        let explanation = self.generate_explanation(&parsed, query);
44
45        // Execute
46        let result = self.execute_natural(&parsed)?;
47
48        Ok((result, explanation))
49    }
50
51    /// Execute a natural language query
52    pub fn execute(&self, query: &str) -> Result<UnifiedResult, ExecutionError> {
53        let parsed = NaturalParser::parse(query).map_err(|e| ExecutionError::new(e.to_string()))?;
54        self.execute_natural(&parsed)
55    }
56
57    /// Execute a parsed natural language query
58    fn execute_natural(&self, query: &NaturalQuery) -> Result<UnifiedResult, ExecutionError> {
59        let mut stats = QueryStats::default();
60        let mut result = UnifiedResult::empty();
61
62        match query.intent {
63            QueryIntent::Find => {
64                // Find handles both "find" and "list" semantics
65                self.execute_find(query, &mut result, &mut stats)?;
66            }
67            QueryIntent::Path => {
68                self.execute_path(query, &mut result, &mut stats)?;
69            }
70            QueryIntent::Count => {
71                self.execute_count(query, &mut result, &mut stats)?;
72            }
73            QueryIntent::Show => {
74                self.execute_show(query, &mut result, &mut stats)?;
75            }
76            QueryIntent::Check => {
77                self.execute_check(query, &mut result, &mut stats)?;
78            }
79        }
80
81        result.stats = stats;
82        Ok(result)
83    }
84
85    /// Execute FIND intent (also handles LIST semantics)
86    fn execute_find(
87        &self,
88        query: &NaturalQuery,
89        result: &mut UnifiedResult,
90        stats: &mut QueryStats,
91    ) -> Result<(), ExecutionError> {
92        let entity_label = self.primary_entity_label(query);
93
94        for node in self.graph.iter_nodes() {
95            stats.nodes_scanned += 1;
96
97            // Check label match — compare against the legacy enum's string form.
98            let type_matches = match entity_label {
99                Some(label) => node.node_type.as_str() == label,
100                None => true,
101            };
102
103            if !type_matches {
104                continue;
105            }
106
107            // Check entity filters
108            if !self.node_matches_filters(&node, &query.entities) {
109                continue;
110            }
111
112            // Check relationship constraints
113            let mut rel_match = true;
114            for entity in &query.entities {
115                if let Some(ref value) = entity.value {
116                    // Check if node has relationship to this entity
117                    if !self.has_relationship_to(&node.id, value, stats) {
118                        rel_match = false;
119                        break;
120                    }
121                }
122            }
123
124            if rel_match {
125                let mut record = UnifiedRecord::new();
126                record.set_node("_", MatchedNode::from_stored(&node));
127                result.push(record);
128            }
129        }
130
131        // Apply limit if specified
132        if let Some(limit) = query.limit {
133            if result.len() > limit as usize {
134                result.records.truncate(limit as usize);
135            }
136        }
137
138        Ok(())
139    }
140
141    /// Execute PATH intent
142    fn execute_path(
143        &self,
144        query: &NaturalQuery,
145        result: &mut UnifiedResult,
146        stats: &mut QueryStats,
147    ) -> Result<(), ExecutionError> {
148        // Extract source and target from entities
149        let (source, target) = self.extract_path_endpoints(query)?;
150
151        // BFS to find path
152        use crate::storage::query::unified::{GraphPath, MatchedEdge};
153        use std::collections::{HashSet, VecDeque};
154
155        let mut queue: VecDeque<(String, GraphPath)> = VecDeque::new();
156        let mut visited: HashSet<String> = HashSet::new();
157
158        queue.push_back((source.clone(), GraphPath::start(&source)));
159        visited.insert(source.clone());
160
161        let max_hops = query.limit.unwrap_or(10) as usize;
162
163        while let Some((current, path)) = queue.pop_front() {
164            if path.len() > max_hops {
165                continue;
166            }
167
168            if current == target {
169                let mut record = UnifiedRecord::new();
170                record.paths.push(path);
171                result.push(record);
172                break; // Found shortest path
173            }
174
175            for (edge_type, neighbor, weight) in self.graph.outgoing_edges(&current) {
176                stats.edges_scanned += 1;
177
178                if !visited.contains(&neighbor) {
179                    visited.insert(neighbor.clone());
180                    let edge = MatchedEdge::from_tuple(&current, edge_type, &neighbor, weight);
181                    let new_path = path.extend(edge, &neighbor);
182                    queue.push_back((neighbor, new_path));
183                }
184            }
185        }
186
187        if result.is_empty() {
188            return Err(ExecutionError::new(format!(
189                "No path found from {} to {}",
190                source, target
191            )));
192        }
193
194        Ok(())
195    }
196
197    /// Execute COUNT intent
198    fn execute_count(
199        &self,
200        query: &NaturalQuery,
201        result: &mut UnifiedResult,
202        stats: &mut QueryStats,
203    ) -> Result<(), ExecutionError> {
204        let entity_label = self.primary_entity_label(query);
205        let mut count = 0u64;
206
207        for node in self.graph.iter_nodes() {
208            stats.nodes_scanned += 1;
209
210            let type_matches = match entity_label {
211                Some(label) => node.node_type.as_str() == label,
212                None => true,
213            };
214
215            if type_matches && self.node_matches_filters(&node, &query.entities) {
216                count += 1;
217            }
218        }
219
220        let mut record = UnifiedRecord::new();
221        record.set(
222            "count",
223            crate::storage::schema::Value::Integer(count as i64),
224        );
225        result.push(record);
226        result.columns.push("count".to_string());
227
228        Ok(())
229    }
230
231    /// Execute SHOW intent
232    fn execute_show(
233        &self,
234        query: &NaturalQuery,
235        result: &mut UnifiedResult,
236        stats: &mut QueryStats,
237    ) -> Result<(), ExecutionError> {
238        // SHOW is like FIND but includes more details
239        self.execute_find(query, result, stats)?;
240
241        // Add neighbors for context
242        if result.len() == 1 {
243            if let Some(node) = result.records.first().and_then(|r| r.nodes.get("_")) {
244                // Add outgoing connections
245                for (edge_type, target, _) in self.graph.outgoing_edges(&node.id) {
246                    stats.edges_scanned += 1;
247                    if let Some(target_node) = self.graph.get_node(&target) {
248                        let mut record = UnifiedRecord::new();
249                        record.set_node("related", MatchedNode::from_stored(&target_node));
250                        record.set(
251                            "relationship",
252                            crate::storage::schema::Value::text(format!("{:?}", edge_type)),
253                        );
254                        result.push(record);
255                    }
256                }
257            }
258        }
259
260        Ok(())
261    }
262
263    /// Execute CHECK intent - verify if a relationship exists
264    fn execute_check(
265        &self,
266        query: &NaturalQuery,
267        result: &mut UnifiedResult,
268        stats: &mut QueryStats,
269    ) -> Result<(), ExecutionError> {
270        // Check requires two entities with a relationship
271        let (source, target) = self.extract_path_endpoints(query)?;
272
273        // Check if direct connection exists
274        let mut found = false;
275        for (edge_type, neighbor, weight) in self.graph.outgoing_edges(&source) {
276            stats.edges_scanned += 1;
277            if neighbor == target || neighbor.contains(&target) {
278                found = true;
279                // Add the relationship to result
280                let mut record = UnifiedRecord::new();
281                if let Some(src_node) = self.graph.get_node(&source) {
282                    record.set_node("source", MatchedNode::from_stored(&src_node));
283                }
284                if let Some(tgt_node) = self.graph.get_node(&neighbor) {
285                    record.set_node("target", MatchedNode::from_stored(&tgt_node));
286                }
287                record.set(
288                    "relationship",
289                    crate::storage::schema::Value::text(format!("{:?}", edge_type)),
290                );
291                record.set("exists", crate::storage::schema::Value::Boolean(true));
292                record.set(
293                    "weight",
294                    crate::storage::schema::Value::Float(weight as f64),
295                );
296                result.push(record);
297                break;
298            }
299        }
300
301        if !found {
302            // Report that no relationship was found
303            let mut record = UnifiedRecord::new();
304            record.set("exists", crate::storage::schema::Value::Boolean(false));
305            record.set("source", crate::storage::schema::Value::text(source));
306            record.set("target", crate::storage::schema::Value::text(target));
307            result.push(record);
308        }
309
310        result.columns = vec![
311            "source".into(),
312            "target".into(),
313            "relationship".into(),
314            "exists".into(),
315        ];
316        Ok(())
317    }
318
319    /// Get the primary entity label from query — canonical lower-snake-case
320    /// label string corresponding to the first matched entity.
321    fn primary_entity_label(&self, query: &NaturalQuery) -> Option<&'static str> {
322        for entity in &query.entities {
323            match entity.entity_type {
324                EntityType::Host => return Some("host"),
325                EntityType::User => return Some("user"),
326                EntityType::Credential => return Some("credential"),
327                EntityType::Service | EntityType::Port => return Some("service"),
328                EntityType::Vulnerability => return Some("vulnerability"),
329                EntityType::Technology => return Some("technology"),
330                EntityType::Domain => return Some("domain"),
331                EntityType::Certificate => return Some("certificate"),
332                // Network has no canonical legacy label.
333                EntityType::Network => continue,
334            }
335        }
336        None
337    }
338
339    /// Check if node matches entity filters
340    fn node_matches_filters(
341        &self,
342        node: &crate::storage::engine::graph_store::StoredNode,
343        entities: &[ExtractedEntity],
344    ) -> bool {
345        for entity in entities {
346            if let Some(ref value) = entity.value {
347                // Check if value matches node ID or label
348                let matches = node.id.contains(value)
349                    || node.label.to_lowercase().contains(&value.to_lowercase())
350                    || value.to_lowercase().contains(&node.label.to_lowercase());
351                if matches {
352                    return true;
353                }
354            }
355        }
356        // If no values to match, accept all
357        entities.iter().all(|e| e.value.is_none())
358    }
359
360    /// Check if node has relationship to target
361    fn has_relationship_to(&self, node_id: &str, target: &str, stats: &mut QueryStats) -> bool {
362        for (_, neighbor, _) in self.graph.outgoing_edges(node_id) {
363            stats.edges_scanned += 1;
364            if neighbor.contains(target) {
365                return true;
366            }
367            // Check neighbor's label
368            if let Some(neighbor_node) = self.graph.get_node(&neighbor) {
369                if neighbor_node
370                    .label
371                    .to_lowercase()
372                    .contains(&target.to_lowercase())
373                {
374                    return true;
375                }
376            }
377        }
378        false
379    }
380
381    /// Extract path endpoints from query
382    fn extract_path_endpoints(
383        &self,
384        query: &NaturalQuery,
385    ) -> Result<(String, String), ExecutionError> {
386        // Look for "from X to Y" pattern in entities
387        let mut source = None;
388        let mut target = None;
389
390        for entity in &query.entities {
391            if let Some(ref value) = entity.value {
392                // Find nodes matching this value
393                for node in self.graph.iter_nodes() {
394                    if node.id.contains(value)
395                        || node.label.to_lowercase().contains(&value.to_lowercase())
396                    {
397                        if source.is_none() {
398                            source = Some(node.id.clone());
399                        } else if target.is_none() && Some(&node.id) != source.as_ref() {
400                            target = Some(node.id.clone());
401                        }
402                    }
403                }
404            }
405        }
406
407        match (source, target) {
408            (Some(s), Some(t)) => Ok((s, t)),
409            (Some(s), None) => Err(ExecutionError::new(format!(
410                "Path query needs a target. Found source: {}",
411                s
412            ))),
413            _ => Err(ExecutionError::new(
414                "Path query needs source and target. Try: 'path from host X to host Y'",
415            )),
416        }
417    }
418
419    /// Generate explanation of query translation
420    fn generate_explanation(&self, query: &NaturalQuery, original: &str) -> String {
421        let mut explanation = Vec::new();
422
423        explanation.push(format!("Query: \"{}\"", original));
424        explanation.push(format!("Intent: {:?}", query.intent));
425
426        if !query.entities.is_empty() {
427            let entities: Vec<String> = query
428                .entities
429                .iter()
430                .map(|e| {
431                    if let Some(ref val) = e.value {
432                        format!("{:?}({})", e.entity_type, val)
433                    } else {
434                        format!("{:?}", e.entity_type)
435                    }
436                })
437                .collect();
438            explanation.push(format!("Entities: {}", entities.join(", ")));
439        }
440
441        // Generate equivalent RQL
442        let rql = self.to_rql(query);
443        explanation.push(format!("Equivalent RQL: {}", rql));
444
445        explanation.join("\n")
446    }
447
448    /// Convert natural query to RQL string
449    fn to_rql(&self, query: &NaturalQuery) -> String {
450        match query.intent {
451            QueryIntent::Find => {
452                let node_type = self.primary_entity_label(query).unwrap_or("*");
453
454                let filters: Vec<String> = query
455                    .entities
456                    .iter()
457                    .filter_map(|e| {
458                        e.value
459                            .as_ref()
460                            .map(|v| format!("n.label CONTAINS '{}'", v))
461                    })
462                    .collect();
463
464                if filters.is_empty() {
465                    format!("MATCH (n:{}) RETURN n", node_type)
466                } else {
467                    format!(
468                        "MATCH (n:{}) WHERE {} RETURN n",
469                        node_type,
470                        filters.join(" AND ")
471                    )
472                }
473            }
474            QueryIntent::Path => {
475                let endpoints: Vec<&str> = query
476                    .entities
477                    .iter()
478                    .filter_map(|e| e.value.as_deref())
479                    .collect();
480                if endpoints.len() >= 2 {
481                    format!("PATH FROM '{}' TO '{}'", endpoints[0], endpoints[1])
482                } else {
483                    "PATH FROM ? TO ?".to_string()
484                }
485            }
486            QueryIntent::Count => {
487                let node_type = self.primary_entity_label(query).unwrap_or("*");
488                format!("MATCH (n:{}) RETURN COUNT(n)", node_type)
489            }
490            QueryIntent::Show => {
491                let filters: Vec<String> = query
492                    .entities
493                    .iter()
494                    .filter_map(|e| e.value.as_ref().map(|v| format!("n.id = '{}'", v)))
495                    .collect();
496                if filters.is_empty() {
497                    "MATCH (n) RETURN n".to_string()
498                } else {
499                    format!(
500                        "MATCH (n) WHERE {} RETURN n, n.neighbors",
501                        filters.first().unwrap()
502                    )
503                }
504            }
505            QueryIntent::Check => {
506                let endpoints: Vec<&str> = query
507                    .entities
508                    .iter()
509                    .filter_map(|e| e.value.as_deref())
510                    .collect();
511                if endpoints.len() >= 2 {
512                    format!(
513                        "MATCH (a)-[r]->(b) WHERE a.id = '{}' AND b.id = '{}' RETURN EXISTS(r)",
514                        endpoints[0], endpoints[1]
515                    )
516                } else {
517                    "MATCH (a)-[r]->(b) RETURN EXISTS(r)".to_string()
518                }
519            }
520        }
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::storage::query::ast::EdgeDirection;
528    use crate::storage::query::test_support::service_graph_with_user;
529
530    fn create_test_graph() -> Arc<GraphStore> {
531        service_graph_with_user()
532    }
533
534    #[test]
535    fn test_list_hosts() {
536        let graph = create_test_graph();
537        let executor = NaturalExecutor::new(graph);
538
539        let (result, explanation) = executor.execute_with_explanation("list all hosts").unwrap();
540        assert_eq!(result.records.len(), 2);
541        // "list" maps to Find intent in NaturalParser
542        assert!(explanation.contains("Intent: Find"));
543    }
544
545    #[test]
546    fn test_find_services() {
547        let graph = create_test_graph();
548        let executor = NaturalExecutor::new(graph);
549
550        let (result, explanation) = executor.execute_with_explanation("find services").unwrap();
551        assert_eq!(result.records.len(), 2);
552        assert!(explanation.contains("Service"));
553    }
554
555    #[test]
556    fn test_count_hosts() {
557        let graph = create_test_graph();
558        let executor = NaturalExecutor::new(graph);
559
560        let (result, _) = executor.execute_with_explanation("how many hosts").unwrap();
561        assert_eq!(result.records.len(), 1);
562        let count = result.records[0].get("count");
563        assert!(count.is_some());
564    }
565
566    #[test]
567    fn test_explanation_includes_rql() {
568        let graph = create_test_graph();
569        let executor = NaturalExecutor::new(graph);
570
571        let (_, explanation) = executor
572            .execute_with_explanation("find hosts with SSH")
573            .unwrap();
574        assert!(explanation.contains("Equivalent RQL:"));
575        assert!(explanation.contains("MATCH"));
576    }
577
578    #[test]
579    fn test_path_query() {
580        let graph = create_test_graph();
581        let executor = NaturalExecutor::new(graph);
582
583        let (result, explanation) = executor
584            .execute_with_explanation("path from host 10.0.0.1 to host 10.0.0.2")
585            .unwrap();
586        assert!(!result.is_empty());
587        assert!(explanation.contains("Path"));
588    }
589}