use crate::dense::factor::BunchKaufmanParams;
const SSIDS_DET_SMALL: f64 = 1e-20;
const MAX_ROOK_ITER: usize = 8;
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RookKind {
Pivot1x1,
Pivot2x2,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct RookPivot {
pub kind: RookKind,
pub swaps: [(usize, usize); 2],
pub n_swaps: usize,
}
#[allow(dead_code)]
impl RookPivot {
fn single(k: usize, r: usize) -> Self {
let (swaps, n_swaps) = if r == k {
([(0, 0); 2], 0)
} else {
([(k, r), (0, 0)], 1)
};
RookPivot {
kind: RookKind::Pivot1x1,
swaps,
n_swaps,
}
}
fn pair(k: usize, r: usize, s: usize) -> Self {
if r == k + 1 && s == k {
return RookPivot {
kind: RookKind::Pivot2x2,
swaps: [(k, k + 1), (0, 0)],
n_swaps: 1,
};
}
let mut swaps = [(0, 0); 2];
let mut n = 0;
if r != k {
swaps[n] = (k, r);
n += 1;
}
let s_after = if r != k {
if s == k {
r
} else if s == r {
k
} else {
s
}
} else {
s
};
if s_after != k + 1 {
swaps[n] = (k + 1, s_after);
n += 1;
}
RookPivot {
kind: RookKind::Pivot2x2,
swaps,
n_swaps: n,
}
}
}
#[inline]
fn sym_elem(a: &[f64], nrow: usize, i: usize, j: usize) -> f64 {
let (row, col) = if i >= j { (i, j) } else { (j, i) };
a[col * nrow + row]
}
fn sym_row_argmax(a: &[f64], nrow: usize, k: usize, r: usize) -> (f64, usize) {
let mut max_val = 0.0f64;
let mut max_idx = r;
for i in k..r {
let v = a[i * nrow + r].abs();
if v > max_val {
max_val = v;
max_idx = i;
}
}
for i in (r + 1)..nrow {
let v = a[r * nrow + i].abs();
if v > max_val {
max_val = v;
max_idx = i;
}
}
(max_val, max_idx)
}
fn passes_2x2_gates(a: &[f64], nrow: usize, k: usize, p: usize, q: usize, u: f64) -> bool {
debug_assert!(p < q);
let d11 = sym_elem(a, nrow, p, p);
let d22 = sym_elem(a, nrow, q, q);
let d21 = sym_elem(a, nrow, q, p);
let det = d11 * d22 - d21 * d21;
let max_piv = d11.abs().max(d21.abs()).max(d22.abs());
if max_piv < SSIDS_DET_SMALL {
return false;
}
let det_scale = 1.0 / max_piv;
let detpiv0 = (d11 * det_scale) * d22;
let detpiv1 = (d21 * det_scale) * d21;
let detpiv = detpiv0 - detpiv1;
let cancel_floor = SSIDS_DET_SMALL
.max(detpiv0.abs() * 0.5)
.max(detpiv1.abs() * 0.5);
if detpiv.abs() < cancel_floor {
return false;
}
let mut rmax = 0.0f64;
let mut tmax = 0.0f64;
for i in k..nrow {
if i == p || i == q {
continue;
}
let r_val = sym_elem(a, nrow, i, p).abs();
if r_val > rmax {
rmax = r_val;
}
let t_val = sym_elem(a, nrow, i, q).abs();
if t_val > tmax {
tmax = t_val;
}
}
let amax = d21.abs();
let absdet = det.abs();
(d22.abs() * rmax + amax * tmax) * u <= absdet && (d11.abs() * tmax + amax * rmax) * u <= absdet
}
#[allow(dead_code)]
pub(crate) fn rook_rescue(
a: &[f64],
nrow: usize,
ncol: usize,
k: usize,
params: &BunchKaufmanParams,
) -> Option<RookPivot> {
let u = params.pivot_threshold;
if u <= 0.0 {
return None;
}
if k >= ncol {
return None;
}
let mut r = k;
let (mut gamma_r, mut s) = sym_row_argmax(a, nrow, k, r);
if gamma_r == 0.0 {
return None;
}
for _ in 0..MAX_ROOK_ITER {
if r < ncol {
let arr = a[r * nrow + r].abs();
if arr >= u * gamma_r {
return Some(RookPivot::single(k, r));
}
}
let (gamma_s, t) = sym_row_argmax(a, nrow, k, s);
if gamma_s == 0.0 {
return None;
}
if s < ncol {
let ass = a[s * nrow + s].abs();
if ass >= u * gamma_s {
return Some(RookPivot::single(k, s));
}
}
if r < ncol && s < ncol && gamma_s <= gamma_r {
let (p, q) = if r < s { (r, s) } else { (s, r) };
if passes_2x2_gates(a, nrow, k, p, q, u) {
return Some(RookPivot::pair(k, r, s));
}
}
r = s;
gamma_r = gamma_s;
s = t;
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn default_params_u(u: f64) -> BunchKaufmanParams {
BunchKaufmanParams {
pivot_threshold: u,
..BunchKaufmanParams::default()
}
}
#[test]
fn rook_accepts_1x1_at_row_2() {
let n = 4;
let mut a = vec![0.0f64; n * n];
a[0] = 0.008;
a[1] = 1.0; a[n + 1] = 0.5; a[n + 2] = 100.0; a[2 * n + 2] = 500.0; a[3 * n + 3] = 1.0;
let params = default_params_u(0.01);
let pivot = rook_rescue(&a, n, n, 0, ¶ms).expect("rook should find a pivot");
assert_eq!(pivot.kind, RookKind::Pivot1x1);
assert_eq!(pivot.n_swaps, 1);
assert_eq!(pivot.swaps[0], (0, 2));
}
#[test]
fn rook_accepts_2x2_at_rows_1_2() {
let n = 5;
let mut a = vec![0.0f64; n * n];
a[0] = 0.008;
a[1] = 1.0;
a[n + 1] = 0.1;
a[n + 2] = 1.0e4;
a[2 * n + 2] = 0.1;
a[3 * n + 3] = 1.0;
a[4 * n + 4] = 1.0;
let params = default_params_u(0.01);
let pivot = rook_rescue(&a, n, n, 0, ¶ms).expect("rook should find a pivot");
assert_eq!(pivot.kind, RookKind::Pivot2x2);
assert_eq!(pivot.n_swaps, 2);
assert_eq!(pivot.swaps[0], (0, 1));
assert_eq!(pivot.swaps[1], (1, 2));
}
#[test]
fn rook_accepts_1x1_at_k_when_diagonally_dominant() {
let n = 4;
let mut a = vec![0.0f64; n * n];
for i in 0..n {
a[i * n + i] = 10.0 + i as f64;
}
a[1] = 1.0;
a[n + 2] = 0.5;
a[2 * n + 3] = 0.25;
let params = default_params_u(0.01);
let pivot = rook_rescue(&a, n, n, 0, ¶ms).expect("rook finds pivot on SPD");
assert_eq!(pivot.kind, RookKind::Pivot1x1);
assert_eq!(pivot.n_swaps, 0);
}
#[test]
fn rook_declines_when_threshold_zero() {
let n = 3;
let mut a = vec![0.0f64; n * n];
a[0] = 1.0;
a[n + 1] = 1.0;
a[2 * n + 2] = 1.0;
let params = default_params_u(0.0);
assert!(rook_rescue(&a, n, n, 0, ¶ms).is_none());
}
#[test]
fn rook_declines_when_only_ghost_has_good_pivot() {
let n = 3;
let ncol = 2;
let mut a = vec![0.0f64; n * n];
a[0] = 0.001;
a[1] = 1.0;
a[n + 1] = 0.001;
a[2 * n + 2] = 100.0;
let params = default_params_u(0.01);
let mut b = vec![0.0f64; n * n];
b[0] = 0.001;
b[1] = 0.0; b[n + 1] = 0.001;
b[n + 2] = 0.0; b[2] = 1.0; b[2 * n + 2] = 100.0;
let _ = a; let pivot = rook_rescue(&b, n, ncol, 0, ¶ms);
assert!(
pivot.is_none(),
"rook must decline when only ghost rows host sufficient pivots, got {:?}",
pivot
);
}
}