use crate::error::{StatsError, StatsResult};
use crate::sampling::SampleableDistribution;
use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
use std::fmt::Debug;
use super::normal::{compute_cholesky, compute_inverse_from_cholesky};
use super::wishart::Wishart;
#[derive(Debug, Clone)]
pub struct InverseWishart {
pub scale: Array2<f64>,
pub df: f64,
pub dim: usize,
#[allow(dead_code)]
scale_chol: Array2<f64>,
scale_det: f64,
}
impl InverseWishart {
pub fn new<D>(scale: ArrayBase<D, Ix2>, df: f64) -> StatsResult<Self>
where
D: Data<Elem = f64>,
{
let scale_owned = scale.to_owned();
let dim = scale_owned.shape()[0];
if scale_owned.shape()[1] != dim {
return Err(StatsError::DimensionMismatch(
"Scale matrix must be square".to_string(),
));
}
if df <= dim as f64 + 1.0 {
return Err(StatsError::DomainError(format!(
"Degrees of freedom ({}) must be greater than dimension + 1 ({})",
df,
dim + 1
)));
}
let scale_chol = compute_cholesky(&scale_owned).map_err(|_| {
StatsError::DomainError("Scale matrix must be positive definite".to_string())
})?;
let scale_det = {
let mut det = 1.0;
for i in 0..dim {
det *= scale_chol[[i, i]];
}
det * det };
Ok(InverseWishart {
scale: scale_owned,
df,
dim,
scale_chol,
scale_det,
})
}
pub fn pdf<D>(&self, x: &ArrayBase<D, Ix2>) -> f64
where
D: Data<Elem = f64>,
{
if x.shape()[0] != self.dim || x.shape()[1] != self.dim {
return 0.0;
}
let x_owned = x.to_owned();
let x_chol = match compute_cholesky(&x_owned) {
Ok(chol) => chol,
Err(_) => return 0.0, };
self.logpdf_with_cholesky(x, &x_chol).exp()
}
fn logpdf_with_cholesky<D>(&self, x: &ArrayBase<D, Ix2>, xchol: &Array2<f64>) -> f64
where
D: Data<Elem = f64>,
{
let mut x_det = 1.0;
for i in 0..self.dim {
x_det *= xchol[[i, i]];
}
x_det = x_det * x_det;
let x_inv = compute_inverse_from_cholesky(xchol).expect("Failed to compute matrix inverse");
let mut trace = 0.0;
for i in 0..self.dim {
for j in 0..self.dim {
trace += self.scale[[i, j]] * x_inv[[j, i]];
}
}
let p = self.dim as f64;
let n = self.df;
let term1 = -0.5 * trace;
let term2 = -0.5 * self.scale_det.ln();
let term3 = -0.5 * p * (2.0f64).ln();
let term4 = -super::wishart::lmultigamma(n, self.dim);
let term5 = -0.5 * (n + p + 1.0) * x_det.ln();
term1 + term2 + term3 + term4 + term5
}
pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix2>) -> f64
where
D: Data<Elem = f64>,
{
if x.shape()[0] != self.dim || x.shape()[1] != self.dim {
return f64::NEG_INFINITY;
}
let x_owned = x.to_owned();
let x_chol = match compute_cholesky(&x_owned) {
Ok(chol) => chol,
Err(_) => return f64::NEG_INFINITY, };
self.logpdf_with_cholesky(x, &x_chol)
}
pub fn rvs(&self, size: usize) -> StatsResult<Vec<Array2<f64>>> {
let wishart = Wishart::new(self.scale.clone(), self.df)?;
let wishart_samples = wishart.rvs(size)?;
let mut inv_wishart_samples = Vec::with_capacity(size);
for sample in wishart_samples {
let sample_chol = compute_cholesky(&sample).map_err(|_| {
StatsError::ComputationError("Failed to compute Cholesky decomposition".to_string())
})?;
let inv_sample = compute_inverse_from_cholesky(&sample_chol).map_err(|_| {
StatsError::ComputationError("Failed to compute matrix inverse".to_string())
})?;
inv_wishart_samples.push(inv_sample);
}
Ok(inv_wishart_samples)
}
pub fn rvs_single(&self) -> StatsResult<Array2<f64>> {
let samples = self.rvs(1)?;
Ok(samples[0].clone())
}
pub fn mean(&self) -> StatsResult<Array2<f64>> {
let p = self.dim as f64;
let nu = self.df;
if nu <= p + 1.0 {
return Err(StatsError::DomainError(
"Mean is undefined for degrees of freedom <= dimension + 1".to_string(),
));
}
let mut mean = self.scale.clone();
mean /= nu - p - 1.0;
Ok(mean)
}
pub fn mode(&self) -> Array2<f64> {
let p = self.dim as f64;
let nu = self.df;
let mut mode = self.scale.clone();
mode /= nu + p + 1.0;
mode
}
}
#[allow(dead_code)]
pub fn inverse_wishart<D>(scale: ArrayBase<D, Ix2>, df: f64) -> StatsResult<InverseWishart>
where
D: Data<Elem = f64>,
{
InverseWishart::new(scale, df)
}
impl SampleableDistribution<Array2<f64>> for InverseWishart {
fn rvs(&self, size: usize) -> StatsResult<Vec<Array2<f64>>> {
self.rvs(size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_inverse_wishart_creation() {
let scale = array![[1.0, 0.0], [0.0, 1.0]];
let df = 5.0;
let inv_wishart = InverseWishart::new(scale.clone(), df).expect("Operation failed");
assert_eq!(inv_wishart.dim, 2);
assert_eq!(inv_wishart.df, df);
assert_eq!(inv_wishart.scale, scale);
let scale3 = array![[2.0, 0.5, 0.3], [0.5, 1.0, 0.1], [0.3, 0.1, 1.5]];
let df3 = 10.0;
let inv_wishart3 = InverseWishart::new(scale3.clone(), df3).expect("Operation failed");
assert_eq!(inv_wishart3.dim, 3);
assert_eq!(inv_wishart3.df, df3);
assert_eq!(inv_wishart3.scale, scale3);
}
#[test]
fn test_inverse_wishart_creation_errors() {
let non_square_scale = array![[1.0, 0.5, 0.3], [0.5, 1.0, 0.1]];
assert!(InverseWishart::new(non_square_scale, 5.0).is_err());
let scale = array![[1.0, 0.0], [0.0, 1.0]];
assert!(InverseWishart::new(scale.clone(), 3.0).is_err());
let non_pd_scale = array![[1.0, 2.0], [2.0, 1.0]]; assert!(InverseWishart::new(non_pd_scale, 5.0).is_err());
}
#[test]
fn test_inverse_wishart_mean() {
let scale = array![[1.0, 0.5], [0.5, 2.0]];
let df = 5.0;
let inv_wishart = InverseWishart::new(scale.clone(), df).expect("Operation failed");
let mean = inv_wishart.mean().expect("Operation failed");
let expected_mean = scale.clone() / (df - 3.0);
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(mean[[i, j]], expected_mean[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_inverse_wishart_mode() {
let scale = array![[1.0, 0.5], [0.5, 2.0]];
let df = 5.0;
let inv_wishart = InverseWishart::new(scale.clone(), df).expect("Operation failed");
let mode = inv_wishart.mode();
let expected_mode = scale.clone() / (df + 3.0);
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(mode[[i, j]], expected_mode[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_inverse_wishart_pdf() {
let scale = array![[1.0, 0.0], [0.0, 1.0]];
let df = 5.0;
let inv_wishart = InverseWishart::new(scale, df).expect("Operation failed");
let x = array![[1.0, 0.0], [0.0, 1.0]];
let pdf_at_id = inv_wishart.pdf(&x);
assert!(pdf_at_id > 0.0);
let x2 = array![[0.5, 0.1], [0.1, 0.8]];
let pdf_at_x2 = inv_wishart.pdf(&x2);
assert!(pdf_at_x2 > 0.0);
let non_pd = array![[1.0, 2.0], [2.0, 1.0]];
assert_eq!(inv_wishart.pdf(&non_pd), 0.0);
}
#[test]
fn test_inverse_wishart_logpdf() {
let scale = array![[1.0, 0.0], [0.0, 1.0]];
let df = 5.0;
let inv_wishart = InverseWishart::new(scale, df).expect("Operation failed");
let x = array![[0.5, 0.1], [0.1, 0.8]];
let pdf = inv_wishart.pdf(&x);
let logpdf = inv_wishart.logpdf(&x);
assert_relative_eq!(logpdf.exp(), pdf, epsilon = 1e-10);
let non_pd = array![[1.0, 2.0], [2.0, 1.0]];
assert_eq!(inv_wishart.logpdf(&non_pd), f64::NEG_INFINITY);
}
#[test]
fn test_inverse_wishart_rvs() {
let scale = array![[1.0, 0.5], [0.5, 2.0]];
let df = 5.0;
let inv_wishart = InverseWishart::new(scale.clone(), df).expect("Operation failed");
let n_samples_ = 100;
let samples = inv_wishart.rvs(n_samples_).expect("Operation failed");
assert_eq!(samples.len(), n_samples_);
for sample in &samples {
assert_eq!(sample.shape(), &[2, 2]);
}
let mut sample_mean = Array2::<f64>::zeros((2, 2));
for sample in &samples {
sample_mean += sample;
}
sample_mean /= n_samples_ as f64;
let expected_mean = scale.clone() / (df - 3.0);
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(
sample_mean[[i, j]],
expected_mean[[i, j]],
epsilon = 1.0, max_relative = 0.8 );
}
}
}
#[test]
fn test_inverse_wishart_rvs_single() {
let scale = array![[1.0, 0.5], [0.5, 2.0]];
let df = 5.0;
let inv_wishart = InverseWishart::new(scale, df).expect("Operation failed");
let sample = inv_wishart.rvs_single().expect("Operation failed");
assert_eq!(sample.shape(), &[2, 2]);
assert_relative_eq!(sample[[0, 1]], sample[[1, 0]], epsilon = 1e-10);
assert!(compute_cholesky(&sample).is_ok());
}
}