Skip to main content

reddb_rql/modes/
detect.rs

1//! Query Mode Detection
2//!
3//! Automatically detects the query language based on syntax patterns.
4
5/// Supported query modes
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8    /// SQL-style: SELECT ... FROM ... WHERE
9    Sql,
10    /// Gremlin traversal: g.V().out().has(...)
11    Gremlin,
12    /// Cypher pattern matching: MATCH (a)-[r]->(b) RETURN
13    Cypher,
14    /// SPARQL RDF queries: SELECT ?var WHERE { ... }
15    Sparql,
16    /// Path queries: PATH FROM ... TO ... VIA
17    Path,
18    /// Natural language queries
19    Natural,
20    /// Unknown mode
21    Unknown,
22}
23
24/// Detect the query mode from input string
25pub fn detect_mode(input: &str) -> QueryMode {
26    let trimmed = input.trim();
27    let lower = trimmed.to_lowercase();
28
29    // Check for quoted natural language (starts with quote)
30    if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31        return QueryMode::Natural;
32    }
33
34    // Gremlin: starts with g. or __.
35    if lower.starts_with("g.") || lower.starts_with("__.") {
36        return QueryMode::Gremlin;
37    }
38
39    // Path: PATH or PATHS keyword at start
40    if lower.starts_with("path ") || lower.starts_with("paths ") {
41        return QueryMode::Path;
42    }
43
44    // SPARQL: has ?variable pattern or PREFIX keyword
45    if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46        return QueryMode::Sparql;
47    }
48
49    // Cypher: MATCH keyword at start
50    if lower.starts_with("match ") || lower.starts_with("match(") {
51        return QueryMode::Cypher;
52    }
53
54    // SQL: SELECT, FROM, INSERT, UPDATE, DELETE, CREATE, DROP, ALTER, GRAPH, SEARCH at start
55    // Plus transaction / admin one-word commands that have no trailing
56    // clause (BEGIN/COMMIT/ROLLBACK/SAVEPOINT/RELEASE/VACUUM/ANALYZE/
57    // RESET/TENANT/etc.) — matched by equality on the trimmed token.
58    let first_token = lower.split_whitespace().next().unwrap_or("");
59    if matches!(
60        first_token,
61        "begin"
62            | "start"
63            | "commit"
64            | "rollback"
65            | "savepoint"
66            | "release"
67            | "end"
68            | "vacuum"
69            | "analyze"
70            | "reset"
71            | "copy"
72            | "refresh"
73            | "explain"
74            | "grant"
75            | "revoke"
76            | "attach"
77            | "detach"
78            | "simulate"
79            | "lint"
80            | "migrate"
81            | "apply"
82            | "events"
83            | "describe"
84            | "desc"
85    ) {
86        return QueryMode::Sql;
87    }
88    if lower.starts_with("select ")
89        || lower.starts_with("from ")
90        || lower.starts_with("insert ")
91        || lower.starts_with("update ")
92        || lower.starts_with("delete ")
93        || lower.starts_with("truncate ")
94        || lower.starts_with("create ")
95        || lower.starts_with("drop ")
96        || lower.starts_with("alter ")
97        || lower.starts_with("vector ")
98        || lower.starts_with("hybrid ")
99        || lower.starts_with("graph ")
100        || lower.starts_with("queue ")
101        || lower.starts_with("events ")
102        || lower.starts_with("tree ")
103        || lower.starts_with("hll ")
104        || lower.starts_with("sketch ")
105        || lower.starts_with("filter ")
106        || lower.starts_with("vault ")
107        || lower.starts_with("unseal vault ")
108        || lower.starts_with("rotate vault ")
109        || lower.starts_with("history vault ")
110        || lower.starts_with("list vault ")
111        || lower.starts_with("list kv ")
112        || lower.starts_with("watch vault ")
113        || lower.starts_with("delete vault ")
114        || lower.starts_with("purge vault ")
115        || lower.starts_with("search ")
116        || lower.starts_with("ask ")
117        || lower.starts_with("put config ")
118        || lower.starts_with("get config ")
119        || lower.starts_with("resolve config ")
120        || lower.starts_with("rotate config ")
121        || lower.starts_with("delete config ")
122        || lower.starts_with("history config ")
123        || lower.starts_with("list config ")
124        || lower.starts_with("watch config ")
125        || lower.starts_with("incr config ")
126        || lower.starts_with("decr config ")
127        || lower.starts_with("add config ")
128        || lower.starts_with("invalidate config ")
129        || lower.starts_with("invalidate tags ")
130        || lower.starts_with("set config ")
131        || lower.starts_with("set secret ")
132        || lower.starts_with("set tenant")
133        || lower.starts_with("show create ")
134        || lower.starts_with("show config")
135        || lower.starts_with("show collections")
136        || lower.starts_with("show tables")
137        || lower.starts_with("show queues")
138        || lower.starts_with("show vectors")
139        || lower.starts_with("show documents")
140        || lower.starts_with("show timeseries")
141        || lower.starts_with("show graphs")
142        || lower.starts_with("kv ")
143        || lower.starts_with("show kv")
144        || lower.starts_with("show configs")
145        || lower.starts_with("show vaults")
146        || lower.starts_with("show schema")
147        || lower.starts_with("show indices")
148        || lower.starts_with("show indexes")
149        || lower.starts_with("show sample ")
150        || lower.starts_with("show secret")
151        || lower.starts_with("show stats")
152        || lower.starts_with("show tenant")
153        || lower.starts_with("show policies")
154        || lower.starts_with("show effective ")
155        || lower.starts_with("rank of ")
156        || lower.starts_with("rank range ")
157        || lower.starts_with("approx rank of ")
158        || lower.starts_with("approximate rank of ")
159        || lower.starts_with("zrank ")
160        || lower.starts_with("zrange ")
161        || lower.starts_with("describe ")
162        || lower.starts_with("desc ")
163    {
164        // But check if it's SPARQL-style SELECT with ?variable.
165        // Bare SQL placeholders (`?`) and numbered placeholders (`?1`)
166        // must stay in SQL mode for parameter binding.
167        if lower.starts_with("select ") && has_sparql_variable(&lower) {
168            return QueryMode::Sparql;
169        }
170        return QueryMode::Sql;
171    }
172
173    // Natural language detection: common question words and patterns
174    if is_natural_language(&lower) {
175        return QueryMode::Natural;
176    }
177
178    QueryMode::Unknown
179}
180
181/// Check for SPARQL-specific patterns
182fn has_sparql_pattern(lower: &str) -> bool {
183    // SPARQL variables start with ? or $
184    // SPARQL has WHERE { } with triple patterns
185
186    // Check for ?variable pattern. Bare SQL placeholders (`?`) and
187    // numbered placeholders (`?1`) are not SPARQL variables.
188    let has_var = has_sparql_variable(lower);
189
190    // Check for typical SPARQL structure
191    let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
192
193    // Check for RDF predicates (prefixed URIs like :predicate or prefix:pred)
194    let has_prefix_pattern = lower.contains(":")
195        && (lower.contains(":<")
196            || lower.contains("> :")
197            || lower.contains(" :") && lower.contains("?"));
198
199    has_var || has_triple_pattern || has_prefix_pattern
200}
201
202fn has_sparql_variable(input: &str) -> bool {
203    let bytes = input.as_bytes();
204    bytes
205        .windows(2)
206        .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
207}
208
209fn is_sparql_variable_start(byte: u8) -> bool {
210    byte.is_ascii_alphabetic() || byte == b'_'
211}
212
213/// Detect natural language patterns
214fn is_natural_language(lower: &str) -> bool {
215    // Question words
216    let question_starters = [
217        "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
218        "tell ", "display ", "search ", "look ",
219    ];
220
221    // Common natural language verbs/phrases
222    let nl_patterns = [
223        " with ",
224        " for ",
225        " that ",
226        " have ",
227        " has ",
228        " can ",
229        " are ",
230        " is ",
231        " all ",
232        " me ",
233        " the ",
234        " from ",
235        " to ",
236        " on ",
237        " in ",
238        "vulnerable",
239        "credential",
240        "password",
241        "user",
242        "host",
243        "service",
244        "connected",
245        "reachable",
246        "exposed",
247        "critical",
248    ];
249
250    // Check starters
251    for starter in question_starters.iter() {
252        if lower.starts_with(starter) {
253            return true;
254        }
255    }
256
257    // Check for multiple natural language patterns (at least 2)
258    let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
259
260    pattern_count >= 2
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_sql_detection() {
269        assert_eq!(
270            detect_mode("SELECT * FROM users WHERE id = 1"),
271            QueryMode::Sql
272        );
273        assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
274        assert_eq!(
275            detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
276            QueryMode::Sql
277        );
278        assert_eq!(
279            detect_mode("INSERT INTO users VALUES (1, 'alice')"),
280            QueryMode::Sql
281        );
282        assert_eq!(
283            detect_mode("UPDATE hosts SET status = 'active'"),
284            QueryMode::Sql
285        );
286        assert_eq!(
287            detect_mode("DELETE FROM logs WHERE age > 30"),
288            QueryMode::Sql
289        );
290        assert_eq!(
291            detect_mode("QUEUE GROUP CREATE tasks workers"),
292            QueryMode::Sql
293        );
294        assert_eq!(
295            detect_mode("EVENTS BACKFILL users TO audit"),
296            QueryMode::Sql
297        );
298        assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
299        assert_eq!(
300            detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
301            QueryMode::Sql
302        );
303        assert_eq!(
304            detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
305            QueryMode::Sql
306        );
307        assert_eq!(
308            detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
309            QueryMode::Sql
310        );
311        assert_eq!(
312            detect_mode("SELECT name FROM t WHERE id = ?"),
313            QueryMode::Sql
314        );
315        assert_eq!(
316            detect_mode("SELECT name FROM t WHERE id = ?1"),
317            QueryMode::Sql
318        );
319        assert_eq!(
320            detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
321            QueryMode::Sql
322        );
323        assert_eq!(
324            detect_mode("SET SECRET red.secret.api = 'x'"),
325            QueryMode::Sql
326        );
327        assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
328        assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
329        assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
330        assert_eq!(
331            detect_mode("LIST KV settings PREFIX feature LIMIT 10"),
332            QueryMode::Sql
333        );
334        assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
335        assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
336        assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
337        assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
338        assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
339        assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
340        assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
341        assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
342        assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
343        assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
344        assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
345        assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
346        assert_eq!(detect_mode("SHOW CREATE TABLE users"), QueryMode::Sql);
347        assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
348        assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
349        assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
350        assert_eq!(detect_mode("SHOW INDEXES"), QueryMode::Sql);
351        assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
352    }
353
354    #[test]
355    fn test_gremlin_detection() {
356        assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
357        assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
358        assert_eq!(
359            detect_mode("g.V().out('connects').in('has_service')"),
360            QueryMode::Gremlin
361        );
362        assert_eq!(
363            detect_mode("g.E().hasLabel('auth_access')"),
364            QueryMode::Gremlin
365        );
366        assert_eq!(
367            detect_mode("__.out('knows').has('name', 'bob')"),
368            QueryMode::Gremlin
369        );
370        assert_eq!(
371            detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
372            QueryMode::Gremlin
373        );
374    }
375
376    #[test]
377    fn test_cypher_detection() {
378        assert_eq!(
379            detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
380            QueryMode::Cypher
381        );
382        assert_eq!(
383            detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
384            QueryMode::Cypher
385        );
386        assert_eq!(
387            detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
388            QueryMode::Cypher
389        );
390        assert_eq!(
391            detect_mode("MATCH(a:User) RETURN a.name"),
392            QueryMode::Cypher
393        );
394    }
395
396    #[test]
397    fn test_sparql_detection() {
398        assert_eq!(
399            detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
400            QueryMode::Sparql
401        );
402        assert_eq!(
403            detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
404            QueryMode::Sparql
405        );
406        assert_eq!(
407            detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
408            QueryMode::Sparql
409        );
410        assert_eq!(
411            detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
412            QueryMode::Sparql
413        );
414    }
415
416    #[test]
417    fn test_path_detection() {
418        assert_eq!(
419            detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
420            QueryMode::Path
421        );
422        assert_eq!(
423            detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
424            QueryMode::Path
425        );
426        assert_eq!(
427            detect_mode("path from user('root') to service('ssh') via auth_access"),
428            QueryMode::Path
429        );
430    }
431
432    #[test]
433    fn test_natural_detection() {
434        assert_eq!(
435            detect_mode("find all hosts with ssh open"),
436            QueryMode::Natural
437        );
438        assert_eq!(
439            detect_mode("show me vulnerable services"),
440            QueryMode::Natural
441        );
442        assert_eq!(
443            detect_mode("what credentials can reach the database?"),
444            QueryMode::Natural
445        );
446        assert_eq!(
447            detect_mode("list users with weak passwords"),
448            QueryMode::Natural
449        );
450        assert_eq!(
451            detect_mode("\"find hosts connected to 10.0.0.1\""),
452            QueryMode::Natural
453        );
454        assert_eq!(
455            detect_mode("which hosts have critical vulnerabilities?"),
456            QueryMode::Natural
457        );
458    }
459
460    #[test]
461    fn test_edge_cases() {
462        // Empty input
463        assert_eq!(detect_mode(""), QueryMode::Unknown);
464
465        // Just whitespace
466        assert_eq!(detect_mode("   "), QueryMode::Unknown);
467
468        // Case insensitivity
469        assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); // No space after
470        assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
471        assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
472    }
473}