use crate::error::{StatsError, StatsResult as Result};
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::validation::*;
use statrs::statistics::Statistics;
#[derive(Debug, Clone)]
pub struct BetaBinomial {
pub alpha: f64,
pub beta: f64,
}
impl BetaBinomial {
pub fn new(alpha: f64, beta: f64) -> Result<Self> {
check_positive(alpha, "alpha")?;
check_positive(beta, "beta")?;
Ok(Self { alpha, beta })
}
pub fn update(&self, successes: usize, failures: usize) -> Self {
Self {
alpha: self.alpha + successes as f64,
beta: self.beta + failures as f64,
}
}
pub fn posterior_mean(&self) -> Result<f64> {
let total = self.alpha + self.beta;
if total.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior mean: alpha + beta too close to zero",
));
}
Ok(self.alpha / total)
}
pub fn posterior_variance(&self) -> Result<f64> {
let total = self.alpha + self.beta;
if total.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior variance: alpha + beta too close to zero",
));
}
let denominator = total * total * (total + 1.0);
if denominator.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior variance: denominator too close to zero",
));
}
Ok((self.alpha * self.beta) / denominator)
}
pub fn posterior_mode(&self) -> Result<Option<f64>> {
if self.alpha > 1.0 && self.beta > 1.0 {
let denominator = self.alpha + self.beta - 2.0;
if denominator.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior mode: alpha + beta - 2 too close to zero",
));
}
Ok(Some((self.alpha - 1.0) / denominator))
} else {
Ok(None)
}
}
pub fn credible_interval(&self, confidence: f64) -> Result<(f64, f64)> {
check_probability(confidence, "confidence")?;
use crate::distributions::beta::Beta;
let dist = Beta::new(self.alpha, self.beta, 0.0, 1.0)?;
let alpha_level = (1.0 - confidence) / 2.0;
Ok((dist.ppf(alpha_level)?, dist.ppf(1.0 - alpha_level)?))
}
}
#[derive(Debug, Clone)]
pub struct GammaPoisson {
pub alpha: f64,
pub beta: f64,
}
impl GammaPoisson {
pub fn new(alpha: f64, beta: f64) -> Result<Self> {
check_positive(alpha, "alpha")?;
check_positive(beta, "beta")?;
Ok(Self { alpha, beta })
}
pub fn update(&self, data: ArrayView1<f64>) -> Result<Self> {
checkarray_finite(&data, "data")?;
let sum: f64 = data.sum();
let n = data.len() as f64;
Ok(Self {
alpha: self.alpha + sum,
beta: self.beta + n,
})
}
pub fn posterior_mean(&self) -> Result<f64> {
if self.beta.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior mean: beta too close to zero",
));
}
Ok(self.alpha / self.beta)
}
pub fn posterior_variance(&self) -> Result<f64> {
if self.beta.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior variance: beta too close to zero",
));
}
Ok(self.alpha / (self.beta * self.beta))
}
pub fn posterior_mode(&self) -> Result<Option<f64>> {
if self.alpha >= 1.0 {
if self.beta.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior mode: beta too close to zero",
));
}
Ok(Some((self.alpha - 1.0) / self.beta))
} else {
Ok(None)
}
}
pub fn credible_interval(&self, confidence: f64) -> Result<(f64, f64)> {
check_probability(confidence, "confidence")?;
use crate::distributions::gamma::Gamma;
let dist = Gamma::new(self.alpha, 1.0 / self.beta, 0.0)?;
let alpha_level = (1.0 - confidence) / 2.0;
Ok((dist.ppf(alpha_level)?, dist.ppf(1.0 - alpha_level)?))
}
}
#[derive(Debug, Clone)]
pub struct NormalKnownVariance {
pub prior_mean: f64,
pub prior_variance: f64,
pub data_variance: f64,
}
impl NormalKnownVariance {
pub fn new(prior_mean: f64, prior_variance: f64, data_variance: f64) -> Result<Self> {
check_positive(prior_variance, "prior_variance")?;
check_positive(data_variance, "data_variance")?;
Ok(Self {
prior_mean,
prior_variance,
data_variance,
})
}
pub fn update(&self, data: ArrayView1<f64>) -> Result<Self> {
checkarray_finite(&data, "data")?;
let n = data.len() as f64;
let data_mean = data.mean();
if self.prior_variance.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot update: prior_variance too close to zero",
));
}
if self.data_variance.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot update: data_variance too close to zero",
));
}
let precision_prior = 1.0 / self.prior_variance;
let precisiondata = n / self.data_variance;
let precision_posterior = precision_prior + precisiondata;
if precision_posterior.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot update: precision_posterior too close to zero",
));
}
let posterior_variance = 1.0 / precision_posterior;
let posterior_mean =
(precision_prior * self.prior_mean + precisiondata * data_mean) / precision_posterior;
Ok(Self {
prior_mean: posterior_mean,
prior_variance: posterior_variance,
data_variance: self.data_variance,
})
}
pub fn posterior_mean(&self) -> f64 {
self.prior_mean
}
pub fn posterior_variance(&self) -> f64 {
self.prior_variance
}
pub fn credible_interval(&self, confidence: f64) -> Result<(f64, f64)> {
check_probability(confidence, "confidence")?;
use crate::distributions::normal::Normal;
if self.prior_variance < 0.0 {
return Err(StatsError::domain(
"Cannot compute credible interval: prior_variance must be non-negative",
));
}
let dist = Normal::new(self.prior_mean, self.prior_variance.sqrt())?;
let alpha_level = (1.0 - confidence) / 2.0;
Ok((dist.ppf(alpha_level)?, dist.ppf(1.0 - alpha_level)?))
}
pub fn predictive_params(&self) -> (f64, f64) {
(self.prior_mean, self.prior_variance + self.data_variance)
}
}
#[derive(Debug, Clone)]
pub struct DirichletMultinomial {
pub alpha: Array1<f64>,
}
impl DirichletMultinomial {
pub fn new(alpha: Array1<f64>) -> Result<Self> {
checkarray_finite(&alpha, "alpha")?;
for &a in alpha.iter() {
check_positive(a, "_alpha element")?;
}
Ok(Self { alpha })
}
pub fn uniform(k: usize) -> Result<Self> {
check_positive(k, "k")?;
Ok(Self {
alpha: Array1::from_elem(k, 1.0),
})
}
pub fn update(&self, counts: ArrayView1<f64>) -> Result<Self> {
if counts.len() != self.alpha.len() {
return Err(StatsError::DimensionMismatch(format!(
"counts length ({}) must match alpha length ({})",
counts.len(),
self.alpha.len()
)));
}
checkarray_finite(&counts, "counts")?;
Ok(Self {
alpha: &self.alpha + &counts,
})
}
pub fn posterior_mean(&self) -> Result<Array1<f64>> {
let sum = self.alpha.sum();
if sum.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior mean: sum of alpha parameters too close to zero",
));
}
Ok(&self.alpha / sum)
}
pub fn posterior_mode(&self) -> Result<Option<Array1<f64>>> {
let k = self.alpha.len() as f64;
if self.alpha.iter().all(|&a| a > 1.0) {
let sum = self.alpha.sum();
let denominator = sum - k;
if denominator.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior mode: sum - k too close to zero",
));
}
Ok(Some((&self.alpha - 1.0) / denominator))
} else {
Ok(None)
}
}
pub fn posterior_variance(&self) -> Result<Array1<f64>> {
let sum = self.alpha.sum();
let denominator = sum + 1.0;
if denominator.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior variance: sum + 1 too close to zero",
));
}
let mean = self.posterior_mean()?;
Ok(mean.mapv(|p| p * (1.0 - p) / denominator))
}
}
#[derive(Debug, Clone)]
pub struct NormalInverseGamma {
pub mu0: f64,
pub lambda: f64,
pub alpha: f64,
pub beta: f64,
}
impl NormalInverseGamma {
pub fn new(mu0: f64, lambda: f64, alpha: f64, beta: f64) -> Result<Self> {
check_positive(lambda, "lambda")?;
check_positive(alpha, "alpha")?;
check_positive(beta, "beta")?;
Ok(Self {
mu0,
lambda,
alpha,
beta,
})
}
pub fn update(&self, data: ArrayView1<f64>) -> Result<Self> {
checkarray_finite(&data, "data")?;
let n = data.len() as f64;
let data_mean = data.mean();
let ss = data.mapv(|x| (x - data_mean).powi(2)).sum();
let lambda_n = self.lambda + n;
if lambda_n.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot update: lambda_n too close to zero",
));
}
let mu_n = (self.lambda * self.mu0 + n * data_mean) / lambda_n;
let alpha_n = self.alpha + n / 2.0;
let beta_n = self.beta
+ 0.5 * ss
+ 0.5 * self.lambda * n * (data_mean - self.mu0).powi(2) / lambda_n;
Ok(Self {
mu0: mu_n,
lambda: lambda_n,
alpha: alpha_n,
beta: beta_n,
})
}
pub fn posterior_mean_mu(&self) -> f64 {
self.mu0
}
pub fn posterior_mean_variance(&self) -> Result<f64> {
let denominator = self.alpha - 1.0;
if denominator.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior mean variance: alpha - 1 too close to zero",
));
}
Ok(self.beta / denominator)
}
pub fn posterior_variance_mu(&self) -> Result<f64> {
let denominator = self.lambda * (self.alpha - 1.0);
if denominator.abs() < f64::EPSILON {
return Err(StatsError::domain(
"Cannot compute posterior variance mu: lambda * (alpha - 1) too close to zero",
));
}
Ok(self.beta / denominator)
}
}