reddb_server/storage/query/modes/
detect.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8 Sql,
10 Gremlin,
12 Cypher,
14 Sparql,
16 Path,
18 Natural,
20 Unknown,
22}
23
24pub fn detect_mode(input: &str) -> QueryMode {
26 let trimmed = input.trim();
27 let lower = trimmed.to_lowercase();
28
29 if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31 return QueryMode::Natural;
32 }
33
34 if lower.starts_with("g.") || lower.starts_with("__.") {
36 return QueryMode::Gremlin;
37 }
38
39 if lower.starts_with("path ") || lower.starts_with("paths ") {
41 return QueryMode::Path;
42 }
43
44 if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46 return QueryMode::Sparql;
47 }
48
49 if lower.starts_with("match ") || lower.starts_with("match(") {
51 return QueryMode::Cypher;
52 }
53
54 let first_token = lower.split_whitespace().next().unwrap_or("");
59 if matches!(
60 first_token,
61 "begin"
62 | "start"
63 | "commit"
64 | "rollback"
65 | "savepoint"
66 | "release"
67 | "end"
68 | "vacuum"
69 | "analyze"
70 | "reset"
71 | "copy"
72 | "refresh"
73 | "explain"
74 | "grant"
75 | "revoke"
76 | "attach"
77 | "detach"
78 | "simulate"
79 | "lint"
80 | "migrate"
81 | "apply"
82 | "events"
83 | "describe"
84 | "desc"
85 ) {
86 return QueryMode::Sql;
87 }
88 if lower.starts_with("select ")
89 || lower.starts_with("from ")
90 || lower.starts_with("insert ")
91 || lower.starts_with("update ")
92 || lower.starts_with("delete ")
93 || lower.starts_with("truncate ")
94 || lower.starts_with("create ")
95 || lower.starts_with("drop ")
96 || lower.starts_with("alter ")
97 || lower.starts_with("vector ")
98 || lower.starts_with("hybrid ")
99 || lower.starts_with("graph ")
100 || lower.starts_with("queue ")
101 || lower.starts_with("events ")
102 || lower.starts_with("tree ")
103 || lower.starts_with("hll ")
104 || lower.starts_with("sketch ")
105 || lower.starts_with("filter ")
106 || lower.starts_with("vault ")
107 || lower.starts_with("unseal vault ")
108 || lower.starts_with("rotate vault ")
109 || lower.starts_with("history vault ")
110 || lower.starts_with("list vault ")
111 || lower.starts_with("watch vault ")
112 || lower.starts_with("delete vault ")
113 || lower.starts_with("purge vault ")
114 || lower.starts_with("search ")
115 || lower.starts_with("ask ")
116 || lower.starts_with("put config ")
117 || lower.starts_with("get config ")
118 || lower.starts_with("resolve config ")
119 || lower.starts_with("rotate config ")
120 || lower.starts_with("delete config ")
121 || lower.starts_with("history config ")
122 || lower.starts_with("list config ")
123 || lower.starts_with("watch config ")
124 || lower.starts_with("incr config ")
125 || lower.starts_with("decr config ")
126 || lower.starts_with("add config ")
127 || lower.starts_with("invalidate config ")
128 || lower.starts_with("invalidate tags ")
129 || lower.starts_with("set config ")
130 || lower.starts_with("set secret ")
131 || lower.starts_with("set tenant")
132 || lower.starts_with("show create ")
133 || lower.starts_with("show config")
134 || lower.starts_with("show collections")
135 || lower.starts_with("show tables")
136 || lower.starts_with("show queues")
137 || lower.starts_with("show vectors")
138 || lower.starts_with("show documents")
139 || lower.starts_with("show timeseries")
140 || lower.starts_with("show graphs")
141 || lower.starts_with("kv ")
142 || lower.starts_with("show kv")
143 || lower.starts_with("show configs")
144 || lower.starts_with("show vaults")
145 || lower.starts_with("show schema")
146 || lower.starts_with("show indices")
147 || lower.starts_with("show indexes")
148 || lower.starts_with("show sample ")
149 || lower.starts_with("show secret")
150 || lower.starts_with("show stats")
151 || lower.starts_with("show tenant")
152 || lower.starts_with("show policies")
153 || lower.starts_with("show effective ")
154 || lower.starts_with("describe ")
155 || lower.starts_with("desc ")
156 {
157 if lower.starts_with("select ") && has_sparql_variable(&lower) {
161 return QueryMode::Sparql;
162 }
163 return QueryMode::Sql;
164 }
165
166 if is_natural_language(&lower) {
168 return QueryMode::Natural;
169 }
170
171 QueryMode::Unknown
172}
173
174fn has_sparql_pattern(lower: &str) -> bool {
176 let has_var = has_sparql_variable(lower);
182
183 let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
185
186 let has_prefix_pattern = lower.contains(":")
188 && (lower.contains(":<")
189 || lower.contains("> :")
190 || lower.contains(" :") && lower.contains("?"));
191
192 has_var || has_triple_pattern || has_prefix_pattern
193}
194
195fn has_sparql_variable(input: &str) -> bool {
196 let bytes = input.as_bytes();
197 bytes
198 .windows(2)
199 .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
200}
201
202fn is_sparql_variable_start(byte: u8) -> bool {
203 byte.is_ascii_alphabetic() || byte == b'_'
204}
205
206fn is_natural_language(lower: &str) -> bool {
208 let question_starters = [
210 "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
211 "tell ", "display ", "search ", "look ",
212 ];
213
214 let nl_patterns = [
216 " with ",
217 " for ",
218 " that ",
219 " have ",
220 " has ",
221 " can ",
222 " are ",
223 " is ",
224 " all ",
225 " me ",
226 " the ",
227 " from ",
228 " to ",
229 " on ",
230 " in ",
231 "vulnerable",
232 "credential",
233 "password",
234 "user",
235 "host",
236 "service",
237 "connected",
238 "reachable",
239 "exposed",
240 "critical",
241 ];
242
243 for starter in question_starters.iter() {
245 if lower.starts_with(starter) {
246 return true;
247 }
248 }
249
250 let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
252
253 pattern_count >= 2
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_sql_detection() {
262 assert_eq!(
263 detect_mode("SELECT * FROM users WHERE id = 1"),
264 QueryMode::Sql
265 );
266 assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
267 assert_eq!(
268 detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
269 QueryMode::Sql
270 );
271 assert_eq!(
272 detect_mode("INSERT INTO users VALUES (1, 'alice')"),
273 QueryMode::Sql
274 );
275 assert_eq!(
276 detect_mode("UPDATE hosts SET status = 'active'"),
277 QueryMode::Sql
278 );
279 assert_eq!(
280 detect_mode("DELETE FROM logs WHERE age > 30"),
281 QueryMode::Sql
282 );
283 assert_eq!(
284 detect_mode("QUEUE GROUP CREATE tasks workers"),
285 QueryMode::Sql
286 );
287 assert_eq!(
288 detect_mode("EVENTS BACKFILL users TO audit"),
289 QueryMode::Sql
290 );
291 assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
292 assert_eq!(
293 detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
294 QueryMode::Sql
295 );
296 assert_eq!(
297 detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
298 QueryMode::Sql
299 );
300 assert_eq!(
301 detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
302 QueryMode::Sql
303 );
304 assert_eq!(
305 detect_mode("SELECT name FROM t WHERE id = ?"),
306 QueryMode::Sql
307 );
308 assert_eq!(
309 detect_mode("SELECT name FROM t WHERE id = ?1"),
310 QueryMode::Sql
311 );
312 assert_eq!(
313 detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
314 QueryMode::Sql
315 );
316 assert_eq!(
317 detect_mode("SET SECRET red.secret.api = 'x'"),
318 QueryMode::Sql
319 );
320 assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
321 assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
322 assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
323 assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
324 assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
325 assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
326 assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
327 assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
328 assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
329 assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
330 assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
331 assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
332 assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
333 assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
334 assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
335 assert_eq!(detect_mode("SHOW CREATE TABLE users"), QueryMode::Sql);
336 assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
337 assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
338 assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
339 assert_eq!(detect_mode("SHOW INDEXES"), QueryMode::Sql);
340 assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
341 }
342
343 #[test]
344 fn test_gremlin_detection() {
345 assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
346 assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
347 assert_eq!(
348 detect_mode("g.V().out('connects').in('has_service')"),
349 QueryMode::Gremlin
350 );
351 assert_eq!(
352 detect_mode("g.E().hasLabel('auth_access')"),
353 QueryMode::Gremlin
354 );
355 assert_eq!(
356 detect_mode("__.out('knows').has('name', 'bob')"),
357 QueryMode::Gremlin
358 );
359 assert_eq!(
360 detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
361 QueryMode::Gremlin
362 );
363 }
364
365 #[test]
366 fn test_cypher_detection() {
367 assert_eq!(
368 detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
369 QueryMode::Cypher
370 );
371 assert_eq!(
372 detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
373 QueryMode::Cypher
374 );
375 assert_eq!(
376 detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
377 QueryMode::Cypher
378 );
379 assert_eq!(
380 detect_mode("MATCH(a:User) RETURN a.name"),
381 QueryMode::Cypher
382 );
383 }
384
385 #[test]
386 fn test_sparql_detection() {
387 assert_eq!(
388 detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
389 QueryMode::Sparql
390 );
391 assert_eq!(
392 detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
393 QueryMode::Sparql
394 );
395 assert_eq!(
396 detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
397 QueryMode::Sparql
398 );
399 assert_eq!(
400 detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
401 QueryMode::Sparql
402 );
403 }
404
405 #[test]
406 fn test_path_detection() {
407 assert_eq!(
408 detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
409 QueryMode::Path
410 );
411 assert_eq!(
412 detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
413 QueryMode::Path
414 );
415 assert_eq!(
416 detect_mode("path from user('root') to service('ssh') via auth_access"),
417 QueryMode::Path
418 );
419 }
420
421 #[test]
422 fn test_natural_detection() {
423 assert_eq!(
424 detect_mode("find all hosts with ssh open"),
425 QueryMode::Natural
426 );
427 assert_eq!(
428 detect_mode("show me vulnerable services"),
429 QueryMode::Natural
430 );
431 assert_eq!(
432 detect_mode("what credentials can reach the database?"),
433 QueryMode::Natural
434 );
435 assert_eq!(
436 detect_mode("list users with weak passwords"),
437 QueryMode::Natural
438 );
439 assert_eq!(
440 detect_mode("\"find hosts connected to 10.0.0.1\""),
441 QueryMode::Natural
442 );
443 assert_eq!(
444 detect_mode("which hosts have critical vulnerabilities?"),
445 QueryMode::Natural
446 );
447 }
448
449 #[test]
450 fn test_edge_cases() {
451 assert_eq!(detect_mode(""), QueryMode::Unknown);
453
454 assert_eq!(detect_mode(" "), QueryMode::Unknown);
456
457 assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
460 assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
461 }
462}