1use std::{fmt, ops::{Add, Div, Mul, Sub}};
4
5use probability::prelude::*;
6use rand::{self, FromEntropy, prelude::ThreadRng, rngs::StdRng};
7use rand::distributions::Distribution as Distr;
8use statrs::function::gamma::gamma;
9
10#[derive(Debug, Clone, PartialEq)]
14pub enum Value {
15 Boolean(bool),
17 Integer(i64),
19 Real(f64),
21 Vector(Vec<Value>)
23}
24
25impl fmt::Display for Value {
26 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
27 match self {
28 Value::Boolean(b) => {
29 write!(formatter, "{}", b)
30 },
31 Value::Integer(i) => {
32 write!(formatter, "{}", i)
33 },
34 Value::Real(r) => {
35 write!(formatter, "{}", r)
36 },
37 Value::Vector(v) => {
38 write!(formatter, "{:?}", v)
39 },
40 }
41 }
42}
43
44impl From<bool> for Value {
46 fn from(b: bool) -> Self {
47 Value::Boolean(b)
48 }
49}
50impl From<f64> for Value {
51 fn from(f: f64) -> Self {
52 Value::Real(f)
53 }
54}
55impl From<f32> for Value {
56 fn from(f: f32) -> Self {
57 Value::Real(f as f64)
58 }
59}
60impl From<i8> for Value {
61 fn from(i: i8) -> Self {
62 Value::Integer(i as i64)
63 }
64}
65impl From<i16> for Value {
66 fn from(i: i16) -> Self {
67 Value::Integer(i as i64)
68 }
69}
70impl From<i32> for Value {
71 fn from(i: i32) -> Self {
72 Value::Integer(i as i64)
73 }
74}
75impl From<i64> for Value {
76 fn from(i: i64) -> Self {
77 Value::Integer(i)
78 }
79}
80impl From<u8> for Value {
81 fn from(i: u8) -> Self {
82 Value::Integer(i as i64)
83 }
84}
85impl From<u16> for Value {
86 fn from(i: u16) -> Self {
87 Value::Integer(i as i64)
88 }
89}
90impl From<u32> for Value {
91 fn from(i: u32) -> Self {
92 Value::Integer(i as i64)
93 }
94}
95impl From<u64> for Value {
96 fn from(i: u64) -> Self {
97 Value::Integer(i as i64)
98 }
99}
100impl From<usize> for Value {
101 fn from(i: usize) -> Self {
102 Value::Integer(i as i64)
103 }
104}
105
106impl Into<f64> for Value {
107 fn into(self) -> f64 {
108 match self {
109 Value::Real(r) => r,
110 _ => panic!("Cannot convert non-Real to f64.")
111 }
112 }
113}
114impl Into<i64> for Value {
115 fn into(self) -> i64 {
116 match self {
117 Value::Integer(i) => i,
118 _ => panic!("Cannot convert non-Integer to i64.")
119 }
120 }
121}
122impl Into<bool> for Value {
123 fn into(self) -> bool {
124 match self {
125 Value::Boolean(b) => b,
126 _ => panic!("Cannot convert non-Boolean to bool.")
127 }
128 }
129}
130
131
132impl Add<Value> for Value {
134 type Output = Value;
135
136 fn add(self, rhs: Value) -> Self::Output {
137 match self {
138 Self::Integer(i1) => {
139 match rhs {
140 Self::Integer(i2) => Self::Integer(i1+i2),
141 Self::Real(r2) => Self::Real(i1 as f64 + r2),
142 Self::Boolean(_) => panic!("Unable to add boolean values."),
143 Self::Vector(_) => panic!("Ubalbe to add integer to vector.")
144 }
145 },
146 Self::Real(r1) => {
147 match rhs {
148 Self::Integer(i2) => Self::Real(r1 + i2 as f64),
149 Self::Real(r2) => Self::Real(r1 + r2),
150 Self::Boolean(_) => panic!("Unable to add boolean values."),
151 Self::Vector(_) => panic!("Unable to add Real value to Vector")
152 }
153 },
154 Self::Boolean(_) => panic!("Unable to add boolean values."),
155 Self::Vector(vl) => {
156 match rhs {
157 Self::Vector(vr) => Self::Vector(vl.iter().zip(vr.iter()).map(|(l, r)| l.clone() + r.clone()).collect()),
158 _ => panic!("Unable to add Vector to non-Vector.")
159 }
160 }
161 }
162 }
163}
164
165impl Sub<Value> for Value {
166 type Output = Value;
167
168 fn sub(self, rhs: Value) -> Self::Output {
169 match self {
170 Self::Integer(i1) => {
171 match rhs {
172 Self::Integer(i2) => Self::Integer(i1 - i2),
173 Self::Real(r2) => Self::Real(i1 as f64 - r2),
174 Self::Boolean(_) => panic!("Unable to subtract boolean values."),
175 Self::Vector(_) => panic!("Ubalbe to subtract vector form integer.")
176 }
177 },
178 Self::Real(r1) => {
179 match rhs {
180 Self::Integer(i2) => Self::Real(r1 - i2 as f64),
181 Self::Real(r2) => Self::Real(r1 - r2),
182 Self::Boolean(_) => panic!("Unable to subtract boolean values."),
183 Self::Vector(_) => panic!("Unalbe to subtract Real from Vector.")
184 }
185 },
186 Self::Boolean(_) => panic!("Unable to subtract boolean values."),
187 Self::Vector(vl) => {
188 match rhs {
189 Self::Vector(vr) => Self::Vector(vl.iter().zip(vr.iter()).map(|(l, r)| l.clone() - r.clone()).collect()),
190 _ => panic!("Unable to subtract non-Vector from Vector.")
191 }
192 }
193 }
194 }
195}
196
197impl Mul<Value> for Value {
198 type Output = Value;
199
200 fn mul(self, rhs: Value) -> Self::Output {
201 match self {
202 Self::Integer(i1) => {
203 match rhs {
204 Self::Integer(i2) => Self::Integer(i1 * i2),
205 Self::Real(r2) => Self::Real(i1 as f64 * r2),
206 Self::Boolean(_) => panic!("Unable to multiply boolean values."),
207 Self::Vector(_) => panic!("Unable to multiply vectors.")
208 }
209 },
210 Self::Real(r1) => {
211 match rhs {
212 Self::Integer(i2) => Self::Real(r1 * i2 as f64),
213 Self::Real(r2) => Self::Real(r1 * r2),
214 Self::Boolean(_) => panic!("Unable to multiply boolean values."),
215 Self::Vector(_) => panic!("Unable to multiply vectors.")
216 }
217 },
218 Self::Boolean(_) => panic!("Unable to multiply boolean values."),
219 Self::Vector(_) => panic!("Unable to multiply vectors.")
220 }
221 }
222}
223
224impl Div<Value> for Value {
225 type Output = Value;
226
227 fn div(self, rhs: Value) -> Self::Output {
228 match self {
229 Self::Integer(i1) => {
230 match rhs {
231 Self::Integer(i2) => Self::Integer(i1 / i2),
232 Self::Real(r2) => Self::Real(i1 as f64 / r2),
233 Self::Boolean(_) => panic!("Unable to divide boolean values."),
234 Self::Vector(_) => panic!("Unable to divide vectors.")
235 }
236 },
237 Self::Real(r1) => {
238 match rhs {
239 Self::Integer(i2) => Self::Real(r1 / i2 as f64),
240 Self::Real(r2) => Self::Real(r1 / r2),
241 Self::Boolean(_) => panic!("Unable to divide boolean values."),
242 Self::Vector(_) => panic!("Unable to divide vectors.")
243 }
244 },
245 Self::Boolean(_) => panic!("Unable to divide boolean values."),
246 Self::Vector(_) => panic!("Unable to divide vectors.")
247 }
248 }
249}
250
251
252
253
254
255#[derive(Debug)]
257pub enum Distribution {
258 Bernoulli(f64),
260 Binomial(i64, f64),
262 Normal(f64, f64),
264 Gamma(f64, f64),
266 Beta(f64, f64),
268 LogNormal(f64, f64),
270 Categorical(Vec<f64>),
272 Dirichlet(Vec<f64>),
274}
275
276pub trait Sampleable {
281 fn sample(&self) -> Value;
283 fn liklihood(&self, value : &Value) -> Result<f64, &str>;
285}
286
287pub(crate) struct Source<T>(pub T);
289impl<T: rand::RngCore> source::Source for Source<T> {
290 fn read_u64(&mut self) -> u64 {
291 self.0.next_u64()
292 }
293}
294
295impl Sampleable for Distribution {
296 fn sample(&self) -> Value {
298 match self {
299 Distribution::Bernoulli(p) => {
300 let d = rand::distributions::Bernoulli::new(*p);
301 let v = d.sample(&mut ThreadRng::default());
302 Value::Boolean(v)
303 },
304 Distribution::Binomial(n, p) => {
305 let b = probability::distribution::Binomial::new(*n as usize, *p);
306 Value::Integer(b.sample(&mut Source(StdRng::from_entropy())) as i64)
307 },
308 Distribution::Normal(mu, sigma_squared) => {
309 let n = probability::distribution::Gaussian::new(*mu, *sigma_squared);
310 Value::Real(n.sample(&mut Source(StdRng::from_entropy())))
311 },
312 Distribution::Gamma(alpha, beta) => {
313 let g = probability::distribution::Gamma::new(*alpha, *beta);
314 Value::Real(g.sample(&mut Source(StdRng::from_entropy())))
315 },
316 Distribution::Beta(alpha, beta) => {
317 let b = probability::distribution::Beta::new(*alpha, *beta, 0.0, 1.0);
318 Value::Real(b.sample(&mut Source(StdRng::from_entropy())))
319 },
320 Distribution::LogNormal(mu, sigma) => {
321 let n = probability::distribution::Lognormal::new(*mu, *sigma);
322 Value::Real(n.sample(&mut Source(StdRng::from_entropy())))
323 },
324 Distribution::Categorical(v) => {
325 let c = probability::distribution::Categorical::new(&v[..]);
326 Value::Integer(c.sample(&mut Source(StdRng::from_entropy())) as i64)
327 },
328 Distribution::Dirichlet(xs) => {
329 let ys : Vec<f64> = xs.iter().map(|x|{
330 let beta = probability::distribution::Beta::new(*x, 1.0, 0.0, 1.0);
331 beta.sample(&mut Source(StdRng::from_entropy()))
332 }).collect();
333 let sum : f64 = ys.iter().sum();
334 let ys = ys.iter().map(|y| Value::Real(y / sum)).collect();
335 Value::Vector(ys)
336 }
337 }
338 }
339
340
341 fn liklihood(&self, value : &Value) -> Result<f64, &str> {
349 match self {
350 Distribution::Bernoulli(p) => {
351 match value {
352 Value::Boolean(b) => {
353 match b {
354 true => Ok(p.ln()),
355 false => Ok((1.0 - p).ln()),
356 }
357 },
358 _ => Err("Value of wrong type, expected Boolean.")
359 }
360 },
361 Distribution::Binomial(n, p) => {
362 match value {
363 Value::Integer(k) => {
364 let norm = probability::distribution::Binomial::new(*n as usize, *p);
365 Ok(norm.mass(*k as usize).ln())
366 },
367 _ => Err("Value of wrong type, expected Integer.")
368 }
369 },
370 Distribution::Normal(mu, sigma_squared) => {
371 match value {
372 Value::Real(n) => {
373 let norm = probability::distribution::Gaussian::new(*mu, *sigma_squared);
374 Ok(norm.density(*n).ln())
375 },
376 _ => Err("Value of wrong type, expected Real.")
377 }
378 },
379 Distribution::Gamma(alpha, beta) => {
380 match value {
381 Value::Real(n) => {
382 let g = probability::distribution::Gaussian::new(*alpha, *beta);
383 Ok(g.density(*n).ln())
384 },
385 _ => Err("Value of wrong type, expected Real.")
386 }
387 },
388 Distribution::Beta(alpha, beta) => {
389 match value {
390 Value::Real(n) => {
391 let b = probability::distribution::Beta::new(*alpha, *beta, 0.0, 1.0);
392 Ok(b.density(*n).ln())
393 },
394 _ => Err("Value of wrong type, expected Real.")
395 }
396 },
397 Distribution::LogNormal(mu, sigma) => {
398 match value {
399 Value::Real(n) => {
400 let l = probability::distribution::Lognormal::new(*mu, *sigma);
401 Ok(l.density(*n).ln())
402 },
403 _ => Err("Value of wrong type, expected Real.")
404 }
405 },
406 Distribution::Categorical(p) => {
407 match value {
408 Value::Integer(i) => {
409 let c = probability::distribution::Categorical::new(&p[..]);
410 Ok(c.mass(*i as usize).ln())
411 },
412 _ => Err("Value of wrong type, expected Integer.")
413 }
414 },
415 Distribution::Dirichlet(a) => {
416 match value {
417 Value::Vector(x) => {
418 let ba_numerator : f64 = a.iter().fold(1.0, |acc, x| acc * gamma(*x));
419 let ba_denominator : f64 = gamma(a.iter().sum());
420 let ba = ba_numerator / ba_denominator;
421
422 Ok(((1.0/ba) * x.iter().map(|x| {
423 match *x {
424 Value::Real(x) => x,
425 _ => 1.0,
426 }
427 }).zip(a.iter()).fold(1.0, |acc, (x, a)| {
428 acc * x.powf(a-1.0)
429 })).ln())
430 },
431 _ => Err("Value of wrong type, expected Vector.")
432 }
433 }
434 }
435 }
436}