Skip to main content

reddb_server/storage/query/modes/
detect.rs

1//! Query Mode Detection
2//!
3//! Automatically detects the query language based on syntax patterns.
4
5/// Supported query modes
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8    /// SQL-style: SELECT ... FROM ... WHERE
9    Sql,
10    /// Gremlin traversal: g.V().out().has(...)
11    Gremlin,
12    /// Cypher pattern matching: MATCH (a)-[r]->(b) RETURN
13    Cypher,
14    /// SPARQL RDF queries: SELECT ?var WHERE { ... }
15    Sparql,
16    /// Path queries: PATH FROM ... TO ... VIA
17    Path,
18    /// Natural language queries
19    Natural,
20    /// Unknown mode
21    Unknown,
22}
23
24/// Detect the query mode from input string
25pub fn detect_mode(input: &str) -> QueryMode {
26    let trimmed = input.trim();
27    let lower = trimmed.to_lowercase();
28
29    // Check for quoted natural language (starts with quote)
30    if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31        return QueryMode::Natural;
32    }
33
34    // Gremlin: starts with g. or __.
35    if lower.starts_with("g.") || lower.starts_with("__.") {
36        return QueryMode::Gremlin;
37    }
38
39    // Path: PATH or PATHS keyword at start
40    if lower.starts_with("path ") || lower.starts_with("paths ") {
41        return QueryMode::Path;
42    }
43
44    // SPARQL: has ?variable pattern or PREFIX keyword
45    if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46        return QueryMode::Sparql;
47    }
48
49    // Cypher: MATCH keyword at start
50    if lower.starts_with("match ") || lower.starts_with("match(") {
51        return QueryMode::Cypher;
52    }
53
54    // SQL: SELECT, FROM, INSERT, UPDATE, DELETE, CREATE, DROP, ALTER, GRAPH, SEARCH at start
55    // Plus transaction / admin one-word commands that have no trailing
56    // clause (BEGIN/COMMIT/ROLLBACK/SAVEPOINT/RELEASE/VACUUM/ANALYZE/
57    // RESET/TENANT/etc.) — matched by equality on the trimmed token.
58    let first_token = lower.split_whitespace().next().unwrap_or("");
59    if matches!(
60        first_token,
61        "begin"
62            | "start"
63            | "commit"
64            | "rollback"
65            | "savepoint"
66            | "release"
67            | "end"
68            | "vacuum"
69            | "analyze"
70            | "reset"
71            | "copy"
72            | "refresh"
73            | "explain"
74            | "grant"
75            | "revoke"
76            | "attach"
77            | "detach"
78            | "simulate"
79            | "lint"
80            | "migrate"
81            | "apply"
82            | "events"
83            | "describe"
84            | "desc"
85    ) {
86        return QueryMode::Sql;
87    }
88    if lower.starts_with("select ")
89        || lower.starts_with("from ")
90        || lower.starts_with("insert ")
91        || lower.starts_with("update ")
92        || lower.starts_with("delete ")
93        || lower.starts_with("truncate ")
94        || lower.starts_with("create ")
95        || lower.starts_with("drop ")
96        || lower.starts_with("alter ")
97        || lower.starts_with("vector ")
98        || lower.starts_with("hybrid ")
99        || lower.starts_with("graph ")
100        || lower.starts_with("queue ")
101        || lower.starts_with("events ")
102        || lower.starts_with("tree ")
103        || lower.starts_with("hll ")
104        || lower.starts_with("sketch ")
105        || lower.starts_with("filter ")
106        || lower.starts_with("vault ")
107        || lower.starts_with("unseal vault ")
108        || lower.starts_with("rotate vault ")
109        || lower.starts_with("history vault ")
110        || lower.starts_with("list vault ")
111        || lower.starts_with("watch vault ")
112        || lower.starts_with("delete vault ")
113        || lower.starts_with("purge vault ")
114        || lower.starts_with("search ")
115        || lower.starts_with("ask ")
116        || lower.starts_with("put config ")
117        || lower.starts_with("get config ")
118        || lower.starts_with("resolve config ")
119        || lower.starts_with("rotate config ")
120        || lower.starts_with("delete config ")
121        || lower.starts_with("history config ")
122        || lower.starts_with("list config ")
123        || lower.starts_with("watch config ")
124        || lower.starts_with("incr config ")
125        || lower.starts_with("decr config ")
126        || lower.starts_with("add config ")
127        || lower.starts_with("invalidate config ")
128        || lower.starts_with("invalidate tags ")
129        || lower.starts_with("set config ")
130        || lower.starts_with("set secret ")
131        || lower.starts_with("set tenant")
132        || lower.starts_with("show create ")
133        || lower.starts_with("show config")
134        || lower.starts_with("show collections")
135        || lower.starts_with("show tables")
136        || lower.starts_with("show queues")
137        || lower.starts_with("show vectors")
138        || lower.starts_with("show documents")
139        || lower.starts_with("show timeseries")
140        || lower.starts_with("show graphs")
141        || lower.starts_with("kv ")
142        || lower.starts_with("show kv")
143        || lower.starts_with("show configs")
144        || lower.starts_with("show vaults")
145        || lower.starts_with("show schema")
146        || lower.starts_with("show indices")
147        || lower.starts_with("show indexes")
148        || lower.starts_with("show sample ")
149        || lower.starts_with("show secret")
150        || lower.starts_with("show stats")
151        || lower.starts_with("show tenant")
152        || lower.starts_with("show policies")
153        || lower.starts_with("show effective ")
154        || lower.starts_with("rank of ")
155        || lower.starts_with("rank range ")
156        || lower.starts_with("approx rank of ")
157        || lower.starts_with("approximate rank of ")
158        || lower.starts_with("zrank ")
159        || lower.starts_with("zrange ")
160        || lower.starts_with("describe ")
161        || lower.starts_with("desc ")
162    {
163        // But check if it's SPARQL-style SELECT with ?variable.
164        // Bare SQL placeholders (`?`) and numbered placeholders (`?1`)
165        // must stay in SQL mode for parameter binding.
166        if lower.starts_with("select ") && has_sparql_variable(&lower) {
167            return QueryMode::Sparql;
168        }
169        return QueryMode::Sql;
170    }
171
172    // Natural language detection: common question words and patterns
173    if is_natural_language(&lower) {
174        return QueryMode::Natural;
175    }
176
177    QueryMode::Unknown
178}
179
180/// Check for SPARQL-specific patterns
181fn has_sparql_pattern(lower: &str) -> bool {
182    // SPARQL variables start with ? or $
183    // SPARQL has WHERE { } with triple patterns
184
185    // Check for ?variable pattern. Bare SQL placeholders (`?`) and
186    // numbered placeholders (`?1`) are not SPARQL variables.
187    let has_var = has_sparql_variable(lower);
188
189    // Check for typical SPARQL structure
190    let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
191
192    // Check for RDF predicates (prefixed URIs like :predicate or prefix:pred)
193    let has_prefix_pattern = lower.contains(":")
194        && (lower.contains(":<")
195            || lower.contains("> :")
196            || lower.contains(" :") && lower.contains("?"));
197
198    has_var || has_triple_pattern || has_prefix_pattern
199}
200
201fn has_sparql_variable(input: &str) -> bool {
202    let bytes = input.as_bytes();
203    bytes
204        .windows(2)
205        .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
206}
207
208fn is_sparql_variable_start(byte: u8) -> bool {
209    byte.is_ascii_alphabetic() || byte == b'_'
210}
211
212/// Detect natural language patterns
213fn is_natural_language(lower: &str) -> bool {
214    // Question words
215    let question_starters = [
216        "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
217        "tell ", "display ", "search ", "look ",
218    ];
219
220    // Common natural language verbs/phrases
221    let nl_patterns = [
222        " with ",
223        " for ",
224        " that ",
225        " have ",
226        " has ",
227        " can ",
228        " are ",
229        " is ",
230        " all ",
231        " me ",
232        " the ",
233        " from ",
234        " to ",
235        " on ",
236        " in ",
237        "vulnerable",
238        "credential",
239        "password",
240        "user",
241        "host",
242        "service",
243        "connected",
244        "reachable",
245        "exposed",
246        "critical",
247    ];
248
249    // Check starters
250    for starter in question_starters.iter() {
251        if lower.starts_with(starter) {
252            return true;
253        }
254    }
255
256    // Check for multiple natural language patterns (at least 2)
257    let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
258
259    pattern_count >= 2
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_sql_detection() {
268        assert_eq!(
269            detect_mode("SELECT * FROM users WHERE id = 1"),
270            QueryMode::Sql
271        );
272        assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
273        assert_eq!(
274            detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
275            QueryMode::Sql
276        );
277        assert_eq!(
278            detect_mode("INSERT INTO users VALUES (1, 'alice')"),
279            QueryMode::Sql
280        );
281        assert_eq!(
282            detect_mode("UPDATE hosts SET status = 'active'"),
283            QueryMode::Sql
284        );
285        assert_eq!(
286            detect_mode("DELETE FROM logs WHERE age > 30"),
287            QueryMode::Sql
288        );
289        assert_eq!(
290            detect_mode("QUEUE GROUP CREATE tasks workers"),
291            QueryMode::Sql
292        );
293        assert_eq!(
294            detect_mode("EVENTS BACKFILL users TO audit"),
295            QueryMode::Sql
296        );
297        assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
298        assert_eq!(
299            detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
300            QueryMode::Sql
301        );
302        assert_eq!(
303            detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
304            QueryMode::Sql
305        );
306        assert_eq!(
307            detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
308            QueryMode::Sql
309        );
310        assert_eq!(
311            detect_mode("SELECT name FROM t WHERE id = ?"),
312            QueryMode::Sql
313        );
314        assert_eq!(
315            detect_mode("SELECT name FROM t WHERE id = ?1"),
316            QueryMode::Sql
317        );
318        assert_eq!(
319            detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
320            QueryMode::Sql
321        );
322        assert_eq!(
323            detect_mode("SET SECRET red.secret.api = 'x'"),
324            QueryMode::Sql
325        );
326        assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
327        assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
328        assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
329        assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
330        assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
331        assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
332        assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
333        assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
334        assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
335        assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
336        assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
337        assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
338        assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
339        assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
340        assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
341        assert_eq!(detect_mode("SHOW CREATE TABLE users"), QueryMode::Sql);
342        assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
343        assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
344        assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
345        assert_eq!(detect_mode("SHOW INDEXES"), QueryMode::Sql);
346        assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
347    }
348
349    #[test]
350    fn test_gremlin_detection() {
351        assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
352        assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
353        assert_eq!(
354            detect_mode("g.V().out('connects').in('has_service')"),
355            QueryMode::Gremlin
356        );
357        assert_eq!(
358            detect_mode("g.E().hasLabel('auth_access')"),
359            QueryMode::Gremlin
360        );
361        assert_eq!(
362            detect_mode("__.out('knows').has('name', 'bob')"),
363            QueryMode::Gremlin
364        );
365        assert_eq!(
366            detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
367            QueryMode::Gremlin
368        );
369    }
370
371    #[test]
372    fn test_cypher_detection() {
373        assert_eq!(
374            detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
375            QueryMode::Cypher
376        );
377        assert_eq!(
378            detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
379            QueryMode::Cypher
380        );
381        assert_eq!(
382            detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
383            QueryMode::Cypher
384        );
385        assert_eq!(
386            detect_mode("MATCH(a:User) RETURN a.name"),
387            QueryMode::Cypher
388        );
389    }
390
391    #[test]
392    fn test_sparql_detection() {
393        assert_eq!(
394            detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
395            QueryMode::Sparql
396        );
397        assert_eq!(
398            detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
399            QueryMode::Sparql
400        );
401        assert_eq!(
402            detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
403            QueryMode::Sparql
404        );
405        assert_eq!(
406            detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
407            QueryMode::Sparql
408        );
409    }
410
411    #[test]
412    fn test_path_detection() {
413        assert_eq!(
414            detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
415            QueryMode::Path
416        );
417        assert_eq!(
418            detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
419            QueryMode::Path
420        );
421        assert_eq!(
422            detect_mode("path from user('root') to service('ssh') via auth_access"),
423            QueryMode::Path
424        );
425    }
426
427    #[test]
428    fn test_natural_detection() {
429        assert_eq!(
430            detect_mode("find all hosts with ssh open"),
431            QueryMode::Natural
432        );
433        assert_eq!(
434            detect_mode("show me vulnerable services"),
435            QueryMode::Natural
436        );
437        assert_eq!(
438            detect_mode("what credentials can reach the database?"),
439            QueryMode::Natural
440        );
441        assert_eq!(
442            detect_mode("list users with weak passwords"),
443            QueryMode::Natural
444        );
445        assert_eq!(
446            detect_mode("\"find hosts connected to 10.0.0.1\""),
447            QueryMode::Natural
448        );
449        assert_eq!(
450            detect_mode("which hosts have critical vulnerabilities?"),
451            QueryMode::Natural
452        );
453    }
454
455    #[test]
456    fn test_edge_cases() {
457        // Empty input
458        assert_eq!(detect_mode(""), QueryMode::Unknown);
459
460        // Just whitespace
461        assert_eq!(detect_mode("   "), QueryMode::Unknown);
462
463        // Case insensitivity
464        assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); // No space after
465        assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
466        assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
467    }
468}