Skip to main content

reddb_server/storage/query/modes/
detect.rs

1//! Query Mode Detection
2//!
3//! Automatically detects the query language based on syntax patterns.
4
5/// Supported query modes
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8    /// SQL-style: SELECT ... FROM ... WHERE
9    Sql,
10    /// Gremlin traversal: g.V().out().has(...)
11    Gremlin,
12    /// Cypher pattern matching: MATCH (a)-[r]->(b) RETURN
13    Cypher,
14    /// SPARQL RDF queries: SELECT ?var WHERE { ... }
15    Sparql,
16    /// Path queries: PATH FROM ... TO ... VIA
17    Path,
18    /// Natural language queries
19    Natural,
20    /// Unknown mode
21    Unknown,
22}
23
24/// Detect the query mode from input string
25pub fn detect_mode(input: &str) -> QueryMode {
26    let trimmed = input.trim();
27    let lower = trimmed.to_lowercase();
28
29    // Check for quoted natural language (starts with quote)
30    if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31        return QueryMode::Natural;
32    }
33
34    // Gremlin: starts with g. or __.
35    if lower.starts_with("g.") || lower.starts_with("__.") {
36        return QueryMode::Gremlin;
37    }
38
39    // Path: PATH or PATHS keyword at start
40    if lower.starts_with("path ") || lower.starts_with("paths ") {
41        return QueryMode::Path;
42    }
43
44    // SPARQL: has ?variable pattern or PREFIX keyword
45    if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46        return QueryMode::Sparql;
47    }
48
49    // Cypher: MATCH keyword at start
50    if lower.starts_with("match ") || lower.starts_with("match(") {
51        return QueryMode::Cypher;
52    }
53
54    // SQL: SELECT, FROM, INSERT, UPDATE, DELETE, CREATE, DROP, ALTER, GRAPH, SEARCH at start
55    // Plus transaction / admin one-word commands that have no trailing
56    // clause (BEGIN/COMMIT/ROLLBACK/SAVEPOINT/RELEASE/VACUUM/ANALYZE/
57    // RESET/TENANT/etc.) — matched by equality on the trimmed token.
58    let first_token = lower.split_whitespace().next().unwrap_or("");
59    if matches!(
60        first_token,
61        "begin"
62            | "start"
63            | "commit"
64            | "rollback"
65            | "savepoint"
66            | "release"
67            | "end"
68            | "vacuum"
69            | "analyze"
70            | "reset"
71            | "copy"
72            | "refresh"
73            | "explain"
74            | "grant"
75            | "revoke"
76            | "attach"
77            | "detach"
78            | "simulate"
79            | "apply"
80            | "events"
81            | "describe"
82            | "desc"
83    ) {
84        return QueryMode::Sql;
85    }
86    if lower.starts_with("select ")
87        || lower.starts_with("from ")
88        || lower.starts_with("insert ")
89        || lower.starts_with("update ")
90        || lower.starts_with("delete ")
91        || lower.starts_with("truncate ")
92        || lower.starts_with("create ")
93        || lower.starts_with("drop ")
94        || lower.starts_with("alter ")
95        || lower.starts_with("vector ")
96        || lower.starts_with("hybrid ")
97        || lower.starts_with("graph ")
98        || lower.starts_with("queue ")
99        || lower.starts_with("events ")
100        || lower.starts_with("tree ")
101        || lower.starts_with("hll ")
102        || lower.starts_with("sketch ")
103        || lower.starts_with("filter ")
104        || lower.starts_with("vault ")
105        || lower.starts_with("unseal vault ")
106        || lower.starts_with("rotate vault ")
107        || lower.starts_with("history vault ")
108        || lower.starts_with("list vault ")
109        || lower.starts_with("watch vault ")
110        || lower.starts_with("delete vault ")
111        || lower.starts_with("purge vault ")
112        || lower.starts_with("search ")
113        || lower.starts_with("ask ")
114        || lower.starts_with("put config ")
115        || lower.starts_with("get config ")
116        || lower.starts_with("resolve config ")
117        || lower.starts_with("rotate config ")
118        || lower.starts_with("delete config ")
119        || lower.starts_with("history config ")
120        || lower.starts_with("list config ")
121        || lower.starts_with("watch config ")
122        || lower.starts_with("incr config ")
123        || lower.starts_with("decr config ")
124        || lower.starts_with("add config ")
125        || lower.starts_with("invalidate config ")
126        || lower.starts_with("invalidate tags ")
127        || lower.starts_with("set config ")
128        || lower.starts_with("set secret ")
129        || lower.starts_with("set tenant")
130        || lower.starts_with("show create ")
131        || lower.starts_with("show config")
132        || lower.starts_with("show collections")
133        || lower.starts_with("show tables")
134        || lower.starts_with("show queues")
135        || lower.starts_with("show vectors")
136        || lower.starts_with("show documents")
137        || lower.starts_with("show timeseries")
138        || lower.starts_with("show graphs")
139        || lower.starts_with("kv ")
140        || lower.starts_with("show kv")
141        || lower.starts_with("show configs")
142        || lower.starts_with("show vaults")
143        || lower.starts_with("show schema")
144        || lower.starts_with("show indices")
145        || lower.starts_with("show indexes")
146        || lower.starts_with("show sample ")
147        || lower.starts_with("show secret")
148        || lower.starts_with("show stats")
149        || lower.starts_with("show tenant")
150        || lower.starts_with("show policies")
151        || lower.starts_with("show effective ")
152        || lower.starts_with("describe ")
153        || lower.starts_with("desc ")
154    {
155        // But check if it's SPARQL-style SELECT with ?variable.
156        // Bare SQL placeholders (`?`) and numbered placeholders (`?1`)
157        // must stay in SQL mode for parameter binding.
158        if lower.starts_with("select ") && has_sparql_variable(&lower) {
159            return QueryMode::Sparql;
160        }
161        return QueryMode::Sql;
162    }
163
164    // Natural language detection: common question words and patterns
165    if is_natural_language(&lower) {
166        return QueryMode::Natural;
167    }
168
169    QueryMode::Unknown
170}
171
172/// Check for SPARQL-specific patterns
173fn has_sparql_pattern(lower: &str) -> bool {
174    // SPARQL variables start with ? or $
175    // SPARQL has WHERE { } with triple patterns
176
177    // Check for ?variable pattern. Bare SQL placeholders (`?`) and
178    // numbered placeholders (`?1`) are not SPARQL variables.
179    let has_var = has_sparql_variable(lower);
180
181    // Check for typical SPARQL structure
182    let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
183
184    // Check for RDF predicates (prefixed URIs like :predicate or prefix:pred)
185    let has_prefix_pattern = lower.contains(":")
186        && (lower.contains(":<")
187            || lower.contains("> :")
188            || lower.contains(" :") && lower.contains("?"));
189
190    has_var || has_triple_pattern || has_prefix_pattern
191}
192
193fn has_sparql_variable(input: &str) -> bool {
194    let bytes = input.as_bytes();
195    bytes
196        .windows(2)
197        .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
198}
199
200fn is_sparql_variable_start(byte: u8) -> bool {
201    byte.is_ascii_alphabetic() || byte == b'_'
202}
203
204/// Detect natural language patterns
205fn is_natural_language(lower: &str) -> bool {
206    // Question words
207    let question_starters = [
208        "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
209        "tell ", "display ", "search ", "look ",
210    ];
211
212    // Common natural language verbs/phrases
213    let nl_patterns = [
214        " with ",
215        " for ",
216        " that ",
217        " have ",
218        " has ",
219        " can ",
220        " are ",
221        " is ",
222        " all ",
223        " me ",
224        " the ",
225        " from ",
226        " to ",
227        " on ",
228        " in ",
229        "vulnerable",
230        "credential",
231        "password",
232        "user",
233        "host",
234        "service",
235        "connected",
236        "reachable",
237        "exposed",
238        "critical",
239    ];
240
241    // Check starters
242    for starter in question_starters.iter() {
243        if lower.starts_with(starter) {
244            return true;
245        }
246    }
247
248    // Check for multiple natural language patterns (at least 2)
249    let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
250
251    pattern_count >= 2
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_sql_detection() {
260        assert_eq!(
261            detect_mode("SELECT * FROM users WHERE id = 1"),
262            QueryMode::Sql
263        );
264        assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
265        assert_eq!(
266            detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
267            QueryMode::Sql
268        );
269        assert_eq!(
270            detect_mode("INSERT INTO users VALUES (1, 'alice')"),
271            QueryMode::Sql
272        );
273        assert_eq!(
274            detect_mode("UPDATE hosts SET status = 'active'"),
275            QueryMode::Sql
276        );
277        assert_eq!(
278            detect_mode("DELETE FROM logs WHERE age > 30"),
279            QueryMode::Sql
280        );
281        assert_eq!(
282            detect_mode("QUEUE GROUP CREATE tasks workers"),
283            QueryMode::Sql
284        );
285        assert_eq!(
286            detect_mode("EVENTS BACKFILL users TO audit"),
287            QueryMode::Sql
288        );
289        assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
290        assert_eq!(
291            detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
292            QueryMode::Sql
293        );
294        assert_eq!(
295            detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
296            QueryMode::Sql
297        );
298        assert_eq!(
299            detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
300            QueryMode::Sql
301        );
302        assert_eq!(
303            detect_mode("SELECT name FROM t WHERE id = ?"),
304            QueryMode::Sql
305        );
306        assert_eq!(
307            detect_mode("SELECT name FROM t WHERE id = ?1"),
308            QueryMode::Sql
309        );
310        assert_eq!(
311            detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
312            QueryMode::Sql
313        );
314        assert_eq!(
315            detect_mode("SET SECRET red.secret.api = 'x'"),
316            QueryMode::Sql
317        );
318        assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
319        assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
320        assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
321        assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
322        assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
323        assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
324        assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
325        assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
326        assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
327        assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
328        assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
329        assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
330        assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
331        assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
332        assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
333        assert_eq!(detect_mode("SHOW CREATE TABLE users"), QueryMode::Sql);
334        assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
335        assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
336        assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
337        assert_eq!(detect_mode("SHOW INDEXES"), QueryMode::Sql);
338        assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
339    }
340
341    #[test]
342    fn test_gremlin_detection() {
343        assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
344        assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
345        assert_eq!(
346            detect_mode("g.V().out('connects').in('has_service')"),
347            QueryMode::Gremlin
348        );
349        assert_eq!(
350            detect_mode("g.E().hasLabel('auth_access')"),
351            QueryMode::Gremlin
352        );
353        assert_eq!(
354            detect_mode("__.out('knows').has('name', 'bob')"),
355            QueryMode::Gremlin
356        );
357        assert_eq!(
358            detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
359            QueryMode::Gremlin
360        );
361    }
362
363    #[test]
364    fn test_cypher_detection() {
365        assert_eq!(
366            detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
367            QueryMode::Cypher
368        );
369        assert_eq!(
370            detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
371            QueryMode::Cypher
372        );
373        assert_eq!(
374            detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
375            QueryMode::Cypher
376        );
377        assert_eq!(
378            detect_mode("MATCH(a:User) RETURN a.name"),
379            QueryMode::Cypher
380        );
381    }
382
383    #[test]
384    fn test_sparql_detection() {
385        assert_eq!(
386            detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
387            QueryMode::Sparql
388        );
389        assert_eq!(
390            detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
391            QueryMode::Sparql
392        );
393        assert_eq!(
394            detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
395            QueryMode::Sparql
396        );
397        assert_eq!(
398            detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
399            QueryMode::Sparql
400        );
401    }
402
403    #[test]
404    fn test_path_detection() {
405        assert_eq!(
406            detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
407            QueryMode::Path
408        );
409        assert_eq!(
410            detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
411            QueryMode::Path
412        );
413        assert_eq!(
414            detect_mode("path from user('root') to service('ssh') via auth_access"),
415            QueryMode::Path
416        );
417    }
418
419    #[test]
420    fn test_natural_detection() {
421        assert_eq!(
422            detect_mode("find all hosts with ssh open"),
423            QueryMode::Natural
424        );
425        assert_eq!(
426            detect_mode("show me vulnerable services"),
427            QueryMode::Natural
428        );
429        assert_eq!(
430            detect_mode("what credentials can reach the database?"),
431            QueryMode::Natural
432        );
433        assert_eq!(
434            detect_mode("list users with weak passwords"),
435            QueryMode::Natural
436        );
437        assert_eq!(
438            detect_mode("\"find hosts connected to 10.0.0.1\""),
439            QueryMode::Natural
440        );
441        assert_eq!(
442            detect_mode("which hosts have critical vulnerabilities?"),
443            QueryMode::Natural
444        );
445    }
446
447    #[test]
448    fn test_edge_cases() {
449        // Empty input
450        assert_eq!(detect_mode(""), QueryMode::Unknown);
451
452        // Just whitespace
453        assert_eq!(detect_mode("   "), QueryMode::Unknown);
454
455        // Case insensitivity
456        assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); // No space after
457        assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
458        assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
459    }
460}