use std::collections::VecDeque;
use std::fmt;
use std::hash::{Hash, Hasher};
use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::SmallVec;
use super::forest::{ForestChild, ForestNode, ForestNodeId, ParseForest};
use super::grammar::Grammar;
use super::types::{NonTerminal, RuleId, Symbol, Terminal};
use crate::backend::LatticeBackend;
use crate::lattice::{EdgeId, Lattice, NodeId};
use crate::semiring::Semiring;
#[derive(Clone, Debug)]
pub struct EarleyState {
pub rule: RuleId,
pub dot: usize,
pub start: NodeId,
pub forest_node: Option<ForestNodeId>,
pub terminal_edges: SmallVec<[EdgeId; 4]>,
pub child_nodes: SmallVec<[ForestChild; 4]>,
}
impl PartialEq for EarleyState {
fn eq(&self, other: &Self) -> bool {
self.rule == other.rule && self.dot == other.dot && self.start == other.start
}
}
impl Eq for EarleyState {}
impl Hash for EarleyState {
fn hash<H: Hasher>(&self, state: &mut H) {
self.rule.hash(state);
self.dot.hash(state);
self.start.hash(state);
}
}
impl EarleyState {
pub fn new(rule: RuleId, dot: usize, start: NodeId) -> Self {
Self {
rule,
dot,
start,
forest_node: None,
terminal_edges: SmallVec::new(),
child_nodes: SmallVec::new(),
}
}
pub fn with_forest(rule: RuleId, dot: usize, start: NodeId, forest: ForestNodeId) -> Self {
Self {
rule,
dot,
start,
forest_node: Some(forest),
terminal_edges: SmallVec::new(),
child_nodes: SmallVec::new(),
}
}
pub fn advance(&self) -> Self {
Self {
rule: self.rule,
dot: self.dot + 1,
start: self.start,
forest_node: self.forest_node,
terminal_edges: self.terminal_edges.clone(),
child_nodes: self.child_nodes.clone(),
}
}
pub fn advance_with_terminal(&self, edge_id: EdgeId) -> Self {
let mut terminal_edges = self.terminal_edges.clone();
terminal_edges.push(edge_id);
let mut child_nodes = self.child_nodes.clone();
child_nodes.push(ForestChild::Terminal(edge_id));
Self {
rule: self.rule,
dot: self.dot + 1,
start: self.start,
forest_node: self.forest_node,
terminal_edges,
child_nodes,
}
}
pub fn advance_with_nonterminal(&self, child_forest_node: ForestNodeId) -> Self {
let mut child_nodes = self.child_nodes.clone();
child_nodes.push(ForestChild::Derivation(smallvec::smallvec![
child_forest_node
]));
Self {
rule: self.rule,
dot: self.dot + 1,
start: self.start,
forest_node: Some(child_forest_node),
terminal_edges: self.terminal_edges.clone(),
child_nodes,
}
}
pub fn is_complete(&self, grammar: &Grammar) -> bool {
let prod = grammar.production(self.rule).expect("valid rule");
self.dot >= prod.rhs_len()
}
pub fn next_symbol<'a>(&self, grammar: &'a Grammar) -> Option<&'a Symbol> {
let prod = grammar.production(self.rule)?;
prod.rhs_at(self.dot)
}
pub fn lhs(&self, grammar: &Grammar) -> NonTerminal {
grammar.production(self.rule).expect("valid rule").lhs
}
}
impl fmt::Display for EarleyState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[{}, {}, {}, {:?}, edges: {}]",
self.rule,
self.dot,
self.start.0,
self.forest_node,
self.terminal_edges.len()
)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct EarleyKey {
rule: RuleId,
dot: usize,
start: NodeId,
}
impl From<&EarleyState> for EarleyKey {
fn from(state: &EarleyState) -> Self {
Self {
rule: state.rule,
dot: state.dot,
start: state.start,
}
}
}
#[derive(Clone, Debug)]
pub struct EarleyChart {
positions: FxHashMap<NodeId, FxHashMap<EarleyKey, EarleyState>>,
agenda: VecDeque<(NodeId, EarleyState)>,
processed: FxHashSet<(NodeId, EarleyKey)>,
}
impl EarleyChart {
pub fn new() -> Self {
Self {
positions: FxHashMap::default(),
agenda: VecDeque::new(),
processed: FxHashSet::default(),
}
}
pub fn add(&mut self, pos: NodeId, state: EarleyState) -> bool {
let key = EarleyKey::from(&state);
let map = self.positions.entry(pos).or_default();
if let Some(existing) = map.get_mut(&key) {
for child in state.child_nodes {
if !existing.child_nodes.contains(&child) {
existing.child_nodes.push(child);
}
}
for edge in state.terminal_edges {
if !existing.terminal_edges.contains(&edge) {
existing.terminal_edges.push(edge);
}
}
if state.forest_node.is_some() && existing.forest_node.is_none() {
existing.forest_node = state.forest_node;
}
false } else {
map.insert(key.clone(), state.clone());
if !self.processed.contains(&(pos, key.clone())) {
self.processed.insert((pos, key));
self.agenda.push_back((pos, state));
}
true
}
}
pub fn at(&self, pos: NodeId) -> impl Iterator<Item = &EarleyState> {
self.positions
.get(&pos)
.into_iter()
.flat_map(|s| s.values())
}
pub fn pop(&mut self) -> Option<(NodeId, EarleyState)> {
self.agenda.pop_front()
}
pub fn is_agenda_empty(&self) -> bool {
self.agenda.is_empty()
}
pub fn len(&self) -> usize {
self.positions.values().map(|s| s.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.positions.is_empty()
}
}
impl Default for EarleyChart {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ParseError {
NoParse,
EmptyLattice,
GrammarError(String),
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ParseError::NoParse => write!(f, "no complete parse found"),
ParseError::EmptyLattice => write!(f, "empty lattice"),
ParseError::GrammarError(msg) => write!(f, "grammar error: {}", msg),
}
}
}
impl std::error::Error for ParseError {}
pub struct EarleyParser<'g> {
grammar: &'g Grammar,
nullable: Vec<bool>,
}
impl<'g> EarleyParser<'g> {
pub fn new(grammar: &'g Grammar) -> Self {
let nullable = grammar.compute_nullable();
Self { grammar, nullable }
}
pub fn parse_lattice<W: Semiring, B: LatticeBackend>(
&self,
lattice: &Lattice<W, B>,
) -> Result<ParseForest, ParseError> {
if lattice.is_empty() && lattice.start() != lattice.end() {
return Err(ParseError::EmptyLattice);
}
let mut chart = EarleyChart::new();
let mut forest = ParseForest::new();
let start = lattice.start();
self.predict(&mut chart, start, self.grammar.start());
while let Some((pos, state)) = chart.pop() {
if state.is_complete(self.grammar) {
self.complete(&mut chart, &mut forest, lattice, pos, &state);
} else {
let next_sym = state.next_symbol(self.grammar);
match next_sym {
Some(Symbol::NonTerminal(nt)) => {
self.predict(&mut chart, pos, *nt);
if self.nullable[nt.index() as usize] {
let advanced = state.advance();
chart.add(pos, advanced);
}
}
Some(Symbol::Terminal(terminal)) => {
self.scan(&mut chart, lattice, pos, &state, *terminal);
}
Some(Symbol::Epsilon) => {
let advanced = state.advance();
chart.add(pos, advanced);
}
None => {
}
}
}
}
if !forest.is_empty() {
Ok(forest)
} else {
Err(ParseError::NoParse)
}
}
fn predict(&self, chart: &mut EarleyChart, pos: NodeId, nt: NonTerminal) {
for prod in self.grammar.productions_for(nt) {
let state = EarleyState::new(prod.id, 0, pos);
chart.add(pos, state);
}
}
fn scan<W: Semiring, B: LatticeBackend>(
&self,
chart: &mut EarleyChart,
lattice: &Lattice<W, B>,
pos: NodeId,
state: &EarleyState,
terminal: Terminal,
) {
for edge in lattice.outgoing_edges(pos) {
if edge.label == terminal.vocab_id() {
let advanced = state.advance_with_terminal(edge.id);
chart.add(edge.target, advanced);
}
}
}
fn complete<W: Semiring, B: LatticeBackend>(
&self,
chart: &mut EarleyChart,
forest: &mut ParseForest,
lattice: &Lattice<W, B>,
pos: NodeId,
completed: &EarleyState,
) {
let completed_nt = completed.lhs(self.grammar);
let mut node = ForestNode::new(completed.rule, completed.start, pos);
node.children = completed.child_nodes.clone();
let forest_node = forest.add_node(node);
if completed_nt == self.grammar.start()
&& completed.start == lattice.start()
&& pos == lattice.end()
{
forest.add_root(forest_node);
}
let waiting: Vec<_> = chart.at(completed.start)
.filter(|s| {
!s.is_complete(self.grammar) &&
matches!(s.next_symbol(self.grammar), Some(Symbol::NonTerminal(nt)) if *nt == completed_nt)
})
.cloned()
.collect();
for waiter in waiting {
let advanced = waiter.advance_with_nonterminal(forest_node);
chart.add(pos, advanced);
}
}
pub fn accepts<W: Semiring, B: LatticeBackend>(&self, lattice: &Lattice<W, B>) -> bool {
self.parse_lattice(lattice).is_ok()
}
pub fn grammar(&self) -> &Grammar {
self.grammar
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::HashMapBackend;
use crate::cfg::GrammarBuilder;
use crate::lattice::{EdgeMetadata, LatticeBuilder};
use crate::semiring::TropicalWeight;
fn simple_grammar() -> Grammar {
GrammarBuilder::new()
.start("S")
.rule("S", &["NP", "VP"])
.rule("NP", &["Det", "N"])
.rule("VP", &["V", "NP"])
.rule("VP", &["V"])
.rule("Det", &["the"])
.rule("Det", &["a"])
.rule("N", &["dog"])
.rule("N", &["cat"])
.rule("V", &["saw"])
.rule("V", &["chased"])
.build()
.expect("valid grammar")
}
fn build_lattice(words: &[&str], grammar: &Grammar) -> Lattice<TropicalWeight, HashMapBackend> {
let mut backend = HashMapBackend::new();
let word_ids: Vec<_> = words
.iter()
.map(|w| {
let t = grammar
.terminal_by_name(w)
.expect(&format!("unknown word: {}", w));
let _id = backend.intern(w);
t.vocab_id()
})
.collect();
let mut builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
for (i, &id) in word_ids.iter().enumerate() {
builder.add_correction_by_id(
i,
i + 1,
id,
TropicalWeight::one(),
EdgeMetadata::default(),
);
}
builder.build(words.len())
}
#[test]
fn test_earley_state() {
let state = EarleyState::new(RuleId::new(0), 1, NodeId(0));
assert_eq!(state.rule, RuleId::new(0));
assert_eq!(state.dot, 1);
assert_eq!(state.start, NodeId(0));
assert!(state.forest_node.is_none());
let advanced = state.advance();
assert_eq!(advanced.dot, 2);
}
#[test]
fn test_earley_chart() {
let mut chart = EarleyChart::new();
assert!(chart.is_empty());
let state = EarleyState::new(RuleId::new(0), 0, NodeId(0));
assert!(chart.add(NodeId(0), state.clone()));
assert!(!chart.add(NodeId(0), state.clone()));
assert_eq!(chart.len(), 1);
assert!(!chart.is_agenda_empty());
let (pos, s) = chart.pop().expect("item");
assert_eq!(pos, NodeId(0));
assert_eq!(s.rule, RuleId::new(0));
}
#[test]
fn test_parse_simple_sentence() {
let grammar = simple_grammar();
let parser = EarleyParser::new(&grammar);
let lattice = build_lattice(&["the", "dog", "saw", "a", "cat"], &grammar);
let result = parser.parse_lattice(&lattice);
assert!(result.is_ok(), "Parse should succeed: {:?}", result);
}
#[test]
fn test_parse_intransitive() {
let grammar = simple_grammar();
let parser = EarleyParser::new(&grammar);
let lattice = build_lattice(&["the", "dog", "saw"], &grammar);
let result = parser.parse_lattice(&lattice);
assert!(result.is_ok(), "Parse should succeed: {:?}", result);
}
#[test]
fn test_parse_failure() {
let grammar = simple_grammar();
let parser = EarleyParser::new(&grammar);
let mut backend = HashMapBackend::new();
let _saw_interned = backend.intern("saw");
let _the_interned = backend.intern("the");
let saw_id = grammar.terminal_by_name("saw").expect("saw").vocab_id();
let the_id = grammar.terminal_by_name("the").expect("the").vocab_id();
let mut builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
builder.add_correction_by_id(0, 1, saw_id, TropicalWeight::one(), EdgeMetadata::default());
builder.add_correction_by_id(1, 2, the_id, TropicalWeight::one(), EdgeMetadata::default());
let lattice = builder.build(2);
let result = parser.parse_lattice(&lattice);
assert!(matches!(result, Err(ParseError::NoParse)));
}
#[test]
fn test_accepts() {
let grammar = simple_grammar();
let parser = EarleyParser::new(&grammar);
let lattice = build_lattice(&["the", "dog", "saw"], &grammar);
assert!(parser.accepts(&lattice));
}
#[test]
fn test_nullable_handling() {
let grammar = GrammarBuilder::new()
.start("S")
.rule("S", &["A", "B"])
.epsilon_rule("A")
.rule("A", &["a"])
.rule("B", &["b"])
.build()
.expect("valid grammar");
let parser = EarleyParser::new(&grammar);
let mut backend = HashMapBackend::new();
let _b_interned = backend.intern("b");
let b_id = grammar.terminal_by_name("b").expect("b").vocab_id();
let mut builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
builder.add_correction_by_id(0, 1, b_id, TropicalWeight::one(), EdgeMetadata::default());
let lattice = builder.build(1);
let result = parser.parse_lattice(&lattice);
assert!(
result.is_ok(),
"Parse with nullable should succeed: {:?}",
result
);
}
}