#![allow(clippy::needless_range_loop)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum FormkError {
NotPositiveDefiniteFirst,
NotPositiveDefiniteSecond,
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn formk(
wn: &mut [f64],
wn1: &mut [f64],
m: usize,
col: usize,
theta: f64,
sy: &[f64],
ws_cols: &[&[f64]],
wy_cols: &[&[f64]],
nsub: usize,
ind: &[usize],
nenter: usize,
ileave: usize,
indx2: &[usize],
iupdat: u32,
updatd: bool,
) -> Result<(), FormkError> {
debug_assert!(col <= m);
debug_assert!(col >= 1, "formk requires col ≥ 1");
let n = ind.len();
debug_assert!(nsub <= n);
debug_assert!(ileave <= n);
debug_assert!(nenter <= n);
let two_m = 2 * m;
debug_assert!(wn.len() >= two_m * two_m);
debug_assert!(wn1.len() >= two_m * two_m);
debug_assert_eq!(ws_cols.len(), col);
debug_assert_eq!(wy_cols.len(), col);
let upcl = if updatd {
if iupdat as usize > m {
for jy in 0..m - 1 {
let len = m - 1 - jy;
for k in 0..len {
let src = (jy + 1 + k) * two_m + (jy + 1);
let dst = (jy + k) * two_m + jy;
wn1[dst] = wn1[src];
}
let js = m + jy;
for k in 0..len {
let src = (js + 1 + k) * two_m + (js + 1);
let dst = (js + k) * two_m + js;
wn1[dst] = wn1[src];
}
for k in 0..m - 1 {
let src = (m + 1 + k) * two_m + (jy + 1);
let dst = (m + k) * two_m + jy;
wn1[dst] = wn1[src];
}
}
}
let last = col - 1;
for jy in 0..col {
let mut temp1 = 0.0; let mut temp2 = 0.0; let mut temp3 = 0.0; for k in 0..nsub {
let k1 = ind[k];
temp1 += wy_cols[last][k1] * wy_cols[jy][k1];
}
for k in nsub..n {
let k1 = ind[k];
temp2 += ws_cols[last][k1] * ws_cols[jy][k1];
temp3 += ws_cols[last][k1] * wy_cols[jy][k1];
}
wn1[last * two_m + jy] = temp1;
wn1[(m + last) * two_m + (m + jy)] = temp2;
wn1[(m + last) * two_m + jy] = temp3;
}
for i in 0..col {
let mut temp3 = 0.0;
for k in 0..nsub {
let k1 = ind[k];
temp3 += ws_cols[i][k1] * wy_cols[last][k1];
}
wn1[(m + i) * two_m + last] = temp3;
}
col - 1
} else {
col
};
for iy in 0..upcl {
for jy in 0..=iy {
let mut temp1 = 0.0;
let mut temp2 = 0.0;
let mut temp3 = 0.0;
let mut temp4 = 0.0;
for k in 0..nenter {
let k1 = indx2[k];
temp1 += wy_cols[iy][k1] * wy_cols[jy][k1];
temp2 += ws_cols[iy][k1] * ws_cols[jy][k1];
}
for k in ileave..n {
let k1 = indx2[k];
temp3 += wy_cols[iy][k1] * wy_cols[jy][k1];
temp4 += ws_cols[iy][k1] * ws_cols[jy][k1];
}
wn1[iy * two_m + jy] += temp1 - temp3;
wn1[(m + iy) * two_m + (m + jy)] += -temp2 + temp4;
}
}
for iy in 0..upcl {
for jy in 0..upcl {
let mut temp1 = 0.0;
let mut temp3 = 0.0;
for k in 0..nenter {
let k1 = indx2[k];
temp1 += ws_cols[iy][k1] * wy_cols[jy][k1];
}
for k in ileave..n {
let k1 = indx2[k];
temp3 += ws_cols[iy][k1] * wy_cols[jy][k1];
}
let delta = if iy <= jy {
temp1 - temp3
} else {
-temp1 + temp3
};
wn1[(m + iy) * two_m + jy] += delta;
}
}
let col2 = 2 * col;
for iy in 0..col {
for jy in 0..=iy {
wn[jy * two_m + iy] = wn1[iy * two_m + jy] / theta;
wn[(col + jy) * two_m + (col + iy)] = wn1[(m + iy) * two_m + (m + jy)] * theta;
}
let is = col + iy;
let is1 = m + iy;
for jy in 0..iy {
wn[jy * two_m + is] = -wn1[is1 * two_m + jy];
}
for jy in iy..col {
wn[jy * two_m + is] = wn1[is1 * two_m + jy];
}
wn[iy * two_m + iy] += sy[iy * m + iy];
}
for j in 0..col {
let mut s = wn[j * two_m + j];
for k in 0..j {
let jkj = wn[k * two_m + j];
s -= jkj * jkj;
}
if !s.is_finite() || s <= 0.0 {
return Err(FormkError::NotPositiveDefiniteFirst);
}
let djj = s.sqrt();
wn[j * two_m + j] = djj;
for i in (j + 1)..col {
let mut s = wn[j * two_m + i];
for k in 0..j {
s -= wn[k * two_m + j] * wn[k * two_m + i];
}
wn[j * two_m + i] = s / djj;
}
}
for js in col..col2 {
for i in 0..col {
let mut s = wn[i * two_m + js];
for k in 0..i {
s -= wn[k * two_m + i] * wn[k * two_m + js];
}
wn[i * two_m + js] = s / wn[i * two_m + i];
}
}
for is in col..col2 {
for js in is..col2 {
let mut acc = 0.0;
for k in 0..col {
acc += wn[k * two_m + is] * wn[k * two_m + js];
}
wn[is * two_m + js] += acc;
}
}
for j in 0..col {
let mut s = wn[(col + j) * two_m + (col + j)];
for k in 0..j {
let jkj = wn[(col + k) * two_m + (col + j)];
s -= jkj * jkj;
}
if !s.is_finite() || s <= 0.0 {
return Err(FormkError::NotPositiveDefiniteSecond);
}
let djj = s.sqrt();
wn[(col + j) * two_m + (col + j)] = djj;
for i in (j + 1)..col {
let mut s = wn[(col + j) * two_m + (col + i)];
for k in 0..j {
s -= wn[(col + k) * two_m + (col + j)] * wn[(col + k) * two_m + (col + i)];
}
wn[(col + j) * two_m + (col + i)] = s / djj;
}
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::identity_op, clippy::erasing_op)]
mod tests {
use super::*;
#[test]
fn col_one_two_vars_one_free_one_active_matches_hand_fixture() {
let s = [1.0_f64, 2.0];
let y = [1.0_f64, 1.0];
let theta = 1.0;
let m = 1;
let two_m = 2 * m;
let mut sy = vec![0.0_f64; m * m];
sy[0] = s[0] * y[0] + s[1] * y[1];
let ws_cols: Vec<&[f64]> = vec![&s];
let wy_cols: Vec<&[f64]> = vec![&y];
let ind = [0_usize, 1];
let indx2 = [0_usize; 2];
let nenter = 0;
let ileave = 2;
let mut wn = vec![0.0_f64; two_m * two_m];
let mut wn1 = vec![0.0_f64; two_m * two_m];
formk(
&mut wn, &mut wn1, m, 1, theta, &sy, &ws_cols, &wy_cols, 1, &ind, nenter, ileave,
&indx2, 1, true,
)
.unwrap();
assert!((wn1[0 * two_m + 0] - 1.0).abs() < 1e-12);
assert!((wn1[1 * two_m + 1] - 4.0).abs() < 1e-12);
assert!((wn1[1 * two_m + 0] - 1.0).abs() < 1e-12);
assert!((wn[0 * two_m + 0] - 2.0).abs() < 1e-12);
assert!((wn[0 * two_m + 1] - 0.5).abs() < 1e-12);
assert!((wn[1 * two_m + 1] - 4.25_f64.sqrt()).abs() < 1e-12);
}
#[test]
fn col_one_both_free_zeroes_active_blocks() {
let s = [1.0_f64, 2.0];
let y = [1.0_f64, 1.0];
let theta = 1.0;
let m = 1;
let two_m = 2 * m;
let mut sy = vec![0.0_f64; m * m];
sy[0] = s[0] * y[0] + s[1] * y[1];
let ws_cols: Vec<&[f64]> = vec![&s];
let wy_cols: Vec<&[f64]> = vec![&y];
let ind = [0_usize, 1];
let indx2 = [0_usize; 2];
let mut wn = vec![0.0_f64; two_m * two_m];
let mut wn1 = vec![0.0_f64; two_m * two_m];
formk(
&mut wn, &mut wn1, m, 1, theta, &sy, &ws_cols, &wy_cols, 2, &ind, 0, 2, &indx2, 1, true,
)
.unwrap();
assert!((wn1[0 * two_m + 0] - 2.0).abs() < 1e-12);
assert!((wn1[1 * two_m + 1] - 0.0).abs() < 1e-12);
assert!((wn1[1 * two_m + 0] - 3.0).abs() < 1e-12);
assert!((wn[0 * two_m + 0] - 5.0_f64.sqrt()).abs() < 1e-12);
assert!((wn[0 * two_m + 1] - 3.0 / 5.0_f64.sqrt()).abs() < 1e-12);
assert!((wn[1 * two_m + 1] - (9.0_f64 / 5.0).sqrt()).abs() < 1e-12);
}
#[test]
fn singular_first_block_returns_error() {
let s = [0.0_f64, 0.0];
let y = [0.0_f64, 0.0];
let m = 1;
let two_m = 2 * m;
let sy = vec![0.0_f64; m * m];
let ws_cols: Vec<&[f64]> = vec![&s];
let wy_cols: Vec<&[f64]> = vec![&y];
let ind = [0_usize, 1];
let indx2 = [0_usize; 2];
let mut wn = vec![0.0_f64; two_m * two_m];
let mut wn1 = vec![0.0_f64; two_m * two_m];
let res = formk(
&mut wn, &mut wn1, m, 1, 1.0, &sy, &ws_cols, &wy_cols, 2, &ind, 0, 2, &indx2, 1, true,
);
assert_eq!(res, Err(FormkError::NotPositiveDefiniteFirst));
}
}