use std::collections::BTreeMap;
use crate::consts::HALF_LN_2PI;
use crate::data::GaussianSuffStat;
use crate::dist::normal_gamma::dos_to_post;
use crate::dist::{Gaussian, NormalInvGamma};
use crate::gaussian_prior_geweke_testable;
use crate::misc::ln_gammafn;
use crate::test::GewekeTestable;
use crate::traits::{
ConjugatePrior, DataOrSuffStat, HasSuffStat, Parameterized, Sampleable,
SuffStat,
};
#[inline]
fn ln_z(v: f64, a: f64, b: f64) -> f64 {
let p1 = v.ln().mul_add(0.5, ln_gammafn(a));
-b.ln().mul_add(a, -p1)
}
pub struct PosteriorParameters {
m: f64,
v: f64,
a: f64,
b: f64,
}
impl From<PosteriorParameters> for NormalInvGamma {
fn from(PosteriorParameters { m, v, a, b }: PosteriorParameters) -> Self {
NormalInvGamma::new(m, v, a, b).unwrap()
}
}
#[allow(clippy::many_single_char_names)]
fn posterior_from_stat(
nig: &NormalInvGamma,
stat: &GaussianSuffStat,
) -> PosteriorParameters {
let n = stat.n() as f64;
let super::NormalInvGammaParameters { m, v, a, b } = nig.emit_params();
let v_inv = v.recip();
let vn_inv = v_inv + n;
let vn = vn_inv.recip();
let mn = v_inv.mul_add(m, stat.sum_x()) / vn_inv;
let an = n.mul_add(0.5, a);
let p1 = (m * m).mul_add(v_inv, stat.sum_x_sq());
let bn = (-mn * mn).mul_add(vn_inv, p1).mul_add(0.5, b);
PosteriorParameters {
m: mn,
v: vn,
a: an,
b: bn,
}
}
impl ConjugatePrior<f64, Gaussian> for NormalInvGamma {
type Posterior = Self;
type MCache = f64;
type PpCache = (PosteriorParameters, f64);
fn empty_stat(&self) -> <Gaussian as HasSuffStat<f64>>::Stat {
GaussianSuffStat::new()
}
fn posterior(&self, x: &DataOrSuffStat<f64, Gaussian>) -> Self {
dos_to_post!(self, x).into()
}
#[inline]
fn ln_m_cache(&self) -> Self::MCache {
ln_z(self.v, self.a, self.b)
}
fn ln_m_with_cache(
&self,
cache: &Self::MCache,
x: &DataOrSuffStat<f64, Gaussian>,
) -> f64 {
let (n, post) = dos_to_post!(# self, x);
let lnz_n = ln_z(post.v, post.a, post.b);
(n as f64).mul_add(-HALF_LN_2PI, lnz_n - cache)
}
fn ln_pp_cache(&self, x: &DataOrSuffStat<f64, Gaussian>) -> Self::PpCache {
let params = dos_to_post!(self, x);
let PosteriorParameters { v, a, b, .. } = params;
let gamma_ratio = ln_gammafn(a + 0.5) - ln_gammafn(a);
let z = (-0.5_f64).mul_add(v.ln_1p(), a * b.ln()) + gamma_ratio
- HALF_LN_2PI;
(params, z)
}
fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 {
let PosteriorParameters { m, v, a, b } = cache.0;
let y = *y;
let v_recip = v.recip();
let vn_recip = v_recip + 1.0;
let mn = v_recip.mul_add(m, y) / vn_recip;
let bn = 0.5_f64.mul_add(
(mn * mn).mul_add(-vn_recip, (m * m).mul_add(v_recip, y * y)),
b,
);
(a + 0.5).mul_add(-bn.ln(), cache.1)
}
}
gaussian_prior_geweke_testable!(NormalInvGamma, Gaussian);
#[cfg(test)]
mod test {
use super::*;
use crate::consts::LN_2PI;
use crate::dist::normal_inv_gamma::NormalInvGammaParameters;
use crate::test_conjugate_prior;
const TOL: f64 = 1E-12;
test_conjugate_prior!(
f64,
Gaussian,
NormalInvGamma,
NormalInvGamma::new(0.1, 1.2, 0.5, 1.8).unwrap()
);
#[test]
fn geweke() {
use crate::test::GewekeTester;
let mut rng = rand::rng();
let pr = NormalInvGamma::new(0.1, 1.2, 0.5, 1.8).unwrap();
let n_passes = (0..5)
.map(|_| {
let mut tester = GewekeTester::new(pr.clone(), 20);
tester.run_chains(5_000, 20, &mut rng);
u8::from(tester.eval(0.025).is_ok())
})
.sum::<u8>();
assert!(n_passes > 1);
}
fn ln_f_ref(gauss: &Gaussian, nig: &NormalInvGamma) -> f64 {
let NormalInvGammaParameters { m, v, a, b } = nig.emit_params();
let mu = gauss.mu();
let sigma = gauss.sigma();
let sig2 = sigma * sigma;
let lz_inv = a.mul_add(
b.ln(),
-(0.5_f64.mul_add(v.ln() + LN_2PI, ln_gammafn(a))),
);
(0.5 / (sig2 * v) * (mu - m)).mul_add(
-mu - m,
(a + 1.).mul_add(-sig2.ln(), 0.5_f64.mul_add(-sig2.ln(), lz_inv))
- b / sig2,
)
}
fn post_params(
xs: &[f64],
m: f64,
v: f64,
a: f64,
b: f64,
) -> (f64, f64, f64, f64) {
let n = xs.len() as f64;
let sum_x: f64 = xs.iter().sum();
let sum_x_sq: f64 = xs.iter().map(|&x| x * x).sum();
let v_inv = v.recip();
let vn_inv = v_inv + n;
let vn = vn_inv.recip();
let mn = v_inv.mul_add(m, sum_x) * vn;
let an = a + n / 2.0;
let bn = 0.5_f64.mul_add(
(mn * mn).mul_add(-vn_inv, (m * m).mul_add(v_inv, sum_x_sq)),
b,
);
(mn, vn, an, bn)
}
fn alternate_ln_marginal(
xs: &[f64],
m: f64,
v: f64,
a: f64,
b: f64,
) -> f64 {
let n = xs.len() as f64;
let (_, vn, an, bn) = post_params(xs, m, v, a, b);
let numer = 0.5_f64.mul_add(vn.ln(), a * b.ln()) + ln_gammafn(an);
let denom = (n / 2.0).mul_add(
LN_2PI,
0.5_f64.mul_add(v.ln(), an * bn.ln()) + ln_gammafn(a),
);
numer - denom
}
#[test]
fn ln_f_vs_reference() {
use crate::traits::HasDensity;
let (m, v, a, b) = (0.0, 1.2, 2.3, 3.4);
let nig = NormalInvGamma::new(m, v, a, b).unwrap();
let mut rng = rand::rng();
for _ in 0..100 {
let gauss = nig.draw(&mut rng);
let ln_f = nig.ln_f(&gauss);
let reference = ln_f_ref(&gauss, &nig);
assert::close(reference, ln_f, TOL);
}
}
#[test]
fn ln_m_vs_reference() {
let (m, v, a, b) = (0.0, 1.2, 2.3, 3.4);
let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let reference = alternate_ln_marginal(&xs, m, v, a, b);
let nig = NormalInvGamma::new(m, v, a, b).unwrap();
let ln_m = nig.ln_m(&DataOrSuffStat::<f64, Gaussian>::from(&xs));
assert::close(reference, ln_m, TOL);
}
#[test]
fn ln_m_vs_monte_carlo() {
use crate::misc::LogSumExp;
use crate::traits::HasDensity;
let n_samples = 1_000_000;
let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let (m, v, a, b) = (1.0, 2.2, 3.3, 4.4);
let nig = NormalInvGamma::new(m, v, a, b).unwrap();
let ln_m = nig.ln_m(&DataOrSuffStat::<f64, Gaussian>::from(&xs));
let mc_est = {
nig.sample_stream(&mut rand::rng())
.take(n_samples)
.map(|gauss: Gaussian| {
xs.iter().map(|x| gauss.ln_f(x)).sum::<f64>()
})
.logsumexp()
- (n_samples as f64).ln()
};
assert::close(ln_m, mc_est, 1e-2);
}
#[test]
fn ln_m_vs_importance() {
use crate::dist::Gamma;
use crate::misc::LogSumExp;
use crate::traits::HasDensity;
let n_samples = 1_000_000;
let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let (m, v, a, b) = (1.0, 2.2, 3.3, 4.4);
let nig = NormalInvGamma::new(m, v, a, b).unwrap();
let ln_m = nig.ln_m(&DataOrSuffStat::<f64, Gaussian>::from(&xs));
let mc_est = {
let mut rng = rand::rng();
let pr_m = Gaussian::new(1.0, 8.0).unwrap();
let pr_s = Gamma::new(2.0, 0.4).unwrap();
let ln_fs = (0..n_samples).map(|_| {
let mu: f64 = pr_m.draw(&mut rng);
let var: f64 = pr_s.draw(&mut rng);
let gauss = Gaussian::new(mu, var.sqrt()).unwrap();
let ln_f = xs.iter().map(|x| gauss.ln_f(x)).sum::<f64>();
ln_f + nig.ln_f(&gauss) - pr_m.ln_f(&mu) - pr_s.ln_f(&var)
});
ln_fs.logsumexp() - f64::from(n_samples).ln()
};
assert::close(ln_m, mc_est, 1e-2);
}
#[test]
fn ln_pp_vs_monte_carlo() {
use crate::misc::LogSumExp;
use crate::traits::HasDensity;
let n_samples = 1_000_000;
let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let y: f64 = -0.3;
let (m, v, a, b) = (1.0, 2.2, 3.3, 4.4);
let nig = NormalInvGamma::new(m, v, a, b).unwrap();
let post = nig.posterior(&DataOrSuffStat::<f64, Gaussian>::from(&xs));
let ln_pp = nig.ln_pp(&y, &DataOrSuffStat::<f64, Gaussian>::from(&xs));
let mc_est = {
post.sample_stream(&mut rand::rng())
.take(n_samples)
.map(|gauss: Gaussian| gauss.ln_f(&y))
.logsumexp()
- (n_samples as f64).ln()
};
assert::close(ln_pp, mc_est, 1e-2);
}
#[test]
fn ln_pp_vs_ln_m_single() {
let y: f64 = -0.3;
let (m, v, a, b) = (0.0, 1.2, 2.3, 3.4);
let nig = NormalInvGamma::new(m, v, a, b).unwrap();
let ln_pp = nig.ln_pp(&y, &DataOrSuffStat::from(&vec![]));
let ln_m = nig.ln_m(&DataOrSuffStat::from(&vec![y]));
assert::close(ln_pp, ln_m, TOL);
}
#[test]
fn ln_pp_vs_t() {
use crate::dist::StudentsT;
use crate::traits::HasDensity;
let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y: f64 = -0.3;
let (m, v, a, b) = (0.0, 1.2, 2.3, 3.4);
let (mn, vn, an, bn) = post_params(&xs, m, v, a, b);
let ln_f_t = {
let t = StudentsT::new(2.0 * an).unwrap();
let t_scale = bn * (1.0 + vn) / an;
let t_shift = mn;
let y_adj = (y - t_shift) / t_scale.sqrt();
0.5_f64.mul_add(-t_scale.ln(), t.ln_f(&y_adj))
};
let ln_pp = {
let nig = NormalInvGamma::new(m, v, a, b).unwrap();
nig.ln_pp(&y, &DataOrSuffStat::<f64, Gaussian>::from(&xs))
};
assert::close(ln_f_t, ln_pp, TOL);
}
}