use crate::error::{StatsError, StatsResult as Result};
use scirs2_core::random::{rngs::StdRng, Beta as RandBeta, Distribution, Gamma, Normal, SeedableRng};
use std::f64::consts::PI;
pub trait HyperPrior: Clone + std::fmt::Debug {
fn log_norm_const(&self) -> f64;
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone)]
pub struct NormalInverseGamma {
pub mu0: f64,
pub kappa0: f64,
pub alpha0: f64,
pub beta0: f64,
}
impl NormalInverseGamma {
pub fn new(mu0: f64, kappa0: f64, alpha0: f64, beta0: f64) -> Result<Self> {
if kappa0 <= 0.0 {
return Err(StatsError::DomainError(format!(
"kappa0 must be > 0, got {kappa0}"
)));
}
if alpha0 <= 0.0 {
return Err(StatsError::DomainError(format!(
"alpha0 must be > 0, got {alpha0}"
)));
}
if beta0 <= 0.0 {
return Err(StatsError::DomainError(format!(
"beta0 must be > 0, got {beta0}"
)));
}
if !mu0.is_finite() {
return Err(StatsError::DomainError(format!(
"mu0 must be finite, got {mu0}"
)));
}
Ok(Self { mu0, kappa0, alpha0, beta0 })
}
pub fn update(&self, obs: &[f64]) -> Result<Self> {
let n = obs.len();
if n == 0 {
return Ok(self.clone());
}
let n_f = n as f64;
let x_bar = obs.iter().sum::<f64>() / n_f;
let s: f64 = obs.iter().map(|&x| (x - x_bar).powi(2)).sum();
let kappa_n = self.kappa0 + n_f;
let mu_n = (self.kappa0 * self.mu0 + n_f * x_bar) / kappa_n;
let alpha_n = self.alpha0 + n_f / 2.0;
let beta_n = self.beta0
+ 0.5 * s
+ 0.5 * (self.kappa0 * n_f / kappa_n) * (x_bar - self.mu0).powi(2);
Self::new(mu_n, kappa_n, alpha_n, beta_n)
}
pub fn log_marginal_likelihood(&self, obs: &[f64]) -> Result<f64> {
let n = obs.len();
if n == 0 {
return Ok(0.0);
}
let post = self.update(obs)?;
let n_f = n as f64;
let log_ml = lgamma(post.alpha0) - lgamma(self.alpha0)
+ self.alpha0 * self.beta0.ln()
- post.alpha0 * post.beta0.ln()
+ 0.5 * (self.kappa0 / post.kappa0).ln()
- (n_f / 2.0) * (2.0 * PI).ln();
Ok(log_ml)
}
pub fn posterior_predictive_pdf(&self, x: f64) -> f64 {
let df = 2.0 * self.alpha0;
let scale_sq = self.beta0 * (self.kappa0 + 1.0) / (self.kappa0 * self.alpha0);
let scale = scale_sq.sqrt();
let z = (x - self.mu0) / scale;
let log_pdf = lgamma((df + 1.0) / 2.0)
- lgamma(df / 2.0)
- 0.5 * (df * PI).ln()
- scale.ln()
- ((df + 1.0) / 2.0) * (1.0 + z * z / df).ln();
log_pdf.exp()
}
pub fn sample(&self, rng: &mut StdRng) -> Result<(f64, f64)> {
let gamma = Gamma::new(self.alpha0, 1.0 / self.beta0).map_err(|e| {
StatsError::ComputationError(format!("NIG Gamma sampling error: {e}"))
})?;
let tau = gamma.sample(rng);
let sigma2 = if tau > 0.0 { 1.0 / tau } else { f64::MAX };
let std_mu = (sigma2 / self.kappa0).sqrt();
let normal = Normal::new(self.mu0, std_mu).map_err(|e| {
StatsError::ComputationError(format!("NIG Normal sampling error: {e}"))
})?;
let mu = normal.sample(rng);
Ok((mu, sigma2))
}
pub fn sigma2_mode(&self) -> f64 {
self.beta0 / (self.alpha0 + 1.0)
}
pub fn sigma2_mean(&self) -> Result<f64> {
if self.alpha0 <= 1.0 {
return Err(StatsError::DomainError(
"sigma2_mean requires alpha0 > 1".into(),
));
}
Ok(self.beta0 / (self.alpha0 - 1.0))
}
}
impl HyperPrior for NormalInverseGamma {
fn log_norm_const(&self) -> f64 {
lgamma(self.alpha0)
+ 0.5 * self.kappa0.ln()
- self.alpha0 * self.beta0.ln()
- 0.5 * (2.0 * PI).ln()
}
fn name(&self) -> &'static str {
"NormalInverseGamma"
}
}
#[derive(Debug, Clone)]
pub struct NormalInverseWishart {
pub mu0: Vec<f64>,
pub kappa0: f64,
pub nu0: f64,
pub psi0: Vec<Vec<f64>>,
pub dim: usize,
}
impl NormalInverseWishart {
pub fn new(mu0: Vec<f64>, kappa0: f64, nu0: f64, psi0: Vec<Vec<f64>>) -> Result<Self> {
let dim = mu0.len();
if dim == 0 {
return Err(StatsError::DomainError(
"mu0 must be non-empty".into(),
));
}
if kappa0 <= 0.0 {
return Err(StatsError::DomainError(format!(
"kappa0 must be > 0, got {kappa0}"
)));
}
if nu0 < dim as f64 {
return Err(StatsError::DomainError(format!(
"nu0 ({nu0}) must be >= dim ({dim})"
)));
}
if psi0.len() != dim || psi0.iter().any(|row| row.len() != dim) {
return Err(StatsError::DimensionMismatch(format!(
"psi0 must be {dim}×{dim}"
)));
}
Ok(Self { mu0, kappa0, nu0, psi0, dim })
}
pub fn update(&self, obs: &[Vec<f64>]) -> Result<Self> {
let n = obs.len();
if n == 0 {
return Ok(self.clone());
}
let n_f = n as f64;
let d = self.dim;
for (i, row) in obs.iter().enumerate() {
if row.len() != d {
return Err(StatsError::DimensionMismatch(format!(
"obs[{i}] has length {}, expected {d}",
row.len()
)));
}
}
let mut x_bar = vec![0.0_f64; d];
for row in obs {
for (k, &v) in row.iter().enumerate() {
x_bar[k] += v;
}
}
for k in 0..d {
x_bar[k] /= n_f;
}
let kappa_n = self.kappa0 + n_f;
let nu_n = self.nu0 + n_f;
let mut mu_n = vec![0.0_f64; d];
for k in 0..d {
mu_n[k] = (self.kappa0 * self.mu0[k] + n_f * x_bar[k]) / kappa_n;
}
let mut psi_n = self.psi0.clone();
for row in obs {
for i in 0..d {
for j in 0..d {
psi_n[i][j] += (row[i] - x_bar[i]) * (row[j] - x_bar[j]);
}
}
}
let scale = self.kappa0 * n_f / kappa_n;
for i in 0..d {
for j in 0..d {
psi_n[i][j] += scale * (x_bar[i] - self.mu0[i]) * (x_bar[j] - self.mu0[j]);
}
}
Self::new(mu_n, kappa_n, nu_n, psi_n)
}
pub fn log_marginal_likelihood(&self, obs: &[Vec<f64>]) -> Result<f64> {
let n = obs.len();
if n == 0 {
return Ok(0.0);
}
let n_f = n as f64;
let d = self.dim as f64;
let post = self.update(obs)?;
let log_z_prior = log_niw_norm_const(&self.psi0, self.nu0, self.kappa0, self.dim)?;
let log_z_post = log_niw_norm_const(&post.psi0, post.nu0, post.kappa0, self.dim)?;
Ok(log_z_post - log_z_prior - n_f * d / 2.0 * PI.ln())
}
pub fn mean_of_mu(&self) -> &[f64] {
&self.mu0
}
}
impl HyperPrior for NormalInverseWishart {
fn log_norm_const(&self) -> f64 {
log_niw_norm_const(&self.psi0, self.nu0, self.kappa0, self.dim).unwrap_or(f64::NEG_INFINITY)
}
fn name(&self) -> &'static str {
"NormalInverseWishart"
}
}
fn log_niw_norm_const(psi: &[Vec<f64>], nu: f64, kappa: f64, d: usize) -> Result<f64> {
let log_det = log_det_chol(psi, d)?;
let log_gamma_d = multivariate_lgamma(nu / 2.0, d);
Ok(log_gamma_d - nu / 2.0 * log_det + d as f64 / 2.0 * kappa.ln())
}
fn multivariate_lgamma(x: f64, d: usize) -> f64 {
let mut result = (d * (d - 1)) as f64 / 4.0 * PI.ln();
for j in 1..=d {
result += lgamma(x + (1.0 - j as f64) / 2.0);
}
result
}
fn log_det_chol(m: &[Vec<f64>], d: usize) -> Result<f64> {
let mut l = vec![vec![0.0_f64; d]; d];
for i in 0..d {
for j in 0..=i {
let mut sum = m[i][j];
for k in 0..j {
sum -= l[i][k] * l[j][k];
}
if i == j {
if sum <= 0.0 {
return Err(StatsError::ComputationError(
"Matrix is not positive definite".into(),
));
}
l[i][j] = sum.sqrt();
} else {
l[i][j] = sum / l[j][j];
}
}
}
let log_det: f64 = (0..d).map(|i| l[i][i].ln()).sum::<f64>() * 2.0;
Ok(log_det)
}
pub(crate) fn lgamma(x: f64) -> f64 {
lanczos_lgamma(x)
}
fn lanczos_lgamma(x: f64) -> f64 {
if x < 0.5 {
return PI.ln() - (PI * x).sin().abs().ln() - lanczos_lgamma(1.0 - x);
}
let g = 7.0_f64;
let c = [
0.999_999_999_999_809_3_f64,
676.520_368_121_885_1,
-1_259.139_216_722_403,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507_343_278_686_905,
-0.138_571_095_265_720_12,
9.984_369_578_019_572e-6,
1.505_632_735_149_312_4e-7,
];
let xm1 = x - 1.0;
let mut series = c[0];
for (i, &ci) in c[1..].iter().enumerate() {
series += ci / (xm1 + (i as f64 + 1.0));
}
let t = xm1 + g + 0.5;
(2.0 * PI).sqrt().ln() + series.ln() + (xm1 + 0.5) * t.ln() - t
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nig_construction() {
let nig = NormalInverseGamma::new(0.0, 1.0, 2.0, 1.0).unwrap();
assert_eq!(nig.mu0, 0.0);
assert_eq!(nig.kappa0, 1.0);
assert_eq!(nig.alpha0, 2.0);
assert_eq!(nig.beta0, 1.0);
}
#[test]
fn test_nig_invalid() {
assert!(NormalInverseGamma::new(0.0, 0.0, 2.0, 1.0).is_err());
assert!(NormalInverseGamma::new(0.0, 1.0, 0.0, 1.0).is_err());
assert!(NormalInverseGamma::new(0.0, 1.0, 2.0, 0.0).is_err());
assert!(NormalInverseGamma::new(f64::NAN, 1.0, 2.0, 1.0).is_err());
}
#[test]
fn test_nig_update() {
let prior = NormalInverseGamma::new(0.0, 1.0, 1.0, 1.0).unwrap();
let obs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let post = prior.update(&obs).unwrap();
assert!((post.kappa0 - 6.0).abs() < 1e-10);
assert!((post.mu0 - 2.5).abs() < 1e-10);
assert!((post.alpha0 - 3.5).abs() < 1e-10);
}
#[test]
fn test_nig_update_empty() {
let prior = NormalInverseGamma::new(0.0, 1.0, 1.0, 1.0).unwrap();
let post = prior.update(&[]).unwrap();
assert_eq!(post.mu0, prior.mu0);
assert_eq!(post.kappa0, prior.kappa0);
}
#[test]
fn test_nig_sample() {
let nig = NormalInverseGamma::new(0.0, 1.0, 3.0, 2.0).unwrap();
let mut rng = StdRng::seed_from_u64(42);
let (mu, sigma2) = nig.sample(&mut rng).unwrap();
assert!(mu.is_finite());
assert!(sigma2 > 0.0);
}
#[test]
fn test_nig_posterior_predictive() {
let nig = NormalInverseGamma::new(0.0, 1.0, 2.0, 1.0).unwrap();
let pdf_at_0 = nig.posterior_predictive_pdf(0.0);
let pdf_at_10 = nig.posterior_predictive_pdf(10.0);
assert!(pdf_at_0 > pdf_at_10);
assert!(pdf_at_0 > 0.0);
}
#[test]
fn test_nig_log_marginal_likelihood() {
let prior = NormalInverseGamma::new(0.0, 1.0, 1.0, 1.0).unwrap();
let obs = vec![0.0, 0.1, -0.1, 0.2, -0.2];
let lml = prior.log_marginal_likelihood(&obs).unwrap();
assert!(lml.is_finite());
let obs_far = vec![10.0, 10.1, 9.9, 10.2, 9.8];
let lml_far = prior.log_marginal_likelihood(&obs_far).unwrap();
assert!(lml > lml_far);
}
#[test]
fn test_niw_construction() {
let mu0 = vec![0.0, 0.0];
let psi0 = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let niw = NormalInverseWishart::new(mu0, 1.0, 3.0, psi0).unwrap();
assert_eq!(niw.dim, 2);
}
#[test]
fn test_niw_invalid() {
let mu0 = vec![0.0, 0.0];
let psi0 = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
assert!(NormalInverseWishart::new(mu0.clone(), 1.0, 1.0, psi0.clone()).is_err());
assert!(NormalInverseWishart::new(mu0.clone(), 0.0, 3.0, psi0.clone()).is_err());
let bad_psi = vec![vec![1.0]];
assert!(NormalInverseWishart::new(mu0, 1.0, 3.0, bad_psi).is_err());
}
#[test]
fn test_niw_update() {
let mu0 = vec![0.0, 0.0];
let psi0 = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let prior = NormalInverseWishart::new(mu0, 1.0, 3.0, psi0).unwrap();
let obs = vec![
vec![1.0, 0.5],
vec![2.0, 1.5],
vec![-1.0, 0.0],
];
let post = prior.update(&obs).unwrap();
assert!((post.kappa0 - 4.0).abs() < 1e-10);
assert!((post.nu0 - 6.0).abs() < 1e-10);
}
#[test]
fn test_niw_log_marginal_likelihood() {
let mu0 = vec![0.0, 0.0];
let psi0 = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let prior = NormalInverseWishart::new(mu0, 1.0, 4.0, psi0).unwrap();
let obs = vec![
vec![0.1, -0.1],
vec![-0.1, 0.1],
vec![0.0, 0.0],
];
let lml = prior.log_marginal_likelihood(&obs).unwrap();
assert!(lml.is_finite());
}
#[test]
fn test_lgamma() {
assert!((lgamma(1.0) - 0.0).abs() < 1e-10); assert!((lgamma(2.0) - 0.0).abs() < 1e-10); assert!((lgamma(0.5) - (PI.sqrt().ln())).abs() < 1e-6); }
}