Skip to main content

reddb_server/storage/query/modes/
detect.rs

1//! Query Mode Detection
2//!
3//! Automatically detects the query language based on syntax patterns.
4
5/// Supported query modes
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8    /// SQL-style: SELECT ... FROM ... WHERE
9    Sql,
10    /// Gremlin traversal: g.V().out().has(...)
11    Gremlin,
12    /// Cypher pattern matching: MATCH (a)-[r]->(b) RETURN
13    Cypher,
14    /// SPARQL RDF queries: SELECT ?var WHERE { ... }
15    Sparql,
16    /// Path queries: PATH FROM ... TO ... VIA
17    Path,
18    /// Natural language queries
19    Natural,
20    /// Unknown mode
21    Unknown,
22}
23
24/// Detect the query mode from input string
25pub fn detect_mode(input: &str) -> QueryMode {
26    let trimmed = input.trim();
27    let lower = trimmed.to_lowercase();
28
29    // Check for quoted natural language (starts with quote)
30    if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31        return QueryMode::Natural;
32    }
33
34    // Gremlin: starts with g. or __.
35    if lower.starts_with("g.") || lower.starts_with("__.") {
36        return QueryMode::Gremlin;
37    }
38
39    // Path: PATH or PATHS keyword at start
40    if lower.starts_with("path ") || lower.starts_with("paths ") {
41        return QueryMode::Path;
42    }
43
44    // SPARQL: has ?variable pattern or PREFIX keyword
45    if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46        return QueryMode::Sparql;
47    }
48
49    // Cypher: MATCH keyword at start
50    if lower.starts_with("match ") || lower.starts_with("match(") {
51        return QueryMode::Cypher;
52    }
53
54    // SQL: SELECT, FROM, INSERT, UPDATE, DELETE, CREATE, DROP, ALTER, GRAPH, SEARCH at start
55    // Plus transaction / admin one-word commands that have no trailing
56    // clause (BEGIN/COMMIT/ROLLBACK/SAVEPOINT/RELEASE/VACUUM/ANALYZE/
57    // RESET/TENANT/etc.) — matched by equality on the trimmed token.
58    let first_token = lower.split_whitespace().next().unwrap_or("");
59    if matches!(
60        first_token,
61        "begin"
62            | "start"
63            | "commit"
64            | "rollback"
65            | "savepoint"
66            | "release"
67            | "end"
68            | "vacuum"
69            | "analyze"
70            | "reset"
71            | "copy"
72            | "refresh"
73            | "explain"
74            | "grant"
75            | "revoke"
76            | "attach"
77            | "detach"
78            | "simulate"
79            | "lint"
80            | "migrate"
81            | "apply"
82            | "events"
83            | "describe"
84            | "desc"
85    ) {
86        return QueryMode::Sql;
87    }
88    if lower.starts_with("select ")
89        || lower.starts_with("from ")
90        || lower.starts_with("insert ")
91        || lower.starts_with("update ")
92        || lower.starts_with("delete ")
93        || lower.starts_with("truncate ")
94        || lower.starts_with("create ")
95        || lower.starts_with("drop ")
96        || lower.starts_with("alter ")
97        || lower.starts_with("vector ")
98        || lower.starts_with("hybrid ")
99        || lower.starts_with("graph ")
100        || lower.starts_with("queue ")
101        || lower.starts_with("events ")
102        || lower.starts_with("tree ")
103        || lower.starts_with("hll ")
104        || lower.starts_with("sketch ")
105        || lower.starts_with("filter ")
106        || lower.starts_with("vault ")
107        || lower.starts_with("unseal vault ")
108        || lower.starts_with("rotate vault ")
109        || lower.starts_with("history vault ")
110        || lower.starts_with("list vault ")
111        || lower.starts_with("watch vault ")
112        || lower.starts_with("delete vault ")
113        || lower.starts_with("purge vault ")
114        || lower.starts_with("search ")
115        || lower.starts_with("ask ")
116        || lower.starts_with("put config ")
117        || lower.starts_with("get config ")
118        || lower.starts_with("resolve config ")
119        || lower.starts_with("rotate config ")
120        || lower.starts_with("delete config ")
121        || lower.starts_with("history config ")
122        || lower.starts_with("list config ")
123        || lower.starts_with("watch config ")
124        || lower.starts_with("incr config ")
125        || lower.starts_with("decr config ")
126        || lower.starts_with("add config ")
127        || lower.starts_with("invalidate config ")
128        || lower.starts_with("invalidate tags ")
129        || lower.starts_with("set config ")
130        || lower.starts_with("set secret ")
131        || lower.starts_with("set tenant")
132        || lower.starts_with("show create ")
133        || lower.starts_with("show config")
134        || lower.starts_with("show collections")
135        || lower.starts_with("show tables")
136        || lower.starts_with("show queues")
137        || lower.starts_with("show vectors")
138        || lower.starts_with("show documents")
139        || lower.starts_with("show timeseries")
140        || lower.starts_with("show graphs")
141        || lower.starts_with("kv ")
142        || lower.starts_with("show kv")
143        || lower.starts_with("show configs")
144        || lower.starts_with("show vaults")
145        || lower.starts_with("show schema")
146        || lower.starts_with("show indices")
147        || lower.starts_with("show indexes")
148        || lower.starts_with("show sample ")
149        || lower.starts_with("show secret")
150        || lower.starts_with("show stats")
151        || lower.starts_with("show tenant")
152        || lower.starts_with("show policies")
153        || lower.starts_with("show effective ")
154        || lower.starts_with("describe ")
155        || lower.starts_with("desc ")
156    {
157        // But check if it's SPARQL-style SELECT with ?variable.
158        // Bare SQL placeholders (`?`) and numbered placeholders (`?1`)
159        // must stay in SQL mode for parameter binding.
160        if lower.starts_with("select ") && has_sparql_variable(&lower) {
161            return QueryMode::Sparql;
162        }
163        return QueryMode::Sql;
164    }
165
166    // Natural language detection: common question words and patterns
167    if is_natural_language(&lower) {
168        return QueryMode::Natural;
169    }
170
171    QueryMode::Unknown
172}
173
174/// Check for SPARQL-specific patterns
175fn has_sparql_pattern(lower: &str) -> bool {
176    // SPARQL variables start with ? or $
177    // SPARQL has WHERE { } with triple patterns
178
179    // Check for ?variable pattern. Bare SQL placeholders (`?`) and
180    // numbered placeholders (`?1`) are not SPARQL variables.
181    let has_var = has_sparql_variable(lower);
182
183    // Check for typical SPARQL structure
184    let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
185
186    // Check for RDF predicates (prefixed URIs like :predicate or prefix:pred)
187    let has_prefix_pattern = lower.contains(":")
188        && (lower.contains(":<")
189            || lower.contains("> :")
190            || lower.contains(" :") && lower.contains("?"));
191
192    has_var || has_triple_pattern || has_prefix_pattern
193}
194
195fn has_sparql_variable(input: &str) -> bool {
196    let bytes = input.as_bytes();
197    bytes
198        .windows(2)
199        .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
200}
201
202fn is_sparql_variable_start(byte: u8) -> bool {
203    byte.is_ascii_alphabetic() || byte == b'_'
204}
205
206/// Detect natural language patterns
207fn is_natural_language(lower: &str) -> bool {
208    // Question words
209    let question_starters = [
210        "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
211        "tell ", "display ", "search ", "look ",
212    ];
213
214    // Common natural language verbs/phrases
215    let nl_patterns = [
216        " with ",
217        " for ",
218        " that ",
219        " have ",
220        " has ",
221        " can ",
222        " are ",
223        " is ",
224        " all ",
225        " me ",
226        " the ",
227        " from ",
228        " to ",
229        " on ",
230        " in ",
231        "vulnerable",
232        "credential",
233        "password",
234        "user",
235        "host",
236        "service",
237        "connected",
238        "reachable",
239        "exposed",
240        "critical",
241    ];
242
243    // Check starters
244    for starter in question_starters.iter() {
245        if lower.starts_with(starter) {
246            return true;
247        }
248    }
249
250    // Check for multiple natural language patterns (at least 2)
251    let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
252
253    pattern_count >= 2
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_sql_detection() {
262        assert_eq!(
263            detect_mode("SELECT * FROM users WHERE id = 1"),
264            QueryMode::Sql
265        );
266        assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
267        assert_eq!(
268            detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
269            QueryMode::Sql
270        );
271        assert_eq!(
272            detect_mode("INSERT INTO users VALUES (1, 'alice')"),
273            QueryMode::Sql
274        );
275        assert_eq!(
276            detect_mode("UPDATE hosts SET status = 'active'"),
277            QueryMode::Sql
278        );
279        assert_eq!(
280            detect_mode("DELETE FROM logs WHERE age > 30"),
281            QueryMode::Sql
282        );
283        assert_eq!(
284            detect_mode("QUEUE GROUP CREATE tasks workers"),
285            QueryMode::Sql
286        );
287        assert_eq!(
288            detect_mode("EVENTS BACKFILL users TO audit"),
289            QueryMode::Sql
290        );
291        assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
292        assert_eq!(
293            detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
294            QueryMode::Sql
295        );
296        assert_eq!(
297            detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
298            QueryMode::Sql
299        );
300        assert_eq!(
301            detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
302            QueryMode::Sql
303        );
304        assert_eq!(
305            detect_mode("SELECT name FROM t WHERE id = ?"),
306            QueryMode::Sql
307        );
308        assert_eq!(
309            detect_mode("SELECT name FROM t WHERE id = ?1"),
310            QueryMode::Sql
311        );
312        assert_eq!(
313            detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
314            QueryMode::Sql
315        );
316        assert_eq!(
317            detect_mode("SET SECRET red.secret.api = 'x'"),
318            QueryMode::Sql
319        );
320        assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
321        assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
322        assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
323        assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
324        assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
325        assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
326        assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
327        assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
328        assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
329        assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
330        assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
331        assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
332        assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
333        assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
334        assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
335        assert_eq!(detect_mode("SHOW CREATE TABLE users"), QueryMode::Sql);
336        assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
337        assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
338        assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
339        assert_eq!(detect_mode("SHOW INDEXES"), QueryMode::Sql);
340        assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
341    }
342
343    #[test]
344    fn test_gremlin_detection() {
345        assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
346        assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
347        assert_eq!(
348            detect_mode("g.V().out('connects').in('has_service')"),
349            QueryMode::Gremlin
350        );
351        assert_eq!(
352            detect_mode("g.E().hasLabel('auth_access')"),
353            QueryMode::Gremlin
354        );
355        assert_eq!(
356            detect_mode("__.out('knows').has('name', 'bob')"),
357            QueryMode::Gremlin
358        );
359        assert_eq!(
360            detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
361            QueryMode::Gremlin
362        );
363    }
364
365    #[test]
366    fn test_cypher_detection() {
367        assert_eq!(
368            detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
369            QueryMode::Cypher
370        );
371        assert_eq!(
372            detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
373            QueryMode::Cypher
374        );
375        assert_eq!(
376            detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
377            QueryMode::Cypher
378        );
379        assert_eq!(
380            detect_mode("MATCH(a:User) RETURN a.name"),
381            QueryMode::Cypher
382        );
383    }
384
385    #[test]
386    fn test_sparql_detection() {
387        assert_eq!(
388            detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
389            QueryMode::Sparql
390        );
391        assert_eq!(
392            detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
393            QueryMode::Sparql
394        );
395        assert_eq!(
396            detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
397            QueryMode::Sparql
398        );
399        assert_eq!(
400            detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
401            QueryMode::Sparql
402        );
403    }
404
405    #[test]
406    fn test_path_detection() {
407        assert_eq!(
408            detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
409            QueryMode::Path
410        );
411        assert_eq!(
412            detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
413            QueryMode::Path
414        );
415        assert_eq!(
416            detect_mode("path from user('root') to service('ssh') via auth_access"),
417            QueryMode::Path
418        );
419    }
420
421    #[test]
422    fn test_natural_detection() {
423        assert_eq!(
424            detect_mode("find all hosts with ssh open"),
425            QueryMode::Natural
426        );
427        assert_eq!(
428            detect_mode("show me vulnerable services"),
429            QueryMode::Natural
430        );
431        assert_eq!(
432            detect_mode("what credentials can reach the database?"),
433            QueryMode::Natural
434        );
435        assert_eq!(
436            detect_mode("list users with weak passwords"),
437            QueryMode::Natural
438        );
439        assert_eq!(
440            detect_mode("\"find hosts connected to 10.0.0.1\""),
441            QueryMode::Natural
442        );
443        assert_eq!(
444            detect_mode("which hosts have critical vulnerabilities?"),
445            QueryMode::Natural
446        );
447    }
448
449    #[test]
450    fn test_edge_cases() {
451        // Empty input
452        assert_eq!(detect_mode(""), QueryMode::Unknown);
453
454        // Just whitespace
455        assert_eq!(detect_mode("   "), QueryMode::Unknown);
456
457        // Case insensitivity
458        assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); // No space after
459        assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
460        assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
461    }
462}