use crate::linalg::LinalgError;
use crate::matrix::vector::Vector;
use crate::traits::{FloatScalar, LinalgScalar, MatrixMut, MatrixRef};
use crate::Matrix;
#[inline]
pub fn cholesky_in_place<T: LinalgScalar>(a: &mut impl MatrixMut<T>) -> Result<(), LinalgError> {
let n = a.nrows();
assert_eq!(n, a.ncols(), "Cholesky decomposition requires a square matrix");
for j in 0..n {
for k in 0..j {
let ljk_conj = (*a.get(j, k)).conj();
let (col_j, col_k) = super::split_two_col_slices(a, j, k, j);
crate::simd::axpy_neg_dispatch(col_j, ljk_conj, col_k);
}
let diag = *a.get(j, j);
if diag.re() <= <T::Real as num_traits::Zero>::zero() {
return Err(LinalgError::NotPositiveDefinite);
}
let ljj = diag.re().lsqrt();
let ljj_t = T::from_real(ljj);
*a.get_mut(j, j) = ljj_t;
let inv_ljj = T::one() / ljj_t;
let col = a.col_as_mut_slice(j, j + 1);
crate::simd::scale_in_place_dispatch(col, inv_ljj);
}
Ok(())
}
#[inline]
pub fn forward_substitute<T: LinalgScalar>(
l: &impl MatrixRef<T>,
b: &[T],
x: &mut [T],
) {
let n = l.nrows();
for i in 0..n {
let mut sum = b[i];
for j in 0..i {
sum = sum - *l.get(i, j) * x[j];
}
x[i] = sum / *l.get(i, i);
}
}
#[inline]
pub fn back_substitute_lt<T: LinalgScalar>(
l: &impl MatrixRef<T>,
b: &[T],
x: &mut [T],
) {
let n = l.nrows();
for i in (0..n).rev() {
let mut sum = b[i];
for j in (i + 1)..n {
sum = sum - (*l.get(j, i)).conj() * x[j];
}
x[i] = sum / (*l.get(i, i)).conj();
}
}
pub fn cholesky_rank1_update<T: FloatScalar>(
l: &mut impl MatrixMut<T>,
v: &mut [T],
) -> Result<(), LinalgError> {
cholesky_rank1_impl(l, v, T::one())
}
pub fn cholesky_rank1_downdate<T: FloatScalar>(
l: &mut impl MatrixMut<T>,
v: &mut [T],
) -> Result<(), LinalgError> {
cholesky_rank1_impl(l, v, T::one().neg())
}
fn cholesky_rank1_impl<T: FloatScalar>(
l: &mut impl MatrixMut<T>,
v: &mut [T],
sign: T,
) -> Result<(), LinalgError> {
let n = l.nrows();
debug_assert_eq!(n, l.ncols());
debug_assert_eq!(v.len(), n);
for j in 0..n {
let ljj = *l.get(j, j);
let vj = v[j];
let r = if sign > T::zero() {
ljj.hypot(vj)
} else {
let arg = ljj * ljj + sign * vj * vj;
if arg <= T::zero() {
return Err(LinalgError::NotPositiveDefinite);
}
arg.sqrt()
};
let c = r / ljj;
let s = vj / ljj;
*l.get_mut(j, j) = r;
for i in (j + 1)..n {
let lij = *l.get(i, j);
let new_lij = (lij + sign * s * v[i]) / c;
*l.get_mut(i, j) = new_lij;
v[i] = c * v[i] - s * new_lij;
}
}
Ok(())
}
#[inline(always)]
fn cholesky_direct<T: LinalgScalar, const N: usize>(
l: &mut Matrix<T, N, N>,
) -> Result<(), LinalgError> {
let zero_r = <T::Real as num_traits::Zero>::zero();
for j in 0..N {
for k in 0..j {
let ljk_conj = l.data[k][j].conj();
for i in j..N {
l.data[j][i] = l.data[j][i] - ljk_conj * l.data[k][i];
}
}
let diag = l.data[j][j];
if diag.re() <= zero_r {
return Err(LinalgError::NotPositiveDefinite);
}
let ljj = diag.re().lsqrt();
let ljj_t = T::from_real(ljj);
l.data[j][j] = ljj_t;
let inv_ljj = T::one() / ljj_t;
for i in (j + 1)..N {
l.data[j][i] = l.data[j][i] * inv_ljj;
}
}
Ok(())
}
#[derive(Debug)]
pub struct CholeskyDecomposition<T, const N: usize> {
l: Matrix<T, N, N>,
}
impl<T: LinalgScalar, const N: usize> CholeskyDecomposition<T, N> {
#[inline]
pub fn new(a: &Matrix<T, N, N>) -> Result<Self, LinalgError> {
let mut l = *a;
if N <= 6 {
cholesky_direct(&mut l)?;
} else {
cholesky_in_place(&mut l)?;
}
Ok(Self { l })
}
pub fn l(&self) -> &Matrix<T, N, N> {
&self.l
}
pub fn l_full(&self) -> Matrix<T, N, N> {
let mut out = self.l;
for j in 0..N {
for i in 0..j {
out.data[j][i] = T::zero();
}
}
out
}
pub fn solve_matrix<const P: usize>(&self, b: &Matrix<T, N, P>) -> Matrix<T, N, P> {
let mut x = Matrix::<T, N, P>::zeros();
for col in 0..P {
let b_flat: [T; N] = core::array::from_fn(|i| b[(i, col)]);
let mut y = [T::zero(); N];
let mut x_col = [T::zero(); N];
forward_substitute(&self.l, &b_flat, &mut y);
back_substitute_lt(&self.l, &y, &mut x_col);
for i in 0..N {
x[(i, col)] = x_col[i];
}
}
x
}
pub fn solve(&self, b: &Vector<T, N>) -> Vector<T, N> {
let b_flat: [T; N] = core::array::from_fn(|i| b[i]);
let mut y = [T::zero(); N];
let mut x = [T::zero(); N];
forward_substitute(&self.l, &b_flat, &mut y);
back_substitute_lt(&self.l, &y, &mut x);
Vector::from_array(x)
}
pub fn det(&self) -> T {
let mut prod = T::one();
for i in 0..N {
prod = prod * self.l[(i, i)];
}
prod * prod
}
pub fn ln_det(&self) -> T {
let two = T::one() + T::one();
let mut sum = T::zero();
for i in 0..N {
sum = sum + self.l[(i, i)].lln();
}
sum * two
}
pub fn rank1_update(&mut self, v: &mut Vector<T, N>) -> Result<(), LinalgError>
where
T: FloatScalar,
{
cholesky_rank1_update(&mut self.l, v.as_mut_slice())
}
pub fn rank1_downdate(&mut self, v: &mut Vector<T, N>) -> Result<(), LinalgError>
where
T: FloatScalar,
{
cholesky_rank1_downdate(&mut self.l, v.as_mut_slice())
}
pub fn inverse(&self) -> Matrix<T, N, N> {
let mut inv = Matrix::<T, N, N>::zeros();
let mut e = [T::zero(); N];
let mut y = [T::zero(); N];
let mut x = [T::zero(); N];
for col in 0..N {
if col > 0 {
e[col - 1] = T::zero();
}
e[col] = T::one();
forward_substitute(&self.l, &e, &mut y);
back_substitute_lt(&self.l, &y, &mut x);
for row in 0..N {
inv.data[col][row] = x[row];
}
}
inv
}
}
impl<T: LinalgScalar, const N: usize> Matrix<T, N, N> {
#[inline]
pub fn cholesky(&self) -> Result<CholeskyDecomposition<T, N>, LinalgError> {
CholeskyDecomposition::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn spd_2x2() -> Matrix<f64, 2, 2> {
Matrix::new([[4.0, 2.0], [2.0, 3.0]])
}
fn spd_3x3() -> Matrix<f64, 3, 3> {
Matrix::new([
[4.0, 2.0, 1.0],
[2.0, 10.0, 3.5],
[1.0, 3.5, 4.5],
])
}
#[test]
fn cholesky_2x2() {
let a = spd_2x2();
let chol = a.cholesky().unwrap();
let l = chol.l_full();
let reconstructed = l * l.transpose();
for i in 0..2 {
for j in 0..2 {
assert!(
(reconstructed[(i, j)] - a[(i, j)]).abs() < 1e-12,
"mismatch at ({},{})",
i,
j
);
}
}
}
#[test]
fn cholesky_3x3() {
let a = spd_3x3();
let chol = a.cholesky().unwrap();
let l = chol.l_full();
let reconstructed = l * l.transpose();
for i in 0..3 {
for j in 0..3 {
assert!(
(reconstructed[(i, j)] - a[(i, j)]).abs() < 1e-12,
"mismatch at ({},{})",
i,
j
);
}
}
}
#[test]
fn cholesky_solve() {
let a = spd_2x2();
let b = Vector::from_array([8.0, 7.0]);
let chol = a.cholesky().unwrap();
let x = chol.solve(&b);
for i in 0..2 {
let mut sum = 0.0;
for j in 0..2 {
sum += a[(i, j)] * x[j];
}
assert!((sum - b[i]).abs() < 1e-12, "residual[{}] = {}", i, sum - b[i]);
}
}
#[test]
fn cholesky_solve_3x3() {
let a = spd_3x3();
let b = Vector::from_array([1.0, 2.0, 3.0]);
let chol = a.cholesky().unwrap();
let x = chol.solve(&b);
for i in 0..3 {
let mut sum = 0.0;
for j in 0..3 {
sum += a[(i, j)] * x[j];
}
assert!((sum - b[i]).abs() < 1e-10, "residual[{}] = {}", i, sum - b[i]);
}
}
#[test]
fn cholesky_det() {
let a = spd_2x2();
let chol = a.cholesky().unwrap();
let det_chol = chol.det();
let det_lu = a.det();
assert!((det_chol - det_lu).abs() < 1e-12);
}
#[test]
fn cholesky_ln_det() {
let a = spd_2x2();
let chol = a.cholesky().unwrap();
let ln_det = chol.ln_det();
let expected = chol.det().ln();
assert!((ln_det - expected).abs() < 1e-12);
}
#[test]
fn cholesky_inverse() {
let a = spd_3x3();
let chol = a.cholesky().unwrap();
let a_inv = chol.inverse();
let id = a * a_inv;
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(id[(i, j)] - expected).abs() < 1e-10,
"id[({},{})] = {}, expected {}",
i,
j,
id[(i, j)],
expected
);
}
}
}
#[test]
fn cholesky_not_positive_definite() {
let a = Matrix::new([[1.0_f64, 5.0], [5.0, 1.0]]);
assert_eq!(a.cholesky().unwrap_err(), LinalgError::NotPositiveDefinite);
}
#[test]
fn cholesky_in_place_generic() {
let mut a = spd_2x2();
let result = cholesky_in_place(&mut a);
assert!(result.is_ok());
}
#[test]
fn cholesky_identity() {
let id: Matrix<f64, 3, 3> = Matrix::eye();
let chol = id.cholesky().unwrap();
let l = chol.l_full();
assert_eq!(l, id);
}
#[test]
fn rank1_update_2x2() {
let a = spd_2x2();
let mut l = a.cholesky().unwrap().l_full();
let v_orig = [1.0_f64, 0.5];
let mut v = v_orig;
cholesky_rank1_update(&mut l, &mut v).unwrap();
let p_new = l * l.transpose();
assert!((p_new[(0, 0)] - 5.0).abs() < 1e-12); assert!((p_new[(0, 1)] - 2.5).abs() < 1e-12); assert!((p_new[(1, 0)] - 2.5).abs() < 1e-12);
assert!((p_new[(1, 1)] - 3.25).abs() < 1e-12); }
#[test]
fn rank1_downdate_roundtrip() {
let a = spd_2x2();
let v_orig = [0.5_f64, 0.3];
let v_col = Matrix::new([[0.5], [0.3_f64]]);
let a_aug = a + v_col * v_col.transpose();
let mut l = a_aug.cholesky().unwrap().l_full();
let mut v = v_orig;
cholesky_rank1_downdate(&mut l, &mut v).unwrap();
let recovered = l * l.transpose();
for i in 0..2 {
for j in 0..2 {
assert!(
(recovered[(i, j)] - a[(i, j)]).abs() < 1e-10,
"mismatch at ({},{}): {} vs {}",
i, j, recovered[(i, j)], a[(i, j)]
);
}
}
}
#[test]
fn rank1_downdate_fails_non_pd() {
let mut l = Matrix::<f64, 2, 2>::eye();
let mut v = [1.5_f64, 0.0];
let result = cholesky_rank1_downdate(&mut l, &mut v);
assert_eq!(result.unwrap_err(), LinalgError::NotPositiveDefinite);
}
#[test]
fn rank1_update_3x3() {
let a = spd_3x3();
let mut l = a.cholesky().unwrap().l_full();
let v_orig = [0.3_f64, 0.7, 0.1];
let mut v = v_orig;
cholesky_rank1_update(&mut l, &mut v).unwrap();
let p_new = l * l.transpose();
let v_col = Matrix::new([[0.3], [0.7], [0.1_f64]]);
let p_expected = a + v_col * v_col.transpose();
for i in 0..3 {
for j in 0..3 {
assert!(
(p_new[(i, j)] - p_expected[(i, j)]).abs() < 1e-10,
"mismatch at ({},{})",
i, j
);
}
}
}
#[test]
fn rank1_update_via_decomposition() {
let a = spd_2x2();
let mut chol = a.cholesky().unwrap();
let mut v = Vector::from_array([1.0_f64, 0.5]);
chol.rank1_update(&mut v).unwrap();
let l = chol.l_full();
let result = l * l.transpose();
assert!((result[(0, 0)] - 5.0).abs() < 1e-12);
}
#[test]
fn rank1_downdate_via_decomposition() {
let a = spd_2x2();
let v_col = Matrix::new([[0.5], [0.3_f64]]);
let a_aug = a + v_col * v_col.transpose();
let mut chol = a_aug.cholesky().unwrap();
let mut v = Vector::from_array([0.5_f64, 0.3]);
chol.rank1_downdate(&mut v).unwrap();
let l = chol.l_full();
let recovered = l * l.transpose();
for i in 0..2 {
for j in 0..2 {
assert!((recovered[(i, j)] - a[(i, j)]).abs() < 1e-10);
}
}
}
}