r_gen/
lib.rs

1//! A framework for writing generative models in the rust programming language. 
2/*!
3# Example of Importance Sampling
4```rust
5use r_gen::{sample, r_gen}; 
6use r_gen::{simulate, generate, distributions::{Value, Distribution}, trace::{Choicemap, Trace}}; 
7use std::rc::Rc;
8fn main() {
9    //Define our generative model. 
10    #[r_gen]
11    fn my_model(():()){
12        let p = sample!(format!("p"), Distribution::Beta(1.0, 1.0)); 
13        sample!(format!("num_heads"), Distribution::Binomial(100, p.into()));
14    }
15
16    //Run the model once in the forward direction and record the observations. 
17    let (t, _) : (Trace, _)= simulate(&mut my_model, ());
18    let choices = Choicemap::from(vec![("num_heads", t.choices["num_heads"].clone())]);
19
20    //Perform importance resampling to get an estimate for the value of p. 
21    let mut traces = Vec::new();
22    for _ in 0..1000 {
23        let (gt, _) : (Trace, _)= generate(&mut my_model, (), &choices);
24        traces.push(gt); 
25    }
26    
27    println!("Actual value for p:\t {}", t.choices["p"]); 
28    println!("Generated value for p:\t {}", Trace::sample_weighted_traces(&traces).unwrap().choices["p"]); 
29}
30```
31Outputs:
32```shell
33Actual value for p:      0.8011431168181488
34Generated value for p:   0.7879998086169554
35```
36*/
37
38#![warn(missing_docs)]
39
40#[allow(unused_attributes)]
41#[macro_use]
42pub use r_gen_macro::r_gen; 
43
44#[allow(unused_imports)]
45#[macro_use]
46extern crate r_gen_macro;
47
48/**
49The macro that is used for sampling from a distribution. 
50# Example
51```
52use r_gen::{sample, r_gen}; 
53use r_gen::{simulate, distributions::{Value, Distribution}, trace::{Choicemap, Trace}}; 
54use std::rc::Rc;
55
56#[r_gen]
57fn my_model(():()) {
58    let p = sample!(format!("p"), Distribution::Bernoulli(0.5)); 
59    print!("p: {}", p);
60}
61simulate(&mut my_model, ()); 
62``` 
63Takes the form: identifier,  Distribution. The identifier will have the value sampled from the distribution stored in
64it. It can be used later. p will have type ```Value```.
65# Example (Store results in an array)
66```
67use r_gen::{sample, r_gen}; 
68use r_gen::{simulate, distributions::{Value, Distribution}, trace::{Choicemap, Trace}}; 
69use std::rc::Rc;
70
71#[r_gen] 
72fn flip_my_biased_coins((n, p) : (usize, f64)) {
73    let mut flips = vec![Value::Integer(0); n]; 
74    for i in 0..n {
75        flips[i] = sample!(format!("flip_{}", i), Distribution::Bernoulli(p)); 
76    }
77}
78let (tr, _) = simulate(&mut flip_my_biased_coins, (10, 0.5));
79println!("{}", tr.get_trace_string()); 
80```
81*/
82#[macro_export]
83macro_rules! sample {
84    ($sample_ident:ident $trace_ident:ident $name:expr, $dist:expr) => (
85        (Rc::get_mut(&mut $sample_ident).unwrap())(&$name, $dist, $trace_ident);
86    );
87}
88
89
90#[allow(unused_variables)]
91#[cfg(test)]
92mod tests { 
93    use std::rc::Rc;
94
95    use crate::{trace::{Trace, Choicemap}, distributions::{Distribution, Value}};
96    use crate::{simulate, generate}; 
97
98    #[test]
99    fn test_simulate(){
100        fn flip_biased_coin(mut sample : Rc<dyn FnMut(&String, Distribution, &mut Trace) -> Value>, trace : &mut Trace, p : f64) {
101            // flip ~ Bernoulli(p)
102            let flip = (Rc::get_mut(&mut sample).unwrap())(&String::from("flip"), Distribution::Bernoulli(p), trace);
103        }
104        let (t, _) : (Trace, _)= simulate(&mut flip_biased_coin, 0.2); 
105        println!("test_simulate flip_biased_coin trace: {:?}", t); 
106
107        fn flip_multiple_biased_coins(mut sample : Rc<dyn FnMut(&String, Distribution, &mut Trace) -> Value>, trace : &mut Trace, (n, p) : (i64, f64)) {
108            // heads ~ Bernoulli(p)
109            let heads = (Rc::get_mut(&mut sample).unwrap())(&String::from("heads"), Distribution::Binomial(n, p), trace);
110        }
111        let (t, _) : (Trace, _)= simulate(&mut flip_multiple_biased_coins, (5, 0.7)); 
112        println!("test_simulate flip_multiple_biased_coin trace: {:?}", t);
113    }
114
115    #[test]
116    fn test_generate(){
117        #[r_gen]
118        fn flip_multiple_biased_coins((n, p) : (i64, f64)) {
119            let heads = sample!(format!("heads"), Distribution::Binomial(n, p)); 
120            println!("Result of flips: {:?}", heads)
121        }
122        let mut constraints = Choicemap::new(); 
123        constraints.add_choice("heads", Value::Integer(4)); 
124        let (trace, _) : (Trace, _)= generate(&mut flip_multiple_biased_coins, (5, 0.7), &constraints); 
125        println!("Trace from generate: {:?}", trace);
126    }
127
128    #[test]
129    fn test_macros(){
130        #[r_gen]
131        fn my_coin_model(p : f64) {
132            let flip = sample!(format!("flip"), Distribution::Bernoulli(p)); 
133            println!("Result of flip: {:?}", flip)
134        }
135        let (trace, _) = simulate(&mut my_coin_model, 0.2); 
136        println!("testing macro: {:?}", trace);
137
138
139        #[r_gen]
140        fn flip_multiple_biased_coins((n, p) : (i64, f64)) {
141            let heads = sample!(format!("heads"), Distribution::Binomial(n, p)); 
142            println!("Result of flips: {:?}", heads)
143        }
144        let (trace, _) : (Trace, _)= simulate(&mut flip_multiple_biased_coins, (5, 0.7)); 
145        println!("tesing macro: {:?}", trace);
146
147        #[r_gen] 
148        fn flip_my_biased_coin((n, p) : (usize, f64)) {
149            let mut flips = vec![Value::Integer(0); n]; 
150            for i in 0..n {
151                flips[i] = sample!(format!("flips_{}", i), Distribution::Bernoulli(p)); 
152            }
153        }
154        let (trace, _) : (Trace, _)= simulate(&mut flip_my_biased_coin, (5 as usize, 0.7)); 
155        println!("my flip coin trace: {:?}", trace);
156
157        #[r_gen] 
158        fn flip_my_biased_coin2((n, p) : (usize, f64)) {
159            let mut flips = vec![Value::Integer(0); n]; 
160            for i in 0..n {
161                flips[i] = sample!(format!("flips_{}", i), Distribution::Bernoulli(p)); 
162            }
163            println!("flips: {:?}", flips); 
164        }
165        let (trace, _) : (Trace, _)= simulate(&mut flip_my_biased_coin2, (5 as usize, 0.7)); 
166        println!("my flip coin 2 trace: {:?}", trace);
167    }
168
169    #[test]
170    fn test_bernoulli(){
171        #[r_gen]
172        fn my_bernoulli(p : f64) {
173            let mut tests = vec![Value::Real(0.0); 100]; 
174            for i in 0..100 {
175                tests[i] = sample!(format!("tests_{}", i), Distribution::Bernoulli(p));
176            }
177            let mut tot : f64 = 0.0; 
178            for t in tests {
179                match t {
180                    Value::Boolean(true) => {
181                        tot = tot + 1.0; 
182                    }, 
183                    _ => ()
184                }
185            }
186            println!("P: {}\nResult of tests:{:?}", p, tot/100.0); 
187        }
188        let (_, _) = simulate(&mut my_bernoulli, 0.5); 
189    }
190
191    #[test]
192    fn test_binom(){
193        #[r_gen]
194        fn my_binomial((n, p): (i64, f64)) {
195            let mut tests = vec![Value::Real(0.0); 100]; 
196            for i in 0..100 {
197                tests[i] = sample!(format!("tests_{}", i), Distribution::Binomial(n, p));
198            }
199            let mut tot : f64 = 0.0; 
200            for t in tests {
201                match t {
202                    Value::Integer(i) => {
203                        tot = tot + (i as f64); 
204                    }, 
205                    _ => ()
206                }
207            }
208            println!("N*P: {}\nResult of tests:{:?}", ((n as f64)*p), tot/100.0); 
209        }
210        let (_, _) = simulate(&mut my_binomial, (100, 0.5)); 
211    }
212
213    #[test]
214    fn test_normal(){
215        #[r_gen]
216        fn my_normal((m, s): (f64, f64)) {
217            let mut tests = vec![Value::Real(0.0); 100]; 
218            for i in 0..100 {
219                tests[i] = sample!(format!("tests_{}", i), Distribution::Normal(m, s));
220            }
221            let mut tot : f64 = 0.0; 
222            for t in tests {
223                match t {
224                    Value::Real(r) => {
225                        tot = tot + r;
226                    },
227                    _ => ()
228                }
229            }
230            println!("Mean: {}\nResult of tests:{:?}", m, tot/100.0); 
231        }
232        let (_, _) = simulate(&mut my_normal, (60.0, 10.0)); 
233    }
234
235    #[test]
236    fn test_importance_resampling(){
237        //Define our generative model. 
238        #[r_gen]
239        fn my_model(():()){
240            let p = sample!(format!("p"), Distribution::Beta(1.0, 1.0)); 
241            sample!(format!("num_heads"), Distribution::Binomial(100, p.into()));
242        }
243
244        //Run the model once in the forward direction and record the observations. 
245        let (t, _) : (Trace, _)= simulate(&mut my_model, ());
246        let choices = Choicemap::from(vec![("num_heads", t.choices["num_heads"].clone())]);
247
248        //Perform importance resampling to get an estimate for the value of p. 
249        let mut traces = Vec::new();
250        for _ in 0..1000 {
251            let (gt, _) : (Trace, _)= generate(&mut my_model, (), &choices);
252            traces.push(gt); 
253        }
254        
255        println!("Actual value for p:\t {}", t.choices["p"]); 
256        println!("Generated value for p:\t {}", Trace::sample_weighted_traces(&traces).unwrap().choices["p"]); 
257    }
258
259    #[test] 
260    fn test_trace_string(){
261        #[r_gen]
262        fn my_biased_coin_model(():()){
263            let p = sample!(format!("p"), Distribution::Beta(1.0, 1.0));            //Sample p from a uniform. 
264            sample!(format!("num_heads"), Distribution::Binomial(100, p.into()));   //Flip 100 coins where P(Heads)=p
265        }
266        println!("GO"); 
267        let (trace, result) = simulate(&mut my_biased_coin_model, ()); 
268        println!("Trace String: \n{}", trace.get_trace_string());
269    }
270
271}
272
273
274use std::rc::Rc;
275
276use self::{distributions::{Value, Sampleable}, trace::{Choicemap, Trace}};
277
278//Re-export the other sub modules. 
279pub mod distributions; 
280pub mod trace; 
281
282/**
283Run the given generative model in the forward direction.
284As input, it takes a generative model (function with the #[r_gen] tag) and the arguments to that function. 
285Returns a tuple of the trace generated by running the function and the return value of the function itself.
286# Example
287```
288use r_gen::{sample, r_gen}; 
289use r_gen::{simulate, distributions::{Value, Distribution}, trace::{Choicemap, Trace}}; 
290use std::rc::Rc;
291#[r_gen]
292fn my_biased_coin_model(():()){
293    let p = sample!(format!("p"), Distribution::Beta(1.0, 1.0));            //Sample p from a uniform. 
294    sample!(format!("num_heads"), Distribution::Binomial(100, p.into()));   //Flip 100 coins where P(Heads)=p
295}
296println!("GO"); 
297let (trace, result) = simulate(&mut my_biased_coin_model, ()); 
298println!("Trace String: \n{}", trace.get_trace_string());
299```
300Outputs: 
301```shell
302Trace String: 
303num_heads => 37
304p => 0.38724904991570935
305```
306*/
307pub fn simulate<F, A, R, S : Sampleable>(generative_function : &mut F, arguments : A) -> (Trace, R) 
308where 
309F : FnMut(Rc<dyn FnMut(&String, S, &mut Trace) -> Value>, &mut Trace, A) -> R, 
310{
311    let sample = |name : &String, dist : S, trace : &mut Trace| {
312        let value = dist.sample();                              //Sample a value. 
313        let prob = dist.liklihood(&value).unwrap();               //Compute the probability of this value. 
314        trace.update_logscore(prob);                         //Update the log score with the pdf. 
315        trace.choices.add_choice(&name, value.clone());     //Add the choice to the hashmap.
316        value
317    }; 
318    let mut trace = Trace::new(); 
319    let return_value = generative_function(Rc::new(sample), &mut trace, arguments); 
320    (trace, return_value)
321}
322
323/**
324Run a generative model in the forward direction, fixing certian decisions or observations.
325As input, it takes a generative model (function with the #[r_gen] tag), the arguments to that function, and a choicemap of the observed variables. 
326Returns a tuple of the trace generated by running the function and the return value of the function itself.
327# Example
328```
329use r_gen::{sample, r_gen}; 
330use r_gen::{generate, distributions::{Value, Distribution}, trace::{Choicemap, Trace}}; 
331use std::rc::Rc;
332#[r_gen]
333fn my_biased_coin_model(():()){
334    let p = sample!(format!("p"), Distribution::Beta(1.0, 1.0));            //Sample p from a uniform. 
335    sample!(format!("num_heads"), Distribution::Binomial(100, p.into()));   //Flip 100 coins where P(Heads)=p
336}
337let choices = Choicemap::from(vec![("p", Value::Real(0.1))]);     //Fix the value p=0.1
338let (trace, result) = generate(&mut my_biased_coin_model, (), &choices); 
339```
340*/
341pub fn generate<F, A, R, S: Sampleable>(generative_function : &mut F, arguments : A, conditions : &Choicemap) -> (Trace, R) 
342where 
343F : FnMut(Rc<dyn FnMut(&String, S, &mut Trace) -> Value>, &mut Trace, A) -> R, 
344{   
345    let sample = |name : &String, dist : S, trace : &mut Trace| {
346        let mut _value = Value::Real(0.0); 
347        _value = if trace.choices.contains_key(name) {
348            trace.choices[name.as_str()].clone()
349        } else {
350            dist.sample()
351        };
352        let prob = dist.liklihood(&_value).unwrap();                    //Compute the probability of this value. 
353        trace.update_logscore(prob);                               //Update the log score with the pdf. 
354        trace.choices.add_choice(name.as_str(), _value.clone());  //Add the choice to the hashmap.
355        _value 
356    };
357    let mut trace = Trace::new(); 
358    for (k, v) in conditions.get_choices() {
359        trace.choices.add_choice(k, v); 
360    }
361    let return_value = generative_function(Rc::new(sample), &mut trace, arguments);
362    (trace, return_value)
363}