1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use crate::{Distribution, *};
use rand::distributions::WeightedIndex;
use rand_distr::Distribution as RandDistribution;
use rayon::prelude::*;
use std::{collections::HashSet, hash::Hash, marker::PhantomData};
#[derive(Clone, Debug)]
pub struct DiscretePosterior<L, P, A, B>
where
L: Distribution<Value = A, Condition = B>,
P: Distribution<Value = B, Condition = ()>,
A: RandomVariable,
B: RandomVariable + Eq + Hash,
{
likelihood: L,
prior: P,
range: HashSet<B>,
phantom: PhantomData<A>,
}
impl<L, P, A, B> DiscretePosterior<L, P, A, B>
where
L: Distribution<Value = A, Condition = B>,
P: Distribution<Value = B, Condition = ()>,
A: RandomVariable,
B: RandomVariable + Eq + Hash,
{
pub fn new(likelihood: L, prior: P, range: HashSet<B>) -> Self {
Self {
likelihood,
prior,
range,
phantom: PhantomData,
}
}
fn weighted(&self, theta: &A) -> Result<Vec<(f64, &B)>, DistributionError> {
let weighted = self
.range
.par_iter()
.map(|u| -> Result<_, DistributionError> {
Ok((
self.likelihood.p_kernel(theta, u)? * self.prior.p_kernel(u, &())?,
u,
))
})
.collect::<Result<Vec<(f64, &B)>, _>>()?;
Ok(weighted)
}
fn index(&self, weighted: &Vec<(f64, &B)>) -> Result<WeightedIndex<f64>, DistributionError> {
let index = match WeightedIndex::new(weighted.iter().map(|(w, _)| *w)) {
Ok(v) => v,
Err(_) => WeightedIndex::new(vec![1.0; weighted.len()]).unwrap(),
};
Ok(index)
}
}
impl<L, P, A, B> Distribution for DiscretePosterior<L, P, A, B>
where
L: Distribution<Value = A, Condition = B>,
P: Distribution<Value = B, Condition = ()>,
A: RandomVariable,
B: RandomVariable + Eq + Hash,
{
type Value = B;
type Condition = A;
fn p_kernel(&self, x: &Self::Value, theta: &Self::Condition) -> Result<f64, DistributionError> {
Ok(self.likelihood.p_kernel(theta, x)? * self.prior.p_kernel(x, &())?)
}
}
impl<L, P, A, B> SampleableDistribution for DiscretePosterior<L, P, A, B>
where
L: Distribution<Value = A, Condition = B>,
P: Distribution<Value = B, Condition = ()>,
A: RandomVariable,
B: RandomVariable + Eq + Hash,
{
fn sample(
&self,
theta: &Self::Condition,
rng: &mut dyn rand::RngCore,
) -> Result<Self::Value, DistributionError> {
let weighted = self.weighted(theta)?;
let index = self.index(&weighted)?.sample(rng);
Ok(weighted[index].1.clone())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use crate::distribution::Distribution;
use crate::*;
#[test]
fn it_works() {
let range = vec![true, false].into_iter().collect::<HashSet<_>>();
let model = DiscretePosterior::new(
Normal.condition(|x: &bool| NormalParams::new(if *x { 10.0 } else { 0.0 }, 1.0)),
Bernoulli.condition(|_x: &()| BernoulliParams::new(0.5)),
range,
);
let true_result = model.p_kernel(&true, &1.0).unwrap();
let false_result = model.p_kernel(&false, &1.0).unwrap();
assert!(true_result < false_result);
}
}