1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use crate::{Distribution, DistributionError, RandomVariable, SampleableDistribution};
use rand::prelude::*;
pub struct MetropolisSampler<'a, L, P, A, B, PD>
where
L: Distribution<Value = A, Condition = B>,
P: Distribution<Value = B, Condition = ()>,
A: RandomVariable,
B: RandomVariable,
PD: SampleableDistribution<Value = B, Condition = B>,
{
value: &'a A,
likelihood: &'a L,
prior: &'a P,
proposal: &'a PD,
}
impl<'a, L, P, A, B, PD> MetropolisSampler<'a, L, P, A, B, PD>
where
L: Distribution<Value = A, Condition = B>,
P: Distribution<Value = B, Condition = ()>,
A: RandomVariable,
B: RandomVariable,
PD: SampleableDistribution<Value = B, Condition = B>,
{
pub fn new(value: &'a A, likelihood: &'a L, prior: &'a P, proposal: &'a PD) -> Self {
Self {
value,
likelihood,
prior,
proposal,
}
}
pub fn sample(
&self,
iter: usize,
initial: B,
rng: &mut dyn RngCore,
) -> Result<B, DistributionError> {
let mut state = initial;
let mut count = 0;
while count < iter {
let candidate = self.proposal.sample(&state, rng)?;
let r = (self.likelihood.p_kernel(self.value, &candidate)?
* self.prior.p_kernel(&candidate, &())?)
/ (self.likelihood.p_kernel(self.value, &state)?
* self.prior.p_kernel(&state, &())?);
let r = r.min(1.0);
let p = rng.gen_range(0.0..=1.0);
if p < r {
state = candidate;
count += 1;
}
}
Ok(state)
}
}