reddb_server/storage/query/modes/
detect.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8 Sql,
10 Gremlin,
12 Cypher,
14 Sparql,
16 Path,
18 Natural,
20 Unknown,
22}
23
24pub fn detect_mode(input: &str) -> QueryMode {
26 let trimmed = input.trim();
27 let lower = trimmed.to_lowercase();
28
29 if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31 return QueryMode::Natural;
32 }
33
34 if lower.starts_with("g.") || lower.starts_with("__.") {
36 return QueryMode::Gremlin;
37 }
38
39 if lower.starts_with("path ") || lower.starts_with("paths ") {
41 return QueryMode::Path;
42 }
43
44 if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46 return QueryMode::Sparql;
47 }
48
49 if lower.starts_with("match ") || lower.starts_with("match(") {
51 return QueryMode::Cypher;
52 }
53
54 let first_token = lower.split_whitespace().next().unwrap_or("");
59 if matches!(
60 first_token,
61 "begin"
62 | "start"
63 | "commit"
64 | "rollback"
65 | "savepoint"
66 | "release"
67 | "end"
68 | "vacuum"
69 | "analyze"
70 | "reset"
71 | "copy"
72 | "refresh"
73 | "explain"
74 | "grant"
75 | "revoke"
76 | "attach"
77 | "detach"
78 | "simulate"
79 | "apply"
80 | "events"
81 | "describe"
82 | "desc"
83 ) {
84 return QueryMode::Sql;
85 }
86 if lower.starts_with("select ")
87 || lower.starts_with("from ")
88 || lower.starts_with("insert ")
89 || lower.starts_with("update ")
90 || lower.starts_with("delete ")
91 || lower.starts_with("truncate ")
92 || lower.starts_with("create ")
93 || lower.starts_with("drop ")
94 || lower.starts_with("alter ")
95 || lower.starts_with("vector ")
96 || lower.starts_with("hybrid ")
97 || lower.starts_with("graph ")
98 || lower.starts_with("queue ")
99 || lower.starts_with("events ")
100 || lower.starts_with("tree ")
101 || lower.starts_with("hll ")
102 || lower.starts_with("sketch ")
103 || lower.starts_with("filter ")
104 || lower.starts_with("vault ")
105 || lower.starts_with("unseal vault ")
106 || lower.starts_with("rotate vault ")
107 || lower.starts_with("history vault ")
108 || lower.starts_with("list vault ")
109 || lower.starts_with("watch vault ")
110 || lower.starts_with("delete vault ")
111 || lower.starts_with("purge vault ")
112 || lower.starts_with("search ")
113 || lower.starts_with("ask ")
114 || lower.starts_with("put config ")
115 || lower.starts_with("get config ")
116 || lower.starts_with("resolve config ")
117 || lower.starts_with("rotate config ")
118 || lower.starts_with("delete config ")
119 || lower.starts_with("history config ")
120 || lower.starts_with("list config ")
121 || lower.starts_with("watch config ")
122 || lower.starts_with("incr config ")
123 || lower.starts_with("decr config ")
124 || lower.starts_with("add config ")
125 || lower.starts_with("invalidate config ")
126 || lower.starts_with("invalidate tags ")
127 || lower.starts_with("set config ")
128 || lower.starts_with("set secret ")
129 || lower.starts_with("set tenant")
130 || lower.starts_with("show config")
131 || lower.starts_with("show collections")
132 || lower.starts_with("show tables")
133 || lower.starts_with("show queues")
134 || lower.starts_with("show vectors")
135 || lower.starts_with("show documents")
136 || lower.starts_with("show timeseries")
137 || lower.starts_with("show graphs")
138 || lower.starts_with("kv ")
139 || lower.starts_with("show kv")
140 || lower.starts_with("show configs")
141 || lower.starts_with("show vaults")
142 || lower.starts_with("show schema")
143 || lower.starts_with("show indices")
144 || lower.starts_with("show sample ")
145 || lower.starts_with("show secret")
146 || lower.starts_with("show stats")
147 || lower.starts_with("show tenant")
148 || lower.starts_with("show policies")
149 || lower.starts_with("show effective ")
150 || lower.starts_with("describe ")
151 || lower.starts_with("desc ")
152 {
153 if lower.starts_with("select ") && has_sparql_variable(&lower) {
157 return QueryMode::Sparql;
158 }
159 return QueryMode::Sql;
160 }
161
162 if is_natural_language(&lower) {
164 return QueryMode::Natural;
165 }
166
167 QueryMode::Unknown
168}
169
170fn has_sparql_pattern(lower: &str) -> bool {
172 let has_var = has_sparql_variable(lower);
178
179 let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
181
182 let has_prefix_pattern = lower.contains(":")
184 && (lower.contains(":<")
185 || lower.contains("> :")
186 || lower.contains(" :") && lower.contains("?"));
187
188 has_var || has_triple_pattern || has_prefix_pattern
189}
190
191fn has_sparql_variable(input: &str) -> bool {
192 let bytes = input.as_bytes();
193 bytes
194 .windows(2)
195 .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
196}
197
198fn is_sparql_variable_start(byte: u8) -> bool {
199 byte.is_ascii_alphabetic() || byte == b'_'
200}
201
202fn is_natural_language(lower: &str) -> bool {
204 let question_starters = [
206 "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
207 "tell ", "display ", "search ", "look ",
208 ];
209
210 let nl_patterns = [
212 " with ",
213 " for ",
214 " that ",
215 " have ",
216 " has ",
217 " can ",
218 " are ",
219 " is ",
220 " all ",
221 " me ",
222 " the ",
223 " from ",
224 " to ",
225 " on ",
226 " in ",
227 "vulnerable",
228 "credential",
229 "password",
230 "user",
231 "host",
232 "service",
233 "connected",
234 "reachable",
235 "exposed",
236 "critical",
237 ];
238
239 for starter in question_starters.iter() {
241 if lower.starts_with(starter) {
242 return true;
243 }
244 }
245
246 let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
248
249 pattern_count >= 2
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_sql_detection() {
258 assert_eq!(
259 detect_mode("SELECT * FROM users WHERE id = 1"),
260 QueryMode::Sql
261 );
262 assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
263 assert_eq!(
264 detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
265 QueryMode::Sql
266 );
267 assert_eq!(
268 detect_mode("INSERT INTO users VALUES (1, 'alice')"),
269 QueryMode::Sql
270 );
271 assert_eq!(
272 detect_mode("UPDATE hosts SET status = 'active'"),
273 QueryMode::Sql
274 );
275 assert_eq!(
276 detect_mode("DELETE FROM logs WHERE age > 30"),
277 QueryMode::Sql
278 );
279 assert_eq!(
280 detect_mode("QUEUE GROUP CREATE tasks workers"),
281 QueryMode::Sql
282 );
283 assert_eq!(
284 detect_mode("EVENTS BACKFILL users TO audit"),
285 QueryMode::Sql
286 );
287 assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
288 assert_eq!(
289 detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
290 QueryMode::Sql
291 );
292 assert_eq!(
293 detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
294 QueryMode::Sql
295 );
296 assert_eq!(
297 detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
298 QueryMode::Sql
299 );
300 assert_eq!(
301 detect_mode("SELECT name FROM t WHERE id = ?"),
302 QueryMode::Sql
303 );
304 assert_eq!(
305 detect_mode("SELECT name FROM t WHERE id = ?1"),
306 QueryMode::Sql
307 );
308 assert_eq!(
309 detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
310 QueryMode::Sql
311 );
312 assert_eq!(
313 detect_mode("SET SECRET red.secret.api = 'x'"),
314 QueryMode::Sql
315 );
316 assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
317 assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
318 assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
319 assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
320 assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
321 assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
322 assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
323 assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
324 assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
325 assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
326 assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
327 assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
328 assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
329 assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
330 assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
331 assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
332 assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
333 assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
334 assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
335 }
336
337 #[test]
338 fn test_gremlin_detection() {
339 assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
340 assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
341 assert_eq!(
342 detect_mode("g.V().out('connects').in('has_service')"),
343 QueryMode::Gremlin
344 );
345 assert_eq!(
346 detect_mode("g.E().hasLabel('auth_access')"),
347 QueryMode::Gremlin
348 );
349 assert_eq!(
350 detect_mode("__.out('knows').has('name', 'bob')"),
351 QueryMode::Gremlin
352 );
353 assert_eq!(
354 detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
355 QueryMode::Gremlin
356 );
357 }
358
359 #[test]
360 fn test_cypher_detection() {
361 assert_eq!(
362 detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
363 QueryMode::Cypher
364 );
365 assert_eq!(
366 detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
367 QueryMode::Cypher
368 );
369 assert_eq!(
370 detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
371 QueryMode::Cypher
372 );
373 assert_eq!(
374 detect_mode("MATCH(a:User) RETURN a.name"),
375 QueryMode::Cypher
376 );
377 }
378
379 #[test]
380 fn test_sparql_detection() {
381 assert_eq!(
382 detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
383 QueryMode::Sparql
384 );
385 assert_eq!(
386 detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
387 QueryMode::Sparql
388 );
389 assert_eq!(
390 detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
391 QueryMode::Sparql
392 );
393 assert_eq!(
394 detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
395 QueryMode::Sparql
396 );
397 }
398
399 #[test]
400 fn test_path_detection() {
401 assert_eq!(
402 detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
403 QueryMode::Path
404 );
405 assert_eq!(
406 detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
407 QueryMode::Path
408 );
409 assert_eq!(
410 detect_mode("path from user('root') to service('ssh') via auth_access"),
411 QueryMode::Path
412 );
413 }
414
415 #[test]
416 fn test_natural_detection() {
417 assert_eq!(
418 detect_mode("find all hosts with ssh open"),
419 QueryMode::Natural
420 );
421 assert_eq!(
422 detect_mode("show me vulnerable services"),
423 QueryMode::Natural
424 );
425 assert_eq!(
426 detect_mode("what credentials can reach the database?"),
427 QueryMode::Natural
428 );
429 assert_eq!(
430 detect_mode("list users with weak passwords"),
431 QueryMode::Natural
432 );
433 assert_eq!(
434 detect_mode("\"find hosts connected to 10.0.0.1\""),
435 QueryMode::Natural
436 );
437 assert_eq!(
438 detect_mode("which hosts have critical vulnerabilities?"),
439 QueryMode::Natural
440 );
441 }
442
443 #[test]
444 fn test_edge_cases() {
445 assert_eq!(detect_mode(""), QueryMode::Unknown);
447
448 assert_eq!(detect_mode(" "), QueryMode::Unknown);
450
451 assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
454 assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
455 }
456}