#![cfg_attr(feature = "strict_docs", allow(missing_docs))]
use adze_ir::{Grammar, SymbolId};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone)]
pub struct Subtree {
pub symbol: SymbolId,
pub children: Vec<Subtree>,
#[allow(dead_code)]
pub start_byte: usize,
#[allow(dead_code)]
pub end_byte: usize,
}
#[derive(Debug, Clone)]
pub struct Query {
pub patterns: Vec<Pattern>,
pub capture_names: HashMap<String, u32>,
pub predicates: Vec<Predicate>,
}
#[derive(Debug, Clone)]
pub struct Pattern {
pub root: PatternNode,
pub predicate_indices: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct PatternNode {
symbol: Option<SymbolId>,
capture: Option<String>,
children: Vec<PatternChild>,
#[allow(dead_code)]
is_anchor: bool,
}
#[derive(Debug, Clone)]
pub struct PatternChild {
node: PatternNode,
quantifier: Quantifier,
#[allow(dead_code)]
field_name: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Quantifier {
One,
ZeroOrOne,
ZeroOrMore,
OneOrMore,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub enum Predicate {
Equal(Vec<u32>),
NotEqual(Vec<u32>),
Match(u32, String),
NotMatch(u32, String),
AnyOf(u32, Vec<String>),
}
#[derive(Debug, Clone)]
pub struct QueryMatch {
pub pattern_index: usize,
pub captures: Vec<QueryCapture>,
}
#[derive(Debug, Clone)]
pub struct QueryCapture {
pub index: u32,
pub subtree: Subtree,
}
pub struct QueryParser<'a> {
grammar: &'a Grammar,
input: &'a str,
position: usize,
}
impl<'a> QueryParser<'a> {
pub fn new(grammar: &'a Grammar, input: &'a str) -> Self {
Self {
grammar,
input,
position: 0,
}
}
pub fn parse(mut self) -> Result<Query, QueryError> {
let mut patterns = Vec::new();
let mut capture_names = HashMap::new();
let mut predicates = Vec::new();
let mut next_capture_id = 0;
self.skip_whitespace();
while !self.is_at_end() {
let (pattern_node, pattern_predicates) =
self.parse_pattern(&mut capture_names, &mut next_capture_id)?;
let predicate_start = predicates.len();
for pred in pattern_predicates {
predicates.push(pred);
}
let predicate_end = predicates.len();
patterns.push(Pattern {
root: pattern_node,
predicate_indices: (predicate_start..predicate_end).collect(),
});
self.skip_whitespace();
}
if patterns.is_empty() {
return Err(QueryError::EmptyQuery);
}
Ok(Query {
patterns,
capture_names,
predicates,
})
}
fn parse_pattern(
&mut self,
capture_names: &mut HashMap<String, u32>,
next_capture_id: &mut u32,
) -> Result<(PatternNode, Vec<Predicate>), QueryError> {
self.skip_whitespace();
if !self.consume_char('(') {
return Err(QueryError::ExpectedOpenParen(self.position));
}
let node = self.parse_pattern_node(capture_names, next_capture_id)?;
let mut predicates = Vec::new();
self.skip_whitespace();
while self.peek_char() == Some('(') && self.peek_ahead(1) == Some('#') {
predicates.push(self.parse_predicate(capture_names)?);
self.skip_whitespace();
}
Ok((node, predicates))
}
fn parse_pattern_node(
&mut self,
capture_names: &mut HashMap<String, u32>,
next_capture_id: &mut u32,
) -> Result<PatternNode, QueryError> {
self.skip_whitespace();
let is_anchor = self.consume_char('.');
let symbol = if self.consume_char('_') {
None } else {
let node_type = self.parse_identifier()?;
self.find_symbol(&node_type)?
};
let mut children = Vec::new();
self.skip_whitespace();
while self.peek_char() != Some(')') && !self.is_at_end() {
let field_name = if self.peek_char() == Some('[') {
self.advance();
let name = self.parse_identifier()?;
if !self.consume_char(']') {
return Err(QueryError::ExpectedCloseBracket(self.position));
}
self.skip_whitespace();
if !self.consume_char(':') {
return Err(QueryError::ExpectedColon(self.position));
}
Some(name)
} else {
None
};
self.skip_whitespace();
if !self.consume_char('(') {
return Err(QueryError::ExpectedOpenParen(self.position));
}
let child_node = self.parse_pattern_node(capture_names, next_capture_id)?;
self.skip_whitespace();
let quantifier = match self.peek_char() {
Some('?') => {
self.advance();
Quantifier::ZeroOrOne
}
Some('*') => {
self.advance();
Quantifier::ZeroOrMore
}
Some('+') => {
self.advance();
Quantifier::OneOrMore
}
_ => Quantifier::One,
};
children.push(PatternChild {
node: child_node,
quantifier,
field_name,
});
self.skip_whitespace();
}
if !self.consume_char(')') {
return Err(QueryError::ExpectedCloseParen(self.position));
}
self.skip_whitespace();
let capture = if self.peek_char() == Some('@') {
self.advance();
let name = self.parse_identifier()?;
if !capture_names.contains_key(&name) {
capture_names.insert(name.clone(), *next_capture_id);
*next_capture_id += 1;
}
Some(name)
} else {
None
};
Ok(PatternNode {
symbol,
capture,
children,
is_anchor,
})
}
fn parse_predicate(
&mut self,
capture_names: &HashMap<String, u32>,
) -> Result<Predicate, QueryError> {
self.skip_whitespace();
if !self.consume_char('(') {
return Err(QueryError::ExpectedOpenParen(self.position));
}
if !self.consume_char('#') {
return Err(QueryError::ExpectedHash(self.position));
}
let predicate_name = self.parse_identifier()?;
if !self.consume_char('?') {
return Err(QueryError::ExpectedQuestionMark(self.position));
}
self.skip_whitespace();
let predicate = match predicate_name.as_str() {
"eq" => {
let mut captures = Vec::new();
while self.peek_char() == Some('@') {
self.advance();
let name = self.parse_identifier()?;
let id = capture_names
.get(&name)
.ok_or(QueryError::UnknownCapture(name))?;
captures.push(*id);
self.skip_whitespace();
}
if captures.len() < 2 {
return Err(QueryError::InvalidPredicate(
"eq? requires at least 2 captures".into(),
));
}
Predicate::Equal(captures)
}
"match" => {
if !self.consume_char('@') {
return Err(QueryError::ExpectedAt(self.position));
}
let capture_name = self.parse_identifier()?;
let capture_id = capture_names
.get(&capture_name)
.ok_or(QueryError::UnknownCapture(capture_name))?;
self.skip_whitespace();
let pattern = self.parse_string()?;
Predicate::Match(*capture_id, pattern)
}
_ => return Err(QueryError::UnknownPredicate(predicate_name)),
};
self.skip_whitespace();
if !self.consume_char(')') {
return Err(QueryError::ExpectedCloseParen(self.position));
}
Ok(predicate)
}
fn find_symbol(&self, name: &str) -> Result<Option<SymbolId>, QueryError> {
for (id, token) in &self.grammar.tokens {
if token.name == name {
return Ok(Some(*id));
}
}
for (id, rule_name) in &self.grammar.rule_names {
if rule_name == name {
return Ok(Some(*id));
}
}
Err(QueryError::UnknownNodeType(name.to_string()))
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.peek_char() {
if ch.is_whitespace() || ch == ';' {
self.advance();
if ch == ';' {
while let Some(ch) = self.peek_char() {
self.advance();
if ch == '\n' {
break;
}
}
}
} else {
break;
}
}
}
fn peek_char(&self) -> Option<char> {
self.input.chars().nth(self.position)
}
fn peek_ahead(&self, n: usize) -> Option<char> {
self.input.chars().nth(self.position + n)
}
fn advance(&mut self) -> Option<char> {
let ch = self.peek_char();
if ch.is_some() {
self.position += 1;
}
ch
}
fn consume_char(&mut self, expected: char) -> bool {
if self.peek_char() == Some(expected) {
self.advance();
true
} else {
false
}
}
fn is_at_end(&self) -> bool {
self.position >= self.input.len()
}
fn parse_identifier(&mut self) -> Result<String, QueryError> {
let start = self.position;
match self.peek_char() {
Some(ch) if ch.is_alphabetic() || ch == '_' => self.advance(),
_ => return Err(QueryError::ExpectedIdentifier(self.position)),
};
while let Some(ch) = self.peek_char() {
if ch.is_alphanumeric() || ch == '_' || ch == '-' {
self.advance();
} else {
break;
}
}
Ok(self.input[start..self.position].to_string())
}
fn parse_string(&mut self) -> Result<String, QueryError> {
if !self.consume_char('"') {
return Err(QueryError::ExpectedString(self.position));
}
let mut result = String::new();
let mut escaped = false;
while let Some(ch) = self.advance() {
if escaped {
match ch {
'n' => result.push('\n'),
't' => result.push('\t'),
'r' => result.push('\r'),
'\\' => result.push('\\'),
'"' => result.push('"'),
_ => {
result.push('\\');
result.push(ch);
}
}
escaped = false;
} else if ch == '\\' {
escaped = true;
} else if ch == '"' {
return Ok(result);
} else {
result.push(ch);
}
}
Err(QueryError::UnterminatedString(self.position))
}
}
pub struct QueryCursor {
max_depth: Option<usize>,
}
impl Default for QueryCursor {
fn default() -> Self {
Self::new()
}
}
impl QueryCursor {
pub fn new() -> Self {
Self { max_depth: None }
}
pub fn set_max_depth(&mut self, depth: usize) {
self.max_depth = Some(depth);
}
pub fn matches<'a>(
&self,
query: &'a Query,
root: &'a Subtree,
) -> impl Iterator<Item = QueryMatch> + 'a {
QueryMatches {
query,
root,
pattern_index: 0,
node_stack: vec![(root, 0)],
max_depth: self.max_depth,
captures: Vec::new(),
}
}
}
struct QueryMatches<'a> {
query: &'a Query,
root: &'a Subtree,
pattern_index: usize,
node_stack: Vec<(&'a Subtree, usize)>,
max_depth: Option<usize>,
captures: Vec<QueryCapture>,
}
impl<'a> Iterator for QueryMatches<'a> {
type Item = QueryMatch;
fn next(&mut self) -> Option<Self::Item> {
while self.pattern_index < self.query.patterns.len() {
let pattern = &self.query.patterns[self.pattern_index];
if let Some(result) = self.find_next_match(pattern) {
return Some(result);
}
self.pattern_index += 1;
self.node_stack = vec![(self.root, 0)];
}
None
}
}
impl<'a> QueryMatches<'a> {
fn find_next_match(&mut self, pattern: &Pattern) -> Option<QueryMatch> {
while let Some((node, depth)) = self.node_stack.pop() {
if let Some(max) = self.max_depth
&& depth > max
{
continue;
}
self.captures.clear();
if self.match_pattern_node(&pattern.root, node, depth) {
if self.check_predicates(pattern) {
let result = QueryMatch {
pattern_index: self.pattern_index,
captures: self.captures.clone(),
};
self.add_children_to_stack(node, depth + 1);
return Some(result);
}
}
self.add_children_to_stack(node, depth + 1);
}
None
}
fn match_pattern_node(
&mut self,
pattern: &PatternNode,
node: &'a Subtree,
_depth: usize,
) -> bool {
if let Some(expected_symbol) = pattern.symbol
&& node.symbol != expected_symbol
{
return false;
}
if let Some(ref capture_name) = pattern.capture
&& let Some(&capture_id) = self.query.capture_names.get(capture_name)
{
self.captures.push(QueryCapture {
index: capture_id,
subtree: node.clone(),
});
}
if !self.match_children(&pattern.children, &node.children) {
return false;
}
true
}
fn match_children(
&mut self,
pattern_children: &[PatternChild],
node_children: &'a [Subtree],
) -> bool {
if pattern_children.is_empty() {
return true;
}
self.match_children_subsequence(pattern_children, node_children, 0, 0)
}
fn match_children_subsequence(
&mut self,
pattern_children: &[PatternChild],
node_children: &'a [Subtree],
pattern_idx: usize,
node_idx: usize,
) -> bool {
if pattern_idx >= pattern_children.len() {
return true;
}
let pattern_child = &pattern_children[pattern_idx];
match pattern_child.quantifier {
Quantifier::One => {
for i in node_idx..node_children.len() {
if self.match_pattern_node(&pattern_child.node, &node_children[i], 0) {
return self.match_children_subsequence(
pattern_children,
node_children,
pattern_idx + 1,
i + 1,
);
}
}
false
}
Quantifier::ZeroOrOne => {
for i in node_idx..node_children.len() {
if self.match_pattern_node(&pattern_child.node, &node_children[i], 0)
&& self.match_children_subsequence(
pattern_children,
node_children,
pattern_idx + 1,
i + 1,
)
{
return true;
}
}
self.match_children_subsequence(
pattern_children,
node_children,
pattern_idx + 1,
node_idx,
)
}
Quantifier::ZeroOrMore => {
let mut current_node_idx = node_idx;
loop {
if self.match_children_subsequence(
pattern_children,
node_children,
pattern_idx + 1,
current_node_idx,
) {
return true;
}
if current_node_idx >= node_children.len() {
break;
}
if self.match_pattern_node(
&pattern_child.node,
&node_children[current_node_idx],
0,
) {
current_node_idx += 1;
} else {
break;
}
}
false
}
Quantifier::OneOrMore => {
let mut matched = false;
for i in node_idx..node_children.len() {
if self.match_pattern_node(&pattern_child.node, &node_children[i], 0) {
matched = true;
if self.match_children_subsequence(
pattern_children,
node_children,
pattern_idx + 1,
i + 1,
) {
return true;
}
} else if matched {
return self.match_children_subsequence(
pattern_children,
node_children,
pattern_idx + 1,
i,
);
} else {
break;
}
}
false
}
}
}
fn check_predicates(&self, pattern: &Pattern) -> bool {
for &pred_index in &pattern.predicate_indices {
if let Some(predicate) = self.query.predicates.get(pred_index)
&& !self.check_predicate(predicate)
{
return false;
}
}
true
}
fn check_predicate(&self, predicate: &Predicate) -> bool {
match predicate {
Predicate::Equal(capture_ids) => {
if capture_ids.len() < 2 {
return true;
}
let first_text = self.get_capture_text(capture_ids[0]);
for &id in &capture_ids[1..] {
if self.get_capture_text(id) != first_text {
return false;
}
}
true
}
Predicate::NotEqual(capture_ids) => {
if capture_ids.len() < 2 {
return true;
}
let first_text = self.get_capture_text(capture_ids[0]);
for &id in &capture_ids[1..] {
if self.get_capture_text(id) == first_text {
return false;
}
}
true
}
Predicate::Match(capture_id, pattern) => {
let text = self.get_capture_text(*capture_id);
if let Ok(regex) = regex::Regex::new(pattern) {
regex.is_match(&text)
} else {
false
}
}
Predicate::NotMatch(capture_id, pattern) => {
let text = self.get_capture_text(*capture_id);
if let Ok(regex) = regex::Regex::new(pattern) {
!regex.is_match(&text)
} else {
true
}
}
Predicate::AnyOf(capture_id, values) => {
let text = self.get_capture_text(*capture_id);
values.iter().any(|v| v == &text)
}
}
}
fn get_capture_text(&self, capture_id: u32) -> String {
self.captures
.iter()
.find(|c| c.index == capture_id)
.map(|c| format!("{:?}", c.subtree.symbol)) .unwrap_or_default()
}
fn add_children_to_stack(&mut self, node: &'a Subtree, depth: usize) {
for child in node.children.iter().rev() {
self.node_stack.push((child, depth));
}
}
}
#[derive(Debug, Clone)]
pub enum QueryError {
EmptyQuery,
ExpectedOpenParen(usize),
ExpectedCloseParen(usize),
ExpectedCloseBracket(usize),
ExpectedColon(usize),
ExpectedHash(usize),
ExpectedQuestionMark(usize),
ExpectedAt(usize),
ExpectedIdentifier(usize),
ExpectedString(usize),
UnterminatedString(usize),
UnknownNodeType(String),
UnknownCapture(String),
UnknownPredicate(String),
InvalidPredicate(String),
}
impl fmt::Display for QueryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
QueryError::EmptyQuery => write!(f, "Query cannot be empty"),
QueryError::ExpectedOpenParen(pos) => write!(f, "Expected '(' at position {}", pos),
QueryError::ExpectedCloseParen(pos) => write!(f, "Expected ')' at position {}", pos),
QueryError::ExpectedCloseBracket(pos) => write!(f, "Expected ']' at position {}", pos),
QueryError::ExpectedColon(pos) => write!(f, "Expected ':' at position {}", pos),
QueryError::ExpectedHash(pos) => write!(f, "Expected '#' at position {}", pos),
QueryError::ExpectedQuestionMark(pos) => write!(f, "Expected '?' at position {}", pos),
QueryError::ExpectedAt(pos) => write!(f, "Expected '@' at position {}", pos),
QueryError::ExpectedIdentifier(pos) => {
write!(f, "Expected identifier at position {}", pos)
}
QueryError::ExpectedString(pos) => write!(f, "Expected string at position {}", pos),
QueryError::UnterminatedString(pos) => {
write!(f, "Unterminated string at position {}", pos)
}
QueryError::UnknownNodeType(name) => write!(f, "Unknown node type: {}", name),
QueryError::UnknownCapture(name) => write!(f, "Unknown capture: @{}", name),
QueryError::UnknownPredicate(name) => write!(f, "Unknown predicate: #{}?", name),
QueryError::InvalidPredicate(msg) => write!(f, "Invalid predicate: {}", msg),
}
}
}
impl std::error::Error for QueryError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_parser_simple() {
let mut grammar = Grammar::new("test".to_string());
let expr_id = SymbolId(0);
grammar.rule_names.insert(expr_id, "expression".to_string());
let add_id = SymbolId(1);
grammar.tokens.insert(
add_id,
adze_ir::Token {
name: "plus".to_string(),
pattern: adze_ir::TokenPattern::String("+".to_string()),
fragile: false,
},
);
let parser = QueryParser::new(&grammar, "(expression (plus))");
let query = parser.parse().unwrap();
assert_eq!(query.patterns.len(), 1);
assert_eq!(query.capture_names.len(), 0);
}
#[test]
fn test_query_parser_with_captures() {
let mut grammar = Grammar::new("test".to_string());
let expr_id = SymbolId(0);
grammar.rule_names.insert(expr_id, "expression".to_string());
let parser = QueryParser::new(&grammar, "(expression) @expr");
let query = parser.parse().unwrap();
assert_eq!(query.patterns.len(), 1);
assert_eq!(query.capture_names.len(), 1);
assert_eq!(query.capture_names.get("expr"), Some(&0));
}
#[test]
fn test_query_parser_with_quantifiers() {
let mut grammar = Grammar::new("test".to_string());
let list_id = SymbolId(0);
grammar.rule_names.insert(list_id, "list".to_string());
let item_id = SymbolId(1);
grammar.rule_names.insert(item_id, "item".to_string());
let parser = QueryParser::new(&grammar, "(list (item)*)");
let query = parser.parse().unwrap();
assert_eq!(query.patterns.len(), 1);
let pattern = &query.patterns[0];
assert_eq!(pattern.root.children.len(), 1);
assert_eq!(pattern.root.children[0].quantifier, Quantifier::ZeroOrMore);
}
}