use nalgebra::DMatrix;
use nalgebra::linalg::{Cholesky, SymmetricEigen};
pub fn matrix_square_root(matrix: &DMatrix<f64>) -> DMatrix<f64> {
assert!(
matrix.is_square(),
"matrix_square_root: matrix must be square"
);
const INITIAL_JITTER: f64 = 1e-12;
const MAX_JITTER: f64 = 1e-6;
const MAX_TRIES: usize = 6;
const EIGEN_FLOOR: f64 = 1e-12;
let p = symmetrize(matrix);
if let Some(s) = chol_sqrt(&p) {
return s;
}
if let Some(s) = chol_sqrt_with_jitter(&p, INITIAL_JITTER, MAX_JITTER, MAX_TRIES) {
return s;
}
evd_symmetric_sqrt_with_floor(&p, EIGEN_FLOOR)
}
#[inline]
pub fn symmetrize(m: &DMatrix<f64>) -> DMatrix<f64> {
0.5 * (m + m.transpose())
}
fn chol_sqrt(p: &DMatrix<f64>) -> Option<DMatrix<f64>> {
Cholesky::new(p.clone()).map(|ch| ch.l().into_owned())
}
fn chol_sqrt_with_jitter(
p: &DMatrix<f64>,
initial_jitter: f64,
max_jitter: f64,
max_tries: usize,
) -> Option<DMatrix<f64>> {
let n = p.nrows();
let mut jitter = initial_jitter;
for _ in 0..max_tries {
let mut pj = p.clone();
for i in 0..n {
pj[(i, i)] += jitter;
}
if let Some(ch) = Cholesky::new(pj) {
return Some(ch.l().into_owned());
}
jitter *= 10.0;
if jitter > max_jitter {
break;
}
}
None
}
fn evd_symmetric_sqrt_with_floor(p: &DMatrix<f64>, floor: f64) -> DMatrix<f64> {
let se = SymmetricEigen::new(p.clone());
let mut lambdas = se.eigenvalues;
let u = se.eigenvectors;
for i in 0..lambdas.len() {
if lambdas[i] < floor {
lambdas[i] = floor;
}
}
let sqrt_vals = lambdas.map(|l| l.sqrt());
let sigma_half = DMatrix::<f64>::from_diagonal(&sqrt_vals);
&u * sigma_half * u.transpose()
}
#[derive(Debug, Clone, Copy)]
pub struct SolveOptions {
pub initial_jitter: f64, pub max_jitter: f64, pub max_tries: usize, }
impl Default for SolveOptions {
fn default() -> Self {
Self {
initial_jitter: 1e-12,
max_jitter: 1e-6,
max_tries: 6,
}
}
}
pub fn chol_solve_spd(
a: &DMatrix<f64>,
b: &DMatrix<f64>,
opt: SolveOptions,
) -> Option<DMatrix<f64>> {
assert!(a.is_square(), "chol_solve_spd: A must be square");
assert_eq!(a.nrows(), b.nrows(), "chol_solve_spd: A and B incompatible");
let a_sym = symmetrize(a);
if let Some(ch) = Cholesky::new(a_sym.clone()) {
return Some(ch.solve(b));
}
let n = a_sym.nrows();
let mut jitter = opt.initial_jitter;
for _ in 0..opt.max_tries {
let mut a_j = a_sym.clone();
for i in 0..n {
a_j[(i, i)] += jitter;
}
if let Some(ch) = Cholesky::new(a_j) {
return Some(ch.solve(b));
}
jitter *= 10.0;
if jitter > opt.max_jitter {
break;
}
}
None
}
pub fn robust_spd_solve(a: &DMatrix<f64>, b: &DMatrix<f64>) -> DMatrix<f64> {
if let Some(x) = chol_solve_spd(a, b, SolveOptions::default()) {
x
} else if let Some(inv) = symmetrize(a).try_inverse() {
&inv * b
} else {
panic!("robust_spd_solve: A is not invertible (even after jitter).");
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: &DMatrix<f64>, b: &DMatrix<f64>, tol: f64) -> bool {
if a.shape() != b.shape() {
return false;
}
let mut max_abs = 0.0f64;
for i in 0..a.nrows() {
for j in 0..a.ncols() {
max_abs = max_abs.max((a[(i, j)] - b[(i, j)]).abs());
}
}
max_abs <= tol
}
#[test]
fn t_symmetrize() {
let m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 0.0, 3.0]);
let s = symmetrize(&m);
let s_expected = DMatrix::from_row_slice(2, 2, &[1.0, 1.0, 1.0, 3.0]);
assert!(approx_eq(&s, &s_expected, 1e-15));
}
#[test]
fn t_chol_sqrt_spd() {
let a = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 0.5, 0.0, 1.0, -1.0, 0.0, 0.0, 0.2]);
let p = &a * a.transpose();
let s = chol_sqrt(&p).expect("Cholesky should succeed for SPD");
let back = &s * s.transpose();
assert!(approx_eq(&back, &p, 1e-12));
}
#[test]
fn t_chol_sqrt_with_jitter() {
let a = DMatrix::from_row_slice(3, 3, &[1.0, 0.2, 0.0, 0.0, 1.0, 0.2, 0.0, 0.0, 1.0]);
let mut p = &a * a.transpose();
p[(2, 2)] -= 1e-10;
let s =
chol_sqrt_with_jitter(&p, 1e-12, 1e-6, 6).expect("jittered Cholesky should succeed");
let back = &s * s.transpose();
let p_sym = symmetrize(&p);
assert!(approx_eq(&back, &p_sym, 1e-8));
}
#[test]
fn t_evd_floor() {
let p = DMatrix::from_row_slice(2, 2, &[0.0, 1.0, 1.0, 0.0]); let s = evd_symmetric_sqrt_with_floor(&p, 1e-12);
let back = &s * s.transpose();
let p_sym = symmetrize(&p);
assert_eq!(back.nrows(), p_sym.nrows());
assert_eq!(back.ncols(), p_sym.ncols());
assert!(approx_eq(&back, &back.transpose(), 1e-14));
}
#[test]
fn t_public_identity() {
let i = DMatrix::<f64>::identity(4, 4);
let s = matrix_square_root(&i);
assert!(approx_eq(&s, &i, 1e-14));
let back = &s * s.transpose();
assert!(approx_eq(&back, &i, 1e-12));
}
#[test]
fn t_public_nearly_spd() {
let a = DMatrix::from_row_slice(3, 3, &[1.0, 0.1, 0.0, 0.0, 1.0, 0.2, 0.0, 0.0, 1.0]);
let mut p = &a * a.transpose();
p[(2, 2)] -= 1e-10;
p[(0, 2)] += 1e-12;
let s = matrix_square_root(&p);
let back = &s * s.transpose();
let p_sym = symmetrize(&p);
assert!(approx_eq(&back, &p_sym, 1e-8));
}
#[test]
#[should_panic]
fn t_public_non_square_panics() {
let m = DMatrix::<f64>::zeros(3, 2);
let _ = matrix_square_root(&m);
}
#[test]
fn t_chol_sqrt_none() {
let m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 2.0, 1.0]); let result = chol_sqrt(&m);
assert!(result.is_none(), "Cholesky should fail for non-PD matrix");
}
#[test]
fn t_chol_sqrt_with_jitter_max_tries() {
let mut m = DMatrix::<f64>::identity(3, 3);
m[(0, 0)] = 0.0;
let result = chol_sqrt_with_jitter(&m, 0.01, 2.0, 3);
let _ = result;
}
#[test]
fn t_chol_sqrt_with_jitter_none() {
let mut m = DMatrix::<f64>::identity(3, 3);
m[(0, 0)] = -1e10;
let result = chol_sqrt_with_jitter(&m, 1e-12, 1e-6, 6);
let _ = result;
}
#[test]
fn t_evd_floor_negative_eigenvalues() {
let m = DMatrix::from_row_slice(3, 3, &[-1.0, 0.0, 0.0, 0.0, -2.0, 0.0, 0.0, 0.0, 3.0]);
let s = evd_symmetric_sqrt_with_floor(&m, 1e-6);
let back = &s * s.transpose();
assert!(approx_eq(&back, &back.transpose(), 1e-12));
let se = SymmetricEigen::new(back);
for lambda in se.eigenvalues.iter() {
assert!(
*lambda >= -1e-10,
"Eigenvalue should be non-negative after flooring"
);
}
}
#[test]
fn t_matrix_square_root_evd_fallback() {
let m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 2.0, 1.0]);
let s = matrix_square_root(&m);
let back = &s * s.transpose();
assert!(approx_eq(&back, &back.transpose(), 1e-12));
}
#[test]
fn t_chol_solve_spd_basic() {
let a = DMatrix::from_row_slice(2, 2, &[4.0, 2.0, 2.0, 3.0]);
let b = DMatrix::from_row_slice(2, 1, &[6.0, 5.0]);
let x = chol_solve_spd(&a, &b, SolveOptions::default()).expect("Should solve");
let result = &a * &x;
assert!(approx_eq(&result, &b, 1e-10));
}
#[test]
fn t_chol_solve_spd_with_jitter() {
let mut a = DMatrix::from_row_slice(2, 2, &[1.0, 0.5, 0.5, 1.0]);
a[(1, 1)] -= 0.25; let b = DMatrix::from_row_slice(2, 1, &[1.0, 1.0]);
let x = chol_solve_spd(&a, &b, SolveOptions::default()).expect("Should solve with jitter");
let result = &a * &x;
assert!(approx_eq(&result, &b, 1e-8));
}
#[test]
fn t_chol_solve_spd_none() {
let a = DMatrix::from_row_slice(2, 2, &[1e-15, 0.0, 0.0, 1e-15]);
let b = DMatrix::from_row_slice(2, 1, &[1.0, 1.0]);
let opts = SolveOptions {
initial_jitter: 1e-20,
max_jitter: 1e-18,
max_tries: 2,
};
let result = chol_solve_spd(&a, &b, opts);
let _ = result;
}
#[test]
fn t_robust_spd_solve_basic() {
let a = DMatrix::from_row_slice(2, 2, &[4.0, 2.0, 2.0, 3.0]);
let b = DMatrix::from_row_slice(2, 1, &[6.0, 5.0]);
let x = robust_spd_solve(&a, &b);
let result = &a * &x;
assert!(approx_eq(&result, &b, 1e-10));
}
#[test]
fn t_robust_spd_solve_fallback() {
let mut a = DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]);
a[(0, 1)] = 1e-8; let b = DMatrix::from_row_slice(2, 1, &[1.0, 2.0]);
let x = robust_spd_solve(&a, &b);
let a_sym = symmetrize(&a);
let result = &a_sym * &x;
assert!(approx_eq(&result, &b, 1e-8));
}
#[test]
fn t_robust_spd_solve_panic() {
let a = DMatrix::from_row_slice(2, 2, &[0.0, 0.0, 0.0, 0.0]);
let b = DMatrix::from_row_slice(2, 1, &[1.0, 1.0]);
let result = std::panic::catch_unwind(|| robust_spd_solve(&a, &b));
assert!(result.is_err() || result.is_ok());
}
#[test]
#[should_panic(expected = "chol_solve_spd: A must be square")]
fn t_chol_solve_spd_non_square_panic() {
let a = DMatrix::<f64>::zeros(3, 2);
let b = DMatrix::<f64>::zeros(3, 1);
let _ = chol_solve_spd(&a, &b, SolveOptions::default());
}
#[test]
#[should_panic(expected = "chol_solve_spd: A and B incompatible")]
fn t_chol_solve_spd_incompatible_panic() {
let a = DMatrix::<f64>::identity(2, 2);
let b = DMatrix::<f64>::zeros(3, 1);
let _ = chol_solve_spd(&a, &b, SolveOptions::default());
}
#[test]
fn t_chol_solve_spd_max_jitter_exceeded() {
let a = DMatrix::from_row_slice(2, 2, &[-10.0, 0.0, 0.0, -10.0]);
let b = DMatrix::from_row_slice(2, 1, &[1.0, 1.0]);
let opts = SolveOptions {
initial_jitter: 1e-6,
max_jitter: 1e-5,
max_tries: 10,
};
let result = chol_solve_spd(&a, &b, opts);
assert!(
result.is_none(),
"Should return None when jitter limit exceeded"
);
}
#[test]
fn t_robust_spd_solve_inverse_fallback() {
let mut a = DMatrix::from_row_slice(2, 2, &[2.0, 1.0, 1.0, 2.0]);
a[(0, 0)] = -0.1; let b = DMatrix::from_row_slice(2, 1, &[1.0, 1.0]);
let opts = SolveOptions {
initial_jitter: 1e-20,
max_jitter: 1e-19,
max_tries: 1,
};
let chol_result = chol_solve_spd(&a, &b, opts);
assert!(
chol_result.is_none(),
"Cholesky should fail with restrictive jitter"
);
let x = robust_spd_solve(&a, &b);
let a_sym = symmetrize(&a);
let result = &a_sym * &x;
assert!(approx_eq(&result, &b, 1e-6));
}
#[test]
fn t_robust_spd_solve_singular_handled() {
let a = DMatrix::from_row_slice(2, 2, &[1.0, 1.0, 1.0, 1.0]);
let b = DMatrix::from_row_slice(2, 1, &[1.0, 1.0]);
let result = std::panic::catch_unwind(|| robust_spd_solve(&a, &b));
assert!(result.is_ok() || result.is_err());
}
#[test]
fn t_solve_options_default() {
let opts = SolveOptions::default();
assert_eq!(opts.initial_jitter, 1e-12);
assert_eq!(opts.max_jitter, 1e-6);
assert_eq!(opts.max_tries, 6);
}
}