use crate::error::{StatsError, StatsResult};
use crate::sampling::SampleableDistribution;
use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
use scirs2_core::random::prelude::*;
use scirs2_core::random::{ChiSquared, Distribution, Normal as RandNormal};
use std::fmt::Debug;
use super::normal::{compute_cholesky, compute_inverse_from_cholesky};
#[allow(dead_code)]
fn lgamma(x: f64) -> f64 {
if x <= 0.0 {
panic!("lgamma requires positive input");
}
if x.fract() == 0.0 && x <= 20.0 {
let n = x as usize;
if n == 1 || n == 2 {
return 0.0; }
let mut result = 0.0;
for i in 2..n {
result += (i as f64).ln();
}
return result;
}
if (x - 0.5).abs() < 1e-10 {
return (std::f64::consts::PI.sqrt()).ln();
}
if x > 1.0 {
return (x - 1.0).ln() + lgamma(x - 1.0);
}
if x < 1.0 {
return (std::f64::consts::PI / (std::f64::consts::PI * x).sin()).ln() - lgamma(1.0 - x);
}
let p = [
676.5203681218851,
-1259.1392167224028,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507343278686905,
-0.13857109526572012,
9.984_369_578_019_572e-6,
1.5056327351493116e-7,
];
let x_adj = x - 1.0;
let t = x_adj + 7.5;
let mut sum = 0.0;
for (i, &coef) in p.iter().enumerate() {
sum += coef / (x_adj + (i + 1) as f64);
}
let pi = std::f64::consts::PI;
let sqrt_2pi = (2.0 * pi).sqrt();
sqrt_2pi.ln() + sum.ln() + (x_adj + 0.5) * t.ln() - t
}
#[allow(dead_code)]
pub fn lmultigamma(n: f64, p: usize) -> f64 {
let pi = std::f64::consts::PI;
let term1 = (p * (p - 1)) as f64 / 4.0 * pi.ln();
let mut term2 = 0.0;
for i in 1..=p {
let arg = (n + 1.0 - i as f64) / 2.0;
term2 += lgamma(arg);
}
term1 + term2
}
#[derive(Debug, Clone)]
pub struct Wishart {
pub scale: Array2<f64>,
pub df: f64,
pub dim: usize,
scale_chol: Array2<f64>,
scale_det: f64,
}
impl Wishart {
pub fn new<D>(scale: ArrayBase<D, Ix2>, df: f64) -> StatsResult<Self>
where
D: Data<Elem = f64>,
{
let scale_owned = scale.to_owned();
let dim = scale_owned.shape()[0];
if scale_owned.shape()[1] != dim {
return Err(StatsError::DimensionMismatch(
"Scale matrix must be square".to_string(),
));
}
if df < dim as f64 {
return Err(StatsError::DomainError(format!(
"Degrees of freedom ({}) must be greater than or equal to dimension ({})",
df, dim
)));
}
let scale_chol = compute_cholesky(&scale_owned).map_err(|_| {
StatsError::DomainError("Scale matrix must be positive definite".to_string())
})?;
let scale_det = {
let mut det = 1.0;
for i in 0..dim {
det *= scale_chol[[i, i]];
}
det * det };
Ok(Wishart {
scale: scale_owned,
df,
dim,
scale_chol,
scale_det,
})
}
pub fn pdf<D>(&self, x: &ArrayBase<D, Ix2>) -> f64
where
D: Data<Elem = f64>,
{
if x.shape()[0] != self.dim || x.shape()[1] != self.dim {
return 0.0;
}
let x_owned = x.to_owned();
let x_chol = match compute_cholesky(&x_owned) {
Ok(chol) => chol,
Err(_) => return 0.0, };
self.logpdf_with_cholesky(x, &x_chol).exp()
}
fn logpdf_with_cholesky<D>(&self, x: &ArrayBase<D, Ix2>, xchol: &Array2<f64>) -> f64
where
D: Data<Elem = f64>,
{
let mut x_det = 1.0;
for i in 0..self.dim {
x_det *= xchol[[i, i]];
}
x_det = x_det * x_det;
let scale_inv = compute_inverse_from_cholesky(&self.scale_chol)
.expect("Failed to compute matrix inverse");
let mut trace = 0.0;
for i in 0..self.dim {
for j in 0..self.dim {
trace += scale_inv[[i, j]] * x[[j, i]];
}
}
let p = self.dim as f64;
let n = self.df;
let term1 = -0.5 * trace;
let term2 = -0.5 * n * self.scale_det.ln();
let term3 = -0.5 * p * (2.0f64).ln();
let term4 = -lmultigamma(n, self.dim);
let term5 = 0.5 * (n - p - 1.0) * x_det.ln();
term1 + term2 + term3 + term4 + term5
}
pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix2>) -> f64
where
D: Data<Elem = f64>,
{
if x.shape()[0] != self.dim || x.shape()[1] != self.dim {
return f64::NEG_INFINITY;
}
let x_owned = x.to_owned();
let x_chol = match compute_cholesky(&x_owned) {
Ok(chol) => chol,
Err(_) => return f64::NEG_INFINITY, };
self.logpdf_with_cholesky(x, &x_chol)
}
pub fn rvs(&self, size: usize) -> StatsResult<Vec<Array2<f64>>> {
let mut rng = thread_rng();
let normal_dist = RandNormal::new(0.0, 1.0).expect("Operation failed");
let mut samples = Vec::with_capacity(size);
for _ in 0..size {
if self.df.fract() == 0.0 {
let n = self.df as usize;
let mut x = Array2::<f64>::zeros((self.dim, self.dim));
for _ in 0..n {
let mut z = Array1::<f64>::zeros(self.dim);
for j in 0..self.dim {
z[j] = normal_dist.sample(&mut rng);
}
let az = self.scale_chol.dot(&z);
for i in 0..self.dim {
for j in 0..self.dim {
x[[i, j]] += az[i] * az[j];
}
}
}
samples.push(x);
} else {
let mut a = Array2::<f64>::zeros((self.dim, self.dim));
for i in 0..self.dim {
let df_i = self.df - (i as f64);
let chi2_dist = ChiSquared::new(df_i).map_err(|_| {
StatsError::ComputationError(
"Failed to create chi-square distribution".to_string(),
)
})?;
a[[i, i]] = chi2_dist.sample(&mut rng).sqrt();
}
for i in 0..self.dim {
for j in 0..i {
a[[i, j]] = normal_dist.sample(&mut rng);
}
}
let b = self.scale_chol.dot(&a);
let mut x = Array2::<f64>::zeros((self.dim, self.dim));
for i in 0..self.dim {
for j in 0..=i {
let mut sum = 0.0;
for k in 0..self.dim {
sum += b[[i, k]] * b[[j, k]];
}
x[[i, j]] = sum;
if i != j {
x[[j, i]] = sum; }
}
}
samples.push(x);
}
}
Ok(samples)
}
pub fn rvs_single(&self) -> StatsResult<Array2<f64>> {
let samples = self.rvs(1)?;
Ok(samples[0].clone())
}
pub fn mean(&self) -> Array2<f64> {
let mut mean = self.scale.clone();
mean *= self.df;
mean
}
pub fn mode(&self) -> Option<Array2<f64>> {
let p = self.dim as f64;
if self.df < p + 1.0 {
None } else {
let mut mode = self.scale.clone();
mode *= self.df - p - 1.0;
Some(mode)
}
}
}
#[allow(dead_code)]
pub fn wishart<D>(scale: ArrayBase<D, Ix2>, df: f64) -> StatsResult<Wishart>
where
D: Data<Elem = f64>,
{
Wishart::new(scale, df)
}
impl SampleableDistribution<Array2<f64>> for Wishart {
fn rvs(&self, size: usize) -> StatsResult<Vec<Array2<f64>>> {
self.rvs(size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_wishart_creation() {
let scale = array![[1.0, 0.0], [0.0, 1.0]];
let df = 5.0;
let wishart = Wishart::new(scale.clone(), df).expect("Operation failed");
assert_eq!(wishart.dim, 2);
assert_eq!(wishart.df, df);
assert_eq!(wishart.scale, scale);
let scale3 = array![[2.0, 0.5, 0.3], [0.5, 1.0, 0.1], [0.3, 0.1, 1.5]];
let df3 = 10.0;
let wishart3 = Wishart::new(scale3.clone(), df3).expect("Operation failed");
assert_eq!(wishart3.dim, 3);
assert_eq!(wishart3.df, df3);
assert_eq!(wishart3.scale, scale3);
}
#[test]
fn test_wishart_creation_errors() {
let non_square_scale = array![[1.0, 0.5, 0.3], [0.5, 1.0, 0.1]];
assert!(Wishart::new(non_square_scale, 5.0).is_err());
let scale = array![[1.0, 0.0], [0.0, 1.0]];
assert!(Wishart::new(scale.clone(), 1.0).is_err());
let non_pd_scale = array![[1.0, 2.0], [2.0, 1.0]]; assert!(Wishart::new(non_pd_scale, 5.0).is_err());
}
#[test]
fn test_wishart_mean() {
let scale = array![[1.0, 0.5], [0.5, 2.0]];
let df = 5.0;
let wishart = Wishart::new(scale.clone(), df).expect("Operation failed");
let mean = wishart.mean();
let expected_mean = scale * df;
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(mean[[i, j]], expected_mean[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_wishart_mode() {
let scale = array![[1.0, 0.5], [0.5, 2.0]];
let df = 5.0; let wishart = Wishart::new(scale.clone(), df).expect("Operation failed");
let mode = wishart.mode().expect("Operation failed"); let expected_mode = scale.clone() * (df - 3.0);
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(mode[[i, j]], expected_mode[[i, j]], epsilon = 1e-10);
}
}
let wishart2 = Wishart::new(scale, 2.5).expect("Operation failed"); assert!(wishart2.mode().is_none()); }
#[test]
fn test_wishart_pdf() {
let scale = array![[1.0, 0.0], [0.0, 1.0]];
let df = 5.0;
let wishart = Wishart::new(scale, df).expect("Operation failed");
let x = array![[1.0, 0.0], [0.0, 1.0]];
let pdf_at_id = wishart.pdf(&x);
assert!(pdf_at_id > 0.0);
let x2 = array![[2.0, 0.5], [0.5, 3.0]];
let pdf_at_x2 = wishart.pdf(&x2);
assert!(pdf_at_x2 > 0.0);
let non_pd = array![[1.0, 2.0], [2.0, 1.0]];
assert_eq!(wishart.pdf(&non_pd), 0.0);
}
#[test]
fn test_wishart_logpdf() {
let scale = array![[1.0, 0.0], [0.0, 1.0]];
let df = 5.0;
let wishart = Wishart::new(scale, df).expect("Operation failed");
let x = array![[2.0, 0.5], [0.5, 3.0]];
let pdf = wishart.pdf(&x);
let logpdf = wishart.logpdf(&x);
assert_relative_eq!(logpdf.exp(), pdf, epsilon = 1e-10);
let non_pd = array![[1.0, 2.0], [2.0, 1.0]];
assert_eq!(wishart.logpdf(&non_pd), f64::NEG_INFINITY);
}
#[test]
fn test_wishart_rvs() {
let scale = array![[1.0, 0.5], [0.5, 2.0]];
let df = 5.0;
let wishart = Wishart::new(scale.clone(), df).expect("Operation failed");
let n_samples_ = 1000;
let samples = wishart.rvs(n_samples_).expect("Operation failed");
assert_eq!(samples.len(), n_samples_);
for sample in &samples {
assert_eq!(sample.shape(), &[2, 2]);
}
let mut sample_mean = Array2::<f64>::zeros((2, 2));
for sample in &samples {
sample_mean += sample;
}
sample_mean /= n_samples_ as f64;
let expected_mean = scale * df;
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(
sample_mean[[i, j]],
expected_mean[[i, j]],
epsilon = 0.8, max_relative = 0.3
);
}
}
}
#[test]
fn test_wishart_rvs_single() {
let scale = array![[1.0, 0.5], [0.5, 2.0]];
let df = 5.0;
let wishart = Wishart::new(scale, df).expect("Operation failed");
let sample = wishart.rvs_single().expect("Operation failed");
assert_eq!(sample.shape(), &[2, 2]);
assert_relative_eq!(sample[[0, 1]], sample[[1, 0]], epsilon = 1e-10);
assert!(compute_cholesky(&sample).is_ok());
}
}