1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8 Sql,
10 Gremlin,
12 Cypher,
14 Sparql,
16 Path,
18 Natural,
20 Unknown,
22}
23
24pub fn detect_mode(input: &str) -> QueryMode {
26 let trimmed = input.trim();
27 let lower = trimmed.to_lowercase();
28
29 if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31 return QueryMode::Natural;
32 }
33
34 if lower.starts_with("g.") || lower.starts_with("__.") {
36 return QueryMode::Gremlin;
37 }
38
39 if lower.starts_with("path ") || lower.starts_with("paths ") {
41 return QueryMode::Path;
42 }
43
44 if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46 return QueryMode::Sparql;
47 }
48
49 if lower.starts_with("match ") || lower.starts_with("match(") {
51 return QueryMode::Cypher;
52 }
53
54 let first_token = lower.split_whitespace().next().unwrap_or("");
59 if matches!(
60 first_token,
61 "begin"
62 | "start"
63 | "commit"
64 | "rollback"
65 | "savepoint"
66 | "release"
67 | "end"
68 | "vacuum"
69 | "analyze"
70 | "reset"
71 | "copy"
72 | "refresh"
73 | "explain"
74 | "grant"
75 | "revoke"
76 | "attach"
77 | "detach"
78 | "simulate"
79 | "lint"
80 | "migrate"
81 | "apply"
82 | "events"
83 | "describe"
84 | "desc"
85 ) {
86 return QueryMode::Sql;
87 }
88 if lower.starts_with("select ")
89 || lower.starts_with("from ")
90 || lower.starts_with("insert ")
91 || lower.starts_with("update ")
92 || lower.starts_with("delete ")
93 || lower.starts_with("truncate ")
94 || lower.starts_with("create ")
95 || lower.starts_with("drop ")
96 || lower.starts_with("alter ")
97 || lower.starts_with("vector ")
98 || lower.starts_with("hybrid ")
99 || lower.starts_with("graph ")
100 || lower.starts_with("queue ")
101 || lower.starts_with("events ")
102 || lower.starts_with("tree ")
103 || lower.starts_with("hll ")
104 || lower.starts_with("sketch ")
105 || lower.starts_with("filter ")
106 || lower.starts_with("vault ")
107 || lower.starts_with("unseal vault ")
108 || lower.starts_with("rotate vault ")
109 || lower.starts_with("history vault ")
110 || lower.starts_with("list vault ")
111 || lower.starts_with("list kv ")
112 || lower.starts_with("watch vault ")
113 || lower.starts_with("delete vault ")
114 || lower.starts_with("purge vault ")
115 || lower.starts_with("search ")
116 || lower.starts_with("ask ")
117 || lower.starts_with("put config ")
118 || lower.starts_with("get config ")
119 || lower.starts_with("resolve config ")
120 || lower.starts_with("rotate config ")
121 || lower.starts_with("delete config ")
122 || lower.starts_with("history config ")
123 || lower.starts_with("list config ")
124 || lower.starts_with("watch config ")
125 || lower.starts_with("incr config ")
126 || lower.starts_with("decr config ")
127 || lower.starts_with("add config ")
128 || lower.starts_with("invalidate config ")
129 || lower.starts_with("invalidate tags ")
130 || lower.starts_with("set config ")
131 || lower.starts_with("set secret ")
132 || lower.starts_with("set tenant")
133 || lower.starts_with("show create ")
134 || lower.starts_with("show config")
135 || lower.starts_with("show collections")
136 || lower.starts_with("show tables")
137 || lower.starts_with("show queues")
138 || lower.starts_with("show vectors")
139 || lower.starts_with("show documents")
140 || lower.starts_with("show timeseries")
141 || lower.starts_with("show graphs")
142 || lower.starts_with("kv ")
143 || lower.starts_with("show kv")
144 || lower.starts_with("show configs")
145 || lower.starts_with("show vaults")
146 || lower.starts_with("show schema")
147 || lower.starts_with("show indices")
148 || lower.starts_with("show indexes")
149 || lower.starts_with("show sample ")
150 || lower.starts_with("show secret")
151 || lower.starts_with("show stats")
152 || lower.starts_with("show tenant")
153 || lower.starts_with("show policies")
154 || lower.starts_with("show effective ")
155 || lower.starts_with("rank of ")
156 || lower.starts_with("rank range ")
157 || lower.starts_with("approx rank of ")
158 || lower.starts_with("approximate rank of ")
159 || lower.starts_with("zrank ")
160 || lower.starts_with("zrange ")
161 || lower.starts_with("describe ")
162 || lower.starts_with("desc ")
163 {
164 if lower.starts_with("select ") && has_sparql_variable(&lower) {
168 return QueryMode::Sparql;
169 }
170 return QueryMode::Sql;
171 }
172
173 if is_natural_language(&lower) {
175 return QueryMode::Natural;
176 }
177
178 QueryMode::Unknown
179}
180
181fn has_sparql_pattern(lower: &str) -> bool {
183 let has_var = has_sparql_variable(lower);
189
190 let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
192
193 let has_prefix_pattern = lower.contains(":")
195 && (lower.contains(":<")
196 || lower.contains("> :")
197 || lower.contains(" :") && lower.contains("?"));
198
199 has_var || has_triple_pattern || has_prefix_pattern
200}
201
202fn has_sparql_variable(input: &str) -> bool {
203 let bytes = input.as_bytes();
204 bytes
205 .windows(2)
206 .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
207}
208
209fn is_sparql_variable_start(byte: u8) -> bool {
210 byte.is_ascii_alphabetic() || byte == b'_'
211}
212
213fn is_natural_language(lower: &str) -> bool {
215 let question_starters = [
217 "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
218 "tell ", "display ", "search ", "look ",
219 ];
220
221 let nl_patterns = [
223 " with ",
224 " for ",
225 " that ",
226 " have ",
227 " has ",
228 " can ",
229 " are ",
230 " is ",
231 " all ",
232 " me ",
233 " the ",
234 " from ",
235 " to ",
236 " on ",
237 " in ",
238 "vulnerable",
239 "credential",
240 "password",
241 "user",
242 "host",
243 "service",
244 "connected",
245 "reachable",
246 "exposed",
247 "critical",
248 ];
249
250 for starter in question_starters.iter() {
252 if lower.starts_with(starter) {
253 return true;
254 }
255 }
256
257 let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
259
260 pattern_count >= 2
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_sql_detection() {
269 assert_eq!(
270 detect_mode("SELECT * FROM users WHERE id = 1"),
271 QueryMode::Sql
272 );
273 assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
274 assert_eq!(
275 detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
276 QueryMode::Sql
277 );
278 assert_eq!(
279 detect_mode("INSERT INTO users VALUES (1, 'alice')"),
280 QueryMode::Sql
281 );
282 assert_eq!(
283 detect_mode("UPDATE hosts SET status = 'active'"),
284 QueryMode::Sql
285 );
286 assert_eq!(
287 detect_mode("DELETE FROM logs WHERE age > 30"),
288 QueryMode::Sql
289 );
290 assert_eq!(
291 detect_mode("QUEUE GROUP CREATE tasks workers"),
292 QueryMode::Sql
293 );
294 assert_eq!(
295 detect_mode("EVENTS BACKFILL users TO audit"),
296 QueryMode::Sql
297 );
298 assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
299 assert_eq!(
300 detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
301 QueryMode::Sql
302 );
303 assert_eq!(
304 detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
305 QueryMode::Sql
306 );
307 assert_eq!(
308 detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
309 QueryMode::Sql
310 );
311 assert_eq!(
312 detect_mode("SELECT name FROM t WHERE id = ?"),
313 QueryMode::Sql
314 );
315 assert_eq!(
316 detect_mode("SELECT name FROM t WHERE id = ?1"),
317 QueryMode::Sql
318 );
319 assert_eq!(
320 detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
321 QueryMode::Sql
322 );
323 assert_eq!(
324 detect_mode("SET SECRET red.secret.api = 'x'"),
325 QueryMode::Sql
326 );
327 assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
328 assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
329 assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
330 assert_eq!(
331 detect_mode("LIST KV settings PREFIX feature LIMIT 10"),
332 QueryMode::Sql
333 );
334 assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
335 assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
336 assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
337 assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
338 assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
339 assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
340 assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
341 assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
342 assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
343 assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
344 assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
345 assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
346 assert_eq!(detect_mode("SHOW CREATE TABLE users"), QueryMode::Sql);
347 assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
348 assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
349 assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
350 assert_eq!(detect_mode("SHOW INDEXES"), QueryMode::Sql);
351 assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
352 }
353
354 #[test]
355 fn test_gremlin_detection() {
356 assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
357 assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
358 assert_eq!(
359 detect_mode("g.V().out('connects').in('has_service')"),
360 QueryMode::Gremlin
361 );
362 assert_eq!(
363 detect_mode("g.E().hasLabel('auth_access')"),
364 QueryMode::Gremlin
365 );
366 assert_eq!(
367 detect_mode("__.out('knows').has('name', 'bob')"),
368 QueryMode::Gremlin
369 );
370 assert_eq!(
371 detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
372 QueryMode::Gremlin
373 );
374 }
375
376 #[test]
377 fn test_cypher_detection() {
378 assert_eq!(
379 detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
380 QueryMode::Cypher
381 );
382 assert_eq!(
383 detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
384 QueryMode::Cypher
385 );
386 assert_eq!(
387 detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
388 QueryMode::Cypher
389 );
390 assert_eq!(
391 detect_mode("MATCH(a:User) RETURN a.name"),
392 QueryMode::Cypher
393 );
394 }
395
396 #[test]
397 fn test_sparql_detection() {
398 assert_eq!(
399 detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
400 QueryMode::Sparql
401 );
402 assert_eq!(
403 detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
404 QueryMode::Sparql
405 );
406 assert_eq!(
407 detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
408 QueryMode::Sparql
409 );
410 assert_eq!(
411 detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
412 QueryMode::Sparql
413 );
414 }
415
416 #[test]
417 fn test_path_detection() {
418 assert_eq!(
419 detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
420 QueryMode::Path
421 );
422 assert_eq!(
423 detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
424 QueryMode::Path
425 );
426 assert_eq!(
427 detect_mode("path from user('root') to service('ssh') via auth_access"),
428 QueryMode::Path
429 );
430 }
431
432 #[test]
433 fn test_natural_detection() {
434 assert_eq!(
435 detect_mode("find all hosts with ssh open"),
436 QueryMode::Natural
437 );
438 assert_eq!(
439 detect_mode("show me vulnerable services"),
440 QueryMode::Natural
441 );
442 assert_eq!(
443 detect_mode("what credentials can reach the database?"),
444 QueryMode::Natural
445 );
446 assert_eq!(
447 detect_mode("list users with weak passwords"),
448 QueryMode::Natural
449 );
450 assert_eq!(
451 detect_mode("\"find hosts connected to 10.0.0.1\""),
452 QueryMode::Natural
453 );
454 assert_eq!(
455 detect_mode("which hosts have critical vulnerabilities?"),
456 QueryMode::Natural
457 );
458 }
459
460 #[test]
461 fn test_edge_cases() {
462 assert_eq!(detect_mode(""), QueryMode::Unknown);
464
465 assert_eq!(detect_mode(" "), QueryMode::Unknown);
467
468 assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
471 assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
472 }
473}