use ndarray::Array2;
pub fn cholesky(a: &Array2<f32>) -> Option<Array2<f32>> {
let n = a.nrows();
debug_assert_eq!(n, a.ncols(), "Cholesky requires a square matrix");
let mut l = Array2::<f32>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = 0.0_f32;
for k in 0..j {
sum += l[[i, k]] * l[[j, k]];
}
if i == j {
let diag = a[[i, i]] - sum;
if diag <= 0.0 {
return None; }
l[[i, j]] = diag.sqrt();
} else {
l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
}
}
}
Some(l)
}
pub fn log_det(a: &Array2<f32>) -> f32 {
match cholesky(a) {
Some(l) => {
let n = l.nrows();
let mut sum = 0.0_f32;
for i in 0..n {
sum += l[[i, i]].ln();
}
2.0 * sum
}
None => f32::NEG_INFINITY,
}
}
pub fn log_det_incremental(
chol_s: &Array2<f32>,
cross: &[f32],
diag_cc: f32,
current_log_det: f32,
) -> f32 {
let m = chol_s.nrows();
debug_assert_eq!(cross.len(), m);
let mut a = vec![0.0_f32; m];
for i in 0..m {
let mut sum = 0.0_f32;
for j in 0..i {
sum += chol_s[[i, j]] * a[j];
}
a[i] = (cross[i] - sum) / chol_s[[i, i]];
}
let norm_sq: f32 = a.iter().map(|x| x * x).sum();
let schur = diag_cc - norm_sq;
if schur <= 0.0 {
return f32::NEG_INFINITY;
}
current_log_det + schur.ln()
}
pub fn cholesky_extend(chol_s: &Array2<f32>, cross: &[f32], diag_cc: f32) -> Option<Array2<f32>> {
let m = chol_s.nrows();
debug_assert_eq!(cross.len(), m);
let mut a = vec![0.0_f32; m];
for i in 0..m {
let mut sum = 0.0_f32;
for j in 0..i {
sum += chol_s[[i, j]] * a[j];
}
a[i] = (cross[i] - sum) / chol_s[[i, i]];
}
let norm_sq: f32 = a.iter().map(|x| x * x).sum();
let schur = diag_cc - norm_sq;
if schur <= 0.0 {
return None;
}
let new_diag = schur.sqrt();
let new_size = m + 1;
let mut new_l = Array2::<f32>::zeros((new_size, new_size));
for i in 0..m {
for j in 0..=i {
new_l[[i, j]] = chol_s[[i, j]];
}
}
for j in 0..m {
new_l[[m, j]] = a[j];
}
new_l[[m, m]] = new_diag;
Some(new_l)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_cholesky_identity() {
let eye = Array2::<f32>::eye(3);
let l = cholesky(&eye).expect("identity is positive definite");
for i in 0..3 {
for j in 0..3 {
if i == j {
assert!((l[[i, j]] - 1.0).abs() < 1e-6);
} else {
assert!(l[[i, j]].abs() < 1e-6);
}
}
}
}
#[test]
fn test_cholesky_known_matrix() {
let a = array![[4.0_f32, 2.0], [2.0, 3.0]];
let l = cholesky(&a).expect("known positive definite");
assert!((l[[0, 0]] - 2.0).abs() < 1e-6);
assert!((l[[1, 0]] - 1.0).abs() < 1e-6);
assert!((l[[1, 1]] - 2.0_f32.sqrt()).abs() < 1e-6);
assert!(l[[0, 1]].abs() < 1e-6);
}
#[test]
fn test_log_det_known() {
let a = array![[4.0_f32, 2.0], [2.0, 3.0]];
let ld = log_det(&a);
assert!((ld - 8.0_f32.ln()).abs() < 1e-5);
}
#[test]
fn test_log_det_identity() {
let eye = Array2::<f32>::eye(4);
let ld = log_det(&eye);
assert!(ld.abs() < 1e-6); }
#[test]
fn test_not_positive_definite() {
let a = array![[-1.0_f32, 0.0], [0.0, 1.0]];
assert!(cholesky(&a).is_none());
assert!(log_det(&a) == f32::NEG_INFINITY);
}
#[test]
fn test_cholesky_extend_matches_full() {
let full = array![[4.0_f32, 2.0, 1.0], [2.0, 5.0, 3.0], [1.0, 3.0, 6.0]];
let sub = array![[4.0_f32, 2.0], [2.0, 5.0]];
let chol_sub = cholesky(&sub).unwrap();
let cross = vec![1.0_f32, 3.0]; let diag_cc = 6.0_f32;
let extended = cholesky_extend(&chol_sub, &cross, diag_cc).unwrap();
let full_chol = cholesky(&full).unwrap();
for i in 0..3 {
for j in 0..=i {
assert!(
(extended[[i, j]] - full_chol[[i, j]]).abs() < 1e-5,
"mismatch at [{},{}]: {} vs {}",
i,
j,
extended[[i, j]],
full_chol[[i, j]]
);
}
}
}
#[test]
fn test_log_det_incremental_matches_full() {
let full = array![[4.0_f32, 2.0, 1.0], [2.0, 5.0, 3.0], [1.0, 3.0, 6.0]];
let sub = array![[4.0_f32, 2.0], [2.0, 5.0]];
let chol_sub = cholesky(&sub).unwrap();
let ld_sub = log_det(&sub);
let cross = vec![1.0_f32, 3.0];
let diag_cc = 6.0_f32;
let ld_incr = log_det_incremental(&chol_sub, &cross, diag_cc, ld_sub);
let ld_full = log_det(&full);
assert!(
(ld_incr - ld_full).abs() < 1e-4,
"incremental {} vs full {}",
ld_incr,
ld_full
);
}
}