use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
const RANK_DEFICIENT_PIVOT_FLOOR: f64 = 1e-14;
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum CholeskyGuard {
NonnegativePivot,
FiniteStrict,
}
pub(crate) fn cholesky_factor_in_place(
a: ArrayView2<'_, f64>,
guard: CholeskyGuard,
) -> Option<Array2<f64>> {
let n = a.nrows();
if a.ncols() != n {
return None;
}
if guard == CholeskyGuard::FiniteStrict && a.iter().any(|v| !v.is_finite()) {
return None;
}
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = a[[i, j]];
for k in 0..j {
sum -= l[[i, k]] * l[[j, k]];
}
if i == j {
let pivot_rejected = match guard {
CholeskyGuard::NonnegativePivot => sum <= 0.0,
CholeskyGuard::FiniteStrict => !(sum.is_finite() && sum > 0.0),
};
if pivot_rejected {
return None;
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
Some(l)
}
fn forward_kernel(l: ArrayView2<'_, f64>, b: ArrayView1<'_, f64>) -> Array1<f64> {
let n = l.nrows();
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
y
}
fn back_kernel(l: ArrayView2<'_, f64>, y: ArrayView1<'_, f64>) -> Array1<f64> {
let n = l.nrows();
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for k in (i + 1)..n {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
x
}
pub(crate) fn back_substitution_lower_transpose<'l, 'y>(
l: impl Into<ArrayView2<'l, f64>>,
y: impl Into<ArrayView1<'y, f64>>,
) -> Array1<f64> {
back_kernel(l.into(), y.into())
}
pub(crate) fn back_substitution_lower_transpose_guarded_into(
l: &Array2<f64>,
rhs: &Array1<f64>,
out: &mut Array1<f64>,
) {
let p = rhs.len();
assert_eq!(l.nrows(), p);
assert_eq!(l.ncols(), p);
assert_eq!(out.len(), p);
for i in (0..p).rev() {
let mut v = rhs[i];
for j in (i + 1)..p {
v -= l[[j, i]] * out[j];
}
let d = l[[i, i]];
out[i] = if d.abs() > RANK_DEFICIENT_PIVOT_FLOOR {
v / d
} else {
0.0
};
}
}
pub(crate) fn cholesky_solve_vector<'l, 'b>(
l: impl Into<ArrayView2<'l, f64>>,
b: impl Into<ArrayView1<'b, f64>>,
) -> Array1<f64> {
let l = l.into();
let y = forward_kernel(l, b.into());
back_kernel(l, y.view())
}
pub(crate) fn cholesky_solve_matrix<'l, 'b>(
l: impl Into<ArrayView2<'l, f64>>,
b: impl Into<ArrayView2<'b, f64>>,
) -> Array2<f64> {
let l = l.into();
let b = b.into();
let n = l.nrows();
let m = b.ncols();
let mut out = Array2::<f64>::zeros((n, m));
for c in 0..m {
let y = forward_kernel(l, b.column(c));
let x = back_kernel(l, y.view());
out.column_mut(c).assign(&x);
}
out
}
pub(crate) fn forward_substitution_lower_matrix<'l, 'b>(
l: impl Into<ArrayView2<'l, f64>>,
b: impl Into<ArrayView2<'b, f64>>,
) -> Array2<f64> {
let l = l.into();
let b = b.into();
let n = l.nrows();
let m = b.ncols();
let mut out = Array2::<f64>::zeros((n, m));
for c in 0..m {
let y = forward_kernel(l, b.column(c));
out.column_mut(c).assign(&y);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2, array};
fn fixture_factor() -> Array2<f64> {
array![[2.0, 0.0, 0.0], [1.0, 2.0, 0.0], [1.0, 3.0, 3.0]]
}
fn reconstruct_spd(l: &Array2<f64>) -> Array2<f64> {
l.dot(&l.t())
}
#[test]
fn forward_then_back_solves_the_spd_system() {
let l = fixture_factor();
let a = reconstruct_spd(&l);
let x_true = array![1.5, -2.0, 0.75];
let b = a.dot(&x_true);
let x = cholesky_solve_vector(&l, &b);
for i in 0..3 {
assert!((x[i] - x_true[i]).abs() < 1e-12, "x[{i}] = {}", x[i]);
}
}
#[test]
fn forward_substitution_solves_lower_system() {
let l = fixture_factor();
let y_true = array![3.0, -1.0, 2.0];
let b = l.dot(&y_true);
let y = forward_kernel(l.view(), b.view());
for i in 0..3 {
assert!((y[i] - y_true[i]).abs() < 1e-12, "y[{i}] = {}", y[i]);
}
}
#[test]
fn back_substitution_solves_upper_system() {
let l = fixture_factor();
let x_true = array![0.5, 4.0, -3.0];
let rhs = l.t().dot(&x_true);
let x = back_substitution_lower_transpose(&l, &rhs);
for i in 0..3 {
assert!((x[i] - x_true[i]).abs() < 1e-12, "x[{i}] = {}", x[i]);
}
}
#[test]
fn full_solve_is_forward_then_back_composed() {
let l = fixture_factor();
let b = array![7.0, -2.5, 11.0];
let y = forward_kernel(l.view(), b.view());
let expected = back_substitution_lower_transpose(&l, &y);
let got = cholesky_solve_vector(&l, &b);
assert_eq!(got, expected);
}
#[test]
fn matrix_solve_matches_per_column_vector_solve() {
let l = fixture_factor();
let b = array![[1.0, 0.0, 2.0], [0.0, 1.0, -1.0], [3.0, -2.0, 0.5]];
let x = cholesky_solve_matrix(&l, &b);
for c in 0..b.ncols() {
let col = cholesky_solve_vector(&l, b.column(c));
for r in 0..3 {
assert_eq!(x[[r, c]], col[r], "mismatch at ({r},{c})");
}
}
}
#[test]
fn matrix_solve_recovers_inverse() {
let l = fixture_factor();
let a = reconstruct_spd(&l);
let inv = cholesky_solve_matrix(&l, &Array2::<f64>::eye(3));
let prod = a.dot(&inv);
for i in 0..3 {
for j in 0..3 {
let expect = if i == j { 1.0 } else { 0.0 };
assert!((prod[[i, j]] - expect).abs() < 1e-12, "prod[{i},{j}]");
}
}
}
#[test]
fn forward_matrix_matches_per_column_forward_solve() {
let l = fixture_factor();
let b = array![[2.0, -1.0], [5.0, 0.0], [-3.0, 4.0]];
let y = forward_substitution_lower_matrix(&l, &b);
for c in 0..b.ncols() {
let col = forward_kernel(l.view(), b.column(c));
for r in 0..3 {
assert_eq!(y[[r, c]], col[r], "mismatch at ({r},{c})");
}
}
let recon = l.dot(&y);
for i in 0..3 {
for c in 0..b.ncols() {
assert!((recon[[i, c]] - b[[i, c]]).abs() < 1e-12);
}
}
}
#[test]
fn one_by_one_system() {
let l = array![[2.0_f64]];
let b = array![6.0_f64];
let x = cholesky_solve_vector(&l, &b);
assert!((x[0] - 1.5).abs() < 1e-15);
}
#[test]
fn empty_system_returns_empty() {
let l = Array2::<f64>::zeros((0, 0));
let b = Array1::<f64>::zeros(0);
let x = cholesky_solve_vector(&l, &b);
assert_eq!(x.len(), 0);
}
}