use crate::array::Array;
use crate::new_modules::probabilistic::{
validate_positive, validate_probability, BetaDistribution, DirichletDistribution,
GammaDistribution, ProbabilisticError, Result,
};
use scirs2_core::random::{thread_rng, Rng};
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct BetaBinomialConjugate {
prior: BetaDistribution,
}
impl BetaBinomialConjugate {
pub fn new(alpha: f64, beta: f64) -> Result<Self> {
let prior = BetaDistribution::new(alpha, beta)?;
Ok(Self { prior })
}
pub fn update(&self, n_successes: usize, n_trials: usize) -> Result<BetaDistribution> {
if n_successes > n_trials {
return Err(ProbabilisticError::InvalidParameter {
parameter: "n_successes".to_string(),
message: "n_successes cannot exceed n_trials".to_string(),
});
}
let alpha_post = self.prior.alpha() + n_successes as f64;
let beta_post = self.prior.beta() + (n_trials - n_successes) as f64;
BetaDistribution::new(alpha_post, beta_post)
}
pub fn posterior_predictive(&self, n_successes: usize, n_trials: usize) -> Result<f64> {
let posterior = self.update(n_successes, n_trials)?;
Ok(posterior.mean())
}
}
#[derive(Debug, Clone)]
pub struct GammaPoissonConjugate {
prior: GammaDistribution,
}
impl GammaPoissonConjugate {
pub fn new(alpha: f64, beta: f64) -> Result<Self> {
let prior = GammaDistribution::new(alpha, beta)?;
Ok(Self { prior })
}
pub fn update(&self, data: &[f64]) -> Result<GammaDistribution> {
let sum_x: f64 = data.iter().sum();
let n = data.len() as f64;
let alpha_post = self.prior.alpha() + sum_x;
let beta_post = self.prior.beta() + n;
GammaDistribution::new(alpha_post, beta_post)
}
pub fn posterior_predictive_mean(&self, data: &[f64]) -> Result<f64> {
let posterior = self.update(data)?;
Ok(posterior.mean())
}
}
#[derive(Debug, Clone)]
pub struct NormalNormalConjugate {
prior_mean: f64,
prior_variance: f64,
likelihood_variance: f64,
}
impl NormalNormalConjugate {
pub fn new(prior_mean: f64, prior_variance: f64, likelihood_variance: f64) -> Result<Self> {
validate_positive(prior_variance, "prior_variance")?;
validate_positive(likelihood_variance, "likelihood_variance")?;
Ok(Self {
prior_mean,
prior_variance,
likelihood_variance,
})
}
pub fn update(&self, data: &[f64]) -> Result<(f64, f64)> {
if data.is_empty() {
return Ok((self.prior_mean, self.prior_variance));
}
let n = data.len() as f64;
let sample_mean = data.iter().sum::<f64>() / n;
let precision_prior = 1.0 / self.prior_variance;
let precision_likelihood = n / self.likelihood_variance;
let precision_post = precision_prior + precision_likelihood;
let mean_post = (precision_prior * self.prior_mean + precision_likelihood * sample_mean)
/ precision_post;
let variance_post = 1.0 / precision_post;
Ok((mean_post, variance_post))
}
pub fn credible_interval_95(&self, data: &[f64]) -> Result<(f64, f64)> {
let (mean, variance) = self.update(data)?;
let std = variance.sqrt();
Ok((mean - 1.96 * std, mean + 1.96 * std))
}
}
#[derive(Debug, Clone)]
pub struct DirichletMultinomialConjugate {
prior: DirichletDistribution,
}
impl DirichletMultinomialConjugate {
pub fn new(alpha: Vec<f64>) -> Result<Self> {
let prior = DirichletDistribution::new(alpha)?;
Ok(Self { prior })
}
pub fn update(&self, counts: &[usize]) -> Result<DirichletDistribution> {
if counts.len() != self.prior.alpha().len() {
return Err(ProbabilisticError::DimensionMismatch {
expected: vec![self.prior.alpha().len()],
actual: vec![counts.len()],
operation: "Dirichlet-Multinomial update".to_string(),
});
}
let mut alpha_post = self.prior.alpha().clone();
for (i, &count) in counts.iter().enumerate() {
alpha_post[i] += count as f64;
}
DirichletDistribution::new(alpha_post)
}
pub fn posterior_predictive(&self, counts: &[usize]) -> Result<Vec<f64>> {
let posterior = self.update(counts)?;
Ok(posterior.mean())
}
}
pub fn bic(log_likelihood: f64, n_parameters: usize, n_observations: usize) -> f64 {
-2.0 * log_likelihood + (n_parameters as f64) * (n_observations as f64).ln()
}
pub fn aic(log_likelihood: f64, n_parameters: usize) -> f64 {
-2.0 * log_likelihood + 2.0 * (n_parameters as f64)
}
#[derive(Debug, Clone)]
pub struct DICResult {
pub dic: f64,
pub d_bar: f64,
pub p_d: f64,
}
pub fn dic(log_likelihood_samples: &[f64], log_likelihood_at_mean: f64) -> Result<DICResult> {
if log_likelihood_samples.is_empty() {
return Err(ProbabilisticError::InvalidParameter {
parameter: "log_likelihood_samples".to_string(),
message: "samples cannot be empty".to_string(),
});
}
let deviances: Vec<f64> = log_likelihood_samples.iter().map(|&ll| -2.0 * ll).collect();
let d_bar = deviances.iter().sum::<f64>() / deviances.len() as f64;
let d_hat = -2.0 * log_likelihood_at_mean;
let p_d = d_bar - d_hat;
let dic = d_bar + p_d;
Ok(DICResult { dic, d_bar, p_d })
}
#[derive(Debug, Clone)]
pub struct WAICResult {
pub waic: f64,
pub lppd: f64,
pub p_waic: f64,
}
pub fn waic(pointwise_log_likelihood: &[Vec<f64>]) -> Result<WAICResult> {
if pointwise_log_likelihood.is_empty() {
return Err(ProbabilisticError::InvalidParameter {
parameter: "pointwise_log_likelihood".to_string(),
message: "samples cannot be empty".to_string(),
});
}
let n_samples = pointwise_log_likelihood.len();
let n_obs = pointwise_log_likelihood[0].len();
let mut lppd = 0.0;
for i in 0..n_obs {
let mut mean_likelihood = 0.0;
for sample in pointwise_log_likelihood {
mean_likelihood += sample[i].exp();
}
mean_likelihood /= n_samples as f64;
lppd += mean_likelihood.ln();
}
let mut p_waic = 0.0;
for i in 0..n_obs {
let ll_i: Vec<f64> = pointwise_log_likelihood.iter().map(|s| s[i]).collect();
let mean_ll = ll_i.iter().sum::<f64>() / n_samples as f64;
let var_ll = ll_i.iter().map(|&ll| (ll - mean_ll).powi(2)).sum::<f64>() / n_samples as f64;
p_waic += var_ll;
}
let waic = -2.0 * (lppd - p_waic);
Ok(WAICResult { waic, lppd, p_waic })
}
pub fn equal_tailed_interval(samples: &[f64], alpha: f64) -> Result<(f64, f64)> {
validate_probability(alpha, "alpha")?;
if samples.is_empty() {
return Err(ProbabilisticError::InvalidParameter {
parameter: "samples".to_string(),
message: "samples cannot be empty".to_string(),
});
}
let mut sorted = samples.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
let lower_idx = ((alpha / 2.0) * n as f64).floor() as usize;
let upper_idx = ((1.0 - alpha / 2.0) * n as f64).ceil() as usize;
let lower = sorted[lower_idx.min(n - 1)];
let upper = sorted[upper_idx.min(n - 1)];
Ok((lower, upper))
}
pub fn hpd_interval(samples: &[f64], alpha: f64) -> Result<(f64, f64)> {
validate_probability(alpha, "alpha")?;
if samples.is_empty() {
return Err(ProbabilisticError::InvalidParameter {
parameter: "samples".to_string(),
message: "samples cannot be empty".to_string(),
});
}
let mut sorted = samples.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
let interval_size = ((1.0 - alpha) * n as f64).ceil() as usize;
if interval_size >= n {
return Ok((sorted[0], sorted[n - 1]));
}
let mut min_width = f64::INFINITY;
let mut best_lower = sorted[0];
let mut best_upper = sorted[n - 1];
for i in 0..=(n - interval_size) {
let lower = sorted[i];
let upper = sorted[i + interval_size - 1];
let width = upper - lower;
if width < min_width {
min_width = width;
best_lower = lower;
best_upper = upper;
}
}
Ok((best_lower, best_upper))
}
pub fn bayes_factor(log_marginal_likelihood_m1: f64, log_marginal_likelihood_m0: f64) -> f64 {
(log_marginal_likelihood_m1 - log_marginal_likelihood_m0).exp()
}
pub fn harmonic_mean_marginal_likelihood(log_likelihood_samples: &[f64]) -> Result<f64> {
if log_likelihood_samples.is_empty() {
return Err(ProbabilisticError::InvalidParameter {
parameter: "log_likelihood_samples".to_string(),
message: "samples cannot be empty".to_string(),
});
}
let n = log_likelihood_samples.len() as f64;
let sum_inv: f64 = log_likelihood_samples.iter().map(|&ll| (-ll).exp()).sum();
let harmonic_mean = n / sum_inv;
Ok(harmonic_mean.ln())
}
pub fn posterior_predictive_pvalue(
test_statistic_observed: f64,
test_statistic_replicated: &[f64],
) -> Result<f64> {
if test_statistic_replicated.is_empty() {
return Err(ProbabilisticError::InvalidParameter {
parameter: "test_statistic_replicated".to_string(),
message: "replicated statistics cannot be empty".to_string(),
});
}
let n_greater = test_statistic_replicated
.iter()
.filter(|&&t| t >= test_statistic_observed)
.count();
Ok(n_greater as f64 / test_statistic_replicated.len() as f64)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_beta_binomial_conjugate() {
let prior = BetaBinomialConjugate::new(1.0, 1.0)
.expect("test: valid Beta-Binomial prior parameters");
let posterior = prior.update(7, 10).expect("test: valid posterior update");
assert_relative_eq!(posterior.alpha(), 8.0, epsilon = 1e-10);
assert_relative_eq!(posterior.beta(), 4.0, epsilon = 1e-10);
assert_relative_eq!(posterior.mean(), 2.0 / 3.0, epsilon = 1e-10);
}
#[test]
fn test_gamma_poisson_conjugate() {
let prior = GammaPoissonConjugate::new(1.0, 1.0)
.expect("test: valid Gamma-Poisson prior parameters");
let data = vec![2.0, 3.0, 4.0, 3.0, 2.0];
let posterior = prior
.update(&data)
.expect("test: valid Gamma-Poisson posterior update");
assert_relative_eq!(posterior.alpha(), 15.0, epsilon = 1e-10);
assert_relative_eq!(posterior.beta(), 6.0, epsilon = 1e-10);
}
#[test]
fn test_normal_normal_conjugate() {
let prior = NormalNormalConjugate::new(0.0, 1.0, 1.0)
.expect("test: valid Normal-Normal prior parameters");
let data = vec![1.8, 2.0, 2.2, 1.9, 2.1];
let (mean_post, var_post) = prior
.update(&data)
.expect("test: valid Normal-Normal posterior update");
assert!(mean_post > 0.0 && mean_post < 2.0);
assert!(var_post < 1.0);
}
#[test]
fn test_dirichlet_multinomial_conjugate() {
let prior = DirichletMultinomialConjugate::new(vec![1.0, 1.0, 1.0])
.expect("test: valid Dirichlet-Multinomial prior parameters");
let counts = vec![10, 20, 15];
let posterior = prior
.update(&counts)
.expect("test: valid Dirichlet-Multinomial posterior update");
assert_relative_eq!(posterior.alpha()[0], 11.0, epsilon = 1e-10);
assert_relative_eq!(posterior.alpha()[1], 21.0, epsilon = 1e-10);
assert_relative_eq!(posterior.alpha()[2], 16.0, epsilon = 1e-10);
}
#[test]
fn test_bic() {
let log_lik = -100.0;
let bic_value = bic(log_lik, 5, 100);
assert!(bic_value > 220.0 && bic_value < 225.0);
}
#[test]
fn test_aic() {
let log_lik = -100.0;
let aic_value = aic(log_lik, 5);
assert_relative_eq!(aic_value, 210.0, epsilon = 1e-10);
}
#[test]
fn test_dic() {
let log_lik_samples = vec![-10.0, -12.0, -11.0, -10.5, -11.5];
let log_lik_at_mean = -11.0;
let dic_result =
dic(&log_lik_samples, log_lik_at_mean).expect("test: valid DIC computation");
assert!(dic_result.dic.is_finite());
assert!(dic_result.d_bar > 0.0);
assert!(dic_result.p_d.is_finite());
}
#[test]
fn test_waic() {
let pointwise_ll = vec![
vec![-1.0, -2.0, -1.5, -2.5],
vec![-1.2, -1.8, -1.4, -2.3],
vec![-0.9, -2.1, -1.6, -2.4],
];
let waic_result = waic(&pointwise_ll).expect("test: valid WAIC computation");
assert!(waic_result.waic.is_finite());
assert!(waic_result.lppd.is_finite());
assert!(waic_result.p_waic >= 0.0);
}
#[test]
fn test_equal_tailed_interval() {
let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let (lower, upper) = equal_tailed_interval(&samples, 0.1)
.expect("test: valid equal-tailed interval computation");
assert!(lower >= 1.0);
assert!(upper <= 10.0);
assert!(lower < upper);
}
#[test]
fn test_hpd_interval() {
let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let (lower, upper) =
hpd_interval(&samples, 0.1).expect("test: valid HPD interval computation");
assert!(lower < upper);
assert!(lower >= 1.0);
assert!(upper <= 10.0);
}
#[test]
fn test_bayes_factor() {
let log_ml_m1 = -100.0;
let log_ml_m0 = -110.0;
let bf = bayes_factor(log_ml_m1, log_ml_m0);
assert_relative_eq!(bf, 10.0_f64.exp(), epsilon = 1e-6);
}
#[test]
fn test_harmonic_mean_marginal_likelihood() {
let log_lik_samples = vec![-10.0, -11.0, -12.0, -10.5, -11.5];
let log_ml = harmonic_mean_marginal_likelihood(&log_lik_samples)
.expect("test: valid marginal likelihood computation");
assert!(log_ml.is_finite());
assert!(log_ml < 0.0); }
#[test]
fn test_posterior_predictive_pvalue() {
let observed = 5.0;
let replicated = vec![3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let pvalue = posterior_predictive_pvalue(observed, &replicated)
.expect("test: valid posterior predictive p-value");
assert_relative_eq!(pvalue, 6.0 / 8.0, epsilon = 1e-10);
}
#[test]
fn test_credible_interval_95() {
let prior = NormalNormalConjugate::new(0.0, 1.0, 1.0)
.expect("test: valid Normal-Normal prior parameters");
let data = vec![1.0, 2.0, 1.5, 1.8, 2.1];
let (lower, upper) = prior
.credible_interval_95(&data)
.expect("test: valid 95% credible interval");
assert!(lower < upper);
let (mean, _) = prior
.update(&data)
.expect("test: valid Normal-Normal posterior update");
assert!(lower < mean && mean < upper);
}
}