#![forbid(unsafe_code)]
use core::hint::cold_path;
use crate::matrix::{Matrix, SymmetricMatrix};
use crate::vector::Vector;
use crate::{LaError, Tolerance};
#[must_use]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Ldlt<const D: usize> {
factors: LdltFactors<D>,
}
#[derive(Clone, Copy, Debug, PartialEq)]
struct LdltFactors<const D: usize> {
storage: Matrix<D>,
}
impl<const D: usize> LdltFactors<D> {
#[inline]
const fn new_unchecked(storage: Matrix<D>) -> Self {
Self { storage }
}
#[inline]
#[must_use]
const fn row(&self, index: usize) -> &[f64; D] {
&self.storage.rows()[index]
}
#[inline]
#[must_use]
const fn entry(&self, row: usize, col: usize) -> f64 {
self.storage.rows()[row][col]
}
#[inline]
#[must_use]
const fn diag(&self, index: usize) -> f64 {
self.storage.rows()[index][index]
}
}
impl<const D: usize> Ldlt<D> {
#[inline]
#[allow(clippy::needless_range_loop)]
pub(crate) fn factor_symmetric(a: SymmetricMatrix<D>, tol: Tolerance) -> Result<Self, LaError> {
let mut f = a.into_matrix();
let tol = tol.get();
{
let rows = f.rows_mut_unchecked();
for j in 0..D {
let d = rows[j][j];
if !d.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(j, j));
}
if d < 0.0 {
cold_path();
return Err(LaError::not_positive_semidefinite(j, d));
}
if d <= tol {
cold_path();
return Err(LaError::Singular { pivot_col: j });
}
if D <= 5 {
for i in (j + 1)..D {
let l = rows[i][j] / d;
if !l.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(i, j));
}
rows[i][j] = l;
}
for i in (j + 1)..D {
let l_i = rows[i][j];
let l_i_d = l_i * d;
for k in (j + 1)..=i {
let l_k = rows[k][j];
let new_val = (-l_i_d).mul_add(l_k, rows[i][k]);
rows[i][k] = new_val;
}
}
} else {
for i in (j + 1)..D {
let l_i = rows[i][j] / d;
if !l_i.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(i, j));
}
rows[i][j] = l_i;
let l_i_d = l_i * d;
for k in (j + 1)..=i {
let l_k = rows[k][j];
let new_val = (-l_i_d).mul_add(l_k, rows[i][k]);
rows[i][k] = new_val;
}
}
}
}
}
Ok(Self {
factors: LdltFactors::new_unchecked(f),
})
}
#[inline]
pub const fn det(&self) -> Result<f64, LaError> {
let mut det = 1.0;
let mut i = 0;
while i < D {
det *= self.factors.diag(i);
if !det.is_finite() {
cold_path();
return Err(LaError::non_finite_at(i));
}
i += 1;
}
Ok(det)
}
#[inline]
pub const fn solve(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
self.solve_finite(b)
}
#[inline]
pub(crate) const fn solve_finite(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
let mut x = b.into_array();
let mut i = 0;
while i < D {
let mut sum = x[i];
let row = self.factors.row(i);
let mut j = 0;
while j < i {
sum = (-row[j]).mul_add(x[j], sum);
j += 1;
}
if !sum.is_finite() {
cold_path();
return Err(LaError::non_finite_at(i));
}
x[i] = sum;
i += 1;
}
let mut i = 0;
while i < D {
let diag = self.factors.diag(i);
let quotient = x[i] / diag;
if !quotient.is_finite() {
cold_path();
return Err(LaError::non_finite_at(i));
}
x[i] = quotient;
i += 1;
}
if D <= 4 {
let mut ii = 0;
while ii < D {
let i = D - 1 - ii;
let mut sum = x[i];
let mut j = i + 1;
while j < D {
sum = (-self.factors.entry(j, i)).mul_add(x[j], sum);
j += 1;
}
if !sum.is_finite() {
cold_path();
return Err(LaError::non_finite_at(i));
}
x[i] = sum;
ii += 1;
}
} else {
let mut jj = D;
while jj > 0 {
jj -= 1;
let x_j = x[jj];
if !x_j.is_finite() {
cold_path();
return Err(LaError::non_finite_at(jj));
}
let row = self.factors.row(jj);
let mut i = 0;
while i < jj {
x[i] = (-row[i]).mul_add(x_j, x[i]);
i += 1;
}
}
}
Ok(Vector::new_unchecked(x))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DEFAULT_SINGULAR_TOL;
use core::hint::black_box;
use approx::assert_abs_diff_eq;
use pastey::paste;
macro_rules! gen_public_api_ldlt_identity_tests {
($d:literal) => {
paste! {
#[test]
fn [<public_api_ldlt_det_and_solve_identity_ $d d>]() {
let a = Matrix::<$d>::identity();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
assert_abs_diff_eq!(ldlt.det().unwrap(), 1.0, epsilon = 1e-12);
let b_arr = {
let mut arr = [0.0f64; $d];
let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = *src;
}
arr
};
let b = Vector::<$d>::new(black_box(b_arr));
let x = ldlt.solve(b).unwrap().into_array();
for i in 0..$d {
assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12);
}
}
}
};
}
gen_public_api_ldlt_identity_tests!(2);
gen_public_api_ldlt_identity_tests!(3);
gen_public_api_ldlt_identity_tests!(4);
gen_public_api_ldlt_identity_tests!(5);
macro_rules! gen_public_api_ldlt_diagonal_tests {
($d:literal) => {
paste! {
#[test]
fn [<public_api_ldlt_det_and_solve_diagonal_spd_ $d d>]() {
let diag = {
let mut arr = [0.0f64; $d];
let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = *src;
}
arr
};
let mut rows = [[0.0f64; $d]; $d];
for i in 0..$d {
rows[i][i] = diag[i];
}
let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let expected_det = {
let mut acc = 1.0;
for i in 0..$d {
acc *= diag[i];
}
acc
};
assert_abs_diff_eq!(ldlt.det().unwrap(), expected_det, epsilon = 1e-12);
let b_arr = {
let mut arr = [0.0f64; $d];
let values = [5.0f64, 4.0, 3.0, 2.0, 1.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = *src;
}
arr
};
let b = Vector::<$d>::new(black_box(b_arr));
let x = ldlt.solve(b).unwrap().into_array();
for i in 0..$d {
assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12);
}
}
}
};
}
gen_public_api_ldlt_diagonal_tests!(2);
gen_public_api_ldlt_diagonal_tests!(3);
gen_public_api_ldlt_diagonal_tests!(4);
gen_public_api_ldlt_diagonal_tests!(5);
#[test]
fn solve_0x0_returns_empty_vector_and_unit_det() {
let a = Matrix::<0>::zero();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
assert_eq!(ldlt.det(), Ok(1.0));
assert!(
ldlt.solve(Vector::<0>::zero())
.unwrap()
.into_array()
.is_empty()
);
}
#[test]
fn solve_2x2_known_spd() {
let a = Matrix::<2>::try_from_rows(black_box([[4.0, 2.0], [2.0, 3.0]])).unwrap();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<2>::new(black_box([1.0, 2.0]));
let x = ldlt.solve(b).unwrap().into_array();
assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12);
assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12);
assert_abs_diff_eq!(ldlt.det().unwrap(), 8.0, epsilon = 1e-12);
}
#[test]
fn solve_3x3_spd_tridiagonal_smoke() {
let a = Matrix::<3>::try_from_rows(black_box([
[2.0, -1.0, 0.0],
[-1.0, 2.0, -1.0],
[0.0, -1.0, 2.0],
]))
.unwrap();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0]));
let x = ldlt.solve(b).unwrap().into_array();
for &x_i in &x {
assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
}
}
#[test]
fn singular_detected_for_degenerate_psd() {
let a = Matrix::<2>::try_from_rows(black_box([[1.0, 1.0], [1.0, 1.0]])).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(err, LaError::Singular { pivot_col: 1 });
}
#[test]
fn negative_initial_diagonal_reports_not_positive_semidefinite() {
let a = Matrix::<2>::try_from_rows(black_box([[-1.0, 0.0], [0.0, 1.0]])).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NotPositiveSemidefinite {
pivot_col: 0,
value: -1.0,
}
);
}
#[test]
fn negative_updated_diagonal_reports_not_positive_semidefinite() {
let a = Matrix::<2>::try_from_rows(black_box([[1.0, 2.0], [2.0, 1.0]])).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NotPositiveSemidefinite {
pivot_col: 1,
value: -3.0,
}
);
}
#[test]
fn matrix_constructor_rejects_nonfinite_diagonal() {
let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(0),
col: 0
}
);
}
#[test]
fn matrix_constructor_rejects_nonfinite_offdiagonal_before_asymmetry() {
let err = Matrix::<2>::try_from_rows([[1.0, f64::NAN], [0.0, 1.0]]).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(0),
col: 1,
}
);
}
#[test]
fn nonfinite_l_multiplier_overflow() {
let a = Matrix::<2>::try_from_rows([[1e-11, 1e300], [1e300, 1.0]]).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(1),
col: 0
}
);
}
#[test]
fn nonfinite_l_multiplier_overflow_fused_branch_6d() {
let mut rows = [[0.0; 6]; 6];
for (i, row) in rows.iter_mut().enumerate() {
row[i] = 1.0;
}
rows[0][0] = 1e-11;
rows[0][5] = 1e300;
rows[5][0] = 1e300;
let a = Matrix::<6>::try_from_rows(rows).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(5),
col: 0
}
);
}
#[test]
fn nonfinite_trailing_submatrix_overflow() {
let a = Matrix::<2>::try_from_rows([[1.0, 1e200], [1e200, 1.0]]).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(1),
col: 1
}
);
}
#[test]
fn nonfinite_trailing_submatrix_overflow_fused_branch_6d() {
let mut rows = [[0.0; 6]; 6];
for (i, row) in rows.iter_mut().enumerate() {
row[i] = 1.0;
}
rows[0][5] = 1e200;
rows[5][0] = 1e200;
let a = Matrix::<6>::try_from_rows(rows).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(5),
col: 5
}
);
}
#[test]
fn nonfinite_solve_forward_substitution_overflow() {
let a = Matrix::<3>::try_from_rows([
[1.0, 1e153, 0.0],
[1e153, 1e306 + 1.0, 0.0],
[0.0, 0.0, 1.0],
])
.unwrap();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<3>::new([1e156, 0.0, 0.0]);
let err = ldlt.solve(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}
#[test]
fn nonfinite_solve_back_substitution_overflow() {
let a = Matrix::<3>::try_from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0], [0.0, 2.0, 5.0]])
.unwrap();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<3>::new([0.0, 0.0, 1e308]);
let err = ldlt.solve(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}
#[test]
fn nonfinite_solve_back_substitution_overflow_scatter_branch_5d() {
let a = Matrix::<5>::try_from_rows([
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 2.0],
[0.0, 0.0, 0.0, 2.0, 5.0],
])
.unwrap();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<5>::new([0.0, 0.0, 0.0, 0.0, 1e308]);
let err = ldlt.solve(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 3 });
}
#[test]
fn nonfinite_solve_diagonal_solve_overflow() {
let a = Matrix::<2>::try_from_rows([[1.0, 0.0], [0.0, 1.0e-11]]).unwrap();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
let b = Vector::<2>::new([0.0, 1.0e300]);
let err = ldlt.solve(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}
#[test]
fn det_rejects_product_overflow() {
let a = Matrix::<5>::try_from_rows([
[1.0e100, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0e100, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0e100, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0e100, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0e100],
])
.unwrap();
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
assert_eq!(ldlt.det(), Err(LaError::NonFinite { row: None, col: 3 }));
}
#[test]
fn asymmetric_input_returns_typed_error() {
let a = Matrix::<3>::try_from_rows([[4.0, 2.0, 0.0], [-2.0, 5.0, 1.0], [0.0, 1.0, 3.0]])
.unwrap();
assert_eq!(
a.ldlt(DEFAULT_SINGULAR_TOL),
Err(LaError::Asymmetric {
row: 0,
col: 1,
dim: 3,
})
);
}
macro_rules! gen_solve_boundary_tests {
($d:literal) => {
paste! {
#[test]
fn [<solve_rhs_constructor_rejects_non_finite_ $d d>]() {
let mut rhs = [1.0; $d];
rhs[$d - 1] = f64::NAN;
assert_eq!(
Vector::<$d>::try_new(rhs),
Err(LaError::NonFinite {
row: None,
col: $d - 1,
})
);
}
}
};
}
gen_solve_boundary_tests!(2);
gen_solve_boundary_tests!(3);
gen_solve_boundary_tests!(4);
gen_solve_boundary_tests!(5);
macro_rules! gen_ldlt_const_eval_tests {
($d:literal) => {
paste! {
#[test]
fn [<ldlt_det_const_eval_ $d d>]() {
const DET: Result<f64, LaError> = {
let mut rows = [[0.0f64; $d]; $d];
let mut i = 0;
while i < $d {
rows[i][i] = 1.0;
i += 1;
}
rows[0][0] = 2.0;
let factors = Matrix::<$d>::from_rows_unchecked(rows);
let ldlt = Ldlt::<$d> {
factors: LdltFactors::new_unchecked(factors),
};
ldlt.det()
};
assert_eq!(DET, Ok(2.0));
}
#[test]
fn [<ldlt_solve_const_eval_ $d d>]() {
#[allow(clippy::cast_precision_loss)]
const X: [f64; $d] = {
let ldlt = Ldlt::<$d> {
factors: LdltFactors::new_unchecked(Matrix::<$d>::identity()),
};
let mut b_arr = [0.0f64; $d];
let mut i = 0;
while i < $d {
b_arr[i] = i as f64 + 1.0;
i += 1;
}
let b = Vector::<$d>::new(b_arr);
match ldlt.solve(b) {
Ok(v) => v.into_array(),
Err(_) => [0.0f64; $d],
}
};
#[allow(clippy::cast_precision_loss)]
for i in 0..$d {
let expected = i as f64 + 1.0;
assert!((X[i] - expected).abs() <= 1e-12);
}
}
}
};
}
gen_ldlt_const_eval_tests!(2);
gen_ldlt_const_eval_tests!(3);
gen_ldlt_const_eval_tests!(4);
gen_ldlt_const_eval_tests!(5);
}