use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use super::ast::{Grammar, NonTerminalId, RuleId, Symbol};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct EarleyItem {
pub rule: RuleId,
pub dot: usize,
pub origin: usize,
}
impl EarleyItem {
#[inline]
pub fn new(rule: RuleId, dot: usize, origin: usize) -> Self {
Self { rule, dot, origin }
}
#[inline]
pub fn is_complete(&self, rhs_len: usize) -> bool {
self.dot >= rhs_len
}
}
pub struct FirstSets {
pub first: HashMap<NonTerminalId, HashSet<u8>>,
pub nullable: HashMap<NonTerminalId, bool>,
}
impl FirstSets {
pub fn compute(grammar: &Grammar) -> Self {
let n = grammar.nt_count;
let mut first: HashMap<NonTerminalId, HashSet<u8>> =
(0..n).map(|i| (i, HashSet::new())).collect();
let mut nullable: HashMap<NonTerminalId, bool> = (0..n).map(|i| (i, false)).collect();
loop {
let mut changed = false;
for rule in &grammar.rules {
let lhs = rule.lhs;
if rule.rhs.is_empty() {
if !nullable[&lhs] {
*nullable
.get_mut(&lhs)
.expect("lhs key inserted during initialization") = true;
changed = true;
}
continue;
}
let mut all_nullable = true;
for sym in &rule.rhs {
match sym {
Symbol::Terminal(bytes) => {
if let Some(&b) = bytes.first() {
if first
.get_mut(&lhs)
.expect("lhs key inserted during initialization")
.insert(b)
{
changed = true;
}
}
if !bytes.is_empty() {
all_nullable = false;
break;
}
}
Symbol::NonTerminal(nt) => {
let first_nt: HashSet<u8> = first.get(nt).cloned().unwrap_or_default();
for &b in &first_nt {
if first
.get_mut(&lhs)
.expect("lhs key inserted during initialization")
.insert(b)
{
changed = true;
}
}
if !nullable.get(nt).copied().unwrap_or(false) {
all_nullable = false;
break;
}
}
}
}
if all_nullable && !nullable[&lhs] {
*nullable
.get_mut(&lhs)
.expect("lhs key inserted during initialization") = true;
changed = true;
}
}
if !changed {
break;
}
}
Self { first, nullable }
}
pub fn first_of_symbol(&self, sym: &Symbol) -> HashSet<u8> {
match sym {
Symbol::Terminal(bytes) => {
if let Some(&b) = bytes.first() {
let mut s = HashSet::new();
s.insert(b);
s
} else {
HashSet::new() }
}
Symbol::NonTerminal(nt) => self.first.get(nt).cloned().unwrap_or_default(),
}
}
}
pub struct EarleyRecognizer {
grammar: Arc<Grammar>,
first_sets: Arc<FirstSets>,
chart: Vec<HashSet<EarleyItem>>,
pub input_pos: usize,
rule_index: Arc<HashMap<NonTerminalId, Vec<RuleId>>>,
}
impl EarleyRecognizer {
pub fn new(grammar: Arc<Grammar>) -> Self {
let first_sets = Arc::new(FirstSets::compute(&grammar));
let rule_index = Arc::new(build_rule_index(&grammar));
let mut recognizer = Self {
grammar,
first_sets,
chart: vec![HashSet::new()],
input_pos: 0,
rule_index,
};
recognizer.init_chart_zero();
recognizer
}
fn init_chart_zero(&mut self) {
let start = self.grammar.start();
if let Some(rule_ids) = self.rule_index.get(&start).cloned() {
for rule_id in rule_ids {
self.chart[0].insert(EarleyItem::new(rule_id, 0, 0));
}
}
self.closure(0);
}
fn closure(&mut self, k: usize) {
let mut worklist: VecDeque<EarleyItem> = self.chart[k].iter().cloned().collect();
while let Some(item) = worklist.pop_front() {
let rule = &self.grammar.rules[item.rule];
let rhs_len = rule.rhs.len();
if item.dot >= rhs_len {
let completed_nt = rule.lhs;
let origin = item.origin;
let origin_items: Vec<EarleyItem> = self.chart[origin].iter().cloned().collect();
for orig_item in origin_items {
let orig_rule = &self.grammar.rules[orig_item.rule];
let orig_rhs_len = orig_rule.rhs.len();
if orig_item.dot < orig_rhs_len {
if let Symbol::NonTerminal(nt) = &orig_rule.rhs[orig_item.dot] {
if *nt == completed_nt {
let advanced = EarleyItem::new(
orig_item.rule,
orig_item.dot + 1,
orig_item.origin,
);
if self.chart[k].insert(advanced.clone()) {
worklist.push_back(advanced);
}
}
}
}
}
} else {
match &self.grammar.rules[item.rule].rhs[item.dot] {
Symbol::NonTerminal(nt) => {
let nt = *nt;
if let Some(rule_ids) = self.rule_index.get(&nt).cloned() {
for rule_id in rule_ids {
let new_item = EarleyItem::new(rule_id, 0, k);
if self.chart[k].insert(new_item.clone()) {
worklist.push_back(new_item);
}
}
}
}
Symbol::Terminal(bytes) => {
if bytes.is_empty() {
let advanced = EarleyItem::new(item.rule, item.dot + 1, item.origin);
if self.chart[k].insert(advanced.clone()) {
worklist.push_back(advanced);
}
}
}
}
}
}
}
pub fn feed_byte(&mut self, byte: u8) -> bool {
let k = self.input_pos;
let mut next_set: HashSet<EarleyItem> = HashSet::new();
for item in &self.chart[k] {
let rule = &self.grammar.rules[item.rule];
let rhs_len = rule.rhs.len();
if item.dot < rhs_len {
if let Symbol::Terminal(bytes) = &rule.rhs[item.dot] {
if bytes.len() == 1 && bytes[0] == byte {
next_set.insert(EarleyItem::new(item.rule, item.dot + 1, item.origin));
}
}
}
}
if next_set.is_empty() {
self.chart.push(HashSet::new());
self.input_pos += 1;
return false;
}
self.chart.push(next_set);
self.input_pos += 1;
self.closure(self.input_pos);
true
}
pub fn is_accepting(&self) -> bool {
let start = self.grammar.start();
let k = self.input_pos;
self.chart[k].iter().any(|item| {
let rule = &self.grammar.rules[item.rule];
rule.lhs == start && item.origin == 0 && item.dot == rule.rhs.len()
})
}
pub fn is_live(&self) -> bool {
!self.chart[self.input_pos].is_empty()
}
pub fn next_byte_set(&self) -> HashSet<u8> {
let k = self.input_pos;
let mut result: HashSet<u8> = HashSet::new();
for item in &self.chart[k] {
let rule = &self.grammar.rules[item.rule];
let rhs_len = rule.rhs.len();
if item.dot < rhs_len {
match &rule.rhs[item.dot] {
Symbol::Terminal(bytes) => {
if bytes.len() == 1 {
result.insert(bytes[0]);
}
}
Symbol::NonTerminal(nt) => {
let first_nt = self.first_sets.first.get(nt).cloned().unwrap_or_default();
result.extend(first_nt);
}
}
}
}
result
}
pub fn reset(&mut self) {
self.chart.clear();
self.chart.push(HashSet::new());
self.input_pos = 0;
self.init_chart_zero();
}
pub fn clone_state(&self) -> Self {
Self {
grammar: Arc::clone(&self.grammar),
first_sets: Arc::clone(&self.first_sets),
chart: self.chart.clone(),
input_pos: self.input_pos,
rule_index: Arc::clone(&self.rule_index),
}
}
pub fn feed_bytes(&mut self, bytes: &[u8]) -> bool {
for &b in bytes {
if !self.feed_byte(b) {
return false;
}
}
true
}
pub fn grammar(&self) -> &Grammar {
&self.grammar
}
pub fn active_item_count(&self) -> usize {
self.chart[self.input_pos].len()
}
pub fn state_hash(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut items: Vec<(usize, usize, usize)> = self.chart[self.input_pos]
.iter()
.map(|item| (item.rule, item.dot, item.origin))
.collect();
items.sort_unstable();
let mut hasher = DefaultHasher::new();
items.hash(&mut hasher);
self.input_pos.hash(&mut hasher);
hasher.finish()
}
}
fn build_rule_index(grammar: &Grammar) -> HashMap<NonTerminalId, Vec<RuleId>> {
let mut index: HashMap<NonTerminalId, Vec<RuleId>> = HashMap::new();
for (rule_id, rule) in grammar.rules.iter().enumerate() {
index.entry(rule.lhs).or_default().push(rule_id);
}
index
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grammar::bnf_parser::parse_bnf;
fn recognizer_from_bnf(bnf: &str) -> EarleyRecognizer {
let mut g = parse_bnf(bnf).expect("valid BNF");
g.normalise_terminals();
EarleyRecognizer::new(Arc::new(g))
}
fn feed_str(rec: &mut EarleyRecognizer, s: &str) -> bool {
for b in s.bytes() {
if !rec.feed_byte(b) {
return false;
}
}
true
}
#[test]
fn earley_accepts_single_terminal() {
let mut r = recognizer_from_bnf(r#"<S> ::= "x""#);
assert!(feed_str(&mut r, "x"));
assert!(r.is_accepting());
}
#[test]
fn earley_rejects_wrong_terminal() {
let mut r = recognizer_from_bnf(r#"<S> ::= "x""#);
assert!(!feed_str(&mut r, "y"));
assert!(!r.is_accepting());
}
#[test]
fn earley_accepts_epsilon_rule_empty_input() {
let r = recognizer_from_bnf(r#"<S> ::= """#);
assert!(r.is_live());
}
#[test]
fn earley_accepts_alternation_first() {
let mut r = recognizer_from_bnf(r#"<S> ::= "a" | "b""#);
assert!(feed_str(&mut r, "a"));
assert!(r.is_accepting());
}
#[test]
fn earley_accepts_alternation_second() {
let mut r = recognizer_from_bnf(r#"<S> ::= "a" | "b""#);
assert!(feed_str(&mut r, "b"));
assert!(r.is_accepting());
}
#[test]
fn earley_rejects_alternation_neither() {
let mut r = recognizer_from_bnf(r#"<S> ::= "a" | "b""#);
assert!(!feed_str(&mut r, "c"));
}
#[test]
fn earley_accepts_right_recursive() {
let mut r = recognizer_from_bnf(r#"<S> ::= "a" <S> | "a""#);
assert!(feed_str(&mut r, "aaa"));
assert!(r.is_accepting());
}
#[test]
fn earley_accepts_simple_ab_grammar() {
let mut r = recognizer_from_bnf(r#"<S> ::= "a" <S> "b" | "ab""#);
assert!(feed_str(&mut r, "ab"));
assert!(r.is_accepting());
r.reset();
assert!(feed_str(&mut r, "aabb"));
assert!(r.is_accepting());
r.reset();
assert!(feed_str(&mut r, "aaabbb"));
assert!(r.is_accepting());
}
#[test]
fn earley_rejects_ab_grammar_wrong() {
let mut r = recognizer_from_bnf(r#"<S> ::= "a" <S> "b" | "ab""#);
assert!(!feed_str(&mut r, "ba"));
}
#[test]
fn earley_rejects_ab_grammar_unbalanced() {
let mut r = recognizer_from_bnf(r#"<S> ::= "a" <S> "b" | "ab""#);
let ok = feed_str(&mut r, "aab");
if ok {
assert!(!r.is_accepting());
}
}
fn arithmetic_recognizer() -> EarleyRecognizer {
recognizer_from_bnf(
r#"
<expr> ::= <term> "+" <expr> | <term> "-" <expr> | <term>
<term> ::= <factor> "*" <term> | <factor> "/" <term> | <factor>
<factor> ::= "(" <expr> ")" | <number>
<number> ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
"#,
)
}
#[test]
fn earley_accepts_arithmetic_single_digit() {
let mut r = arithmetic_recognizer();
assert!(feed_str(&mut r, "9"));
assert!(r.is_accepting());
}
#[test]
fn earley_accepts_arithmetic_1plus2() {
let mut r = arithmetic_recognizer();
assert!(feed_str(&mut r, "1+2"));
assert!(r.is_accepting());
}
#[test]
fn earley_accepts_arithmetic_1times2plus3() {
let mut r = arithmetic_recognizer();
assert!(feed_str(&mut r, "1*2+3"));
assert!(r.is_accepting());
}
#[test]
fn earley_accepts_arithmetic_paren() {
let mut r = arithmetic_recognizer();
assert!(feed_str(&mut r, "(1+2)*3"));
assert!(r.is_accepting());
}
#[test]
fn earley_rejects_arithmetic_plus_at_start() {
let mut r = arithmetic_recognizer();
assert!(!feed_str(&mut r, "+1"));
}
#[test]
fn earley_rejects_arithmetic_double_plus() {
let mut r = arithmetic_recognizer();
let ok = feed_str(&mut r, "1++2");
if ok {
assert!(!r.is_accepting());
}
}
#[test]
fn earley_next_byte_set_at_start_arithmetic() {
let r = arithmetic_recognizer();
let nbs = r.next_byte_set();
for d in b'0'..=b'9' {
assert!(nbs.contains(&d), "digit {d} should be in next_byte_set");
}
assert!(nbs.contains(&b'('), "'(' should be in next_byte_set");
assert!(
!nbs.contains(&b'+'),
"'+' should not be in next_byte_set at start"
);
}
#[test]
fn earley_next_byte_set_after_digit() {
let mut r = arithmetic_recognizer();
feed_str(&mut r, "1");
let nbs = r.next_byte_set();
assert!(nbs.contains(&b'+'), "'+' should be valid after a digit");
assert!(nbs.contains(&b'-'), "'-' should be valid after a digit");
assert!(nbs.contains(&b'*'), "'*' should be valid after a digit");
assert!(nbs.contains(&b'/'), "'/' should be valid after a digit");
}
#[test]
fn earley_not_accepting_mid_input() {
let mut r = arithmetic_recognizer();
feed_str(&mut r, "1+");
assert!(!r.is_accepting(), "should not accept after '1+'");
}
#[test]
fn earley_reset_restores_initial_state() {
let mut r = arithmetic_recognizer();
feed_str(&mut r, "1+2");
assert!(r.is_accepting());
r.reset();
assert_eq!(r.input_pos, 0);
assert!(!r.is_accepting());
feed_str(&mut r, "9");
assert!(r.is_accepting());
}
#[test]
fn earley_clone_state_is_independent() {
let mut r = arithmetic_recognizer();
feed_str(&mut r, "1");
let mut clone = r.clone_state();
feed_str(&mut r, "+2");
assert!(r.is_accepting());
assert_eq!(clone.input_pos, 1);
feed_str(&mut r, "*3");
feed_str(&mut clone, "*3");
assert!(clone.is_accepting());
}
#[test]
fn earley_handles_left_recursive_grammar() {
let mut r = recognizer_from_bnf(r#"<E> ::= <E> "+" "1" | "1""#);
assert!(feed_str(&mut r, "1"));
assert!(r.is_accepting());
r.reset();
assert!(feed_str(&mut r, "1+1+1"));
assert!(r.is_accepting());
r.reset();
assert!(!feed_str(&mut r, "+1"));
}
#[test]
fn earley_handles_nullable_productions() {
let mut r = recognizer_from_bnf(
r#"
<S> ::= <A> "x"
<A> ::= "" | "a"
"#,
);
assert!(feed_str(&mut r, "x"));
assert!(r.is_accepting());
r.reset();
assert!(feed_str(&mut r, "ax"));
assert!(r.is_accepting());
}
#[test]
fn earley_is_not_live_after_rejection() {
let mut r = recognizer_from_bnf(r#"<S> ::= "abc""#);
r.feed_byte(b'a');
r.feed_byte(b'b');
r.feed_byte(b'z'); assert!(!r.is_live());
}
}