r_gen/
trace.rs

1//! Trace objects that represent a run of a generative model.
2use std::{collections::HashMap, ops::Index};
3
4use probability::distribution::Sample;
5use rand::{FromEntropy, rngs::StdRng};
6
7use super::distributions::{self, Value};
8
9/// A struct to hold all of the random choices made during the execution of a generative model. 
10#[derive(Clone, Debug)]
11pub struct Choicemap{
12    values : HashMap<String, Value>
13}
14
15//Implement standard functions for choice maps. 
16impl Choicemap {
17    /// Create a new, blank choice map. 
18    /// # Example
19    /// ```
20    /// use r_gen::trace::Choicemap; 
21    /// use r_gen::distributions::Value; 
22    /// let mut choicemap = Choicemap::new(); 
23    /// choicemap.add_choice("p", Value::Real(0.5)); 
24    /// ```
25    pub fn new() -> Choicemap {
26        Choicemap{ values : HashMap::new() }
27    }
28
29    /// Create a new choicemap with given choices in it. 
30    /// # Example 
31    /// ```
32    /// use r_gen::trace::Choicemap; 
33    /// use r_gen::distributions::Value; 
34    /// let mut choicemap = Choicemap::from(vec![("p", Value::Real(0.5))]); 
35    /// ```
36    pub fn from(choices : Vec<(&str, Value)>) -> Choicemap {
37        let mut res = Choicemap::new(); 
38        choices.iter().for_each(|(s, v)| res.add_choice(*s, v.clone())); 
39        res
40    }
41
42    /// Add a choice to this choicemap. 
43    /// # Example
44    /// ```
45    /// use r_gen::trace::Choicemap; 
46    /// use r_gen::distributions::Value; 
47    /// let mut choicemap = Choicemap::new(); 
48    /// choicemap.add_choice("p", Value::Real(0.5)); 
49    /// ```
50    pub fn add_choice(&mut self, identifier : &str, value : Value) {
51        self.values.insert(identifier.to_string(), value); 
52    }
53
54    /// Get a list of the choices that were made in the generative model. 
55    pub fn get_choices(&self) -> Vec<(&str, Value)> {
56        self.values.keys().map(|k| (k.as_str(), self.values.get(k).unwrap().clone())).collect() 
57    }
58
59    /// Check whether or not the given key is already in the choicemap. 
60    pub fn contains_key(&self, key : &str) -> bool {
61        self.values.contains_key(key)
62    }
63}
64
65impl Index<&str> for Choicemap {
66    type Output = Value;
67
68    fn index(&self, index: &str) -> &Self::Output {
69        match self.values.get(index) {
70            Some(v) => v, 
71            None => panic!("Value not present in choicemap.")
72        }
73    }
74}
75
76impl Index<&String> for Choicemap {
77    type Output = Value;
78
79    fn index(&self, index: &String) -> &Self::Output {
80        match self.values.get(index.as_str()) {
81            Some(v) => v, 
82            None => panic!("Value not present in choicemap.")
83        }
84    }
85}
86
87/**
88The trace struct. This holds information about the execution of a gnerative model. 
89*/
90#[derive(Debug, Clone)]
91pub struct Trace {
92    /// The log joint liklihood of all of the random decisions in the trace. 
93    pub log_score : f64, 
94    /// The Choicemap that holds the list of the actual decisions that were made in the execution of the generative model.
95    pub choices : Choicemap
96}
97
98
99impl Trace {
100    /**
101    Create a new blank trace. It begins with an empty choice map and a log score of 0 (which corresponds to a 
102    probability of 1.0 when exponentiated.)
103    */
104    pub fn new() -> Trace {
105        Trace{ log_score : 0.0, choices : Choicemap::new() }
106    }
107
108    /**
109    Update the logscore of a given trace. 
110    */
111    pub(crate) fn update_logscore(&mut self, new_value : f64) {
112        self.log_score = self.log_score + new_value; 
113    }
114
115    /**
116    Return a string that discribes the random decisions made by the model in this trace.
117    # Example 
118    ```
119    #[macro_use]
120    use r_gen::{sample, r_gen}; 
121    use r_gen::{simulate, distributions::{Value, Distribution}, trace::{Choicemap, Trace}}; 
122    use std::rc::Rc;
123
124    #[r_gen]
125    fn my_biased_coin_model(():()){
126        let p = sample!(format!("p"), Distribution::Beta(1.0, 1.0));            //Sample p from a uniform. 
127        sample!(format!("num_heads"), Distribution::Binomial(100, p.into()));   //Flip 100 coins where P(Heads)=p
128    }
129    let (trace, result) = simulate(&mut my_biased_coin_model, ()); 
130    println!("{}", trace.get_trace_string()); 
131    ```
132    */
133    pub fn get_trace_string(&self) -> String {
134        let mut s = String::new(); 
135        for (key, value) in &self.choices.get_choices() {
136            s.push_str(&format!("{} => {}\n", key, value));
137        }
138        s
139    }
140
141    /**
142    Sample a trace from a vector of traces according to a categorical distribution. The weights for the distribution are 
143    the scores of the traces rescaled by a normalizing constant. This function is intended to be used in an importance
144    resampling algorithm.
145    */
146    pub fn sample_weighted_traces(traces : &Vec<Trace>) -> Option<Trace> {
147        if traces.len() == 0 {
148            None
149        } else {
150            let values : Vec<f64> = traces.iter().map(|x| x.log_score.exp()).collect();
151            let sum : f64 = values.iter().map(|x| x).sum(); 
152            let normalized_values : Vec<f64> = values.iter().map(|x| x / sum).collect(); 
153            let categorical = probability::distribution::Categorical::new(&normalized_values[..]); 
154            
155            Some(traces[categorical.sample(&mut distributions::Source(StdRng::from_entropy()))].clone())
156        }
157    }
158}
159
160//Implement equivelance for traces based on the log_score. 
161impl PartialEq for Trace {
162    fn eq(&self, other: &Trace) -> bool { 
163        self.log_score == other.log_score
164        }
165}
166
167impl PartialOrd for Trace {
168    fn partial_cmp(&self, other: &Trace) -> std::option::Option<std::cmp::Ordering> { 
169        if self.log_score > other.log_score {
170            Some(std::cmp::Ordering::Greater)
171        } else if self.log_score < other.log_score {
172            Some(std::cmp::Ordering::Less)
173        } else {
174            Some(std::cmp::Ordering::Equal)
175        }
176    }
177}