use crate::dense::DenseMat;
use crate::Inertia;
#[cfg(target_arch = "aarch64")]
#[inline]
fn axpy_neg(alpha: f64, x: &[f64], y: &mut [f64]) {
let n = x.len();
debug_assert!(y.len() >= n);
unsafe {
use std::arch::aarch64::*;
let av = vdupq_n_f64(alpha);
let mut i = 0;
while i + 8 <= n {
let x0 = vld1q_f64(x.as_ptr().add(i));
let x1 = vld1q_f64(x.as_ptr().add(i + 2));
let x2 = vld1q_f64(x.as_ptr().add(i + 4));
let x3 = vld1q_f64(x.as_ptr().add(i + 6));
let y0 = vld1q_f64(y.as_ptr().add(i));
let y1 = vld1q_f64(y.as_ptr().add(i + 2));
let y2 = vld1q_f64(y.as_ptr().add(i + 4));
let y3 = vld1q_f64(y.as_ptr().add(i + 6));
vst1q_f64(y.as_mut_ptr().add(i), vfmsq_f64(y0, x0, av));
vst1q_f64(y.as_mut_ptr().add(i + 2), vfmsq_f64(y1, x1, av));
vst1q_f64(y.as_mut_ptr().add(i + 4), vfmsq_f64(y2, x2, av));
vst1q_f64(y.as_mut_ptr().add(i + 6), vfmsq_f64(y3, x3, av));
i += 8;
}
while i + 2 <= n {
let xv = vld1q_f64(x.as_ptr().add(i));
let yv = vld1q_f64(y.as_ptr().add(i));
vst1q_f64(y.as_mut_ptr().add(i), vfmsq_f64(yv, xv, av));
i += 2;
}
if i < n { y[i] -= alpha * x[i]; }
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn axpy2_neg(a0: f64, a1: f64, x0: &[f64], x1: &[f64], y: &mut [f64]) {
let n = x0.len();
debug_assert!(x1.len() >= n && y.len() >= n);
unsafe {
use std::arch::aarch64::*;
let av0 = vdupq_n_f64(a0);
let av1 = vdupq_n_f64(a1);
let mut i = 0;
while i + 4 <= n {
let x0a = vld1q_f64(x0.as_ptr().add(i));
let x0b = vld1q_f64(x0.as_ptr().add(i + 2));
let x1a = vld1q_f64(x1.as_ptr().add(i));
let x1b = vld1q_f64(x1.as_ptr().add(i + 2));
let ya = vld1q_f64(y.as_ptr().add(i));
let yb = vld1q_f64(y.as_ptr().add(i + 2));
let ya = vfmsq_f64(ya, x0a, av0);
let yb = vfmsq_f64(yb, x0b, av0);
let ya = vfmsq_f64(ya, x1a, av1);
let yb = vfmsq_f64(yb, x1b, av1);
vst1q_f64(y.as_mut_ptr().add(i), ya);
vst1q_f64(y.as_mut_ptr().add(i + 2), yb);
i += 4;
}
while i < n {
y[i] -= a0 * x0[i] + a1 * x1[i];
i += 1;
}
}
}
#[cfg(not(target_arch = "aarch64"))]
#[inline]
fn axpy_neg(alpha: f64, x: &[f64], y: &mut [f64]) {
for i in 0..x.len() {
y[i] -= alpha * x[i];
}
}
#[cfg(not(target_arch = "aarch64"))]
#[inline]
fn axpy2_neg(a0: f64, a1: f64, x0: &[f64], x1: &[f64], y: &mut [f64]) {
for i in 0..x0.len() {
y[i] -= a0 * x0[i] + a1 * x1[i];
}
}
#[derive(Debug, Clone)]
pub struct BunchKaufmanResult {
pub l: DenseMat,
pub d_diag: Vec<f64>,
pub d_offdiag: Vec<f64>,
pub perm: Vec<usize>,
pub perm_inv: Vec<usize>,
pub inertia: Inertia,
}
const BK_ALPHA: f64 = 0.6404;
const ZERO_PIVOT_TOL: f64 = 1e-12;
pub const DEFAULT_PIVOT_THRESHOLD: f64 = 0.01;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PivotResult {
OneByOne(usize),
TwoByTwo(usize, usize),
Delayed,
}
fn find_pivot(a: &[f64], n: usize, k: usize) -> (usize, usize, usize) {
if k == n - 1 {
return (1, k, k);
}
let akk = a[k * n + k].abs();
let mut lambda = 0.0f64;
let mut r = k;
for i in (k + 1)..n {
let v = a[i * n + k].abs();
if v > lambda {
lambda = v;
r = i;
}
}
if lambda == 0.0 && akk == 0.0 {
return (1, k, k); }
if akk >= BK_ALPHA * lambda {
return (1, k, k); }
let mut sigma = 0.0f64;
for j in k..n {
if j != r {
let v = a[r * n + j].abs();
if v > sigma {
sigma = v;
}
}
}
if akk * sigma >= BK_ALPHA * lambda * lambda {
return (1, k, k); }
let arr = a[r * n + r].abs();
if arr >= BK_ALPHA * sigma {
return (1, r, r); }
(2, k, r) }
pub fn find_pivot_threshold(
a: &[f64],
n: usize,
k: usize,
threshold: f64,
) -> PivotResult {
if k == n - 1 {
let akk = a[k * n + k].abs();
if akk < ZERO_PIVOT_TOL {
return PivotResult::Delayed;
}
return PivotResult::OneByOne(k);
}
let akk = a[k * n + k].abs();
let mut lambda = 0.0f64;
let mut r = k;
for i in (k + 1)..n {
let v = a[i * n + k].abs();
if v > lambda {
lambda = v;
r = i;
}
}
if lambda == 0.0 && akk == 0.0 {
return PivotResult::Delayed; }
if akk >= threshold * lambda {
return PivotResult::OneByOne(k);
}
let mut sigma = 0.0f64;
for j in k..n {
if j != r {
let v = a[r * n + j].abs();
if v > sigma {
sigma = v;
}
}
}
if akk * sigma >= BK_ALPHA * lambda * lambda {
if akk >= threshold * lambda {
return PivotResult::OneByOne(k);
}
return PivotResult::Delayed;
}
let arr = a[r * n + r].abs();
if arr >= threshold * sigma {
return PivotResult::OneByOne(r);
}
let akr = a[k * n + r].abs().max(a[r * n + k].abs());
if akr > ZERO_PIVOT_TOL {
let d_kk = a[k * n + k];
let d_kr = a[r * n + k];
let d_rr = a[r * n + r];
let det = (d_kk * d_rr - d_kr * d_kr).abs();
let max_elem = akk.max(arr).max(akr);
if det >= threshold * max_elem * max_elem {
return PivotResult::TwoByTwo(k, r);
}
}
PivotResult::Delayed
}
#[inline]
fn swap_rows_cols_from(a: &mut [f64], n: usize, p: usize, q: usize, start: usize) {
if p == q {
return;
}
for j in start..n {
a.swap(p * n + j, q * n + j);
}
for i in start..n {
a.swap(i * n + p, i * n + q);
}
}
pub fn compute_inertia(d_diag: &[f64], d_offdiag: &[f64], n: usize) -> Inertia {
let mut positive = 0;
let mut negative = 0;
let mut zero = 0;
let mut k = 0;
while k < n {
if k + 1 < n && d_offdiag[k].abs() > ZERO_PIVOT_TOL {
let a = d_diag[k];
let b = d_offdiag[k];
let c = d_diag[k + 1];
let trace = a + c;
let det = a * c - b * b;
let disc = (trace * trace - 4.0 * det).max(0.0).sqrt();
let eig1 = (trace + disc) / 2.0;
let eig2 = (trace - disc) / 2.0;
for eig in [eig1, eig2] {
if eig > ZERO_PIVOT_TOL {
positive += 1;
} else if eig < -ZERO_PIVOT_TOL {
negative += 1;
} else {
zero += 1;
}
}
k += 2;
} else {
let d = d_diag[k];
if d > ZERO_PIVOT_TOL {
positive += 1;
} else if d < -ZERO_PIVOT_TOL {
negative += 1;
} else {
zero += 1;
}
k += 1;
}
}
Inertia { positive, negative, zero }
}
pub fn dense_ldlt_bunch_kaufman(a: &mut DenseMat) -> BunchKaufmanResult {
let n = a.nrows;
debug_assert_eq!(a.ncols, n);
let mut l = DenseMat::zeros(n, n);
let mut d_diag = vec![0.0; n];
let mut d_offdiag = vec![0.0; n];
let mut perm: Vec<usize> = (0..n).collect();
let mut work = vec![0.0f64; 2 * n];
let aa = &mut a.data;
let mut k = 0;
while k < n {
let (pivot_type, p1, p2) = find_pivot(aa, n, k);
if pivot_type == 1 {
if p1 != k {
swap_rows_cols_from(aa, n, k, p1, k);
perm.swap(k, p1);
for j in 0..k {
l.data.swap(k * n + j, p1 * n + j);
}
}
let akk = aa[k * n + k];
d_diag[k] = akk;
if akk.abs() > ZERO_PIVOT_TOL {
let m = n - k - 1;
for i in 0..m {
work[i] = aa[(k + 1 + i) * n + k] / akk;
l.data[(k + 1 + i) * n + k] = work[i];
}
for i in 0..m {
let si = work[i] * akk;
let base = (k + 1 + i) * n + (k + 1);
axpy_neg(si, &work[..m], &mut aa[base..base + m]);
}
}
l.data[k * n + k] = 1.0;
k += 1;
} else {
if p2 != k + 1 {
swap_rows_cols_from(aa, n, k + 1, p2, k);
perm.swap(k + 1, p2);
for j in 0..k {
l.data.swap((k + 1) * n + j, p2 * n + j);
}
}
if p1 != k {
swap_rows_cols_from(aa, n, k, p1, k);
perm.swap(k, p1);
for j in 0..k {
l.data.swap(k * n + j, p1 * n + j);
}
}
let akk = aa[k * n + k];
let ak1k = aa[(k + 1) * n + k];
let ak1k1 = aa[(k + 1) * n + (k + 1)];
d_diag[k] = akk;
d_diag[k + 1] = ak1k1;
d_offdiag[k] = ak1k;
let det = akk * ak1k1 - ak1k * ak1k;
if det.abs() > ZERO_PIVOT_TOL {
let d_inv_00 = ak1k1 / det;
let d_inv_01 = -ak1k / det;
let d_inv_11 = akk / det;
let m = n - k - 2;
for i in 0..m {
let aik = aa[(k + 2 + i) * n + k];
let aik1 = aa[(k + 2 + i) * n + (k + 1)];
work[i] = aik * d_inv_00 + aik1 * d_inv_01;
work[m + i] = aik * d_inv_01 + aik1 * d_inv_11;
l.data[(k + 2 + i) * n + k] = work[i];
l.data[(k + 2 + i) * n + (k + 1)] = work[m + i];
}
for i in 0..m {
let li0 = work[i];
let li1 = work[m + i];
let si0 = li0 * akk + li1 * ak1k;
let si1 = li0 * ak1k + li1 * ak1k1;
let base = (k + 2 + i) * n + (k + 2);
axpy2_neg(si0, si1, &work[..m], &work[m..m + m], &mut aa[base..base + m]);
}
}
l.data[k * n + k] = 1.0;
l.data[(k + 1) * n + (k + 1)] = 1.0;
k += 2;
}
}
let mut perm_inv = vec![0; n];
for i in 0..n {
perm_inv[perm[i]] = i;
}
let inertia = compute_inertia(&d_diag, &d_offdiag, n);
BunchKaufmanResult { l, d_diag, d_offdiag, perm, perm_inv, inertia }
}
pub fn bunch_kaufman_solve(bk: &BunchKaufmanResult, rhs: &[f64], solution: &mut [f64]) {
let n = bk.l.nrows;
let mut y = vec![0.0; n];
for i in 0..n {
y[i] = rhs[bk.perm[i]];
}
for i in 0..n {
for j in 0..i {
y[i] -= bk.l.data[i * n + j] * y[j];
}
}
let mut w = vec![0.0; n];
let mut k = 0;
while k < n {
if k + 1 < n && bk.d_offdiag[k].abs() > ZERO_PIVOT_TOL {
let a = bk.d_diag[k];
let b = bk.d_offdiag[k];
let c = bk.d_diag[k + 1];
let det = a * c - b * b;
w[k] = (c * y[k] - b * y[k + 1]) / det;
w[k + 1] = (a * y[k + 1] - b * y[k]) / det;
k += 2;
} else {
w[k] = y[k] / bk.d_diag[k];
k += 1;
}
}
for i in (0..n).rev() {
for j in (i + 1)..n {
w[i] -= bk.l.data[j * n + i] * w[j];
}
}
for i in 0..n {
solution[bk.perm[i]] = w[i];
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_full_symmetric(vals: &[&[f64]]) -> DenseMat {
let n = vals.len();
let mut m = DenseMat::zeros(n, n);
for i in 0..n {
for j in 0..n {
m.set(i, j, vals[i][j]);
}
}
m
}
fn verify_factorization(orig: &DenseMat, bk: &BunchKaufmanResult) {
let n = orig.nrows;
let mut ldlt = DenseMat::zeros(n, n);
for i in 0..n {
for j in 0..n {
let mut val = 0.0;
let mut kk = 0;
while kk < n {
if kk + 1 < n && bk.d_offdiag[kk].abs() > ZERO_PIVOT_TOL {
let lik0 = bk.l.data[i * n + kk];
let lik1 = bk.l.data[i * n + kk + 1];
let ljk0 = bk.l.data[j * n + kk];
let ljk1 = bk.l.data[j * n + kk + 1];
let d00 = bk.d_diag[kk];
let d01 = bk.d_offdiag[kk];
let d11 = bk.d_diag[kk + 1];
val += (lik0 * d00 + lik1 * d01) * ljk0
+ (lik0 * d01 + lik1 * d11) * ljk1;
kk += 2;
} else {
let lik = bk.l.data[i * n + kk];
let ljk = bk.l.data[j * n + kk];
val += lik * bk.d_diag[kk] * ljk;
kk += 1;
}
}
ldlt.set(i, j, val);
}
}
for i in 0..n {
for j in 0..n {
let expected = orig.get(bk.perm[i], bk.perm[j]);
let got = ldlt.get(i, j);
assert!(
(expected - got).abs() < 1e-10,
"P*L*D*L^T*P^T mismatch at ({},{}): expected {} got {}",
i, j, expected, got
);
}
}
}
#[test]
fn test_bk_spd_3x3() {
let mut a = make_full_symmetric(&[
&[4.0, 2.0, 1.0],
&[2.0, 5.0, 3.0],
&[1.0, 3.0, 6.0],
]);
let orig = a.clone();
let bk = dense_ldlt_bunch_kaufman(&mut a);
assert_eq!(bk.inertia, Inertia { positive: 3, negative: 0, zero: 0 });
verify_factorization(&orig, &bk);
}
#[test]
fn test_bk_indefinite_2x2() {
let mut a = make_full_symmetric(&[&[1.0, 2.0], &[2.0, 1.0]]);
let orig = a.clone();
let bk = dense_ldlt_bunch_kaufman(&mut a);
assert_eq!(bk.inertia, Inertia { positive: 1, negative: 1, zero: 0 });
verify_factorization(&orig, &bk);
}
#[test]
fn test_bk_kkt_like() {
let mut a = make_full_symmetric(&[
&[2.0, 0.0, 1.0],
&[0.0, 2.0, 1.0],
&[1.0, 1.0, 0.0],
]);
let orig = a.clone();
let bk = dense_ldlt_bunch_kaufman(&mut a);
assert_eq!(bk.inertia.positive, 2);
assert_eq!(bk.inertia.negative, 1);
assert_eq!(bk.inertia.zero, 0);
verify_factorization(&orig, &bk);
}
#[test]
fn test_bk_solve_spd() {
let mut a = make_full_symmetric(&[
&[4.0, 2.0, 1.0],
&[2.0, 5.0, 3.0],
&[1.0, 3.0, 6.0],
]);
let orig = a.clone();
let bk = dense_ldlt_bunch_kaufman(&mut a);
let b = [8.0, 18.0, 25.0];
let mut x = [0.0; 3];
bunch_kaufman_solve(&bk, &b, &mut x);
for i in 0..3 {
let mut ax = 0.0;
for j in 0..3 {
ax += orig.get(i, j) * x[j];
}
assert!(
(ax - b[i]).abs() < 1e-10,
"residual at {}: {}",
i,
(ax - b[i]).abs()
);
}
}
#[test]
fn test_bk_solve_indefinite() {
let mut a = make_full_symmetric(&[
&[2.0, 0.0, 1.0],
&[0.0, 2.0, 1.0],
&[1.0, 1.0, 0.0],
]);
let orig = a.clone();
let bk = dense_ldlt_bunch_kaufman(&mut a);
let b = [3.0, 5.0, 2.0];
let mut x = [0.0; 3];
bunch_kaufman_solve(&bk, &b, &mut x);
for i in 0..3 {
let mut ax = 0.0;
for j in 0..3 {
ax += orig.get(i, j) * x[j];
}
assert!(
(ax - b[i]).abs() < 1e-10,
"residual at {}: {}",
i,
(ax - b[i]).abs()
);
}
}
#[test]
fn test_bk_solve_larger_kkt() {
let mut a = make_full_symmetric(&[
&[4.0, 0.0, 0.0, 1.0, 0.0],
&[0.0, 5.0, 0.0, 0.0, 1.0],
&[0.0, 0.0, 6.0, 1.0, 1.0],
&[1.0, 0.0, 1.0, 0.0, 0.0],
&[0.0, 1.0, 1.0, 0.0, 0.0],
]);
let orig = a.clone();
let bk = dense_ldlt_bunch_kaufman(&mut a);
assert_eq!(bk.inertia.positive, 3);
assert_eq!(bk.inertia.negative, 2);
let b = [1.0, 2.0, 3.0, 4.0, 5.0];
let mut x = [0.0; 5];
bunch_kaufman_solve(&bk, &b, &mut x);
for i in 0..5 {
let mut ax = 0.0;
for j in 0..5 {
ax += orig.get(i, j) * x[j];
}
assert!(
(ax - b[i]).abs() < 1e-10,
"residual at {}: {}",
i,
(ax - b[i]).abs()
);
}
}
}