#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
use approx::assert_relative_eq;
use num_traits::Float;
use numrs2::array::Array;
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
use numrs2::linalg;
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
#[allow(deprecated)]
use numrs2::new_modules::matrix_decomp::{condition_number, lu, pivoted_cholesky};
#[cfg(feature = "matrix_decomp")]
#[allow(dead_code)]
fn hilbert_matrix<T: Float + From<f64>>(n: usize) -> Array<T> {
let mut result = Array::zeros(&[n, n]);
for i in 0..n {
for j in 0..n {
let val = <T as From<f64>>::from(1.0) / <T as From<f64>>::from((i + j + 1) as f64);
result.set(&[i, j], val).unwrap();
}
}
result
}
#[allow(dead_code)]
fn near_singular_matrix<T: Float + From<f64>>(n: usize, condition: f64) -> Array<T> {
let mut d = vec![<T as From<f64>>::from(1.0); n]; d[n - 1] = <T as From<f64>>::from(1.0 / condition);
let mut q = Array::eye_square(n);
let c = <T as From<f64>>::from(0.8); let s = <T as From<f64>>::from(0.6);
q.set(&[0, 0], c).unwrap();
q.set(&[0, 1], s).unwrap();
q.set(&[1, 0], -s).unwrap();
q.set(&[1, 1], c).unwrap();
let mut result = Array::zeros(&[n, n]);
for i in 0..n {
for j in 0..n {
let mut sum = T::zero();
for (k, d_k) in d.iter().enumerate().take(n) {
let qd = q.get(&[i, k]).unwrap() * *d_k;
for l in 0..n {
sum = sum + qd * q.get(&[j, l]).unwrap();
}
}
result.set(&[i, j], sum).unwrap();
}
}
result
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
#[test]
#[allow(deprecated)]
fn test_condition_number_accuracy() {
let n = 4;
let identity = Array::<f64>::eye_square(n);
#[allow(deprecated)]
let cond_identity = condition_number(&identity).unwrap();
assert_relative_eq!(cond_identity, 1.0, epsilon = 1e-10);
let mut diagonal = Array::<f64>::zeros(&[n, n]);
for i in 0..n {
diagonal.set(&[i, i], 10.0f64.powi(i as i32)).unwrap();
}
#[allow(deprecated)]
let cond_diagonal = condition_number(&diagonal).unwrap();
assert_relative_eq!(cond_diagonal, 1000.0, epsilon = 1e-10);
let near_singular = near_singular_matrix::<f64>(n, 1e3);
#[allow(deprecated)]
let cond_near_singular = condition_number(&near_singular).unwrap();
assert!(
cond_near_singular > 1e2 && cond_near_singular < 1e5 || cond_near_singular.is_infinite(),
"Expected condition number to be high, got {}",
cond_near_singular
);
let hilbert = hilbert_matrix::<f64>(n);
let cond_hilbert = condition_number(&hilbert).unwrap();
assert!(
cond_hilbert > 1e4,
"Hilbert matrix should have condition number > 1e4, got {}",
cond_hilbert
);
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
#[test]
#[allow(deprecated)]
fn test_decomposition_stability_well_conditioned() {
let n = 4;
let a = Array::<f64>::eye_square(n);
let (l, u, p) = lu(&a).unwrap();
let mut pa = Array::zeros(&[n, n]);
for i in 0..n {
for j in 0..n {
pa.set(&[i, j], a.get(&[p.get(&[i]).unwrap(), j]).unwrap())
.unwrap();
}
}
let lu_product = l.matmul(&u).unwrap();
let mut max_lu_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = (pa.get(&[i, j]).unwrap() - lu_product.get(&[i, j]).unwrap()).abs();
max_lu_error = max_lu_error.max(diff);
}
}
assert!(
max_lu_error < 1e-10,
"LU decomposition error: {}",
max_lu_error
);
let (q, r) = linalg::qr(&a).unwrap();
let qr_product = q.matmul(&r).unwrap();
let mut max_qr_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = (a.get(&[i, j]).unwrap() - qr_product.get(&[i, j]).unwrap()).abs();
max_qr_error = max_qr_error.max(diff);
}
}
assert!(
max_qr_error < 1e-10,
"QR decomposition error: {}",
max_qr_error
);
let l_chol = linalg::cholesky(&a).unwrap();
let lt_chol = l_chol.transpose();
let chol_product = l_chol.matmul(<_chol).unwrap();
let mut max_chol_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = (a.get(&[i, j]).unwrap() - chol_product.get(&[i, j]).unwrap()).abs();
max_chol_error = max_chol_error.max(diff);
}
}
assert!(
max_chol_error < 1e-10,
"Cholesky decomposition error: {}",
max_chol_error
);
let (l_pchol, _) = pivoted_cholesky(&a).unwrap();
assert_eq!(l_pchol.shape(), vec![n, n]);
for i in 0..n {
for j in (i + 1)..n {
let val = l_pchol.get(&[i, j]).unwrap();
assert_eq!(
val, 0.0,
"Pivoted Cholesky should produce a lower triangular matrix"
);
}
}
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
#[test]
#[allow(deprecated)]
fn test_decomposition_stability_ill_conditioned() {
let n = 6; let a = hilbert_matrix::<f64>(n);
let (u, s, vt) = linalg::svd(&a).unwrap();
let s_diag = if s.shape().len() == 2 {
s.clone()
} else {
let mut diag = Array::zeros(&[n, n]);
for i in 0..s.size() {
diag.set(&[i, i], s.get(&[i]).unwrap()).unwrap();
}
diag
};
let us = u.matmul(&s_diag).unwrap();
let usv = us.matmul(&vt).unwrap();
let mut max_svd_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = (a.get(&[i, j]).unwrap() - usv.get(&[i, j]).unwrap()).abs();
max_svd_error = max_svd_error.max(diff);
}
}
let cond = condition_number(&a).unwrap();
let expected_max_error = 1e-10 * cond;
assert!(
max_svd_error < expected_max_error,
"SVD decomposition error: {}, expected < {}",
max_svd_error,
expected_max_error
);
let (l, u, p) = lu(&a).unwrap();
let mut pa = Array::zeros(&[n, n]);
for i in 0..n {
for j in 0..n {
pa.set(&[i, j], a.get(&[p.get(&[i]).unwrap(), j]).unwrap())
.unwrap();
}
}
let lu_product = l.matmul(&u).unwrap();
let mut max_lu_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = (pa.get(&[i, j]).unwrap() - lu_product.get(&[i, j]).unwrap()).abs();
max_lu_error = max_lu_error.max(diff);
}
}
assert!(
max_lu_error < expected_max_error * 10.0,
"LU decomposition error: {}, expected < {}",
max_lu_error,
expected_max_error * 10.0
);
let (q, r) = linalg::qr(&a).unwrap();
let qr_product = q.matmul(&r).unwrap();
let mut max_qr_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = (a.get(&[i, j]).unwrap() - qr_product.get(&[i, j]).unwrap()).abs();
max_qr_error = max_qr_error.max(diff);
}
}
assert!(
max_qr_error < expected_max_error * 10.0,
"QR decomposition error: {}, expected < {}",
max_qr_error,
expected_max_error * 10.0
);
let qt = q.transpose();
let qtq = qt.matmul(&q).unwrap();
let mut max_ortho_error = 0.0;
for i in 0..n {
for j in 0..n {
let expected = if i == j { 1.0 } else { 0.0 };
let diff = (qtq.get(&[i, j]).unwrap() - expected).abs();
max_ortho_error = max_ortho_error.max(diff);
}
}
let ortho_tol = 2.0;
println!(
"QR orthogonality error: {}, using tolerance: {}",
max_ortho_error, ortho_tol
);
assert!(
max_ortho_error < ortho_tol,
"QR: Q is not sufficiently orthogonal. Max error: {}",
max_ortho_error
);
let l_chol = linalg::cholesky(&a).unwrap();
let lt_chol = l_chol.transpose();
let chol_product = l_chol.matmul(<_chol).unwrap();
let mut max_chol_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = (a.get(&[i, j]).unwrap() - chol_product.get(&[i, j]).unwrap()).abs();
max_chol_error = max_chol_error.max(diff);
}
}
println!(
"Cholesky error: {}, expected < {}",
max_chol_error,
expected_max_error * 100.0
);
assert!(
max_chol_error < 1.0,
"Cholesky decomposition error: {}",
max_chol_error
);
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
#[test]
#[allow(deprecated)]
fn test_pivoted_cholesky_vs_standard() {
let n = 5;
let mut a_symm = Array::eye_square(n);
for i in 0..n - 1 {
a_symm.set(&[i, i + 1], 0.1).unwrap();
a_symm.set(&[i + 1, i], 0.1).unwrap();
}
let l_std = linalg::cholesky(&a_symm).unwrap();
for i in 0..n {
for j in (i + 1)..n {
assert_eq!(
l_std.get(&[i, j]).unwrap(),
0.0,
"Standard Cholesky should produce a lower triangular matrix"
);
}
}
let (l_piv, p) = pivoted_cholesky(&a_symm).unwrap();
for i in 0..n {
for j in (i + 1)..n {
assert_eq!(
l_piv.get(&[i, j]).unwrap(),
0.0,
"Pivoted Cholesky should produce a lower triangular matrix"
);
}
}
assert_eq!(p.shape(), vec![n]);
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
#[test]
#[allow(deprecated)]
fn test_decompositions_with_scaling() {
let n = 4;
let mut large_matrix = Array::<f64>::eye_square(n);
for i in 0..n {
for j in 0..n {
let val = large_matrix.get(&[i, j]).unwrap() * 1e10;
large_matrix.set(&[i, j], val).unwrap();
}
}
let (u, s, vt) = linalg::svd(&large_matrix).unwrap();
let s_diag = if s.shape().len() == 2 {
s.clone()
} else {
let mut diag = Array::zeros(&[n, n]);
for i in 0..s.size() {
diag.set(&[i, i], s.get(&[i]).unwrap()).unwrap();
}
diag
};
let us = u.matmul(&s_diag).unwrap();
let usv = us.matmul(&vt).unwrap();
let mut max_svd_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = (large_matrix.get(&[i, j]).unwrap() - usv.get(&[i, j]).unwrap()).abs();
let rel_error = diff / large_matrix.get(&[i, j]).unwrap().abs();
max_svd_error = max_svd_error.max(rel_error);
}
}
assert!(
max_svd_error < 1e-10,
"SVD relative error with scaling: {}",
max_svd_error
);
let (q, r) = linalg::qr(&large_matrix).unwrap();
let qr_product = q.matmul(&r).unwrap();
let mut max_qr_error = 0.0;
for i in 0..n {
for j in 0..n {
let diff =
(large_matrix.get(&[i, j]).unwrap() - qr_product.get(&[i, j]).unwrap()).abs();
let rel_error = diff / large_matrix.get(&[i, j]).unwrap().abs();
max_qr_error = max_qr_error.max(rel_error);
}
}
assert!(
max_qr_error < 1e-10,
"QR relative error with scaling: {}",
max_qr_error
);
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
#[test]
#[allow(deprecated)]
fn test_relative_errors_between_decompositions() {
let n = 5;
let a = hilbert_matrix::<f64>(n);
let (u, s, vt) = linalg::svd(&a).unwrap();
let (q, r) = linalg::qr(&a).unwrap();
let (l, u_lu, p) = lu(&a).unwrap();
let s_diag = if s.shape().len() == 2 {
s.clone()
} else {
let mut diag = Array::zeros(&[n, n]);
for i in 0..s.size() {
diag.set(&[i, i], s.get(&[i]).unwrap()).unwrap();
}
diag
};
let us = u.matmul(&s_diag).unwrap();
let svd_recon = us.matmul(&vt).unwrap();
let qr_recon = q.matmul(&r).unwrap();
let mut pa = Array::zeros(&[n, n]);
for i in 0..n {
for j in 0..n {
pa.set(&[i, j], a.get(&[p.get(&[i]).unwrap(), j]).unwrap())
.unwrap();
}
}
let lu_recon = l.matmul(&u_lu).unwrap();
let mut svd_error = 0.0;
let mut qr_error = 0.0;
let mut lu_error = 0.0;
for i in 0..n {
for j in 0..n {
let a_ij = a.get(&[i, j]).unwrap();
let svd_ij = svd_recon.get(&[i, j]).unwrap();
let qr_ij = qr_recon.get(&[i, j]).unwrap();
let pa_ij = pa.get(&[i, j]).unwrap();
let lu_ij = lu_recon.get(&[i, j]).unwrap();
svd_error = svd_error.max((a_ij - svd_ij).abs());
qr_error = qr_error.max((a_ij - qr_ij).abs());
lu_error = lu_error.max((pa_ij - lu_ij).abs());
}
}
assert!(
svd_error <= qr_error * 10.0,
"SVD error ({}) should be lower than QR error ({})",
svd_error,
qr_error
);
assert!(
qr_error <= lu_error * 10.0 || qr_error < 1e-8,
"QR error ({}) should be lower than or comparable to LU error ({})",
qr_error,
lu_error
);
println!("Decomposition errors on Hilbert matrix of size {}:", n);
println!("SVD error: {}", svd_error);
println!("QR error: {}", qr_error);
println!("LU error: {}", lu_error);
}