use core::hint::cold_path;
use crate::LaError;
use crate::matrix::Matrix;
use crate::vector::Vector;
#[must_use]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Ldlt<const D: usize> {
factors: Matrix<D>,
tol: f64,
}
impl<const D: usize> Ldlt<D> {
#[inline]
pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
debug_assert!(tol >= 0.0, "tol must be non-negative");
#[cfg(debug_assertions)]
debug_assert_symmetric(&a);
let mut f = a;
for j in 0..D {
let d = f.rows[j][j];
if !d.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(j, j));
}
if d <= tol {
cold_path();
return Err(LaError::Singular { pivot_col: j });
}
for i in (j + 1)..D {
let l = f.rows[i][j] / d;
if !l.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(i, j));
}
f.rows[i][j] = l;
}
for i in (j + 1)..D {
let l_i = f.rows[i][j];
let l_i_d = l_i * d;
for k in (j + 1)..=i {
let l_k = f.rows[k][j];
let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]);
if !new_val.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(i, k));
}
f.rows[i][k] = new_val;
}
}
}
Ok(Self { factors: f, tol })
}
#[inline]
#[must_use]
pub const fn det(&self) -> f64 {
let mut det = 1.0;
let mut i = 0;
while i < D {
det *= self.factors.rows[i][i];
i += 1;
}
det
}
#[inline]
pub const fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
let mut x = b.data;
let mut i = 0;
while i < D {
let mut sum = x[i];
let row = self.factors.rows[i];
let mut j = 0;
while j < i {
sum = (-row[j]).mul_add(x[j], sum);
j += 1;
}
if !sum.is_finite() {
cold_path();
return Err(LaError::non_finite_at(i));
}
x[i] = sum;
i += 1;
}
let mut i = 0;
while i < D {
let diag = self.factors.rows[i][i];
if !diag.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(i, i));
}
if diag <= self.tol {
cold_path();
return Err(LaError::Singular { pivot_col: i });
}
let quotient = x[i] / diag;
if !quotient.is_finite() {
cold_path();
return Err(LaError::non_finite_at(i));
}
x[i] = quotient;
i += 1;
}
let mut ii = 0;
while ii < D {
let i = D - 1 - ii;
let mut sum = x[i];
let mut j = i + 1;
while j < D {
sum = (-self.factors.rows[j][i]).mul_add(x[j], sum);
j += 1;
}
if !sum.is_finite() {
cold_path();
return Err(LaError::non_finite_at(i));
}
x[i] = sum;
ii += 1;
}
Ok(Vector::new(x))
}
}
#[cfg(debug_assertions)]
fn debug_assert_symmetric<const D: usize>(a: &Matrix<D>) {
if let Some((r, c)) = a.first_asymmetry(1e-12) {
let diff = (a.rows[r][c] - a.rows[c][r]).abs();
let eps = 1e-12 * a.inf_norm().max(1.0);
debug_assert!(
false,
"matrix must be symmetric (diff={diff}, eps={eps}) at ({r}, {c}); \
pre-validate with Matrix::is_symmetric before calling ldlt"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DEFAULT_SINGULAR_TOL;
use core::hint::black_box;
use approx::assert_abs_diff_eq;
use pastey::paste;
macro_rules! gen_public_api_ldlt_identity_tests {
($d:literal) => {
paste! {
#[test]
fn [<public_api_ldlt_det_and_solve_identity_ $d d>]() {
let a = Matrix::<$d>::identity();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
assert_abs_diff_eq!(ldlt.det(), 1.0, epsilon = 1e-12);
let b_arr = {
let mut arr = [0.0f64; $d];
let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = *src;
}
arr
};
let b = Vector::<$d>::new(black_box(b_arr));
let x = ldlt.solve_vec(b).unwrap().into_array();
for i in 0..$d {
assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12);
}
}
}
};
}
gen_public_api_ldlt_identity_tests!(2);
gen_public_api_ldlt_identity_tests!(3);
gen_public_api_ldlt_identity_tests!(4);
gen_public_api_ldlt_identity_tests!(5);
macro_rules! gen_public_api_ldlt_diagonal_tests {
($d:literal) => {
paste! {
#[test]
fn [<public_api_ldlt_det_and_solve_diagonal_spd_ $d d>]() {
let diag = {
let mut arr = [0.0f64; $d];
let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = *src;
}
arr
};
let mut rows = [[0.0f64; $d]; $d];
for i in 0..$d {
rows[i][i] = diag[i];
}
let a = Matrix::<$d>::from_rows(black_box(rows));
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let expected_det = {
let mut acc = 1.0;
for i in 0..$d {
acc *= diag[i];
}
acc
};
assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-12);
let b_arr = {
let mut arr = [0.0f64; $d];
let values = [5.0f64, 4.0, 3.0, 2.0, 1.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = *src;
}
arr
};
let b = Vector::<$d>::new(black_box(b_arr));
let x = ldlt.solve_vec(b).unwrap().into_array();
for i in 0..$d {
assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12);
}
}
}
};
}
gen_public_api_ldlt_diagonal_tests!(2);
gen_public_api_ldlt_diagonal_tests!(3);
gen_public_api_ldlt_diagonal_tests!(4);
gen_public_api_ldlt_diagonal_tests!(5);
#[test]
fn solve_2x2_known_spd() {
let a = Matrix::<2>::from_rows(black_box([[4.0, 2.0], [2.0, 3.0]]));
let ldlt = (black_box(Ldlt::<2>::factor))(a, DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<2>::new(black_box([1.0, 2.0]));
let x = ldlt.solve_vec(b).unwrap().into_array();
assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12);
assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12);
assert_abs_diff_eq!(ldlt.det(), 8.0, epsilon = 1e-12);
}
#[test]
fn solve_3x3_spd_tridiagonal_smoke() {
let a = Matrix::<3>::from_rows(black_box([
[2.0, -1.0, 0.0],
[-1.0, 2.0, -1.0],
[0.0, -1.0, 2.0],
]));
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0]));
let x = ldlt.solve_vec(b).unwrap().into_array();
for &x_i in &x {
assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
}
}
#[test]
fn singular_detected_for_degenerate_psd() {
let a = Matrix::<2>::from_rows(black_box([[1.0, 1.0], [1.0, 1.0]]));
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(err, LaError::Singular { pivot_col: 1 });
}
#[test]
fn nonfinite_detected() {
let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(0),
col: 0
}
);
}
#[test]
fn nonfinite_l_multiplier_overflow() {
let a = Matrix::<2>::from_rows([[1e-11, 1e300], [1e300, 1.0]]);
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(1),
col: 0
}
);
}
#[test]
fn nonfinite_trailing_submatrix_overflow() {
let a = Matrix::<2>::from_rows([[1.0, 1e200], [1e200, 1.0]]);
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(1),
col: 1
}
);
}
#[test]
fn nonfinite_solve_vec_forward_substitution_overflow() {
let a = Matrix::<3>::from_rows([
[1.0, 1e153, 0.0],
[1e153, 1e306 + 1.0, 0.0],
[0.0, 0.0, 1.0],
]);
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<3>::new([1e156, 0.0, 0.0]);
let err = ldlt.solve_vec(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}
#[test]
fn nonfinite_solve_vec_back_substitution_overflow() {
let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0], [0.0, 2.0, 5.0]]);
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<3>::new([0.0, 0.0, 1e308]);
let err = ldlt.solve_vec(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}
#[test]
fn nonfinite_solve_vec_diagonal_solve_overflow() {
let a = Matrix::<2>::from_rows([[1.0, 0.0], [0.0, 1.0e-11]]);
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<2>::new([0.0, 1.0e300]);
let err = ldlt.solve_vec(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "matrix must be symmetric")]
fn debug_asymmetric_input_panics() {
let a = Matrix::<3>::from_rows([[4.0, 2.0, 0.0], [-2.0, 5.0, 1.0], [0.0, 1.0, 3.0]]);
let _ = a.ldlt(DEFAULT_SINGULAR_TOL);
}
macro_rules! gen_solve_vec_defensive_tests {
($d:literal) => {
paste! {
#[test]
fn [<solve_vec_defensive_non_finite_diagonal_ $d d>]() {
let mut factors = Matrix::<$d>::identity();
factors.rows[$d - 1][$d - 1] = f64::NAN;
let ldlt = Ldlt::<$d> {
factors,
tol: DEFAULT_SINGULAR_TOL,
};
let b = Vector::<$d>::new([1.0; $d]);
let err = ldlt.solve_vec(b).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some($d - 1),
col: $d - 1,
}
);
}
#[test]
fn [<solve_vec_defensive_sub_tolerance_diagonal_ $d d>]() {
let mut factors = Matrix::<$d>::identity();
factors.rows[$d - 1][$d - 1] = 0.0;
let ldlt = Ldlt::<$d> {
factors,
tol: DEFAULT_SINGULAR_TOL,
};
let b = Vector::<$d>::new([1.0; $d]);
let err = ldlt.solve_vec(b).unwrap_err();
assert_eq!(err, LaError::Singular { pivot_col: $d - 1 });
}
}
};
}
gen_solve_vec_defensive_tests!(2);
gen_solve_vec_defensive_tests!(3);
gen_solve_vec_defensive_tests!(4);
gen_solve_vec_defensive_tests!(5);
macro_rules! gen_ldlt_const_eval_tests {
($d:literal) => {
paste! {
#[test]
fn [<ldlt_det_const_eval_ $d d>]() {
const DET: f64 = {
let mut factors = Matrix::<$d>::identity();
factors.rows[0][0] = 2.0;
let ldlt = Ldlt::<$d> {
factors,
tol: DEFAULT_SINGULAR_TOL,
};
ldlt.det()
};
assert!((DET - 2.0).abs() <= 1e-12);
}
#[test]
fn [<ldlt_solve_vec_const_eval_ $d d>]() {
#[allow(clippy::cast_precision_loss)]
const X: [f64; $d] = {
let ldlt = Ldlt::<$d> {
factors: Matrix::<$d>::identity(),
tol: DEFAULT_SINGULAR_TOL,
};
let mut b_arr = [0.0f64; $d];
let mut i = 0;
while i < $d {
b_arr[i] = i as f64 + 1.0;
i += 1;
}
let b = Vector::<$d>::new(b_arr);
match ldlt.solve_vec(b) {
Ok(v) => v.into_array(),
Err(_) => [0.0f64; $d],
}
};
#[allow(clippy::cast_precision_loss)]
for i in 0..$d {
let expected = i as f64 + 1.0;
assert!((X[i] - expected).abs() <= 1e-12);
}
}
}
};
}
gen_ldlt_const_eval_tests!(2);
gen_ldlt_const_eval_tests!(3);
gen_ldlt_const_eval_tests!(4);
gen_ldlt_const_eval_tests!(5);
}