use pest::Parser;
use pest_derive::Parser;
use std::error::Error;
use std::fmt;
use crate::query::Query;
#[derive(Parser)]
#[grammar = "query/grammar/query.pest"]
pub struct QueryDSLParser;
#[derive(Debug, Clone)]
pub enum QueryParseError {
UnexpectedToken(String),
UnexpectedEndOfInput,
}
impl Error for QueryParseError {}
impl fmt::Display for QueryParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnexpectedToken(token) => {
write!(f, "Unexpected token: {token}")
}
Self::UnexpectedEndOfInput => {
write!(f, "Unexpected end of input")
}
}
}
}
pub fn parse_query(input: &str) -> Result<Query, QueryParseError> {
let mut pairs = QueryDSLParser::parse(Rule::query, input)
.map_err(|e| QueryParseError::UnexpectedToken(e.to_string()))?;
let query = pairs.next().expect("Empty query string");
let mut inner = query.into_inner();
let constructed_query: Query;
match inner.next() {
Some(disjunction_pair) => {
if matches!(disjunction_pair.as_rule(), Rule::EOI) {
constructed_query = Query::Sequence(vec![]);
} else {
constructed_query = parse_disjunction(disjunction_pair)?;
}
}
None => return Err(QueryParseError::UnexpectedEndOfInput),
}
#[cfg(test)]
println!("Constructed query AST:\n{constructed_query:?}");
Ok(constructed_query)
}
fn parse_disjunction(
pair: pest::iterators::Pair<Rule>,
) -> Result<Query, QueryParseError> {
if pair.as_rule() != Rule::disjunction {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected disjunction rule, got {:?}",
pair.as_rule()
)));
}
let sequences: Vec<Query> = pair
.into_inner()
.map(parse_sequence)
.collect::<Result<Vec<Query>, _>>()?;
if sequences.len() == 1 {
Ok(sequences[0].clone())
} else {
Ok(Query::Disjunction(sequences))
}
}
fn parse_sequence(
pair: pest::iterators::Pair<Rule>,
) -> Result<Query, QueryParseError> {
if pair.as_rule() != Rule::sequence {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected sequence rule, got {:?}",
pair.as_rule()
)));
}
let mut steps: Vec<Query> = vec![];
for step_result in pair.into_inner().map(parse_step) {
let step = step_result?;
steps.push(step);
}
Ok(Query::Sequence(steps))
}
fn parse_step(
pair: pest::iterators::Pair<Rule>,
) -> Result<Query, QueryParseError> {
if pair.as_rule() != Rule::step {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected step rule, got {:?}",
pair.as_rule()
)));
}
let mut inner = pair.into_inner();
let mut queries: Vec<Query> = vec![];
let first_pair =
inner.next().ok_or(QueryParseError::UnexpectedEndOfInput)?;
match first_pair.as_rule() {
Rule::field => {
let field = parse_field(&first_pair)?;
queries.push(field);
}
Rule::index => {
queries.push(parse_index(first_pair)?);
}
Rule::range => {
queries.push(parse_range(first_pair)?);
}
Rule::array_wildcard => {
queries.push(Query::ArrayWildcard);
}
Rule::field_wildcard => {
queries.push(Query::FieldWildcard);
}
Rule::regex => {
let regex = parse_regex(&first_pair)?;
queries.push(regex);
}
Rule::group => {
let group_query = parse_group(first_pair)?;
queries.push(group_query);
}
_ => {
return Err(QueryParseError::UnexpectedToken(format!(
"Unexpected start of step: {:?}",
first_pair.as_rule()
)));
}
}
while let Some(pair) = inner.peek() {
if matches!(
pair.as_rule(),
Rule::index | Rule::range | Rule::array_wildcard
) {
let pair = inner.next().unwrap();
match pair.as_rule() {
Rule::index => {
queries.push(parse_index(pair)?);
}
Rule::range => {
queries.push(parse_range(pair)?);
}
Rule::array_wildcard => {
queries.push(Query::ArrayWildcard);
}
_ => unreachable!(),
}
} else {
break;
}
}
if let Some(modifier_pair) = inner.next() {
if modifier_pair.as_rule() == Rule::modifier {
let last_query = queries.pop().ok_or_else(|| {
QueryParseError::UnexpectedToken(
"No query to apply modifier to".to_string(),
)
})?;
let modified_query = match modifier_pair.as_str() {
"*" => Query::KleeneStar(Box::new(last_query)),
"?" => Query::Optional(Box::new(last_query)),
_ => {
return Err(QueryParseError::UnexpectedToken(format!(
"Unknown modifier: {}",
modifier_pair.as_str()
)));
}
};
queries.push(modified_query);
} else {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected modifier, got {:?}",
modifier_pair.as_rule()
)));
}
}
Ok(if queries.len() == 1 {
queries.into_iter().next().unwrap()
} else {
Query::Sequence(queries)
})
}
fn parse_field(
pair: &pest::iterators::Pair<Rule>,
) -> Result<Query, QueryParseError> {
if pair.as_rule() != Rule::field {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected field rule, got {:?}",
pair.as_rule()
)));
}
let raw = pair.as_str();
let name = if raw.starts_with('"') && raw.ends_with('"') && raw.len() >= 2 {
unescape_json_string(&raw[1..raw.len() - 1])
} else {
raw.to_string()
};
Ok(Query::Field(name))
}
fn unescape_json_string(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut chars = s.chars();
while let Some(c) = chars.next() {
if c == '\\' {
match chars.next() {
Some('"') => result.push('"'),
Some('\\') => result.push('\\'),
Some('/') => result.push('/'),
Some('b') => result.push('\u{0008}'),
Some('f') => result.push('\u{000C}'),
Some('n') => result.push('\n'),
Some('r') => result.push('\r'),
Some('t') => result.push('\t'),
Some('u') => {
let hex: String = chars.by_ref().take(4).collect();
if let Ok(code_point) = u32::from_str_radix(&hex, 16)
&& let Some(ch) = char::from_u32(code_point)
{
result.push(ch);
}
}
Some(other) => {
result.push('\\');
result.push(other);
}
None => {
result.push('\\');
}
}
} else {
result.push(c);
}
}
result
}
fn parse_group(
pair: pest::iterators::Pair<Rule>,
) -> Result<Query, QueryParseError> {
if pair.as_rule() != Rule::group {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected group rule, got {:?}",
pair.as_rule()
)));
}
let mut inner = pair.into_inner();
let disjunction_pair =
inner.next().ok_or(QueryParseError::UnexpectedEndOfInput)?;
parse_disjunction(disjunction_pair)
}
fn parse_index(
pair: pest::iterators::Pair<Rule>,
) -> Result<Query, QueryParseError> {
if pair.as_rule() != Rule::index {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected index rule, got {:?}",
pair.as_rule()
)));
}
let number_pair = pair
.into_inner()
.next()
.ok_or(QueryParseError::UnexpectedEndOfInput)?;
let idx = number_pair.as_str().parse::<usize>().map_err(|_| {
QueryParseError::UnexpectedToken(number_pair.as_str().to_string())
})?;
Ok(Query::Index(idx))
}
fn parse_range(
pair: pest::iterators::Pair<Rule>,
) -> Result<Query, QueryParseError> {
if pair.as_rule() != Rule::range {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected range rule, got {:?}",
pair.as_rule()
)));
}
let inner = pair.into_inner();
let mut start: Option<usize> = None;
let mut end: Option<usize> = None;
for p in inner {
match p.as_rule() {
Rule::range_start => {
start = Some(p.as_str().parse::<usize>().map_err(|_| {
QueryParseError::UnexpectedToken(p.as_str().to_string())
})?);
}
Rule::range_end => {
end = Some(p.as_str().parse::<usize>().map_err(|_| {
QueryParseError::UnexpectedToken(p.as_str().to_string())
})?);
}
_ => {}
}
}
match (start, end) {
(None, None) => Ok(Query::ArrayWildcard),
(None, Some(e)) => Ok(Query::Range(Some(0), Some(e))),
(Some(s), None) => Ok(Query::RangeFrom(s)),
(Some(s), Some(e)) => Ok(Query::Range(Some(s), Some(e))),
}
}
fn parse_regex(
pair: &pest::iterators::Pair<Rule>,
) -> Result<Query, QueryParseError> {
if pair.as_rule() != Rule::regex {
return Err(QueryParseError::UnexpectedToken(format!(
"Expected regex rule, got {:?}",
pair.as_rule()
)));
}
let regex_str = pair.as_str();
if regex_str.len() < 2
|| !regex_str.starts_with('/')
|| !regex_str.ends_with('/')
{
return Err(QueryParseError::UnexpectedToken(regex_str.to_string()));
}
let pattern = ®ex_str[1..regex_str.len() - 1];
let unescaped_pattern = pattern.replace("\\/", "/");
Ok(Query::Regex(unescaped_pattern))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_field() {
let query = "foo";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_field_and_number() {
let query = "foo123[42]";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_single_regex() {
let query = "/foo.bar/";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_disjunction() {
let query = "foo | bar";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_kleene_star() {
let query = "a*";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_optional() {
let query = "b?";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_complex_query() {
let query = "foo.bar[0]?.baz*";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_multiple_optional() {
let query = "c*.c?.c?";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_simple_disjunction_group() {
let query = "(foo | bar).baz";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_any_path_group() {
let query = "(* | [*])*";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_any_path_group_in_query() {
let query = "a.(* | [*])*.b?";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_nested_groups_trivial() {
let query = "((foo))";
let result = parse_query(query).unwrap();
assert_eq!("foo", result.to_string());
}
#[test]
fn parse_nested_groups() {
let query = "((foo.bar)* | bar)";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_group_sequence() {
let query = "(foo.bar.baz)?";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_invalid_number() {
let result = parse_query("foo[abc]");
assert!(
matches!(result, Err(QueryParseError::UnexpectedToken(_))),
"Actual result: {result:?}"
);
}
#[test]
fn parse_invalid_regex() {
let result = parse_query("/unclosed");
assert!(matches!(result, Err(QueryParseError::UnexpectedToken(_))));
}
#[test]
fn parse_empty() {
let query = "";
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn reserved_chars_in_double_quotes() {
let query = r#"".|*?[]()/""#;
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn group_any_reserved_chars_in_double_quotes() {
let query = r#"("." | "|" | "*" | "?" | "[" | "]" | "(" | ")" | "/")*"#;
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_unclosed_double_quotes() {
let query = r#"""#;
let result = parse_query(query);
assert!(matches!(result, Err(QueryParseError::UnexpectedToken(_))));
}
#[test]
fn parse_valid_key_with_spaces() {
let query = r#""key space".foo"#;
let result = parse_query(query).unwrap();
assert_eq!(query, result.to_string());
}
#[test]
fn parse_invalid_key_with_spaces() {
let query = r"spaces not allowed without double quotes";
let result = parse_query(query);
assert!(matches!(result, Err(QueryParseError::UnexpectedToken(_))));
}
#[test]
fn parse_invalid_key_with_reserved_chars() {
let result = parse_query(r"][");
assert!(matches!(result, Err(QueryParseError::UnexpectedToken(_))));
}
#[test]
fn parse_range_both_bounds() {
let result = parse_query("[1:3]").unwrap();
assert_eq!(
result,
Query::Sequence(vec![Query::Range(Some(1), Some(3))]),
);
assert_eq!("[1:3]", result.to_string());
}
#[test]
fn parse_range_start_only() {
let result = parse_query("[2:]").unwrap();
assert_eq!(result, Query::Sequence(vec![Query::RangeFrom(2)]));
assert_eq!("[2:]", result.to_string());
}
#[test]
fn parse_range_end_only() {
let result = parse_query("[:3]").unwrap();
assert_eq!(
result,
Query::Sequence(vec![Query::Range(Some(0), Some(3))]),
);
assert_eq!("[0:3]", result.to_string());
}
#[test]
fn parse_range_unbounded() {
let result = parse_query("[:]").unwrap();
assert_eq!(result, Query::Sequence(vec![Query::ArrayWildcard]));
assert_eq!("[*]", result.to_string());
}
#[test]
fn parse_range_zero_start() {
let result = parse_query("[0:5]").unwrap();
assert_eq!(
result,
Query::Sequence(vec![Query::Range(Some(0), Some(5))]),
);
assert_eq!("[0:5]", result.to_string());
}
#[test]
fn parse_range_on_field() {
let result = parse_query("foo[1:3]").unwrap();
assert_eq!(
result,
Query::Sequence(vec![Query::Sequence(vec![
Query::Field("foo".into()),
Query::Range(Some(1), Some(3)),
])]),
);
assert_eq!("foo[1:3]", result.to_string());
}
#[test]
fn quoted_field_strips_quotes() {
let result = parse_query(r#""foo""#).unwrap();
assert_eq!(result, Query::Sequence(vec![Query::Field("foo".into())]));
assert_eq!("foo", result.to_string());
}
#[test]
fn quoted_field_with_slash() {
let result = parse_query(r#""/activities""#).unwrap();
assert_eq!(
result,
Query::Sequence(vec![Query::Field("/activities".into())])
);
assert_eq!(r#""/activities""#, result.to_string());
}
#[test]
fn quoted_field_with_dot() {
let result = parse_query(r#""a.b""#).unwrap();
assert_eq!(result, Query::Sequence(vec![Query::Field("a.b".into())]));
assert_eq!(r#""a.b""#, result.to_string());
}
#[test]
fn quoted_field_unescape_backslash() {
let result = parse_query(r#""a\\b""#).unwrap();
assert_eq!(result, Query::Sequence(vec![Query::Field(r"a\b".into())]));
assert_eq!(r#""a\\b""#, result.to_string());
}
#[test]
fn quoted_field_unescape_inner_quote() {
let result = parse_query(r#""a\"b""#).unwrap();
assert_eq!(result, Query::Sequence(vec![Query::Field("a\"b".into())]));
assert_eq!(r#""a\"b""#, result.to_string());
}
#[test]
fn quoted_field_unescape_unicode() {
let result = parse_query(r#""\u0041""#).unwrap();
assert_eq!(result, Query::Sequence(vec![Query::Field("A".into())]));
assert_eq!("A", result.to_string());
}
#[test]
fn quoted_field_in_sequence() {
let result = parse_query(r#"paths."/activities""#).unwrap();
assert_eq!(
result,
Query::Sequence(vec![
Query::Field("paths".into()),
Query::Field("/activities".into()),
])
);
assert_eq!(r#"paths."/activities""#, result.to_string());
}
}