r_gen/
distributions.rs

1//! Distributions a generative model can sample from. 
2
3use std::{fmt, ops::{Add, Div, Mul, Sub}};
4
5use probability::prelude::*;
6use rand::{self, FromEntropy, prelude::ThreadRng, rngs::StdRng}; 
7use rand::distributions::Distribution as Distr;
8use statrs::function::gamma::gamma;
9
10/**
11A value struct that will handle possible values from the distributions.
12*/ 
13#[derive(Debug, Clone, PartialEq)]
14pub enum Value {
15    /// Represents a boolean. 
16    Boolean(bool), 
17    /// Represents an integer. 
18    Integer(i64), 
19    /// Represents a real number. 
20    Real(f64), 
21    /// Represents a vector.
22    Vector(Vec<Value>)
23}
24
25impl fmt::Display for Value {
26    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
27        match self {
28            Value::Boolean(b) => {
29                write!(formatter, "{}", b)
30            }, 
31            Value::Integer(i) => {
32                write!(formatter, "{}", i)
33            }, 
34            Value::Real(r) => {
35                write!(formatter, "{}", r)
36            },
37            Value::Vector(v) => {
38                write!(formatter, "{:?}", v)
39            }, 
40        }
41    }
42}
43
44//Implement type conversions
45impl From<bool> for Value {
46    fn from(b: bool) -> Self {
47        Value::Boolean(b)
48    }
49}
50impl From<f64> for Value {
51    fn from(f: f64) -> Self {
52        Value::Real(f)
53    }
54}
55impl From<f32> for Value {
56    fn from(f: f32) -> Self {
57        Value::Real(f as f64)
58    }
59}
60impl From<i8> for Value {
61    fn from(i: i8) -> Self {
62        Value::Integer(i as i64)
63    }
64}
65impl From<i16> for Value {
66    fn from(i: i16) -> Self {
67        Value::Integer(i as i64)
68    }
69}
70impl From<i32> for Value {
71    fn from(i: i32) -> Self {
72        Value::Integer(i as i64)
73    }
74}
75impl From<i64> for Value {
76    fn from(i: i64) -> Self {
77        Value::Integer(i)
78    }
79}
80impl From<u8> for Value {
81    fn from(i: u8) -> Self {
82        Value::Integer(i as i64)
83    }
84}
85impl From<u16> for Value {
86    fn from(i: u16) -> Self {
87        Value::Integer(i as i64)
88    }
89}
90impl From<u32> for Value {
91    fn from(i: u32) -> Self {
92        Value::Integer(i as i64)
93    }
94}
95impl From<u64> for Value {
96    fn from(i: u64) -> Self {
97        Value::Integer(i as i64)
98    }
99}
100impl From<usize> for Value {
101    fn from(i: usize) -> Self {
102        Value::Integer(i as i64)
103    }
104}
105
106impl Into<f64> for Value {
107    fn into(self) -> f64 {
108        match self {
109            Value::Real(r) => r, 
110            _ => panic!("Cannot convert non-Real to f64.")
111        }
112    }
113}
114impl Into<i64> for Value {
115    fn into(self) -> i64 {
116        match self {
117            Value::Integer(i) => i, 
118            _ => panic!("Cannot convert non-Integer to i64.")
119        }
120    }
121}
122impl Into<bool> for Value {
123    fn into(self) -> bool {
124        match self {
125            Value::Boolean(b) => b, 
126            _ => panic!("Cannot convert non-Boolean to bool.")
127        }
128    }
129}
130
131
132//Implement mathematical operations for the functions.
133impl Add<Value> for Value {
134    type Output = Value;
135
136    fn add(self, rhs: Value) -> Self::Output {
137        match self {
138            Self::Integer(i1) => {
139                match rhs {
140                    Self::Integer(i2) => Self::Integer(i1+i2), 
141                    Self::Real(r2) => Self::Real(i1 as f64 + r2), 
142                    Self::Boolean(_) => panic!("Unable to add boolean values."), 
143                    Self::Vector(_) => panic!("Ubalbe to add integer to vector.")
144                }
145            }, 
146            Self::Real(r1) => {
147                match rhs {
148                    Self::Integer(i2) => Self::Real(r1 + i2 as f64), 
149                    Self::Real(r2) => Self::Real(r1 + r2), 
150                    Self::Boolean(_) => panic!("Unable to add boolean values."), 
151                    Self::Vector(_) => panic!("Unable to add Real value to Vector")
152                }
153            }, 
154            Self::Boolean(_) => panic!("Unable to add boolean values."),
155            Self::Vector(vl) => {
156                match rhs {
157                    Self::Vector(vr) => Self::Vector(vl.iter().zip(vr.iter()).map(|(l, r)| l.clone() + r.clone()).collect()), 
158                    _ => panic!("Unable to add Vector to non-Vector.")
159                }
160            }
161        }
162    }
163}
164
165impl Sub<Value> for Value {
166    type Output = Value;
167
168    fn sub(self, rhs: Value) -> Self::Output {
169        match self {
170            Self::Integer(i1) => {
171                match rhs {
172                    Self::Integer(i2) => Self::Integer(i1 - i2), 
173                    Self::Real(r2) => Self::Real(i1 as f64 - r2), 
174                    Self::Boolean(_) => panic!("Unable to subtract boolean values."), 
175                    Self::Vector(_) => panic!("Ubalbe to subtract vector form integer.")
176                }
177            }, 
178            Self::Real(r1) => {
179                match rhs {
180                    Self::Integer(i2) => Self::Real(r1 - i2 as f64), 
181                    Self::Real(r2) => Self::Real(r1 - r2), 
182                    Self::Boolean(_) => panic!("Unable to subtract boolean values."),
183                    Self::Vector(_) => panic!("Unalbe to subtract Real from Vector.")
184                }
185            }, 
186            Self::Boolean(_) => panic!("Unable to subtract boolean values."),
187            Self::Vector(vl) => {
188                match rhs {
189                    Self::Vector(vr) => Self::Vector(vl.iter().zip(vr.iter()).map(|(l, r)| l.clone() - r.clone()).collect()), 
190                    _ => panic!("Unable to subtract non-Vector from Vector.")
191                }
192            }
193        }
194    }
195}
196
197impl Mul<Value> for Value {
198    type Output = Value;
199
200    fn mul(self, rhs: Value) -> Self::Output {
201        match self {
202            Self::Integer(i1) => {
203                match rhs {
204                    Self::Integer(i2) => Self::Integer(i1 * i2), 
205                    Self::Real(r2) => Self::Real(i1 as f64 * r2), 
206                    Self::Boolean(_) => panic!("Unable to multiply boolean values."), 
207                    Self::Vector(_) => panic!("Unable to multiply vectors.")
208                }
209            }, 
210            Self::Real(r1) => {
211                match rhs {
212                    Self::Integer(i2) => Self::Real(r1 * i2 as f64), 
213                    Self::Real(r2) => Self::Real(r1 * r2), 
214                    Self::Boolean(_) => panic!("Unable to multiply boolean values."),
215                    Self::Vector(_) => panic!("Unable to multiply vectors.")
216                }
217            }, 
218            Self::Boolean(_) => panic!("Unable to multiply boolean values."),
219            Self::Vector(_) => panic!("Unable to multiply vectors.")
220        }
221    }
222}
223
224impl Div<Value> for Value {
225    type Output = Value;
226
227    fn div(self, rhs: Value) -> Self::Output {
228        match self {
229            Self::Integer(i1) => {
230                match rhs {
231                    Self::Integer(i2) => Self::Integer(i1 / i2), 
232                    Self::Real(r2) => Self::Real(i1 as f64 / r2), 
233                    Self::Boolean(_) => panic!("Unable to divide boolean values."),
234                    Self::Vector(_) => panic!("Unable to divide vectors.")
235                }
236            }, 
237            Self::Real(r1) => {
238                match rhs {
239                    Self::Integer(i2) => Self::Real(r1 / i2 as f64), 
240                    Self::Real(r2) => Self::Real(r1 / r2), 
241                    Self::Boolean(_) => panic!("Unable to divide boolean values."),
242                    Self::Vector(_) => panic!("Unable to divide vectors.")
243                }
244            }, 
245            Self::Boolean(_) => panic!("Unable to divide boolean values."),
246            Self::Vector(_) => panic!("Unable to divide vectors.")
247        }
248    }
249}
250
251
252
253
254
255/// Enum that uniquely discribes a given distribution. 
256#[derive(Debug)]
257pub enum Distribution {
258    /// A Bernoulli distribution with paramater p.
259    Bernoulli(f64),     
260    /// A Binomial distribution with paramaters n and p.    
261    Binomial(i64, f64),     
262    /// A Normal distribution with paramaters mu and sigma.
263    Normal(f64, f64),       
264    /// A Gamma distribution with parameters alpha and beta.
265    Gamma(f64, f64),        
266    /// A Beta distribution with parameters alpha and beta.
267    Beta(f64, f64),         
268    /// A Lognormal distribution with paramaters mu and sigma.
269    LogNormal(f64, f64), 
270    /// A Categorical distribution with weights equal to p.
271    Categorical(Vec<f64>), 
272    /// A Dirichlet distribution that returns a vector of degree n. 
273    Dirichlet(Vec<f64>),
274}
275
276/**
277A trait that you can implement to create your own distributions to sample from in a 
278generative model. 
279*/
280pub trait Sampleable {
281    /// Sample a value from the distribution. 
282    fn sample(&self) -> Value; 
283    /// Compute the liklihood of a given value being sampled from the distribution.
284    fn liklihood(&self, value : &Value) -> Result<f64, &str>; 
285}
286
287/// A struct that holds a source of randomness for the various distributions.
288pub(crate) struct Source<T>(pub T);
289impl<T: rand::RngCore> source::Source for Source<T> {
290    fn read_u64(&mut self) -> u64 {
291        self.0.next_u64()
292    }
293}
294
295impl Sampleable for Distribution {
296    /// Sample from the distribution and return the value sampled. 
297    fn sample(&self) -> Value {
298        match self {
299            Distribution::Bernoulli(p) => {
300                let d = rand::distributions::Bernoulli::new(*p);
301                let v = d.sample(&mut ThreadRng::default());
302                Value::Boolean(v)
303            }, 
304            Distribution::Binomial(n, p) => {
305                let b = probability::distribution::Binomial::new(*n as usize, *p); 
306                Value::Integer(b.sample(&mut Source(StdRng::from_entropy())) as i64)
307            }, 
308            Distribution::Normal(mu, sigma_squared) => {
309                let n = probability::distribution::Gaussian::new(*mu, *sigma_squared); 
310                Value::Real(n.sample(&mut Source(StdRng::from_entropy())))
311            }, 
312            Distribution::Gamma(alpha, beta) => {
313                let g = probability::distribution::Gamma::new(*alpha, *beta); 
314                Value::Real(g.sample(&mut Source(StdRng::from_entropy())))
315            }, 
316            Distribution::Beta(alpha, beta) => {
317                let b = probability::distribution::Beta::new(*alpha, *beta, 0.0, 1.0);
318                Value::Real(b.sample(&mut Source(StdRng::from_entropy()))) 
319            }, 
320            Distribution::LogNormal(mu, sigma) => {
321                let n = probability::distribution::Lognormal::new(*mu, *sigma); 
322                Value::Real(n.sample(&mut Source(StdRng::from_entropy())))
323            }, 
324            Distribution::Categorical(v) => {
325                let c = probability::distribution::Categorical::new(&v[..]); 
326                Value::Integer(c.sample(&mut Source(StdRng::from_entropy())) as i64)
327            }, 
328            Distribution::Dirichlet(xs) => {
329                let ys : Vec<f64> = xs.iter().map(|x|{ 
330                    let beta = probability::distribution::Beta::new(*x, 1.0, 0.0, 1.0); 
331                    beta.sample(&mut Source(StdRng::from_entropy()))
332                }).collect(); 
333                let sum : f64 = ys.iter().sum(); 
334                let ys = ys.iter().map(|y| Value::Real(y / sum)).collect(); 
335                Value::Vector(ys)
336            }
337        }
338    }
339
340    
341    /**
342    Compute the liklihood of a value given a distribution (returns the log liklihood.) 
343    # Errors 
344    This function will return an Err if you try to determine the liklihood of a variant of the ```Value``` enum that 
345    the distribution does not produce. For example, trying to get the liklihood of a real number from a bernoulli 
346    distribution will return an Err.
347    */
348    fn liklihood(&self, value : &Value) -> Result<f64, &str> {
349        match self {
350            Distribution::Bernoulli(p) => {
351                match value {
352                    Value::Boolean(b) => {
353                        match b {
354                            true  => Ok(p.ln()), 
355                            false => Ok((1.0 - p).ln()), 
356                        }
357                    }, 
358                    _ => Err("Value of wrong type, expected Boolean.")
359                }
360            }, 
361            Distribution::Binomial(n, p) => {
362                match value {
363                    Value::Integer(k) => {
364                        let norm = probability::distribution::Binomial::new(*n as usize, *p);
365                        Ok(norm.mass(*k as usize).ln())
366                    }, 
367                    _ => Err("Value of wrong type, expected Integer.")
368                }
369            }, 
370            Distribution::Normal(mu, sigma_squared) => {
371                match value {
372                    Value::Real(n) => {
373                        let norm = probability::distribution::Gaussian::new(*mu, *sigma_squared); 
374                        Ok(norm.density(*n).ln())
375                    }, 
376                    _ => Err("Value of wrong type, expected Real.")
377                }
378            },
379            Distribution::Gamma(alpha, beta) => {
380                match value {
381                    Value::Real(n) => {
382                        let g = probability::distribution::Gaussian::new(*alpha, *beta); 
383                        Ok(g.density(*n).ln())
384                    }, 
385                    _ => Err("Value of wrong type, expected Real.")
386                }
387            }, 
388            Distribution::Beta(alpha, beta) => {
389                match value {
390                    Value::Real(n) => {
391                        let b = probability::distribution::Beta::new(*alpha, *beta, 0.0, 1.0);
392                        Ok(b.density(*n).ln())
393                    }, 
394                    _ => Err("Value of wrong type, expected Real.")
395                }
396            }, 
397            Distribution::LogNormal(mu, sigma) => {
398                match value {
399                    Value::Real(n) => {
400                        let l = probability::distribution::Lognormal::new(*mu, *sigma); 
401                        Ok(l.density(*n).ln())
402                    }, 
403                    _ => Err("Value of wrong type, expected Real.")
404                }
405            }, 
406            Distribution::Categorical(p) => {
407                match value {
408                    Value::Integer(i) => {
409                        let c = probability::distribution::Categorical::new(&p[..]); 
410                        Ok(c.mass(*i as usize).ln())
411                    }, 
412                    _ => Err("Value of wrong type, expected Integer.")
413                }
414            }, 
415            Distribution::Dirichlet(a) => {
416                match value {
417                    Value::Vector(x) => {
418                        let ba_numerator : f64 = a.iter().fold(1.0, |acc, x| acc * gamma(*x)); 
419                        let ba_denominator : f64 = gamma(a.iter().sum());  
420                        let ba = ba_numerator / ba_denominator; 
421
422                        Ok(((1.0/ba) * x.iter().map(|x| {
423                            match *x {
424                                Value::Real(x) => x, 
425                                _ => 1.0, 
426                            }
427                        }).zip(a.iter()).fold(1.0, |acc, (x, a)| {
428                            acc * x.powf(a-1.0)
429                        })).ln())
430                    }, 
431                    _ => Err("Value of wrong type, expected Vector.")
432                }
433            }
434        }
435    }
436}