Skip to main content

reddb_server/storage/query/modes/
natural.rs

1//! Natural Language Query Parser
2//!
3//! Translates natural language queries to graph patterns:
4//! - "find all hosts with ssh open"
5//! - "show me credentials for user admin"
6//! - "what vulnerabilities affect host 10.0.0.1?"
7//! - "list users with weak passwords"
8//!
9//! # Approach
10//!
11//! 1. Intent classification (find, show, list, count, path)
12//! 2. Entity extraction (hosts, users, credentials, vulnerabilities)
13//! 3. Property extraction (ip, name, port, cve)
14//! 4. Relationship inference (connects, has, affects)
15//! 5. Generate equivalent graph query
16
17use crate::storage::query::ast::{
18    CompareOp, EdgeDirection, EdgePattern, FieldRef, Filter, GraphPattern, GraphQuery, NodePattern,
19    Projection, PropertyFilter as AstPropertyFilter, QueryExpr,
20};
21use crate::storage::schema::Value;
22
23/// Natural language parse error
24#[derive(Debug, Clone)]
25pub struct NaturalError {
26    pub message: String,
27}
28
29impl std::fmt::Display for NaturalError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "Natural language error: {}", self.message)
32    }
33}
34
35impl std::error::Error for NaturalError {}
36
37/// A parsed natural language query
38#[derive(Debug, Clone)]
39pub struct NaturalQuery {
40    /// The detected intent
41    pub intent: QueryIntent,
42    /// Primary entity type
43    pub primary_entity: Option<EntityType>,
44    /// Secondary entity (for relationships)
45    pub secondary_entity: Option<EntityType>,
46    /// Extracted entities with values
47    pub entities: Vec<ExtractedEntity>,
48    /// Property filters
49    pub filters: Vec<PropertyFilter>,
50    /// Relationship type (if any)
51    pub relationship: Option<RelationshipType>,
52    /// Limit on results
53    pub limit: Option<u64>,
54}
55
56/// Query intent
57#[derive(Debug, Clone, PartialEq)]
58pub enum QueryIntent {
59    /// Find/list entities
60    Find,
61    /// Show details
62    Show,
63    /// Count entities
64    Count,
65    /// Find path between entities
66    Path,
67    /// Check if relationship exists
68    Check,
69}
70
71/// Entity types in the security domain
72#[derive(Debug, Clone, PartialEq)]
73pub enum EntityType {
74    Host,
75    Service,
76    Port,
77    User,
78    Credential,
79    Vulnerability,
80    Technology,
81    Domain,
82    Certificate,
83    Network,
84}
85
86/// An extracted entity mention
87#[derive(Debug, Clone)]
88pub struct ExtractedEntity {
89    pub entity_type: EntityType,
90    pub value: Option<String>,
91    pub alias: String,
92}
93
94/// Property filter from natural language
95#[derive(Debug, Clone)]
96pub struct PropertyFilter {
97    pub property: String,
98    pub op: CompareOp,
99    pub value: String,
100}
101
102/// Relationship types
103#[derive(Debug, Clone, PartialEq)]
104pub enum RelationshipType {
105    HasService,
106    HasPort,
107    HasVuln,
108    HasCredential,
109    HasUser,
110    ConnectsTo,
111    Affects,
112    AuthAccess,
113    Uses,
114    RunsOn,
115    Exposes,
116}
117
118/// Natural language parser
119pub struct NaturalParser;
120
121impl NaturalParser {
122    /// Parse a natural language query
123    pub fn parse(input: &str) -> Result<NaturalQuery, NaturalError> {
124        let text = Self::normalize(input);
125        let tokens: Vec<&str> = text.split_whitespace().collect();
126
127        if tokens.is_empty() {
128            return Err(NaturalError {
129                message: "Empty query".to_string(),
130            });
131        }
132
133        // Detect intent
134        let intent = Self::detect_intent(&tokens);
135
136        // Extract entities
137        let entities = Self::extract_entities(&text);
138
139        // Determine primary and secondary entity types
140        let (primary, secondary) = Self::determine_entity_types(&entities, &text);
141
142        // Extract property filters
143        let filters = Self::extract_filters(&text);
144
145        // Detect relationship
146        let relationship = Self::detect_relationship(&text, &primary, &secondary);
147
148        // Extract limit
149        let limit = Self::extract_limit(&text);
150
151        Ok(NaturalQuery {
152            intent,
153            primary_entity: primary,
154            secondary_entity: secondary,
155            entities,
156            filters,
157            relationship,
158            limit,
159        })
160    }
161
162    /// Normalize input text
163    fn normalize(input: &str) -> String {
164        // Remove quotes if present
165        let trimmed = input.trim();
166        let unquoted = if (trimmed.starts_with('"') && trimmed.ends_with('"'))
167            || (trimmed.starts_with('\'') && trimmed.ends_with('\''))
168        {
169            &trimmed[1..trimmed.len() - 1]
170        } else {
171            trimmed
172        };
173
174        // Convert to lowercase and remove punctuation (except relevant chars)
175        unquoted
176            .to_lowercase()
177            .chars()
178            .map(|c| {
179                if c.is_alphanumeric()
180                    || c.is_whitespace()
181                    || c == '.'
182                    || c == ':'
183                    || c == '-'
184                    || c == '_'
185                {
186                    c
187                } else {
188                    ' '
189                }
190            })
191            .collect::<String>()
192            .split_whitespace()
193            .collect::<Vec<_>>()
194            .join(" ")
195    }
196
197    /// Detect query intent from tokens
198    fn detect_intent(tokens: &[&str]) -> QueryIntent {
199        let first = tokens.first().copied().unwrap_or("");
200
201        match first {
202            "find" | "search" | "list" | "get" | "fetch" | "retrieve" => QueryIntent::Find,
203            "show" | "display" | "view" | "describe" | "detail" | "details" => QueryIntent::Show,
204            "count" | "how" => {
205                if tokens.contains(&"many") || tokens.contains(&"count") {
206                    QueryIntent::Count
207                } else {
208                    QueryIntent::Find
209                }
210            }
211            "path" | "paths" | "route" | "reach" | "reachable" => QueryIntent::Path,
212            "is" | "are" | "does" | "can" | "check" => QueryIntent::Check,
213            "what" | "which" | "where" | "who" => {
214                // Question words usually mean find
215                QueryIntent::Find
216            }
217            _ => QueryIntent::Find,
218        }
219    }
220
221    /// Extract entities from text
222    fn extract_entities(text: &str) -> Vec<ExtractedEntity> {
223        let mut entities = Vec::new();
224        let mut alias_counter = 0;
225
226        // Entity patterns with regex-like matching
227        let entity_patterns: Vec<(EntityType, &[&str], Option<&str>)> = vec![
228            (
229                EntityType::Host,
230                &[
231                    "host", "hosts", "server", "servers", "machine", "machines", "ip", "ips",
232                ],
233                None,
234            ),
235            (EntityType::Service, &["service", "services"], None),
236            (EntityType::Port, &["port", "ports"], None),
237            (
238                EntityType::User,
239                &[
240                    "user",
241                    "users",
242                    "account",
243                    "accounts",
244                    "username",
245                    "usernames",
246                ],
247                None,
248            ),
249            (
250                EntityType::Credential,
251                &[
252                    "credential",
253                    "credentials",
254                    "password",
255                    "passwords",
256                    "cred",
257                    "creds",
258                ],
259                None,
260            ),
261            (
262                EntityType::Vulnerability,
263                &[
264                    "vulnerability",
265                    "vulnerabilities",
266                    "vuln",
267                    "vulns",
268                    "cve",
269                    "cves",
270                ],
271                None,
272            ),
273            (
274                EntityType::Technology,
275                &[
276                    "technology",
277                    "technologies",
278                    "tech",
279                    "software",
280                    "application",
281                    "applications",
282                ],
283                None,
284            ),
285            (
286                EntityType::Domain,
287                &["domain", "domains", "subdomain", "subdomains"],
288                None,
289            ),
290            (
291                EntityType::Certificate,
292                &["certificate", "certificates", "cert", "certs", "ssl", "tls"],
293                None,
294            ),
295            (
296                EntityType::Network,
297                &[
298                    "network", "networks", "subnet", "subnets", "segment", "segments",
299                ],
300                None,
301            ),
302        ];
303
304        for (entity_type, keywords, _) in entity_patterns {
305            for keyword in keywords.iter() {
306                if text.contains(keyword) {
307                    // Try to extract associated value
308                    let value = Self::extract_entity_value(text, keyword);
309
310                    entities.push(ExtractedEntity {
311                        entity_type: entity_type.clone(),
312                        value,
313                        alias: format!("e{}", alias_counter),
314                    });
315                    alias_counter += 1;
316                    break; // Only add once per entity type
317                }
318            }
319        }
320
321        // Extract IP addresses
322        for word in text.split_whitespace() {
323            if Self::is_ip_address(word) {
324                let already_has_host = entities
325                    .iter()
326                    .any(|e| e.entity_type == EntityType::Host && e.value.as_deref() == Some(word));
327                if !already_has_host {
328                    entities.push(ExtractedEntity {
329                        entity_type: EntityType::Host,
330                        value: Some(word.to_string()),
331                        alias: format!("e{}", alias_counter),
332                    });
333                    alias_counter += 1;
334                }
335            }
336        }
337
338        // Extract CVE IDs
339        for word in text.split_whitespace() {
340            if word.starts_with("cve-") || word.starts_with("cve:") {
341                let cve = word
342                    .replace("cve:", "CVE-")
343                    .replace("cve-", "CVE-")
344                    .to_uppercase();
345                entities.push(ExtractedEntity {
346                    entity_type: EntityType::Vulnerability,
347                    value: Some(cve),
348                    alias: format!("e{}", alias_counter),
349                });
350                alias_counter += 1;
351            }
352        }
353
354        entities
355    }
356
357    /// Extract value associated with an entity keyword
358    fn extract_entity_value(text: &str, keyword: &str) -> Option<String> {
359        // Look for patterns like "host 10.0.0.1" or "user admin"
360        let parts: Vec<&str> = text.split_whitespace().collect();
361
362        for (i, part) in parts.iter().enumerate() {
363            if *part == keyword {
364                // Check next word
365                if let Some(next) = parts.get(i + 1) {
366                    // Skip common words
367                    if ![
368                        "with", "that", "has", "have", "is", "are", "the", "a", "an", "for", "on",
369                        "in",
370                    ]
371                    .contains(next)
372                    {
373                        return Some(next.to_string());
374                    }
375                    // Check word after that
376                    if let Some(next2) = parts.get(i + 2) {
377                        if ![
378                            "with", "that", "has", "have", "is", "are", "the", "a", "an", "for",
379                            "on", "in",
380                        ]
381                        .contains(next2)
382                        {
383                            return Some(next2.to_string());
384                        }
385                    }
386                }
387            }
388        }
389
390        None
391    }
392
393    /// Check if a string looks like an IP address
394    fn is_ip_address(s: &str) -> bool {
395        let parts: Vec<&str> = s.split('.').collect();
396        if parts.len() != 4 {
397            return false;
398        }
399        parts.iter().all(|p| p.parse::<u8>().is_ok())
400    }
401
402    /// Determine primary and secondary entity types
403    fn determine_entity_types(
404        entities: &[ExtractedEntity],
405        text: &str,
406    ) -> (Option<EntityType>, Option<EntityType>) {
407        if entities.is_empty() {
408            // Infer from text
409            if text.contains("host") || text.contains("server") || text.contains("ip") {
410                return (Some(EntityType::Host), None);
411            }
412            if text.contains("vuln") || text.contains("cve") {
413                return (Some(EntityType::Vulnerability), None);
414            }
415            if text.contains("user") || text.contains("account") {
416                return (Some(EntityType::User), None);
417            }
418            if text.contains("cred") || text.contains("password") {
419                return (Some(EntityType::Credential), None);
420            }
421            if text.contains("service") {
422                return (Some(EntityType::Service), None);
423            }
424            return (None, None);
425        }
426
427        let primary = entities.first().map(|e| e.entity_type.clone());
428        let secondary = entities.get(1).map(|e| e.entity_type.clone());
429
430        (primary, secondary)
431    }
432
433    /// Extract property filters from text
434    fn extract_filters(text: &str) -> Vec<PropertyFilter> {
435        let mut filters = Vec::new();
436
437        // Port number patterns
438        if text.contains("port") {
439            for word in text.split_whitespace() {
440                if let Ok(port) = word.parse::<u16>() {
441                    if port > 0 {
442                        // u16 already constrains to 0-65535
443                        filters.push(PropertyFilter {
444                            property: "port".to_string(),
445                            op: CompareOp::Eq,
446                            value: port.to_string(),
447                        });
448                    }
449                }
450            }
451        }
452
453        // Common service names
454        let services = [
455            "ssh", "http", "https", "ftp", "smtp", "mysql", "postgres", "redis", "mongodb", "rdp",
456            "vnc",
457        ];
458        for svc in services {
459            if text.contains(svc) {
460                filters.push(PropertyFilter {
461                    property: "service".to_string(),
462                    op: CompareOp::Eq,
463                    value: svc.to_string(),
464                });
465            }
466        }
467
468        // Critical/high/medium/low severity
469        if text.contains("critical") {
470            filters.push(PropertyFilter {
471                property: "severity".to_string(),
472                op: CompareOp::Eq,
473                value: "critical".to_string(),
474            });
475        } else if text.contains("high") {
476            filters.push(PropertyFilter {
477                property: "severity".to_string(),
478                op: CompareOp::Ge,
479                value: "7.0".to_string(),
480            });
481        } else if text.contains("medium") {
482            filters.push(PropertyFilter {
483                property: "severity".to_string(),
484                op: CompareOp::Ge,
485                value: "4.0".to_string(),
486            });
487        }
488
489        // Weak passwords
490        if text.contains("weak") && (text.contains("password") || text.contains("credential")) {
491            filters.push(PropertyFilter {
492                property: "strength".to_string(),
493                op: CompareOp::Eq,
494                value: "weak".to_string(),
495            });
496        }
497
498        // Open/exposed
499        if text.contains("open") || text.contains("exposed") || text.contains("public") {
500            filters.push(PropertyFilter {
501                property: "status".to_string(),
502                op: CompareOp::Eq,
503                value: "open".to_string(),
504            });
505        }
506
507        filters
508    }
509
510    /// Detect relationship type from text
511    fn detect_relationship(
512        text: &str,
513        primary: &Option<EntityType>,
514        secondary: &Option<EntityType>,
515    ) -> Option<RelationshipType> {
516        // Explicit relationship keywords
517        if text.contains("connects to") || text.contains("connected to") || text.contains("reach") {
518            return Some(RelationshipType::ConnectsTo);
519        }
520        if text.contains("affects") || text.contains("affected by") || text.contains("vulnerable") {
521            return Some(RelationshipType::Affects);
522        }
523        if text.contains("has access")
524            || text.contains("can access")
525            || text.contains("authenticate")
526        {
527            return Some(RelationshipType::AuthAccess);
528        }
529        if text.contains("runs on") || text.contains("running on") {
530            return Some(RelationshipType::RunsOn);
531        }
532        if text.contains("uses") || text.contains("using") {
533            return Some(RelationshipType::Uses);
534        }
535        if text.contains("exposes") || text.contains("exposing") {
536            return Some(RelationshipType::Exposes);
537        }
538
539        // Infer from entity types
540        match (primary, secondary) {
541            (Some(EntityType::Host), Some(EntityType::Service)) => {
542                Some(RelationshipType::HasService)
543            }
544            (Some(EntityType::Host), Some(EntityType::Port)) => Some(RelationshipType::HasPort),
545            (Some(EntityType::Host), Some(EntityType::Vulnerability)) => {
546                Some(RelationshipType::HasVuln)
547            }
548            (Some(EntityType::User), Some(EntityType::Credential)) => {
549                Some(RelationshipType::HasCredential)
550            }
551            (Some(EntityType::Credential), Some(EntityType::Host)) => {
552                Some(RelationshipType::AuthAccess)
553            }
554            (Some(EntityType::Vulnerability), Some(EntityType::Host)) => {
555                Some(RelationshipType::Affects)
556            }
557            _ => None,
558        }
559    }
560
561    /// Extract limit from text
562    fn extract_limit(text: &str) -> Option<u64> {
563        let patterns = [("top ", 4), ("first ", 6), ("limit ", 6), ("show ", 5)];
564
565        for (pattern, skip) in patterns {
566            if let Some(pos) = text.find(pattern) {
567                let after = &text[pos + skip..];
568                let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
569                if let Ok(n) = num_str.parse::<u64>() {
570                    return Some(n);
571                }
572            }
573        }
574
575        None
576    }
577}
578
579impl NaturalQuery {
580    /// Convert to QueryExpr
581    pub fn to_query_expr(&self) -> QueryExpr {
582        let mut nodes = Vec::new();
583        let mut edges = Vec::new();
584        let mut filters = Vec::new();
585
586        // Create nodes from extracted entities
587        for entity in &self.entities {
588            let node_type = match entity.entity_type {
589                EntityType::Host => Some("host".to_string()),
590                EntityType::Service => Some("service".to_string()),
591                EntityType::User => Some("user".to_string()),
592                EntityType::Credential => Some("credential".to_string()),
593                EntityType::Vulnerability => Some("vulnerability".to_string()),
594                EntityType::Technology => Some("technology".to_string()),
595                EntityType::Domain => Some("domain".to_string()),
596                EntityType::Certificate => Some("certificate".to_string()),
597                _ => None,
598            };
599
600            let mut properties: Vec<AstPropertyFilter> = Vec::new();
601            if let Some(ref value) = entity.value {
602                properties.push(AstPropertyFilter {
603                    name: "id".to_string(),
604                    op: CompareOp::Eq,
605                    value: Value::text(value.clone()),
606                });
607            }
608
609            nodes.push(NodePattern {
610                alias: entity.alias.clone(),
611                node_label: node_type.clone(),
612                properties,
613            });
614        }
615
616        // Add edges based on relationships. Map the natural-language
617        // relationship enum to the canonical edge label string used by
618        // the legacy reserved range; users can introduce new relationship
619        // types by extending this match.
620        if let Some(ref relationship) = self.relationship {
621            if nodes.len() >= 2 {
622                let edge_label = Some(
623                    match relationship {
624                        RelationshipType::HasService => "has_service",
625                        RelationshipType::HasPort => "has_endpoint",
626                        RelationshipType::HasVuln => "affected_by",
627                        RelationshipType::HasCredential => "auth_access",
628                        RelationshipType::HasUser => "has_user",
629                        RelationshipType::ConnectsTo => "connects_to",
630                        RelationshipType::Affects => "affected_by",
631                        RelationshipType::AuthAccess => "auth_access",
632                        RelationshipType::Uses => "uses_tech",
633                        RelationshipType::RunsOn => "contains",
634                        RelationshipType::Exposes => "has_endpoint",
635                    }
636                    .to_string(),
637                );
638
639                edges.push(EdgePattern {
640                    alias: None,
641                    from: nodes[0].alias.clone(),
642                    to: nodes[1].alias.clone(),
643                    edge_label,
644                    direction: EdgeDirection::Outgoing,
645                    min_hops: 1,
646                    max_hops: 1,
647                });
648            }
649        }
650
651        // Convert property filters
652        let current_alias = nodes
653            .first()
654            .map(|n| n.alias.clone())
655            .unwrap_or_else(|| "n0".to_string());
656        for filter in &self.filters {
657            filters.push(Filter::Compare {
658                field: FieldRef::NodeProperty {
659                    alias: current_alias.clone(),
660                    property: filter.property.clone(),
661                },
662                op: filter.op,
663                value: Value::text(filter.value.clone()),
664            });
665        }
666
667        // Build projections based on intent
668        let projections = match self.intent {
669            QueryIntent::Count => vec![Projection::Field(
670                FieldRef::NodeId {
671                    alias: current_alias.clone(),
672                },
673                Some("count".to_string()),
674            )],
675            _ => vec![Projection::from_field(FieldRef::NodeId {
676                alias: current_alias.clone(),
677            })],
678        };
679
680        // If no nodes were created, create a default based on primary entity
681        if nodes.is_empty() {
682            if let Some(ref entity_type) = self.primary_entity {
683                let node_label = match entity_type {
684                    EntityType::Host => Some("host".to_string()),
685                    EntityType::Service => Some("service".to_string()),
686                    EntityType::User => Some("user".to_string()),
687                    EntityType::Credential => Some("credential".to_string()),
688                    EntityType::Vulnerability => Some("vulnerability".to_string()),
689                    _ => None,
690                };
691
692                nodes.push(NodePattern {
693                    alias: "n0".to_string(),
694                    node_label,
695                    properties: Vec::new(),
696                });
697            }
698        }
699
700        // Fold multiple filters into nested And
701        let combined_filter = if filters.is_empty() {
702            None
703        } else {
704            let mut iter = filters.into_iter();
705            let first = iter.next().unwrap();
706            Some(iter.fold(first, |acc, f| Filter::And(Box::new(acc), Box::new(f))))
707        };
708
709        QueryExpr::Graph(GraphQuery {
710            alias: None,
711            pattern: GraphPattern { nodes, edges },
712            filter: combined_filter,
713            return_: projections,
714        })
715    }
716}
717
718#[cfg(test)]
719mod tests {
720    use super::*;
721
722    #[test]
723    fn test_parse_find_hosts() {
724        let q = NaturalParser::parse("find all hosts with ssh open").unwrap();
725        assert_eq!(q.intent, QueryIntent::Find);
726        assert!(q.entities.iter().any(|e| e.entity_type == EntityType::Host));
727        assert!(q
728            .filters
729            .iter()
730            .any(|f| f.property == "service" && f.value == "ssh"));
731    }
732
733    #[test]
734    fn test_parse_show_credentials() {
735        let q = NaturalParser::parse("show me credentials for user admin").unwrap();
736        assert_eq!(q.intent, QueryIntent::Show);
737        assert!(q
738            .entities
739            .iter()
740            .any(|e| e.entity_type == EntityType::Credential));
741        assert!(q.entities.iter().any(|e| e.entity_type == EntityType::User));
742    }
743
744    #[test]
745    fn test_parse_with_ip() {
746        let q = NaturalParser::parse("what vulnerabilities affect host 10.0.0.1").unwrap();
747        assert!(q
748            .entities
749            .iter()
750            .any(|e| e.entity_type == EntityType::Host && e.value == Some("10.0.0.1".to_string())));
751        assert!(q
752            .entities
753            .iter()
754            .any(|e| e.entity_type == EntityType::Vulnerability));
755    }
756
757    #[test]
758    fn test_parse_count() {
759        let q = NaturalParser::parse("how many hosts have port 22 open").unwrap();
760        assert_eq!(q.intent, QueryIntent::Count);
761    }
762
763    #[test]
764    fn test_parse_weak_passwords() {
765        let q = NaturalParser::parse("list users with weak passwords").unwrap();
766        assert!(q
767            .filters
768            .iter()
769            .any(|f| f.property == "strength" && f.value == "weak"));
770    }
771
772    #[test]
773    fn test_parse_critical_vulns() {
774        let q = NaturalParser::parse("show critical vulnerabilities").unwrap();
775        assert!(q
776            .filters
777            .iter()
778            .any(|f| f.property == "severity" && f.value == "critical"));
779    }
780
781    #[test]
782    fn test_parse_quoted() {
783        let q = NaturalParser::parse("\"find hosts connected to 10.0.0.1\"").unwrap();
784        assert_eq!(q.intent, QueryIntent::Find);
785        assert!(q.relationship == Some(RelationshipType::ConnectsTo));
786    }
787
788    #[test]
789    fn test_parse_with_limit() {
790        let q = NaturalParser::parse("show top 10 vulnerable hosts").unwrap();
791        assert_eq!(q.limit, Some(10));
792    }
793
794    #[test]
795    fn test_to_query_expr() {
796        let q = NaturalParser::parse("find all hosts with ssh").unwrap();
797        let expr = q.to_query_expr();
798        assert!(matches!(expr, QueryExpr::Graph(_)));
799    }
800
801    #[test]
802    fn test_detect_relationship() {
803        let q = NaturalParser::parse("credentials that can access host 10.0.0.1").unwrap();
804        assert_eq!(q.relationship, Some(RelationshipType::AuthAccess));
805    }
806}