use crate::error::{StatsError, StatsResult};
pub fn invert(matrix: &[f64], dim: usize, eps: f64) -> StatsResult<Vec<f64>> {
if matrix.len() != dim * dim {
return Err(StatsError::invalid_input(format!(
"linalg::invert: expected {} elements for {}×{} matrix, got {}",
dim * dim,
dim,
dim,
matrix.len()
)));
}
let w = 2 * dim;
let mut aug = vec![0.0; dim * w];
for r in 0..dim {
for c in 0..dim {
aug[r * w + c] = matrix[r * dim + c];
}
aug[r * w + dim + r] = 1.0;
}
invert_augmented(aug, dim, eps)
}
pub fn invert_with_ridge(matrix: &[f64], dim: usize, ridge_factor: f64) -> StatsResult<Vec<f64>> {
if matrix.len() != dim * dim {
return Err(StatsError::invalid_input(format!(
"linalg::invert_with_ridge: expected {} elements, got {}",
dim * dim,
matrix.len()
)));
}
let mut trace = 0.0;
for i in 0..dim {
trace += matrix[i * dim + i];
}
let lambda = (trace / dim as f64 / ridge_factor.max(1e-9)).max(1e-12);
let w = 2 * dim;
let mut aug = vec![0.0; dim * w];
for r in 0..dim {
for c in 0..dim {
aug[r * w + c] = matrix[r * dim + c];
}
aug[r * w + r] += lambda;
aug[r * w + dim + r] = 1.0;
}
invert_augmented(aug, dim, 1e-9)
}
fn invert_augmented(mut aug: Vec<f64>, dim: usize, eps: f64) -> StatsResult<Vec<f64>> {
let w = 2 * dim;
for col in 0..dim {
let mut pivot_row = col;
let mut pivot_val = aug[col * w + col].abs();
for r in (col + 1)..dim {
let v = aug[r * w + col].abs();
if v > pivot_val {
pivot_val = v;
pivot_row = r;
}
}
if pivot_val < eps {
return Err(StatsError::numerical_error(format!(
"linalg::invert: matrix is singular (pivot {} < eps {})",
pivot_val, eps
)));
}
if pivot_row != col {
for c in 0..w {
aug.swap(col * w + c, pivot_row * w + c);
}
}
let inv_pivot = 1.0 / aug[col * w + col];
for c in 0..w {
aug[col * w + c] *= inv_pivot;
}
for r in 0..dim {
if r == col {
continue;
}
let factor = aug[r * w + col];
if factor == 0.0 {
continue;
}
for c in 0..w {
aug[r * w + c] -= factor * aug[col * w + c];
}
}
}
let mut inv = vec![0.0; dim * dim];
for r in 0..dim {
for c in 0..dim {
inv[r * dim + c] = aug[r * w + dim + c];
}
}
Ok(inv)
}
pub fn mahalanobis_sq(x: &[f64], mean: &[f64], m_inv: &[f64]) -> StatsResult<f64> {
let dim = x.len();
let mut d = vec![0.0; dim];
let mut md = vec![0.0; dim];
mahalanobis_sq_into(x, mean, m_inv, &mut d, &mut md)
}
pub fn mahalanobis_sq_into(
x: &[f64],
mean: &[f64],
m_inv: &[f64],
scratch_diff: &mut [f64],
scratch_md: &mut [f64],
) -> StatsResult<f64> {
let dim = x.len();
if mean.len() != dim {
return Err(StatsError::invalid_input(format!(
"linalg::mahalanobis_sq_into: mean dim {} != x dim {}",
mean.len(),
dim
)));
}
if m_inv.len() != dim * dim {
return Err(StatsError::invalid_input(format!(
"linalg::mahalanobis_sq_into: m_inv dim {} != expected {}",
m_inv.len(),
dim * dim
)));
}
if scratch_diff.len() != dim || scratch_md.len() != dim {
return Err(StatsError::invalid_input(format!(
"linalg::mahalanobis_sq_into: scratch buffers must have len {}",
dim
)));
}
for i in 0..dim {
scratch_diff[i] = x[i] - mean[i];
}
for r in 0..dim {
let mut s = 0.0;
let row = r * dim;
for c in 0..dim {
s += m_inv[row + c] * scratch_diff[c];
}
scratch_md[r] = s;
}
let mut score = 0.0;
for i in 0..dim {
score += scratch_diff[i] * scratch_md[i];
}
Ok(score)
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn invert_identity_is_identity() {
let i = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let inv = invert(&i, 3, 1e-9).unwrap();
for (a, b) in i.iter().zip(inv.iter()) {
assert!(approx(*a, *b, 1e-12));
}
}
#[test]
fn invert_2x2() {
let a = vec![4.0, 7.0, 2.0, 6.0];
let inv = invert(&a, 2, 1e-9).unwrap();
assert!(approx(inv[0], 0.6, 1e-12));
assert!(approx(inv[1], -0.7, 1e-12));
assert!(approx(inv[2], -0.2, 1e-12));
assert!(approx(inv[3], 0.4, 1e-12));
}
#[test]
fn invert_singular_errors() {
let a = vec![1.0, 2.0, 2.0, 4.0];
assert!(invert(&a, 2, 1e-9).is_err());
}
#[test]
fn invert_a_times_inv_is_identity() {
let a = vec![2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 2.0, 4.0];
let inv = invert(&a, 3, 1e-9).unwrap();
let mut prod = vec![0.0; 9];
for r in 0..3 {
for c in 0..3 {
let mut s = 0.0;
for k in 0..3 {
s += a[r * 3 + k] * inv[k * 3 + c];
}
prod[r * 3 + c] = s;
}
}
let identity = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
for (p, i) in prod.iter().zip(identity.iter()) {
assert!(approx(*p, *i, 1e-9));
}
}
#[test]
fn invert_wrong_size_errors() {
let a = vec![1.0; 5]; assert!(invert(&a, 2, 1e-9).is_err());
}
#[test]
fn invert_with_ridge_handles_singular() {
let a = vec![1.0, 2.0, 2.0, 4.0]; let inv = invert_with_ridge(&a, 2, 10.0);
assert!(inv.is_ok());
}
#[test]
fn mahalanobis_identity_is_l2() {
let x = vec![1.0, 2.0, 3.0];
let mean = vec![0.0, 0.0, 0.0];
let i = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let d = mahalanobis_sq(&x, &mean, &i).unwrap();
assert!(approx(d, 1.0 + 4.0 + 9.0, 1e-12));
}
#[test]
fn mahalanobis_diag_weighted() {
let x = vec![2.0, 2.0];
let mean = vec![0.0, 0.0];
let m = vec![1.0, 0.0, 0.0, 4.0];
let d = mahalanobis_sq(&x, &mean, &m).unwrap();
assert!(approx(d, 20.0, 1e-12));
}
#[test]
fn mahalanobis_sq_into_matches_owning_variant() {
let x = vec![1.0, 2.0, 3.0];
let mean = vec![0.5, 0.5, 0.5];
let m = vec![2.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.5];
let owned = mahalanobis_sq(&x, &mean, &m).unwrap();
let mut diff = vec![0.0; 3];
let mut md = vec![0.0; 3];
let scratched = mahalanobis_sq_into(&x, &mean, &m, &mut diff, &mut md).unwrap();
assert!(approx(owned, scratched, 1e-15));
}
#[test]
fn mahalanobis_sq_into_wrong_scratch_errors() {
let x = vec![1.0, 2.0];
let mean = vec![0.0, 0.0];
let m = vec![1.0, 0.0, 0.0, 1.0];
let mut diff = vec![0.0; 1]; let mut md = vec![0.0; 2];
assert!(mahalanobis_sq_into(&x, &mean, &m, &mut diff, &mut md).is_err());
}
}