1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
use std::f64::consts::LN_2; use special::Gamma as SGamma; use crate::consts::*; use crate::data::{DataOrSuffStat, GaussianSuffStat}; use crate::dist::{Gaussian, NormalGamma}; use crate::traits::*; fn extract_stat(x: &DataOrSuffStat<f64, Gaussian>) -> GaussianSuffStat { match x { DataOrSuffStat::SuffStat(ref s) => (*s).clone(), DataOrSuffStat::Data(xs) => { let mut stat = GaussianSuffStat::new(); xs.iter().for_each(|y| stat.observe(y)); stat } DataOrSuffStat::None => GaussianSuffStat::new(), } } fn ln_z(r: f64, s: f64, v: f64) -> f64 { (v + 1.0) / 2.0 * LN_2 + HALF_LN_PI - 0.5 * r.ln() - (v / 2.0) * s.ln() + (v / 2.0).ln_gamma().0 } fn posterior_from_stat( ng: &NormalGamma, stat: &GaussianSuffStat, ) -> NormalGamma { let r = ng.r + stat.n as f64; let v = ng.v + stat.n as f64; let m = (ng.m * ng.r + stat.sum_x) / r; let s = ng.s + stat.sum_x_sq + ng.r * ng.m * ng.m - r * m * m; NormalGamma::new(m, r, s, v).expect("Invalid posterior params.") } impl ConjugatePrior<f64, Gaussian> for NormalGamma { type Posterior = Self; fn posterior(&self, x: &DataOrSuffStat<f64, Gaussian>) -> Self { let stat = extract_stat(&x); posterior_from_stat(&self, &stat) } fn ln_m(&self, x: &DataOrSuffStat<f64, Gaussian>) -> f64 { let stat = extract_stat(&x); let post = posterior_from_stat(&self, &stat); let lnz_0 = ln_z(self.r, self.s, self.v); let lnz_n = ln_z(post.r, post.s, post.v); -(stat.n as f64) * HALF_LN_2PI + lnz_n - lnz_0 } fn ln_pp(&self, y: &f64, x: &DataOrSuffStat<f64, Gaussian>) -> f64 { let mut stat = extract_stat(&x); let post_n = posterior_from_stat(&self, &stat); stat.observe(y); let post_m = posterior_from_stat(&self, &stat); let lnz_n = ln_z(post_n.r, post_n.s, post_n.v); let lnz_m = ln_z(post_m.r, post_m.s, post_m.v); -HALF_LN_2PI + lnz_m - lnz_n } } #[cfg(test)] mod tests { use super::*; const TOL: f64 = 1E-12; #[test] fn ln_z_all_ones() { let z = ln_z(1.0, 1.0, 1.0); assert::close(z, 1.83787706640935, TOL); } #[test] fn ln_z_not_all_ones() { let z = ln_z(1.2, 0.4, 5.2); assert::close(z, 5.36972819068534, TOL); } #[test] fn ln_marginal_likelihood_vec_data() { let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap(); let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0]; let x = DataOrSuffStat::Data(&data); let m = ng.ln_m(&x); assert::close(m, -7.69707018344038, TOL); } #[test] fn ln_marginal_likelihood_suffstat() { let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap(); let mut stat = GaussianSuffStat::new(); stat.observe(&1.0); stat.observe(&2.0); stat.observe(&3.0); stat.observe(&4.0); let x = DataOrSuffStat::SuffStat(&stat); let m = ng.ln_m(&x); assert::close(m, -7.69707018344038, TOL); } #[test] fn posterior_predictive_positive_value() { let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap(); let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0]; let x = DataOrSuffStat::Data(&data); let pp = ng.ln_pp(&3.0, &x); assert::close(pp, -1.28438638499611, TOL); } #[test] fn posterior_predictive_negative_value() { let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap(); let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0]; let x = DataOrSuffStat::Data(&data); let pp = ng.ln_pp(&-3.0, &x); assert::close(pp, -6.1637698862186, TOL); } }