use std::collections::BTreeMap;
use std::f64::consts::LN_2;
use super::dos_to_post;
use crate::consts::{HALF_LN_2PI, HALF_LN_PI};
use crate::data::{GaussianSuffStat, extract_stat_then};
use crate::dist::{Gaussian, NormalGamma};
use crate::gaussian_prior_geweke_testable;
use crate::misc::ln_gammafn;
use crate::test::GewekeTestable;
use crate::traits::{
ConjugatePrior, DataOrSuffStat, HasSuffStat, Sampleable, SuffStat,
};
#[inline]
fn ln_z(r: f64, s: f64, v: f64) -> f64 {
let half_v = 0.5 * v;
(half_v + 0.5).mul_add(LN_2, HALF_LN_PI)
- 0.5_f64.mul_add(r.ln(), half_v.mul_add(s.ln(), -ln_gammafn(half_v)))
}
pub struct PosteriorParameters {
pub m: f64,
pub r: f64,
pub s: f64,
pub v: f64,
}
impl From<PosteriorParameters> for NormalGamma {
fn from(PosteriorParameters { m, r, s, v }: PosteriorParameters) -> Self {
NormalGamma::new(m, r, s, v).unwrap()
}
}
fn posterior_from_stat(
ng: &NormalGamma,
stat: &GaussianSuffStat,
) -> PosteriorParameters {
let nf = stat.n() as f64;
let r = ng.r() + nf;
let v = ng.v() + nf;
let m = ng.m().mul_add(ng.r(), stat.sum_x()) / r;
let s =
ng.s() + stat.sum_x_sq() + ng.r().mul_add(ng.m() * ng.m(), -r * m * m);
PosteriorParameters { m, r, s, v }
}
macro_rules! impl_traits {
($kind: ty) => {
impl ConjugatePrior<$kind, Gaussian> for NormalGamma {
type Posterior = Self;
type MCache = f64;
type PpCache = (PosteriorParameters, f64);
fn empty_stat(&self) -> <Gaussian as HasSuffStat<$kind>>::Stat {
GaussianSuffStat::new()
}
fn posterior(&self, x: &DataOrSuffStat<$kind, Gaussian>) -> Self {
dos_to_post!(self, x).into()
}
#[inline]
fn ln_m_cache(&self) -> Self::MCache {
ln_z(self.r(), self.s, self.v)
}
fn ln_m_with_cache(
&self,
cache: &Self::MCache,
x: &DataOrSuffStat<$kind, Gaussian>,
) -> f64 {
let (n, post) = dos_to_post!(# self, x);
let lnz_n = ln_z(post.r, post.s, post.v);
(-(n as f64)).mul_add(HALF_LN_2PI, lnz_n) - cache
}
fn ln_pp_cache(&self, x: &DataOrSuffStat<$kind, Gaussian>) -> Self::PpCache {
extract_stat_then(self, x, |stat| {
let params = posterior_from_stat(self, stat);
let PosteriorParameters { r, s, v, .. } = params;
let half_v = v / 2.0;
let g_ratio = ln_gammafn(half_v + 0.5) - ln_gammafn(half_v);
let term = 0.5_f64.mul_add(LN_2, -HALF_LN_2PI)
+ 0.5_f64.mul_add(
(r / (r + 1_f64)).ln(),
half_v.mul_add(s.ln(), g_ratio),
);
(params, term)
})
}
fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &$kind) -> f64 {
let PosteriorParameters { m, r, s, v } = cache.0;
let y = f64::from(*y);
let rn = r + 1.0;
let mn = r.mul_add(m, y) / rn;
let sn = (rn * mn).mul_add(-mn, (r * m).mul_add(m, y.mul_add(y, s)));
((v + 1.0) / 2.0).mul_add(-sn.ln(), cache.1)
}
}
};
}
#[cfg(feature = "experimental")]
impl_traits!(f16);
impl_traits!(f32);
impl_traits!(f64);
gaussian_prior_geweke_testable!(NormalGamma, Gaussian);
#[cfg(test)]
mod tests {
use super::*;
use crate::data::GaussianData;
use crate::test_conjugate_prior;
const TOL: f64 = 1E-12;
test_conjugate_prior!(
f64,
Gaussian,
NormalGamma,
NormalGamma::new(0.1, 1.2, 0.5, 1.8).unwrap()
);
#[test]
fn geweke() {
use crate::test::GewekeTester;
let mut rng = rand::rng();
let pr = NormalGamma::new(0.1, 1.2, 0.5, 1.8).unwrap();
let n_passes = (0..5)
.map(|_| {
let mut tester = GewekeTester::new(pr.clone(), 20);
tester.run_chains(5_000, 20, &mut rng);
u8::from(tester.eval(0.025).is_ok())
})
.sum::<u8>();
assert!(n_passes > 1);
}
#[test]
fn ln_z_all_ones() {
let z = ln_z(1.0, 1.0, 1.0);
assert::close(z, 1.837_877_066_409_35, TOL);
}
#[test]
fn ln_z_not_all_ones() {
let z = ln_z(1.2, 0.4, 5.2);
assert::close(z, 5.369_728_190_685_34, TOL);
}
#[test]
fn ln_marginal_likelihood_vec_data() {
let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
let x = GaussianData::<f64>::Data(&data);
let m = ng.ln_m(&x);
assert::close(m, -7.697_070_183_440_38, TOL);
}
#[test]
fn ln_marginal_likelihood_suffstat() {
let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
let mut stat = GaussianSuffStat::new();
stat.observe(&1.0);
stat.observe(&2.0);
stat.observe(&3.0);
stat.observe(&4.0);
let x = GaussianData::<f64>::SuffStat(&stat);
let m = ng.ln_m(&x);
assert::close(m, -7.697_070_183_440_38, TOL);
}
#[test]
fn ln_marginal_likelihood_suffstat_forgotten() {
let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
let mut stat = GaussianSuffStat::new();
stat.observe(&1.0);
stat.observe(&2.0);
stat.observe(&3.0);
stat.observe(&4.0);
stat.observe(&5.0);
stat.forget(&5.0);
let x = GaussianData::<f64>::SuffStat(&stat);
let m = ng.ln_m(&x);
assert::close(m, -7.697_070_183_440_38, TOL);
}
#[test]
fn posterior_predictive_positive_value() {
let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
let x = GaussianData::<f64>::Data(&data);
let pp = ng.ln_pp(&3.0, &x);
assert::close(pp, -1.284_386_384_996_11, TOL);
}
#[test]
fn posterior_predictive_negative_value() {
let ng = NormalGamma::new(2.1, 1.2, 1.3, 1.4).unwrap();
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
let x = GaussianData::<f64>::Data(&data);
let pp = ng.ln_pp(&-3.0, &x);
assert::close(pp, -6.163_769_886_218_6, TOL);
}
#[test]
fn ln_m_vs_monte_carlo() {
use crate::misc::LogSumExp;
use crate::traits::HasDensity;
let n_samples = 8_000_000;
let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let (m, r, s, v) = (0.0, 1.2, 2.3, 3.4);
let ng = NormalGamma::new(m, r, s, v).unwrap();
let ln_m = ng.ln_m(&DataOrSuffStat::<f64, Gaussian>::from(&xs));
let mc_est = {
ng.sample_stream(&mut rand::rng())
.take(n_samples)
.map(|gauss: Gaussian| {
xs.iter().map(|x| gauss.ln_f(x)).sum::<f64>()
})
.logsumexp()
- (n_samples as f64).ln()
};
assert::close(ln_m, mc_est, 1e-2);
}
#[test]
fn ln_m_vs_importance() {
use crate::misc::LogSumExp;
use crate::traits::HasDensity;
let n_samples = 2_000_000;
let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let (m, r, s, v) = (1.0, 2.2, 3.3, 4.4);
let ng = NormalGamma::new(m, r, s, v).unwrap();
let ln_m = ng.ln_m(&DataOrSuffStat::<f64, Gaussian>::from(&xs));
let post = ng.posterior(&DataOrSuffStat::<f64, Gaussian>::from(&xs));
let mc_est = {
let mut rng = rand::rng();
let ln_fs = (0..n_samples).map(|_| {
let gauss: Gaussian = post.draw(&mut rng);
let ln_f = xs.iter().map(|x| gauss.ln_f(x)).sum::<f64>();
ln_f + ng.ln_f(&gauss) - post.ln_f(&gauss)
});
ln_fs.logsumexp() - f64::from(n_samples).ln()
};
assert::close(ln_m, mc_est, 1e-2);
}
}