Skip to main content

reddb_server/storage/query/modes/
detect.rs

1//! Query Mode Detection
2//!
3//! Automatically detects the query language based on syntax patterns.
4
5/// Supported query modes
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8    /// SQL-style: SELECT ... FROM ... WHERE
9    Sql,
10    /// Gremlin traversal: g.V().out().has(...)
11    Gremlin,
12    /// Cypher pattern matching: MATCH (a)-[r]->(b) RETURN
13    Cypher,
14    /// SPARQL RDF queries: SELECT ?var WHERE { ... }
15    Sparql,
16    /// Path queries: PATH FROM ... TO ... VIA
17    Path,
18    /// Natural language queries
19    Natural,
20    /// Unknown mode
21    Unknown,
22}
23
24/// Detect the query mode from input string
25pub fn detect_mode(input: &str) -> QueryMode {
26    let trimmed = input.trim();
27    let lower = trimmed.to_lowercase();
28
29    // Check for quoted natural language (starts with quote)
30    if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31        return QueryMode::Natural;
32    }
33
34    // Gremlin: starts with g. or __.
35    if lower.starts_with("g.") || lower.starts_with("__.") {
36        return QueryMode::Gremlin;
37    }
38
39    // Path: PATH or PATHS keyword at start
40    if lower.starts_with("path ") || lower.starts_with("paths ") {
41        return QueryMode::Path;
42    }
43
44    // SPARQL: has ?variable pattern or PREFIX keyword
45    if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46        return QueryMode::Sparql;
47    }
48
49    // Cypher: MATCH keyword at start
50    if lower.starts_with("match ") || lower.starts_with("match(") {
51        return QueryMode::Cypher;
52    }
53
54    // SQL: SELECT, FROM, INSERT, UPDATE, DELETE, CREATE, DROP, ALTER, GRAPH, SEARCH at start
55    // Plus transaction / admin one-word commands that have no trailing
56    // clause (BEGIN/COMMIT/ROLLBACK/SAVEPOINT/RELEASE/VACUUM/ANALYZE/
57    // RESET/TENANT/etc.) — matched by equality on the trimmed token.
58    let first_token = lower.split_whitespace().next().unwrap_or("");
59    if matches!(
60        first_token,
61        "begin"
62            | "start"
63            | "commit"
64            | "rollback"
65            | "savepoint"
66            | "release"
67            | "end"
68            | "vacuum"
69            | "analyze"
70            | "reset"
71            | "copy"
72            | "refresh"
73            | "explain"
74            | "grant"
75            | "revoke"
76            | "attach"
77            | "detach"
78            | "simulate"
79            | "apply"
80            | "events"
81    ) {
82        return QueryMode::Sql;
83    }
84    if lower.starts_with("select ")
85        || lower.starts_with("from ")
86        || lower.starts_with("insert ")
87        || lower.starts_with("update ")
88        || lower.starts_with("delete ")
89        || lower.starts_with("truncate ")
90        || lower.starts_with("create ")
91        || lower.starts_with("drop ")
92        || lower.starts_with("alter ")
93        || lower.starts_with("vector ")
94        || lower.starts_with("hybrid ")
95        || lower.starts_with("graph ")
96        || lower.starts_with("queue ")
97        || lower.starts_with("events ")
98        || lower.starts_with("tree ")
99        || lower.starts_with("hll ")
100        || lower.starts_with("sketch ")
101        || lower.starts_with("filter ")
102        || lower.starts_with("vault ")
103        || lower.starts_with("unseal vault ")
104        || lower.starts_with("rotate vault ")
105        || lower.starts_with("history vault ")
106        || lower.starts_with("list vault ")
107        || lower.starts_with("watch vault ")
108        || lower.starts_with("delete vault ")
109        || lower.starts_with("purge vault ")
110        || lower.starts_with("search ")
111        || lower.starts_with("ask ")
112        || lower.starts_with("put config ")
113        || lower.starts_with("get config ")
114        || lower.starts_with("resolve config ")
115        || lower.starts_with("rotate config ")
116        || lower.starts_with("delete config ")
117        || lower.starts_with("history config ")
118        || lower.starts_with("list config ")
119        || lower.starts_with("watch config ")
120        || lower.starts_with("incr config ")
121        || lower.starts_with("decr config ")
122        || lower.starts_with("add config ")
123        || lower.starts_with("invalidate config ")
124        || lower.starts_with("invalidate tags ")
125        || lower.starts_with("set config ")
126        || lower.starts_with("set secret ")
127        || lower.starts_with("set tenant")
128        || lower.starts_with("show config")
129        || lower.starts_with("show collections")
130        || lower.starts_with("show tables")
131        || lower.starts_with("show queues")
132        || lower.starts_with("show vectors")
133        || lower.starts_with("show documents")
134        || lower.starts_with("show timeseries")
135        || lower.starts_with("show graphs")
136        || lower.starts_with("kv ")
137        || lower.starts_with("show kv")
138        || lower.starts_with("show configs")
139        || lower.starts_with("show vaults")
140        || lower.starts_with("show schema")
141        || lower.starts_with("show indices")
142        || lower.starts_with("show sample ")
143        || lower.starts_with("show secret")
144        || lower.starts_with("show stats")
145        || lower.starts_with("show tenant")
146        || lower.starts_with("show policies")
147        || lower.starts_with("show effective ")
148    {
149        // But check if it's SPARQL-style SELECT with ?variable.
150        // Bare SQL placeholders (`?`) and numbered placeholders (`?1`)
151        // must stay in SQL mode for parameter binding.
152        if lower.starts_with("select ") && has_sparql_variable(&lower) {
153            return QueryMode::Sparql;
154        }
155        return QueryMode::Sql;
156    }
157
158    // Natural language detection: common question words and patterns
159    if is_natural_language(&lower) {
160        return QueryMode::Natural;
161    }
162
163    QueryMode::Unknown
164}
165
166/// Check for SPARQL-specific patterns
167fn has_sparql_pattern(lower: &str) -> bool {
168    // SPARQL variables start with ? or $
169    // SPARQL has WHERE { } with triple patterns
170
171    // Check for ?variable pattern. Bare SQL placeholders (`?`) and
172    // numbered placeholders (`?1`) are not SPARQL variables.
173    let has_var = has_sparql_variable(lower);
174
175    // Check for typical SPARQL structure
176    let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
177
178    // Check for RDF predicates (prefixed URIs like :predicate or prefix:pred)
179    let has_prefix_pattern = lower.contains(":")
180        && (lower.contains(":<")
181            || lower.contains("> :")
182            || lower.contains(" :") && lower.contains("?"));
183
184    has_var || has_triple_pattern || has_prefix_pattern
185}
186
187fn has_sparql_variable(input: &str) -> bool {
188    let bytes = input.as_bytes();
189    bytes
190        .windows(2)
191        .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
192}
193
194fn is_sparql_variable_start(byte: u8) -> bool {
195    byte.is_ascii_alphabetic() || byte == b'_'
196}
197
198/// Detect natural language patterns
199fn is_natural_language(lower: &str) -> bool {
200    // Question words
201    let question_starters = [
202        "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
203        "tell ", "display ", "search ", "look ",
204    ];
205
206    // Common natural language verbs/phrases
207    let nl_patterns = [
208        " with ",
209        " for ",
210        " that ",
211        " have ",
212        " has ",
213        " can ",
214        " are ",
215        " is ",
216        " all ",
217        " me ",
218        " the ",
219        " from ",
220        " to ",
221        " on ",
222        " in ",
223        "vulnerable",
224        "credential",
225        "password",
226        "user",
227        "host",
228        "service",
229        "connected",
230        "reachable",
231        "exposed",
232        "critical",
233    ];
234
235    // Check starters
236    for starter in question_starters.iter() {
237        if lower.starts_with(starter) {
238            return true;
239        }
240    }
241
242    // Check for multiple natural language patterns (at least 2)
243    let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
244
245    pattern_count >= 2
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_sql_detection() {
254        assert_eq!(
255            detect_mode("SELECT * FROM users WHERE id = 1"),
256            QueryMode::Sql
257        );
258        assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
259        assert_eq!(
260            detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
261            QueryMode::Sql
262        );
263        assert_eq!(
264            detect_mode("INSERT INTO users VALUES (1, 'alice')"),
265            QueryMode::Sql
266        );
267        assert_eq!(
268            detect_mode("UPDATE hosts SET status = 'active'"),
269            QueryMode::Sql
270        );
271        assert_eq!(
272            detect_mode("DELETE FROM logs WHERE age > 30"),
273            QueryMode::Sql
274        );
275        assert_eq!(
276            detect_mode("QUEUE GROUP CREATE tasks workers"),
277            QueryMode::Sql
278        );
279        assert_eq!(
280            detect_mode("EVENTS BACKFILL users TO audit"),
281            QueryMode::Sql
282        );
283        assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
284        assert_eq!(
285            detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
286            QueryMode::Sql
287        );
288        assert_eq!(
289            detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
290            QueryMode::Sql
291        );
292        assert_eq!(
293            detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
294            QueryMode::Sql
295        );
296        assert_eq!(
297            detect_mode("SELECT name FROM t WHERE id = ?"),
298            QueryMode::Sql
299        );
300        assert_eq!(
301            detect_mode("SELECT name FROM t WHERE id = ?1"),
302            QueryMode::Sql
303        );
304        assert_eq!(
305            detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
306            QueryMode::Sql
307        );
308        assert_eq!(
309            detect_mode("SET SECRET red.secret.api = 'x'"),
310            QueryMode::Sql
311        );
312        assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
313        assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
314        assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
315        assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
316        assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
317        assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
318        assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
319        assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
320        assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
321        assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
322        assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
323        assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
324        assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
325        assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
326        assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
327        assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
328        assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
329    }
330
331    #[test]
332    fn test_gremlin_detection() {
333        assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
334        assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
335        assert_eq!(
336            detect_mode("g.V().out('connects').in('has_service')"),
337            QueryMode::Gremlin
338        );
339        assert_eq!(
340            detect_mode("g.E().hasLabel('auth_access')"),
341            QueryMode::Gremlin
342        );
343        assert_eq!(
344            detect_mode("__.out('knows').has('name', 'bob')"),
345            QueryMode::Gremlin
346        );
347        assert_eq!(
348            detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
349            QueryMode::Gremlin
350        );
351    }
352
353    #[test]
354    fn test_cypher_detection() {
355        assert_eq!(
356            detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
357            QueryMode::Cypher
358        );
359        assert_eq!(
360            detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
361            QueryMode::Cypher
362        );
363        assert_eq!(
364            detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
365            QueryMode::Cypher
366        );
367        assert_eq!(
368            detect_mode("MATCH(a:User) RETURN a.name"),
369            QueryMode::Cypher
370        );
371    }
372
373    #[test]
374    fn test_sparql_detection() {
375        assert_eq!(
376            detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
377            QueryMode::Sparql
378        );
379        assert_eq!(
380            detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
381            QueryMode::Sparql
382        );
383        assert_eq!(
384            detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
385            QueryMode::Sparql
386        );
387        assert_eq!(
388            detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
389            QueryMode::Sparql
390        );
391    }
392
393    #[test]
394    fn test_path_detection() {
395        assert_eq!(
396            detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
397            QueryMode::Path
398        );
399        assert_eq!(
400            detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
401            QueryMode::Path
402        );
403        assert_eq!(
404            detect_mode("path from user('root') to service('ssh') via auth_access"),
405            QueryMode::Path
406        );
407    }
408
409    #[test]
410    fn test_natural_detection() {
411        assert_eq!(
412            detect_mode("find all hosts with ssh open"),
413            QueryMode::Natural
414        );
415        assert_eq!(
416            detect_mode("show me vulnerable services"),
417            QueryMode::Natural
418        );
419        assert_eq!(
420            detect_mode("what credentials can reach the database?"),
421            QueryMode::Natural
422        );
423        assert_eq!(
424            detect_mode("list users with weak passwords"),
425            QueryMode::Natural
426        );
427        assert_eq!(
428            detect_mode("\"find hosts connected to 10.0.0.1\""),
429            QueryMode::Natural
430        );
431        assert_eq!(
432            detect_mode("which hosts have critical vulnerabilities?"),
433            QueryMode::Natural
434        );
435    }
436
437    #[test]
438    fn test_edge_cases() {
439        // Empty input
440        assert_eq!(detect_mode(""), QueryMode::Unknown);
441
442        // Just whitespace
443        assert_eq!(detect_mode("   "), QueryMode::Unknown);
444
445        // Case insensitivity
446        assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); // No space after
447        assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
448        assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
449    }
450}