use crate::error::{StatsError, StatsResult};
use crate::sampling::SampleableDistribution;
use crate::traits::{Distribution as DistributionTrait, MultivariateDistribution};
use scirs2_core::ndarray::{
s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix1, Ix2,
};
use scirs2_core::random::prelude::*;
use scirs2_core::random::{Distribution as RandDistribution, Normal as RandNormal};
use statrs::statistics::Statistics;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct MultivariateNormal {
pub mean: Array1<f64>,
pub cov: Array2<f64>,
pub dim: usize,
cholesky_l: Array2<f64>,
cov_det: f64,
cov_inv: Array2<f64>,
}
impl MultivariateNormal {
pub fn new<D1, D2>(mean: ArrayBase<D1, Ix1>, cov: ArrayBase<D2, Ix2>) -> StatsResult<Self>
where
D1: Data<Elem = f64>,
D2: Data<Elem = f64>,
{
let dim = mean.len();
if cov.shape()[0] != dim || cov.shape()[1] != dim {
return Err(StatsError::DimensionMismatch(format!(
"Covariance matrix shape ({:?}) must match _mean vector length ({})",
cov.shape(),
dim
)));
}
let _mean = mean.to_owned();
let cov = cov.to_owned();
let cholesky_l = compute_cholesky(&cov).map_err(|_| {
StatsError::DomainError("Covariance matrix must be positive definite".to_string())
})?;
let cov_det = {
let mut det = 1.0;
for i in 0..dim {
det *= cholesky_l[[i, i]];
}
det * det };
let cov_inv = compute_inverse_from_cholesky(&cholesky_l).map_err(|_| {
StatsError::ComputationError("Failed to compute matrix inverse".to_string())
})?;
Ok(MultivariateNormal {
mean: _mean,
cov,
dim,
cholesky_l,
cov_det,
cov_inv,
})
}
pub fn pdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
where
D: Data<Elem = f64>,
{
if x.len() != self.dim {
return 0.0; }
let pi = std::f64::consts::PI;
let two = 2.0;
let constant_factor = 1.0 / ((two * pi).powf(self.dim as f64 / two) * self.cov_det.sqrt());
let diff = x - &self.mean;
let mahalanobis_squared = self.mahalanobis_distance_squared(&diff.view());
constant_factor * (-mahalanobis_squared / two).exp()
}
fn mahalanobis_distance_squared(&self, diff: &ArrayView1<f64>) -> f64 {
diff.dot(&self.cov_inv.dot(diff))
}
pub fn rvs(&self, size: usize) -> StatsResult<Array2<f64>> {
let mut rng = thread_rng();
let normal = RandNormal::new(0.0, 1.0).expect("Operation failed");
let mut std_normal_samples = Array2::<f64>::zeros((size, self.dim));
for i in 0..size {
for j in 0..self.dim {
let sample = normal.sample(&mut rng);
std_normal_samples[[i, j]] = sample;
}
}
let mut samples = Array2::<f64>::zeros((size, self.dim));
for i in 0..size {
let z = std_normal_samples.slice(s![i, ..]);
let mut transformed = Array1::<f64>::zeros(self.dim);
for j in 0..self.dim {
for k in 0..=j {
transformed[j] += self.cholesky_l[[j, k]] * z[k];
}
}
for j in 0..self.dim {
samples[[i, j]] = self.mean[j] + transformed[j];
}
}
Ok(samples)
}
pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
let samples = self.rvs(1)?;
Ok(samples.index_axis(Axis(0), 0).to_owned())
}
pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
where
D: Data<Elem = f64>,
{
if x.len() != self.dim {
return f64::NEG_INFINITY; }
let pi = std::f64::consts::PI;
let two = 2.0;
let log_const = -(self.dim as f64) / two * (two * pi).ln() - self.cov_det.ln() / two;
let diff = x - &self.mean;
let mahalanobis_squared = self.mahalanobis_distance_squared(&diff.view());
log_const - mahalanobis_squared / two
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn cov(&self) -> ArrayView2<f64> {
self.cov.view()
}
pub fn mean(&self) -> ArrayView1<f64> {
self.mean.view()
}
}
#[allow(dead_code)]
pub fn compute_cholesky(a: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = a.shape()[0];
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = 0.0;
if j == i {
for k in 0..j {
sum += l[[j, k]] * l[[j, k]];
}
let diag_value = a[[j, j]] - sum;
if diag_value <= 0.0 {
return Err("Matrix is not positive definite".to_string());
}
l[[j, j]] = diag_value.sqrt();
} else {
for k in 0..j {
sum += l[[i, k]] * l[[j, k]];
}
l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
}
}
}
Ok(l)
}
#[allow(dead_code)]
pub fn compute_inverse_from_cholesky(l: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = l.shape()[0];
let mut inv = Array2::<f64>::zeros((n, n));
let mut l_inv = Array2::<f64>::zeros((n, n));
for i in 0..n {
l_inv[[i, i]] = 1.0 / l[[i, i]];
}
for i in 1..n {
for j in 0..i {
let mut sum = 0.0;
for k in j..i {
sum += l[[i, k]] * l_inv[[k, j]];
}
l_inv[[i, j]] = -sum / l[[i, i]];
}
}
for i in 0..n {
for j in 0..n {
let mut sum = 0.0;
let max_idx = i.max(j);
for k in max_idx..n {
sum += l_inv[[k, i]] * l_inv[[k, j]];
}
inv[[i, j]] = sum;
}
}
Ok(inv)
}
#[allow(dead_code)]
pub fn multivariate_normal<D1, D2>(
mean: ArrayBase<D1, Ix1>,
cov: ArrayBase<D2, Ix2>,
) -> StatsResult<MultivariateNormal>
where
D1: Data<Elem = f64>,
D2: Data<Elem = f64>,
{
MultivariateNormal::new(mean, cov)
}
impl DistributionTrait<f64> for MultivariateNormal {
fn mean(&self) -> f64 {
if self.dim > 0 {
self.mean[0]
} else {
0.0
}
}
fn var(&self) -> f64 {
if self.dim > 0 {
self.cov[[0, 0]]
} else {
0.0
}
}
fn std(&self) -> f64 {
self.var().sqrt()
}
fn rvs(&self, size: usize) -> StatsResult<Array1<f64>> {
let samples_matrix = self.rvs(size)?;
Ok(samples_matrix.column(0).to_owned())
}
fn entropy(&self) -> f64 {
let k = self.dim as f64;
let pi = std::f64::consts::PI;
k / 2.0 + k / 2.0 * (2.0 * pi).ln() + 0.5 * self.cov_det.ln()
}
}
impl MultivariateDistribution<f64> for MultivariateNormal {
fn pdf(&self, x: &Array1<f64>) -> f64 {
self.pdf(x)
}
fn rvs(&self, size: usize) -> StatsResult<scirs2_core::ndarray::Array2<f64>> {
self.rvs(size)
}
fn mean(&self) -> Array1<f64> {
self.mean.clone()
}
fn cov(&self) -> scirs2_core::ndarray::Array2<f64> {
self.cov.clone()
}
fn dim(&self) -> usize {
self.dim
}
fn logpdf(&self, x: &Array1<f64>) -> f64 {
self.logpdf(x)
}
fn rvs_single(&self) -> StatsResult<Vec<f64>> {
let sample = self.rvs(1)?;
Ok(sample.row(0).to_vec())
}
}
impl SampleableDistribution<Array1<f64>> for MultivariateNormal {
fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
let samples_matrix = self.rvs(size)?;
let mut result = Vec::with_capacity(size);
for i in 0..size {
let row = samples_matrix.slice(s![i, ..]).to_owned();
result.push(row);
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::{array, Axis};
#[test]
fn test_mvn_creation() {
let mean = array![0.0, 0.0];
let cov = array![[1.0, 0.0], [0.0, 1.0]];
let mvn = MultivariateNormal::new(mean.clone(), cov.clone()).expect("Operation failed");
assert_eq!(mvn.dim, 2);
assert_eq!(mvn.mean, mean);
assert_eq!(mvn.cov, cov);
let mean3 = array![1.0, 2.0, 3.0];
let cov3 = array![[1.0, 0.5, 0.3], [0.5, 2.0, 0.2], [0.3, 0.2, 1.5]];
let mvn3 = MultivariateNormal::new(mean3.clone(), cov3.clone()).expect("Operation failed");
assert_eq!(mvn3.dim, 3);
assert_eq!(mvn3.mean, mean3);
assert_eq!(mvn3.cov, cov3);
}
#[test]
fn test_mvn_creation_errors() {
let mean = array![0.0, 0.0, 0.0];
let cov = array![[1.0, 0.0], [0.0, 1.0]];
assert!(MultivariateNormal::new(mean, cov).is_err());
let mean = array![0.0, 0.0];
let cov = array![[1.0, 2.0], [2.0, 1.0]]; assert!(MultivariateNormal::new(mean, cov).is_err());
}
#[test]
fn test_mvn_pdf() {
let mean = array![0.0, 0.0];
let cov = array![[1.0, 0.0], [0.0, 1.0]];
let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
let pdf_at_origin = mvn.pdf(&array![0.0, 0.0]);
assert_relative_eq!(pdf_at_origin, 0.15915494, epsilon = 1e-7);
let pdf_at_one = mvn.pdf(&array![1.0, 1.0]);
assert_relative_eq!(pdf_at_one, 0.05854983, epsilon = 1e-7);
}
#[test]
fn test_mvn_logpdf() {
let mean = array![0.0, 0.0];
let cov = array![[1.0, 0.0], [0.0, 1.0]];
let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
let logpdf_at_origin = mvn.logpdf(&array![0.0, 0.0]);
assert_relative_eq!(logpdf_at_origin, -1.837877, epsilon = 1e-6);
let x = array![1.0, 1.0];
let pdf = mvn.pdf(&x);
let logpdf = mvn.logpdf(&x);
assert_relative_eq!(logpdf.exp(), pdf, epsilon = 1e-7);
}
#[test]
fn test_mvn_rvs() {
let mean = array![1.0, 2.0];
let cov = array![[1.0, 0.5], [0.5, 2.0]];
let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
let n_samples_ = 500;
let samples = mvn.rvs(n_samples_).expect("Operation failed");
assert_eq!(samples.shape(), &[n_samples_, 2]);
let sample_mean = samples.mean_axis(Axis(0)).expect("Operation failed");
assert_relative_eq!(sample_mean[0], 1.0, epsilon = 0.3);
assert_relative_eq!(sample_mean[1], 2.0, epsilon = 0.3);
let centered = samples.mapv(|x| x) - &sample_mean;
let sample_cov = centered.t().dot(¢ered) / (n_samples_ as f64 - 1.0);
assert_relative_eq!(sample_cov[[0, 0]], 1.0, epsilon = 0.5);
assert_relative_eq!(sample_cov[[1, 1]], 2.0, epsilon = 0.5);
assert_relative_eq!(sample_cov[[0, 1]].abs(), 0.5, epsilon = 0.3);
}
#[test]
fn test_mvn_rvs_single() {
let mean = array![1.0, 2.0];
let cov = array![[1.0, 0.5], [0.5, 2.0]];
let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
let sample = mvn.rvs_single().expect("Operation failed");
assert_eq!(sample.len(), 2);
}
#[test]
fn test_cholesky() {
let a = array![[4.0, 2.0, 2.0], [2.0, 5.0, 3.0], [2.0, 3.0, 6.0]];
let l = compute_cholesky(&a).expect("Operation failed");
let mut a_reconstructed = Array2::<f64>::zeros((3, 3));
for i in 0..3 {
for j in 0..3 {
for k in 0..=j.min(i) {
a_reconstructed[[i, j]] += l[[i, k]] * l[[j, k]];
}
}
}
for i in 0..3 {
for j in 0..3 {
assert_relative_eq!(a[[i, j]], a_reconstructed[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_inverse() {
let a = array![[4.0, 2.0, 2.0], [2.0, 5.0, 3.0], [2.0, 3.0, 6.0]];
let l = compute_cholesky(&a).expect("Operation failed");
let a_inv = compute_inverse_from_cholesky(&l).expect("Operation failed");
let identity = a.dot(&a_inv);
for i in 0..3 {
for j in 0..3 {
if i == j {
assert_relative_eq!(identity[[i, j]], 1.0, epsilon = 1e-10);
} else {
assert_relative_eq!(identity[[i, j]], 0.0, epsilon = 1e-10);
}
}
}
}
}