use crate::matrix::DenseMatrix;
use crate::Scalar;
use faer::linalg::solvers::Cholesky;
use faer::prelude::*;
use faer::{ComplexField, Conjugate, Entity, Mat, SimpleEntity};
use numra_core::LinalgError;
pub struct CholeskyFactorization<S: Scalar + Entity> {
chol: Cholesky<S>,
n: usize,
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> CholeskyFactorization<S> {
pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
let nrows = matrix.rows();
let ncols = matrix.cols();
if nrows != ncols {
return Err(LinalgError::NotSquare { nrows, ncols });
}
let n = nrows;
let chol = Cholesky::try_new(matrix.as_faer(), faer::Side::Lower)
.map_err(|_| LinalgError::NotPositiveDefinite)?;
Ok(Self { chol, n })
}
pub fn dim(&self) -> usize {
self.n
}
pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
if b.len() != self.n {
return Err(LinalgError::DimensionMismatch {
expected: (self.n, 1),
actual: (b.len(), 1),
});
}
let mut b_mat = Mat::zeros(self.n, 1);
for (i, &val) in b.iter().enumerate() {
b_mat.write(i, 0, val);
}
let x_mat = self.chol.solve(&b_mat);
let mut x = Vec::with_capacity(self.n);
for i in 0..self.n {
x.push(x_mat.read(i, 0));
}
Ok(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Matrix;
#[test]
fn test_cholesky_solve() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
a.set(0, 0, 4.0);
a.set(0, 1, 2.0);
a.set(1, 0, 2.0);
a.set(1, 1, 3.0);
let chol = CholeskyFactorization::new(&a).unwrap();
assert_eq!(chol.dim(), 2);
let b = vec![6.0, 5.0];
let x = chol.solve(&b).unwrap();
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_cholesky_3x3() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
a.set(0, 0, 25.0);
a.set(0, 1, 15.0);
a.set(0, 2, -5.0);
a.set(1, 0, 15.0);
a.set(1, 1, 18.0);
a.set(1, 2, 0.0);
a.set(2, 0, -5.0);
a.set(2, 1, 0.0);
a.set(2, 2, 11.0);
let chol = CholeskyFactorization::new(&a).unwrap();
assert_eq!(chol.dim(), 3);
let b = vec![35.0, 33.0, 6.0];
let x = chol.solve(&b).unwrap();
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 1.0).abs() < 1e-10);
assert!((x[2] - 1.0).abs() < 1e-10);
}
#[test]
fn test_cholesky_not_positive_definite() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
a.set(0, 0, 1.0);
a.set(0, 1, 2.0);
a.set(1, 0, 2.0);
a.set(1, 1, 1.0);
assert!(CholeskyFactorization::new(&a).is_err());
}
#[test]
fn test_cholesky_not_square() {
let a: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
assert!(CholeskyFactorization::new(&a).is_err());
}
#[test]
fn test_cholesky_repeated_solve() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
a.set(0, 0, 4.0);
a.set(0, 1, 2.0);
a.set(1, 0, 2.0);
a.set(1, 1, 3.0);
let chol = CholeskyFactorization::new(&a).unwrap();
let b1 = vec![6.0, 5.0];
let x1 = chol.solve(&b1).unwrap();
assert!((x1[0] - 1.0).abs() < 1e-10);
assert!((x1[1] - 1.0).abs() < 1e-10);
let b2 = vec![6.0, 1.0];
let x2 = chol.solve(&b2).unwrap();
assert!((x2[0] - 2.0).abs() < 1e-10);
assert!((x2[1] - (-1.0)).abs() < 1e-10);
}
}