#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryMode {
Sql,
Gremlin,
Cypher,
Sparql,
Path,
Natural,
Unknown,
}
pub fn detect_mode(input: &str) -> QueryMode {
let trimmed = input.trim();
let lower = trimmed.to_lowercase();
if trimmed.starts_with('"') || trimmed.starts_with('\'') {
return QueryMode::Natural;
}
if lower.starts_with("g.") || lower.starts_with("__.") {
return QueryMode::Gremlin;
}
if lower.starts_with("path ") || lower.starts_with("paths ") {
return QueryMode::Path;
}
if lower.starts_with("prefix ") || has_sparql_pattern(&lower) {
return QueryMode::Sparql;
}
if lower.starts_with("match ") || lower.starts_with("match(") {
return QueryMode::Cypher;
}
let first_token = lower.split_whitespace().next().unwrap_or("");
if matches!(
first_token,
"begin"
| "start"
| "commit"
| "rollback"
| "savepoint"
| "release"
| "end"
| "vacuum"
| "analyze"
| "reset"
| "copy"
| "refresh"
| "explain"
| "grant"
| "revoke"
| "attach"
| "detach"
| "simulate"
| "apply"
| "events"
) {
return QueryMode::Sql;
}
if lower.starts_with("select ")
|| lower.starts_with("from ")
|| lower.starts_with("insert ")
|| lower.starts_with("update ")
|| lower.starts_with("delete ")
|| lower.starts_with("truncate ")
|| lower.starts_with("create ")
|| lower.starts_with("drop ")
|| lower.starts_with("alter ")
|| lower.starts_with("vector ")
|| lower.starts_with("hybrid ")
|| lower.starts_with("graph ")
|| lower.starts_with("queue ")
|| lower.starts_with("events ")
|| lower.starts_with("tree ")
|| lower.starts_with("vault ")
|| lower.starts_with("unseal vault ")
|| lower.starts_with("rotate vault ")
|| lower.starts_with("history vault ")
|| lower.starts_with("list vault ")
|| lower.starts_with("watch vault ")
|| lower.starts_with("delete vault ")
|| lower.starts_with("purge vault ")
|| lower.starts_with("search ")
|| lower.starts_with("ask ")
|| lower.starts_with("put config ")
|| lower.starts_with("get config ")
|| lower.starts_with("resolve config ")
|| lower.starts_with("rotate config ")
|| lower.starts_with("delete config ")
|| lower.starts_with("history config ")
|| lower.starts_with("list config ")
|| lower.starts_with("watch config ")
|| lower.starts_with("incr config ")
|| lower.starts_with("decr config ")
|| lower.starts_with("add config ")
|| lower.starts_with("invalidate config ")
|| lower.starts_with("invalidate tags ")
|| lower.starts_with("set config ")
|| lower.starts_with("set secret ")
|| lower.starts_with("set tenant")
|| lower.starts_with("show config")
|| lower.starts_with("show collections")
|| lower.starts_with("show tables")
|| lower.starts_with("show queues")
|| lower.starts_with("show vectors")
|| lower.starts_with("show documents")
|| lower.starts_with("show timeseries")
|| lower.starts_with("show graphs")
|| lower.starts_with("kv ")
|| lower.starts_with("show kv")
|| lower.starts_with("show configs")
|| lower.starts_with("show vaults")
|| lower.starts_with("show schema")
|| lower.starts_with("show indices")
|| lower.starts_with("show sample ")
|| lower.starts_with("show secret")
|| lower.starts_with("show stats")
|| lower.starts_with("show tenant")
|| lower.starts_with("show policies")
|| lower.starts_with("show effective ")
{
if lower.starts_with("select ") && lower.contains(" ?") {
return QueryMode::Sparql;
}
return QueryMode::Sql;
}
if is_natural_language(&lower) {
return QueryMode::Natural;
}
QueryMode::Unknown
}
fn has_sparql_pattern(lower: &str) -> bool {
let has_var = lower.contains(" ?") && !lower.contains("= ?") && !lower.contains("> ?");
let has_triple_pattern = lower.contains(" where {") || lower.contains(" where{");
let has_prefix_pattern = lower.contains(":")
&& (lower.contains(":<")
|| lower.contains("> :")
|| lower.contains(" :") && lower.contains("?"));
has_var || has_triple_pattern || has_prefix_pattern
}
fn is_natural_language(lower: &str) -> bool {
let question_starters = [
"find ", "show ", "list ", "what ", "which ", "where ", "how ", "who ", "get ", "give ",
"tell ", "display ", "search ", "look ",
];
let nl_patterns = [
" with ",
" for ",
" that ",
" have ",
" has ",
" can ",
" are ",
" is ",
" all ",
" me ",
" the ",
" from ",
" to ",
" on ",
" in ",
"vulnerable",
"credential",
"password",
"user",
"host",
"service",
"connected",
"reachable",
"exposed",
"critical",
];
for starter in question_starters.iter() {
if lower.starts_with(starter) {
return true;
}
}
let pattern_count = nl_patterns.iter().filter(|p| lower.contains(*p)).count();
pattern_count >= 2
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sql_detection() {
assert_eq!(
detect_mode("SELECT * FROM users WHERE id = 1"),
QueryMode::Sql
);
assert_eq!(detect_mode("select name, age from hosts"), QueryMode::Sql);
assert_eq!(
detect_mode("FROM hosts h WHERE h.os = 'Linux'"),
QueryMode::Sql
);
assert_eq!(
detect_mode("INSERT INTO users VALUES (1, 'alice')"),
QueryMode::Sql
);
assert_eq!(
detect_mode("UPDATE hosts SET status = 'active'"),
QueryMode::Sql
);
assert_eq!(
detect_mode("DELETE FROM logs WHERE age > 30"),
QueryMode::Sql
);
assert_eq!(
detect_mode("QUEUE GROUP CREATE tasks workers"),
QueryMode::Sql
);
assert_eq!(
detect_mode("EVENTS BACKFILL users TO audit"),
QueryMode::Sql
);
assert_eq!(detect_mode("TREE VALIDATE forest.org"), QueryMode::Sql);
assert_eq!(
detect_mode("VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
QueryMode::Sql
);
assert_eq!(
detect_mode("HYBRID FROM hosts VECTOR SEARCH embeddings SIMILAR TO [1.0, 0.0] LIMIT 5"),
QueryMode::Sql
);
assert_eq!(
detect_mode("ASK 'what happened on host 10.0.0.1?' USING groq"),
QueryMode::Sql
);
assert_eq!(
detect_mode("SET SECRET red.secret.api = 'x'"),
QueryMode::Sql
);
assert_eq!(detect_mode("SHOW SECRET red.secret"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW SECRETS"), QueryMode::Sql);
assert_eq!(detect_mode("VAULT PUT secrets.api = 'x'"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW SAMPLE users"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW TABLES"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW QUEUES"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW VECTORS"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW DOCUMENTS"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW TIMESERIES"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW GRAPHS"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW KV"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW KVS"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW CONFIGS"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW VAULTS"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW SCHEMA users"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW INDICES"), QueryMode::Sql);
assert_eq!(detect_mode("SHOW STATS users"), QueryMode::Sql);
}
#[test]
fn test_gremlin_detection() {
assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
assert_eq!(detect_mode("g.V().hasLabel('host')"), QueryMode::Gremlin);
assert_eq!(
detect_mode("g.V().out('connects').in('has_service')"),
QueryMode::Gremlin
);
assert_eq!(
detect_mode("g.E().hasLabel('auth_access')"),
QueryMode::Gremlin
);
assert_eq!(
detect_mode("__.out('knows').has('name', 'bob')"),
QueryMode::Gremlin
);
assert_eq!(
detect_mode("g.V('host:10.0.0.1').repeat(out()).times(3)"),
QueryMode::Gremlin
);
}
#[test]
fn test_cypher_detection() {
assert_eq!(
detect_mode("MATCH (a)-[r]->(b) RETURN a, b"),
QueryMode::Cypher
);
assert_eq!(
detect_mode("MATCH (h:Host)-[:HAS_SERVICE]->(s:Service)"),
QueryMode::Cypher
);
assert_eq!(
detect_mode("match (n) where n.ip = '10.0.0.1' return n"),
QueryMode::Cypher
);
assert_eq!(
detect_mode("MATCH(a:User) RETURN a.name"),
QueryMode::Cypher
);
}
#[test]
fn test_sparql_detection() {
assert_eq!(
detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
QueryMode::Sparql
);
assert_eq!(
detect_mode("PREFIX ex: <http://example.org/> SELECT ?x WHERE { ?x ex:type ?t }"),
QueryMode::Sparql
);
assert_eq!(
detect_mode("SELECT ?host ?ip WHERE { ?host :hasIP ?ip }"),
QueryMode::Sparql
);
}
#[test]
fn test_path_detection() {
assert_eq!(
detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
QueryMode::Path
);
assert_eq!(
detect_mode("PATHS ALL FROM credential('admin') TO host('db')"),
QueryMode::Path
);
assert_eq!(
detect_mode("path from user('root') to service('ssh') via auth_access"),
QueryMode::Path
);
}
#[test]
fn test_natural_detection() {
assert_eq!(
detect_mode("find all hosts with ssh open"),
QueryMode::Natural
);
assert_eq!(
detect_mode("show me vulnerable services"),
QueryMode::Natural
);
assert_eq!(
detect_mode("what credentials can reach the database?"),
QueryMode::Natural
);
assert_eq!(
detect_mode("list users with weak passwords"),
QueryMode::Natural
);
assert_eq!(
detect_mode("\"find hosts connected to 10.0.0.1\""),
QueryMode::Natural
);
assert_eq!(
detect_mode("which hosts have critical vulnerabilities?"),
QueryMode::Natural
);
}
#[test]
fn test_edge_cases() {
assert_eq!(detect_mode(""), QueryMode::Unknown);
assert_eq!(detect_mode(" "), QueryMode::Unknown);
assert_eq!(detect_mode("SELECT"), QueryMode::Unknown); assert_eq!(detect_mode("G.V()"), QueryMode::Gremlin);
assert_eq!(detect_mode("Match (a) RETURN a"), QueryMode::Cypher);
}
}