use super::ast::{AstNode, ParsedCode};
use super::language::CodeLanguage;
use std::collections::HashMap;
use std::hash::Hash;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Production {
pub lhs: String,
pub rhs: Vec<Symbol>,
}
impl Production {
pub fn new(lhs: impl Into<String>, rhs: Vec<Symbol>) -> Self {
Self {
lhs: lhs.into(),
rhs,
}
}
pub fn is_epsilon(&self) -> bool {
self.rhs.is_empty()
}
pub fn arity(&self) -> usize {
self.rhs.len()
}
}
impl std::fmt::Display for Production {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} ->", self.lhs)?;
for sym in &self.rhs {
write!(f, " {}", sym)?;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Symbol {
NonTerminal(String),
Terminal(String),
}
impl Symbol {
pub fn non_terminal(s: impl Into<String>) -> Self {
Symbol::NonTerminal(s.into())
}
pub fn terminal(s: impl Into<String>) -> Self {
Symbol::Terminal(s.into())
}
pub fn is_non_terminal(&self) -> bool {
matches!(self, Symbol::NonTerminal(_))
}
pub fn is_terminal(&self) -> bool {
matches!(self, Symbol::Terminal(_))
}
pub fn name(&self) -> &str {
match self {
Symbol::NonTerminal(s) | Symbol::Terminal(s) => s,
}
}
}
impl std::fmt::Display for Symbol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Symbol::NonTerminal(s) => write!(f, "<{}>", s),
Symbol::Terminal(s) => write!(f, "'{}'", s),
}
}
}
#[derive(Debug, Clone)]
pub struct WeightedCFG {
rules: HashMap<Production, f64>,
rules_by_lhs: HashMap<String, Vec<(Production, f64)>>,
start_symbol: String,
lhs_totals: HashMap<String, f64>,
}
impl WeightedCFG {
pub fn new(start_symbol: impl Into<String>) -> Self {
Self {
rules: HashMap::new(),
rules_by_lhs: HashMap::new(),
start_symbol: start_symbol.into(),
lhs_totals: HashMap::new(),
}
}
pub fn add_rule(&mut self, production: Production, weight: f64) {
let lhs = production.lhs.clone();
*self.lhs_totals.entry(lhs.clone()).or_insert(0.0) += weight;
*self.rules.entry(production.clone()).or_insert(0.0) += weight;
self.rules_by_lhs
.entry(lhs)
.or_default()
.push((production, weight));
}
pub fn weight(&self, production: &Production) -> f64 {
self.rules.get(production).copied().unwrap_or(0.0)
}
pub fn probability(&self, production: &Production) -> f64 {
let weight = self.weight(production);
let total = self.lhs_totals.get(&production.lhs).copied().unwrap_or(1.0);
if total > 0.0 {
weight / total
} else {
0.0
}
}
pub fn log_probability(&self, production: &Production) -> f64 {
let prob = self.probability(production);
if prob > 0.0 {
prob.ln()
} else {
f64::NEG_INFINITY
}
}
pub fn rules_for(&self, lhs: &str) -> Vec<(&Production, f64)> {
self.rules_by_lhs
.get(lhs)
.map(|rules| rules.iter().map(|(p, w)| (p, *w)).collect())
.unwrap_or_default()
}
pub fn start_symbol(&self) -> &str {
&self.start_symbol
}
pub fn non_terminals(&self) -> impl Iterator<Item = &str> {
self.rules_by_lhs.keys().map(|s| s.as_str())
}
pub fn terminals(&self) -> impl Iterator<Item = &str> {
self.rules
.keys()
.flat_map(|p| p.rhs.iter())
.filter_map(|s| match s {
Symbol::Terminal(t) => Some(t.as_str()),
_ => None,
})
}
pub fn rule_count(&self) -> usize {
self.rules.len()
}
pub fn iter_rules(&self) -> impl Iterator<Item = (&Production, &f64)> {
self.rules.iter()
}
pub fn rules(&self) -> &HashMap<Production, f64> {
&self.rules
}
pub fn normalize(&mut self) {
let mut normalized_rules = HashMap::new();
for (production, weight) in &self.rules {
let total = self.lhs_totals.get(&production.lhs).copied().unwrap_or(1.0);
let prob = if total > 0.0 { weight / total } else { 0.0 };
normalized_rules.insert(production.clone(), prob);
}
self.rules = normalized_rules;
for total in self.lhs_totals.values_mut() {
*total = 1.0;
}
self.rules_by_lhs.clear();
for (production, weight) in &self.rules {
self.rules_by_lhs
.entry(production.lhs.clone())
.or_default()
.push((production.clone(), *weight));
}
}
}
pub struct PcfgTrainer<'a, L: CodeLanguage> {
language: &'a L,
rule_counts: HashMap<Production, u64>,
start_symbol: String,
}
impl<'a, L: CodeLanguage> PcfgTrainer<'a, L> {
pub fn new(language: &'a L) -> Self {
Self {
language,
rule_counts: HashMap::new(),
start_symbol: "source_file".to_string(),
}
}
pub fn with_start_symbol(mut self, symbol: impl Into<String>) -> Self {
self.start_symbol = symbol.into();
self
}
pub fn language(&self) -> &L {
self.language
}
pub fn train_from_parsed(&mut self, parsed: &ParsedCode) {
let ast = AstNode::from_ts_node(parsed.root(), &parsed.source);
self.extract_rules(&ast);
}
pub fn train_from_parsed_iter<'b, I>(&mut self, parsed_iter: I)
where
I: Iterator<Item = &'b ParsedCode>,
{
for parsed in parsed_iter {
self.train_from_parsed(parsed);
}
}
fn extract_rules(&mut self, node: &AstNode) {
if node.is_error || node.is_missing {
return;
}
if node.is_named && !node.children.is_empty() {
let lhs = node.kind.clone();
let rhs: Vec<Symbol> = node
.children
.iter()
.filter(|c| c.is_named) .map(|c| {
if c.children.is_empty() && c.text.is_some() {
Symbol::Terminal(c.kind.clone())
} else {
Symbol::NonTerminal(c.kind.clone())
}
})
.collect();
if !rhs.is_empty() {
let production = Production::new(lhs, rhs);
*self.rule_counts.entry(production).or_insert(0) += 1;
}
}
for child in &node.children {
self.extract_rules(child);
}
}
pub fn to_weighted_cfg(&self) -> WeightedCFG {
let mut cfg = WeightedCFG::new(self.start_symbol.clone());
for (production, count) in &self.rule_counts {
cfg.add_rule(production.clone(), *count as f64);
}
cfg
}
pub fn rule_counts(&self) -> &HashMap<Production, u64> {
&self.rule_counts
}
pub fn unique_rule_count(&self) -> usize {
self.rule_counts.len()
}
pub fn total_rule_count(&self) -> u64 {
self.rule_counts.values().sum()
}
pub fn clear(&mut self) {
self.rule_counts.clear();
}
}
#[derive(Debug, Clone)]
pub struct PcfgWfstConfig {
pub include_epsilon: bool,
pub min_probability: f64,
pub max_rules: Option<usize>,
}
impl Default for PcfgWfstConfig {
fn default() -> Self {
Self {
include_epsilon: true,
min_probability: 1e-10,
max_rules: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_production_display() {
let prod = Production::new(
"expr",
vec![
Symbol::NonTerminal("term".to_string()),
Symbol::Terminal("+".to_string()),
Symbol::NonTerminal("expr".to_string()),
],
);
assert_eq!(format!("{}", prod), "expr -> <term> '+' <expr>");
}
#[test]
fn test_weighted_cfg_probability() {
let mut cfg = WeightedCFG::new("S");
cfg.add_rule(
Production::new("S", vec![Symbol::NonTerminal("A".to_string())]),
3.0,
);
cfg.add_rule(
Production::new("S", vec![Symbol::NonTerminal("B".to_string())]),
1.0,
);
let prob_a = cfg.probability(&Production::new(
"S",
vec![Symbol::NonTerminal("A".to_string())],
));
let prob_b = cfg.probability(&Production::new(
"S",
vec![Symbol::NonTerminal("B".to_string())],
));
assert!((prob_a - 0.75).abs() < 1e-6);
assert!((prob_b - 0.25).abs() < 1e-6);
}
#[test]
fn test_symbol_types() {
let nt = Symbol::non_terminal("expr");
let t = Symbol::terminal("+");
assert!(nt.is_non_terminal());
assert!(!nt.is_terminal());
assert!(!t.is_non_terminal());
assert!(t.is_terminal());
assert_eq!(nt.name(), "expr");
assert_eq!(t.name(), "+");
}
}