reddb_server/storage/query/modes/
detect.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8 Sql,
10 Gremlin,
12 Cypher,
14 Sparql,
16 Path,
18 Natural,
20 Unknown,
22}
23
24pub fn detect_mode(input: &str) -> QueryMode {
26 let trimmed = input.trim();
27 let lower = trimmed.to_lowercase();
28
29 if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31 return QueryMode::Natural;
32 }
33
34 if lower.starts_with("g.") || lower.starts_with("__.") {
36 return QueryMode::Gremlin;
37 }
38
39 if lower.starts_with("path ") || lower.starts_with("paths ") {
41 return QueryMode::Path;
42 }
43
44 if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46 return QueryMode::Sparql;
47 }
48
49 if lower.starts_with("match ") || lower.starts_with("match(") {
51 return QueryMode::Cypher;
52 }
53
54 let first_token = lower.split_whitespace().next().unwrap_or("");
59 if matches!(
60 first_token,
61 "begin"
62 | "start"
63 | "commit"
64 | "rollback"
65 | "savepoint"
66 | "release"
67 | "end"
68 | "vacuum"
69 | "analyze"
70 | "reset"
71 | "copy"
72 | "refresh"
73 | "explain"
74 | "grant"
75 | "revoke"
76 | "attach"
77 | "detach"
78 | "simulate"
79 | "lint"
80 | "migrate"
81 | "apply"
82 | "events"
83 | "describe"
84 | "desc"
85 ) {
86 return QueryMode::Sql;
87 }
88 if lower.starts_with("select ")
89 || lower.starts_with("from ")
90 || lower.starts_with("insert ")
91 || lower.starts_with("update ")
92 || lower.starts_with("delete ")
93 || lower.starts_with("truncate ")
94 || lower.starts_with("create ")
95 || lower.starts_with("drop ")
96 || lower.starts_with("alter ")
97 || lower.starts_with("vector ")
98 || lower.starts_with("hybrid ")
99 || lower.starts_with("graph ")
100 || lower.starts_with("queue ")
101 || lower.starts_with("events ")
102 || lower.starts_with("tree ")
103 || lower.starts_with("hll ")
104 || lower.starts_with("sketch ")
105 || lower.starts_with("filter ")
106 || lower.starts_with("vault ")
107 || lower.starts_with("unseal vault ")
108 || lower.starts_with("rotate vault ")
109 || lower.starts_with("history vault ")
110 || lower.starts_with("list vault ")
111 || lower.starts_with("watch vault ")
112 || lower.starts_with("delete vault ")
113 || lower.starts_with("purge vault ")
114 || lower.starts_with("search ")
115 || lower.starts_with("ask ")
116 || lower.starts_with("put config ")
117 || lower.starts_with("get config ")
118 || lower.starts_with("resolve config ")
119 || lower.starts_with("rotate config ")
120 || lower.starts_with("delete config ")
121 || lower.starts_with("history config ")
122 || lower.starts_with("list config ")
123 || lower.starts_with("watch config ")
124 || lower.starts_with("incr config ")
125 || lower.starts_with("decr config ")
126 || lower.starts_with("add config ")
127 || lower.starts_with("invalidate config ")
128 || lower.starts_with("invalidate tags ")
129 || lower.starts_with("set config ")
130 || lower.starts_with("set secret ")
131 || lower.starts_with("set tenant")
132 || lower.starts_with("show create ")
133 || lower.starts_with("show config")
134 || lower.starts_with("show collections")
135 || lower.starts_with("show tables")
136 || lower.starts_with("show queues")
137 || lower.starts_with("show vectors")
138 || lower.starts_with("show documents")
139 || lower.starts_with("show timeseries")
140 || lower.starts_with("show graphs")
141 || lower.starts_with("kv ")
142 || lower.starts_with("show kv")
143 || lower.starts_with("show configs")
144 || lower.starts_with("show vaults")
145 || lower.starts_with("show schema")
146 || lower.starts_with("show indices")
147 || lower.starts_with("show indexes")
148 || lower.starts_with("show sample ")
149 || lower.starts_with("show secret")
150 || lower.starts_with("show stats")
151 || lower.starts_with("show tenant")
152 || lower.starts_with("show policies")
153 || lower.starts_with("show effective ")
154 || lower.starts_with("rank of ")
155 || lower.starts_with("rank range ")
156 || lower.starts_with("approx rank of ")
157 || lower.starts_with("approximate rank of ")
158 || lower.starts_with("zrank ")
159 || lower.starts_with("zrange ")
160 || lower.starts_with("describe ")
161 || lower.starts_with("desc ")
162 {
163 if lower.starts_with("select ") && has_sparql_variable(&lower) {
167 return QueryMode::Sparql;
168 }
169 return QueryMode::Sql;
170 }
171
172 if is_natural_language(&lower) {
174 return QueryMode::Natural;
175 }
176
177 QueryMode::Unknown
178}
179
180fn has_sparql_pattern(lower: &str) -> bool {
182 let has_var = has_sparql_variable(lower);
188
189 let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
191
192 let has_prefix_pattern = lower.contains(":")
194 && (lower.contains(":<")
195 || lower.contains("> :")
196 || lower.contains(" :") && lower.contains("?"));
197
198 has_var || has_triple_pattern || has_prefix_pattern
199}
200
201fn has_sparql_variable(input: &str) -> bool {
202 let bytes = input.as_bytes();
203 bytes
204 .windows(2)
205 .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
206}
207
208fn is_sparql_variable_start(byte: u8) -> bool {
209 byte.is_ascii_alphabetic() || byte == b'_'
210}
211
212fn is_natural_language(lower: &str) -> bool {
214 let question_starters = [
216 "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
217 "tell ", "display ", "search ", "look ",
218 ];
219
220 let nl_patterns = [
222 " with ",
223 " for ",
224 " that ",
225 " have ",
226 " has ",
227 " can ",
228 " are ",
229 " is ",
230 " all ",
231 " me ",
232 " the ",
233 " from ",
234 " to ",
235 " on ",
236 " in ",
237 "vulnerable",
238 "credential",
239 "password",
240 "user",
241 "host",
242 "service",
243 "connected",
244 "reachable",
245 "exposed",
246 "critical",
247 ];
248
249 for starter in question_starters.iter() {
251 if lower.starts_with(starter) {
252 return true;
253 }
254 }
255
256 let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
258
259 pattern_count >= 2
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_sql_detection() {
268 assert_eq!(
269 detect_mode("SELECT * FROM users WHERE id = 1"),
270 QueryMode::Sql
271 );
272 assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
273 assert_eq!(
274 detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
275 QueryMode::Sql
276 );
277 assert_eq!(
278 detect_mode("INSERT INTO users VALUES (1, 'alice')"),
279 QueryMode::Sql
280 );
281 assert_eq!(
282 detect_mode("UPDATE hosts SET status = 'active'"),
283 QueryMode::Sql
284 );
285 assert_eq!(
286 detect_mode("DELETE FROM logs WHERE age > 30"),
287 QueryMode::Sql
288 );
289 assert_eq!(
290 detect_mode("QUEUE GROUP CREATE tasks workers"),
291 QueryMode::Sql
292 );
293 assert_eq!(
294 detect_mode("EVENTS BACKFILL users TO audit"),
295 QueryMode::Sql
296 );
297 assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
298 assert_eq!(
299 detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
300 QueryMode::Sql
301 );
302 assert_eq!(
303 detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
304 QueryMode::Sql
305 );
306 assert_eq!(
307 detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
308 QueryMode::Sql
309 );
310 assert_eq!(
311 detect_mode("SELECT name FROM t WHERE id = ?"),
312 QueryMode::Sql
313 );
314 assert_eq!(
315 detect_mode("SELECT name FROM t WHERE id = ?1"),
316 QueryMode::Sql
317 );
318 assert_eq!(
319 detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
320 QueryMode::Sql
321 );
322 assert_eq!(
323 detect_mode("SET SECRET red.secret.api = 'x'"),
324 QueryMode::Sql
325 );
326 assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
327 assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
328 assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
329 assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
330 assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
331 assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
332 assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
333 assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
334 assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
335 assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
336 assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
337 assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
338 assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
339 assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
340 assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
341 assert_eq!(detect_mode("SHOW CREATE TABLE users"), QueryMode::Sql);
342 assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
343 assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
344 assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
345 assert_eq!(detect_mode("SHOW INDEXES"), QueryMode::Sql);
346 assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
347 }
348
349 #[test]
350 fn test_gremlin_detection() {
351 assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
352 assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
353 assert_eq!(
354 detect_mode("g.V().out('connects').in('has_service')"),
355 QueryMode::Gremlin
356 );
357 assert_eq!(
358 detect_mode("g.E().hasLabel('auth_access')"),
359 QueryMode::Gremlin
360 );
361 assert_eq!(
362 detect_mode("__.out('knows').has('name', 'bob')"),
363 QueryMode::Gremlin
364 );
365 assert_eq!(
366 detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
367 QueryMode::Gremlin
368 );
369 }
370
371 #[test]
372 fn test_cypher_detection() {
373 assert_eq!(
374 detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
375 QueryMode::Cypher
376 );
377 assert_eq!(
378 detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
379 QueryMode::Cypher
380 );
381 assert_eq!(
382 detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
383 QueryMode::Cypher
384 );
385 assert_eq!(
386 detect_mode("MATCH(a:User) RETURN a.name"),
387 QueryMode::Cypher
388 );
389 }
390
391 #[test]
392 fn test_sparql_detection() {
393 assert_eq!(
394 detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
395 QueryMode::Sparql
396 );
397 assert_eq!(
398 detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
399 QueryMode::Sparql
400 );
401 assert_eq!(
402 detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
403 QueryMode::Sparql
404 );
405 assert_eq!(
406 detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
407 QueryMode::Sparql
408 );
409 }
410
411 #[test]
412 fn test_path_detection() {
413 assert_eq!(
414 detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
415 QueryMode::Path
416 );
417 assert_eq!(
418 detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
419 QueryMode::Path
420 );
421 assert_eq!(
422 detect_mode("path from user('root') to service('ssh') via auth_access"),
423 QueryMode::Path
424 );
425 }
426
427 #[test]
428 fn test_natural_detection() {
429 assert_eq!(
430 detect_mode("find all hosts with ssh open"),
431 QueryMode::Natural
432 );
433 assert_eq!(
434 detect_mode("show me vulnerable services"),
435 QueryMode::Natural
436 );
437 assert_eq!(
438 detect_mode("what credentials can reach the database?"),
439 QueryMode::Natural
440 );
441 assert_eq!(
442 detect_mode("list users with weak passwords"),
443 QueryMode::Natural
444 );
445 assert_eq!(
446 detect_mode("\"find hosts connected to 10.0.0.1\""),
447 QueryMode::Natural
448 );
449 assert_eq!(
450 detect_mode("which hosts have critical vulnerabilities?"),
451 QueryMode::Natural
452 );
453 }
454
455 #[test]
456 fn test_edge_cases() {
457 assert_eq!(detect_mode(""), QueryMode::Unknown);
459
460 assert_eq!(detect_mode(" "), QueryMode::Unknown);
462
463 assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
466 assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
467 }
468}