use super::ast::*;
use super::predicate_eval::PredicateContext;
use crate::parser_v4::ParseNode;
use adze_glr_core::SymbolMetadata;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct QueryMatch {
pub pattern_index: usize,
pub captures: Vec<QueryCapture>,
}
#[derive(Debug, Clone)]
pub struct QueryCapture {
pub index: u32,
pub node: ParseNode,
}
#[derive(Debug)]
struct MatchState {
captures: HashMap<u32, ParseNode>,
}
pub struct QueryMatcher<'a> {
query: &'a Query,
source: &'a str,
symbol_metadata: &'a [SymbolMetadata],
}
impl<'a> QueryMatcher<'a> {
pub fn new(query: &'a Query, source: &'a str, symbol_metadata: &'a [SymbolMetadata]) -> Self {
QueryMatcher {
query,
source,
symbol_metadata,
}
}
pub fn matches(&self, root: &ParseNode) -> Vec<QueryMatch> {
let mut matches = Vec::new();
for (pattern_index, pattern) in self.query.patterns.iter().enumerate() {
self.match_pattern(pattern_index, pattern, root, &mut matches);
}
matches
}
fn match_pattern(
&self,
pattern_index: usize,
pattern: &Pattern,
root: &ParseNode,
matches: &mut Vec<QueryMatch>,
) {
self.match_pattern_at_node(pattern_index, pattern, root, matches);
}
fn match_pattern_at_node(
&self,
pattern_index: usize,
pattern: &Pattern,
node: &ParseNode,
matches: &mut Vec<QueryMatch>,
) {
let mut state = MatchState {
captures: HashMap::new(),
};
if self.match_node(&pattern.root, node, &mut state) {
let predicate_ctx = PredicateContext::new(self.source);
if pattern
.predicates
.iter()
.all(|pred| predicate_ctx.evaluate(pred, &state.captures))
{
let mut captures: Vec<_> = state
.captures
.into_iter()
.map(|(index, node)| QueryCapture { index, node })
.collect();
captures.sort_by_key(|c| c.index);
matches.push(QueryMatch {
pattern_index,
captures,
});
}
}
for child in &node.children {
self.match_pattern_at_node(pattern_index, pattern, child, matches);
}
}
fn match_node(&self, pattern: &PatternNode, node: &ParseNode, state: &mut MatchState) -> bool {
if pattern.symbol != node.symbol {
return false;
}
if self.node_is_named(node) != pattern.is_named {
return false;
}
if let Some(capture_id) = pattern.capture {
state.captures.insert(capture_id, node.clone());
}
match pattern.quantifier {
Quantifier::One => self.match_children_one(pattern, node, state),
Quantifier::Optional => self.match_children_optional(pattern, node, state),
Quantifier::Plus => self.match_children_plus(pattern, node, state),
Quantifier::Star => self.match_children_star(pattern, node, state),
}
}
fn match_children_one(
&self,
pattern: &PatternNode,
node: &ParseNode,
state: &mut MatchState,
) -> bool {
for (field_name, field_pattern) in &pattern.fields {
let field_node = node
.children
.iter()
.find(|child| child.field_name.as_ref() == Some(field_name));
if let Some(field_node) = field_node {
if !self.match_node(field_pattern, field_node, state) {
return false;
}
} else {
return false; }
}
if !pattern.children.is_empty() {
return self.match_child_sequence(&pattern.children, &node.children, 0, 0, state);
}
true
}
fn match_children_optional(
&self,
pattern: &PatternNode,
node: &ParseNode,
state: &mut MatchState,
) -> bool {
self.match_children_one(pattern, node, state);
true
}
fn match_children_plus(
&self,
pattern: &PatternNode,
node: &ParseNode,
state: &mut MatchState,
) -> bool {
if !self.match_children_one(pattern, node, state) {
return false;
}
true
}
fn match_children_star(
&self,
pattern: &PatternNode,
node: &ParseNode,
state: &mut MatchState,
) -> bool {
self.match_children_plus(pattern, node, state);
true
}
fn node_is_named(&self, node: &ParseNode) -> bool {
self.symbol_metadata
.get(node.symbol.0 as usize)
.map(|m| m.is_named)
.unwrap_or(true)
}
fn node_is_extra(&self, node: &ParseNode) -> bool {
self.symbol_metadata
.get(node.symbol.0 as usize)
.map(|m| m.is_extra)
.unwrap_or(false)
}
fn match_child_sequence(
&self,
pattern_children: &[PatternChild],
node_children: &[ParseNode],
pattern_idx: usize,
node_idx: usize,
state: &mut MatchState,
) -> bool {
if pattern_idx >= pattern_children.len() {
return node_children[node_idx..]
.iter()
.all(|n| self.node_is_extra(n));
}
let mut node_idx = node_idx;
while node_idx < node_children.len() && self.node_is_extra(&node_children[node_idx]) {
node_idx += 1;
}
if node_idx >= node_children.len() {
return pattern_children[pattern_idx..]
.iter()
.all(|p| matches!(p, PatternChild::Node(n) if n.quantifier != Quantifier::One));
}
match &pattern_children[pattern_idx] {
PatternChild::Node(pattern_node) => {
if self.match_node(pattern_node, &node_children[node_idx], state) {
self.match_child_sequence(
pattern_children,
node_children,
pattern_idx + 1,
node_idx + 1,
state,
)
} else if pattern_node.quantifier != Quantifier::One {
self.match_child_sequence(
pattern_children,
node_children,
pattern_idx + 1,
node_idx,
state,
)
} else {
false
}
}
PatternChild::Token(_token) => {
self.match_child_sequence(
pattern_children,
node_children,
pattern_idx + 1,
node_idx + 1,
state,
)
}
}
}
}
pub struct QueryMatches<'a> {
#[allow(dead_code)]
matcher: QueryMatcher<'a>,
#[allow(dead_code)]
root: &'a ParseNode,
#[allow(dead_code)]
pattern_index: usize,
matches: Vec<QueryMatch>,
current_index: usize,
}
impl<'a> QueryMatches<'a> {
pub fn new(
query: &'a Query,
root: &'a ParseNode,
source: &'a str,
symbol_metadata: &'a [SymbolMetadata],
) -> Self {
let matcher = QueryMatcher::new(query, source, symbol_metadata);
let matches = matcher.matches(root);
QueryMatches {
matcher,
root,
pattern_index: 0,
matches,
current_index: 0,
}
}
}
impl<'a> Iterator for QueryMatches<'a> {
type Item = QueryMatch;
fn next(&mut self) -> Option<Self::Item> {
if self.current_index < self.matches.len() {
let match_item = self.matches[self.current_index].clone();
self.current_index += 1;
Some(match_item)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::compile_query;
use adze_glr_core::SymbolMetadata;
use adze_ir::{Grammar, SymbolId, Token, TokenPattern};
fn make_node(symbol: u16, start: usize, end: usize) -> ParseNode {
let symbol_id = SymbolId(symbol);
ParseNode {
symbol: symbol_id,
symbol_id,
children: vec![],
start_byte: start,
end_byte: end,
field_name: None,
}
}
fn create_test_grammar() -> Grammar {
let mut grammar = Grammar::new("test".to_string());
grammar.tokens.insert(
SymbolId(1),
Token {
name: "identifier".to_string(),
pattern: TokenPattern::Regex("[a-zA-Z]+".to_string()),
fragile: false,
},
);
grammar
}
fn test_symbol_metadata() -> Vec<SymbolMetadata> {
vec![
SymbolMetadata {
name: "root".to_string(),
is_visible: true,
is_named: true,
is_supertype: false,
is_terminal: false,
is_extra: false,
is_fragile: false,
symbol_id: SymbolId(0),
},
SymbolMetadata {
name: "identifier".to_string(),
is_visible: true,
is_named: true,
is_supertype: false,
is_terminal: true,
is_extra: false,
is_fragile: false,
symbol_id: SymbolId(1),
},
]
}
#[test]
fn test_predicate_matching() {
let query_str = r#"
(identifier @name)
(#eq? @name "test")
"#;
let grammar = create_test_grammar();
let query = compile_query(query_str, &grammar).unwrap();
let source = "test other test";
let symbol_id = SymbolId(0);
let root = ParseNode {
symbol: symbol_id,
symbol_id,
children: vec![
make_node(1, 0, 4), make_node(1, 5, 10), make_node(1, 11, 15), ],
start_byte: 0,
end_byte: 15,
field_name: None,
};
let metadata = test_symbol_metadata();
let matcher = QueryMatcher::new(&query, source, &metadata);
let matches = matcher.matches(&root);
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].captures[0].node.start_byte, 0);
assert_eq!(matches[1].captures[0].node.start_byte, 11);
}
#[test]
fn test_query_without_predicates() {
let query_str = "(identifier @name)";
let grammar = create_test_grammar();
let query = compile_query(query_str, &grammar).unwrap();
let source = "foo bar baz";
let root = ParseNode {
symbol: SymbolId(0),
symbol_id: SymbolId(0),
children: vec![
make_node(1, 0, 3), make_node(1, 4, 7), make_node(1, 8, 11), ],
start_byte: 0,
end_byte: 11,
field_name: None,
};
let metadata = test_symbol_metadata();
let matcher = QueryMatcher::new(&query, source, &metadata);
let matches = matcher.matches(&root);
assert_eq!(matches.len(), 3);
assert_eq!(matches[0].captures[0].node.start_byte, 0);
assert_eq!(matches[1].captures[0].node.start_byte, 4);
assert_eq!(matches[2].captures[0].node.start_byte, 8);
}
#[test]
fn test_empty_query_result() {
let query_str = r#"
(identifier @name)
(#eq? @name "nonexistent")
"#;
let grammar = create_test_grammar();
let query = compile_query(query_str, &grammar).unwrap();
let source = "test other test";
let root = ParseNode {
symbol: SymbolId(0),
symbol_id: SymbolId(0),
children: vec![
make_node(1, 0, 4), make_node(1, 5, 10), make_node(1, 11, 15), ],
start_byte: 0,
end_byte: 15,
field_name: None,
};
let metadata = test_symbol_metadata();
let matcher = QueryMatcher::new(&query, source, &metadata);
let matches = matcher.matches(&root);
assert_eq!(matches.len(), 0);
}
}