pub(crate) fn formt(
theta: f64,
sy: &[f64],
ss: &[f64],
col: usize,
m: usize,
wt: &mut [f64],
) -> Result<(), FormtError> {
debug_assert!(col <= m, "col must be ≤ m");
debug_assert!(sy.len() >= m * m && ss.len() >= m * m && wt.len() >= m * m);
if col == 0 {
return Ok(());
}
for j in 0..col {
wt[j] = theta * ss[j];
}
for i in 1..col {
for j in i..col {
let k1 = i.min(j);
let mut ddum = 0.0;
for k in 0..k1 {
ddum += sy[i * m + k] * sy[j * m + k] / sy[k * m + k];
}
wt[i * m + j] = ddum + theta * ss[i * m + j];
}
}
cholesky_upper_in_place(wt, col, m)
.then_some(())
.ok_or(FormtError::NotPositiveDefinite)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum FormtError {
NotPositiveDefinite,
}
pub(crate) fn cholesky_upper_in_place(t: &mut [f64], col: usize, m: usize) -> bool {
if col == 0 {
return true;
}
for j in 0..col {
let mut s = t[j * m + j];
for k in 0..j {
let jkj = t[k * m + j];
s -= jkj * jkj;
}
if !s.is_finite() || s <= 0.0 {
return false;
}
let djj = s.sqrt();
t[j * m + j] = djj;
for i in (j + 1)..col {
let mut s = t[j * m + i];
for k in 0..j {
s -= t[k * m + j] * t[k * m + i];
}
t[j * m + i] = s / djj;
}
}
true
}
pub(crate) fn solve_upper_tri(j_upper: &[f64], col: usize, m: usize, b: &mut [f64]) {
if col == 0 {
return;
}
for i in (0..col).rev() {
let mut s = b[i];
for k in (i + 1)..col {
s -= j_upper[i * m + k] * b[k];
}
b[i] = s / j_upper[i * m + i];
}
}
pub(crate) fn solve_upper_tri_transposed(j_upper: &[f64], col: usize, m: usize, b: &mut [f64]) {
if col == 0 {
return;
}
for i in 0..col {
let mut s = b[i];
for k in 0..i {
s -= j_upper[k * m + i] * b[k];
}
b[i] = s / j_upper[i * m + i];
}
}
pub(crate) fn bmv(
sy: &[f64],
wt: &[f64],
col: usize,
m: usize,
v: &[f64],
p: &mut [f64],
) -> Result<(), BmvError> {
if col == 0 {
return Ok(());
}
debug_assert!(v.len() >= 2 * col && p.len() >= 2 * col);
for i in 0..col {
let d = wt[i * m + i];
if d == 0.0 || !d.is_finite() {
return Err(BmvError::SingularJ);
}
}
p[col] = v[col];
for i in 1..col {
let mut sum = 0.0;
for k in 0..i {
sum += sy[i * m + k] * v[k] / sy[k * m + k];
}
p[col + i] = v[col + i] + sum;
}
solve_upper_tri_transposed(wt, col, m, &mut p[col..col + col]);
for i in 0..col {
p[i] = v[i] / sy[i * m + i].sqrt();
}
solve_upper_tri(wt, col, m, &mut p[col..col + col]);
for i in 0..col {
p[i] = -p[i] / sy[i * m + i].sqrt();
}
for i in 0..col {
let mut sum = 0.0;
for k in (i + 1)..col {
sum += sy[k * m + i] * p[col + k] / sy[i * m + i];
}
p[i] += sum;
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum BmvError {
SingularJ,
}
#[cfg(test)]
#[allow(clippy::identity_op, clippy::erasing_op)]
mod tests {
use super::*;
#[test]
fn cholesky_round_trip_2x2() {
let m = 3;
let mut t = vec![0.0; m * m];
t[0 * m + 0] = 4.0;
t[0 * m + 1] = 6.0;
t[1 * m + 1] = 13.0;
let ok = cholesky_upper_in_place(&mut t, 2, m);
assert!(ok);
assert!((t[0 * m + 0] - 2.0).abs() < 1e-12);
assert!((t[0 * m + 1] - 3.0).abs() < 1e-12);
assert!((t[1 * m + 1] - 2.0).abs() < 1e-12);
}
#[test]
fn cholesky_rejects_non_pd() {
let m = 2;
let mut t = vec![0.0; m * m];
t[0 * m + 0] = 1.0;
t[0 * m + 1] = 2.0;
t[1 * m + 1] = 1.0;
assert!(!cholesky_upper_in_place(&mut t, 2, m));
}
#[test]
fn solve_upper_tri_inverts_apply() {
let m = 2;
let mut j_upper = vec![0.0; m * m];
j_upper[0 * m + 0] = 2.0;
j_upper[0 * m + 1] = 3.0;
j_upper[1 * m + 1] = 2.0;
let mut b = vec![5.0, 4.0];
solve_upper_tri(&j_upper, 2, m, &mut b);
assert!((b[0] - (-0.5)).abs() < 1e-12);
assert!((b[1] - 2.0).abs() < 1e-12);
}
#[test]
fn solve_upper_tri_transposed_matches_forward_sub() {
let m = 2;
let mut j_upper = vec![0.0; m * m];
j_upper[0 * m + 0] = 2.0;
j_upper[0 * m + 1] = 3.0;
j_upper[1 * m + 1] = 2.0;
let mut b = vec![4.0, 11.0];
solve_upper_tri_transposed(&j_upper, 2, m, &mut b);
assert!((b[0] - 2.0).abs() < 1e-12);
assert!((b[1] - 2.5).abs() < 1e-12);
}
#[test]
fn formt_col_one_gives_theta_ss_then_sqrt() {
let m = 3;
let mut sy = vec![0.0; m * m];
let mut ss = vec![0.0; m * m];
sy[0] = 11.0; ss[0] = 5.0; let theta = 25.0 / 11.0;
let mut wt = vec![0.0; m * m];
formt(theta, &sy, &ss, 1, m, &mut wt).unwrap();
assert!((wt[0] - (theta * 5.0).sqrt()).abs() < 1e-12);
}
#[test]
fn bmv_returns_zero_for_col_zero() {
let sy = vec![0.0; 4];
let wt = vec![0.0; 4];
let v = vec![1.0, 2.0];
let mut p = vec![99.0, 99.0];
assert!(bmv(&sy, &wt, 0, 2, &v, &mut p).is_ok());
assert_eq!(p, vec![99.0, 99.0]);
}
#[test]
fn bmv_col_one_matches_2x2_inverse() {
let m = 2;
let mut sy = vec![0.0; m * m];
let mut ss = vec![0.0; m * m];
sy[0] = 11.0;
ss[0] = 5.0;
let theta = 25.0 / 11.0;
let mut wt = vec![0.0; m * m];
formt(theta, &sy, &ss, 1, m, &mut wt).unwrap();
let v = vec![7.0, 9.0]; let mut p = vec![0.0; 2];
bmv(&sy, &wt, 1, m, &v, &mut p).unwrap();
let d = 11.0;
let exp_p1 = -v[0] / d;
let exp_p2 = v[1] / (theta * 5.0);
assert!((p[0] - exp_p1).abs() < 1e-12, "p1 = {} vs {}", p[0], exp_p1);
assert!((p[1] - exp_p2).abs() < 1e-12, "p2 = {} vs {}", p[1], exp_p2);
}
}