pub mod detect;
pub mod gremlin;
pub mod natural;
pub mod sparql;
pub use detect::{detect_mode, QueryMode};
pub use gremlin::{GremlinParser, GremlinStep, GremlinTraversal};
pub use natural::{NaturalParser, NaturalQuery, QueryIntent};
pub use sparql::{SparqlParser, SparqlQuery, TriplePattern};
use crate::ast::QueryExpr;
pub fn parse_multi(input: &str) -> Result<QueryExpr, MultiParseError> {
let mode = detect_mode(input);
match mode {
QueryMode::Sql | QueryMode::Cypher | QueryMode::Path => {
crate::parser::parse(input)
.map(|q| q.query)
.map_err(|e| MultiParseError::Parse(e.to_string()))
}
QueryMode::Gremlin => {
let traversal = GremlinParser::parse(input)?;
Ok(traversal.to_query_expr())
}
QueryMode::Sparql => {
let sparql = SparqlParser::parse(input)?;
Ok(sparql.to_query_expr())
}
QueryMode::Natural => {
let natural = NaturalParser::parse(input)?;
Ok(natural.to_query_expr())
}
QueryMode::Unknown => Err(MultiParseError::UnknownMode(input.to_string())),
}
}
#[derive(Debug, Clone)]
pub enum MultiParseError {
Parse(String),
Gremlin(String),
Sparql(String),
Natural(String),
UnknownMode(String),
}
impl std::fmt::Display for MultiParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Parse(e) => write!(f, "Parse error: {}", e),
Self::Gremlin(e) => write!(f, "Gremlin error: {}", e),
Self::Sparql(e) => write!(f, "SPARQL error: {}", e),
Self::Natural(e) => write!(f, "Natural language error: {}", e),
Self::UnknownMode(q) => write!(f, "Unknown query mode for: {}", q),
}
}
}
impl std::error::Error for MultiParseError {}
impl From<gremlin::GremlinError> for MultiParseError {
fn from(e: gremlin::GremlinError) -> Self {
Self::Gremlin(e.to_string())
}
}
impl From<sparql::SparqlError> for MultiParseError {
fn from(e: sparql::SparqlError) -> Self {
Self::Sparql(e.to_string())
}
}
impl From<natural::NaturalError> for MultiParseError {
fn from(e: natural::NaturalError) -> Self {
Self::Natural(e.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_sql() {
assert_eq!(detect_mode("SELECT * FROM users"), QueryMode::Sql);
assert_eq!(detect_mode("select name from hosts"), QueryMode::Sql);
}
#[test]
fn test_detect_gremlin() {
assert_eq!(detect_mode("g.V()"), QueryMode::Gremlin);
assert_eq!(
detect_mode("g.V().has('name', 'alice')"),
QueryMode::Gremlin
);
assert_eq!(detect_mode("__.out('knows')"), QueryMode::Gremlin);
}
#[test]
fn test_detect_cypher() {
assert_eq!(
detect_mode("MATCH (a)-[r]->(b) RETURN a"),
QueryMode::Cypher
);
assert_eq!(detect_mode("match (n:Host) return n"), QueryMode::Cypher);
}
#[test]
fn test_detect_sparql() {
assert_eq!(
detect_mode("SELECT ?name WHERE { ?s :name ?name }"),
QueryMode::Sparql
);
assert_eq!(
detect_mode("PREFIX ex: <http://example.org/> SELECT ?x"),
QueryMode::Sparql
);
}
#[test]
fn test_detect_path() {
assert_eq!(
detect_mode("PATH FROM host('10.0.0.1') TO host('10.0.0.2')"),
QueryMode::Path
);
assert_eq!(
detect_mode("PATHS ALL FROM user('admin') TO credential('root')"),
QueryMode::Path
);
}
#[test]
fn test_detect_natural() {
assert_eq!(detect_mode("find all hosts with ssh"), QueryMode::Natural);
assert_eq!(
detect_mode("show me credentials for user admin"),
QueryMode::Natural
);
assert_eq!(
detect_mode("\"what vulnerabilities affect host 10.0.0.1?\""),
QueryMode::Natural
);
}
#[test]
fn parse_multi_routes_supported_modes_to_query_exprs() {
assert!(matches!(
parse_multi("SELECT * FROM hosts").expect("sql"),
QueryExpr::Table(_)
));
assert!(matches!(
parse_multi("g.V().hasLabel('Host').limit(2)").expect("gremlin"),
QueryExpr::Graph(_)
));
assert!(matches!(
parse_multi("SELECT ?s WHERE { ?s :name 'alice' }").expect("sparql"),
QueryExpr::Graph(_)
));
assert!(matches!(
parse_multi("find all hosts with ssh").expect("natural"),
QueryExpr::Graph(_)
));
}
#[test]
fn parse_multi_surfaces_parse_and_unknown_errors() {
let err = parse_multi("SELECT * FROM").expect_err("bad SQL should fail");
assert!(matches!(err, MultiParseError::Parse(_)));
assert!(err.to_string().starts_with("Parse error:"));
let err = parse_multi("").expect_err("empty should be unknown");
assert!(matches!(err, MultiParseError::UnknownMode(ref q) if q.is_empty()));
assert_eq!(err.to_string(), "Unknown query mode for: ");
}
#[test]
fn multi_parse_error_display_covers_all_variants() {
let cases = [
(
MultiParseError::Parse("bad".to_string()),
"Parse error: bad",
),
(
MultiParseError::Gremlin("bad".to_string()),
"Gremlin error: bad",
),
(
MultiParseError::Sparql("bad".to_string()),
"SPARQL error: bad",
),
(
MultiParseError::Natural("bad".to_string()),
"Natural language error: bad",
),
(
MultiParseError::UnknownMode("???".to_string()),
"Unknown query mode for: ???",
),
];
for (err, expected) in cases {
assert_eq!(err.to_string(), expected);
}
}
}