1use crate::ast::{
18 CompareOp, EdgeDirection, EdgePattern, FieldRef, Filter, GraphPattern, GraphQuery, NodePattern,
19 Projection, PropertyFilter as AstPropertyFilter, QueryExpr,
20};
21use reddb_types::types::Value;
22
23#[derive(Debug, Clone)]
25pub struct NaturalError {
26 pub message: String,
27}
28
29impl std::fmt::Display for NaturalError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "Natural language error: {}", self.message)
32 }
33}
34
35impl std::error::Error for NaturalError {}
36
37#[derive(Debug, Clone)]
39pub struct NaturalQuery {
40 pub intent: QueryIntent,
42 pub primary_entity: Option<EntityType>,
44 pub secondary_entity: Option<EntityType>,
46 pub entities: Vec<ExtractedEntity>,
48 pub filters: Vec<PropertyFilter>,
50 pub relationship: Option<RelationshipType>,
52 pub limit: Option<u64>,
54}
55
56#[derive(Debug, Clone, PartialEq)]
58pub enum QueryIntent {
59 Find,
61 Show,
63 Count,
65 Path,
67 Check,
69}
70
71#[derive(Debug, Clone, PartialEq)]
73pub enum EntityType {
74 Host,
75 Service,
76 Port,
77 User,
78 Credential,
79 Vulnerability,
80 Technology,
81 Domain,
82 Certificate,
83 Network,
84}
85
86#[derive(Debug, Clone)]
88pub struct ExtractedEntity {
89 pub entity_type: EntityType,
90 pub value: Option<String>,
91 pub alias: String,
92}
93
94#[derive(Debug, Clone)]
96pub struct PropertyFilter {
97 pub property: String,
98 pub op: CompareOp,
99 pub value: String,
100}
101
102#[derive(Debug, Clone, PartialEq)]
104pub enum RelationshipType {
105 HasService,
106 HasPort,
107 HasVuln,
108 HasCredential,
109 HasUser,
110 ConnectsTo,
111 Affects,
112 AuthAccess,
113 Uses,
114 RunsOn,
115 Exposes,
116}
117
118pub struct NaturalParser;
120
121impl NaturalParser {
122 pub fn parse(input: &str) -> Result<NaturalQuery, NaturalError> {
124 let text = Self::normalize(input);
125 let tokens: Vec<&str> = text.split_whitespace().collect();
126
127 if tokens.is_empty() {
128 return Err(NaturalError {
129 message: "Empty query".to_string(),
130 });
131 }
132
133 let intent = Self::detect_intent(&tokens);
135
136 let entities = Self::extract_entities(&text);
138
139 let (primary, secondary) = Self::determine_entity_types(&entities, &text);
141
142 let filters = Self::extract_filters(&text);
144
145 let relationship = Self::detect_relationship(&text, &primary, &secondary);
147
148 let limit = Self::extract_limit(&text);
150
151 Ok(NaturalQuery {
152 intent,
153 primary_entity: primary,
154 secondary_entity: secondary,
155 entities,
156 filters,
157 relationship,
158 limit,
159 })
160 }
161
162 fn normalize(input: &str) -> String {
164 let trimmed = input.trim();
166 let unquoted = if (trimmed.starts_with('"') && trimmed.ends_with('"'))
167 || (trimmed.starts_with('\'') && trimmed.ends_with('\''))
168 {
169 &trimmed[1..trimmed.len() - 1]
170 } else {
171 trimmed
172 };
173
174 unquoted
176 .to_lowercase()
177 .chars()
178 .map(|c| {
179 if c.is_alphanumeric()
180 || c.is_whitespace()
181 || c == '.'
182 || c == ':'
183 || c == '-'
184 || c == '_'
185 {
186 c
187 } else {
188 ' '
189 }
190 })
191 .collect::<String>()
192 .split_whitespace()
193 .collect::<Vec<_>>()
194 .join(" ")
195 }
196
197 fn detect_intent(tokens: &[&str]) -> QueryIntent {
199 let first = tokens.first().copied().unwrap_or("");
200
201 match first {
202 "find" | "search" | "list" | "get" | "fetch" | "retrieve" => QueryIntent::Find,
203 "show" | "display" | "view" | "describe" | "detail" | "details" => QueryIntent::Show,
204 "count" | "how" => {
205 if tokens.contains(&"many") || tokens.contains(&"count") {
206 QueryIntent::Count
207 } else {
208 QueryIntent::Find
209 }
210 }
211 "path" | "paths" | "route" | "reach" | "reachable" => QueryIntent::Path,
212 "is" | "are" | "does" | "can" | "check" => QueryIntent::Check,
213 "what" | "which" | "where" | "who" => {
214 QueryIntent::Find
216 }
217 _ => QueryIntent::Find,
218 }
219 }
220
221 fn extract_entities(text: &str) -> Vec<ExtractedEntity> {
223 let mut entities = Vec::new();
224 let mut alias_counter = 0;
225
226 let entity_patterns: Vec<(EntityType, &[&str], Option<&str>)> = vec![
228 (
229 EntityType::Host,
230 &[
231 "host", "hosts", "server", "servers", "machine", "machines", "ip", "ips",
232 ],
233 None,
234 ),
235 (EntityType::Service, &["service", "services"], None),
236 (EntityType::Port, &["port", "ports"], None),
237 (
238 EntityType::User,
239 &[
240 "user",
241 "users",
242 "account",
243 "accounts",
244 "username",
245 "usernames",
246 ],
247 None,
248 ),
249 (
250 EntityType::Credential,
251 &[
252 "credential",
253 "credentials",
254 "password",
255 "passwords",
256 "cred",
257 "creds",
258 ],
259 None,
260 ),
261 (
262 EntityType::Vulnerability,
263 &[
264 "vulnerability",
265 "vulnerabilities",
266 "vuln",
267 "vulns",
268 "cve",
269 "cves",
270 ],
271 None,
272 ),
273 (
274 EntityType::Technology,
275 &[
276 "technology",
277 "technologies",
278 "tech",
279 "software",
280 "application",
281 "applications",
282 ],
283 None,
284 ),
285 (
286 EntityType::Domain,
287 &["domain", "domains", "subdomain", "subdomains"],
288 None,
289 ),
290 (
291 EntityType::Certificate,
292 &["certificate", "certificates", "cert", "certs", "ssl", "tls"],
293 None,
294 ),
295 (
296 EntityType::Network,
297 &[
298 "network", "networks", "subnet", "subnets", "segment", "segments",
299 ],
300 None,
301 ),
302 ];
303
304 for (entity_type, keywords, _) in entity_patterns {
305 for keyword in keywords.iter() {
306 if text.contains(keyword) {
307 let value = Self::extract_entity_value(text, keyword);
309
310 entities.push(ExtractedEntity {
311 entity_type: entity_type.clone(),
312 value,
313 alias: format!("e{}", alias_counter),
314 });
315 alias_counter += 1;
316 break; }
318 }
319 }
320
321 for word in text.split_whitespace() {
323 if Self::is_ip_address(word) {
324 let already_has_host = entities
325 .iter()
326 .any(|e| e.entity_type == EntityType::Host && e.value.as_deref() == Some(word));
327 if !already_has_host {
328 entities.push(ExtractedEntity {
329 entity_type: EntityType::Host,
330 value: Some(word.to_string()),
331 alias: format!("e{}", alias_counter),
332 });
333 alias_counter += 1;
334 }
335 }
336 }
337
338 for word in text.split_whitespace() {
340 if word.starts_with("cve-") || word.starts_with("cve:") {
341 let cve = word
342 .replace("cve:", "CVE-")
343 .replace("cve-", "CVE-")
344 .to_uppercase();
345 entities.push(ExtractedEntity {
346 entity_type: EntityType::Vulnerability,
347 value: Some(cve),
348 alias: format!("e{}", alias_counter),
349 });
350 alias_counter += 1;
351 }
352 }
353
354 entities
355 }
356
357 fn extract_entity_value(text: &str, keyword: &str) -> Option<String> {
359 let parts: Vec<&str> = text.split_whitespace().collect();
361
362 for (i, part) in parts.iter().enumerate() {
363 if *part == keyword {
364 if let Some(next) = parts.get(i + 1) {
366 if ![
368 "with", "that", "has", "have", "is", "are", "the", "a", "an", "for", "on",
369 "in",
370 ]
371 .contains(next)
372 {
373 return Some(next.to_string());
374 }
375 if let Some(next2) = parts.get(i + 2) {
377 if ![
378 "with", "that", "has", "have", "is", "are", "the", "a", "an", "for",
379 "on", "in",
380 ]
381 .contains(next2)
382 {
383 return Some(next2.to_string());
384 }
385 }
386 }
387 }
388 }
389
390 None
391 }
392
393 fn is_ip_address(s: &str) -> bool {
395 let parts: Vec<&str> = s.split('.').collect();
396 if parts.len() != 4 {
397 return false;
398 }
399 parts.iter().all(|p| p.parse::<u8>().is_ok())
400 }
401
402 fn determine_entity_types(
404 entities: &[ExtractedEntity],
405 text: &str,
406 ) -> (Option<EntityType>, Option<EntityType>) {
407 if entities.is_empty() {
408 if text.contains("host") || text.contains("server") || text.contains("ip") {
410 return (Some(EntityType::Host), None);
411 }
412 if text.contains("vuln") || text.contains("cve") {
413 return (Some(EntityType::Vulnerability), None);
414 }
415 if text.contains("user") || text.contains("account") {
416 return (Some(EntityType::User), None);
417 }
418 if text.contains("cred") || text.contains("password") {
419 return (Some(EntityType::Credential), None);
420 }
421 if text.contains("service") {
422 return (Some(EntityType::Service), None);
423 }
424 return (None, None);
425 }
426
427 let primary = entities.first().map(|e| e.entity_type.clone());
428 let secondary = entities.get(1).map(|e| e.entity_type.clone());
429
430 (primary, secondary)
431 }
432
433 fn extract_filters(text: &str) -> Vec<PropertyFilter> {
435 let mut filters = Vec::new();
436
437 if text.contains("port") {
439 for word in text.split_whitespace() {
440 if let Ok(port) = word.parse::<u16>() {
441 if port > 0 {
442 filters.push(PropertyFilter {
444 property: "port".to_string(),
445 op: CompareOp::Eq,
446 value: port.to_string(),
447 });
448 }
449 }
450 }
451 }
452
453 let services = [
455 "ssh", "http", "https", "ftp", "smtp", "mysql", "postgres", "redis", "mongodb", "rdp",
456 "vnc",
457 ];
458 for svc in services {
459 if text.contains(svc) {
460 filters.push(PropertyFilter {
461 property: "service".to_string(),
462 op: CompareOp::Eq,
463 value: svc.to_string(),
464 });
465 }
466 }
467
468 if text.contains("critical") {
470 filters.push(PropertyFilter {
471 property: "severity".to_string(),
472 op: CompareOp::Eq,
473 value: "critical".to_string(),
474 });
475 } else if text.contains("high") {
476 filters.push(PropertyFilter {
477 property: "severity".to_string(),
478 op: CompareOp::Ge,
479 value: "7.0".to_string(),
480 });
481 } else if text.contains("medium") {
482 filters.push(PropertyFilter {
483 property: "severity".to_string(),
484 op: CompareOp::Ge,
485 value: "4.0".to_string(),
486 });
487 }
488
489 if text.contains("weak") && (text.contains("password") || text.contains("credential")) {
491 filters.push(PropertyFilter {
492 property: "strength".to_string(),
493 op: CompareOp::Eq,
494 value: "weak".to_string(),
495 });
496 }
497
498 if text.contains("open") || text.contains("exposed") || text.contains("public") {
500 filters.push(PropertyFilter {
501 property: "status".to_string(),
502 op: CompareOp::Eq,
503 value: "open".to_string(),
504 });
505 }
506
507 filters
508 }
509
510 fn detect_relationship(
512 text: &str,
513 primary: &Option<EntityType>,
514 secondary: &Option<EntityType>,
515 ) -> Option<RelationshipType> {
516 if text.contains("connects to") || text.contains("connected to") || text.contains("reach") {
518 return Some(RelationshipType::ConnectsTo);
519 }
520 if text.contains("affects") || text.contains("affected by") || text.contains("vulnerable") {
521 return Some(RelationshipType::Affects);
522 }
523 if text.contains("has access")
524 || text.contains("can access")
525 || text.contains("authenticate")
526 {
527 return Some(RelationshipType::AuthAccess);
528 }
529 if text.contains("runs on") || text.contains("running on") {
530 return Some(RelationshipType::RunsOn);
531 }
532 if text.contains("uses") || text.contains("using") {
533 return Some(RelationshipType::Uses);
534 }
535 if text.contains("exposes") || text.contains("exposing") {
536 return Some(RelationshipType::Exposes);
537 }
538
539 match (primary, secondary) {
541 (Some(EntityType::Host), Some(EntityType::Service)) => {
542 Some(RelationshipType::HasService)
543 }
544 (Some(EntityType::Host), Some(EntityType::Port)) => Some(RelationshipType::HasPort),
545 (Some(EntityType::Host), Some(EntityType::Vulnerability)) => {
546 Some(RelationshipType::HasVuln)
547 }
548 (Some(EntityType::User), Some(EntityType::Credential)) => {
549 Some(RelationshipType::HasCredential)
550 }
551 (Some(EntityType::Credential), Some(EntityType::Host)) => {
552 Some(RelationshipType::AuthAccess)
553 }
554 (Some(EntityType::Vulnerability), Some(EntityType::Host)) => {
555 Some(RelationshipType::Affects)
556 }
557 _ => None,
558 }
559 }
560
561 fn extract_limit(text: &str) -> Option<u64> {
563 let patterns = [("top ", 4), ("first ", 6), ("limit ", 6), ("show ", 5)];
564
565 for (pattern, skip) in patterns {
566 if let Some(pos) = text.find(pattern) {
567 let after = &text[pos + skip..];
568 let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
569 if let Ok(n) = num_str.parse::<u64>() {
570 return Some(n);
571 }
572 }
573 }
574
575 None
576 }
577}
578
579impl NaturalQuery {
580 pub fn to_query_expr(&self) -> QueryExpr {
582 let mut nodes = Vec::new();
583 let mut edges = Vec::new();
584 let mut filters = Vec::new();
585
586 for entity in &self.entities {
588 let node_type = match entity.entity_type {
589 EntityType::Host => Some("host".to_string()),
590 EntityType::Service => Some("service".to_string()),
591 EntityType::User => Some("user".to_string()),
592 EntityType::Credential => Some("credential".to_string()),
593 EntityType::Vulnerability => Some("vulnerability".to_string()),
594 EntityType::Technology => Some("technology".to_string()),
595 EntityType::Domain => Some("domain".to_string()),
596 EntityType::Certificate => Some("certificate".to_string()),
597 _ => None,
598 };
599
600 let mut properties: Vec<AstPropertyFilter> = Vec::new();
601 if let Some(ref value) = entity.value {
602 properties.push(AstPropertyFilter {
603 name: "id".to_string(),
604 op: CompareOp::Eq,
605 value: Value::text(value.clone()),
606 });
607 }
608
609 nodes.push(NodePattern {
610 alias: entity.alias.clone(),
611 node_label: node_type.clone(),
612 properties,
613 });
614 }
615
616 if let Some(ref relationship) = self.relationship {
621 if nodes.len() >= 2 {
622 let edge_label = Some(
623 match relationship {
624 RelationshipType::HasService => "has_service",
625 RelationshipType::HasPort => "has_endpoint",
626 RelationshipType::HasVuln => "affected_by",
627 RelationshipType::HasCredential => "auth_access",
628 RelationshipType::HasUser => "has_user",
629 RelationshipType::ConnectsTo => "connects_to",
630 RelationshipType::Affects => "affected_by",
631 RelationshipType::AuthAccess => "auth_access",
632 RelationshipType::Uses => "uses_tech",
633 RelationshipType::RunsOn => "contains",
634 RelationshipType::Exposes => "has_endpoint",
635 }
636 .to_string(),
637 );
638
639 edges.push(EdgePattern {
640 alias: None,
641 from: nodes[0].alias.clone(),
642 to: nodes[1].alias.clone(),
643 edge_label,
644 direction: EdgeDirection::Outgoing,
645 min_hops: 1,
646 max_hops: 1,
647 });
648 }
649 }
650
651 let current_alias = nodes
653 .first()
654 .map(|n| n.alias.clone())
655 .unwrap_or_else(|| "n0".to_string());
656 for filter in &self.filters {
657 filters.push(Filter::Compare {
658 field: FieldRef::NodeProperty {
659 alias: current_alias.clone(),
660 property: filter.property.clone(),
661 },
662 op: filter.op,
663 value: Value::text(filter.value.clone()),
664 });
665 }
666
667 let projections = match self.intent {
669 QueryIntent::Count => vec![Projection::Field(
670 FieldRef::NodeId {
671 alias: current_alias.clone(),
672 },
673 Some("count".to_string()),
674 )],
675 _ => vec![Projection::from_field(FieldRef::NodeId {
676 alias: current_alias.clone(),
677 })],
678 };
679
680 if nodes.is_empty() {
682 if let Some(ref entity_type) = self.primary_entity {
683 let node_label = match entity_type {
684 EntityType::Host => Some("host".to_string()),
685 EntityType::Service => Some("service".to_string()),
686 EntityType::User => Some("user".to_string()),
687 EntityType::Credential => Some("credential".to_string()),
688 EntityType::Vulnerability => Some("vulnerability".to_string()),
689 _ => None,
690 };
691
692 nodes.push(NodePattern {
693 alias: "n0".to_string(),
694 node_label,
695 properties: Vec::new(),
696 });
697 }
698 }
699
700 let combined_filter = if filters.is_empty() {
702 None
703 } else {
704 let mut iter = filters.into_iter();
705 let first = iter.next().unwrap();
706 Some(iter.fold(first, |acc, f| Filter::And(Box::new(acc), Box::new(f))))
707 };
708
709 QueryExpr::Graph(GraphQuery {
710 alias: None,
711 pattern: GraphPattern { nodes, edges },
712 filter: combined_filter,
713 return_: projections,
714 limit: self.limit,
715 })
716 }
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 fn graph(expr: QueryExpr) -> GraphQuery {
724 match expr {
725 QueryExpr::Graph(graph) => graph,
726 other => panic!("expected graph query, got {other:?}"),
727 }
728 }
729
730 fn entity(entity_type: EntityType, alias: &str) -> ExtractedEntity {
731 ExtractedEntity {
732 entity_type,
733 value: None,
734 alias: alias.to_string(),
735 }
736 }
737
738 #[test]
739 fn test_parse_find_hosts() {
740 let q = NaturalParser::parse("find all hosts with ssh open").unwrap();
741 assert_eq!(q.intent, QueryIntent::Find);
742 assert!(q.entities.iter().any(|e| e.entity_type == EntityType::Host));
743 assert!(q
744 .filters
745 .iter()
746 .any(|f| f.property == "service" && f.value == "ssh"));
747 }
748
749 #[test]
750 fn test_parse_show_credentials() {
751 let q = NaturalParser::parse("show me credentials for user admin").unwrap();
752 assert_eq!(q.intent, QueryIntent::Show);
753 assert!(q
754 .entities
755 .iter()
756 .any(|e| e.entity_type == EntityType::Credential));
757 assert!(q.entities.iter().any(|e| e.entity_type == EntityType::User));
758 }
759
760 #[test]
761 fn test_parse_with_ip() {
762 let q = NaturalParser::parse("what vulnerabilities affect host 10.0.0.1").unwrap();
763 assert!(q
764 .entities
765 .iter()
766 .any(|e| e.entity_type == EntityType::Host && e.value == Some("10.0.0.1".to_string())));
767 assert!(q
768 .entities
769 .iter()
770 .any(|e| e.entity_type == EntityType::Vulnerability));
771 }
772
773 #[test]
774 fn test_parse_count() {
775 let q = NaturalParser::parse("how many hosts have port 22 open").unwrap();
776 assert_eq!(q.intent, QueryIntent::Count);
777 }
778
779 #[test]
780 fn test_parse_weak_passwords() {
781 let q = NaturalParser::parse("list users with weak passwords").unwrap();
782 assert!(q
783 .filters
784 .iter()
785 .any(|f| f.property == "strength" && f.value == "weak"));
786 }
787
788 #[test]
789 fn test_parse_critical_vulns() {
790 let q = NaturalParser::parse("show critical vulnerabilities").unwrap();
791 assert!(q
792 .filters
793 .iter()
794 .any(|f| f.property == "severity" && f.value == "critical"));
795 }
796
797 #[test]
798 fn test_parse_quoted() {
799 let q = NaturalParser::parse("\"find hosts connected to 10.0.0.1\"").unwrap();
800 assert_eq!(q.intent, QueryIntent::Find);
801 assert!(q.relationship == Some(RelationshipType::ConnectsTo));
802 }
803
804 #[test]
805 fn test_parse_with_limit() {
806 let q = NaturalParser::parse("show top 10 vulnerable hosts").unwrap();
807 assert_eq!(q.limit, Some(10));
808 }
809
810 #[test]
811 fn test_to_query_expr() {
812 let q = NaturalParser::parse("find all hosts with ssh").unwrap();
813 let expr = q.to_query_expr();
814 assert!(matches!(expr, QueryExpr::Graph(_)));
815 }
816
817 #[test]
818 fn test_detect_relationship() {
819 let q = NaturalParser::parse("credentials that can access host 10.0.0.1").unwrap();
820 assert_eq!(q.relationship, Some(RelationshipType::AuthAccess));
821 }
822
823 #[test]
824 fn test_parse_rejects_empty_after_normalization() {
825 let err = NaturalParser::parse(" ?! ").unwrap_err();
826 assert_eq!(err.message, "Empty query");
827 }
828
829 #[test]
830 fn test_parse_intent_variants() {
831 let cases = [
832 ("search hosts", QueryIntent::Find),
833 ("display users", QueryIntent::Show),
834 ("count users", QueryIntent::Count),
835 ("how are hosts", QueryIntent::Find),
836 ("route between hosts", QueryIntent::Path),
837 ("does user admin access host 10.0.0.1", QueryIntent::Check),
838 ("which services are public", QueryIntent::Find),
839 ("unexpected words", QueryIntent::Find),
840 ];
841
842 for (input, expected) in cases {
843 let q = NaturalParser::parse(input).unwrap();
844 assert_eq!(q.intent, expected, "{input}");
845 }
846 }
847
848 #[test]
849 fn test_parse_entities_from_values_and_identifiers() {
850 let user = NaturalParser::parse("show user the admin").unwrap();
851 assert!(user
852 .entities
853 .iter()
854 .any(|e| e.entity_type == EntityType::User && e.value == Some("admin".to_string())));
855
856 let host = NaturalParser::parse("find 192.168.1.10").unwrap();
857 assert!(host.entities.iter().any(|e| {
858 e.entity_type == EntityType::Host && e.value == Some("192.168.1.10".to_string())
859 }));
860
861 let cve = NaturalParser::parse("show cve:2024-1234").unwrap();
862 assert!(cve.entities.iter().any(|e| {
863 e.entity_type == EntityType::Vulnerability
864 && e.value == Some("CVE-2024-1234".to_string())
865 }));
866 }
867
868 #[test]
869 fn test_parse_filter_variants() {
870 let high = NaturalParser::parse("find high vulnerabilities").unwrap();
871 assert!(high
872 .filters
873 .iter()
874 .any(|f| f.property == "severity" && f.op == CompareOp::Ge && f.value == "7.0"));
875
876 let medium = NaturalParser::parse("find medium vulnerabilities").unwrap();
877 assert!(medium
878 .filters
879 .iter()
880 .any(|f| f.property == "severity" && f.op == CompareOp::Ge && f.value == "4.0"));
881
882 let public_rdp = NaturalParser::parse("find public rdp services").unwrap();
883 assert!(public_rdp
884 .filters
885 .iter()
886 .any(|f| f.property == "service" && f.value == "rdp"));
887 assert!(public_rdp
888 .filters
889 .iter()
890 .any(|f| f.property == "status" && f.value == "open"));
891
892 let zero_port = NaturalParser::parse("find hosts with port 0").unwrap();
893 assert!(!zero_port.filters.iter().any(|f| f.property == "port"));
894 }
895
896 #[test]
897 fn test_parse_limit_variants() {
898 let cases = [
899 ("top 3 hosts", Some(3)),
900 ("first 4 hosts", Some(4)),
901 ("limit 5 hosts", Some(5)),
902 ("show 6 hosts", Some(6)),
903 ("show hosts", None),
904 ("top hosts", None),
905 ];
906
907 for (input, expected) in cases {
908 let q = NaturalParser::parse(input).unwrap();
909 assert_eq!(q.limit, expected, "{input}");
910 }
911 }
912
913 #[test]
914 fn test_parse_explicit_relationship_phrases() {
915 let cases = [
916 (
917 "find hosts running on technology linux",
918 RelationshipType::RunsOn,
919 ),
920 (
921 "find services using certificate tls",
922 RelationshipType::Uses,
923 ),
924 (
925 "show host 10.0.0.1 exposes port 443",
926 RelationshipType::Exposes,
927 ),
928 (
929 "find hosts affected by cve-2024-1234",
930 RelationshipType::Affects,
931 ),
932 (
933 "check users authenticate to host 10.0.0.1",
934 RelationshipType::AuthAccess,
935 ),
936 ];
937
938 for (input, expected) in cases {
939 let q = NaturalParser::parse(input).unwrap();
940 assert_eq!(q.relationship, Some(expected), "{input}");
941 }
942 }
943
944 #[test]
945 fn test_parse_inferred_relationships_from_entity_pairs() {
946 let cases = [
947 ("find hosts services", RelationshipType::HasService),
948 ("find hosts port 443", RelationshipType::HasPort),
949 ("find hosts vulnerabilities", RelationshipType::HasVuln),
950 ("find users credentials", RelationshipType::HasCredential),
951 (
952 "find credentials for 10.0.0.1",
953 RelationshipType::AuthAccess,
954 ),
955 ("find cves for 10.0.0.1", RelationshipType::Affects),
956 ];
957
958 for (input, expected) in cases {
959 let q = NaturalParser::parse(input).unwrap();
960 assert_eq!(q.relationship, Some(expected), "{input}");
961 }
962
963 let unrelated = NaturalParser::parse("find domains certificates").unwrap();
964 assert_eq!(unrelated.relationship, None);
965 }
966
967 #[test]
968 fn test_unknown_text_falls_back_to_find_without_entities() {
969 let q = NaturalParser::parse("unmapped gibberish").unwrap();
970 assert_eq!(q.intent, QueryIntent::Find);
971 assert_eq!(q.primary_entity, None);
972 assert_eq!(q.secondary_entity, None);
973 assert!(q.entities.is_empty());
974 assert!(q.filters.is_empty());
975 assert_eq!(q.relationship, None);
976
977 let graph = graph(q.to_query_expr());
978 assert!(graph.pattern.nodes.is_empty());
979 assert!(graph.pattern.edges.is_empty());
980 assert_eq!(graph.limit, None);
981 }
982
983 #[test]
984 fn test_to_query_expr_builds_count_projection_limit_and_nested_filters() {
985 let q = NaturalParser::parse("count top 2 hosts with ssh open").unwrap();
986 let graph = graph(q.to_query_expr());
987
988 assert_eq!(graph.limit, Some(2));
989 assert!(matches!(graph.filter, Some(Filter::And(_, _))));
990 match graph.return_.as_slice() {
991 [Projection::Field(FieldRef::NodeId { alias }, Some(name))] => {
992 assert_eq!(alias, "e0");
993 assert_eq!(name, "count");
994 }
995 other => panic!("unexpected projection: {other:?}"),
996 }
997 }
998
999 #[test]
1000 fn test_to_query_expr_maps_all_relationship_edge_labels() {
1001 let cases = [
1002 (RelationshipType::HasService, "has_service"),
1003 (RelationshipType::HasPort, "has_endpoint"),
1004 (RelationshipType::HasVuln, "affected_by"),
1005 (RelationshipType::HasCredential, "auth_access"),
1006 (RelationshipType::HasUser, "has_user"),
1007 (RelationshipType::ConnectsTo, "connects_to"),
1008 (RelationshipType::Affects, "affected_by"),
1009 (RelationshipType::AuthAccess, "auth_access"),
1010 (RelationshipType::Uses, "uses_tech"),
1011 (RelationshipType::RunsOn, "contains"),
1012 (RelationshipType::Exposes, "has_endpoint"),
1013 ];
1014
1015 for (relationship, expected_label) in cases {
1016 let debug_name = format!("{relationship:?}");
1017 let q = NaturalQuery {
1018 intent: QueryIntent::Find,
1019 primary_entity: Some(EntityType::Host),
1020 secondary_entity: Some(EntityType::Service),
1021 entities: vec![
1022 entity(EntityType::Host, "source"),
1023 entity(EntityType::Service, "target"),
1024 ],
1025 filters: Vec::new(),
1026 relationship: Some(relationship),
1027 limit: None,
1028 };
1029 let graph = graph(q.to_query_expr());
1030
1031 assert_eq!(graph.pattern.edges.len(), 1, "{debug_name}");
1032 let edge = &graph.pattern.edges[0];
1033 assert_eq!(edge.from, "source", "{debug_name}");
1034 assert_eq!(edge.to, "target", "{debug_name}");
1035 assert_eq!(
1036 edge.edge_label.as_deref(),
1037 Some(expected_label),
1038 "{debug_name}"
1039 );
1040 }
1041 }
1042
1043 #[test]
1044 fn test_to_query_expr_skips_edge_without_two_nodes() {
1045 let q = NaturalQuery {
1046 intent: QueryIntent::Find,
1047 primary_entity: Some(EntityType::Host),
1048 secondary_entity: None,
1049 entities: vec![entity(EntityType::Host, "source")],
1050 filters: Vec::new(),
1051 relationship: Some(RelationshipType::ConnectsTo),
1052 limit: None,
1053 };
1054
1055 let graph = graph(q.to_query_expr());
1056 assert_eq!(graph.pattern.nodes.len(), 1);
1057 assert!(graph.pattern.edges.is_empty());
1058 }
1059
1060 #[test]
1061 fn test_to_query_expr_maps_entity_labels_and_id_properties() {
1062 let cases = [
1063 (EntityType::Host, Some("host")),
1064 (EntityType::Service, Some("service")),
1065 (EntityType::Port, None),
1066 (EntityType::User, Some("user")),
1067 (EntityType::Credential, Some("credential")),
1068 (EntityType::Vulnerability, Some("vulnerability")),
1069 (EntityType::Technology, Some("technology")),
1070 (EntityType::Domain, Some("domain")),
1071 (EntityType::Certificate, Some("certificate")),
1072 (EntityType::Network, None),
1073 ];
1074 let entities: Vec<_> = cases
1075 .iter()
1076 .enumerate()
1077 .map(|(i, (entity_type, _))| ExtractedEntity {
1078 entity_type: entity_type.clone(),
1079 value: Some(format!("value{i}")),
1080 alias: format!("e{i}"),
1081 })
1082 .collect();
1083 let q = NaturalQuery {
1084 intent: QueryIntent::Find,
1085 primary_entity: Some(EntityType::Host),
1086 secondary_entity: None,
1087 entities,
1088 filters: Vec::new(),
1089 relationship: None,
1090 limit: None,
1091 };
1092 let graph = graph(q.to_query_expr());
1093
1094 for (node, (_, expected_label)) in graph.pattern.nodes.iter().zip(cases.iter()) {
1095 assert_eq!(node.node_label.as_deref(), *expected_label);
1096 assert_eq!(node.properties.len(), 1);
1097 assert_eq!(node.properties[0].name, "id");
1098 }
1099 }
1100
1101 #[test]
1102 fn test_to_query_expr_creates_default_node_from_primary_entity() {
1103 let cases = [
1104 (EntityType::Host, Some("host")),
1105 (EntityType::Service, Some("service")),
1106 (EntityType::User, Some("user")),
1107 (EntityType::Credential, Some("credential")),
1108 (EntityType::Vulnerability, Some("vulnerability")),
1109 (EntityType::Network, None),
1110 ];
1111
1112 for (entity_type, expected_label) in cases {
1113 let q = NaturalQuery {
1114 intent: QueryIntent::Find,
1115 primary_entity: Some(entity_type),
1116 secondary_entity: None,
1117 entities: Vec::new(),
1118 filters: Vec::new(),
1119 relationship: None,
1120 limit: None,
1121 };
1122 let graph = graph(q.to_query_expr());
1123
1124 assert_eq!(graph.pattern.nodes.len(), 1);
1125 assert_eq!(graph.pattern.nodes[0].alias, "n0");
1126 assert_eq!(graph.pattern.nodes[0].node_label.as_deref(), expected_label);
1127 }
1128 }
1129}