reddb_server/storage/query/modes/
detect.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8 Sql,
10 Gremlin,
12 Cypher,
14 Sparql,
16 Path,
18 Natural,
20 Unknown,
22}
23
24pub fn detect_mode(input: &str) -> QueryMode {
26 let trimmed = input.trim();
27 let lower = trimmed.to_lowercase();
28
29 if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31 return QueryMode::Natural;
32 }
33
34 if lower.starts_with("g.") || lower.starts_with("__.") {
36 return QueryMode::Gremlin;
37 }
38
39 if lower.starts_with("path ") || lower.starts_with("paths ") {
41 return QueryMode::Path;
42 }
43
44 if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46 return QueryMode::Sparql;
47 }
48
49 if lower.starts_with("match ") || lower.starts_with("match(") {
51 return QueryMode::Cypher;
52 }
53
54 let first_token = lower.split_whitespace().next().unwrap_or("");
59 if matches!(
60 first_token,
61 "begin"
62 | "start"
63 | "commit"
64 | "rollback"
65 | "savepoint"
66 | "release"
67 | "end"
68 | "vacuum"
69 | "analyze"
70 | "reset"
71 | "copy"
72 | "refresh"
73 | "explain"
74 | "grant"
75 | "revoke"
76 | "attach"
77 | "detach"
78 | "simulate"
79 | "apply"
80 | "events"
81 ) {
82 return QueryMode::Sql;
83 }
84 if lower.starts_with("select ")
85 || lower.starts_with("from ")
86 || lower.starts_with("insert ")
87 || lower.starts_with("update ")
88 || lower.starts_with("delete ")
89 || lower.starts_with("truncate ")
90 || lower.starts_with("create ")
91 || lower.starts_with("drop ")
92 || lower.starts_with("alter ")
93 || lower.starts_with("vector ")
94 || lower.starts_with("hybrid ")
95 || lower.starts_with("graph ")
96 || lower.starts_with("queue ")
97 || lower.starts_with("events ")
98 || lower.starts_with("tree ")
99 || lower.starts_with("vault ")
100 || lower.starts_with("unseal vault ")
101 || lower.starts_with("rotate vault ")
102 || lower.starts_with("history vault ")
103 || lower.starts_with("list vault ")
104 || lower.starts_with("watch vault ")
105 || lower.starts_with("delete vault ")
106 || lower.starts_with("purge vault ")
107 || lower.starts_with("search ")
108 || lower.starts_with("ask ")
109 || lower.starts_with("put config ")
110 || lower.starts_with("get config ")
111 || lower.starts_with("resolve config ")
112 || lower.starts_with("rotate config ")
113 || lower.starts_with("delete config ")
114 || lower.starts_with("history config ")
115 || lower.starts_with("list config ")
116 || lower.starts_with("watch config ")
117 || lower.starts_with("incr config ")
118 || lower.starts_with("decr config ")
119 || lower.starts_with("add config ")
120 || lower.starts_with("invalidate config ")
121 || lower.starts_with("invalidate tags ")
122 || lower.starts_with("set config ")
123 || lower.starts_with("set secret ")
124 || lower.starts_with("set tenant")
125 || lower.starts_with("show config")
126 || lower.starts_with("show collections")
127 || lower.starts_with("show tables")
128 || lower.starts_with("show queues")
129 || lower.starts_with("show vectors")
130 || lower.starts_with("show documents")
131 || lower.starts_with("show timeseries")
132 || lower.starts_with("show graphs")
133 || lower.starts_with("kv ")
134 || lower.starts_with("show kv")
135 || lower.starts_with("show configs")
136 || lower.starts_with("show vaults")
137 || lower.starts_with("show schema")
138 || lower.starts_with("show indices")
139 || lower.starts_with("show sample ")
140 || lower.starts_with("show secret")
141 || lower.starts_with("show stats")
142 || lower.starts_with("show tenant")
143 || lower.starts_with("show policies")
144 || lower.starts_with("show effective ")
145 {
146 if lower.starts_with("select ") && lower.contains(" ?") {
148 return QueryMode::Sparql;
149 }
150 return QueryMode::Sql;
151 }
152
153 if is_natural_language(&lower) {
155 return QueryMode::Natural;
156 }
157
158 QueryMode::Unknown
159}
160
161fn has_sparql_pattern(lower: &str) -> bool {
163 let has_var = lower.contains(" ?") && !lower.contains("= ?") && !lower.contains("> ?");
168
169 let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
171
172 let has_prefix_pattern = lower.contains(":")
174 && (lower.contains(":<")
175 || lower.contains("> :")
176 || lower.contains(" :") && lower.contains("?"));
177
178 has_var || has_triple_pattern || has_prefix_pattern
179}
180
181fn is_natural_language(lower: &str) -> bool {
183 let question_starters = [
185 "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
186 "tell ", "display ", "search ", "look ",
187 ];
188
189 let nl_patterns = [
191 " with ",
192 " for ",
193 " that ",
194 " have ",
195 " has ",
196 " can ",
197 " are ",
198 " is ",
199 " all ",
200 " me ",
201 " the ",
202 " from ",
203 " to ",
204 " on ",
205 " in ",
206 "vulnerable",
207 "credential",
208 "password",
209 "user",
210 "host",
211 "service",
212 "connected",
213 "reachable",
214 "exposed",
215 "critical",
216 ];
217
218 for starter in question_starters.iter() {
220 if lower.starts_with(starter) {
221 return true;
222 }
223 }
224
225 let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
227
228 pattern_count >= 2
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_sql_detection() {
237 assert_eq!(
238 detect_mode("SELECT * FROM users WHERE id = 1"),
239 QueryMode::Sql
240 );
241 assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
242 assert_eq!(
243 detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
244 QueryMode::Sql
245 );
246 assert_eq!(
247 detect_mode("INSERT INTO users VALUES (1, 'alice')"),
248 QueryMode::Sql
249 );
250 assert_eq!(
251 detect_mode("UPDATE hosts SET status = 'active'"),
252 QueryMode::Sql
253 );
254 assert_eq!(
255 detect_mode("DELETE FROM logs WHERE age > 30"),
256 QueryMode::Sql
257 );
258 assert_eq!(
259 detect_mode("QUEUE GROUP CREATE tasks workers"),
260 QueryMode::Sql
261 );
262 assert_eq!(
263 detect_mode("EVENTS BACKFILL users TO audit"),
264 QueryMode::Sql
265 );
266 assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
267 assert_eq!(
268 detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
269 QueryMode::Sql
270 );
271 assert_eq!(
272 detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
273 QueryMode::Sql
274 );
275 assert_eq!(
276 detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
277 QueryMode::Sql
278 );
279 assert_eq!(
280 detect_mode("SET SECRET red.secret.api = 'x'"),
281 QueryMode::Sql
282 );
283 assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
284 assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
285 assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
286 assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
287 assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
288 assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
289 assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
290 assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
291 assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
292 assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
293 assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
294 assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
295 assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
296 assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
297 assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
298 assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
299 assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
300 }
301
302 #[test]
303 fn test_gremlin_detection() {
304 assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
305 assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
306 assert_eq!(
307 detect_mode("g.V().out('connects').in('has_service')"),
308 QueryMode::Gremlin
309 );
310 assert_eq!(
311 detect_mode("g.E().hasLabel('auth_access')"),
312 QueryMode::Gremlin
313 );
314 assert_eq!(
315 detect_mode("__.out('knows').has('name', 'bob')"),
316 QueryMode::Gremlin
317 );
318 assert_eq!(
319 detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
320 QueryMode::Gremlin
321 );
322 }
323
324 #[test]
325 fn test_cypher_detection() {
326 assert_eq!(
327 detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
328 QueryMode::Cypher
329 );
330 assert_eq!(
331 detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
332 QueryMode::Cypher
333 );
334 assert_eq!(
335 detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
336 QueryMode::Cypher
337 );
338 assert_eq!(
339 detect_mode("MATCH(a:User) RETURN a.name"),
340 QueryMode::Cypher
341 );
342 }
343
344 #[test]
345 fn test_sparql_detection() {
346 assert_eq!(
347 detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
348 QueryMode::Sparql
349 );
350 assert_eq!(
351 detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
352 QueryMode::Sparql
353 );
354 assert_eq!(
355 detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
356 QueryMode::Sparql
357 );
358 }
359
360 #[test]
361 fn test_path_detection() {
362 assert_eq!(
363 detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
364 QueryMode::Path
365 );
366 assert_eq!(
367 detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
368 QueryMode::Path
369 );
370 assert_eq!(
371 detect_mode("path from user('root') to service('ssh') via auth_access"),
372 QueryMode::Path
373 );
374 }
375
376 #[test]
377 fn test_natural_detection() {
378 assert_eq!(
379 detect_mode("find all hosts with ssh open"),
380 QueryMode::Natural
381 );
382 assert_eq!(
383 detect_mode("show me vulnerable services"),
384 QueryMode::Natural
385 );
386 assert_eq!(
387 detect_mode("what credentials can reach the database?"),
388 QueryMode::Natural
389 );
390 assert_eq!(
391 detect_mode("list users with weak passwords"),
392 QueryMode::Natural
393 );
394 assert_eq!(
395 detect_mode("\"find hosts connected to 10.0.0.1\""),
396 QueryMode::Natural
397 );
398 assert_eq!(
399 detect_mode("which hosts have critical vulnerabilities?"),
400 QueryMode::Natural
401 );
402 }
403
404 #[test]
405 fn test_edge_cases() {
406 assert_eq!(detect_mode(""), QueryMode::Unknown);
408
409 assert_eq!(detect_mode(" "), QueryMode::Unknown);
411
412 assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
415 assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
416 }
417}