use rand::prelude::StdRng;
use rand::RngExt;
use std::collections::HashMap;
use super::types::{RandomTermGenerationConfig, TermGenerationSymbol};
pub struct TermGenerationSymbolsProbabilities<CONF: RandomTermGenerationConfig> {
pub ordered_symbols: Vec<TermGenerationSymbol<CONF::LOS, CONF::PATTERN>>,
pub ordered_bounds: Vec<f32>,
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum InteractionSymbolsProbabilitiesError {
SymbolProbabilityMustBeBetweenOAnd1,
SumOfProbabilitiesMustBe1,
}
impl<CONF: RandomTermGenerationConfig> TermGenerationSymbolsProbabilities<CONF> {
pub fn from_map(
map: HashMap<TermGenerationSymbol<CONF::LOS, CONF::PATTERN>, f32>,
) -> Result<Self, InteractionSymbolsProbabilitiesError> {
let mut ordered_symbols = vec![];
let mut ordered_bounds = vec![0.0_f32];
let mut sum = 0.0;
for (s, p) in map {
if !(0.0 - 1e-6..=1.0 + 1e-6).contains(&p) {
return Err(
InteractionSymbolsProbabilitiesError::SymbolProbabilityMustBeBetweenOAnd1,
);
}
ordered_symbols.push(s);
sum += p;
ordered_bounds.push(sum);
}
if !(1.0 - 1e-6..=1.0 + 1e-6).contains(&sum) {
return Err(InteractionSymbolsProbabilitiesError::SumOfProbabilitiesMustBe1);
}
assert!(ordered_bounds.len() == ordered_symbols.len() + 1);
Ok(Self {
ordered_symbols,
ordered_bounds,
})
}
pub fn get_random_symbol(
&self,
rng: &mut StdRng,
) -> TermGenerationSymbol<CONF::LOS, CONF::PATTERN> {
let got = rng.random_range(0.0_f32..1.0_f32);
for (idx, x) in self.ordered_bounds.iter().enumerate() {
if got <= *x + 1e-6 {
if idx == 0 {
return self.ordered_symbols.first().unwrap().clone();
} else {
return self.ordered_symbols.get(idx - 1).unwrap().clone();
}
}
}
panic!()
}
}