Skip to main content

reddb_server/storage/query/modes/
detect.rs

1//! Query Mode Detection
2//!
3//! Automatically detects the query language based on syntax patterns.
4
5/// Supported query modes
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QueryMode {
8    /// SQL-style: SELECT ... FROM ... WHERE
9    Sql,
10    /// Gremlin traversal: g.V().out().has(...)
11    Gremlin,
12    /// Cypher pattern matching: MATCH (a)-[r]->(b) RETURN
13    Cypher,
14    /// SPARQL RDF queries: SELECT ?var WHERE { ... }
15    Sparql,
16    /// Path queries: PATH FROM ... TO ... VIA
17    Path,
18    /// Natural language queries
19    Natural,
20    /// Unknown mode
21    Unknown,
22}
23
24/// Detect the query mode from input string
25pub fn detect_mode(input: &str) -> QueryMode {
26    let trimmed = input.trim();
27    let lower = trimmed.to_lowercase();
28
29    // Check for quoted natural language (starts with quote)
30    if trimmed.starts_with('"') || trimmed.starts_with('\'') {
31        return QueryMode::Natural;
32    }
33
34    // Gremlin: starts with g. or __.
35    if lower.starts_with("g.") || lower.starts_with("__.") {
36        return QueryMode::Gremlin;
37    }
38
39    // Path: PATH or PATHS keyword at start
40    if lower.starts_with("path ") || lower.starts_with("paths ") {
41        return QueryMode::Path;
42    }
43
44    // SPARQL: has ?variable pattern or PREFIX keyword
45    if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
46        return QueryMode::Sparql;
47    }
48
49    // Cypher: MATCH keyword at start
50    if lower.starts_with("match ") || lower.starts_with("match(") {
51        return QueryMode::Cypher;
52    }
53
54    // SQL: SELECT, FROM, INSERT, UPDATE, DELETE, CREATE, DROP, ALTER, GRAPH, SEARCH at start
55    // Plus transaction / admin one-word commands that have no trailing
56    // clause (BEGIN/COMMIT/ROLLBACK/SAVEPOINT/RELEASE/VACUUM/ANALYZE/
57    // RESET/TENANT/etc.) — matched by equality on the trimmed token.
58    let first_token = lower.split_whitespace().next().unwrap_or("");
59    if matches!(
60        first_token,
61        "begin"
62            | "start"
63            | "commit"
64            | "rollback"
65            | "savepoint"
66            | "release"
67            | "end"
68            | "vacuum"
69            | "analyze"
70            | "reset"
71            | "copy"
72            | "refresh"
73            | "explain"
74            | "grant"
75            | "revoke"
76            | "attach"
77            | "detach"
78            | "simulate"
79            | "apply"
80            | "events"
81    ) {
82        return QueryMode::Sql;
83    }
84    if lower.starts_with("select ")
85        || lower.starts_with("from ")
86        || lower.starts_with("insert ")
87        || lower.starts_with("update ")
88        || lower.starts_with("delete ")
89        || lower.starts_with("truncate ")
90        || lower.starts_with("create ")
91        || lower.starts_with("drop ")
92        || lower.starts_with("alter ")
93        || lower.starts_with("vector ")
94        || lower.starts_with("hybrid ")
95        || lower.starts_with("graph ")
96        || lower.starts_with("queue ")
97        || lower.starts_with("events ")
98        || lower.starts_with("tree ")
99        || lower.starts_with("vault ")
100        || lower.starts_with("unseal vault ")
101        || lower.starts_with("rotate vault ")
102        || lower.starts_with("history vault ")
103        || lower.starts_with("list vault ")
104        || lower.starts_with("watch vault ")
105        || lower.starts_with("delete vault ")
106        || lower.starts_with("purge vault ")
107        || lower.starts_with("search ")
108        || lower.starts_with("ask ")
109        || lower.starts_with("put config ")
110        || lower.starts_with("get config ")
111        || lower.starts_with("resolve config ")
112        || lower.starts_with("rotate config ")
113        || lower.starts_with("delete config ")
114        || lower.starts_with("history config ")
115        || lower.starts_with("list config ")
116        || lower.starts_with("watch config ")
117        || lower.starts_with("incr config ")
118        || lower.starts_with("decr config ")
119        || lower.starts_with("add config ")
120        || lower.starts_with("invalidate config ")
121        || lower.starts_with("invalidate tags ")
122        || lower.starts_with("set config ")
123        || lower.starts_with("set secret ")
124        || lower.starts_with("set tenant")
125        || lower.starts_with("show config")
126        || lower.starts_with("show collections")
127        || lower.starts_with("show tables")
128        || lower.starts_with("show queues")
129        || lower.starts_with("show vectors")
130        || lower.starts_with("show documents")
131        || lower.starts_with("show timeseries")
132        || lower.starts_with("show graphs")
133        || lower.starts_with("kv ")
134        || lower.starts_with("show kv")
135        || lower.starts_with("show configs")
136        || lower.starts_with("show vaults")
137        || lower.starts_with("show schema")
138        || lower.starts_with("show indices")
139        || lower.starts_with("show sample ")
140        || lower.starts_with("show secret")
141        || lower.starts_with("show stats")
142        || lower.starts_with("show tenant")
143        || lower.starts_with("show policies")
144        || lower.starts_with("show effective ")
145    {
146        // But check if it's SPARQL-style SELECT with ?variable
147        if lower.starts_with("select ") && lower.contains(" ?") {
148            return QueryMode::Sparql;
149        }
150        return QueryMode::Sql;
151    }
152
153    // Natural language detection: common question words and patterns
154    if is_natural_language(&lower) {
155        return QueryMode::Natural;
156    }
157
158    QueryMode::Unknown
159}
160
161/// Check for SPARQL-specific patterns
162fn has_sparql_pattern(lower: &str) -> bool {
163    // SPARQL variables start with ? or $
164    // SPARQL has WHERE { } with triple patterns
165
166    // Check for ?variable pattern (not after comparison operators)
167    let has_var = lower.contains(" ?") && !lower.contains("= ?") && !lower.contains("> ?");
168
169    // Check for typical SPARQL structure
170    let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
171
172    // Check for RDF predicates (prefixed URIs like :predicate or prefix:pred)
173    let has_prefix_pattern = lower.contains(":")
174        && (lower.contains(":<")
175            || lower.contains("> :")
176            || lower.contains(" :") && lower.contains("?"));
177
178    has_var || has_triple_pattern || has_prefix_pattern
179}
180
181/// Detect natural language patterns
182fn is_natural_language(lower: &str) -> bool {
183    // Question words
184    let question_starters = [
185        "find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
186        "tell ", "display ", "search ", "look ",
187    ];
188
189    // Common natural language verbs/phrases
190    let nl_patterns = [
191        " with ",
192        " for ",
193        " that ",
194        " have ",
195        " has ",
196        " can ",
197        " are ",
198        " is ",
199        " all ",
200        " me ",
201        " the ",
202        " from ",
203        " to ",
204        " on ",
205        " in ",
206        "vulnerable",
207        "credential",
208        "password",
209        "user",
210        "host",
211        "service",
212        "connected",
213        "reachable",
214        "exposed",
215        "critical",
216    ];
217
218    // Check starters
219    for starter in question_starters.iter() {
220        if lower.starts_with(starter) {
221            return true;
222        }
223    }
224
225    // Check for multiple natural language patterns (at least 2)
226    let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
227
228    pattern_count >= 2
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_sql_detection() {
237        assert_eq!(
238            detect_mode("SELECT * FROM users WHERE id = 1"),
239            QueryMode::Sql
240        );
241        assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
242        assert_eq!(
243            detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
244            QueryMode::Sql
245        );
246        assert_eq!(
247            detect_mode("INSERT INTO users VALUES (1, 'alice')"),
248            QueryMode::Sql
249        );
250        assert_eq!(
251            detect_mode("UPDATE hosts SET status = 'active'"),
252            QueryMode::Sql
253        );
254        assert_eq!(
255            detect_mode("DELETE FROM logs WHERE age > 30"),
256            QueryMode::Sql
257        );
258        assert_eq!(
259            detect_mode("QUEUE GROUP CREATE tasks workers"),
260            QueryMode::Sql
261        );
262        assert_eq!(
263            detect_mode("EVENTS BACKFILL users TO audit"),
264            QueryMode::Sql
265        );
266        assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
267        assert_eq!(
268            detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
269            QueryMode::Sql
270        );
271        assert_eq!(
272            detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
273            QueryMode::Sql
274        );
275        assert_eq!(
276            detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
277            QueryMode::Sql
278        );
279        assert_eq!(
280            detect_mode("SET SECRET red.secret.api = 'x'"),
281            QueryMode::Sql
282        );
283        assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
284        assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
285        assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
286        assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
287        assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
288        assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
289        assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
290        assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
291        assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
292        assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
293        assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
294        assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
295        assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
296        assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
297        assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
298        assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
299        assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
300    }
301
302    #[test]
303    fn test_gremlin_detection() {
304        assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
305        assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
306        assert_eq!(
307            detect_mode("g.V().out('connects').in('has_service')"),
308            QueryMode::Gremlin
309        );
310        assert_eq!(
311            detect_mode("g.E().hasLabel('auth_access')"),
312            QueryMode::Gremlin
313        );
314        assert_eq!(
315            detect_mode("__.out('knows').has('name', 'bob')"),
316            QueryMode::Gremlin
317        );
318        assert_eq!(
319            detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
320            QueryMode::Gremlin
321        );
322    }
323
324    #[test]
325    fn test_cypher_detection() {
326        assert_eq!(
327            detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
328            QueryMode::Cypher
329        );
330        assert_eq!(
331            detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
332            QueryMode::Cypher
333        );
334        assert_eq!(
335            detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
336            QueryMode::Cypher
337        );
338        assert_eq!(
339            detect_mode("MATCH(a:User) RETURN a.name"),
340            QueryMode::Cypher
341        );
342    }
343
344    #[test]
345    fn test_sparql_detection() {
346        assert_eq!(
347            detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
348            QueryMode::Sparql
349        );
350        assert_eq!(
351            detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
352            QueryMode::Sparql
353        );
354        assert_eq!(
355            detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
356            QueryMode::Sparql
357        );
358    }
359
360    #[test]
361    fn test_path_detection() {
362        assert_eq!(
363            detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
364            QueryMode::Path
365        );
366        assert_eq!(
367            detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
368            QueryMode::Path
369        );
370        assert_eq!(
371            detect_mode("path from user('root') to service('ssh') via auth_access"),
372            QueryMode::Path
373        );
374    }
375
376    #[test]
377    fn test_natural_detection() {
378        assert_eq!(
379            detect_mode("find all hosts with ssh open"),
380            QueryMode::Natural
381        );
382        assert_eq!(
383            detect_mode("show me vulnerable services"),
384            QueryMode::Natural
385        );
386        assert_eq!(
387            detect_mode("what credentials can reach the database?"),
388            QueryMode::Natural
389        );
390        assert_eq!(
391            detect_mode("list users with weak passwords"),
392            QueryMode::Natural
393        );
394        assert_eq!(
395            detect_mode("\"find hosts connected to 10.0.0.1\""),
396            QueryMode::Natural
397        );
398        assert_eq!(
399            detect_mode("which hosts have critical vulnerabilities?"),
400            QueryMode::Natural
401        );
402    }
403
404    #[test]
405    fn test_edge_cases() {
406        // Empty input
407        assert_eq!(detect_mode(""), QueryMode::Unknown);
408
409        // Just whitespace
410        assert_eq!(detect_mode("   "), QueryMode::Unknown);
411
412        // Case insensitivity
413        assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); // No space after
414        assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
415        assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
416    }
417}