use std::collections::HashMap;
pub type NonTerminalId = usize;
pub type RuleId = usize;
pub const NULL_NT: NonTerminalId = usize::MAX;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Symbol {
Terminal(Vec<u8>),
NonTerminal(NonTerminalId),
}
impl Symbol {
#[inline]
pub fn is_terminal(&self) -> bool {
matches!(self, Symbol::Terminal(_))
}
#[inline]
pub fn is_non_terminal(&self) -> bool {
matches!(self, Symbol::NonTerminal(_))
}
#[inline]
pub fn terminal_bytes(&self) -> Option<&[u8]> {
match self {
Symbol::Terminal(b) => Some(b),
Symbol::NonTerminal(_) => None,
}
}
#[inline]
pub fn non_terminal_id(&self) -> Option<NonTerminalId> {
match self {
Symbol::NonTerminal(id) => Some(*id),
Symbol::Terminal(_) => None,
}
}
}
#[derive(Debug, Clone)]
pub struct Rule {
pub lhs: NonTerminalId,
pub rhs: Vec<Symbol>,
}
impl Rule {
pub fn new(lhs: NonTerminalId, rhs: Vec<Symbol>) -> Self {
Self { lhs, rhs }
}
pub fn rhs_len(&self) -> usize {
self.rhs.len()
}
pub fn is_epsilon(&self) -> bool {
self.rhs.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Grammar {
pub rules: Vec<Rule>,
pub start: NonTerminalId,
pub nt_names: HashMap<NonTerminalId, String>,
pub nt_count: usize,
}
impl Grammar {
pub fn new(start: NonTerminalId) -> Self {
Self {
rules: Vec::new(),
start,
nt_names: HashMap::new(),
nt_count: 0,
}
}
pub fn add_rule(&mut self, rule: Rule) {
self.rules.push(rule);
}
pub fn rules_for(&self, nt: NonTerminalId) -> impl Iterator<Item = (RuleId, &Rule)> + '_ {
self.rules
.iter()
.enumerate()
.filter(move |(_, r)| r.lhs == nt)
}
pub fn start(&self) -> NonTerminalId {
self.start
}
pub fn nt_name(&self, id: NonTerminalId) -> &str {
self.nt_names
.get(&id)
.map(|s| s.as_str())
.unwrap_or("<unknown>")
}
pub fn alloc_nt(&mut self, name: impl Into<String>) -> NonTerminalId {
let id = self.nt_count;
self.nt_count += 1;
self.nt_names.insert(id, name.into());
id
}
pub fn normalise_terminals(&mut self) {
let mut cache: HashMap<Vec<u8>, NonTerminalId> = HashMap::new();
let rule_count = self.rules.len();
for rule_idx in 0..rule_count {
let mut new_rhs: Vec<Symbol> = Vec::new();
let mut changed = false;
let rhs: Vec<Symbol> = self.rules[rule_idx].rhs.clone();
for symbol in rhs {
match &symbol {
Symbol::Terminal(bytes) if bytes.len() > 1 => {
changed = true;
let chain_nt = Self::intern_byte_chain(self, bytes.clone(), &mut cache);
new_rhs.push(Symbol::NonTerminal(chain_nt));
}
other => {
new_rhs.push(other.clone());
}
}
}
if changed {
self.rules[rule_idx].rhs = new_rhs;
}
}
}
fn intern_byte_chain(
grammar: &mut Grammar,
bytes: Vec<u8>,
cache: &mut HashMap<Vec<u8>, NonTerminalId>,
) -> NonTerminalId {
if let Some(&id) = cache.get(&bytes) {
return id;
}
let name = {
let hex: Vec<String> = bytes.iter().map(|b| format!("{b:02x}")).collect();
format!("__T_{}", hex.join("_"))
};
let nt = grammar.alloc_nt(&name);
cache.insert(bytes.clone(), nt);
if bytes.len() == 1 {
grammar.rules.push(Rule {
lhs: nt,
rhs: vec![Symbol::Terminal(bytes)],
});
} else {
let rest = bytes[1..].to_vec();
let rest_nt = Self::intern_byte_chain(grammar, rest, cache);
grammar.rules.push(Rule {
lhs: nt,
rhs: vec![
Symbol::Terminal(vec![bytes[0]]),
Symbol::NonTerminal(rest_nt),
],
});
}
nt
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_simple_grammar() -> Grammar {
let mut g = Grammar::new(0);
g.alloc_nt("S"); g.add_rule(Rule::new(0, vec![Symbol::Terminal(b"ab".to_vec())]));
g.add_rule(Rule::new(0, vec![Symbol::Terminal(b"c".to_vec())]));
g
}
#[test]
fn symbol_is_terminal() {
let t = Symbol::Terminal(vec![65]);
assert!(t.is_terminal());
assert!(!t.is_non_terminal());
assert_eq!(t.terminal_bytes(), Some([65u8].as_ref()));
assert_eq!(t.non_terminal_id(), None);
}
#[test]
fn symbol_is_non_terminal() {
let nt = Symbol::NonTerminal(3);
assert!(nt.is_non_terminal());
assert!(!nt.is_terminal());
assert_eq!(nt.non_terminal_id(), Some(3));
assert_eq!(nt.terminal_bytes(), None);
}
#[test]
fn grammar_rules_for() {
let g = make_simple_grammar();
let rules: Vec<_> = g.rules_for(0).collect();
assert_eq!(rules.len(), 2);
}
#[test]
fn grammar_nt_name_unknown() {
let g = Grammar::new(0);
assert_eq!(g.nt_name(99), "<unknown>");
}
#[test]
fn grammar_normalise_multi_byte_terminal() {
let mut g = make_simple_grammar();
let original_rule_count = g.rules.len();
g.normalise_terminals();
assert!(g.rules.len() > original_rule_count);
let first_rhs = &g.rules[0].rhs;
assert_eq!(first_rhs.len(), 1);
assert!(first_rhs[0].is_non_terminal());
}
#[test]
fn grammar_normalise_idempotent() {
let mut g = make_simple_grammar();
g.normalise_terminals();
let count_after_first = g.rules.len();
g.normalise_terminals();
assert_eq!(g.rules.len(), count_after_first);
}
#[test]
fn rule_is_epsilon() {
let r = Rule::new(0, vec![]);
assert!(r.is_epsilon());
assert_eq!(r.rhs_len(), 0);
let r2 = Rule::new(0, vec![Symbol::Terminal(vec![65])]);
assert!(!r2.is_epsilon());
}
}