use std::{collections::HashMap, ops::Index};
use probability::distribution::Sample;
use rand::{FromEntropy, rngs::StdRng};
use super::distributions::{self, Value};
#[derive(Clone, Debug)]
pub struct Choicemap{
values : HashMap<String, Value>
}
impl Choicemap {
pub fn new() -> Choicemap {
Choicemap{ values : HashMap::new() }
}
pub fn from(choices : Vec<(&str, Value)>) -> Choicemap {
let mut res = Choicemap::new();
choices.iter().for_each(|(s, v)| res.add_choice(*s, v.clone()));
res
}
pub fn add_choice(&mut self, identifier : &str, value : Value) {
self.values.insert(identifier.to_string(), value);
}
pub fn get_choices(&self) -> Vec<(&str, Value)> {
self.values.keys().map(|k| (k.as_str(), self.values.get(k).unwrap().clone())).collect()
}
pub fn contains_key(&self, key : &str) -> bool {
self.values.contains_key(key)
}
}
impl Index<&str> for Choicemap {
type Output = Value;
fn index(&self, index: &str) -> &Self::Output {
match self.values.get(index) {
Some(v) => v,
None => panic!("Value not present in choicemap.")
}
}
}
impl Index<&String> for Choicemap {
type Output = Value;
fn index(&self, index: &String) -> &Self::Output {
match self.values.get(index.as_str()) {
Some(v) => v,
None => panic!("Value not present in choicemap.")
}
}
}
#[derive(Debug, Clone)]
pub struct Trace {
pub log_score : f64,
pub choices : Choicemap
}
impl Trace {
pub fn new() -> Trace {
Trace{ log_score : 0.0, choices : Choicemap::new() }
}
pub(crate) fn update_logscore(&mut self, new_value : f64) {
self.log_score = self.log_score + new_value;
}
pub fn get_trace_string(&self) -> String {
let mut s = String::new();
for (key, value) in &self.choices.get_choices() {
s.push_str(&format!("{} => {}\n", key, value));
}
s
}
pub fn sample_weighted_traces(traces : &Vec<Trace>) -> Option<Trace> {
if traces.len() == 0 {
None
} else {
let values : Vec<f64> = traces.iter().map(|x| x.log_score.exp()).collect();
let sum : f64 = values.iter().map(|x| x).sum();
let normalized_values : Vec<f64> = values.iter().map(|x| x / sum).collect();
let categorical = probability::distribution::Categorical::new(&normalized_values[..]);
Some(traces[categorical.sample(&mut distributions::Source(StdRng::from_entropy()))].clone())
}
}
}
impl PartialEq for Trace {
fn eq(&self, other: &Trace) -> bool {
self.log_score == other.log_score
}
}
impl PartialOrd for Trace {
fn partial_cmp(&self, other: &Trace) -> std::option::Option<std::cmp::Ordering> {
if self.log_score > other.log_score {
Some(std::cmp::Ordering::Greater)
} else if self.log_score < other.log_score {
Some(std::cmp::Ordering::Less)
} else {
Some(std::cmp::Ordering::Equal)
}
}
}