1use crate::storage::query::ast::{
18 CompareOp, EdgeDirection, EdgePattern, FieldRef, Filter, GraphPattern, GraphQuery, NodePattern,
19 Projection, PropertyFilter as AstPropertyFilter, QueryExpr,
20};
21use crate::storage::schema::Value;
22
23#[derive(Debug, Clone)]
25pub struct NaturalError {
26 pub message: String,
27}
28
29impl std::fmt::Display for NaturalError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "Natural language error: {}", self.message)
32 }
33}
34
35impl std::error::Error for NaturalError {}
36
37#[derive(Debug, Clone)]
39pub struct NaturalQuery {
40 pub intent: QueryIntent,
42 pub primary_entity: Option<EntityType>,
44 pub secondary_entity: Option<EntityType>,
46 pub entities: Vec<ExtractedEntity>,
48 pub filters: Vec<PropertyFilter>,
50 pub relationship: Option<RelationshipType>,
52 pub limit: Option<u64>,
54}
55
56#[derive(Debug, Clone, PartialEq)]
58pub enum QueryIntent {
59 Find,
61 Show,
63 Count,
65 Path,
67 Check,
69}
70
71#[derive(Debug, Clone, PartialEq)]
73pub enum EntityType {
74 Host,
75 Service,
76 Port,
77 User,
78 Credential,
79 Vulnerability,
80 Technology,
81 Domain,
82 Certificate,
83 Network,
84}
85
86#[derive(Debug, Clone)]
88pub struct ExtractedEntity {
89 pub entity_type: EntityType,
90 pub value: Option<String>,
91 pub alias: String,
92}
93
94#[derive(Debug, Clone)]
96pub struct PropertyFilter {
97 pub property: String,
98 pub op: CompareOp,
99 pub value: String,
100}
101
102#[derive(Debug, Clone, PartialEq)]
104pub enum RelationshipType {
105 HasService,
106 HasPort,
107 HasVuln,
108 HasCredential,
109 HasUser,
110 ConnectsTo,
111 Affects,
112 AuthAccess,
113 Uses,
114 RunsOn,
115 Exposes,
116}
117
118pub struct NaturalParser;
120
121impl NaturalParser {
122 pub fn parse(input: &str) -> Result<NaturalQuery, NaturalError> {
124 let text = Self::normalize(input);
125 let tokens: Vec<&str> = text.split_whitespace().collect();
126
127 if tokens.is_empty() {
128 return Err(NaturalError {
129 message: "Empty query".to_string(),
130 });
131 }
132
133 let intent = Self::detect_intent(&tokens);
135
136 let entities = Self::extract_entities(&text);
138
139 let (primary, secondary) = Self::determine_entity_types(&entities, &text);
141
142 let filters = Self::extract_filters(&text);
144
145 let relationship = Self::detect_relationship(&text, &primary, &secondary);
147
148 let limit = Self::extract_limit(&text);
150
151 Ok(NaturalQuery {
152 intent,
153 primary_entity: primary,
154 secondary_entity: secondary,
155 entities,
156 filters,
157 relationship,
158 limit,
159 })
160 }
161
162 fn normalize(input: &str) -> String {
164 let trimmed = input.trim();
166 let unquoted = if (trimmed.starts_with('"') && trimmed.ends_with('"'))
167 || (trimmed.starts_with('\'') && trimmed.ends_with('\''))
168 {
169 &trimmed[1..trimmed.len() - 1]
170 } else {
171 trimmed
172 };
173
174 unquoted
176 .to_lowercase()
177 .chars()
178 .map(|c| {
179 if c.is_alphanumeric()
180 || c.is_whitespace()
181 || c == '.'
182 || c == ':'
183 || c == '-'
184 || c == '_'
185 {
186 c
187 } else {
188 ' '
189 }
190 })
191 .collect::<String>()
192 .split_whitespace()
193 .collect::<Vec<_>>()
194 .join(" ")
195 }
196
197 fn detect_intent(tokens: &[&str]) -> QueryIntent {
199 let first = tokens.first().copied().unwrap_or("");
200
201 match first {
202 "find" | "search" | "list" | "get" | "fetch" | "retrieve" => QueryIntent::Find,
203 "show" | "display" | "view" | "describe" | "detail" | "details" => QueryIntent::Show,
204 "count" | "how" => {
205 if tokens.contains(&"many") || tokens.contains(&"count") {
206 QueryIntent::Count
207 } else {
208 QueryIntent::Find
209 }
210 }
211 "path" | "paths" | "route" | "reach" | "reachable" => QueryIntent::Path,
212 "is" | "are" | "does" | "can" | "check" => QueryIntent::Check,
213 "what" | "which" | "where" | "who" => {
214 QueryIntent::Find
216 }
217 _ => QueryIntent::Find,
218 }
219 }
220
221 fn extract_entities(text: &str) -> Vec<ExtractedEntity> {
223 let mut entities = Vec::new();
224 let mut alias_counter = 0;
225
226 let entity_patterns: Vec<(EntityType, &[&str], Option<&str>)> = vec![
228 (
229 EntityType::Host,
230 &[
231 "host", "hosts", "server", "servers", "machine", "machines", "ip", "ips",
232 ],
233 None,
234 ),
235 (EntityType::Service, &["service", "services"], None),
236 (EntityType::Port, &["port", "ports"], None),
237 (
238 EntityType::User,
239 &[
240 "user",
241 "users",
242 "account",
243 "accounts",
244 "username",
245 "usernames",
246 ],
247 None,
248 ),
249 (
250 EntityType::Credential,
251 &[
252 "credential",
253 "credentials",
254 "password",
255 "passwords",
256 "cred",
257 "creds",
258 ],
259 None,
260 ),
261 (
262 EntityType::Vulnerability,
263 &[
264 "vulnerability",
265 "vulnerabilities",
266 "vuln",
267 "vulns",
268 "cve",
269 "cves",
270 ],
271 None,
272 ),
273 (
274 EntityType::Technology,
275 &[
276 "technology",
277 "technologies",
278 "tech",
279 "software",
280 "application",
281 "applications",
282 ],
283 None,
284 ),
285 (
286 EntityType::Domain,
287 &["domain", "domains", "subdomain", "subdomains"],
288 None,
289 ),
290 (
291 EntityType::Certificate,
292 &["certificate", "certificates", "cert", "certs", "ssl", "tls"],
293 None,
294 ),
295 (
296 EntityType::Network,
297 &[
298 "network", "networks", "subnet", "subnets", "segment", "segments",
299 ],
300 None,
301 ),
302 ];
303
304 for (entity_type, keywords, _) in entity_patterns {
305 for keyword in keywords.iter() {
306 if text.contains(keyword) {
307 let value = Self::extract_entity_value(text, keyword);
309
310 entities.push(ExtractedEntity {
311 entity_type: entity_type.clone(),
312 value,
313 alias: format!("e{}", alias_counter),
314 });
315 alias_counter += 1;
316 break; }
318 }
319 }
320
321 for word in text.split_whitespace() {
323 if Self::is_ip_address(word) {
324 let already_has_host = entities
325 .iter()
326 .any(|e| e.entity_type == EntityType::Host && e.value.as_deref() == Some(word));
327 if !already_has_host {
328 entities.push(ExtractedEntity {
329 entity_type: EntityType::Host,
330 value: Some(word.to_string()),
331 alias: format!("e{}", alias_counter),
332 });
333 alias_counter += 1;
334 }
335 }
336 }
337
338 for word in text.split_whitespace() {
340 if word.starts_with("cve-") || word.starts_with("cve:") {
341 let cve = word
342 .replace("cve:", "CVE-")
343 .replace("cve-", "CVE-")
344 .to_uppercase();
345 entities.push(ExtractedEntity {
346 entity_type: EntityType::Vulnerability,
347 value: Some(cve),
348 alias: format!("e{}", alias_counter),
349 });
350 alias_counter += 1;
351 }
352 }
353
354 entities
355 }
356
357 fn extract_entity_value(text: &str, keyword: &str) -> Option<String> {
359 let parts: Vec<&str> = text.split_whitespace().collect();
361
362 for (i, part) in parts.iter().enumerate() {
363 if *part == keyword {
364 if let Some(next) = parts.get(i + 1) {
366 if ![
368 "with", "that", "has", "have", "is", "are", "the", "a", "an", "for", "on",
369 "in",
370 ]
371 .contains(next)
372 {
373 return Some(next.to_string());
374 }
375 if let Some(next2) = parts.get(i + 2) {
377 if ![
378 "with", "that", "has", "have", "is", "are", "the", "a", "an", "for",
379 "on", "in",
380 ]
381 .contains(next2)
382 {
383 return Some(next2.to_string());
384 }
385 }
386 }
387 }
388 }
389
390 None
391 }
392
393 fn is_ip_address(s: &str) -> bool {
395 let parts: Vec<&str> = s.split('.').collect();
396 if parts.len() != 4 {
397 return false;
398 }
399 parts.iter().all(|p| p.parse::<u8>().is_ok())
400 }
401
402 fn determine_entity_types(
404 entities: &[ExtractedEntity],
405 text: &str,
406 ) -> (Option<EntityType>, Option<EntityType>) {
407 if entities.is_empty() {
408 if text.contains("host") || text.contains("server") || text.contains("ip") {
410 return (Some(EntityType::Host), None);
411 }
412 if text.contains("vuln") || text.contains("cve") {
413 return (Some(EntityType::Vulnerability), None);
414 }
415 if text.contains("user") || text.contains("account") {
416 return (Some(EntityType::User), None);
417 }
418 if text.contains("cred") || text.contains("password") {
419 return (Some(EntityType::Credential), None);
420 }
421 if text.contains("service") {
422 return (Some(EntityType::Service), None);
423 }
424 return (None, None);
425 }
426
427 let primary = entities.first().map(|e| e.entity_type.clone());
428 let secondary = entities.get(1).map(|e| e.entity_type.clone());
429
430 (primary, secondary)
431 }
432
433 fn extract_filters(text: &str) -> Vec<PropertyFilter> {
435 let mut filters = Vec::new();
436
437 if text.contains("port") {
439 for word in text.split_whitespace() {
440 if let Ok(port) = word.parse::<u16>() {
441 if port > 0 {
442 filters.push(PropertyFilter {
444 property: "port".to_string(),
445 op: CompareOp::Eq,
446 value: port.to_string(),
447 });
448 }
449 }
450 }
451 }
452
453 let services = [
455 "ssh", "http", "https", "ftp", "smtp", "mysql", "postgres", "redis", "mongodb", "rdp",
456 "vnc",
457 ];
458 for svc in services {
459 if text.contains(svc) {
460 filters.push(PropertyFilter {
461 property: "service".to_string(),
462 op: CompareOp::Eq,
463 value: svc.to_string(),
464 });
465 }
466 }
467
468 if text.contains("critical") {
470 filters.push(PropertyFilter {
471 property: "severity".to_string(),
472 op: CompareOp::Eq,
473 value: "critical".to_string(),
474 });
475 } else if text.contains("high") {
476 filters.push(PropertyFilter {
477 property: "severity".to_string(),
478 op: CompareOp::Ge,
479 value: "7.0".to_string(),
480 });
481 } else if text.contains("medium") {
482 filters.push(PropertyFilter {
483 property: "severity".to_string(),
484 op: CompareOp::Ge,
485 value: "4.0".to_string(),
486 });
487 }
488
489 if text.contains("weak") && (text.contains("password") || text.contains("credential")) {
491 filters.push(PropertyFilter {
492 property: "strength".to_string(),
493 op: CompareOp::Eq,
494 value: "weak".to_string(),
495 });
496 }
497
498 if text.contains("open") || text.contains("exposed") || text.contains("public") {
500 filters.push(PropertyFilter {
501 property: "status".to_string(),
502 op: CompareOp::Eq,
503 value: "open".to_string(),
504 });
505 }
506
507 filters
508 }
509
510 fn detect_relationship(
512 text: &str,
513 primary: &Option<EntityType>,
514 secondary: &Option<EntityType>,
515 ) -> Option<RelationshipType> {
516 if text.contains("connects to") || text.contains("connected to") || text.contains("reach") {
518 return Some(RelationshipType::ConnectsTo);
519 }
520 if text.contains("affects") || text.contains("affected by") || text.contains("vulnerable") {
521 return Some(RelationshipType::Affects);
522 }
523 if text.contains("has access")
524 || text.contains("can access")
525 || text.contains("authenticate")
526 {
527 return Some(RelationshipType::AuthAccess);
528 }
529 if text.contains("runs on") || text.contains("running on") {
530 return Some(RelationshipType::RunsOn);
531 }
532 if text.contains("uses") || text.contains("using") {
533 return Some(RelationshipType::Uses);
534 }
535 if text.contains("exposes") || text.contains("exposing") {
536 return Some(RelationshipType::Exposes);
537 }
538
539 match (primary, secondary) {
541 (Some(EntityType::Host), Some(EntityType::Service)) => {
542 Some(RelationshipType::HasService)
543 }
544 (Some(EntityType::Host), Some(EntityType::Port)) => Some(RelationshipType::HasPort),
545 (Some(EntityType::Host), Some(EntityType::Vulnerability)) => {
546 Some(RelationshipType::HasVuln)
547 }
548 (Some(EntityType::User), Some(EntityType::Credential)) => {
549 Some(RelationshipType::HasCredential)
550 }
551 (Some(EntityType::Credential), Some(EntityType::Host)) => {
552 Some(RelationshipType::AuthAccess)
553 }
554 (Some(EntityType::Vulnerability), Some(EntityType::Host)) => {
555 Some(RelationshipType::Affects)
556 }
557 _ => None,
558 }
559 }
560
561 fn extract_limit(text: &str) -> Option<u64> {
563 let patterns = [("top ", 4), ("first ", 6), ("limit ", 6), ("show ", 5)];
564
565 for (pattern, skip) in patterns {
566 if let Some(pos) = text.find(pattern) {
567 let after = &text[pos + skip..];
568 let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
569 if let Ok(n) = num_str.parse::<u64>() {
570 return Some(n);
571 }
572 }
573 }
574
575 None
576 }
577}
578
579impl NaturalQuery {
580 pub fn to_query_expr(&self) -> QueryExpr {
582 let mut nodes = Vec::new();
583 let mut edges = Vec::new();
584 let mut filters = Vec::new();
585
586 for entity in &self.entities {
588 let node_type = match entity.entity_type {
589 EntityType::Host => Some("host".to_string()),
590 EntityType::Service => Some("service".to_string()),
591 EntityType::User => Some("user".to_string()),
592 EntityType::Credential => Some("credential".to_string()),
593 EntityType::Vulnerability => Some("vulnerability".to_string()),
594 EntityType::Technology => Some("technology".to_string()),
595 EntityType::Domain => Some("domain".to_string()),
596 EntityType::Certificate => Some("certificate".to_string()),
597 _ => None,
598 };
599
600 let mut properties: Vec<AstPropertyFilter> = Vec::new();
601 if let Some(ref value) = entity.value {
602 properties.push(AstPropertyFilter {
603 name: "id".to_string(),
604 op: CompareOp::Eq,
605 value: Value::text(value.clone()),
606 });
607 }
608
609 nodes.push(NodePattern {
610 alias: entity.alias.clone(),
611 node_label: node_type.clone(),
612 properties,
613 });
614 }
615
616 if let Some(ref relationship) = self.relationship {
621 if nodes.len() >= 2 {
622 let edge_label = Some(
623 match relationship {
624 RelationshipType::HasService => "has_service",
625 RelationshipType::HasPort => "has_endpoint",
626 RelationshipType::HasVuln => "affected_by",
627 RelationshipType::HasCredential => "auth_access",
628 RelationshipType::HasUser => "has_user",
629 RelationshipType::ConnectsTo => "connects_to",
630 RelationshipType::Affects => "affected_by",
631 RelationshipType::AuthAccess => "auth_access",
632 RelationshipType::Uses => "uses_tech",
633 RelationshipType::RunsOn => "contains",
634 RelationshipType::Exposes => "has_endpoint",
635 }
636 .to_string(),
637 );
638
639 edges.push(EdgePattern {
640 alias: None,
641 from: nodes[0].alias.clone(),
642 to: nodes[1].alias.clone(),
643 edge_label,
644 direction: EdgeDirection::Outgoing,
645 min_hops: 1,
646 max_hops: 1,
647 });
648 }
649 }
650
651 let current_alias = nodes
653 .first()
654 .map(|n| n.alias.clone())
655 .unwrap_or_else(|| "n0".to_string());
656 for filter in &self.filters {
657 filters.push(Filter::Compare {
658 field: FieldRef::NodeProperty {
659 alias: current_alias.clone(),
660 property: filter.property.clone(),
661 },
662 op: filter.op,
663 value: Value::text(filter.value.clone()),
664 });
665 }
666
667 let projections = match self.intent {
669 QueryIntent::Count => vec![Projection::Field(
670 FieldRef::NodeId {
671 alias: current_alias.clone(),
672 },
673 Some("count".to_string()),
674 )],
675 _ => vec![Projection::from_field(FieldRef::NodeId {
676 alias: current_alias.clone(),
677 })],
678 };
679
680 if nodes.is_empty() {
682 if let Some(ref entity_type) = self.primary_entity {
683 let node_label = match entity_type {
684 EntityType::Host => Some("host".to_string()),
685 EntityType::Service => Some("service".to_string()),
686 EntityType::User => Some("user".to_string()),
687 EntityType::Credential => Some("credential".to_string()),
688 EntityType::Vulnerability => Some("vulnerability".to_string()),
689 _ => None,
690 };
691
692 nodes.push(NodePattern {
693 alias: "n0".to_string(),
694 node_label,
695 properties: Vec::new(),
696 });
697 }
698 }
699
700 let combined_filter = if filters.is_empty() {
702 None
703 } else {
704 let mut iter = filters.into_iter();
705 let first = iter.next().unwrap();
706 Some(iter.fold(first, |acc, f| Filter::And(Box::new(acc), Box::new(f))))
707 };
708
709 QueryExpr::Graph(GraphQuery {
710 alias: None,
711 pattern: GraphPattern { nodes, edges },
712 filter: combined_filter,
713 return_: projections,
714 limit: self.limit,
715 })
716 }
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 #[test]
724 fn test_parse_find_hosts() {
725 let q = NaturalParser::parse("find all hosts with ssh open").unwrap();
726 assert_eq!(q.intent, QueryIntent::Find);
727 assert!(q.entities.iter().any(|e| e.entity_type == EntityType::Host));
728 assert!(q
729 .filters
730 .iter()
731 .any(|f| f.property == "service" && f.value == "ssh"));
732 }
733
734 #[test]
735 fn test_parse_show_credentials() {
736 let q = NaturalParser::parse("show me credentials for user admin").unwrap();
737 assert_eq!(q.intent, QueryIntent::Show);
738 assert!(q
739 .entities
740 .iter()
741 .any(|e| e.entity_type == EntityType::Credential));
742 assert!(q.entities.iter().any(|e| e.entity_type == EntityType::User));
743 }
744
745 #[test]
746 fn test_parse_with_ip() {
747 let q = NaturalParser::parse("what vulnerabilities affect host 10.0.0.1").unwrap();
748 assert!(q
749 .entities
750 .iter()
751 .any(|e| e.entity_type == EntityType::Host && e.value == Some("10.0.0.1".to_string())));
752 assert!(q
753 .entities
754 .iter()
755 .any(|e| e.entity_type == EntityType::Vulnerability));
756 }
757
758 #[test]
759 fn test_parse_count() {
760 let q = NaturalParser::parse("how many hosts have port 22 open").unwrap();
761 assert_eq!(q.intent, QueryIntent::Count);
762 }
763
764 #[test]
765 fn test_parse_weak_passwords() {
766 let q = NaturalParser::parse("list users with weak passwords").unwrap();
767 assert!(q
768 .filters
769 .iter()
770 .any(|f| f.property == "strength" && f.value == "weak"));
771 }
772
773 #[test]
774 fn test_parse_critical_vulns() {
775 let q = NaturalParser::parse("show critical vulnerabilities").unwrap();
776 assert!(q
777 .filters
778 .iter()
779 .any(|f| f.property == "severity" && f.value == "critical"));
780 }
781
782 #[test]
783 fn test_parse_quoted() {
784 let q = NaturalParser::parse("\"find hosts connected to 10.0.0.1\"").unwrap();
785 assert_eq!(q.intent, QueryIntent::Find);
786 assert!(q.relationship == Some(RelationshipType::ConnectsTo));
787 }
788
789 #[test]
790 fn test_parse_with_limit() {
791 let q = NaturalParser::parse("show top 10 vulnerable hosts").unwrap();
792 assert_eq!(q.limit, Some(10));
793 }
794
795 #[test]
796 fn test_to_query_expr() {
797 let q = NaturalParser::parse("find all hosts with ssh").unwrap();
798 let expr = q.to_query_expr();
799 assert!(matches!(expr, QueryExpr::Graph(_)));
800 }
801
802 #[test]
803 fn test_detect_relationship() {
804 let q = NaturalParser::parse("credentials that can access host 10.0.0.1").unwrap();
805 assert_eq!(q.relationship, Some(RelationshipType::AuthAccess));
806 }
807}