use crate::{Cfg, Pr, Symbol, Terminal};
use anyhow::{Result, anyhow};
use parol_runtime::log::trace;
use rand::RngExt;
use std::collections::HashMap;
use thiserror::Error;
const MAX_RESULT_SIZE: usize = 100000;
const MAX_REPEAT: u32 = 8;
#[derive(Error, Debug)]
#[error("Stopping generation to prevent endless recursion at size {len}")]
pub struct SourceSizeExceeded {
len: usize,
}
#[derive(Debug)]
pub struct LanguageGenerator<'a> {
generator_stack: Vec<Symbol>,
cfg: &'a Cfg,
cache: HashMap<String, rand_regex::Regex>,
}
impl<'a> LanguageGenerator<'a> {
pub fn new(cfg: &'a Cfg) -> Self {
Self {
generator_stack: Vec::new(),
cfg,
cache: HashMap::new(),
}
}
pub fn generate(&mut self, max_result_length: Option<usize>) -> Result<String> {
let mut result = String::new();
let termination_threshold = max_result_length.unwrap_or(MAX_RESULT_SIZE) / 2;
trace!("Try to terminate at result length {termination_threshold}");
self.process_non_terminal(self.cfg.get_start_symbol(), false)?;
while let Some(symbol) = self.generator_stack.pop() {
match symbol {
Symbol::N(n, ..) => {
self.process_non_terminal(&n, result.len() > termination_threshold)
}
Symbol::T(Terminal::Trm(t, k, ..)) => {
self.process_terminal(k.expand(&t), &mut result, max_result_length)
}
_ => Ok(()),
}?
}
Ok(result)
}
fn process_non_terminal(&mut self, non_terminal: &str, terminate: bool) -> Result<()> {
let productions_of_nt = self.cfg.matching_productions(non_terminal);
let chosen_index = if terminate {
Self::chose_minimal_expanding_production(&productions_of_nt)
} else {
rand::rng().random_range(0..productions_of_nt.len())
};
trace!(
"/* {} */ {} {}/{} {}",
productions_of_nt[chosen_index].0,
productions_of_nt[chosen_index].1,
chosen_index + 1,
productions_of_nt.len(),
if terminate { "term" } else { "" }
);
productions_of_nt[chosen_index]
.1
.get_r()
.iter()
.rev()
.for_each(|s| self.generator_stack.push(s.clone()));
Ok(())
}
fn process_terminal(
&mut self,
terminal: String,
result: &mut String,
max_result_length: Option<usize>,
) -> Result<()> {
let mut rng = rand::rng();
let utf8_gen = self.get_regex(terminal)?;
let generated: String = rng.sample(utf8_gen);
trace!("gen: {generated}");
result.push_str(&generated);
result.push(' ');
let len = result.len();
if len > max_result_length.unwrap_or(MAX_RESULT_SIZE) {
Err(anyhow!(SourceSizeExceeded { len }))
} else {
Ok(())
}
}
fn get_regex<'b, 'c>(&'b mut self, terminal: String) -> Result<&'c rand_regex::Regex>
where
'b: 'c,
{
let exist = self.cache.contains_key(&terminal);
if exist {
let regex = self.cache.get(&terminal).unwrap();
trace!("Reusing cached regex for: {terminal}");
return Ok(regex);
}
match regex_syntax::ParserBuilder::new().build().parse(&terminal) {
Ok(utf8_hir) => match rand_regex::Regex::with_hir(utf8_hir, MAX_REPEAT) {
Ok(utf8_gen) => {
trace!("Caching regex for: {terminal}");
self.cache.insert(terminal.clone(), utf8_gen);
self.get_regex(terminal)
}
Err(e) => Err(anyhow!(e).context(format!(
"rand_regex can't generate a sentence for this: /{terminal}/"
))),
},
Err(err) => Err(anyhow!(err).context(format!("regex_syntax can't parse /{terminal}/"))),
}
}
fn chose_minimal_expanding_production(productions_of_nt: &[(usize, &Pr)]) -> usize {
let production_index = productions_of_nt
.iter()
.min_by(|(_, a), (_, b)| {
let a_nt_count = a.get_r().iter().fold(0, |mut acc, s| {
if s.is_n() {
acc += 1
}
acc
});
let b_nt_count = b.get_r().iter().fold(0, |mut acc, s| {
if s.is_n() {
acc += 1
}
acc
});
a_nt_count.cmp(&b_nt_count)
})
.unwrap()
.0;
productions_of_nt
.iter()
.position(|(idx, _)| *idx == production_index)
.unwrap()
}
}