use rand::prelude::StdRng;
use crate::term::syntax::{
LanguageOperatorArity, LanguageTerm, LanguageTermNode, RewritableLanguageOperatorSymbol,
TermFactory,
};
use super::probas::TermGenerationSymbolsProbabilities;
use super::types::{
RandomTermGenerationConfig, TermGenerationSymbol, TermPatternForRandomGeneration,
};
pub struct RandomTermGenerationStopCriterion<CONF: RandomTermGenerationConfig> {
pub max_depth: u32,
pub symbol_at_end: TermGenerationSymbol<CONF::LOS, CONF::PATTERN>,
}
impl<CONF: RandomTermGenerationConfig> RandomTermGenerationStopCriterion<CONF> {
pub fn new(
max_depth: u32,
symbol_at_end: TermGenerationSymbol<CONF::LOS, CONF::PATTERN>,
) -> Self {
Self {
max_depth,
symbol_at_end,
}
}
}
pub fn generate_random_term<CONF: RandomTermGenerationConfig>(
probas: &TermGenerationSymbolsProbabilities<CONF>,
stop_crit: &RandomTermGenerationStopCriterion<CONF>,
context: &CONF::CONTEXT,
rng: &mut StdRng,
factory: &mut TermFactory<CONF::LOS>,
) -> LanguageTerm<CONF::LOS> {
generate_random_term_rec(probas, 0, stop_crit, context, rng, factory)
}
fn generate_random_term_rec<CONF: RandomTermGenerationConfig>(
probas: &TermGenerationSymbolsProbabilities<CONF>,
depth: u32,
stop_crit: &RandomTermGenerationStopCriterion<CONF>,
context: &CONF::CONTEXT,
rng: &mut StdRng,
factory: &mut TermFactory<CONF::LOS>,
) -> LanguageTerm<CONF::LOS> {
if depth >= stop_crit.max_depth {
return match &stop_crit.symbol_at_end {
TermGenerationSymbol::LanguageSymbol(s) => {
assert!(s.arity() == LanguageOperatorArity::Fixed(0));
LanguageTermNode::build(s.clone(), vec![], factory)
}
TermGenerationSymbol::Pattern(p) => p.generate_term_from_pattern(rng, context, factory),
};
}
let symbol = probas.get_random_symbol(rng);
match symbol {
TermGenerationSymbol::LanguageSymbol(s) => {
let n = match s.arity() {
LanguageOperatorArity::Fixed(n) => n,
LanguageOperatorArity::Variadic => {
panic!("variadic operators are not supported in random term generation")
}
};
let mut sub_terms = vec![];
for _ in 0..n {
sub_terms.push(generate_random_term_rec(
probas,
depth + 1,
stop_crit,
context,
rng,
factory,
));
}
LanguageTermNode::build(s.clone(), sub_terms, factory)
}
TermGenerationSymbol::Pattern(p) => p.generate_term_from_pattern(rng, context, factory),
}
}