use nalgebra::{DMatrix, DVector};
use crate::utils::errors::BraheError;
pub fn isotropic_covariance(dim: usize, sigma: f64) -> DMatrix<f64> {
DMatrix::from_diagonal(&DVector::from_element(dim, sigma * sigma))
}
pub fn diagonal_covariance(sigmas: &[f64]) -> DMatrix<f64> {
let variances: Vec<f64> = sigmas.iter().map(|s| s * s).collect();
DMatrix::from_diagonal(&DVector::from_vec(variances))
}
pub fn validate_covariance(matrix: DMatrix<f64>) -> Result<DMatrix<f64>, BraheError> {
let n = matrix.nrows();
let m = matrix.ncols();
if n != m {
return Err(BraheError::Error(format!(
"Covariance matrix must be square, got {}x{}",
n, m
)));
}
let tol = 1e-10;
for i in 0..n {
for j in (i + 1)..n {
let aij = matrix[(i, j)];
let aji = matrix[(j, i)];
let scale = aij.abs().max(aji.abs()).max(1.0);
if (aij - aji).abs() > tol * scale {
return Err(BraheError::Error(format!(
"Covariance matrix is not symmetric: element ({},{})={} != ({},{})={}",
i, j, aij, j, i, aji
)));
}
}
}
Ok(matrix)
}
pub fn covariance_from_upper_triangular(
dim: usize,
upper: &[f64],
) -> Result<DMatrix<f64>, BraheError> {
let expected = dim * (dim + 1) / 2;
if upper.len() != expected {
return Err(BraheError::Error(format!(
"Upper-triangular covariance for {}x{} matrix requires {} elements, got {}",
dim,
dim,
expected,
upper.len()
)));
}
let mut matrix = DMatrix::zeros(dim, dim);
let mut idx = 0;
for i in 0..dim {
for j in i..dim {
matrix[(i, j)] = upper[idx];
matrix[(j, i)] = upper[idx];
idx += 1;
}
}
Ok(matrix)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_isotropic_covariance() {
let r = isotropic_covariance(3, 10.0);
assert_eq!(r.nrows(), 3);
assert_eq!(r.ncols(), 3);
assert_abs_diff_eq!(r[(0, 0)], 100.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(1, 1)], 100.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(2, 2)], 100.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(0, 1)], 0.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(1, 2)], 0.0, epsilon = 1e-15);
}
#[test]
fn test_isotropic_covariance_1d() {
let r = isotropic_covariance(1, 5.0);
assert_eq!(r.nrows(), 1);
assert_abs_diff_eq!(r[(0, 0)], 25.0, epsilon = 1e-15);
}
#[test]
fn test_diagonal_covariance() {
let r = diagonal_covariance(&[5.0, 10.0, 15.0]);
assert_eq!(r.nrows(), 3);
assert_abs_diff_eq!(r[(0, 0)], 25.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(1, 1)], 100.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(2, 2)], 225.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(0, 1)], 0.0, epsilon = 1e-15);
}
#[test]
fn test_diagonal_covariance_6d() {
let sigmas = [5.0, 10.0, 15.0, 0.05, 0.1, 0.15];
let r = diagonal_covariance(&sigmas);
assert_eq!(r.nrows(), 6);
assert_abs_diff_eq!(r[(0, 0)], 25.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(3, 3)], 0.0025, epsilon = 1e-15);
assert_abs_diff_eq!(r[(5, 5)], 0.0225, epsilon = 1e-15);
}
#[test]
fn test_validate_covariance_valid() {
let r = DMatrix::from_diagonal_element(3, 3, 100.0);
assert!(validate_covariance(r).is_ok());
}
#[test]
fn test_validate_covariance_symmetric_with_offdiag() {
let mut r = DMatrix::zeros(3, 3);
r[(0, 0)] = 100.0;
r[(1, 1)] = 200.0;
r[(2, 2)] = 300.0;
r[(0, 1)] = 5.0;
r[(1, 0)] = 5.0;
assert!(validate_covariance(r).is_ok());
}
#[test]
fn test_validate_covariance_non_square() {
let r = DMatrix::zeros(3, 4);
let result = validate_covariance(r);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("square"));
}
#[test]
fn test_validate_covariance_asymmetric() {
let mut r = DMatrix::zeros(3, 3);
r[(0, 0)] = 100.0;
r[(1, 1)] = 200.0;
r[(2, 2)] = 300.0;
r[(0, 1)] = 5.0;
r[(1, 0)] = 50.0; let result = validate_covariance(r);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("symmetric"));
}
#[test]
fn test_covariance_from_upper_triangular_3d() {
let upper = [100.0, 5.0, 0.0, 225.0, 10.0, 400.0];
let r = covariance_from_upper_triangular(3, &upper).unwrap();
assert_eq!(r.nrows(), 3);
assert_eq!(r.ncols(), 3);
assert_abs_diff_eq!(r[(0, 0)], 100.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(1, 1)], 225.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(2, 2)], 400.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(0, 1)], 5.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(1, 0)], 5.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(0, 2)], 0.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(2, 0)], 0.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(1, 2)], 10.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(2, 1)], 10.0, epsilon = 1e-15);
}
#[test]
fn test_covariance_from_upper_triangular_6d() {
let mut upper = vec![0.0; 21];
upper[0] = 1.0; upper[6] = 4.0; upper[11] = 9.0; upper[15] = 16.0; upper[18] = 25.0; upper[20] = 36.0; upper[3] = 7.5;
let r = covariance_from_upper_triangular(6, &upper).unwrap();
assert_eq!(r.nrows(), 6);
assert_abs_diff_eq!(r[(0, 0)], 1.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(1, 1)], 4.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(4, 4)], 25.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(5, 5)], 36.0, epsilon = 1e-15);
assert_abs_diff_eq!(r[(0, 3)], 7.5, epsilon = 1e-15);
assert_abs_diff_eq!(r[(3, 0)], 7.5, epsilon = 1e-15);
}
#[test]
fn test_covariance_from_upper_triangular_wrong_count() {
let result = covariance_from_upper_triangular(3, &[1.0, 2.0, 3.0]);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("6 elements"));
}
#[test]
fn test_covariance_from_upper_triangular_1d() {
let r = covariance_from_upper_triangular(1, &[42.0]).unwrap();
assert_eq!(r.nrows(), 1);
assert_abs_diff_eq!(r[(0, 0)], 42.0, epsilon = 1e-15);
}
}