use crate::dense::factor::{factor_frontal, BunchKaufmanParams, FrontalFactors};
use crate::dense::matrix::SymmetricMatrix;
use crate::dense::schur_kernel;
use crate::error::FeralError;
pub const BLOCK_SIZE: usize = 32;
pub(crate) fn factor_block32(
matrix: &SymmetricMatrix,
ncol: usize,
may_delay: bool,
params: &BunchKaufmanParams,
) -> Result<FrontalFactors, FeralError> {
if matrix.n != BLOCK_SIZE {
return Err(FeralError::InvalidInput(format!(
"factor_block32: matrix size {} != BLOCK_SIZE {}",
matrix.n, BLOCK_SIZE
)));
}
factor_frontal(matrix, ncol, may_delay, params)
}
pub(crate) fn update_1x1_block32(a: &mut [f64], p: usize, fma: bool) {
debug_assert!(a.len() >= BLOCK_SIZE * BLOCK_SIZE);
debug_assert!(p < BLOCK_SIZE);
let n = BLOCK_SIZE;
let d = a[p * n + p];
if d.abs() == 0.0 {
return;
}
let inv_d = 1.0 / d;
for i in (p + 1)..n {
a[p * n + i] *= inv_d;
}
let mut j = p + 1;
while j + 3 < n {
let alpha0 = a[p * n + j] * d;
let alpha1 = a[p * n + (j + 1)] * d;
let alpha2 = a[p * n + (j + 2)] * d;
let alpha3 = a[p * n + (j + 3)] * d;
if alpha0 != 0.0 || alpha1 != 0.0 || alpha2 != 0.0 || alpha3 != 0.0 {
let (before, rest) = a.split_at_mut(j * n);
let (col_j, rest1) = rest.split_at_mut(n);
let (col_j1, rest2) = rest1.split_at_mut(n);
let (col_j2, col_j3_and_after) = rest2.split_at_mut(n);
let dst0 = &mut col_j[j..n];
let dst1 = &mut col_j1[(j + 1)..n];
let dst2 = &mut col_j2[(j + 2)..n];
let dst3 = &mut col_j3_and_after[(j + 3)..n];
if fma {
schur_kernel::schur_panel_minus_fma_strided_quad(
dst0,
dst1,
dst2,
dst3,
before,
p,
1,
n,
j,
&[alpha0],
&[alpha1],
&[alpha2],
&[alpha3],
);
} else {
schur_kernel::schur_panel_minus_nofma_strided_quad(
dst0,
dst1,
dst2,
dst3,
before,
p,
1,
n,
j,
&[alpha0],
&[alpha1],
&[alpha2],
&[alpha3],
);
}
}
j += 4;
}
if j + 1 < n {
let alpha0 = a[p * n + j] * d;
let alpha1 = a[p * n + (j + 1)] * d;
if alpha0 != 0.0 || alpha1 != 0.0 {
let (before, rest) = a.split_at_mut(j * n);
let (col_j, after_j) = rest.split_at_mut(n);
let dst0 = &mut col_j[j..n];
let dst1 = &mut after_j[(j + 1)..n];
if fma {
schur_kernel::schur_panel_minus_fma_strided_dual(
dst0,
dst1,
before,
p,
1,
n,
j,
&[alpha0],
&[alpha1],
);
} else {
schur_kernel::schur_panel_minus_nofma_strided_dual(
dst0,
dst1,
before,
p,
1,
n,
j,
&[alpha0],
&[alpha1],
);
}
}
j += 2;
}
if j < n {
let alpha = a[p * n + j] * d;
if alpha != 0.0 {
let (before, rest) = a.split_at_mut(j * n);
let src = &before[p * n + j..p * n + n];
let dst = &mut rest[j..n];
if fma {
schur_kernel::axpy_minus_unroll4(dst, src, alpha);
} else {
schur_kernel::axpy_minus_unroll4_nofma(dst, src, alpha);
}
}
}
}
pub(crate) fn update_2x2_block32(a: &mut [f64], p: usize, d11: f64, d21: f64, d22: f64, fma: bool) {
debug_assert!(a.len() >= BLOCK_SIZE * BLOCK_SIZE);
debug_assert!(p + 1 < BLOCK_SIZE);
let n = BLOCK_SIZE;
let det = d11 * d22 - d21 * d21;
if det.abs() == 0.0 {
return;
}
let inv_det = 1.0 / det;
for i in (p + 2)..n {
let a_ik = a[p * n + i];
let a_ik1 = a[(p + 1) * n + i];
a[p * n + i] = (d22 * a_ik - d21 * a_ik1) * inv_det;
a[(p + 1) * n + i] = (d11 * a_ik1 - d21 * a_ik) * inv_det;
}
for j in (p + 2)..n {
let l_j0 = a[p * n + j];
let l_j1 = a[(p + 1) * n + j];
let dl_j0 = d11 * l_j0 + d21 * l_j1;
let dl_j1 = d21 * l_j0 + d22 * l_j1;
let (before, rest) = a.split_at_mut(j * n);
let src0 = &before[p * n + j..p * n + n];
let src1 = &before[(p + 1) * n + j..(p + 1) * n + n];
let dst = &mut rest[j..n];
if fma {
schur_kernel::axpy2_minus_unroll4(dst, src0, dl_j0, src1, dl_j1);
} else {
schur_kernel::axpy2_minus_unroll4_nofma(dst, src0, dl_j0, src1, dl_j1);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn from_lower(rows: &[[f64; BLOCK_SIZE]; BLOCK_SIZE]) -> SymmetricMatrix {
let mut data = vec![0.0f64; BLOCK_SIZE * BLOCK_SIZE];
for j in 0..BLOCK_SIZE {
for i in j..BLOCK_SIZE {
data[j * BLOCK_SIZE + i] = rows[i][j];
}
}
SymmetricMatrix {
n: BLOCK_SIZE,
data,
}
}
fn seeded_indefinite_32x32() -> SymmetricMatrix {
let mut state: u64 = 0x9E3779B97F4A7C15;
let mut next = || -> f64 {
state = state.wrapping_add(0x9E3779B97F4A7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^= z >> 31;
((z >> 11) as f64) * f64::from_bits(0x3CA0_0000_0000_0000)
};
let mut rows = [[0.0f64; BLOCK_SIZE]; BLOCK_SIZE];
for i in 0..BLOCK_SIZE {
for j in 0..=i {
if i == j {
rows[i][j] = 2.0 * next() - 1.0;
} else {
rows[i][j] = next() - 0.5;
}
}
}
from_lower(&rows)
}
fn assert_factors_bit_equal(actual: &FrontalFactors, expected: &FrontalFactors) {
assert_eq!(actual.nrow, expected.nrow, "nrow");
assert_eq!(actual.ncol, expected.ncol, "ncol");
assert_eq!(actual.nelim, expected.nelim, "nelim");
assert_eq!(actual.n_delayed, expected.n_delayed, "n_delayed");
assert_eq!(actual.inertia, expected.inertia, "inertia");
assert_eq!(actual.perm, expected.perm, "perm");
assert_eq!(actual.perm_inv, expected.perm_inv, "perm_inv");
assert_eq!(actual.l.len(), expected.l.len(), "L length");
for k in 0..actual.l.len() {
assert_eq!(
actual.l[k].to_bits(),
expected.l[k].to_bits(),
"L[{k}] mismatch: actual={} expected={}",
actual.l[k],
expected.l[k]
);
}
assert_eq!(actual.d_diag.len(), expected.d_diag.len(), "d_diag length");
for k in 0..actual.d_diag.len() {
assert_eq!(
actual.d_diag[k].to_bits(),
expected.d_diag[k].to_bits(),
"d_diag[{k}] mismatch"
);
}
assert_eq!(
actual.d_subdiag.len(),
expected.d_subdiag.len(),
"d_subdiag length"
);
for k in 0..actual.d_subdiag.len() {
assert_eq!(
actual.d_subdiag[k].to_bits(),
expected.d_subdiag[k].to_bits(),
"d_subdiag[{k}] mismatch"
);
}
assert_eq!(
actual.contrib.len(),
expected.contrib.len(),
"contrib length"
);
for k in 0..actual.contrib.len() {
assert_eq!(
actual.contrib[k].to_bits(),
expected.contrib[k].to_bits(),
"contrib[{k}] mismatch"
);
}
}
fn dup_lower(src: &SymmetricMatrix) -> (SymmetricMatrix, SymmetricMatrix) {
let a = SymmetricMatrix {
n: src.n,
data: src.data.clone(),
};
let b = SymmetricMatrix {
n: src.n,
data: src.data.clone(),
};
(a, b)
}
fn seeded_block_1024(seed: u64) -> Vec<f64> {
let mut state: u64 = seed;
let mut next = || -> f64 {
state = state.wrapping_add(0x9E3779B97F4A7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^= z >> 31;
((z >> 11) as f64) * f64::from_bits(0x3CA0_0000_0000_0000)
};
let mut data = vec![0.0f64; BLOCK_SIZE * BLOCK_SIZE];
for j in 0..BLOCK_SIZE {
for i in j..BLOCK_SIZE {
let v = if i == j {
2.0 * next() - 1.0
} else {
next() - 0.5
};
data[j * BLOCK_SIZE + i] = v;
}
}
data
}
fn assert_blocks_bit_equal(actual: &[f64], expected: &[f64], context: &str) {
assert_eq!(actual.len(), expected.len(), "{}: block length", context);
for k in 0..actual.len() {
assert_eq!(
actual[k].to_bits(),
expected[k].to_bits(),
"{}: a[{k}] mismatch (actual={}, expected={})",
context,
actual[k],
expected[k],
);
}
}
#[test]
fn update_1x1_block32_matches_do_1x1_update_at_p0() {
let a0 = seeded_block_1024(0xA5A5_5A5A_DEAD_BEEF);
let mut a_scalar = a0.clone();
let mut a_block = a0;
crate::dense::factor::do_1x1_update(&mut a_scalar, BLOCK_SIZE, 0, false);
update_1x1_block32(&mut a_block, 0, false);
assert_blocks_bit_equal(&a_block, &a_scalar, "update_1x1_block32 at p=0");
}
#[test]
fn update_1x1_block32_matches_do_1x1_update_at_p5() {
let a0 = seeded_block_1024(0x1234_5678_9ABC_DEF0);
let mut a_staged = a0;
for p in 0..5 {
crate::dense::factor::do_1x1_update(&mut a_staged, BLOCK_SIZE, p, false);
}
let mut a_scalar = a_staged.clone();
let mut a_block = a_staged;
crate::dense::factor::do_1x1_update(&mut a_scalar, BLOCK_SIZE, 5, false);
update_1x1_block32(&mut a_block, 5, false);
assert_blocks_bit_equal(&a_block, &a_scalar, "update_1x1_block32 at p=5");
}
#[test]
fn update_1x1_block32_matches_do_1x1_update_at_p30() {
let a0 = seeded_block_1024(0xF00D_FACE_C0FF_EE00);
let mut a_staged = a0;
for p in 0..30 {
crate::dense::factor::do_1x1_update(&mut a_staged, BLOCK_SIZE, p, false);
}
let mut a_scalar = a_staged.clone();
let mut a_block = a_staged;
crate::dense::factor::do_1x1_update(&mut a_scalar, BLOCK_SIZE, 30, false);
update_1x1_block32(&mut a_block, 30, false);
assert_blocks_bit_equal(&a_block, &a_scalar, "update_1x1_block32 at p=30");
}
#[test]
fn update_1x1_block32_zero_pivot_is_noop() {
let a0 = seeded_block_1024(0xBADD_F00D_DEAD_BEEF);
let mut a_scalar = a0.clone();
let mut a_block = a0;
a_scalar[2 * BLOCK_SIZE + 2] = 0.0;
a_block[2 * BLOCK_SIZE + 2] = 0.0;
crate::dense::factor::do_1x1_update(&mut a_scalar, BLOCK_SIZE, 2, false);
update_1x1_block32(&mut a_block, 2, false);
assert_blocks_bit_equal(&a_block, &a_scalar, "update_1x1_block32 zero pivot");
}
#[test]
fn update_2x2_block32_matches_do_2x2_update_at_p0() {
let a0 = seeded_block_1024(0xCAFE_BABE_1234_5678);
let mut a_scalar = a0.clone();
let mut a_block = a0;
let d11 = 2.5;
let d21 = -0.75;
let d22 = 1.125;
crate::dense::factor::do_2x2_update(&mut a_scalar, BLOCK_SIZE, 0, d11, d21, d22, false);
update_2x2_block32(&mut a_block, 0, d11, d21, d22, false);
assert_blocks_bit_equal(&a_block, &a_scalar, "update_2x2_block32 at p=0");
}
#[test]
fn update_2x2_block32_matches_do_2x2_update_at_p10_and_p28() {
let a0 = seeded_block_1024(0x0123_4567_89AB_CDEF);
let (d11, d21, d22) = (-1.5, 0.25, 0.875);
let mut a_scalar = a0.clone();
let mut a_block = a0.clone();
crate::dense::factor::do_2x2_update(&mut a_scalar, BLOCK_SIZE, 10, d11, d21, d22, false);
update_2x2_block32(&mut a_block, 10, d11, d21, d22, false);
assert_blocks_bit_equal(&a_block, &a_scalar, "update_2x2_block32 at p=10");
let mut a_scalar2 = a0.clone();
let mut a_block2 = a0;
crate::dense::factor::do_2x2_update(&mut a_scalar2, BLOCK_SIZE, 28, d11, d21, d22, false);
update_2x2_block32(&mut a_block2, 28, d11, d21, d22, false);
assert_blocks_bit_equal(&a_block2, &a_scalar2, "update_2x2_block32 at p=28");
}
#[test]
fn update_2x2_block32_singular_is_noop() {
let a0 = seeded_block_1024(0xDEAD_BEEF_FEED_FACE);
let mut a_scalar = a0.clone();
let mut a_block = a0;
let (d11, d21, d22) = (1.0, 1.0, 1.0);
crate::dense::factor::do_2x2_update(&mut a_scalar, BLOCK_SIZE, 5, d11, d21, d22, false);
update_2x2_block32(&mut a_block, 5, d11, d21, d22, false);
assert_blocks_bit_equal(&a_block, &a_scalar, "update_2x2_block32 singular");
}
#[test]
fn factor_block32_rejects_wrong_size() {
let m = SymmetricMatrix::zeros(16);
let params = BunchKaufmanParams::default();
let res = factor_block32(&m, 16, false, ¶ms);
assert!(res.is_err());
}
#[test]
fn factor_block32_diagonal_spd_matches_scalar() {
let mut rows = [[0.0f64; BLOCK_SIZE]; BLOCK_SIZE];
for i in 0..BLOCK_SIZE {
rows[i][i] = (i as f64) + 1.0;
}
let src = from_lower(&rows);
let (a, b) = dup_lower(&src);
let params = BunchKaufmanParams::default();
let scalar = factor_frontal(&a, BLOCK_SIZE, false, ¶ms).expect("scalar");
let block = factor_block32(&b, BLOCK_SIZE, false, ¶ms).expect("block32");
assert_factors_bit_equal(&block, &scalar);
}
#[test]
fn factor_block32_seeded_indefinite_matches_scalar() {
let src = seeded_indefinite_32x32();
let (a, b) = dup_lower(&src);
let params = BunchKaufmanParams::default();
let scalar = factor_frontal(&a, BLOCK_SIZE, false, ¶ms).expect("scalar");
let block = factor_block32(&b, BLOCK_SIZE, false, ¶ms).expect("block32");
assert_factors_bit_equal(&block, &scalar);
}
}