1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use crate::{Distribution, DistributionError, RandomVariable};
use rand::prelude::*;

/// Sample b from posterior p(b|a) with likelihood p(a|b) and prior p(b)
/// Unbounded Slice Sampling
/// http://chasen.org/~daiti-m/diary/?201510
pub struct SliceSampler<L, P, A>
where
    L: Distribution<Value = A, Condition = f64>,
    P: Distribution<Value = f64, Condition = ()>,
    A: RandomVariable,
{
    value: A,
    likelihood: L,
    prior: P,
}

#[derive(thiserror::Error, Debug)]
pub enum SliceSamplingError {
    #[error("out of range")]
    OutOfRange,
    #[error("Unknown error")]
    Unknown,
}

impl<L, P, A> SliceSampler<L, P, A>
where
    L: Distribution<Value = A, Condition = f64>,
    P: Distribution<Value = f64, Condition = ()>,
    A: RandomVariable,
{
    pub fn new(value: A, likelihood: L, prior: P) -> Result<Self, DistributionError> {
        Ok(Self {
            value,
            likelihood,
            prior,
        })
    }

    pub fn sample(
        &self,
        x: f64,
        max_iter: usize,
        rng: &mut dyn RngCore,
    ) -> Result<f64, DistributionError> {
        let mut st = 0.0;
        let mut ed = 1.0;

        let r = shrink(x)?;
        let slice = self.likelihood.p_kernel(&self.value, &x)? * self.prior.p_kernel(&x, &())?
            - 2.0 * rng.gen_range(0.0f64..1.0f64).ln();

        for _iter in 0..max_iter {
            let rnew = rng.gen_range(st..ed);
            let expanded = expand(rnew)?;

            let newlik = self.likelihood.p_kernel(&self.value, &expanded)?
                * self.prior.p_kernel(&expanded, &())?
                - (2.0 * rnew * (1.0 - rnew));

            if newlik > slice {
                return expand(rnew);
            } else if rnew > r {
                ed = rnew;
            } else if rnew < r {
                st = rnew;
            } else {
                return Ok(x);
            }
        }
        Ok(x)
    }
}

fn expand(p: f64) -> Result<f64, DistributionError> {
    Ok(-100.0 * (1.0 / (p - 1.0)).ln())
}

fn shrink(x: f64) -> Result<f64, DistributionError> {
    Ok(1.0 / (1.0 + (-x / 100.0).exp()))
}