Skip to main content

reddb_server/storage/query/modes/
detect.rs

1//! Query Mode Detection
2//!
3//! Automatically detects the query language based on syntax patterns.
4
5/// Supported query modes
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8    /// SQL-style: SELECT ... FROM ... WHERE
9    Sql,
10    /// Gremlin traversal: g.V().out().has(...)
11    Gremlin,
12    /// Cypher pattern matching: MATCH (a)-[r]->(b) RETURN
13    Cypher,
14    /// SPARQL RDF queries: SELECT ?var WHERE { ... }
15    Sparql,
16    /// Path queries: PATH FROM ... TO ... VIA
17    Path,
18    /// Natural language queries
19    Natural,
20    /// Unknown mode
21    Unknown,
22}
23
24/// Detect the query mode from input string
25pub fn detect_mode(input: &str) -> QueryMode {
26    let trimmed = input.trim();
27    let lower = trimmed.to_lowercase();
28
29    // Check for quoted natural language (starts with quote)
30    if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31        return QueryMode::Natural;
32    }
33
34    // Gremlin: starts with g. or __.
35    if lower.starts_with("g.") || lower.starts_with("__.") {
36        return QueryMode::Gremlin;
37    }
38
39    // Path: PATH or PATHS keyword at start
40    if lower.starts_with("path ") || lower.starts_with("paths ") {
41        return QueryMode::Path;
42    }
43
44    // SPARQL: has ?variable pattern or PREFIX keyword
45    if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46        return QueryMode::Sparql;
47    }
48
49    // Cypher: MATCH keyword at start
50    if lower.starts_with("match ") || lower.starts_with("match(") {
51        return QueryMode::Cypher;
52    }
53
54    // SQL: SELECT, FROM, INSERT, UPDATE, DELETE, CREATE, DROP, ALTER, GRAPH, SEARCH at start
55    // Plus transaction / admin one-word commands that have no trailing
56    // clause (BEGIN/COMMIT/ROLLBACK/SAVEPOINT/RELEASE/VACUUM/ANALYZE/
57    // RESET/TENANT/etc.) — matched by equality on the trimmed token.
58    let first_token = lower.split_whitespace().next().unwrap_or("");
59    if matches!(
60        first_token,
61        "begin"
62            | "start"
63            | "commit"
64            | "rollback"
65            | "savepoint"
66            | "release"
67            | "end"
68            | "vacuum"
69            | "analyze"
70            | "reset"
71            | "copy"
72            | "refresh"
73            | "explain"
74            | "grant"
75            | "revoke"
76            | "attach"
77            | "detach"
78            | "simulate"
79            | "apply"
80            | "events"
81            | "describe"
82            | "desc"
83    ) {
84        return QueryMode::Sql;
85    }
86    if lower.starts_with("select ")
87        || lower.starts_with("from ")
88        || lower.starts_with("insert ")
89        || lower.starts_with("update ")
90        || lower.starts_with("delete ")
91        || lower.starts_with("truncate ")
92        || lower.starts_with("create ")
93        || lower.starts_with("drop ")
94        || lower.starts_with("alter ")
95        || lower.starts_with("vector ")
96        || lower.starts_with("hybrid ")
97        || lower.starts_with("graph ")
98        || lower.starts_with("queue ")
99        || lower.starts_with("events ")
100        || lower.starts_with("tree ")
101        || lower.starts_with("hll ")
102        || lower.starts_with("sketch ")
103        || lower.starts_with("filter ")
104        || lower.starts_with("vault ")
105        || lower.starts_with("unseal vault ")
106        || lower.starts_with("rotate vault ")
107        || lower.starts_with("history vault ")
108        || lower.starts_with("list vault ")
109        || lower.starts_with("watch vault ")
110        || lower.starts_with("delete vault ")
111        || lower.starts_with("purge vault ")
112        || lower.starts_with("search ")
113        || lower.starts_with("ask ")
114        || lower.starts_with("put config ")
115        || lower.starts_with("get config ")
116        || lower.starts_with("resolve config ")
117        || lower.starts_with("rotate config ")
118        || lower.starts_with("delete config ")
119        || lower.starts_with("history config ")
120        || lower.starts_with("list config ")
121        || lower.starts_with("watch config ")
122        || lower.starts_with("incr config ")
123        || lower.starts_with("decr config ")
124        || lower.starts_with("add config ")
125        || lower.starts_with("invalidate config ")
126        || lower.starts_with("invalidate tags ")
127        || lower.starts_with("set config ")
128        || lower.starts_with("set secret ")
129        || lower.starts_with("set tenant")
130        || lower.starts_with("show config")
131        || lower.starts_with("show collections")
132        || lower.starts_with("show tables")
133        || lower.starts_with("show queues")
134        || lower.starts_with("show vectors")
135        || lower.starts_with("show documents")
136        || lower.starts_with("show timeseries")
137        || lower.starts_with("show graphs")
138        || lower.starts_with("kv ")
139        || lower.starts_with("show kv")
140        || lower.starts_with("show configs")
141        || lower.starts_with("show vaults")
142        || lower.starts_with("show schema")
143        || lower.starts_with("show indices")
144        || lower.starts_with("show sample ")
145        || lower.starts_with("show secret")
146        || lower.starts_with("show stats")
147        || lower.starts_with("show tenant")
148        || lower.starts_with("show policies")
149        || lower.starts_with("show effective ")
150        || lower.starts_with("describe ")
151        || lower.starts_with("desc ")
152    {
153        // But check if it's SPARQL-style SELECT with ?variable.
154        // Bare SQL placeholders (`?`) and numbered placeholders (`?1`)
155        // must stay in SQL mode for parameter binding.
156        if lower.starts_with("select ") && has_sparql_variable(&lower) {
157            return QueryMode::Sparql;
158        }
159        return QueryMode::Sql;
160    }
161
162    // Natural language detection: common question words and patterns
163    if is_natural_language(&lower) {
164        return QueryMode::Natural;
165    }
166
167    QueryMode::Unknown
168}
169
170/// Check for SPARQL-specific patterns
171fn has_sparql_pattern(lower: &str) -> bool {
172    // SPARQL variables start with ? or $
173    // SPARQL has WHERE { } with triple patterns
174
175    // Check for ?variable pattern. Bare SQL placeholders (`?`) and
176    // numbered placeholders (`?1`) are not SPARQL variables.
177    let has_var = has_sparql_variable(lower);
178
179    // Check for typical SPARQL structure
180    let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
181
182    // Check for RDF predicates (prefixed URIs like :predicate or prefix:pred)
183    let has_prefix_pattern = lower.contains(":")
184        && (lower.contains(":<")
185            || lower.contains("> :")
186            || lower.contains(" :") && lower.contains("?"));
187
188    has_var || has_triple_pattern || has_prefix_pattern
189}
190
191fn has_sparql_variable(input: &str) -> bool {
192    let bytes = input.as_bytes();
193    bytes
194        .windows(2)
195        .any(|pair| pair[0] == b'?' && is_sparql_variable_start(pair[1]))
196}
197
198fn is_sparql_variable_start(byte: u8) -> bool {
199    byte.is_ascii_alphabetic() || byte == b'_'
200}
201
202/// Detect natural language patterns
203fn is_natural_language(lower: &str) -> bool {
204    // Question words
205    let question_starters = [
206        "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
207        "tell ", "display ", "search ", "look ",
208    ];
209
210    // Common natural language verbs/phrases
211    let nl_patterns = [
212        " with ",
213        " for ",
214        " that ",
215        " have ",
216        " has ",
217        " can ",
218        " are ",
219        " is ",
220        " all ",
221        " me ",
222        " the ",
223        " from ",
224        " to ",
225        " on ",
226        " in ",
227        "vulnerable",
228        "credential",
229        "password",
230        "user",
231        "host",
232        "service",
233        "connected",
234        "reachable",
235        "exposed",
236        "critical",
237    ];
238
239    // Check starters
240    for starter in question_starters.iter() {
241        if lower.starts_with(starter) {
242            return true;
243        }
244    }
245
246    // Check for multiple natural language patterns (at least 2)
247    let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
248
249    pattern_count >= 2
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_sql_detection() {
258        assert_eq!(
259            detect_mode("SELECT * FROM users WHERE id = 1"),
260            QueryMode::Sql
261        );
262        assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
263        assert_eq!(
264            detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
265            QueryMode::Sql
266        );
267        assert_eq!(
268            detect_mode("INSERT INTO users VALUES (1, 'alice')"),
269            QueryMode::Sql
270        );
271        assert_eq!(
272            detect_mode("UPDATE hosts SET status = 'active'"),
273            QueryMode::Sql
274        );
275        assert_eq!(
276            detect_mode("DELETE FROM logs WHERE age > 30"),
277            QueryMode::Sql
278        );
279        assert_eq!(
280            detect_mode("QUEUE GROUP CREATE tasks workers"),
281            QueryMode::Sql
282        );
283        assert_eq!(
284            detect_mode("EVENTS BACKFILL users TO audit"),
285            QueryMode::Sql
286        );
287        assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
288        assert_eq!(
289            detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
290            QueryMode::Sql
291        );
292        assert_eq!(
293            detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
294            QueryMode::Sql
295        );
296        assert_eq!(
297            detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
298            QueryMode::Sql
299        );
300        assert_eq!(
301            detect_mode("SELECT name FROM t WHERE id = ?"),
302            QueryMode::Sql
303        );
304        assert_eq!(
305            detect_mode("SELECT name FROM t WHERE id = ?1"),
306            QueryMode::Sql
307        );
308        assert_eq!(
309            detect_mode("INSERT INTO t (id, name) VALUES (?, ?)"),
310            QueryMode::Sql
311        );
312        assert_eq!(
313            detect_mode("SET SECRET red.secret.api = 'x'"),
314            QueryMode::Sql
315        );
316        assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
317        assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
318        assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
319        assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
320        assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
321        assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
322        assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
323        assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
324        assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
325        assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
326        assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
327        assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
328        assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
329        assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
330        assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
331        assert_eq!(detect_mode("DESCRIBE users"), QueryMode::Sql);
332        assert_eq!(detect_mode("DESC users"), QueryMode::Sql);
333        assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
334        assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
335    }
336
337    #[test]
338    fn test_gremlin_detection() {
339        assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
340        assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
341        assert_eq!(
342            detect_mode("g.V().out('connects').in('has_service')"),
343            QueryMode::Gremlin
344        );
345        assert_eq!(
346            detect_mode("g.E().hasLabel('auth_access')"),
347            QueryMode::Gremlin
348        );
349        assert_eq!(
350            detect_mode("__.out('knows').has('name', 'bob')"),
351            QueryMode::Gremlin
352        );
353        assert_eq!(
354            detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
355            QueryMode::Gremlin
356        );
357    }
358
359    #[test]
360    fn test_cypher_detection() {
361        assert_eq!(
362            detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
363            QueryMode::Cypher
364        );
365        assert_eq!(
366            detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
367            QueryMode::Cypher
368        );
369        assert_eq!(
370            detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
371            QueryMode::Cypher
372        );
373        assert_eq!(
374            detect_mode("MATCH(a:User) RETURN a.name"),
375            QueryMode::Cypher
376        );
377    }
378
379    #[test]
380    fn test_sparql_detection() {
381        assert_eq!(
382            detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
383            QueryMode::Sparql
384        );
385        assert_eq!(
386            detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
387            QueryMode::Sparql
388        );
389        assert_eq!(
390            detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
391            QueryMode::Sparql
392        );
393        assert_eq!(
394            detect_mode("SELECT ?x WHERE { ?x rdf:type :Foo }"),
395            QueryMode::Sparql
396        );
397    }
398
399    #[test]
400    fn test_path_detection() {
401        assert_eq!(
402            detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
403            QueryMode::Path
404        );
405        assert_eq!(
406            detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
407            QueryMode::Path
408        );
409        assert_eq!(
410            detect_mode("path from user('root') to service('ssh') via auth_access"),
411            QueryMode::Path
412        );
413    }
414
415    #[test]
416    fn test_natural_detection() {
417        assert_eq!(
418            detect_mode("find all hosts with ssh open"),
419            QueryMode::Natural
420        );
421        assert_eq!(
422            detect_mode("show me vulnerable services"),
423            QueryMode::Natural
424        );
425        assert_eq!(
426            detect_mode("what credentials can reach the database?"),
427            QueryMode::Natural
428        );
429        assert_eq!(
430            detect_mode("list users with weak passwords"),
431            QueryMode::Natural
432        );
433        assert_eq!(
434            detect_mode("\"find hosts connected to 10.0.0.1\""),
435            QueryMode::Natural
436        );
437        assert_eq!(
438            detect_mode("which hosts have critical vulnerabilities?"),
439            QueryMode::Natural
440        );
441    }
442
443    #[test]
444    fn test_edge_cases() {
445        // Empty input
446        assert_eq!(detect_mode(""), QueryMode::Unknown);
447
448        // Just whitespace
449        assert_eq!(detect_mode("   "), QueryMode::Unknown);
450
451        // Case insensitivity
452        assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); // No space after
453        assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
454        assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
455    }
456}