reddb_server/storage/query/modes/
detect.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8 Sql,
10 Gremlin,
12 Cypher,
14 Sparql,
16 Path,
18 Natural,
20 Unknown,
22}
23
24pub fn detect_mode(input: &str) -> QueryMode {
26 let trimmed = input.trim();
27 let lower = trimmed.to_lowercase();
28
29 if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31 return QueryMode::Natural;
32 }
33
34 if lower.starts_with("g.") || lower.starts_with("__.") {
36 return QueryMode::Gremlin;
37 }
38
39 if lower.starts_with("path ") || lower.starts_with("paths ") {
41 return QueryMode::Path;
42 }
43
44 if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46 return QueryMode::Sparql;
47 }
48
49 if lower.starts_with("match ") || lower.starts_with("match(") {
51 return QueryMode::Cypher;
52 }
53
54 let first_token = lower.split_whitespace().next().unwrap_or("");
59 if matches!(
60 first_token,
61 "begin"
62 | "start"
63 | "commit"
64 | "rollback"
65 | "savepoint"
66 | "release"
67 | "end"
68 | "vacuum"
69 | "analyze"
70 | "reset"
71 | "copy"
72 | "refresh"
73 | "explain"
74 | "grant"
75 | "revoke"
76 | "attach"
77 | "detach"
78 | "simulate"
79 | "apply"
80 | "events"
81 | "describe"
82 | "desc"
83 ) {
84 return QueryMode::Sql;
85 }
86 if lower.starts_with("select ")
87 || lower.starts_with("from ")
88 || lower.starts_with("insert ")
89 || lower.starts_with("update ")
90 || lower.starts_with("delete ")
91 || lower.starts_with("truncate ")
92 || lower.starts_with("create ")
93 || lower.starts_with("drop ")
94 || lower.starts_with("alter ")
95 || lower.starts_with("vector ")
96 || lower.starts_with("hybrid ")
97 || lower.starts_with("graph ")
98 || lower.starts_with("queue ")
99 || lower.starts_with("events ")
100 || lower.starts_with("tree ")
101 || lower.starts_with("hll ")
102 || lower.starts_with("sketch ")
103 || lower.starts_with("filter ")
104 || lower.starts_with("vault ")
105 || lower.starts_with("unseal vault ")
106 || lower.starts_with("rotate vault ")
107 || lower.starts_with("history vault ")
108 || lower.starts_with("list vault ")
109 || lower.starts_with("watch vault ")
110 || lower.starts_with("delete vault ")
111 || lower.starts_with("purge vault ")
112 || lower.starts_with("search ")
113 || lower.starts_with("ask ")
114 || lower.starts_with("put config ")
115 || lower.starts_with("get config ")
116 || lower.starts_with("resolve config ")
117 || lower.starts_with("rotate config ")
118 || lower.starts_with("delete config ")
119 || lower.starts_with("history config ")
120 || lower.starts_with("list config ")
121 || lower.starts_with("watch config ")
122 || lower.starts_with("incr config ")
123 || lower.starts_with("decr config ")
124 || lower.starts_with("add config ")
125 || lower.starts_with("invalidate config ")
126 || lower.starts_with("invalidate tags ")
127 || lower.starts_with("set config ")
128 || lower.starts_with("set secret ")
129 || lower.starts_with("set tenant")
130 || lower.starts_with("show create ")
131 || lower.starts_with("show config")
132 || lower.starts_with("show collections")
133 || lower.starts_with("show tables")
134 || lower.starts_with("show queues")
135 || lower.starts_with("show vectors")
136 || lower.starts_with("show documents")
137 || lower.starts_with("show timeseries")
138 || lower.starts_with("show graphs")
139 || lower.starts_with("kv ")
140 || lower.starts_with("show kv")
141 || lower.starts_with("show configs")
142 || lower.starts_with("show vaults")
143 || lower.starts_with("show schema")
144 || lower.starts_with("show indices")
145 || lower.starts_with("show indexes")
146 || lower.starts_with("show sample ")
147 || lower.starts_with("show secret")
148 || lower.starts_with("show stats")
149 || lower.starts_with("show tenant")
150 || lower.starts_with("show policies")
151 || lower.starts_with("show effective ")
152 || lower.starts_with("describe ")
153 || lower.starts_with("desc ")
154 {
155 if lower.starts_with("select ") && has_sparql_variable(&lower) {
159 return QueryMode::Sparql;
160 }
161 return QueryMode::Sql;
162 }
163
164 if is_natural_language(&lower) {
166 return QueryMode::Natural;
167 }
168
169 QueryMode::Unknown
170}
171
172fn has_sparql_pattern(lower: &str) -> bool {
174 let has_var = has_sparql_variable(lower);
180
181 let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
183
184 let has_prefix_pattern = lower.contains(":")
186 && (lower.contains(":<")
187 || lower.contains("> :")
188 || lower.contains(" :") && lower.contains("?"));
189
190 has_var || has_triple_pattern || has_prefix_pattern
191}
192
193fn has_sparql_variable(input: &str) -> bool {
194 let bytes = input.as_bytes();
195 bytes
196 .windows(2)
197 .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
198}
199
200fn is_sparql_variable_start(byte: u8) -> bool {
201 byte.is_ascii_alphabetic() || byte == b'_'
202}
203
204fn is_natural_language(lower: &str) -> bool {
206 let question_starters = [
208 "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
209 "tell ", "display ", "search ", "look ",
210 ];
211
212 let nl_patterns = [
214 " with ",
215 " for ",
216 " that ",
217 " have ",
218 " has ",
219 " can ",
220 " are ",
221 " is ",
222 " all ",
223 " me ",
224 " the ",
225 " from ",
226 " to ",
227 " on ",
228 " in ",
229 "vulnerable",
230 "credential",
231 "password",
232 "user",
233 "host",
234 "service",
235 "connected",
236 "reachable",
237 "exposed",
238 "critical",
239 ];
240
241 for starter in question_starters.iter() {
243 if lower.starts_with(starter) {
244 return true;
245 }
246 }
247
248 let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
250
251 pattern_count >= 2
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_sql_detection() {
260 assert_eq!(
261 detect_mode("SELECT * FROM users WHERE id = 1"),
262 QueryMode::Sql
263 );
264 assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
265 assert_eq!(
266 detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
267 QueryMode::Sql
268 );
269 assert_eq!(
270 detect_mode("INSERT INTO users VALUES (1, 'alice')"),
271 QueryMode::Sql
272 );
273 assert_eq!(
274 detect_mode("UPDATE hosts SET status = 'active'"),
275 QueryMode::Sql
276 );
277 assert_eq!(
278 detect_mode("DELETE FROM logs WHERE age > 30"),
279 QueryMode::Sql
280 );
281 assert_eq!(
282 detect_mode("QUEUE GROUP CREATE tasks workers"),
283 QueryMode::Sql
284 );
285 assert_eq!(
286 detect_mode("EVENTS BACKFILL users TO audit"),
287 QueryMode::Sql
288 );
289 assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
290 assert_eq!(
291 detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
292 QueryMode::Sql
293 );
294 assert_eq!(
295 detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
296 QueryMode::Sql
297 );
298 assert_eq!(
299 detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
300 QueryMode::Sql
301 );
302 assert_eq!(
303 detect_mode("SELECT name FROM t WHERE id = ?"),
304 QueryMode::Sql
305 );
306 assert_eq!(
307 detect_mode("SELECT name FROM t WHERE id = ?1"),
308 QueryMode::Sql
309 );
310 assert_eq!(
311 detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
312 QueryMode::Sql
313 );
314 assert_eq!(
315 detect_mode("SET SECRET red.secret.api = 'x'"),
316 QueryMode::Sql
317 );
318 assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
319 assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
320 assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
321 assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
322 assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
323 assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
324 assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
325 assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
326 assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
327 assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
328 assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
329 assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
330 assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
331 assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
332 assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
333 assert_eq!(detect_mode("SHOW CREATE TABLE users"), QueryMode::Sql);
334 assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
335 assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
336 assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
337 assert_eq!(detect_mode("SHOW INDEXES"), QueryMode::Sql);
338 assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
339 }
340
341 #[test]
342 fn test_gremlin_detection() {
343 assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
344 assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
345 assert_eq!(
346 detect_mode("g.V().out('connects').in('has_service')"),
347 QueryMode::Gremlin
348 );
349 assert_eq!(
350 detect_mode("g.E().hasLabel('auth_access')"),
351 QueryMode::Gremlin
352 );
353 assert_eq!(
354 detect_mode("__.out('knows').has('name', 'bob')"),
355 QueryMode::Gremlin
356 );
357 assert_eq!(
358 detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
359 QueryMode::Gremlin
360 );
361 }
362
363 #[test]
364 fn test_cypher_detection() {
365 assert_eq!(
366 detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
367 QueryMode::Cypher
368 );
369 assert_eq!(
370 detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
371 QueryMode::Cypher
372 );
373 assert_eq!(
374 detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
375 QueryMode::Cypher
376 );
377 assert_eq!(
378 detect_mode("MATCH(a:User) RETURN a.name"),
379 QueryMode::Cypher
380 );
381 }
382
383 #[test]
384 fn test_sparql_detection() {
385 assert_eq!(
386 detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
387 QueryMode::Sparql
388 );
389 assert_eq!(
390 detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
391 QueryMode::Sparql
392 );
393 assert_eq!(
394 detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
395 QueryMode::Sparql
396 );
397 assert_eq!(
398 detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
399 QueryMode::Sparql
400 );
401 }
402
403 #[test]
404 fn test_path_detection() {
405 assert_eq!(
406 detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
407 QueryMode::Path
408 );
409 assert_eq!(
410 detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
411 QueryMode::Path
412 );
413 assert_eq!(
414 detect_mode("path from user('root') to service('ssh') via auth_access"),
415 QueryMode::Path
416 );
417 }
418
419 #[test]
420 fn test_natural_detection() {
421 assert_eq!(
422 detect_mode("find all hosts with ssh open"),
423 QueryMode::Natural
424 );
425 assert_eq!(
426 detect_mode("show me vulnerable services"),
427 QueryMode::Natural
428 );
429 assert_eq!(
430 detect_mode("what credentials can reach the database?"),
431 QueryMode::Natural
432 );
433 assert_eq!(
434 detect_mode("list users with weak passwords"),
435 QueryMode::Natural
436 );
437 assert_eq!(
438 detect_mode("\"find hosts connected to 10.0.0.1\""),
439 QueryMode::Natural
440 );
441 assert_eq!(
442 detect_mode("which hosts have critical vulnerabilities?"),
443 QueryMode::Natural
444 );
445 }
446
447 #[test]
448 fn test_edge_cases() {
449 assert_eq!(detect_mode(""), QueryMode::Unknown);
451
452 assert_eq!(detect_mode(" "), QueryMode::Unknown);
454
455 assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
458 assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
459 }
460}