use num::Float;
#[derive(Debug, Clone)]
pub struct CholeskyFactorizer<T> {
n: usize,
cholesky_factor: Vec<T>, is_factorized: bool, }
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CholeskyError {
NotPositiveDefinite,
DimensionMismatch,
NotFactorized,
}
impl<T: Float> CholeskyFactorizer<T> {
pub fn new(n: usize) -> Self {
Self {
n,
cholesky_factor: vec![T::zero(); n * n],
is_factorized: false,
}
}
pub fn factorize(&mut self, a: &[T]) -> Result<(), CholeskyError> {
self.is_factorized = false;
if a.len() != self.n * self.n {
return Err(CholeskyError::DimensionMismatch);
}
self.cholesky_factor.fill(T::zero());
let n = self.n;
for i in 0..n {
let row_i = i * n;
for j in 0..=i {
let row_j = j * n;
let mut sum = a[row_i + j];
for k in 0..j {
sum = sum - self.cholesky_factor[row_i + k] * self.cholesky_factor[row_j + k];
}
if i == j {
if sum <= T::zero() {
return Err(CholeskyError::NotPositiveDefinite);
}
self.cholesky_factor[row_i + i] = sum.sqrt();
} else {
self.cholesky_factor[row_i + j] = sum / self.cholesky_factor[row_j + j];
}
}
}
self.is_factorized = true;
Ok(())
}
#[inline]
#[must_use]
pub fn dimension(&self) -> usize {
self.n
}
#[inline]
#[must_use]
pub fn cholesky_factor(&self) -> &[T] {
&self.cholesky_factor
}
pub fn solve(&self, b: &[T]) -> Result<Vec<T>, CholeskyError> {
if !self.is_factorized {
return Err(CholeskyError::NotFactorized);
}
if b.len() != self.n {
return Err(CholeskyError::DimensionMismatch);
}
let n = self.n;
let mut y = vec![T::zero(); n];
for i in 0..n {
let row_i = i * n;
let mut sum = b[i];
for (&lij, &yj) in self.cholesky_factor[row_i..row_i + i]
.iter()
.zip(y[..i].iter())
{
sum = sum - lij * yj;
}
y[i] = sum / self.cholesky_factor[row_i + i];
}
let mut x = vec![T::zero(); n];
for i in (0..n).rev() {
let mut sum = y[i];
for (row_j, &xj) in self
.cholesky_factor
.chunks_exact(n)
.skip(i + 1)
.zip(x.iter().skip(i + 1))
{
sum = sum - row_j[i] * xj;
}
x[i] = sum / self.cholesky_factor[i * n + i];
}
Ok(x)
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn t_cholesky_basic() {
let a = vec![4.0_f64, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0];
let mut factorizer = CholeskyFactorizer::new(3);
let _ = factorizer.factorize(&a);
assert!(3 == factorizer.dimension(), "wrong dimension");
let expected_l = [2.0, 0.0, 0.0, 6.0, 1.0, 0.0, -8.0, 5.0, 3.0];
unit_test_utils::nearly_equal_array(
&expected_l,
factorizer.cholesky_factor(),
1e-10,
1e-12,
);
}
#[test]
fn t_cholesky_solve_linear_system() {
let a = vec![4.0_f64, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0];
let mut factorizer = CholeskyFactorizer::new(3);
let _ = factorizer.factorize(&a);
let rhs = vec![-5.0_f64, 2.0, -3.0];
let x = factorizer.solve(&rhs).unwrap();
let expected_sol = [-280.25_f64, 77., -12.];
unit_test_utils::nearly_equal_array(&expected_sol, &x, 1e-10, 1e-12);
}
#[test]
fn t_cholesky_f32() {
let a = vec![4.0_f32, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0];
let mut factorizer = CholeskyFactorizer::new(3);
factorizer.factorize(&a).unwrap();
let expected_l = [2.0_f32, 0.0, 0.0, 6.0, 1.0, 0.0, -8.0, 5.0, 3.0];
unit_test_utils::nearly_equal_array(&expected_l, factorizer.cholesky_factor(), 1e-5, 1e-6);
let rhs = vec![-5.0_f32, 2.0, -3.0];
let x = factorizer.solve(&rhs).unwrap();
let expected_sol = [-280.25_f32, 77.0, -12.0];
unit_test_utils::nearly_equal_array(&expected_sol, &x, 1e-4, 1e-5);
}
#[test]
fn t_cholesky_not_square_matrix() {
let a = vec![1.0_f64, 2., 7., 5., 9.];
let mut factorizer = CholeskyFactorizer::new(3);
let result = factorizer.factorize(&a);
assert_eq!(result, Err(CholeskyError::DimensionMismatch));
}
#[test]
fn t_cholesky_solve_wrong_dimension_rhs() {
let a = vec![4.0_f64, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0];
let mut factorizer = CholeskyFactorizer::new(3);
let _ = factorizer.factorize(&a);
let rhs = vec![-5.0_f64, 2.0];
let result = factorizer.solve(&rhs);
assert_eq!(result, Err(CholeskyError::DimensionMismatch));
}
#[test]
fn t_cholesky_solve_not_factorized_1() {
let factorizer = CholeskyFactorizer::new(3);
let rhs = vec![-5.0_f64, 2.0];
let result = factorizer.solve(&rhs);
assert_eq!(result, Err(CholeskyError::NotFactorized));
}
#[test]
fn t_cholesky_solve_not_factorized_2() {
let a = vec![1.0_f64, 1.0, 1.0, 1.0];
let mut factorizer = CholeskyFactorizer::new(2);
let factorization_result = factorizer.factorize(&a);
assert_eq!(
factorization_result,
Err(CholeskyError::NotPositiveDefinite)
);
let rhs = vec![-5.0_f64, 2.0];
let result = factorizer.solve(&rhs);
assert_eq!(result, Err(CholeskyError::NotFactorized));
}
}