use rand::prelude::StdRng;
use crate::core::terms::term::LanguageTerm;
use super::probas::TermGenerationSymbolsProbabilities;
use super::types::{RandomTermGenerationConfig, TermGenerationSymbol, TermPatternForRandomGeneration};
pub struct RandomTermGenerationStopCriterion<CONF : RandomTermGenerationConfig> {
pub max_depth : u32,
pub symbol_at_end : TermGenerationSymbol<CONF::LOS,CONF::PATTERN>
}
impl<CONF : RandomTermGenerationConfig> RandomTermGenerationStopCriterion<CONF> {
pub fn new(max_depth : u32,symbol_at_end : TermGenerationSymbol<CONF::LOS,CONF::PATTERN>) -> Self {
Self { max_depth, symbol_at_end }
}
}
pub fn generate_random_term<CONF : RandomTermGenerationConfig>(
probas : &TermGenerationSymbolsProbabilities<CONF>,
stop_crit : &RandomTermGenerationStopCriterion<CONF>,
context : &CONF::CONTEXT,
rng : &mut StdRng
) -> LanguageTerm<CONF::LOS> {
generate_random_term_rec(
probas,0,stop_crit,context,rng
)
}
fn generate_random_term_rec<CONF : RandomTermGenerationConfig>(
probas : &TermGenerationSymbolsProbabilities<CONF>,
depth : u32,
stop_crit : &RandomTermGenerationStopCriterion<CONF>,
context : &CONF::CONTEXT,
rng : &mut StdRng
) -> LanguageTerm<CONF::LOS> {
if depth >= stop_crit.max_depth {
return match &stop_crit.symbol_at_end {
TermGenerationSymbol::LanguageSymbol(s) => {
assert!(CONF::get_arity(s) == 0);
LanguageTerm::new(s.clone(), vec![])
},
TermGenerationSymbol::Pattern(p) => {
p.generate_term_from_pattern(rng, context)
},
};
}
let symbol = probas.get_random_symbol(rng);
match symbol {
TermGenerationSymbol::LanguageSymbol(s) => {
let mut sub_terms = vec![];
for _ in 0..CONF::get_arity(&s) {
sub_terms.push(
generate_random_term_rec(probas,depth+1,stop_crit,context,rng)
);
}
LanguageTerm::new(s.clone(), sub_terms)
},
TermGenerationSymbol::Pattern(p) => {
p.generate_term_from_pattern(rng, context)
},
}
}