1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use std::f64::consts::LN_2;

use special::Gamma as SGamma;

use crate::consts::*;
use crate::data::{DataOrSuffStat, GaussianSuffStat};
use crate::dist::{Gaussian, NormalGamma};
use crate::traits::*;

fn extract_stat(x: &DataOrSuffStat<f64, Gaussian>) -> GaussianSuffStat {
    match x {
        DataOrSuffStat::SuffStat(ref s) => (*s).clone(),
        DataOrSuffStat::Data(xs) => {
            let mut stat = GaussianSuffStat::new();
            xs.iter().for_each(|y| stat.observe(y));
            stat
        }
        DataOrSuffStat::None => GaussianSuffStat::new(),
    }
}

fn ln_z(r: f64, s: f64, v: f64) -> f64 {
    (v + 1.0) / 2.0 * LN_2 + HALF_LN_PI - 0.5 * r.ln() - (v / 2.0) * s.ln()
        + (v / 2.0).ln_gamma().0
}

fn posterior_from_stat(
    ng: &NormalGamma,
    stat: &GaussianSuffStat,
) -> NormalGamma {
    let r = ng.r + stat.n as f64;
    let v = ng.v + stat.n as f64;
    let m = (ng.m * ng.r + stat.sum_x) / r;
    let s = ng.s + stat.sum_x_sq + ng.r * ng.m * ng.m - r * m * m;
    NormalGamma::new(m, r, s, v).expect("Invalid posterior params.")
}

impl ConjugatePrior<f64, Gaussian> for NormalGamma {
    type Posterior = Self;
    fn posterior(&self, x: &DataOrSuffStat<f64, Gaussian>) -> Self {
        let stat = extract_stat(&x);
        posterior_from_stat(&self, &stat)
    }

    fn ln_m(&self, x: &DataOrSuffStat<f64, Gaussian>) -> f64 {
        let stat = extract_stat(&x);
        let post = posterior_from_stat(&self, &stat);
        let lnz_0 = ln_z(self.r, self.s, self.v);
        let lnz_n = ln_z(post.r, post.s, post.v);
        -(stat.n as f64) * HALF_LN_2PI + lnz_n - lnz_0
    }

    fn ln_pp(&self, y: &f64, x: &DataOrSuffStat<f64, Gaussian>) -> f64 {
        let mut stat = extract_stat(&x);
        let post_n = posterior_from_stat(&self, &stat);
        stat.observe(y);
        let post_m = posterior_from_stat(&self, &stat);

        let lnz_n = ln_z(post_n.r, post_n.s, post_n.v);
        let lnz_m = ln_z(post_m.r, post_m.s, post_m.v);

        -HALF_LN_2PI + lnz_m - lnz_n
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const TOL: f64 = 1E-12;

    #[test]
    fn ln_z_all_ones() {
        let z = ln_z(1.0, 1.0, 1.0);
        assert::close(z, 1.83787706640935, TOL);
    }

    #[test]
    fn ln_z_not_all_ones() {
        let z = ln_z(1.2, 0.4, 5.2);
        assert::close(z, 5.36972819068534, TOL);
    }

    #[test]
    fn ln_marginal_likelihood_vec_data() {
        let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
        let x = DataOrSuffStat::Data(&data);
        let m = ng.ln_m(&x);
        assert::close(m, -7.69707018344038, TOL);
    }

    #[test]
    fn ln_marginal_likelihood_suffstat() {
        let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
        let mut stat = GaussianSuffStat::new();
        stat.observe(&1.0);
        stat.observe(&2.0);
        stat.observe(&3.0);
        stat.observe(&4.0);
        let x = DataOrSuffStat::SuffStat(&stat);
        let m = ng.ln_m(&x);
        assert::close(m, -7.69707018344038, TOL);
    }

    #[test]
    fn posterior_predictive_positive_value() {
        let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
        let x = DataOrSuffStat::Data(&data);
        let pp = ng.ln_pp(&3.0, &x);
        assert::close(pp, -1.28438638499611, TOL);
    }

    #[test]
    fn posterior_predictive_negative_value() {
        let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
        let x = DataOrSuffStat::Data(&data);
        let pp = ng.ln_pp(&-3.0, &x);
        assert::close(pp, -6.1637698862186, TOL);
    }
}