#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::consts::HALF_LN_2PI_E;
use crate::consts::LN_2PI;
use crate::data::MvGaussianSuffStat;
use crate::impl_display;
use crate::traits::*;
use nalgebra::linalg::Cholesky;
use nalgebra::{DMatrix, DVector, Dyn};
use rand::Rng;
use std::fmt;
use std::sync::OnceLock;
#[derive(Clone, Debug)]
struct MvgCache {
pub cov_chol: Cholesky<f64, Dyn>,
pub cov_inv: DMatrix<f64>,
}
impl MvgCache {
pub fn from_cov(cov: &DMatrix<f64>) -> Result<Self, MvGaussianError> {
match cov.clone().cholesky() {
None => Err(MvGaussianError::CovNotPositiveSemiDefinite),
Some(cov_chol) => {
let cov_inv = cov_chol.inverse();
Ok(MvgCache { cov_chol, cov_inv })
}
}
}
#[inline]
pub fn from_chol(cov_chol: Cholesky<f64, Dyn>) -> Self {
let cov_inv = cov_chol.inverse();
MvgCache { cov_chol, cov_inv }
}
#[inline]
pub fn cov(&self) -> DMatrix<f64> {
let l = self.cov_chol.l();
&l * &l.transpose()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct MvGaussian {
mu: DVector<f64>,
cov: DMatrix<f64>,
#[cfg_attr(
feature = "serde1",
serde(skip, default = "default_cache_none")
)]
cache: OnceLock<MvgCache>,
}
#[allow(dead_code)]
#[cfg(feature = "serde1")]
fn default_cache_none() -> OnceLock<MvgCache> {
OnceLock::new()
}
impl PartialEq for MvGaussian {
fn eq(&self, other: &MvGaussian) -> bool {
self.mu == other.mu && self.cov == other.cov
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub enum MvGaussianError {
MuCovDimensionMismatch {
n_mu: usize,
n_cov: usize,
},
CovNotSquare {
nrows: usize,
ncols: usize,
},
CovNotPositiveSemiDefinite,
ZeroDimension,
}
impl MvGaussian {
pub fn new(
mu: DVector<f64>,
cov: DMatrix<f64>,
) -> Result<Self, MvGaussianError> {
let cov_rows = cov.nrows();
let cov_cols = cov.ncols();
if cov_rows != cov_cols {
Err(MvGaussianError::CovNotSquare {
nrows: cov_rows,
ncols: cov_cols,
})
} else if mu.len() != cov_rows {
Err(MvGaussianError::MuCovDimensionMismatch {
n_mu: mu.len(),
n_cov: cov_rows,
})
} else {
let cache = OnceLock::from(MvgCache::from_cov(&cov)?);
Ok(MvGaussian { mu, cov, cache })
}
}
pub fn new_cholesky(
mu: DVector<f64>,
cov_chol: Cholesky<f64, Dyn>,
) -> Result<Self, MvGaussianError> {
let l = cov_chol.l();
let cov = &l * &l.transpose();
if mu.len() != cov.nrows() {
Err(MvGaussianError::MuCovDimensionMismatch {
n_mu: mu.len(),
n_cov: cov.nrows(),
})
} else {
let cache = OnceLock::from(MvgCache::from_chol(cov_chol));
Ok(MvGaussian { mu, cov, cache })
}
}
#[inline]
pub fn new_unchecked(mu: DVector<f64>, cov: DMatrix<f64>) -> Self {
let cache = OnceLock::from(MvgCache::from_cov(&cov).unwrap());
MvGaussian { mu, cov, cache }
}
#[inline]
pub fn new_cholesky_unchecked(
mu: DVector<f64>,
cov_chol: Cholesky<f64, Dyn>,
) -> Self {
let cache = OnceLock::from(MvgCache::from_chol(cov_chol));
let cov = cache.get().unwrap().cov();
MvGaussian { mu, cov, cache }
}
#[inline]
pub fn standard(dims: usize) -> Result<Self, MvGaussianError> {
if dims == 0 {
Err(MvGaussianError::ZeroDimension)
} else {
let mu = DVector::zeros(dims);
let cov = DMatrix::identity(dims, dims);
let cov_chol = cov.clone().cholesky().unwrap();
let cache = OnceLock::from(MvgCache::from_chol(cov_chol));
Ok(MvGaussian { mu, cov, cache })
}
}
#[inline]
pub fn ndims(&self) -> usize {
self.mu.len()
}
#[inline]
pub fn mu(&self) -> &DVector<f64> {
&self.mu
}
#[inline]
pub fn cov(&self) -> &DMatrix<f64> {
&self.cov
}
#[inline]
pub fn set_mu(&mut self, mu: DVector<f64>) -> Result<(), MvGaussianError> {
if mu.len() != self.cov.nrows() {
Err(MvGaussianError::MuCovDimensionMismatch {
n_mu: mu.len(),
n_cov: self.cov.nrows(),
})
} else {
self.mu = mu;
Ok(())
}
}
#[inline]
pub fn set_mu_unchecked(&mut self, mu: DVector<f64>) {
self.mu = mu;
}
pub fn set_cov(
&mut self,
cov: DMatrix<f64>,
) -> Result<(), MvGaussianError> {
let cov_rows = cov.nrows();
if self.mu.len() != cov_rows {
Err(MvGaussianError::MuCovDimensionMismatch {
n_mu: self.mu.len(),
n_cov: cov.nrows(),
})
} else if cov_rows != cov.ncols() {
Err(MvGaussianError::CovNotSquare {
nrows: cov_rows,
ncols: cov.ncols(),
})
} else {
let cache = MvgCache::from_cov(&cov)?;
self.cov = cov;
self.cache = OnceLock::new();
self.cache.set(cache).unwrap();
Ok(())
}
}
#[inline]
pub fn set_cov_unchecked(&mut self, cov: DMatrix<f64>) {
let cache = MvgCache::from_cov(&cov).unwrap();
self.cov = cov;
self.cache = OnceLock::from(cache);
}
#[inline]
fn cache(&self) -> &MvgCache {
self.cache
.get_or_init(|| MvgCache::from_cov(&self.cov).unwrap())
}
}
impl From<&MvGaussian> for String {
fn from(mvg: &MvGaussian) -> String {
format!("Nₖ({})\n μ: {}\n σ: {})", mvg.ndims(), mvg.mu, mvg.cov)
}
}
impl_display!(MvGaussian);
impl Rv<DVector<f64>> for MvGaussian {
fn ln_f(&self, x: &DVector<f64>) -> f64 {
let diff = x - &self.mu;
let det_sqrt: f64 = self
.cache()
.cov_chol
.l_dirty()
.diagonal()
.row_iter()
.fold(1.0, |acc, y| acc * y[0]);
let det = det_sqrt * det_sqrt;
let inv = &(self.cache().cov_inv);
let term: f64 = (diff.transpose() * inv * &diff)[0];
-0.5 * (det.ln() + (diff.nrows() as f64).mul_add(LN_2PI, term))
}
fn draw<R: Rng>(&self, rng: &mut R) -> DVector<f64> {
let dims = self.mu.len();
let norm = rand_distr::StandardNormal;
let vals: Vec<f64> = (0..dims).map(|_| rng.sample(norm)).collect();
let a = self.cache().cov_chol.l_dirty();
let z: DVector<f64> = DVector::from_column_slice(&vals);
DVector::from_fn(dims, |i, _| {
let mut out: f64 = self.mu[i];
for j in 0..=i {
out += a[(i, j)] * z[j];
}
out
})
}
}
impl Support<DVector<f64>> for MvGaussian {
fn supports(&self, x: &DVector<f64>) -> bool {
x.len() == self.mu.len()
}
}
impl ContinuousDistr<DVector<f64>> for MvGaussian {}
impl Mean<DVector<f64>> for MvGaussian {
fn mean(&self) -> Option<DVector<f64>> {
Some(self.mu.clone())
}
}
impl Mode<DVector<f64>> for MvGaussian {
fn mode(&self) -> Option<DVector<f64>> {
Some(self.mu.clone())
}
}
impl Variance<DMatrix<f64>> for MvGaussian {
fn variance(&self) -> Option<DMatrix<f64>> {
Some(self.cov.clone())
}
}
impl Entropy for MvGaussian {
fn entropy(&self) -> f64 {
let det_sqrt: f64 = self
.cache()
.cov_chol
.l_dirty()
.diagonal()
.row_iter()
.fold(1.0, |acc, x| acc * x[0]);
let det = det_sqrt * det_sqrt;
det.ln()
.mul_add(0.5, HALF_LN_2PI_E * (self.cov.nrows() as f64))
}
}
impl HasSuffStat<DVector<f64>> for MvGaussian {
type Stat = MvGaussianSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
MvGaussianSuffStat::new(self.mu.len())
}
fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
let n = stat.n() as f64;
let k = stat.sum_x().len() as f64;
let x_bar = stat.sum_x() / n;
let sigma_hat =
stat.sum_x_sq() - (stat.sum_x() * stat.sum_x().transpose()) / n;
let sigma_inv = &self.cache().cov_inv;
let ln_cov_det = self.cache().cov_chol.ln_determinant();
let neg_half_n = -0.5 * n;
neg_half_n.mul_add(
LN_2PI.mul_add(k, ln_cov_det)
+ ((&x_bar - &self.mu).transpose()
* sigma_inv
* (&x_bar - &self.mu))[0],
-(sigma_inv * sigma_hat).trace() / 2.0,
)
}
}
impl std::error::Error for MvGaussianError {}
impl fmt::Display for MvGaussianError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ZeroDimension => write!(f, "requested dimension is too low"),
Self::CovNotPositiveSemiDefinite => {
write!(f, "covariance is not positive semi-definite")
}
Self::MuCovDimensionMismatch { n_mu, n_cov } => write!(
f,
"mean vector and covariance matrix do not align. mu is {} \
dimensions but cov is {} dimensions",
n_mu, n_cov
),
Self::CovNotSquare { nrows, ncols } => write!(
f,
"covariance matrix is not square ({} x {})",
nrows, ncols
),
}
}
}
#[cfg(test)]
mod tests {
use nalgebra::{dmatrix, dvector};
use rand::{thread_rng, SeedableRng};
use super::*;
use crate::dist::Gaussian;
use crate::misc::{ks_test, mardia};
use crate::test_basic_impls;
const TOL: f64 = 1E-12;
const NTRIES: usize = 5;
const KS_PVAL: f64 = 0.2;
const MARDIA_PVAL: f64 = 0.2;
test_basic_impls!(MvGaussian::standard(3).unwrap(), DVector::zeros(3));
#[test]
fn new() {
let mu = DVector::zeros(3);
let cov = DMatrix::identity(3, 3);
assert!(MvGaussian::new(mu, cov).is_ok());
}
#[test]
fn new_should_reject_cov_too_big() {
let mu = DVector::zeros(3);
let cov = DMatrix::identity(4, 4);
let mvg = MvGaussian::new(mu, cov);
assert_eq!(
mvg,
Err(MvGaussianError::MuCovDimensionMismatch { n_mu: 3, n_cov: 4 })
)
}
#[test]
fn new_should_reject_cov_too_small() {
let mu = DVector::zeros(3);
let cov = DMatrix::identity(2, 2);
let mvg = MvGaussian::new(mu, cov);
assert_eq!(
mvg,
Err(MvGaussianError::MuCovDimensionMismatch { n_mu: 3, n_cov: 2 })
)
}
#[test]
fn new_should_reject_cov_not_square() {
let mu = DVector::zeros(3);
let cov = DMatrix::identity(3, 2);
let mvg = MvGaussian::new(mu, cov);
assert_eq!(
mvg,
Err(MvGaussianError::CovNotSquare { nrows: 3, ncols: 2 })
);
}
#[test]
fn ln_f_standard_x_zeros() {
let mvg = MvGaussian::standard(3).unwrap();
let x = DVector::<f64>::zeros(3);
assert::close(mvg.ln_f(&x), -2.756_815_599_614_018, TOL);
}
#[test]
fn ln_f_standard_x_nonzeros() {
let mvg = MvGaussian::standard(3).unwrap();
let x = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
assert::close(mvg.ln_f(&x), -26.906_815_599_614_02, TOL);
}
#[test]
fn ln_f_nonstandard_zeros() {
let cov_vals = vec![
1.017_427_88,
0.365_866_52,
-0.656_204_86,
0.365_866_52,
1.005_645_53,
-0.425_972_61,
-0.656_204_86,
-0.425_972_61,
1.272_479_72,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
let x = DVector::<f64>::zeros(3);
assert::close(mvg.ln_f(&x), -24.602_370_253_215_66, TOL);
}
#[test]
fn ln_f_nonstandard_nonzeros() {
let cov_vals = vec![
1.017_427_88,
0.365_866_52,
-0.656_204_86,
0.365_866_52,
1.005_645_53,
-0.425_972_61,
-0.656_204_86,
-0.425_972_61,
1.272_479_72,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
let x = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
assert::close(mvg.ln_f(&x), -2.591_535_053_811_229_6, TOL);
}
#[test]
fn sample_returns_proper_number_of_draws() {
let cov_vals = vec![
1.017_427_88,
0.365_866_52,
-0.656_204_86,
0.365_866_52,
1.005_645_53,
-0.425_972_61,
-0.656_204_86,
-0.425_972_61,
1.272_479_72,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
let mut rng = rand::thread_rng();
let xs: Vec<DVector<f64>> = mvg.sample(103, &mut rng);
assert_eq!(xs.len(), 103);
}
#[test]
fn standard_entropy() {
let mvg = MvGaussian::standard(3).unwrap();
assert::close(mvg.entropy(), 4.256_815_599_614_018_5, TOL);
}
#[test]
fn nonstandard_entropy() {
let cov_vals = vec![
1.017_427_88,
0.365_866_52,
-0.656_204_86,
0.365_866_52,
1.005_645_53,
-0.425_972_61,
-0.656_204_86,
-0.425_972_61,
1.272_479_72,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
assert::close(mvg.entropy(), 4.091_535_053_811_230_5, TOL);
}
#[test]
fn standard_draw_marginals() {
let mut rng = rand::thread_rng();
let mvg = MvGaussian::standard(2).unwrap();
let g = Gaussian::standard();
let cdf = |x: f64| g.cdf(&x);
let passed = (0..NTRIES).fold(false, |acc, _| {
if acc {
acc
} else {
let xys = mvg.sample(500, &mut rng);
let xs: Vec<f64> =
xys.iter().map(|xy: &DVector<f64>| xy[0]).collect();
let ys: Vec<f64> =
xys.iter().map(|xy: &DVector<f64>| xy[1]).collect();
let (_, px) = ks_test(&xs, cdf);
let (_, py) = ks_test(&ys, cdf);
px > KS_PVAL && py > KS_PVAL
}
});
assert!(passed);
}
#[test]
fn standard_draw_mardia() {
let mut rng = rand::thread_rng();
let mvg = MvGaussian::standard(4).unwrap();
let passed = (0..NTRIES).fold(false, |acc, _| {
if acc {
acc
} else {
let xys = mvg.sample(500, &mut rng);
let (pa, pb) = mardia(&xys);
pa > MARDIA_PVAL && pb > MARDIA_PVAL
}
});
assert!(passed);
}
#[test]
fn nonstandard_draw_mardia() {
let mut rng = rand::thread_rng();
let cov_vals = vec![
1.017_427_88,
0.365_866_52,
-0.656_204_86,
0.365_866_52,
1.005_645_53,
-0.425_972_61,
-0.656_204_86,
-0.425_972_61,
1.272_479_72,
];
let cov: DMatrix<f64> = DMatrix::from_row_slice(3, 3, &cov_vals);
let mu = DVector::<f64>::from_column_slice(&[0.5, 3.1, -6.2]);
let mvg = MvGaussian::new(mu, cov).unwrap();
let passed = (0..NTRIES).fold(false, |acc, _| {
if acc {
acc
} else {
let xys = mvg.sample(500, &mut rng);
let (pa, pb) = mardia(&xys);
pa > MARDIA_PVAL && pb > MARDIA_PVAL
}
});
assert!(passed);
}
#[test]
fn suff_stat_ln_f() {
let f = MvGaussian::new(
dvector![1.0, 2.0],
dmatrix![1.0, 3.0/5.0; 3.0/5.0, 2.0;],
)
.unwrap();
let mut stat = f.empty_suffstat();
stat.observe_many(&[
dvector![1.0, 2.0],
dvector![3.0, 4.0],
dvector![5.0, 6.0],
]);
assert::close(f.ln_f_stat(&stat), -17.231_285_318_08, 1E-12);
}
#[test]
fn suff_stat_ln_f_fuzzy() {
let f = MvGaussian::new(
dvector![1.0, 2.0],
dmatrix![1.0, 3.0/5.0; 3.0/5.0, 2.0;],
)
.unwrap();
let g =
MvGaussian::new(dvector![0.0, 0.0], dmatrix![1.0, 0.0; 0.0, 1.0;])
.unwrap();
let seed: u64 = thread_rng().gen();
dbg!(&seed); let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
for _ in 0..100 {
let data: Vec<DVector<f64>> = g.sample(11, &mut rng);
let mut stat = f.empty_suffstat();
stat.observe_many(&data);
let ln_f_sum: f64 = data.iter().map(|x| f.ln_f(x)).sum();
assert::close(f.ln_f_stat(&stat), ln_f_sum, 1E-13);
}
}
}