reddb_server/storage/query/modes/
detect.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8 Sql,
10 Gremlin,
12 Cypher,
14 Sparql,
16 Path,
18 Natural,
20 Unknown,
22}
23
24pub fn detect_mode(input: &str) -> QueryMode {
26 let trimmed = input.trim();
27 let lower = trimmed.to_lowercase();
28
29 if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31 return QueryMode::Natural;
32 }
33
34 if lower.starts_with("g.") || lower.starts_with("__.") {
36 return QueryMode::Gremlin;
37 }
38
39 if lower.starts_with("path ") || lower.starts_with("paths ") {
41 return QueryMode::Path;
42 }
43
44 if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46 return QueryMode::Sparql;
47 }
48
49 if lower.starts_with("match ") || lower.starts_with("match(") {
51 return QueryMode::Cypher;
52 }
53
54 let first_token = lower.split_whitespace().next().unwrap_or("");
59 if matches!(
60 first_token,
61 "begin"
62 | "start"
63 | "commit"
64 | "rollback"
65 | "savepoint"
66 | "release"
67 | "end"
68 | "vacuum"
69 | "analyze"
70 | "reset"
71 | "copy"
72 | "refresh"
73 | "explain"
74 | "grant"
75 | "revoke"
76 | "attach"
77 | "detach"
78 | "simulate"
79 | "apply"
80 | "events"
81 ) {
82 return QueryMode::Sql;
83 }
84 if lower.starts_with("select ")
85 || lower.starts_with("from ")
86 || lower.starts_with("insert ")
87 || lower.starts_with("update ")
88 || lower.starts_with("delete ")
89 || lower.starts_with("truncate ")
90 || lower.starts_with("create ")
91 || lower.starts_with("drop ")
92 || lower.starts_with("alter ")
93 || lower.starts_with("vector ")
94 || lower.starts_with("hybrid ")
95 || lower.starts_with("graph ")
96 || lower.starts_with("queue ")
97 || lower.starts_with("events ")
98 || lower.starts_with("tree ")
99 || lower.starts_with("hll ")
100 || lower.starts_with("sketch ")
101 || lower.starts_with("filter ")
102 || lower.starts_with("vault ")
103 || lower.starts_with("unseal vault ")
104 || lower.starts_with("rotate vault ")
105 || lower.starts_with("history vault ")
106 || lower.starts_with("list vault ")
107 || lower.starts_with("watch vault ")
108 || lower.starts_with("delete vault ")
109 || lower.starts_with("purge vault ")
110 || lower.starts_with("search ")
111 || lower.starts_with("ask ")
112 || lower.starts_with("put config ")
113 || lower.starts_with("get config ")
114 || lower.starts_with("resolve config ")
115 || lower.starts_with("rotate config ")
116 || lower.starts_with("delete config ")
117 || lower.starts_with("history config ")
118 || lower.starts_with("list config ")
119 || lower.starts_with("watch config ")
120 || lower.starts_with("incr config ")
121 || lower.starts_with("decr config ")
122 || lower.starts_with("add config ")
123 || lower.starts_with("invalidate config ")
124 || lower.starts_with("invalidate tags ")
125 || lower.starts_with("set config ")
126 || lower.starts_with("set secret ")
127 || lower.starts_with("set tenant")
128 || lower.starts_with("show config")
129 || lower.starts_with("show collections")
130 || lower.starts_with("show tables")
131 || lower.starts_with("show queues")
132 || lower.starts_with("show vectors")
133 || lower.starts_with("show documents")
134 || lower.starts_with("show timeseries")
135 || lower.starts_with("show graphs")
136 || lower.starts_with("kv ")
137 || lower.starts_with("show kv")
138 || lower.starts_with("show configs")
139 || lower.starts_with("show vaults")
140 || lower.starts_with("show schema")
141 || lower.starts_with("show indices")
142 || lower.starts_with("show sample ")
143 || lower.starts_with("show secret")
144 || lower.starts_with("show stats")
145 || lower.starts_with("show tenant")
146 || lower.starts_with("show policies")
147 || lower.starts_with("show effective ")
148 {
149 if lower.starts_with("select ") && has_sparql_variable(&lower) {
153 return QueryMode::Sparql;
154 }
155 return QueryMode::Sql;
156 }
157
158 if is_natural_language(&lower) {
160 return QueryMode::Natural;
161 }
162
163 QueryMode::Unknown
164}
165
166fn has_sparql_pattern(lower: &str) -> bool {
168 let has_var = has_sparql_variable(lower);
174
175 let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
177
178 let has_prefix_pattern = lower.contains(":")
180 && (lower.contains(":<")
181 || lower.contains("> :")
182 || lower.contains(" :") && lower.contains("?"));
183
184 has_var || has_triple_pattern || has_prefix_pattern
185}
186
187fn has_sparql_variable(input: &str) -> bool {
188 let bytes = input.as_bytes();
189 bytes
190 .windows(2)
191 .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
192}
193
194fn is_sparql_variable_start(byte: u8) -> bool {
195 byte.is_ascii_alphabetic() || byte == b'_'
196}
197
198fn is_natural_language(lower: &str) -> bool {
200 let question_starters = [
202 "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
203 "tell ", "display ", "search ", "look ",
204 ];
205
206 let nl_patterns = [
208 " with ",
209 " for ",
210 " that ",
211 " have ",
212 " has ",
213 " can ",
214 " are ",
215 " is ",
216 " all ",
217 " me ",
218 " the ",
219 " from ",
220 " to ",
221 " on ",
222 " in ",
223 "vulnerable",
224 "credential",
225 "password",
226 "user",
227 "host",
228 "service",
229 "connected",
230 "reachable",
231 "exposed",
232 "critical",
233 ];
234
235 for starter in question_starters.iter() {
237 if lower.starts_with(starter) {
238 return true;
239 }
240 }
241
242 let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
244
245 pattern_count >= 2
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_sql_detection() {
254 assert_eq!(
255 detect_mode("SELECT * FROM users WHERE id = 1"),
256 QueryMode::Sql
257 );
258 assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
259 assert_eq!(
260 detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
261 QueryMode::Sql
262 );
263 assert_eq!(
264 detect_mode("INSERT INTO users VALUES (1, 'alice')"),
265 QueryMode::Sql
266 );
267 assert_eq!(
268 detect_mode("UPDATE hosts SET status = 'active'"),
269 QueryMode::Sql
270 );
271 assert_eq!(
272 detect_mode("DELETE FROM logs WHERE age > 30"),
273 QueryMode::Sql
274 );
275 assert_eq!(
276 detect_mode("QUEUE GROUP CREATE tasks workers"),
277 QueryMode::Sql
278 );
279 assert_eq!(
280 detect_mode("EVENTS BACKFILL users TO audit"),
281 QueryMode::Sql
282 );
283 assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
284 assert_eq!(
285 detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
286 QueryMode::Sql
287 );
288 assert_eq!(
289 detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
290 QueryMode::Sql
291 );
292 assert_eq!(
293 detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
294 QueryMode::Sql
295 );
296 assert_eq!(
297 detect_mode("SELECT name FROM t WHERE id = ?"),
298 QueryMode::Sql
299 );
300 assert_eq!(
301 detect_mode("SELECT name FROM t WHERE id = ?1"),
302 QueryMode::Sql
303 );
304 assert_eq!(
305 detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
306 QueryMode::Sql
307 );
308 assert_eq!(
309 detect_mode("SET SECRET red.secret.api = 'x'"),
310 QueryMode::Sql
311 );
312 assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
313 assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
314 assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
315 assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
316 assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
317 assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
318 assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
319 assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
320 assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
321 assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
322 assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
323 assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
324 assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
325 assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
326 assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
327 assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
328 assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
329 }
330
331 #[test]
332 fn test_gremlin_detection() {
333 assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
334 assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
335 assert_eq!(
336 detect_mode("g.V().out('connects').in('has_service')"),
337 QueryMode::Gremlin
338 );
339 assert_eq!(
340 detect_mode("g.E().hasLabel('auth_access')"),
341 QueryMode::Gremlin
342 );
343 assert_eq!(
344 detect_mode("__.out('knows').has('name', 'bob')"),
345 QueryMode::Gremlin
346 );
347 assert_eq!(
348 detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
349 QueryMode::Gremlin
350 );
351 }
352
353 #[test]
354 fn test_cypher_detection() {
355 assert_eq!(
356 detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
357 QueryMode::Cypher
358 );
359 assert_eq!(
360 detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
361 QueryMode::Cypher
362 );
363 assert_eq!(
364 detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
365 QueryMode::Cypher
366 );
367 assert_eq!(
368 detect_mode("MATCH(a:User) RETURN a.name"),
369 QueryMode::Cypher
370 );
371 }
372
373 #[test]
374 fn test_sparql_detection() {
375 assert_eq!(
376 detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
377 QueryMode::Sparql
378 );
379 assert_eq!(
380 detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
381 QueryMode::Sparql
382 );
383 assert_eq!(
384 detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
385 QueryMode::Sparql
386 );
387 assert_eq!(
388 detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
389 QueryMode::Sparql
390 );
391 }
392
393 #[test]
394 fn test_path_detection() {
395 assert_eq!(
396 detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
397 QueryMode::Path
398 );
399 assert_eq!(
400 detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
401 QueryMode::Path
402 );
403 assert_eq!(
404 detect_mode("path from user('root') to service('ssh') via auth_access"),
405 QueryMode::Path
406 );
407 }
408
409 #[test]
410 fn test_natural_detection() {
411 assert_eq!(
412 detect_mode("find all hosts with ssh open"),
413 QueryMode::Natural
414 );
415 assert_eq!(
416 detect_mode("show me vulnerable services"),
417 QueryMode::Natural
418 );
419 assert_eq!(
420 detect_mode("what credentials can reach the database?"),
421 QueryMode::Natural
422 );
423 assert_eq!(
424 detect_mode("list users with weak passwords"),
425 QueryMode::Natural
426 );
427 assert_eq!(
428 detect_mode("\"find hosts connected to 10.0.0.1\""),
429 QueryMode::Natural
430 );
431 assert_eq!(
432 detect_mode("which hosts have critical vulnerabilities?"),
433 QueryMode::Natural
434 );
435 }
436
437 #[test]
438 fn test_edge_cases() {
439 assert_eq!(detect_mode(""), QueryMode::Unknown);
441
442 assert_eq!(detect_mode(" "), QueryMode::Unknown);
444
445 assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
448 assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
449 }
450}