#![allow(non_snake_case)]
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::{Failed, FailedError};
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
#[derive(Debug, Clone)]
pub struct Cholesky<T: RealNumber, M: BaseMatrix<T>> {
R: M,
t: PhantomData<T>,
}
impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
pub(crate) fn new(R: M) -> Cholesky<T, M> {
Cholesky { R, t: PhantomData }
}
pub fn L(&self) -> M {
let (n, _) = self.R.shape();
let mut R = M::zeros(n, n);
for i in 0..n {
for j in 0..n {
if j <= i {
R.set(i, j, self.R.get(i, j));
}
}
}
R
}
pub fn U(&self) -> M {
let (n, _) = self.R.shape();
let mut R = M::zeros(n, n);
for i in 0..n {
for j in 0..n {
if j <= i {
R.set(j, i, self.R.get(i, j));
}
}
}
R
}
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
let (bn, m) = b.shape();
let (rn, _) = self.R.shape();
if bn != rn {
return Err(Failed::because(
FailedError::SolutionFailed,
&"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R."
.to_string(),
));
}
for k in 0..bn {
for j in 0..m {
for i in 0..k {
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(k, i));
}
b.div_element_mut(k, j, self.R.get(k, k));
}
}
for k in (0..bn).rev() {
for j in 0..m {
for i in k + 1..bn {
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(i, k));
}
b.div_element_mut(k, j, self.R.get(k, k));
}
}
Ok(b)
}
}
pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
fn cholesky(&self) -> Result<Cholesky<T, Self>, Failed> {
self.clone().cholesky_mut()
}
fn cholesky_mut(mut self) -> Result<Cholesky<T, Self>, Failed> {
let (m, n) = self.shape();
if m != n {
return Err(Failed::because(
FailedError::DecompositionFailed,
&"Can\'t do Cholesky decomposition on a non-square matrix".to_string(),
));
}
for j in 0..n {
let mut d = T::zero();
for k in 0..j {
let mut s = T::zero();
for i in 0..k {
s += self.get(k, i) * self.get(j, i);
}
s = (self.get(j, k) - s) / self.get(k, k);
self.set(j, k, s);
d += s * s;
}
d = self.get(j, j) - d;
if d < T::zero() {
return Err(Failed::because(
FailedError::DecompositionFailed,
&"The matrix is not positive definite.".to_string(),
));
}
self.set(j, j, d.sqrt());
}
Ok(Cholesky::new(self))
}
fn cholesky_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.cholesky_mut().and_then(|qr| qr.solve(b))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
#[test]
fn cholesky_decompose() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
let l =
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]);
let u =
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]);
let cholesky = a.cholesky().unwrap();
assert!(cholesky.L().abs().approximate_eq(&l.abs(), 1e-4));
assert!(cholesky.U().abs().approximate_eq(&u.abs(), 1e-4));
assert!(cholesky
.L()
.matmul(&cholesky.U())
.abs()
.approximate_eq(&a.abs(), 1e-4));
}
#[test]
fn cholesky_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
let cholesky = a.cholesky().unwrap();
assert!(cholesky
.solve(b.transpose())
.unwrap()
.transpose()
.approximate_eq(&expected, 1e-4));
}
}