use super::Matrix;
use crate::{to_i32, CcBool, StrError, C_FALSE, C_TRUE};
extern "C" {
fn c_dpotrf(upper: CcBool, n: *const i32, a: *mut f64, lda: *const i32, info: *mut i32);
}
pub fn mat_cholesky(a: &mut Matrix, upper: bool) -> Result<(), StrError> {
let (m, n) = a.dims();
if m != n {
return Err("the matrix must be square");
}
let c_upper = if upper { C_TRUE } else { C_FALSE };
let m_i32 = to_i32(m);
let lda = m_i32;
let mut info = 0;
unsafe { c_dpotrf(c_upper, &m_i32, a.as_mut_data().as_mut_ptr(), &lda, &mut info) }
if info < 0 {
println!("LAPACK ERROR (dpotrf): Argument #{} had an illegal value", -info);
return Err("LAPACK ERROR (dpotrf): An argument had an illegal value");
} else if info > 0 {
println!(
"LAPACK ERROR (dpotrf): The leading minor of order {} is not positive definite",
info
);
return Err("LAPACK ERROR (dpotrf): Positive definite check failed");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{mat_cholesky, Matrix};
use crate::{mat_approx_eq, math};
fn calc_l_times_lt(l_and_a: &Matrix) -> Matrix {
let m = l_and_a.nrow();
let mut l_lt = Matrix::new(m, m);
for i in 0..m {
for j in 0..m {
for k in 0..m {
let l_ik = if i >= k { l_and_a.get(i, k) } else { 0.0 };
let l_jk = if j >= k { l_and_a.get(j, k) } else { 0.0 };
l_lt.add(i, j, l_ik * l_jk);
}
}
}
l_lt
}
fn calc_ut_times_u(u_and_a: &Matrix) -> Matrix {
let m = u_and_a.nrow();
let mut ut_u = Matrix::new(m, m);
for i in 0..m {
for j in 0..m {
for k in 0..m {
let u_ki = if i >= k { u_and_a.get(k, i) } else { 0.0 };
let u_kj = if j >= k { u_and_a.get(k, j) } else { 0.0 };
ut_u.add(i, j, u_ki * u_kj);
}
}
}
ut_u
}
#[test]
fn mat_cholesky_fails_on_wrong_dims() {
let mut a_wrong = Matrix::new(2, 3);
assert_eq!(mat_cholesky(&mut a_wrong, false), Err("the matrix must be square"));
}
#[test]
fn mat_cholesky_3x3_lower_works() {
let (a01, a02) = (15.0, -5.0);
let a12 = 0.0;
#[rustfmt::skip]
let a_full = Matrix::from(&[
[25.0, a01, a02],
[ a01, 18.0, a12],
[ a02, a12, 11.0],
]);
#[rustfmt::skip]
let a_lower = Matrix::from(&[
[25.0, 0.0, 0.0],
[ a01, 18.0, 0.0],
[ a02, a12, 11.0],
]);
let mut l_and_a = a_full.clone();
mat_cholesky(&mut l_and_a, false).unwrap(); #[rustfmt::skip]
let l_and_a_correct = Matrix::from(&[
[ 5.0, a01, a02],
[ 3.0, 3.0, a12],
[-1.0, 1.0, 3.0],
]);
mat_approx_eq(&l_and_a, &l_and_a_correct, 1e-15);
let l_lt = calc_l_times_lt(&l_and_a);
mat_approx_eq(&l_lt, &a_full, 1e-15);
let mut l = a_lower.clone();
mat_cholesky(&mut l, false).unwrap();
let nil = 0.0;
#[rustfmt::skip]
let l_correct = Matrix::from(&[
[ 5.0, nil, nil],
[ 3.0, 3.0, nil],
[-1.0, 1.0, 3.0],
]);
mat_approx_eq(&l, &l_correct, 1e-15);
let l_lt = calc_l_times_lt(&l);
mat_approx_eq(&l_lt, &a_full, 1e-15);
}
#[test]
fn mat_cholesky_3x3_upper_works() {
let (a01, a02) = (15.0, -5.0);
let a12 = 0.0;
#[rustfmt::skip]
let a_full = Matrix::from(&[
[25.0, a01, a02],
[ a01, 18.0, a12],
[ a02, a12, 11.0],
]);
#[rustfmt::skip]
let a_upper = Matrix::from(&[
[25.0, a01, a02],
[ 0.0, 18.0, a12],
[ 0.0, 0.0, 11.0],
]);
let mut u_and_a = a_full.clone();
mat_cholesky(&mut u_and_a, true).unwrap(); #[rustfmt::skip]
let u_and_a_correct = Matrix::from(&[
[5.0, 3.0,-1.0],
[a01, 3.0, 1.0],
[a02, a12, 3.0],
]);
mat_approx_eq(&u_and_a, &u_and_a_correct, 1e-15);
let ut_u = calc_ut_times_u(&u_and_a);
mat_approx_eq(&ut_u, &a_full, 1e-15);
let mut u = a_upper.clone();
mat_cholesky(&mut u, true).unwrap();
let nil = 0.0;
#[rustfmt::skip]
let u_and_a_correct = Matrix::from(&[
[5.0, 3.0,-1.0],
[nil, 3.0, 1.0],
[nil, nil, 3.0],
]);
mat_approx_eq(&u, &u_and_a_correct, 1e-15);
let ut_u = calc_ut_times_u(&u);
mat_approx_eq(&ut_u, &a_full, 1e-15);
}
#[test]
fn mat_cholesky_5x5_lower_works() {
let nil = 0.0;
let (a01, a02, a03, a04) = (1.0, 1.0, 3.0, 2.0);
let (___, a12, a13, a14) = (nil, 2.0, 1.0, 1.0);
let (___, __p, a23, a24) = (nil, nil, 1.0, 5.0);
let (___, __p, __q, a34) = (nil, nil, nil, 1.0);
#[rustfmt::skip]
let a_full = Matrix::from(&[
[2.0, a01, a02, a03, a04],
[a01, 2.0, a12, a13, a14],
[a02, a12, 9.0, a23, a24],
[a03, a13, a23, 7.0, a34],
[a04, a14, a24, a34, 8.0],
]);
#[rustfmt::skip]
let a_lower = Matrix::from(&[
[2.0, 0.0, 0.0, 0.0, 0.0],
[a01, 2.0, 0.0, 0.0, 0.0],
[a02, a12, 9.0, 0.0, 0.0],
[a03, a13, a23, 7.0, 0.0],
[a04, a14, a24, a34, 8.0],
]);
let mut l_and_a = a_full.clone();
mat_cholesky(&mut l_and_a, false).unwrap(); let sqrt2 = math::SQRT_2;
#[rustfmt::skip]
let l_and_a_correct = Matrix::from(&[
[ sqrt2, a01, a02, a03, a04],
[1.0/sqrt2, f64::sqrt(3.0/2.0), a12, a13, a14],
[1.0/sqrt2, f64::sqrt(3.0/2.0), f64::sqrt(7.0), a23, a24],
[3.0/sqrt2, -1.0/f64::sqrt(6.0), 0.0, f64::sqrt(7.0/3.0), a34],
[ sqrt2, 0.0, 4.0/f64::sqrt(7.0), -2.0*f64::sqrt(3.0/7.0), sqrt2],
]);
mat_approx_eq(&l_and_a, &l_and_a_correct, 1e-14);
let l_lt = calc_l_times_lt(&l_and_a);
mat_approx_eq(&l_lt, &a_full, 1e-14);
let mut l = a_lower.clone();
mat_cholesky(&mut l, false).unwrap();
let l_lt = calc_l_times_lt(&l);
mat_approx_eq(&l_lt, &a_full, 1e-14);
}
#[test]
fn mat_cholesky_5x5_upper_works() {
let nil = 0.0;
let (a01, a02, a03, a04) = (1.0, 1.0, 3.0, 2.0);
let (___, a12, a13, a14) = (nil, 2.0, 1.0, 1.0);
let (___, __p, a23, a24) = (nil, nil, 1.0, 5.0);
let (___, __p, __q, a34) = (nil, nil, nil, 1.0);
#[rustfmt::skip]
let a_full = Matrix::from(&[
[2.0, a01, a02, a03, a04],
[a01, 2.0, a12, a13, a14],
[a02, a12, 9.0, a23, a24],
[a03, a13, a23, 7.0, a34],
[a04, a14, a24, a34, 8.0],
]);
#[rustfmt::skip]
let a_upper = Matrix::from(&[
[2.0, a01, a02, a03, a04],
[0.0, 2.0, a12, a13, a14],
[0.0, 0.0, 9.0, a23, a24],
[0.0, 0.0, 0.0, 7.0, a34],
[0.0, 0.0, 0.0, 0.0, 8.0],
]);
let mut u_and_a = a_full.clone();
mat_cholesky(&mut u_and_a, true).unwrap(); let ut_u = calc_ut_times_u(&u_and_a);
mat_approx_eq(&ut_u, &a_full, 1e-14);
let mut u = a_upper.clone();
mat_cholesky(&mut u, true).unwrap();
let ut_u = calc_ut_times_u(&u);
mat_approx_eq(&ut_u, &a_full, 1e-14);
}
#[test]
fn mat_cholesky_captures_non_positive_definite() {
let (a01, a02) = (15.0, -5.0);
let a12 = 0.0;
#[rustfmt::skip]
let a_full = Matrix::from(&[
[25.0, a01, a02],
[ a01, -18.0, a12],
[ a02, a12, 11.0],
]);
let mut res = a_full.clone();
assert_eq!(
mat_cholesky(&mut res, true).err(),
Some("LAPACK ERROR (dpotrf): Positive definite check failed")
);
}
}