1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use crate::{Distribution, DistributionError, RandomVariable, SampleableDistribution};
use rand::prelude::*;

/// Sample b from posterior p(b|a) with likelihood p(a|b) and prior p(b)
pub struct MetropolisSampler<'a, L, P, A, B, PD>
where
    L: Distribution<Value = A, Condition = B>,
    P: Distribution<Value = B, Condition = ()>,
    A: RandomVariable,
    B: RandomVariable,
    PD: SampleableDistribution<Value = B, Condition = B>,
{
    value: &'a A,
    likelihood: &'a L,
    prior: &'a P,
    proposal: &'a PD,
}

impl<'a, L, P, A, B, PD> MetropolisSampler<'a, L, P, A, B, PD>
where
    L: Distribution<Value = A, Condition = B>,
    P: Distribution<Value = B, Condition = ()>,
    A: RandomVariable,
    B: RandomVariable,
    PD: SampleableDistribution<Value = B, Condition = B>,
{
    pub fn new(value: &'a A, likelihood: &'a L, prior: &'a P, proposal: &'a PD) -> Self {
        Self {
            value,
            likelihood,
            prior,
            proposal,
        }
    }

    pub fn sample(
        &self,
        iter: usize,
        initial: B,
        rng: &mut dyn RngCore,
    ) -> Result<B, DistributionError> {
        let mut state = initial;
        let mut count = 0;

        while count < iter {
            let candidate = self.proposal.sample(&state, rng)?;
            let r = (self.likelihood.p_kernel(self.value, &candidate)?
                * self.prior.p_kernel(&candidate, &())?)
                / (self.likelihood.p_kernel(self.value, &state)?
                    * self.prior.p_kernel(&state, &())?);
            let r = r.min(1.0);
            let p = rng.gen_range(0.0..=1.0);

            if p < r {
                state = candidate;
                count += 1;
            }
        }

        Ok(state)
    }
}