1use std::sync::Arc;
14
15use crate::storage::engine::graph_store::GraphStore;
16use crate::storage::query::modes::natural::{
17 EntityType, ExtractedEntity, NaturalParser, NaturalQuery, QueryIntent,
18};
19use crate::storage::query::unified::{
20 ExecutionError, MatchedNode, QueryStats, UnifiedRecord, UnifiedResult,
21};
22
23pub struct NaturalExecutor {
25 graph: Arc<GraphStore>,
26}
27
28impl NaturalExecutor {
29 pub fn new(graph: Arc<GraphStore>) -> Self {
31 Self { graph }
32 }
33
34 pub fn execute_with_explanation(
36 &self,
37 query: &str,
38 ) -> Result<(UnifiedResult, String), ExecutionError> {
39 let parsed = NaturalParser::parse(query).map_err(|e| ExecutionError::new(e.to_string()))?;
41
42 let explanation = self.generate_explanation(&parsed, query);
44
45 let result = self.execute_natural(&parsed)?;
47
48 Ok((result, explanation))
49 }
50
51 pub fn execute(&self, query: &str) -> Result<UnifiedResult, ExecutionError> {
53 let parsed = NaturalParser::parse(query).map_err(|e| ExecutionError::new(e.to_string()))?;
54 self.execute_natural(&parsed)
55 }
56
57 fn execute_natural(&self, query: &NaturalQuery) -> Result<UnifiedResult, ExecutionError> {
59 let mut stats = QueryStats::default();
60 let mut result = UnifiedResult::empty();
61
62 match query.intent {
63 QueryIntent::Find => {
64 self.execute_find(query, &mut result, &mut stats)?;
66 }
67 QueryIntent::Path => {
68 self.execute_path(query, &mut result, &mut stats)?;
69 }
70 QueryIntent::Count => {
71 self.execute_count(query, &mut result, &mut stats)?;
72 }
73 QueryIntent::Show => {
74 self.execute_show(query, &mut result, &mut stats)?;
75 }
76 QueryIntent::Check => {
77 self.execute_check(query, &mut result, &mut stats)?;
78 }
79 }
80
81 result.stats = stats;
82 Ok(result)
83 }
84
85 fn execute_find(
87 &self,
88 query: &NaturalQuery,
89 result: &mut UnifiedResult,
90 stats: &mut QueryStats,
91 ) -> Result<(), ExecutionError> {
92 let entity_label = self.primary_entity_label(query);
93
94 for node in self.graph.iter_nodes() {
95 stats.nodes_scanned += 1;
96
97 let type_matches = match entity_label {
99 Some(label) => node.node_type.as_str() == label,
100 None => true,
101 };
102
103 if !type_matches {
104 continue;
105 }
106
107 if !self.node_matches_filters(&node, &query.entities) {
109 continue;
110 }
111
112 let mut rel_match = true;
114 for entity in &query.entities {
115 if let Some(ref value) = entity.value {
116 if !self.has_relationship_to(&node.id, value, stats) {
118 rel_match = false;
119 break;
120 }
121 }
122 }
123
124 if rel_match {
125 let mut record = UnifiedRecord::new();
126 record.set_node("_", MatchedNode::from_stored(&node));
127 result.push(record);
128 }
129 }
130
131 if let Some(limit) = query.limit {
133 if result.len() > limit as usize {
134 result.records.truncate(limit as usize);
135 }
136 }
137
138 Ok(())
139 }
140
141 fn execute_path(
143 &self,
144 query: &NaturalQuery,
145 result: &mut UnifiedResult,
146 stats: &mut QueryStats,
147 ) -> Result<(), ExecutionError> {
148 let (source, target) = self.extract_path_endpoints(query)?;
150
151 use crate::storage::query::unified::{GraphPath, MatchedEdge};
153 use std::collections::{HashSet, VecDeque};
154
155 let mut queue: VecDeque<(String, GraphPath)> = VecDeque::new();
156 let mut visited: HashSet<String> = HashSet::new();
157
158 queue.push_back((source.clone(), GraphPath::start(&source)));
159 visited.insert(source.clone());
160
161 let max_hops = query.limit.unwrap_or(10) as usize;
162
163 while let Some((current, path)) = queue.pop_front() {
164 if path.len() > max_hops {
165 continue;
166 }
167
168 if current == target {
169 let mut record = UnifiedRecord::new();
170 record.paths.push(path);
171 result.push(record);
172 break; }
174
175 for (edge_type, neighbor, weight) in self.graph.outgoing_edges(¤t) {
176 stats.edges_scanned += 1;
177
178 if !visited.contains(&neighbor) {
179 visited.insert(neighbor.clone());
180 let edge = MatchedEdge::from_tuple(¤t, edge_type, &neighbor, weight);
181 let new_path = path.extend(edge, &neighbor);
182 queue.push_back((neighbor, new_path));
183 }
184 }
185 }
186
187 if result.is_empty() {
188 return Err(ExecutionError::new(format!(
189 "No path found from {} to {}",
190 source, target
191 )));
192 }
193
194 Ok(())
195 }
196
197 fn execute_count(
199 &self,
200 query: &NaturalQuery,
201 result: &mut UnifiedResult,
202 stats: &mut QueryStats,
203 ) -> Result<(), ExecutionError> {
204 let entity_label = self.primary_entity_label(query);
205 let mut count = 0u64;
206
207 for node in self.graph.iter_nodes() {
208 stats.nodes_scanned += 1;
209
210 let type_matches = match entity_label {
211 Some(label) => node.node_type.as_str() == label,
212 None => true,
213 };
214
215 if type_matches && self.node_matches_filters(&node, &query.entities) {
216 count += 1;
217 }
218 }
219
220 let mut record = UnifiedRecord::new();
221 record.set(
222 "count",
223 crate::storage::schema::Value::Integer(count as i64),
224 );
225 result.push(record);
226 result.columns.push("count".to_string());
227
228 Ok(())
229 }
230
231 fn execute_show(
233 &self,
234 query: &NaturalQuery,
235 result: &mut UnifiedResult,
236 stats: &mut QueryStats,
237 ) -> Result<(), ExecutionError> {
238 self.execute_find(query, result, stats)?;
240
241 if result.len() == 1 {
243 if let Some(node) = result.records.first().and_then(|r| r.nodes.get("_")) {
244 for (edge_type, target, _) in self.graph.outgoing_edges(&node.id) {
246 stats.edges_scanned += 1;
247 if let Some(target_node) = self.graph.get_node(&target) {
248 let mut record = UnifiedRecord::new();
249 record.set_node("related", MatchedNode::from_stored(&target_node));
250 record.set(
251 "relationship",
252 crate::storage::schema::Value::text(format!("{:?}", edge_type)),
253 );
254 result.push(record);
255 }
256 }
257 }
258 }
259
260 Ok(())
261 }
262
263 fn execute_check(
265 &self,
266 query: &NaturalQuery,
267 result: &mut UnifiedResult,
268 stats: &mut QueryStats,
269 ) -> Result<(), ExecutionError> {
270 let (source, target) = self.extract_path_endpoints(query)?;
272
273 let mut found = false;
275 for (edge_type, neighbor, weight) in self.graph.outgoing_edges(&source) {
276 stats.edges_scanned += 1;
277 if neighbor == target || neighbor.contains(&target) {
278 found = true;
279 let mut record = UnifiedRecord::new();
281 if let Some(src_node) = self.graph.get_node(&source) {
282 record.set_node("source", MatchedNode::from_stored(&src_node));
283 }
284 if let Some(tgt_node) = self.graph.get_node(&neighbor) {
285 record.set_node("target", MatchedNode::from_stored(&tgt_node));
286 }
287 record.set(
288 "relationship",
289 crate::storage::schema::Value::text(format!("{:?}", edge_type)),
290 );
291 record.set("exists", crate::storage::schema::Value::Boolean(true));
292 record.set(
293 "weight",
294 crate::storage::schema::Value::Float(weight as f64),
295 );
296 result.push(record);
297 break;
298 }
299 }
300
301 if !found {
302 let mut record = UnifiedRecord::new();
304 record.set("exists", crate::storage::schema::Value::Boolean(false));
305 record.set("source", crate::storage::schema::Value::text(source));
306 record.set("target", crate::storage::schema::Value::text(target));
307 result.push(record);
308 }
309
310 result.columns = vec![
311 "source".into(),
312 "target".into(),
313 "relationship".into(),
314 "exists".into(),
315 ];
316 Ok(())
317 }
318
319 fn primary_entity_label(&self, query: &NaturalQuery) -> Option<&'static str> {
322 for entity in &query.entities {
323 match entity.entity_type {
324 EntityType::Host => return Some("host"),
325 EntityType::User => return Some("user"),
326 EntityType::Credential => return Some("credential"),
327 EntityType::Service | EntityType::Port => return Some("service"),
328 EntityType::Vulnerability => return Some("vulnerability"),
329 EntityType::Technology => return Some("technology"),
330 EntityType::Domain => return Some("domain"),
331 EntityType::Certificate => return Some("certificate"),
332 EntityType::Network => continue,
334 }
335 }
336 None
337 }
338
339 fn node_matches_filters(
341 &self,
342 node: &crate::storage::engine::graph_store::StoredNode,
343 entities: &[ExtractedEntity],
344 ) -> bool {
345 for entity in entities {
346 if let Some(ref value) = entity.value {
347 let matches = node.id.contains(value)
349 || node.label.to_lowercase().contains(&value.to_lowercase())
350 || value.to_lowercase().contains(&node.label.to_lowercase());
351 if matches {
352 return true;
353 }
354 }
355 }
356 entities.iter().all(|e| e.value.is_none())
358 }
359
360 fn has_relationship_to(&self, node_id: &str, target: &str, stats: &mut QueryStats) -> bool {
362 for (_, neighbor, _) in self.graph.outgoing_edges(node_id) {
363 stats.edges_scanned += 1;
364 if neighbor.contains(target) {
365 return true;
366 }
367 if let Some(neighbor_node) = self.graph.get_node(&neighbor) {
369 if neighbor_node
370 .label
371 .to_lowercase()
372 .contains(&target.to_lowercase())
373 {
374 return true;
375 }
376 }
377 }
378 false
379 }
380
381 fn extract_path_endpoints(
383 &self,
384 query: &NaturalQuery,
385 ) -> Result<(String, String), ExecutionError> {
386 let mut source = None;
388 let mut target = None;
389
390 for entity in &query.entities {
391 if let Some(ref value) = entity.value {
392 for node in self.graph.iter_nodes() {
394 if node.id.contains(value)
395 || node.label.to_lowercase().contains(&value.to_lowercase())
396 {
397 if source.is_none() {
398 source = Some(node.id.clone());
399 } else if target.is_none() && Some(&node.id) != source.as_ref() {
400 target = Some(node.id.clone());
401 }
402 }
403 }
404 }
405 }
406
407 match (source, target) {
408 (Some(s), Some(t)) => Ok((s, t)),
409 (Some(s), None) => Err(ExecutionError::new(format!(
410 "Path query needs a target. Found source: {}",
411 s
412 ))),
413 _ => Err(ExecutionError::new(
414 "Path query needs source and target. Try: 'path from host X to host Y'",
415 )),
416 }
417 }
418
419 fn generate_explanation(&self, query: &NaturalQuery, original: &str) -> String {
421 let mut explanation = Vec::new();
422
423 explanation.push(format!("Query: \"{}\"", original));
424 explanation.push(format!("Intent: {:?}", query.intent));
425
426 if !query.entities.is_empty() {
427 let entities: Vec<String> = query
428 .entities
429 .iter()
430 .map(|e| {
431 if let Some(ref val) = e.value {
432 format!("{:?}({})", e.entity_type, val)
433 } else {
434 format!("{:?}", e.entity_type)
435 }
436 })
437 .collect();
438 explanation.push(format!("Entities: {}", entities.join(", ")));
439 }
440
441 let rql = self.to_rql(query);
443 explanation.push(format!("Equivalent RQL: {}", rql));
444
445 explanation.join("\n")
446 }
447
448 fn to_rql(&self, query: &NaturalQuery) -> String {
450 match query.intent {
451 QueryIntent::Find => {
452 let node_type = self.primary_entity_label(query).unwrap_or("*");
453
454 let filters: Vec<String> = query
455 .entities
456 .iter()
457 .filter_map(|e| {
458 e.value
459 .as_ref()
460 .map(|v| format!("n.label CONTAINS '{}'", v))
461 })
462 .collect();
463
464 if filters.is_empty() {
465 format!("MATCH (n:{}) RETURN n", node_type)
466 } else {
467 format!(
468 "MATCH (n:{}) WHERE {} RETURN n",
469 node_type,
470 filters.join(" AND ")
471 )
472 }
473 }
474 QueryIntent::Path => {
475 let endpoints: Vec<&str> = query
476 .entities
477 .iter()
478 .filter_map(|e| e.value.as_deref())
479 .collect();
480 if endpoints.len() >= 2 {
481 format!("PATH FROM '{}' TO '{}'", endpoints[0], endpoints[1])
482 } else {
483 "PATH FROM ? TO ?".to_string()
484 }
485 }
486 QueryIntent::Count => {
487 let node_type = self.primary_entity_label(query).unwrap_or("*");
488 format!("MATCH (n:{}) RETURN COUNT(n)", node_type)
489 }
490 QueryIntent::Show => {
491 let filters: Vec<String> = query
492 .entities
493 .iter()
494 .filter_map(|e| e.value.as_ref().map(|v| format!("n.id = '{}'", v)))
495 .collect();
496 if filters.is_empty() {
497 "MATCH (n) RETURN n".to_string()
498 } else {
499 format!(
500 "MATCH (n) WHERE {} RETURN n, n.neighbors",
501 filters.first().unwrap()
502 )
503 }
504 }
505 QueryIntent::Check => {
506 let endpoints: Vec<&str> = query
507 .entities
508 .iter()
509 .filter_map(|e| e.value.as_deref())
510 .collect();
511 if endpoints.len() >= 2 {
512 format!(
513 "MATCH (a)-[r]->(b) WHERE a.id = '{}' AND b.id = '{}' RETURN EXISTS(r)",
514 endpoints[0], endpoints[1]
515 )
516 } else {
517 "MATCH (a)-[r]->(b) RETURN EXISTS(r)".to_string()
518 }
519 }
520 }
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::storage::query::ast::EdgeDirection;
528 use crate::storage::query::test_support::service_graph_with_user;
529
530 fn create_test_graph() -> Arc<GraphStore> {
531 service_graph_with_user()
532 }
533
534 #[test]
535 fn test_list_hosts() {
536 let graph = create_test_graph();
537 let executor = NaturalExecutor::new(graph);
538
539 let (result, explanation) = executor.execute_with_explanation("list all hosts").unwrap();
540 assert_eq!(result.records.len(), 2);
541 assert!(explanation.contains("Intent: Find"));
543 }
544
545 #[test]
546 fn test_find_services() {
547 let graph = create_test_graph();
548 let executor = NaturalExecutor::new(graph);
549
550 let (result, explanation) = executor.execute_with_explanation("find services").unwrap();
551 assert_eq!(result.records.len(), 2);
552 assert!(explanation.contains("Service"));
553 }
554
555 #[test]
556 fn test_count_hosts() {
557 let graph = create_test_graph();
558 let executor = NaturalExecutor::new(graph);
559
560 let (result, _) = executor.execute_with_explanation("how many hosts").unwrap();
561 assert_eq!(result.records.len(), 1);
562 let count = result.records[0].get("count");
563 assert!(count.is_some());
564 }
565
566 #[test]
567 fn test_explanation_includes_rql() {
568 let graph = create_test_graph();
569 let executor = NaturalExecutor::new(graph);
570
571 let (_, explanation) = executor
572 .execute_with_explanation("find hosts with SSH")
573 .unwrap();
574 assert!(explanation.contains("Equivalent RQL:"));
575 assert!(explanation.contains("MATCH"));
576 }
577
578 #[test]
579 fn test_path_query() {
580 let graph = create_test_graph();
581 let executor = NaturalExecutor::new(graph);
582
583 let (result, explanation) = executor
584 .execute_with_explanation("path from host 10.0.0.1 to host 10.0.0.2")
585 .unwrap();
586 assert!(!result.is_empty());
587 assert!(explanation.contains("Path"));
588 }
589}