use crate::dense::rook::{rook_rescue, RookKind};
use crate::error::FeralError;
use crate::inertia::Inertia;
use crate::dense::schur_kernel;
pub static FORCE_SCALAR_FRONTAL: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
pub static DISABLE_PANEL_INLINE_2X2: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
pub static PANEL_DIAG_ENABLED: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
pub mod panel_diag {
use std::sync::atomic::AtomicU64;
pub static PANEL_FULL: AtomicU64 = AtomicU64::new(0);
pub static PANEL_PARTIAL: AtomicU64 = AtomicU64::new(0);
pub static PANEL_DELAYED: AtomicU64 = AtomicU64::new(0);
pub static FALLBACK_2X2_NEED_SWAP_OR_BOUND: AtomicU64 = AtomicU64::new(0);
pub static FALLBACK_2X2_SWAP_1X1_WINS: AtomicU64 = AtomicU64::new(0);
pub static FALLBACK_2X2_LAPACK_1X1_WINS: AtomicU64 = AtomicU64::new(0);
pub static FALLBACK_2X2_GROWTH_OR_DET: AtomicU64 = AtomicU64::new(0);
pub static SCALAR_TAIL_STEPS: AtomicU64 = AtomicU64::new(0);
pub static PIVOTS_INLINE: AtomicU64 = AtomicU64::new(0);
pub static PIVOTS_SCALAR: AtomicU64 = AtomicU64::new(0);
pub static INLINE_2X2_SWAP_OK: AtomicU64 = AtomicU64::new(0);
pub fn reset() {
for c in [
&PANEL_FULL,
&PANEL_PARTIAL,
&PANEL_DELAYED,
&FALLBACK_2X2_NEED_SWAP_OR_BOUND,
&FALLBACK_2X2_SWAP_1X1_WINS,
&FALLBACK_2X2_LAPACK_1X1_WINS,
&FALLBACK_2X2_GROWTH_OR_DET,
&SCALAR_TAIL_STEPS,
&PIVOTS_INLINE,
&PIVOTS_SCALAR,
&INLINE_2X2_SWAP_OK,
] {
c.store(0, std::sync::atomic::Ordering::Relaxed);
}
}
pub fn snapshot() -> [(&'static str, u64); 11] {
use std::sync::atomic::Ordering::Relaxed;
[
("panel_full", PANEL_FULL.load(Relaxed)),
("panel_partial", PANEL_PARTIAL.load(Relaxed)),
("panel_delayed", PANEL_DELAYED.load(Relaxed)),
(
"fallback_2x2_need_swap_or_bound",
FALLBACK_2X2_NEED_SWAP_OR_BOUND.load(Relaxed),
),
(
"fallback_2x2_swap_1x1_wins",
FALLBACK_2X2_SWAP_1X1_WINS.load(Relaxed),
),
(
"fallback_2x2_lapack_1x1_wins",
FALLBACK_2X2_LAPACK_1X1_WINS.load(Relaxed),
),
(
"fallback_2x2_growth_or_det",
FALLBACK_2X2_GROWTH_OR_DET.load(Relaxed),
),
("scalar_tail_steps", SCALAR_TAIL_STEPS.load(Relaxed)),
("pivots_inline", PIVOTS_INLINE.load(Relaxed)),
("pivots_scalar", PIVOTS_SCALAR.load(Relaxed)),
("inline_2x2_swap_ok", INLINE_2X2_SWAP_OK.load(Relaxed)),
]
}
}
#[inline(always)]
fn diag_inc(counter: &std::sync::atomic::AtomicU64) {
if PANEL_DIAG_ENABLED.load(std::sync::atomic::Ordering::Relaxed) {
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
#[inline(always)]
fn diag_add(counter: &std::sync::atomic::AtomicU64, n: u64) {
if PANEL_DIAG_ENABLED.load(std::sync::atomic::Ordering::Relaxed) {
counter.fetch_add(n, std::sync::atomic::Ordering::Relaxed);
}
}
const SSIDS_DET_SMALL: f64 = 1e-20;
const L_GROWTH_THRESHOLD: f64 = 1e6;
fn flag_growth_for_refinement(l: &[f64], needs_refinement: &mut bool) {
if *needs_refinement {
return;
}
for &v in l {
if v.abs() > L_GROWTH_THRESHOLD {
*needs_refinement = true;
return;
}
}
}
#[derive(Debug, Clone)]
pub struct BunchKaufmanParams {
pub alpha: f64,
pub zero_tol: f64,
pub zero_tol_2x2: f64,
pub on_zero_pivot: ZeroPivotAction,
pub pivot_threshold: f64,
pub block_size: usize,
}
#[derive(Debug, Clone)]
pub enum ZeroPivotAction {
ForceAccept,
Fail,
}
impl Default for BunchKaufmanParams {
fn default() -> Self {
let zero_tol = f64::EPSILON;
Self {
alpha: (1.0 + 17f64.sqrt()) / 8.0, zero_tol,
zero_tol_2x2: zero_tol * zero_tol,
on_zero_pivot: ZeroPivotAction::Fail,
pivot_threshold: 0.0,
block_size: 64,
}
}
}
#[derive(Debug)]
pub struct Factors {
pub n: usize,
pub l: Vec<f64>,
pub d_diag: Vec<f64>,
pub d_subdiag: Vec<f64>,
pub perm: Vec<usize>,
pub perm_inv: Vec<usize>,
pub d_eq: Vec<f64>,
pub needs_refinement: bool,
pub zero_tol: f64,
pub zero_tol_2x2: f64,
}
pub fn factor(
matrix: &crate::dense::matrix::SymmetricMatrix,
params: &BunchKaufmanParams,
) -> Result<(Factors, Inertia), FeralError> {
matrix.validate()?;
let n = matrix.n;
let d_eq = crate::dense::equilibrate::equilibrate_scaling(matrix);
let mut a = vec![0.0; n * n];
for j in 0..n {
for i in j..n {
a[j * n + i] = d_eq[i] * matrix.data[j * n + i] * d_eq[j];
}
}
let mut perm: Vec<usize> = (0..n).collect();
let mut subdiag = vec![0.0; n];
let mut pos = 0usize;
let mut neg = 0usize;
let mut zero = 0usize;
let mut needs_refinement = false;
let alpha = params.alpha;
let mut k = 0;
let mut fused_gamma0 = 0.0f64;
let mut fused_r = 0usize;
let mut have_fused = false;
while k < n {
let remaining = n - k;
if remaining == 1 {
let d = a[k * n + k];
if d.abs() <= params.zero_tol {
match params.on_zero_pivot {
ZeroPivotAction::ForceAccept => {
needs_refinement = true;
zero += 1;
}
ZeroPivotAction::Fail => {
return Err(FeralError::NumericallyRankDeficient);
}
}
} else if d > 0.0 {
pos += 1;
} else {
neg += 1;
}
k += 1;
continue;
}
let (gamma0, r) = if have_fused {
have_fused = false;
(fused_gamma0, fused_r)
} else {
column_offdiag_max(&a, n, k)
};
if gamma0 == 0.0 {
let d = a[k * n + k];
count_1x1_inertia(
d,
params,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
)?;
set_l_column_identity(&mut a, n, k);
k += 1;
continue;
}
let akk = a[k * n + k].abs();
if akk >= alpha * gamma0 {
let (ng, nr) = do_1x1_pivot(
&mut a,
n,
k,
gamma0,
params,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
)?;
fused_gamma0 = ng;
fused_r = nr;
have_fused = k + 1 < n;
k += 1;
continue;
}
let gamma_r = symmetric_row_offdiag_max(&a, n, k, r);
let arr = a[r * n + r].abs();
if arr >= alpha * gamma_r {
swap_rows_cols(&mut a, n, k, r, &mut perm);
let (ng, nr) = do_1x1_pivot(
&mut a,
n,
k,
gamma_r,
params,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
)?;
fused_gamma0 = ng;
fused_r = nr;
have_fused = k + 1 < n;
k += 1;
continue;
}
if akk * gamma_r >= alpha * gamma0 * gamma0 {
let (ng, nr) = do_1x1_pivot(
&mut a,
n,
k,
gamma0,
params,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
)?;
fused_gamma0 = ng;
fused_r = nr;
have_fused = k + 1 < n;
k += 1;
continue;
}
if r != k + 1 {
swap_rows_cols(&mut a, n, k + 1, r, &mut perm);
}
let d11_v = a[k * n + k];
let d21_v = a[k * n + (k + 1)];
let d22_v = a[(k + 1) * n + (k + 1)];
let det_v = d11_v * d22_v - d21_v * d21_v;
let absdet = det_v.abs();
let mut rmax = 0.0f64;
let mut tmax = 0.0f64;
for i in (k + 2)..n {
let v0 = a[k * n + i].abs();
if v0 > rmax {
rmax = v0;
}
let v1 = a[(k + 1) * n + i].abs();
if v1 > tmax {
tmax = v1;
}
}
let amax = d21_v.abs();
let u = params.pivot_threshold;
let growth_fail = (d22_v.abs() * rmax + amax * tmax) * u > absdet
|| (d11_v.abs() * tmax + amax * rmax) * u > absdet;
if growth_fail {
let (ng, nr) = do_1x1_pivot(
&mut a,
n,
k,
gamma0,
params,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
)?;
fused_gamma0 = ng;
fused_r = nr;
have_fused = k + 1 < n;
k += 1;
continue;
}
let (ng, nr) = do_2x2_pivot(
&mut a,
n,
k,
&mut subdiag,
params,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
)?;
fused_gamma0 = ng;
fused_r = nr;
have_fused = k + 2 < n;
k += 2;
}
let mut l = vec![0.0; n * n];
let mut d_diag = vec![0.0; n];
let mut j = 0;
while j < n {
d_diag[j] = a[j * n + j];
l[j * n + j] = 1.0;
if j + 1 < n && subdiag[j] != 0.0 {
d_diag[j + 1] = a[(j + 1) * n + (j + 1)];
l[(j + 1) * n + (j + 1)] = 1.0;
for i in (j + 2)..n {
l[j * n + i] = a[j * n + i];
l[(j + 1) * n + i] = a[(j + 1) * n + i];
}
j += 2;
} else {
for i in (j + 1)..n {
l[j * n + i] = a[j * n + i];
}
j += 1;
}
}
let mut perm_inv = vec![0usize; n];
for (i, &p) in perm.iter().enumerate() {
perm_inv[p] = i;
}
let inertia = Inertia::new(pos, neg, zero);
flag_growth_for_refinement(&l, &mut needs_refinement);
Ok((
Factors {
n,
l,
d_diag,
d_subdiag: subdiag,
perm,
perm_inv,
d_eq,
needs_refinement,
zero_tol: params.zero_tol,
zero_tol_2x2: params.zero_tol_2x2,
},
inertia,
))
}
pub fn factor_single_front(
matrix: &crate::dense::matrix::SymmetricMatrix,
params: &BunchKaufmanParams,
) -> Result<(Factors, Inertia), FeralError> {
matrix.validate()?;
let n = matrix.n;
let d_eq = crate::dense::equilibrate::equilibrate_scaling(matrix);
let mut eq_data = vec![0.0; n * n];
for j in 0..n {
for i in j..n {
eq_data[j * n + i] = d_eq[i] * matrix.data[j * n + i] * d_eq[j];
}
}
let eq_matrix = crate::dense::matrix::SymmetricMatrix { n, data: eq_data };
let front = factor_frontal_blocked(&eq_matrix, n, false, params)?;
debug_assert_eq!(front.nelim, n);
debug_assert_eq!(front.n_delayed, 0);
debug_assert_eq!(front.contrib_dim, 0);
let inertia = front.inertia;
let factors = Factors {
n,
l: front.l,
d_diag: front.d_diag,
d_subdiag: front.d_subdiag,
perm: front.perm,
perm_inv: front.perm_inv,
d_eq,
needs_refinement: front.needs_refinement,
zero_tol: front.zero_tol,
zero_tol_2x2: front.zero_tol_2x2,
};
Ok((factors, inertia))
}
#[derive(Debug)]
pub struct FrontalFactors {
pub nrow: usize,
pub ncol: usize,
pub nelim: usize,
pub l: Vec<f64>,
pub d_diag: Vec<f64>,
pub d_subdiag: Vec<f64>,
pub perm: Vec<usize>,
pub perm_inv: Vec<usize>,
pub contrib: Vec<f64>,
pub contrib_dim: usize,
pub n_delayed: usize,
pub inertia: Inertia,
pub needs_refinement: bool,
pub n_rook_rescues: usize,
pub zero_tol: f64,
pub zero_tol_2x2: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum PivotOutcome {
Accepted,
Rejected,
Delayed,
AcceptedRook2x2 { d11: f64, d21: f64, d22: f64 },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PivotStepResult {
Advanced(usize),
Delayed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PanelStatus {
Full,
ScalarFallback,
ScalarFallbackPeekedNext,
Delayed,
}
#[doc(hidden)]
#[derive(Default, Debug, Clone, Copy)]
pub struct FrontalProfile {
pub alloc_copy_ns: u128,
pub setup_ns: u128,
pub pivot_loop_ns: u128,
pub extract_ns: u128,
pub n_calls: u64,
}
pub fn factor_frontal(
matrix: &crate::dense::matrix::SymmetricMatrix,
ncol: usize,
may_delay: bool,
params: &BunchKaufmanParams,
) -> Result<FrontalFactors, FeralError> {
factor_frontal_with_profile(matrix, ncol, may_delay, params, None)
}
#[doc(hidden)]
pub fn factor_frontal_with_profile(
matrix: &crate::dense::matrix::SymmetricMatrix,
ncol: usize,
may_delay: bool,
params: &BunchKaufmanParams,
mut profile: Option<&mut FrontalProfile>,
) -> Result<FrontalFactors, FeralError> {
matrix.validate()?;
let nrow = matrix.n;
if ncol > nrow {
return Err(FeralError::InvalidInput(format!(
"ncol {} > nrow {}",
ncol, nrow
)));
}
if ncol == 0 {
return Ok(FrontalFactors {
nrow,
ncol: 0,
nelim: 0,
l: Vec::new(),
d_diag: Vec::new(),
d_subdiag: Vec::new(),
perm: (0..nrow).collect(),
perm_inv: (0..nrow).collect(),
contrib: matrix.data.clone(),
contrib_dim: nrow,
n_delayed: 0,
inertia: Inertia {
positive: 0,
negative: 0,
zero: 0,
},
needs_refinement: false,
n_rook_rescues: 0,
zero_tol: params.zero_tol,
zero_tol_2x2: params.zero_tol_2x2,
});
}
let t0 = profile.as_ref().map(|_| std::time::Instant::now());
let mut a = vec![0.0; nrow * nrow];
for j in 0..nrow {
for i in j..nrow {
a[j * nrow + i] = matrix.data[j * nrow + i];
}
}
if let (Some(p), Some(t)) = (profile.as_deref_mut(), t0) {
p.alloc_copy_ns += t.elapsed().as_nanos();
}
let t0 = profile.as_ref().map(|_| std::time::Instant::now());
let mut perm: Vec<usize> = (0..nrow).collect();
let mut subdiag = vec![0.0; nrow];
let mut pos = 0usize;
let mut neg = 0usize;
let mut zero = 0usize;
let mut needs_refinement = false;
let mut n_rook_rescues = 0usize;
if let (Some(p), Some(t)) = (profile.as_deref_mut(), t0) {
p.setup_ns += t.elapsed().as_nanos();
}
let t_pivot = profile.as_ref().map(|_| std::time::Instant::now());
let mut k = 0;
while k < ncol {
match scalar_pivot_step(
&mut a,
nrow,
ncol,
k,
may_delay,
params,
&mut perm,
&mut subdiag,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
&mut n_rook_rescues,
)? {
PivotStepResult::Advanced(n) => k += n,
PivotStepResult::Delayed => break,
}
}
if let (Some(p), Some(t)) = (profile.as_deref_mut(), t_pivot) {
p.pivot_loop_ns += t.elapsed().as_nanos();
}
let t_extract = profile.as_ref().map(|_| std::time::Instant::now());
let nelim = k;
let n_delayed = ncol - nelim;
let mut l = vec![0.0; nrow * nelim];
let mut d_diag = vec![0.0; nelim];
let mut j = 0;
while j < nelim {
d_diag[j] = a[j * nrow + j];
l[j * nrow + j] = 1.0;
if j + 1 < nelim && subdiag[j] != 0.0 {
d_diag[j + 1] = a[(j + 1) * nrow + (j + 1)];
l[(j + 1) * nrow + (j + 1)] = 1.0;
for i in (j + 2)..nrow {
l[j * nrow + i] = a[j * nrow + i];
l[(j + 1) * nrow + i] = a[(j + 1) * nrow + i];
}
j += 2;
} else {
for i in (j + 1)..nrow {
l[j * nrow + i] = a[j * nrow + i];
}
j += 1;
}
}
let cdim = nrow - nelim;
let mut contrib = vec![0.0; cdim * cdim];
for cj in 0..cdim {
for ci in cj..cdim {
contrib[cj * cdim + ci] = a[(nelim + cj) * nrow + (nelim + ci)];
}
}
let mut perm_inv = vec![0usize; nrow];
for (i, &p) in perm.iter().enumerate() {
perm_inv[p] = i;
}
flag_growth_for_refinement(&l, &mut needs_refinement);
let result = FrontalFactors {
nrow,
ncol,
nelim,
l,
d_diag,
d_subdiag: subdiag[..nelim].to_vec(),
perm,
perm_inv,
contrib,
contrib_dim: cdim,
n_delayed,
inertia: Inertia::new(pos, neg, zero),
needs_refinement,
n_rook_rescues,
zero_tol: params.zero_tol,
zero_tol_2x2: params.zero_tol_2x2,
};
if let (Some(p), Some(t)) = (profile, t_extract) {
p.extract_ns += t.elapsed().as_nanos();
p.n_calls += 1;
}
Ok(result)
}
pub fn factor_frontal_blocked(
matrix: &crate::dense::matrix::SymmetricMatrix,
ncol: usize,
may_delay: bool,
params: &BunchKaufmanParams,
) -> Result<FrontalFactors, FeralError> {
matrix.validate()?;
let nrow = matrix.n;
let mut scratch_data = vec![0.0; nrow * nrow];
for j in 0..nrow {
for i in j..nrow {
scratch_data[j * nrow + i] = matrix.data[j * nrow + i];
}
}
let mut scratch = crate::dense::matrix::SymmetricMatrix {
n: nrow,
data: scratch_data,
};
factor_frontal_blocked_in_place(&mut scratch, ncol, may_delay, params)
}
pub fn factor_frontal_blocked_in_place(
matrix: &mut crate::dense::matrix::SymmetricMatrix,
ncol: usize,
may_delay: bool,
params: &BunchKaufmanParams,
) -> Result<FrontalFactors, FeralError> {
let nrow = matrix.n;
if ncol > nrow {
return Err(FeralError::InvalidInput(format!(
"ncol {} > nrow {}",
ncol, nrow
)));
}
if ncol == 0 {
return Ok(FrontalFactors {
nrow,
ncol: 0,
nelim: 0,
l: Vec::new(),
d_diag: Vec::new(),
d_subdiag: Vec::new(),
perm: (0..nrow).collect(),
perm_inv: (0..nrow).collect(),
contrib: matrix.data.clone(),
contrib_dim: nrow,
n_delayed: 0,
inertia: Inertia {
positive: 0,
negative: 0,
zero: 0,
},
needs_refinement: false,
n_rook_rescues: 0,
zero_tol: params.zero_tol,
zero_tol_2x2: params.zero_tol_2x2,
});
}
if FORCE_SCALAR_FRONTAL.load(std::sync::atomic::Ordering::Relaxed) {
return factor_frontal(matrix, ncol, may_delay, params);
}
const PANEL_MIN_NCOL: usize = 8;
let bs = params.block_size.min(ncol);
if bs < 2 || ncol < PANEL_MIN_NCOL {
return factor_frontal(matrix, ncol, may_delay, params);
}
let a: &mut [f64] = matrix.data.as_mut_slice();
let mut perm: Vec<usize> = (0..nrow).collect();
let mut subdiag = vec![0.0; nrow];
let mut pos = 0usize;
let mut neg = 0usize;
let mut zero = 0usize;
let mut needs_refinement = false;
let mut n_rook_rescues = 0usize;
let mut d_panel = vec![0.0f64; bs];
let mut k = 0;
while k < ncol {
let remaining = ncol - k;
if remaining < PANEL_MIN_NCOL {
diag_inc(&panel_diag::SCALAR_TAIL_STEPS);
match scalar_pivot_step(
&mut *a,
nrow,
ncol,
k,
may_delay,
params,
&mut perm,
&mut subdiag,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
&mut n_rook_rescues,
)? {
PivotStepResult::Advanced(n) => {
diag_add(&panel_diag::PIVOTS_SCALAR, n as u64);
k += n;
}
PivotStepResult::Delayed => break,
}
continue;
}
let panel_cap = bs.min(remaining);
let (n_elim, status) = lblt_panel_frontal(
&mut *a,
nrow,
ncol,
k,
panel_cap,
may_delay,
params,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
&mut d_panel,
&mut subdiag,
&mut perm,
)?;
let j_start = match status {
PanelStatus::Full => k + n_elim,
PanelStatus::ScalarFallback | PanelStatus::Delayed => k + n_elim + 1,
PanelStatus::ScalarFallbackPeekedNext => k + n_elim + 2,
};
apply_blocked_schur(&mut *a, nrow, k, n_elim, j_start, &d_panel, &subdiag);
k += n_elim;
diag_add(&panel_diag::PIVOTS_INLINE, n_elim as u64);
match status {
PanelStatus::Full => diag_inc(&panel_diag::PANEL_FULL),
PanelStatus::ScalarFallback | PanelStatus::ScalarFallbackPeekedNext => {
diag_inc(&panel_diag::PANEL_PARTIAL)
}
PanelStatus::Delayed => diag_inc(&panel_diag::PANEL_DELAYED),
}
match status {
PanelStatus::Full => {}
PanelStatus::ScalarFallback | PanelStatus::ScalarFallbackPeekedNext => {
if k >= ncol {
break;
}
match scalar_pivot_step(
&mut *a,
nrow,
ncol,
k,
may_delay,
params,
&mut perm,
&mut subdiag,
&mut pos,
&mut neg,
&mut zero,
&mut needs_refinement,
&mut n_rook_rescues,
)? {
PivotStepResult::Advanced(n) => {
diag_add(&panel_diag::PIVOTS_SCALAR, n as u64);
k += n;
}
PivotStepResult::Delayed => break,
}
}
PanelStatus::Delayed => break,
}
}
let nelim = k;
let n_delayed = ncol - nelim;
let mut l = vec![0.0; nrow * nelim];
let mut d_diag = vec![0.0; nelim];
let mut j = 0;
while j < nelim {
d_diag[j] = a[j * nrow + j];
l[j * nrow + j] = 1.0;
if j + 1 < nelim && subdiag[j] != 0.0 {
d_diag[j + 1] = a[(j + 1) * nrow + (j + 1)];
l[(j + 1) * nrow + (j + 1)] = 1.0;
for i in (j + 2)..nrow {
l[j * nrow + i] = a[j * nrow + i];
l[(j + 1) * nrow + i] = a[(j + 1) * nrow + i];
}
j += 2;
} else {
for i in (j + 1)..nrow {
l[j * nrow + i] = a[j * nrow + i];
}
j += 1;
}
}
let cdim = nrow - nelim;
let mut contrib = vec![0.0; cdim * cdim];
for cj in 0..cdim {
for ci in cj..cdim {
contrib[cj * cdim + ci] = a[(nelim + cj) * nrow + (nelim + ci)];
}
}
let mut perm_inv = vec![0usize; nrow];
for (i, &p) in perm.iter().enumerate() {
perm_inv[p] = i;
}
flag_growth_for_refinement(&l, &mut needs_refinement);
Ok(FrontalFactors {
nrow,
ncol,
nelim,
l,
d_diag,
d_subdiag: subdiag[..nelim].to_vec(),
perm,
perm_inv,
contrib,
contrib_dim: cdim,
n_delayed,
inertia: Inertia::new(pos, neg, zero),
needs_refinement,
n_rook_rescues,
zero_tol: params.zero_tol,
zero_tol_2x2: params.zero_tol_2x2,
})
}
#[allow(clippy::too_many_arguments)]
fn lblt_panel_frontal(
a: &mut [f64],
nrow: usize,
ncol: usize,
k: usize,
bs: usize,
may_delay: bool,
params: &BunchKaufmanParams,
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
needs_refinement: &mut bool,
d_panel: &mut [f64],
subdiag: &mut [f64],
perm: &mut [usize],
) -> Result<(usize, PanelStatus), FeralError> {
let alpha_bk = params.alpha;
let cap = bs;
let mut c = 0usize;
while c < cap {
let col = k + c;
peek_ahead_column(a, nrow, k, c, d_panel, subdiag);
let mut gamma0 = 0.0f64;
let mut r = col + 1;
for i in (col + 1)..nrow {
let v = a[col * nrow + i].abs();
if v > gamma0 {
gamma0 = v;
r = i;
}
}
if gamma0 == 0.0 {
let d = a[col * nrow + col];
count_1x1_inertia(d, params, pos, neg, zero, needs_refinement)?;
set_l_column_identity(a, nrow, col);
d_panel[c] = d;
c += 1;
continue;
}
let akk = a[col * nrow + col].abs();
if akk < alpha_bk * gamma0 {
let need_swap = r > col + 1;
let bounds_ok = col + 1 < ncol
&& c + 1 < cap
&& (!need_swap || r < ncol)
&& !DISABLE_PANEL_INLINE_2X2.load(std::sync::atomic::Ordering::Relaxed);
let allowed = bounds_ok && (!need_swap || c == 0);
if !allowed {
diag_inc(&panel_diag::FALLBACK_2X2_NEED_SWAP_OR_BOUND);
return Ok((c, PanelStatus::ScalarFallback));
}
if need_swap {
let gamma_r_pre = symmetric_row_offdiag_max(a, nrow, col, r);
let arr_pre = a[r * nrow + r].abs();
if arr_pre >= alpha_bk * gamma_r_pre {
diag_inc(&panel_diag::FALLBACK_2X2_SWAP_1X1_WINS);
return Ok((c, PanelStatus::ScalarFallback));
}
if akk * gamma_r_pre >= alpha_bk * gamma0 * gamma0 {
diag_inc(&panel_diag::FALLBACK_2X2_LAPACK_1X1_WINS);
return Ok((c, PanelStatus::ScalarFallback));
}
swap_rows_cols(a, nrow, col + 1, r, perm);
} else {
let r_idx = col + 1;
peek_ahead_replay(a, nrow, k, c, r_idx, d_panel, subdiag);
let gamma_r = symmetric_row_offdiag_max(a, nrow, col, r_idx);
let arr = a[r_idx * nrow + r_idx].abs();
if arr >= alpha_bk * gamma_r {
diag_inc(&panel_diag::FALLBACK_2X2_SWAP_1X1_WINS);
return Ok((c, PanelStatus::ScalarFallbackPeekedNext));
}
if akk * gamma_r >= alpha_bk * gamma0 * gamma0 {
diag_inc(&panel_diag::FALLBACK_2X2_LAPACK_1X1_WINS);
return Ok((c, PanelStatus::ScalarFallbackPeekedNext));
}
}
let d11 = a[col * nrow + col];
let d21 = a[col * nrow + (col + 1)];
let d22 = a[(col + 1) * nrow + (col + 1)];
let det = d11 * d22 - d21 * d21;
let mut rmax = 0.0f64;
let mut tmax = 0.0f64;
for i in (col + 2)..nrow {
let v0 = a[col * nrow + i].abs();
if v0 > rmax {
rmax = v0;
}
let v1 = a[(col + 1) * nrow + i].abs();
if v1 > tmax {
tmax = v1;
}
}
let amax = d21.abs();
let absdet = det.abs();
let u = params.pivot_threshold;
let growth_fail = (d22.abs() * rmax + amax * tmax) * u > absdet
|| (d11.abs() * tmax + amax * rmax) * u > absdet;
let max_piv = d11.abs().max(d21.abs()).max(d22.abs());
let det_floor_fail = if max_piv < SSIDS_DET_SMALL {
true
} else {
let det_scale = 1.0 / max_piv;
let detpiv0 = (d11 * det_scale) * d22;
let detpiv1 = (d21 * det_scale) * d21;
let detpiv = detpiv0 - detpiv1;
let cancel_floor = SSIDS_DET_SMALL
.max(detpiv0.abs() * 0.5)
.max(detpiv1.abs() * 0.5);
detpiv.abs() < cancel_floor
};
if growth_fail || det_floor_fail {
diag_inc(&panel_diag::FALLBACK_2X2_GROWTH_OR_DET);
if need_swap {
return Ok((c, PanelStatus::ScalarFallback));
}
return Ok((c, PanelStatus::ScalarFallbackPeekedNext));
}
if need_swap {
diag_inc(&panel_diag::INLINE_2X2_SWAP_OK);
}
let pivot_inertia = count_2x2_inertia_val(d11, d21, d22);
*pos += pivot_inertia.positive;
*neg += pivot_inertia.negative;
*zero += pivot_inertia.zero;
d_panel[c] = d11;
d_panel[c + 1] = d22;
subdiag[k + c] = d21;
if det.abs() != 0.0 {
let inv_det = 1.0 / det;
for i in (col + 2)..nrow {
let a_ik = a[col * nrow + i];
let a_ik1 = a[(col + 1) * nrow + i];
a[col * nrow + i] = (d22 * a_ik - d21 * a_ik1) * inv_det;
a[(col + 1) * nrow + i] = (d11 * a_ik1 - d21 * a_ik) * inv_det;
}
}
c += 2;
continue;
}
let outcome = try_reject_1x1_frontal(
a,
nrow,
col,
gamma0,
may_delay,
params,
pos,
neg,
zero,
needs_refinement,
)?;
match outcome {
PivotOutcome::Accepted => {
let d = a[col * nrow + col];
if d.abs() != 0.0 {
let inv_d = 1.0 / d;
for i in (col + 1)..nrow {
a[col * nrow + i] *= inv_d;
}
}
d_panel[c] = d;
}
PivotOutcome::Rejected => {
d_panel[c] = 0.0;
}
PivotOutcome::Delayed => {
return Ok((c, PanelStatus::Delayed));
}
PivotOutcome::AcceptedRook2x2 { .. } => {
unreachable!("panel path never enables rook rescue")
}
}
c += 1;
}
Ok((c, PanelStatus::Full))
}
fn peek_ahead_column(
a: &mut [f64],
nrow: usize,
k: usize,
c: usize,
d_panel: &[f64],
subdiag: &[f64],
) {
peek_ahead_replay(a, nrow, k, c, k + c, d_panel, subdiag);
}
fn peek_ahead_replay(
a: &mut [f64],
nrow: usize,
k: usize,
n_committed: usize,
target_col: usize,
d_panel: &[f64],
subdiag: &[f64],
) {
let col = target_col;
let mut q = 0usize;
while q < n_committed {
if q + 1 < n_committed && subdiag[k + q] != 0.0 {
let d11 = d_panel[q];
let d22 = d_panel[q + 1];
let d21 = subdiag[k + q];
let q_col = k + q;
let q1_col = k + q + 1;
let l_jq = a[q_col * nrow + col];
let l_jq1 = a[q1_col * nrow + col];
let dl_jq = d11 * l_jq + d21 * l_jq1;
let dl_jq1 = d21 * l_jq + d22 * l_jq1;
let (before, rest) = a.split_at_mut(col * nrow);
let src0 = &before[q_col * nrow + col..q_col * nrow + nrow];
let src1 = &before[q1_col * nrow + col..q1_col * nrow + nrow];
let dst = &mut rest[col..nrow];
schur_kernel::axpy2_minus_unroll4_nofma(dst, src0, dl_jq, src1, dl_jq1);
q += 2;
continue;
}
let d_q = d_panel[q];
if d_q.abs() == 0.0 {
q += 1;
continue;
}
let q_col = k + q;
let l_jk = a[q_col * nrow + col];
let alpha = l_jk * d_q;
if alpha == 0.0 {
q += 1;
continue;
}
let (before, rest) = a.split_at_mut(col * nrow);
let src = &before[q_col * nrow + col..q_col * nrow + nrow];
let dst = &mut rest[col..nrow];
schur_kernel::axpy_minus_unroll4_nofma(dst, src, alpha);
q += 1;
}
}
fn apply_blocked_schur(
a: &mut [f64],
nrow: usize,
k: usize,
n_elim: usize,
j_start: usize,
d_panel: &[f64],
subdiag: &[f64],
) {
if n_elim == 0 || j_start >= nrow {
return;
}
const W2_RANK_BS_MIN: usize = 2;
let any_zero_d = d_panel.iter().take(n_elim).any(|&d| d.abs() == 0.0);
let has_2x2 = subdiag[k..k + n_elim].iter().any(|&s| s != 0.0);
if !any_zero_d && !has_2x2 && n_elim >= W2_RANK_BS_MIN {
apply_blocked_schur_panel(a, nrow, k, n_elim, j_start, d_panel);
return;
}
let mut q = 0usize;
while q < n_elim {
if q + 1 < n_elim && subdiag[k + q] != 0.0 {
let d11 = d_panel[q];
let d22 = d_panel[q + 1];
let d21 = subdiag[k + q];
let q_col = k + q;
let q1_col = k + q + 1;
for j in j_start..nrow {
let l_jq = a[q_col * nrow + j];
let l_jq1 = a[q1_col * nrow + j];
let dl_jq = d11 * l_jq + d21 * l_jq1;
let dl_jq1 = d21 * l_jq + d22 * l_jq1;
let (before, rest) = a.split_at_mut(j * nrow);
let src0 = &before[q_col * nrow + j..q_col * nrow + nrow];
let src1 = &before[q1_col * nrow + j..q1_col * nrow + nrow];
let dst = &mut rest[j..nrow];
schur_kernel::axpy2_minus_unroll4_nofma(dst, src0, dl_jq, src1, dl_jq1);
}
q += 2;
} else {
let d_q = d_panel[q];
if d_q.abs() == 0.0 {
q += 1;
continue;
}
let q_col = k + q;
for j in j_start..nrow {
let l_jk = a[q_col * nrow + j];
let alpha = l_jk * d_q;
if alpha == 0.0 {
continue;
}
let (before, rest) = a.split_at_mut(j * nrow);
let src = &before[q_col * nrow + j..q_col * nrow + nrow];
let dst = &mut rest[j..nrow];
schur_kernel::axpy_minus_unroll4_nofma(dst, src, alpha);
}
q += 1;
}
}
}
fn apply_blocked_schur_panel(
a: &mut [f64],
nrow: usize,
k: usize,
n_elim: usize,
j_start: usize,
d_panel: &[f64],
) {
const MAX_N_ELIM: usize = 64;
debug_assert!(
n_elim <= MAX_N_ELIM,
"apply_blocked_schur_panel: n_elim {} exceeds MAX_N_ELIM {}",
n_elim,
MAX_N_ELIM
);
let mut alphas0_buf = [0.0f64; MAX_N_ELIM];
let mut alphas1_buf = [0.0f64; MAX_N_ELIM];
let mut j = j_start;
while j + 1 < nrow {
let alphas0 = &mut alphas0_buf[..n_elim];
let alphas1 = &mut alphas1_buf[..n_elim];
let mut all_zero = true;
for q in 0..n_elim {
let q_col = k + q;
let l_jk0 = a[q_col * nrow + j];
let l_jk1 = a[q_col * nrow + j + 1];
let d_q = d_panel[q];
let alpha0 = l_jk0 * d_q;
let alpha1 = l_jk1 * d_q;
alphas0[q] = alpha0;
alphas1[q] = alpha1;
if alpha0 != 0.0 || alpha1 != 0.0 {
all_zero = false;
}
}
if !all_zero {
let (before, rest) = a.split_at_mut(j * nrow);
let (col_j, after_j) = rest.split_at_mut(nrow);
let dst0 = &mut col_j[j..];
let dst1 = &mut after_j[(j + 1)..nrow];
schur_kernel::schur_panel_minus_nofma_strided_dual(
dst0, dst1, before, k, n_elim, nrow, j, alphas0, alphas1,
);
}
j += 2;
}
if j < nrow {
let alphas = &mut alphas0_buf[..n_elim];
let mut all_zero_alpha = true;
for q in 0..n_elim {
let q_col = k + q;
let l_jk = a[q_col * nrow + j];
let d_q = d_panel[q];
let alpha = l_jk * d_q;
alphas[q] = alpha;
if alpha != 0.0 {
all_zero_alpha = false;
}
}
if !all_zero_alpha {
let trailing_len = nrow - j;
let (before, rest) = a.split_at_mut(j * nrow);
let dst = &mut rest[j..nrow];
schur_kernel::schur_panel_minus_nofma_strided(
dst,
before,
k,
n_elim,
nrow,
j,
trailing_len,
alphas,
);
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn finish_1x1_outcome(
outcome: PivotOutcome,
a: &mut [f64],
nrow: usize,
k: usize,
subdiag: &mut [f64],
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
) -> PivotStepResult {
match outcome {
PivotOutcome::Accepted => {
do_1x1_update(a, nrow, k);
PivotStepResult::Advanced(1)
}
PivotOutcome::Rejected => PivotStepResult::Advanced(1),
PivotOutcome::Delayed => PivotStepResult::Delayed,
PivotOutcome::AcceptedRook2x2 { d11, d21, d22 } => {
let inertia = count_2x2_inertia_val(d11, d21, d22);
*pos += inertia.positive;
*neg += inertia.negative;
*zero += inertia.zero;
subdiag[k] = d21;
do_2x2_update(a, nrow, k, d11, d21, d22);
PivotStepResult::Advanced(2)
}
}
}
#[allow(clippy::too_many_arguments)]
fn scalar_pivot_step(
a: &mut [f64],
nrow: usize,
ncol: usize,
k: usize,
may_delay: bool,
params: &BunchKaufmanParams,
perm: &mut [usize],
subdiag: &mut [f64],
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
needs_refinement: &mut bool,
n_rook_rescues: &mut usize,
) -> Result<PivotStepResult, FeralError> {
let alpha = params.alpha;
let remaining = ncol - k;
if remaining == 1 {
let mut col_max = 0.0f64;
for i in (k + 1)..nrow {
let v = a[k * nrow + i].abs();
if v > col_max {
col_max = v;
}
}
let outcome = try_reject_1x1_frontal(
a,
nrow,
k,
col_max,
may_delay,
params,
pos,
neg,
zero,
needs_refinement,
)?;
match outcome {
PivotOutcome::Accepted => do_1x1_update(a, nrow, k),
PivotOutcome::Rejected => {}
PivotOutcome::Delayed => return Ok(PivotStepResult::Delayed),
PivotOutcome::AcceptedRook2x2 { .. } => {
unreachable!("remaining==1 never triggers rook rescue")
}
}
return Ok(PivotStepResult::Advanced(1));
}
let (gamma0, r) = {
let mut max_val = 0.0f64;
let mut max_row = k + 1;
for i in (k + 1)..ncol {
let v = a[k * nrow + i].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
for i in ncol..nrow {
let v = a[k * nrow + i].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
(max_val, max_row)
};
if gamma0 == 0.0 {
let d = a[k * nrow + k];
count_1x1_inertia(d, params, pos, neg, zero, needs_refinement)?;
set_l_column_identity(a, nrow, k);
return Ok(PivotStepResult::Advanced(1));
}
let akk = a[k * nrow + k].abs();
if akk >= alpha * gamma0 {
let outcome = try_reject_1x1_with_rook_rescue(
a,
nrow,
ncol,
k,
gamma0,
may_delay,
params,
perm,
pos,
neg,
zero,
needs_refinement,
n_rook_rescues,
)?;
return Ok(finish_1x1_outcome(
outcome, a, nrow, k, subdiag, pos, neg, zero,
));
}
let gamma_r = symmetric_row_offdiag_max(a, nrow, k, r);
let arr = a[r * nrow + r].abs();
let r_is_fully_summed = r < ncol;
if r_is_fully_summed && arr >= alpha * gamma_r {
swap_rows_cols(a, nrow, k, r, perm);
let outcome = try_reject_1x1_with_rook_rescue(
a,
nrow,
ncol,
k,
gamma_r,
may_delay,
params,
perm,
pos,
neg,
zero,
needs_refinement,
n_rook_rescues,
)?;
return Ok(finish_1x1_outcome(
outcome, a, nrow, k, subdiag, pos, neg, zero,
));
}
if akk * gamma_r >= alpha * gamma0 * gamma0 {
let outcome = try_reject_1x1_with_rook_rescue(
a,
nrow,
ncol,
k,
gamma0,
may_delay,
params,
perm,
pos,
neg,
zero,
needs_refinement,
n_rook_rescues,
)?;
return Ok(finish_1x1_outcome(
outcome, a, nrow, k, subdiag, pos, neg, zero,
));
}
if r_is_fully_summed && k + 1 < ncol {
if r != k + 1 {
swap_rows_cols(a, nrow, k + 1, r, perm);
}
let d11 = a[k * nrow + k];
let d21 = a[k * nrow + (k + 1)];
let d22 = a[(k + 1) * nrow + (k + 1)];
let det = d11 * d22 - d21 * d21;
let mut rmax = 0.0f64;
let mut tmax = 0.0f64;
for i in (k + 2)..nrow {
let v0 = a[k * nrow + i].abs();
if v0 > rmax {
rmax = v0;
}
let v1 = a[(k + 1) * nrow + i].abs();
if v1 > tmax {
tmax = v1;
}
}
let amax = d21.abs();
let absdet = det.abs();
let u = params.pivot_threshold;
let growth_fail = (d22.abs() * rmax + amax * tmax) * u > absdet
|| (d11.abs() * tmax + amax * rmax) * u > absdet;
let max_piv = d11.abs().max(d21.abs()).max(d22.abs());
let det_floor_fail = if max_piv < SSIDS_DET_SMALL {
true
} else {
let det_scale = 1.0 / max_piv;
let detpiv0 = (d11 * det_scale) * d22;
let detpiv1 = (d21 * det_scale) * d21;
let detpiv = detpiv0 - detpiv1;
let cancel_floor = SSIDS_DET_SMALL
.max(detpiv0.abs() * 0.5)
.max(detpiv1.abs() * 0.5);
detpiv.abs() < cancel_floor
};
if growth_fail || det_floor_fail {
if may_delay {
return Ok(PivotStepResult::Delayed);
}
if det_floor_fail {
match params.on_zero_pivot {
ZeroPivotAction::Fail => {
return Err(FeralError::NumericallyRankDeficient);
}
ZeroPivotAction::ForceAccept => {
*needs_refinement = true;
}
}
}
let outcome = try_reject_1x1_with_rook_rescue(
a,
nrow,
ncol,
k,
gamma0,
may_delay,
params,
perm,
pos,
neg,
zero,
needs_refinement,
n_rook_rescues,
)?;
return Ok(finish_1x1_outcome(
outcome, a, nrow, k, subdiag, pos, neg, zero,
));
}
let pivot_inertia = count_2x2_inertia_val(d11, d21, d22);
*pos += pivot_inertia.positive;
*neg += pivot_inertia.negative;
*zero += pivot_inertia.zero;
subdiag[k] = d21;
do_2x2_update(a, nrow, k, d11, d21, d22);
Ok(PivotStepResult::Advanced(2))
} else {
let outcome = try_reject_1x1_with_rook_rescue(
a,
nrow,
ncol,
k,
gamma0,
may_delay,
params,
perm,
pos,
neg,
zero,
needs_refinement,
n_rook_rescues,
)?;
Ok(finish_1x1_outcome(
outcome, a, nrow, k, subdiag, pos, neg, zero,
))
}
}
#[allow(clippy::too_many_arguments)]
fn try_reject_1x1_frontal(
a: &mut [f64],
nrow: usize,
k: usize,
col_max: f64,
may_delay: bool,
params: &BunchKaufmanParams,
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
needs_refinement: &mut bool,
) -> Result<PivotOutcome, FeralError> {
let d = a[k * nrow + k];
let threshold = (params.pivot_threshold * col_max).max(params.zero_tol);
if d.abs() <= threshold {
if may_delay {
return Ok(PivotOutcome::Delayed);
}
if d.abs() <= params.zero_tol {
match params.on_zero_pivot {
ZeroPivotAction::ForceAccept => {
*needs_refinement = true;
*zero += 1;
}
ZeroPivotAction::Fail => return Err(FeralError::NumericallyRankDeficient),
}
for i in (k + 1)..nrow {
a[k * nrow + i] = 0.0;
}
a[k * nrow + k] = 0.0;
return Ok(PivotOutcome::Rejected);
}
*needs_refinement = true;
if d > 0.0 {
*pos += 1;
} else {
*neg += 1;
}
return Ok(PivotOutcome::Accepted);
}
if d > 0.0 {
*pos += 1;
} else {
*neg += 1;
}
Ok(PivotOutcome::Accepted)
}
#[allow(clippy::too_many_arguments)]
fn try_reject_1x1_with_rook_rescue(
a: &mut [f64],
nrow: usize,
ncol: usize,
k: usize,
col_max: f64,
may_delay: bool,
params: &BunchKaufmanParams,
perm: &mut [usize],
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
needs_refinement: &mut bool,
n_rook_rescues: &mut usize,
) -> Result<PivotOutcome, FeralError> {
let d = a[k * nrow + k];
let threshold = (params.pivot_threshold * col_max).max(params.zero_tol);
if d.abs() > threshold {
return try_reject_1x1_frontal(
a,
nrow,
k,
col_max,
may_delay,
params,
pos,
neg,
zero,
needs_refinement,
);
}
if let Some(pivot) = rook_rescue(a, nrow, ncol, k, params) {
*n_rook_rescues += 1;
for idx in 0..pivot.n_swaps {
let (p, q) = pivot.swaps[idx];
swap_rows_cols(a, nrow, p, q, perm);
}
match pivot.kind {
RookKind::Pivot1x1 => {
let d_new = a[k * nrow + k];
if d_new > 0.0 {
*pos += 1;
} else {
*neg += 1;
}
return Ok(PivotOutcome::Accepted);
}
RookKind::Pivot2x2 => {
let d11 = a[k * nrow + k];
let d21 = a[k * nrow + (k + 1)];
let d22 = a[(k + 1) * nrow + (k + 1)];
return Ok(PivotOutcome::AcceptedRook2x2 { d11, d21, d22 });
}
}
}
try_reject_1x1_frontal(
a,
nrow,
k,
col_max,
may_delay,
params,
pos,
neg,
zero,
needs_refinement,
)
}
fn do_1x1_update(a: &mut [f64], n: usize, k: usize) {
let d = a[k * n + k];
if d.abs() == 0.0 {
return;
}
let inv_d = 1.0 / d;
for i in (k + 1)..n {
a[k * n + i] *= inv_d;
}
for j in (k + 1)..n {
let l_jk = a[k * n + j];
let alpha = l_jk * d;
let (before, rest) = a.split_at_mut(j * n);
let src = &before[k * n + j..k * n + n];
let dst = &mut rest[j..n];
schur_kernel::axpy_minus_unroll4_nofma(dst, src, alpha);
}
}
fn do_2x2_update(a: &mut [f64], n: usize, k: usize, d11: f64, d21: f64, d22: f64) {
let det = d11 * d22 - d21 * d21;
if det.abs() == 0.0 {
return;
}
let inv_det = 1.0 / det;
for i in (k + 2)..n {
let a_ik = a[k * n + i];
let a_ik1 = a[(k + 1) * n + i];
a[k * n + i] = (d22 * a_ik - d21 * a_ik1) * inv_det;
a[(k + 1) * n + i] = (d11 * a_ik1 - d21 * a_ik) * inv_det;
}
for j in (k + 2)..n {
let l_j0 = a[k * n + j];
let l_j1 = a[(k + 1) * n + j];
let dl_j0 = d11 * l_j0 + d21 * l_j1;
let dl_j1 = d21 * l_j0 + d22 * l_j1;
let (before, rest) = a.split_at_mut(j * n);
let src0 = &before[k * n + j..k * n + n];
let src1 = &before[(k + 1) * n + j..(k + 1) * n + n];
let dst = &mut rest[j..n];
schur_kernel::axpy2_minus_unroll4_nofma(dst, src0, dl_j0, src1, dl_j1);
}
}
fn count_2x2_inertia_val(d11: f64, d21: f64, d22: f64) -> Inertia {
let det = d11 * d22 - d21 * d21;
let trace = d11 + d22;
if det > 0.0 {
if trace > 0.0 {
Inertia::new(2, 0, 0)
} else {
Inertia::new(0, 2, 0)
}
} else if det < 0.0 {
Inertia::new(1, 1, 0)
} else if trace > 0.0 {
Inertia::new(1, 0, 1)
} else if trace < 0.0 {
Inertia::new(0, 1, 1)
} else {
Inertia::new(0, 0, 2)
}
}
fn column_offdiag_max(a: &[f64], n: usize, k: usize) -> (f64, usize) {
let mut max_val = 0.0;
let mut max_idx = k + 1;
for i in (k + 1)..n {
let val = a[k * n + i].abs();
if val > max_val {
max_val = val;
max_idx = i;
}
}
(max_val, max_idx)
}
fn symmetric_row_offdiag_max(a: &[f64], n: usize, k: usize, r: usize) -> f64 {
let mut max_val = 0.0;
for i in (r + 1)..n {
let val = a[r * n + i].abs();
if val > max_val {
max_val = val;
}
}
for j in k..r {
let val = a[j * n + r].abs();
if val > max_val {
max_val = val;
}
}
max_val
}
fn swap_rows_cols(a: &mut [f64], n: usize, p: usize, q: usize, perm: &mut [usize]) {
if p == q {
return;
}
let (p, q) = if p < q { (p, q) } else { (q, p) };
perm.swap(p, q);
a.swap(p * n + p, q * n + q);
for i in (q + 1)..n {
a.swap(p * n + i, q * n + i);
}
for i in (p + 1)..q {
a.swap(p * n + i, i * n + q);
}
for j in 0..p {
a.swap(j * n + p, j * n + q);
}
}
#[allow(clippy::too_many_arguments)]
fn do_1x1_pivot(
a: &mut [f64],
n: usize,
k: usize,
col_max: f64,
params: &BunchKaufmanParams,
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
needs_refinement: &mut bool,
) -> Result<(f64, usize), FeralError> {
let d = a[k * n + k];
let threshold = (params.pivot_threshold * col_max).max(params.zero_tol);
if d.abs() <= threshold {
match params.on_zero_pivot {
ZeroPivotAction::ForceAccept => {
*needs_refinement = true;
*zero += 1;
}
ZeroPivotAction::Fail => return Err(FeralError::NumericallyRankDeficient),
}
for i in (k + 1)..n {
a[k * n + i] = 0.0;
}
a[k * n + k] = 0.0;
return Ok((0.0, k + 2));
}
if d > 0.0 {
*pos += 1;
} else {
*neg += 1;
}
let d_inv = 1.0 / d;
for i in (k + 1)..n {
a[k * n + i] *= d_inv;
}
let mut next_gamma0 = 0.0;
let mut next_r = k + 2;
if k + 1 < n {
let j = k + 1;
let l_jk = a[k * n + j];
let l_jk_d = l_jk * d;
a[j * n + j] -= a[k * n + j] * l_jk_d;
for i in (j + 1)..n {
a[j * n + i] -= a[k * n + i] * l_jk_d;
let val = a[j * n + i].abs();
if val > next_gamma0 {
next_gamma0 = val;
next_r = i;
}
}
}
for j in (k + 2)..n {
let l_jk = a[k * n + j];
let l_jk_d = l_jk * d;
for i in j..n {
a[j * n + i] -= a[k * n + i] * l_jk_d;
}
}
Ok((next_gamma0, next_r))
}
#[allow(clippy::too_many_arguments)]
fn do_2x2_pivot(
a: &mut [f64],
n: usize,
k: usize,
subdiag: &mut [f64],
params: &BunchKaufmanParams,
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
needs_refinement: &mut bool,
) -> Result<(f64, usize), FeralError> {
let a00 = a[k * n + k];
let a10 = a[k * n + (k + 1)];
let a11 = a[(k + 1) * n + (k + 1)];
subdiag[k] = a10;
let det = a00 * a11 - a10 * a10;
count_2x2_inertia(det, a00, a11, params, pos, neg, zero, needs_refinement)?;
if (k + 2) >= n {
return Ok((0.0, 0));
}
let d10_abs = a10.abs();
if d10_abs < f64::EPSILON * 1e-10 {
for i in (k + 2)..n {
a[k * n + i] = 0.0;
a[(k + 1) * n + i] = 0.0;
}
return Ok((0.0, k + 3));
}
let d00 = a00 / d10_abs;
let d11 = a11 / d10_abs;
let t = 1.0 / (d00 * d11 - 1.0);
let d10 = a10 / d10_abs; let d = t / d10_abs;
let mut next_gamma0 = 0.0;
let mut next_r = k + 3;
if k + 2 < n {
let j = k + 2;
let x0 = a[k * n + j];
let x1 = a[(k + 1) * n + j];
let w0 = (x0 * d11 - x1 * d10) * d;
let w1 = (x1 * d00 - x0 * d10) * d;
a[j * n + j] -= a[k * n + j] * w0 + a[(k + 1) * n + j] * w1;
for i in (j + 1)..n {
a[j * n + i] -= a[k * n + i] * w0 + a[(k + 1) * n + i] * w1;
let val = a[j * n + i].abs();
if val > next_gamma0 {
next_gamma0 = val;
next_r = i;
}
}
a[k * n + j] = w0;
a[(k + 1) * n + j] = w1;
}
for j in (k + 3)..n {
let x0 = a[k * n + j];
let x1 = a[(k + 1) * n + j];
let w0 = (x0 * d11 - x1 * d10) * d;
let w1 = (x1 * d00 - x0 * d10) * d;
for i in j..n {
a[j * n + i] -= a[k * n + i] * w0 + a[(k + 1) * n + i] * w1;
}
a[k * n + j] = w0;
a[(k + 1) * n + j] = w1;
}
Ok((next_gamma0, next_r))
}
fn count_1x1_inertia(
d: f64,
params: &BunchKaufmanParams,
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
needs_refinement: &mut bool,
) -> Result<(), FeralError> {
if d.abs() <= params.zero_tol {
match params.on_zero_pivot {
ZeroPivotAction::ForceAccept => {
*needs_refinement = true;
*zero += 1;
Ok(())
}
ZeroPivotAction::Fail => Err(FeralError::NumericallyRankDeficient),
}
} else if d > 0.0 {
*pos += 1;
Ok(())
} else {
*neg += 1;
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
fn count_2x2_inertia(
det: f64,
a00: f64,
a11: f64,
params: &BunchKaufmanParams,
pos: &mut usize,
neg: &mut usize,
zero: &mut usize,
needs_refinement: &mut bool,
) -> Result<(), FeralError> {
let trace = a00 + a11;
if det.abs() <= params.zero_tol_2x2 {
match params.on_zero_pivot {
ZeroPivotAction::ForceAccept => {
*needs_refinement = true;
if trace > 0.0 {
*pos += 1;
*zero += 1;
} else {
*neg += 1;
*zero += 1;
}
Ok(())
}
ZeroPivotAction::Fail => Err(FeralError::NumericallyRankDeficient),
}
} else if det > 0.0 {
if trace > 0.0 {
*pos += 2;
} else {
*neg += 2;
}
Ok(())
} else {
*pos += 1;
*neg += 1;
Ok(())
}
}
fn set_l_column_identity(a: &mut [f64], n: usize, k: usize) {
for i in (k + 1)..n {
a[k * n + i] = 0.0;
}
}
#[cfg(test)]
mod growth_flag_tests {
use super::*;
#[test]
fn growth_below_threshold_does_not_flag() {
let l = vec![1.0, 2.78, -2.5, 0.0, 100.0, -999_999.0];
let mut flag = false;
flag_growth_for_refinement(&l, &mut flag);
assert!(!flag, "max|L| = 999_999 < 1e6 should not flag");
}
#[test]
fn growth_above_threshold_flags() {
let l = vec![1.0, 2.0, 1.5e6, -3.0];
let mut flag = false;
flag_growth_for_refinement(&l, &mut flag);
assert!(flag, "max|L| = 1.5e6 > 1e6 must flag");
}
#[test]
fn catastrophic_growth_flags() {
let l = vec![1.0, 1.0, 8.06e16, 1.0];
let mut flag = false;
flag_growth_for_refinement(&l, &mut flag);
assert!(flag, "max|L| = 8e16 (bratu3d-class) must flag");
}
#[test]
fn negative_large_entry_flags() {
let l = vec![-2e10, 1.0];
let mut flag = false;
flag_growth_for_refinement(&l, &mut flag);
assert!(flag, "negative large |L| must flag");
}
#[test]
fn already_set_flag_is_preserved() {
let l = vec![0.0, 0.0]; let mut flag = true; flag_growth_for_refinement(&l, &mut flag);
assert!(flag, "must not clobber pre-set flag");
}
#[test]
fn empty_l_does_not_flag() {
let l: Vec<f64> = vec![];
let mut flag = false;
flag_growth_for_refinement(&l, &mut flag);
assert!(!flag);
}
#[test]
fn nan_and_inf_in_l_flag() {
let l_inf = vec![1.0, f64::INFINITY];
let mut flag = false;
flag_growth_for_refinement(&l_inf, &mut flag);
assert!(flag, "Inf entry must trigger");
}
}