use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::{Float, One, Zero};
use std::f64::consts::PI;
use crate::basic::{det, inv};
use crate::decomposition::cholesky;
use crate::error::{LinalgError, LinalgResult};
use crate::random::random_normalmatrix;
#[derive(Debug, Clone)]
pub struct MatrixNormalParams<F: Float> {
pub mean: Array2<F>,
pub row_cov: Array2<F>,
pub col_cov: Array2<F>,
}
impl<F: Float + Zero + One + Copy + std::fmt::Debug + std::fmt::Display> MatrixNormalParams<F> {
pub fn new(_mean: Array2<F>, row_cov: Array2<F>, colcov: Array2<F>) -> LinalgResult<Self> {
let (m, n) = _mean.dim();
if row_cov.dim() != (m, m) {
return Err(LinalgError::ShapeError(format!(
"Row covariance must be {}x{}, got {:?}",
m,
m,
row_cov.dim()
)));
}
if colcov.dim() != (n, n) {
return Err(LinalgError::ShapeError(format!(
"Column covariance must be {}x{}, got {:?}",
n,
n,
colcov.dim()
)));
}
Ok(Self {
mean: _mean,
row_cov,
col_cov: colcov,
})
}
}
#[derive(Debug, Clone)]
pub struct WishartParams<F: Float> {
pub scale: Array2<F>,
pub dof: F,
}
impl<F: Float + Zero + One + Copy + std::fmt::Debug + std::fmt::Display> WishartParams<F> {
pub fn new(scale: Array2<F>, dof: F) -> LinalgResult<Self> {
let p = scale.nrows();
if scale.nrows() != scale.ncols() {
return Err(LinalgError::ShapeError(
"Scale matrix must be square".to_string(),
));
}
let min_dof = F::from(p).expect("Failed to convert to float") - F::one();
if dof <= min_dof {
return Err(LinalgError::InvalidInputError(format!(
"Degrees of freedom must be > {min_dof}, got {dof:?}"
)));
}
Ok(Self { scale, dof })
}
}
#[allow(dead_code)]
pub fn matrix_normal_logpdf<F>(x: &ArrayView2<F>, params: &MatrixNormalParams<F>) -> LinalgResult<F>
where
F: Float
+ Zero
+ One
+ Copy
+ std::fmt::Debug
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ scirs2_core::numeric::NumAssign
+ std::iter::Sum
+ Send
+ Sync
+ 'static,
{
let (m, n) = x.dim();
if params.mean.dim() != (m, n) {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions don't match: x is {}x{}, mean is {:?}",
m,
n,
params.mean.dim()
)));
}
let centered = x - ¶ms.mean;
let log_det_u = det(¶ms.row_cov.view(), None)?.ln();
let log_det_v = det(¶ms.col_cov.view(), None)?.ln();
let u_inv = inv(¶ms.row_cov.view(), None)?;
let v_inv = inv(¶ms.col_cov.view(), None)?;
let temp1 = centered.t().dot(&u_inv);
let temp2 = temp1.dot(¢ered);
let quad_form = v_inv.dot(&temp2).diag().sum();
let log_2pi = F::from(2.0 * PI).expect("Failed to convert to float").ln();
let normalizer = -F::from(m * n).expect("Failed to convert to float")
* F::from(0.5).expect("Failed to convert constant to float")
* log_2pi
- F::from(n).expect("Failed to convert to float")
* F::from(0.5).expect("Failed to convert constant to float")
* log_det_u
- F::from(m).expect("Failed to convert to float")
* F::from(0.5).expect("Failed to convert constant to float")
* log_det_v;
Ok(normalizer - F::from(0.5).expect("Failed to convert constant to float") * quad_form)
}
#[allow(dead_code)]
pub fn wishart_logpdf<F>(x: &ArrayView2<F>, params: &WishartParams<F>) -> LinalgResult<F>
where
F: Float
+ Zero
+ One
+ Copy
+ std::fmt::Debug
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ scirs2_core::numeric::NumAssign
+ std::iter::Sum
+ Send
+ Sync
+ 'static,
{
let p = x.nrows();
if x.nrows() != x.ncols() {
return Err(LinalgError::ShapeError(
"Matrix must be square for Wishart distribution".to_string(),
));
}
if params.scale.dim() != (p, p) {
return Err(LinalgError::ShapeError(format!(
"Scale matrix dimension mismatch: expected {}x{}, got {:?}",
p,
p,
params.scale.dim()
)));
}
let log_det_x = det(x, None)?.ln();
let log_det_v = det(¶ms.scale.view(), None)?.ln();
let v_inv = inv(¶ms.scale.view(), None)?;
let trace_term = v_inv.dot(x).diag().sum();
let log_gamma_p = multivariate_log_gamma(params.dof, p)?;
let log_2 = F::from(2.0)
.expect("Failed to convert constant to float")
.ln();
let log_normalizer = params.dof
* F::from(p).expect("Failed to convert to float")
* F::from(0.5).expect("Failed to convert constant to float")
* log_2
+ F::from(0.25).expect("Failed to convert constant to float")
* F::from(p * (p - 1)).expect("Operation failed")
* F::from(PI).expect("Failed to convert to float").ln()
+ log_gamma_p
+ params.dof * F::from(0.5).expect("Failed to convert constant to float") * log_det_v;
let main_term = (params.dof - F::from(p + 1).expect("Failed to convert to float"))
* F::from(0.5).expect("Failed to convert constant to float")
* log_det_x
- F::from(0.5).expect("Failed to convert constant to float") * trace_term;
Ok(main_term - log_normalizer)
}
#[allow(dead_code)]
pub fn samplematrix_normal<F>(
params: &MatrixNormalParams<F>,
rng_seed: Option<u64>,
) -> LinalgResult<Array2<F>>
where
F: Float
+ Zero
+ One
+ Copy
+ std::fmt::Debug
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ scirs2_core::numeric::NumAssign
+ std::iter::Sum
+ Send
+ Sync
+ 'static,
{
let (m, n) = params.mean.dim();
let z = random_normalmatrix((m, n), rng_seed)?;
let l_u = cholesky(¶ms.row_cov.view(), None)?;
let l_v = cholesky(¶ms.col_cov.view(), None)?;
let temp = l_u.dot(&z);
let sample = ¶ms.mean + &temp.dot(&l_v.t());
Ok(sample)
}
#[allow(dead_code)]
pub fn sample_wishart<F>(
params: &WishartParams<F>,
rng_seed: Option<u64>,
) -> LinalgResult<Array2<F>>
where
F: Float
+ Zero
+ One
+ Copy
+ std::fmt::Debug
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ scirs2_core::numeric::NumAssign
+ std::iter::Sum
+ Send
+ Sync
+ 'static,
{
let p = params.scale.nrows();
let mut a = Array2::zeros((p, p));
let z = random_normalmatrix::<F>((p, p), rng_seed)?;
for i in 0..p {
for j in 0..=i {
if i == j {
let chi_approx = z[[i, j]].abs()
* (params.dof - F::from(i).expect("Failed to convert to float")).sqrt();
a[[i, j]] = chi_approx;
} else {
a[[i, j]] = z[[i, j]];
}
}
}
let l = cholesky(¶ms.scale.view(), None)?;
let temp = l.dot(&a);
let sample = temp.dot(&temp.t());
Ok(sample)
}
#[allow(dead_code)]
fn multivariate_log_gamma<F>(x: F, p: usize) -> LinalgResult<F>
where
F: Float + Zero + One + Copy + std::fmt::Debug + scirs2_core::numeric::FromPrimitive,
{
let log_pi = F::from(PI).expect("Failed to convert to float").ln();
let mut result = F::from(p * (p - 1)).expect("Operation failed")
* F::from(0.25).expect("Failed to convert constant to float")
* log_pi;
for j in 1..=p {
let arg = x
+ (F::one() - F::from(j).expect("Failed to convert to float"))
* F::from(0.5).expect("Failed to convert constant to float");
let log_gamma_approx = if arg > F::one() {
(arg - F::from(0.5).expect("Failed to convert constant to float")) * arg.ln() - arg
+ F::from(0.5).expect("Failed to convert constant to float")
* F::from(2.0 * PI).expect("Failed to convert to float").ln()
} else {
F::zero() };
result = result + log_gamma_approx;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn testmatrix_normal_params() {
let mean = array![[1.0, 2.0], [3.0, 4.0]];
let row_cov = array![[1.0, 0.0], [0.0, 1.0]];
let col_cov = array![[2.0, 0.0], [0.0, 2.0]];
let params = MatrixNormalParams::new(mean, row_cov, col_cov).expect("Operation failed");
assert_eq!(params.mean.dim(), (2, 2));
assert_eq!(params.row_cov.dim(), (2, 2));
assert_eq!(params.col_cov.dim(), (2, 2));
}
#[test]
fn test_wishart_params() {
let scale = array![[2.0, 0.0], [0.0, 2.0]];
let dof = 3.0;
let params = WishartParams::new(scale, dof).expect("Operation failed");
assert_abs_diff_eq!(params.dof, 3.0, epsilon = 1e-10);
assert_eq!(params.scale.dim(), (2, 2));
}
#[test]
fn testmatrix_normal_logpdf() {
let x = array![[1.0, 0.0], [0.0, 1.0]];
let mean = array![[0.0, 0.0], [0.0, 0.0]];
let row_cov = array![[1.0, 0.0], [0.0, 1.0]];
let col_cov = array![[1.0, 0.0], [0.0, 1.0]];
let params = MatrixNormalParams::new(mean, row_cov, col_cov).expect("Operation failed");
let logpdf = matrix_normal_logpdf(&x.view(), ¶ms).expect("Operation failed");
assert!(logpdf.is_finite());
}
#[test]
fn test_samplematrix_normal() {
let mean = array![[0.0, 0.0], [0.0, 0.0]];
let row_cov = array![[1.0, 0.0], [0.0, 1.0]];
let col_cov = array![[1.0, 0.0], [0.0, 1.0]];
let params = MatrixNormalParams::new(mean, row_cov, col_cov).expect("Operation failed");
let sample = samplematrix_normal(¶ms, Some(42)).expect("Operation failed");
assert_eq!(sample.dim(), (2, 2));
assert!(sample.iter().all(|&x| x.is_finite()));
}
}
#[derive(Debug, Clone)]
pub struct InverseWishartParams<F: Float> {
pub scale: Array2<F>,
pub dof: F,
}
impl<F: Float + Zero + One + Copy + std::fmt::Debug + std::fmt::Display> InverseWishartParams<F> {
pub fn new(scale: Array2<F>, dof: F) -> LinalgResult<Self> {
if scale.nrows() != scale.ncols() {
return Err(LinalgError::ShapeError(format!(
"Scale matrix must be square, got shape {:?}",
scale.shape()
)));
}
let p = F::from(scale.nrows()).expect("Operation failed");
if dof <= p - F::one() {
return Err(LinalgError::InvalidInputError(format!(
"Degrees of freedom must be > dimension - 1, got dof = {} for dimension {}",
dof,
scale.nrows()
)));
}
Ok(InverseWishartParams { scale, dof })
}
}
#[allow(dead_code)]
pub fn inverse_wishart_logpdf<F>(
x: &ArrayView2<F>,
params: &InverseWishartParams<F>,
) -> LinalgResult<F>
where
F: Float
+ Zero
+ One
+ Copy
+ std::fmt::Debug
+ std::fmt::Display
+ scirs2_core::numeric::NumAssign
+ std::iter::Sum
+ 'static
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand,
{
let p = F::from(x.nrows()).expect("Operation failed");
let nu = params.dof;
let log_det_x = det(x, None)?.ln();
if !log_det_x.is_finite() {
return Err(LinalgError::ComputationError(
"Matrix must be positive definite".to_string(),
));
}
let log_det_psi = det(¶ms.scale.view(), None)?.ln();
if !log_det_psi.is_finite() {
return Err(LinalgError::ComputationError(
"Scale matrix must be positive definite".to_string(),
));
}
let x_inv = inv(x, None)?;
let psi_x_inv = params.scale.dot(&x_inv);
let trace_psi_x_inv = (0..psi_x_inv.nrows()).map(|i| psi_x_inv[[i, i]]).sum::<F>();
let half = F::from(0.5).expect("Failed to convert constant to float");
let two = F::from(2.0).expect("Failed to convert constant to float");
let pi = F::from(PI).expect("Failed to convert to float");
let log_norm = half * nu * log_det_psi
- half * nu * p * two.ln()
- F::from(0.25).expect("Failed to convert constant to float")
* p
* (p - F::one())
* pi.ln();
let mut log_gamma_p = F::zero();
for j in 0..p.to_usize().expect("Operation failed") {
let arg = half * (nu - F::from(j).expect("Failed to convert to float"));
if arg > F::one() {
let ln_2pi = F::from(2.0 * PI).expect("Failed to convert to float").ln();
log_gamma_p += (arg - half) * arg.ln() - arg + half * ln_2pi;
}
}
let log_density =
log_norm - log_gamma_p - half * (nu + p + F::one()) * log_det_x - half * trace_psi_x_inv;
Ok(log_density)
}
#[derive(Debug, Clone)]
pub struct MatrixTParams<F: Float> {
pub location: Array2<F>,
pub scale_u: Array2<F>,
pub scale_v: Array2<F>,
pub dof: F,
}
impl<F: Float + Zero + One + Copy + std::fmt::Debug + std::fmt::Display> MatrixTParams<F> {
pub fn new(
location: Array2<F>,
scale_u: Array2<F>,
scale_v: Array2<F>,
dof: F,
) -> LinalgResult<Self> {
if location.nrows() != scale_u.nrows() || location.ncols() != scale_v.nrows() {
return Err(LinalgError::ShapeError(
"Incompatible matrix dimensions".to_string(),
));
}
if scale_u.nrows() != scale_u.ncols() || scale_v.nrows() != scale_v.ncols() {
return Err(LinalgError::ShapeError(
"Scale matrices must be square".to_string(),
));
}
if dof <= F::zero() {
return Err(LinalgError::InvalidInputError(
"Degrees of freedom must be positive".to_string(),
));
}
Ok(MatrixTParams {
location,
scale_u,
scale_v,
dof,
})
}
}
#[allow(dead_code)]
pub fn matrix_t_logpdf<F>(x: &ArrayView2<F>, params: &MatrixTParams<F>) -> LinalgResult<F>
where
F: Float
+ Zero
+ One
+ Copy
+ std::fmt::Debug
+ std::fmt::Display
+ scirs2_core::numeric::NumAssign
+ std::iter::Sum
+ 'static
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand,
{
let (n, p) = (x.nrows(), x.ncols());
let nu = params.dof;
let residual = x - ¶ms.location;
let u_inv = inv(¶ms.scale_u.view(), None)?;
let v_inv = inv(¶ms.scale_v.view(), None)?;
let temp1 = u_inv.dot(&residual);
let temp2 = temp1.t().dot(&residual);
let temp3 = temp2.dot(&v_inv);
let quadratic_form = (0..temp3.nrows()).map(|i| temp3[[i, i]]).sum::<F>();
let log_det_u = det(¶ms.scale_u.view(), None)?.ln();
let log_det_v = det(¶ms.scale_v.view(), None)?.ln();
let half = F::from(0.5).expect("Failed to convert constant to float");
let pi = F::from(PI).expect("Failed to convert to float");
let n_f = F::from(n).expect("Failed to convert to float");
let p_f = F::from(p).expect("Failed to convert to float");
let log_norm = -half * n_f * log_det_u - half * p_f * log_det_v - half * n_f * p_f * pi.ln();
let log_density =
log_norm - half * (nu + n_f + p_f - F::one()) * (F::one() + quadratic_form / nu).ln();
Ok(log_density)
}
#[cfg(test)]
mod extended_tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_inverse_wishart_params() {
let scale = array![[2.0, 0.5], [0.5, 1.0]];
let dof = 5.0;
let params = InverseWishartParams::new(scale, dof).expect("Operation failed");
assert_eq!(params.dof, 5.0);
let invalid_params = InverseWishartParams::new(params.scale.clone(), 1.0);
assert!(invalid_params.is_err());
}
#[test]
fn testmatrix_t_params() {
let location = array![[0.0, 0.0], [0.0, 0.0]];
let scale_u = array![[1.0, 0.0], [0.0, 1.0]];
let scale_v = array![[1.0, 0.0], [0.0, 1.0]];
let dof = 3.0;
let params = MatrixTParams::new(location, scale_u, scale_v, dof).expect("Operation failed");
assert_eq!(params.dof, 3.0);
let invalid_params = MatrixTParams::new(
params.location.clone(),
params.scale_u.clone(),
params.scale_v.clone(),
-1.0,
);
assert!(invalid_params.is_err());
}
#[test]
fn test_inverse_wishart_logpdf() {
let x = array![[2.0, 0.5], [0.5, 1.5]];
let scale = array![[1.0, 0.0], [0.0, 1.0]];
let params = InverseWishartParams::new(scale, 5.0).expect("Operation failed");
let logpdf = inverse_wishart_logpdf(&x.view(), ¶ms).expect("Operation failed");
assert!(logpdf.is_finite());
}
#[test]
fn testmatrix_t_logpdf() {
let x = array![[1.0, 0.5], [0.5, 1.0]];
let location = array![[0.0, 0.0], [0.0, 0.0]];
let scale_u = array![[1.0, 0.0], [0.0, 1.0]];
let scale_v = array![[1.0, 0.0], [0.0, 1.0]];
let params = MatrixTParams::new(location, scale_u, scale_v, 3.0).expect("Operation failed");
let logpdf = matrix_t_logpdf(&x.view(), ¶ms).expect("Operation failed");
assert!(logpdf.is_finite());
}
}