use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use super::grammar::{Grammar, GrammarError, Production};
use super::types::{NonTerminal, RuleId, Symbol, Terminal};
#[derive(Debug, Default)]
pub struct GrammarBuilder {
nt_names: FxHashMap<String, NonTerminal>,
terminal_names: FxHashMap<String, Terminal>,
next_nt: u16,
next_terminal: u32,
productions: Vec<Production>,
start: Option<String>,
}
impl GrammarBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn non_terminal(mut self, name: &str) -> Self {
if !self.nt_names.contains_key(name) {
let nt = NonTerminal::new(self.next_nt);
self.nt_names.insert(name.to_string(), nt);
self.next_nt += 1;
}
self
}
pub fn non_terminals(mut self, names: &[&str]) -> Self {
for name in names {
self = self.non_terminal(name);
}
self
}
pub fn terminal(mut self, name: &str) -> Self {
if !self.terminal_names.contains_key(name) {
let t = Terminal::new(self.next_terminal);
self.terminal_names.insert(name.to_string(), t);
self.next_terminal += 1;
}
self
}
pub fn terminal_with_id(mut self, name: &str, vocab_id: u32) -> Self {
let t = Terminal::new(vocab_id);
self.terminal_names.insert(name.to_string(), t);
self
}
pub fn start(mut self, name: &str) -> Self {
self.start = Some(name.to_string());
self
}
pub fn rule(self, lhs: &str, rhs: &[&str]) -> Self {
self.rule_with_prob(lhs, rhs, 0.0)
}
pub fn rule_with_prob(mut self, lhs: &str, rhs: &[&str], log_prob: f32) -> Self {
self = self.non_terminal(lhs);
let lhs_nt = self.nt_names[lhs];
let mut rhs_symbols: SmallVec<[Symbol; 4]> = SmallVec::new();
for sym in rhs {
if sym.is_empty() || *sym == "ε" {
rhs_symbols.push(Symbol::epsilon());
} else if self.is_non_terminal_name(sym) {
self = self.non_terminal(sym);
rhs_symbols.push(Symbol::NonTerminal(self.nt_names[*sym]));
} else {
self = self.terminal(sym);
rhs_symbols.push(Symbol::Terminal(self.terminal_names[*sym]));
}
}
let rule_id = RuleId::new(self.productions.len() as u32);
self.productions.push(Production::with_prob(
rule_id,
lhs_nt,
rhs_symbols,
log_prob,
));
self
}
pub fn epsilon_rule(self, lhs: &str) -> Self {
self.rule(lhs, &["ε"])
}
pub fn build(self) -> Result<Grammar, GrammarError> {
let start_name = self.start.as_deref().ok_or(GrammarError::NoStartSymbol)?;
let start_nt = self
.nt_names
.get(start_name)
.copied()
.ok_or_else(|| GrammarError::UndefinedNonTerminal(NonTerminal::new(0)))?;
let mut grammar = Grammar::new(start_nt, self.productions, self.next_nt as usize)?;
for (name, nt) in &self.nt_names {
grammar.set_nt_name(*nt, name.clone());
}
for (name, terminal) in &self.terminal_names {
grammar.register_terminal(name.clone(), *terminal);
}
Ok(grammar)
}
fn is_non_terminal_name(&self, name: &str) -> bool {
name.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false)
}
pub fn get_non_terminal(&self, name: &str) -> Option<NonTerminal> {
self.nt_names.get(name).copied()
}
pub fn get_terminal(&self, name: &str) -> Option<Terminal> {
self.terminal_names.get(name).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_simple() {
let grammar = GrammarBuilder::new()
.start("S")
.rule("S", &["NP", "VP"])
.rule("NP", &["Det", "N"])
.rule("VP", &["V", "NP"])
.rule("Det", &["the"])
.rule("N", &["dog"])
.rule("V", &["saw"])
.build()
.expect("valid grammar");
assert_eq!(grammar.num_productions(), 6);
assert_eq!(grammar.nt_name(grammar.start()), Some("S"));
}
#[test]
fn test_builder_non_terminals() {
let builder = GrammarBuilder::new().non_terminals(&["S", "A", "B"]);
assert!(builder.get_non_terminal("S").is_some());
assert!(builder.get_non_terminal("A").is_some());
assert!(builder.get_non_terminal("B").is_some());
assert!(builder.get_non_terminal("C").is_none());
}
#[test]
fn test_builder_epsilon_rule() {
let grammar = GrammarBuilder::new()
.start("S")
.rule("S", &["A"])
.epsilon_rule("A")
.build()
.expect("valid grammar");
assert_eq!(grammar.num_productions(), 2);
let a_nt = grammar.terminal_by_name("A");
assert!(a_nt.is_none()); }
#[test]
fn test_builder_with_probability() {
let grammar = GrammarBuilder::new()
.start("S")
.rule_with_prob("S", &["a"], -0.5)
.rule_with_prob("S", &["b"], -1.0)
.build()
.expect("valid grammar");
let prods: Vec<_> = grammar.productions_for(grammar.start()).collect();
assert_eq!(prods[0].log_prob, -0.5);
assert_eq!(prods[1].log_prob, -1.0);
}
#[test]
fn test_builder_terminal_with_id() {
let grammar = GrammarBuilder::new()
.terminal_with_id("the", 100)
.terminal_with_id("dog", 200)
.start("S")
.rule("S", &["the", "dog"])
.build()
.expect("valid grammar");
assert_eq!(grammar.terminal_by_name("the"), Some(Terminal::new(100)));
assert_eq!(grammar.terminal_by_name("dog"), Some(Terminal::new(200)));
}
#[test]
fn test_builder_auto_declaration() {
let grammar = GrammarBuilder::new()
.start("S")
.rule("S", &["A", "B"])
.rule("A", &["a"])
.rule("B", &["b"])
.build()
.expect("valid grammar");
assert!(
grammar.nt_name(NonTerminal::new(1)).is_some()
|| grammar.nt_name(NonTerminal::new(2)).is_some()
);
}
#[test]
fn test_builder_no_start_error() {
let result = GrammarBuilder::new().rule("S", &["a"]).build();
assert!(matches!(result, Err(GrammarError::NoStartSymbol)));
}
#[test]
fn test_builder_terminal_names() {
let grammar = GrammarBuilder::new()
.start("S")
.rule("S", &["the", "dog"])
.build()
.expect("valid grammar");
assert!(grammar.terminal_by_name("the").is_some());
assert!(grammar.terminal_by_name("dog").is_some());
}
}