r_gen/trace.rs
1//! Trace objects that represent a run of a generative model.
2use std::{collections::HashMap, ops::Index};
3
4use probability::distribution::Sample;
5use rand::{FromEntropy, rngs::StdRng};
6
7use super::distributions::{self, Value};
8
9/// A struct to hold all of the random choices made during the execution of a generative model.
10#[derive(Clone, Debug)]
11pub struct Choicemap{
12 values : HashMap<String, Value>
13}
14
15//Implement standard functions for choice maps.
16impl Choicemap {
17 /// Create a new, blank choice map.
18 /// # Example
19 /// ```
20 /// use r_gen::trace::Choicemap;
21 /// use r_gen::distributions::Value;
22 /// let mut choicemap = Choicemap::new();
23 /// choicemap.add_choice("p", Value::Real(0.5));
24 /// ```
25 pub fn new() -> Choicemap {
26 Choicemap{ values : HashMap::new() }
27 }
28
29 /// Create a new choicemap with given choices in it.
30 /// # Example
31 /// ```
32 /// use r_gen::trace::Choicemap;
33 /// use r_gen::distributions::Value;
34 /// let mut choicemap = Choicemap::from(vec![("p", Value::Real(0.5))]);
35 /// ```
36 pub fn from(choices : Vec<(&str, Value)>) -> Choicemap {
37 let mut res = Choicemap::new();
38 choices.iter().for_each(|(s, v)| res.add_choice(*s, v.clone()));
39 res
40 }
41
42 /// Add a choice to this choicemap.
43 /// # Example
44 /// ```
45 /// use r_gen::trace::Choicemap;
46 /// use r_gen::distributions::Value;
47 /// let mut choicemap = Choicemap::new();
48 /// choicemap.add_choice("p", Value::Real(0.5));
49 /// ```
50 pub fn add_choice(&mut self, identifier : &str, value : Value) {
51 self.values.insert(identifier.to_string(), value);
52 }
53
54 /// Get a list of the choices that were made in the generative model.
55 pub fn get_choices(&self) -> Vec<(&str, Value)> {
56 self.values.keys().map(|k| (k.as_str(), self.values.get(k).unwrap().clone())).collect()
57 }
58
59 /// Check whether or not the given key is already in the choicemap.
60 pub fn contains_key(&self, key : &str) -> bool {
61 self.values.contains_key(key)
62 }
63}
64
65impl Index<&str> for Choicemap {
66 type Output = Value;
67
68 fn index(&self, index: &str) -> &Self::Output {
69 match self.values.get(index) {
70 Some(v) => v,
71 None => panic!("Value not present in choicemap.")
72 }
73 }
74}
75
76impl Index<&String> for Choicemap {
77 type Output = Value;
78
79 fn index(&self, index: &String) -> &Self::Output {
80 match self.values.get(index.as_str()) {
81 Some(v) => v,
82 None => panic!("Value not present in choicemap.")
83 }
84 }
85}
86
87/**
88The trace struct. This holds information about the execution of a gnerative model.
89*/
90#[derive(Debug, Clone)]
91pub struct Trace {
92 /// The log joint liklihood of all of the random decisions in the trace.
93 pub log_score : f64,
94 /// The Choicemap that holds the list of the actual decisions that were made in the execution of the generative model.
95 pub choices : Choicemap
96}
97
98
99impl Trace {
100 /**
101 Create a new blank trace. It begins with an empty choice map and a log score of 0 (which corresponds to a
102 probability of 1.0 when exponentiated.)
103 */
104 pub fn new() -> Trace {
105 Trace{ log_score : 0.0, choices : Choicemap::new() }
106 }
107
108 /**
109 Update the logscore of a given trace.
110 */
111 pub(crate) fn update_logscore(&mut self, new_value : f64) {
112 self.log_score = self.log_score + new_value;
113 }
114
115 /**
116 Return a string that discribes the random decisions made by the model in this trace.
117 # Example
118 ```
119 #[macro_use]
120 use r_gen::{sample, r_gen};
121 use r_gen::{simulate, distributions::{Value, Distribution}, trace::{Choicemap, Trace}};
122 use std::rc::Rc;
123
124 #[r_gen]
125 fn my_biased_coin_model(():()){
126 let p = sample!(format!("p"), Distribution::Beta(1.0, 1.0)); //Sample p from a uniform.
127 sample!(format!("num_heads"), Distribution::Binomial(100, p.into())); //Flip 100 coins where P(Heads)=p
128 }
129 let (trace, result) = simulate(&mut my_biased_coin_model, ());
130 println!("{}", trace.get_trace_string());
131 ```
132 */
133 pub fn get_trace_string(&self) -> String {
134 let mut s = String::new();
135 for (key, value) in &self.choices.get_choices() {
136 s.push_str(&format!("{} => {}\n", key, value));
137 }
138 s
139 }
140
141 /**
142 Sample a trace from a vector of traces according to a categorical distribution. The weights for the distribution are
143 the scores of the traces rescaled by a normalizing constant. This function is intended to be used in an importance
144 resampling algorithm.
145 */
146 pub fn sample_weighted_traces(traces : &Vec<Trace>) -> Option<Trace> {
147 if traces.len() == 0 {
148 None
149 } else {
150 let values : Vec<f64> = traces.iter().map(|x| x.log_score.exp()).collect();
151 let sum : f64 = values.iter().map(|x| x).sum();
152 let normalized_values : Vec<f64> = values.iter().map(|x| x / sum).collect();
153 let categorical = probability::distribution::Categorical::new(&normalized_values[..]);
154
155 Some(traces[categorical.sample(&mut distributions::Source(StdRng::from_entropy()))].clone())
156 }
157 }
158}
159
160//Implement equivelance for traces based on the log_score.
161impl PartialEq for Trace {
162 fn eq(&self, other: &Trace) -> bool {
163 self.log_score == other.log_score
164 }
165}
166
167impl PartialOrd for Trace {
168 fn partial_cmp(&self, other: &Trace) -> std::option::Option<std::cmp::Ordering> {
169 if self.log_score > other.log_score {
170 Some(std::cmp::Ordering::Greater)
171 } else if self.log_score < other.log_score {
172 Some(std::cmp::Ordering::Less)
173 } else {
174 Some(std::cmp::Ordering::Equal)
175 }
176 }
177}