1pub mod detect;
23pub mod gremlin;
24pub mod natural;
25pub mod sparql;
26
27pub use detect::{detect_mode, QueryMode};
28pub use gremlin::{GremlinParser, GremlinStep, GremlinTraversal};
29pub use natural::{NaturalParser, NaturalQuery, QueryIntent};
30pub use sparql::{SparqlParser, SparqlQuery, TriplePattern};
31
32use crate::ast::QueryExpr;
33
34pub fn parse_multi(input: &str) -> Result<QueryExpr, MultiParseError> {
36 let mode = detect_mode(input);
37
38 match mode {
39 QueryMode::Sql | QueryMode::Cypher | QueryMode::Path => {
40 crate::parser::parse(input)
46 .map(|q| q.query)
47 .map_err(|e| MultiParseError::Parse(e.to_string()))
48 }
49 QueryMode::Gremlin => {
50 let traversal = GremlinParser::parse(input)?;
51 Ok(traversal.to_query_expr())
52 }
53 QueryMode::Sparql => {
54 let sparql = SparqlParser::parse(input)?;
55 Ok(sparql.to_query_expr())
56 }
57 QueryMode::Natural => {
58 let natural = NaturalParser::parse(input)?;
59 Ok(natural.to_query_expr())
60 }
61 QueryMode::Unknown => Err(MultiParseError::UnknownMode(input.to_string())),
62 }
63}
64
65#[derive(Debug, Clone)]
67pub enum MultiParseError {
68 Parse(String),
69 Gremlin(String),
70 Sparql(String),
71 Natural(String),
72 UnknownMode(String),
73}
74
75impl std::fmt::Display for MultiParseError {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 Self::Parse(e) => write!(f, "Parse error: {}", e),
79 Self::Gremlin(e) => write!(f, "Gremlin error: {}", e),
80 Self::Sparql(e) => write!(f, "SPARQL error: {}", e),
81 Self::Natural(e) => write!(f, "Natural language error: {}", e),
82 Self::UnknownMode(q) => write!(f, "Unknown query mode for: {}", q),
83 }
84 }
85}
86
87impl std::error::Error for MultiParseError {}
88
89impl From<gremlin::GremlinError> for MultiParseError {
90 fn from(e: gremlin::GremlinError) -> Self {
91 Self::Gremlin(e.to_string())
92 }
93}
94
95impl From<sparql::SparqlError> for MultiParseError {
96 fn from(e: sparql::SparqlError) -> Self {
97 Self::Sparql(e.to_string())
98 }
99}
100
101impl From<natural::NaturalError> for MultiParseError {
102 fn from(e: natural::NaturalError) -> Self {
103 Self::Natural(e.to_string())
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_detect_sql() {
113 assert_eq!(detect_mode("SELECT * FROM users"), QueryMode::Sql);
114 assert_eq!(detect_mode("select name from hosts"), QueryMode::Sql);
115 }
116
117 #[test]
118 fn test_detect_gremlin() {
119 assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
120 assert_eq!(
121 detect_mode("g.V().has('name', 'alice')"),
122 QueryMode::Gremlin
123 );
124 assert_eq!(detect_mode("__.out('knows')"), QueryMode::Gremlin);
125 }
126
127 #[test]
128 fn test_detect_cypher() {
129 assert_eq!(
130 detect_mode("MATCH (a)-[r]->(b) RETURN a"),
131 QueryMode::Cypher
132 );
133 assert_eq!(detect_mode("match (n:Host) return n"), QueryMode::Cypher);
134 }
135
136 #[test]
137 fn test_detect_sparql() {
138 assert_eq!(
139 detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
140 QueryMode::Sparql
141 );
142 assert_eq!(
143 detect_mode("PREFIX ex: <http://example.org/> SELECT ?x"),
144 QueryMode::Sparql
145 );
146 }
147
148 #[test]
149 fn test_detect_path() {
150 assert_eq!(
151 detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
152 QueryMode::Path
153 );
154 assert_eq!(
155 detect_mode("PATHS ALL FROM user('admin') TO credential('root')"),
156 QueryMode::Path
157 );
158 }
159
160 #[test]
161 fn test_detect_natural() {
162 assert_eq!(detect_mode("find all hosts with ssh"), QueryMode::Natural);
163 assert_eq!(
164 detect_mode("show me credentials for user admin"),
165 QueryMode::Natural
166 );
167 assert_eq!(
168 detect_mode("\"what vulnerabilities affect host 10.0.0.1?\""),
169 QueryMode::Natural
170 );
171 }
172
173 #[test]
174 fn parse_multi_routes_supported_modes_to_query_exprs() {
175 assert!(matches!(
176 parse_multi("SELECT * FROM hosts").expect("sql"),
177 QueryExpr::Table(_)
178 ));
179 assert!(matches!(
180 parse_multi("g.V().hasLabel('Host').limit(2)").expect("gremlin"),
181 QueryExpr::Graph(_)
182 ));
183 assert!(matches!(
184 parse_multi("SELECT ?s WHERE { ?s :name 'alice' }").expect("sparql"),
185 QueryExpr::Graph(_)
186 ));
187 assert!(matches!(
188 parse_multi("find all hosts with ssh").expect("natural"),
189 QueryExpr::Graph(_)
190 ));
191 }
192
193 #[test]
194 fn parse_multi_surfaces_parse_and_unknown_errors() {
195 let err = parse_multi("SELECT * FROM").expect_err("bad SQL should fail");
196 assert!(matches!(err, MultiParseError::Parse(_)));
197 assert!(err.to_string().starts_with("Parse error:"));
198
199 let err = parse_multi("").expect_err("empty should be unknown");
200 assert!(matches!(err, MultiParseError::UnknownMode(ref q) if q.is_empty()));
201 assert_eq!(err.to_string(), "Unknown query mode for: ");
202 }
203
204 #[test]
205 fn multi_parse_error_display_covers_all_variants() {
206 let cases = [
207 (
208 MultiParseError::Parse("bad".to_string()),
209 "Parse error: bad",
210 ),
211 (
212 MultiParseError::Gremlin("bad".to_string()),
213 "Gremlin error: bad",
214 ),
215 (
216 MultiParseError::Sparql("bad".to_string()),
217 "SPARQL error: bad",
218 ),
219 (
220 MultiParseError::Natural("bad".to_string()),
221 "Natural language error: bad",
222 ),
223 (
224 MultiParseError::UnknownMode("???".to_string()),
225 "Unknown query mode for: ???",
226 ),
227 ];
228
229 for (err, expected) in cases {
230 assert_eq!(err.to_string(), expected);
231 }
232 }
233}