use super::InputRange;
use crate::parser::grammar::{Production, ProductionId};
use crate::{
Term,
append_vec::{AppendOnlyVec, append_only_vec_id},
tracing,
};
append_only_vec_id!(pub(crate) TraversalId);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) enum TermMatch<'gram> {
Terminal(&'gram str),
Nonterminal(TraversalId),
}
#[derive(Debug)]
pub(crate) struct Traversal<'gram> {
pub id: TraversalId,
pub unmatched: &'gram [crate::Term],
pub input_range: InputRange<'gram>,
pub production_id: ProductionId,
pub is_starting: bool,
from: Option<TraversalEdge<'gram>>,
}
impl<'gram> Traversal<'gram> {
pub const fn next_unmatched(&self) -> Option<&'gram Term> {
self.unmatched.first()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TraversalEdge<'gram> {
pub term: TermMatch<'gram>,
pub parent_id: TraversalId,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct TraversalRoot {
production_id: ProductionId,
input_start: usize,
}
type TraversalArena<'gram> = AppendOnlyVec<Traversal<'gram>, TraversalId>;
type TreeRootMap = crate::HashMap<TraversalRoot, TraversalId>;
type TreeEdgeMap<'gram> = crate::HashMap<TraversalEdge<'gram>, TraversalId>;
#[derive(Debug)]
pub(crate) struct TraversalMatchIter<'gram, 'tree> {
tree: &'tree TraversalTree<'gram>,
current: TraversalId,
last: TraversalId,
}
impl<'gram, 'tree> TraversalMatchIter<'gram, 'tree> {
pub fn new(last: TraversalId, tree: &'tree TraversalTree<'gram>) -> Self {
let _span = tracing::span!(DEBUG, "match_iter_new").entered();
let mut current = last;
while let Some(edge) = &tree.get(current).from {
current = edge.parent_id;
}
Self {
current,
tree,
last,
}
}
}
impl<'gram, 'tree> Iterator for TraversalMatchIter<'gram, 'tree> {
type Item = &'tree TermMatch<'gram>;
fn next(&mut self) -> Option<Self::Item> {
let _span = tracing::span!(DEBUG, "match_iter_next").entered();
if self.current == self.last {
return None;
}
let mut scan = self.last;
while let Some(edge) = &self.tree.get(scan).from {
if self.current == edge.parent_id {
self.current = scan;
return Some(&edge.term);
}
scan = edge.parent_id;
}
None
}
}
#[derive(Debug, Default)]
pub(crate) struct TraversalTree<'gram> {
arena: TraversalArena<'gram>,
tree_roots: TreeRootMap,
edges: TreeEdgeMap<'gram>,
}
impl<'gram> TraversalTree<'gram> {
pub fn get(&self, id: TraversalId) -> &Traversal<'gram> {
self.arena.get(id).expect("valid traversal ID")
}
pub fn get_matching(&self, id: TraversalId) -> Option<&'gram Term> {
self.get(id).next_unmatched()
}
pub fn get_matched(&self, id: TraversalId) -> impl Iterator<Item = &TermMatch<'gram>> {
TraversalMatchIter::new(id, self)
}
fn _predict(
&mut self,
input_range: &InputRange<'gram>,
production: &Production<'gram>,
is_starting: bool,
) -> TraversalId {
let _span = tracing::span!(DEBUG, "traversal_tree_predict_is_starting").entered();
let production_id = production.id;
let traversal_root_key = TraversalRoot {
production_id,
input_start: input_range.offset.total_len(),
};
*self
.tree_roots
.entry(traversal_root_key)
.or_insert_with(|| {
let traversal = self.arena.push_with_id(|id| Traversal {
id,
production_id,
unmatched: &production.rhs.terms,
input_range: input_range.after(),
is_starting,
from: None,
});
traversal.id
})
}
pub fn predict_starting(
&mut self,
production: &Production<'gram>,
input_range: &InputRange<'gram>,
) -> TraversalId {
self._predict(input_range, production, true)
}
pub fn predict(
&mut self,
production: &Production<'gram>,
input_range: &InputRange<'gram>,
) -> TraversalId {
self._predict(input_range, production, false)
}
pub fn match_term(&mut self, parent: TraversalId, term: TermMatch<'gram>) -> TraversalId {
let _span = tracing::span!(DEBUG, "match_term").entered();
let parent = self.arena.get(parent).expect("valid parent traversal ID");
let input_range = match term {
TermMatch::Terminal(term) => parent.input_range.advance_by(term.len()),
TermMatch::Nonterminal(nonterminal_traversal_id) => {
let nonterminal_traversal = self
.arena
.get(nonterminal_traversal_id)
.expect("valid completed traversal ID");
parent
.input_range
.advance_by(nonterminal_traversal.input_range.offset.len)
}
};
let parent_id = parent.id;
let production_id = parent.production_id;
let unmatched = parent
.unmatched
.get(1..)
.expect("parent traversal has at least one unmatched term");
let is_starting = parent.is_starting;
let from = TraversalEdge { term, parent_id };
*self.edges.entry(from).or_insert_with_key(|from| {
let traversal = self.arena.push_with_id(|id| Traversal {
id,
production_id,
unmatched,
input_range,
is_starting,
from: Some(from.clone()),
});
traversal.id
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Grammar;
use crate::parser::grammar::ParseGrammar;
fn dna_grammar() -> Grammar {
let grammar: Grammar = "<dna> ::= <base> | <base> <dna>
<base> ::= 'A' | 'C' | 'G' | 'T'"
.parse()
.unwrap();
grammar
}
fn traversal_test_setup<'a>(
grammar: &'a Grammar,
input: &'static str,
) -> (ParseGrammar<'a>, InputRange<'static>, TraversalTree<'a>) {
let matching = ParseGrammar::new(grammar).unwrap();
let input = InputRange::new(input);
let tree = TraversalTree::default();
(matching, input, tree)
}
#[test]
fn predict() {
let grammar = dna_grammar();
let (grammar, input, mut tree) = traversal_test_setup(&grammar, "GATTACA");
let production = grammar.productions_iter().next().unwrap();
let id = tree.predict(production, &input);
let predicted = tree.get(id);
assert_eq!(&production.rhs.terms, predicted.unmatched);
}
#[test]
fn predict_again() {
let grammar = dna_grammar();
let (grammar, input, mut tree) = traversal_test_setup(&grammar, "GATTACA");
let production = grammar.productions_iter().next().unwrap();
let first = tree.predict(production, &input);
let again = tree.predict(production, &input);
assert_eq!(first, again);
}
#[test]
fn match_term() {
let grammar = "<start> ::= 'A'".parse().unwrap();
let (grammar, input, mut tree) = traversal_test_setup(&grammar, "AAAA");
let production = grammar.productions_iter().next().unwrap();
let prediction = tree.predict(production, &input);
let term_match = TermMatch::Terminal("A");
let id = tree.match_term(prediction, term_match);
let scanned = tree.get(id);
assert_eq!(scanned.unmatched, production.rhs.terms.get(1..).unwrap());
}
#[test]
fn match_term_again() {
let grammar = "<start> ::= 'A'".parse().unwrap();
let (grammar, input, mut tree) = traversal_test_setup(&grammar, "AAAA");
let production = grammar.productions_iter().next().unwrap();
let prediction = tree.predict(production, &input);
let term_match = TermMatch::Terminal("A");
let first = tree.match_term(prediction, term_match.clone());
let again = tree.match_term(prediction, term_match.clone());
assert_eq!(first, again);
}
#[test]
fn match_term_complete() {
let grammar = "<start> ::= 'A' | 'B' | 'C'".parse().unwrap();
let (grammar, input, mut tree) = traversal_test_setup(&grammar, "ABC");
let production = grammar.productions_iter().next().unwrap();
let prediction = tree.predict(production, &input);
for term_match in ["A", "B", "C"] {
let term_match = TermMatch::Terminal(term_match);
let id = tree.match_term(prediction, term_match);
let traversal = tree.get(id);
assert_eq!(traversal.next_unmatched(), None);
}
}
}