use crate::sparse::CscMatrix;
use std::collections::HashMap;
pub(crate) fn psd_shift_from_gershgorin(q: &CscMatrix) -> f64 {
let n = q.nrows;
if n == 0 {
return 0.0;
}
let mut diag = vec![0.0_f64; n];
let mut canonical: HashMap<(usize, usize), f64> = HashMap::new();
for col in 0..n {
for k in q.col_ptr[col]..q.col_ptr[col + 1] {
let row = q.row_ind[k];
let val = q.values[k];
if row == col {
diag[col] = val;
} else {
let key = if row < col { (row, col) } else { (col, row) };
let abs_val = val.abs();
let entry = canonical.entry(key).or_insert(0.0);
if abs_val > *entry {
*entry = abs_val;
}
}
}
}
let mut row_offdiag_sum = vec![0.0_f64; n];
for (&(i, j), &abs_val) in canonical.iter() {
row_offdiag_sum[i] += abs_val;
row_offdiag_sum[j] += abs_val;
}
let mut shift = 0.0_f64;
for j in 0..n {
let lower = diag[j] - row_offdiag_sum[j];
if lower < 0.0 {
shift = shift.max(-lower);
}
}
shift
}
#[cfg(test)]
mod tests {
use super::*;
fn upper_tri(n: usize, entries: &[(usize, usize, f64)]) -> CscMatrix {
let rows: Vec<usize> = entries.iter().map(|&(r, _, _)| r).collect();
let cols: Vec<usize> = entries.iter().map(|&(_, c, _)| c).collect();
let vals: Vec<f64> = entries.iter().map(|&(_, _, v)| v).collect();
CscMatrix::from_triplets(&rows, &cols, &vals, n, n).unwrap()
}
fn lower_tri(n: usize, entries: &[(usize, usize, f64)]) -> CscMatrix {
for &(r, c, _) in entries {
assert!(r >= c, "lower-tri requires row >= col, got ({r},{c})");
}
let rows: Vec<usize> = entries.iter().map(|&(r, _, _)| r).collect();
let cols: Vec<usize> = entries.iter().map(|&(_, c, _)| c).collect();
let vals: Vec<f64> = entries.iter().map(|&(_, _, v)| v).collect();
CscMatrix::from_triplets(&rows, &cols, &vals, n, n).unwrap()
}
#[test]
fn empty_matrix_returns_zero() {
let q = CscMatrix::new(0, 0);
assert_eq!(psd_shift_from_gershgorin(&q), 0.0);
}
#[test]
fn diagonal_psd_returns_zero() {
let q = upper_tri(2, &[(0, 0, 1.0), (1, 1, 2.0)]);
assert_eq!(psd_shift_from_gershgorin(&q), 0.0);
}
#[test]
fn diagonal_negative_returns_abs_min_diag() {
let q = upper_tri(2, &[(0, 0, -2.0), (1, 1, -3.0)]);
assert!((psd_shift_from_gershgorin(&q) - 3.0).abs() < 1e-12);
}
#[test]
fn pure_bilinear_zero_diag_off_one() {
let q = upper_tri(2, &[(0, 1, 1.0)]);
assert!((psd_shift_from_gershgorin(&q) - 1.0).abs() < 1e-12);
}
#[test]
fn full_symmetric_input_matches_upper() {
let q_full =
CscMatrix::from_triplets(&[0, 1], &[1, 0], &[1.0, 1.0], 2, 2).unwrap();
let q_upper = upper_tri(2, &[(0, 1, 1.0)]);
assert_eq!(
psd_shift_from_gershgorin(&q_full),
psd_shift_from_gershgorin(&q_upper),
);
}
#[test]
fn mixed_zero_and_negative_diag() {
let q = upper_tri(2, &[(0, 1, 1.0), (1, 1, -1.0)]);
assert!((psd_shift_from_gershgorin(&q) - 2.0).abs() < 1e-12);
}
#[test]
fn extreme_offdiag_dominates_diag() {
let q = upper_tri(4, &[(0, 0, 2.0), (0, 3, 3.0)]);
assert!((psd_shift_from_gershgorin(&q) - 3.0).abs() < 1e-12);
}
#[test]
fn psd_with_large_offdiag_returns_positive_shift_false_alarm() {
let q = upper_tri(2, &[(0, 0, 1.0), (0, 1, 1.1), (1, 1, 2.0)]);
assert!((psd_shift_from_gershgorin(&q) - 0.1).abs() < 1e-12);
}
#[test]
fn lower_triangular_only_zero_diag_bilinear_matches_upper() {
let q_lower = lower_tri(2, &[(1, 0, 1.0)]);
let q_upper = upper_tri(2, &[(0, 1, 1.0)]);
assert!((psd_shift_from_gershgorin(&q_lower) - 1.0).abs() < 1e-12);
assert_eq!(
psd_shift_from_gershgorin(&q_lower),
psd_shift_from_gershgorin(&q_upper),
"lower-only と upper-only は同じ shift を返すべき (対称化)"
);
}
#[test]
fn lower_triangular_only_offdiag_dominant_indefinite() {
let q_lower = lower_tri(2, &[(0, 0, 1.0), (1, 0, 2.0), (1, 1, 1.0)]);
assert!((psd_shift_from_gershgorin(&q_lower) - 1.0).abs() < 1e-12);
}
#[test]
fn three_layouts_agree_on_indefinite_q() {
let upper = upper_tri(3, &[
(0, 0, 1.0), (0, 1, 2.0),
(1, 1, 1.0), (1, 2, -1.0),
(2, 2, 1.0),
]);
let lower = lower_tri(3, &[
(0, 0, 1.0),
(1, 0, 2.0), (1, 1, 1.0),
(2, 1, -1.0), (2, 2, 1.0),
]);
let full = CscMatrix::from_triplets(
&[0, 0, 1, 1, 1, 2, 2],
&[0, 1, 0, 1, 2, 1, 2],
&[1.0, 2.0, 2.0, 1.0, -1.0, -1.0, 1.0],
3, 3,
).unwrap();
let s_upper = psd_shift_from_gershgorin(&upper);
let s_lower = psd_shift_from_gershgorin(&lower);
let s_full = psd_shift_from_gershgorin(&full);
assert!((s_upper - 2.0).abs() < 1e-12, "upper shift = {s_upper}");
assert!((s_lower - 2.0).abs() < 1e-12, "lower shift = {s_lower}");
assert!((s_full - 2.0).abs() < 1e-12, "full shift = {s_full}");
}
#[test]
fn mixed_asymmetric_no_pair_canonicalizes() {
let q = CscMatrix::from_triplets(
&[0, 2],
&[1, 1],
&[1.0, 1.0],
3,
3,
)
.unwrap();
let s = psd_shift_from_gershgorin(&q);
assert!((s - 2.0).abs() < 1e-12, "mixed-asymm shift = {s} (期待 2.0)");
}
#[test]
fn mixed_asymmetric_matches_full_symmetric() {
let q_mixed = CscMatrix::from_triplets(
&[0, 2],
&[1, 1],
&[1.0, 1.0],
3,
3,
)
.unwrap();
let q_full = CscMatrix::from_triplets(
&[0, 1, 2, 1],
&[1, 0, 1, 2],
&[1.0, 1.0, 1.0, 1.0],
3,
3,
)
.unwrap();
let s_mixed = psd_shift_from_gershgorin(&q_mixed);
let s_full = psd_shift_from_gershgorin(&q_full);
assert!((s_mixed - s_full).abs() < 1e-12, "mixed={s_mixed} vs full={s_full}");
}
#[test]
fn asymmetric_value_pair_takes_max_abs() {
let q = CscMatrix::from_triplets(
&[0, 1],
&[1, 0],
&[1.0, 3.0],
2,
2,
)
.unwrap();
let s = psd_shift_from_gershgorin(&q);
assert!((s - 3.0).abs() < 1e-12, "max-abs shift = {s} (期待 3.0)");
}
#[test]
fn no_op_proof_lower_tri_symmetrize_required() {
fn legacy_row_lt_col_only(q: &CscMatrix) -> f64 {
let n = q.nrows;
if n == 0 { return 0.0; }
let mut diag = vec![0.0_f64; n];
let mut row_sum = vec![0.0_f64; n];
for col in 0..n {
for k in q.col_ptr[col]..q.col_ptr[col + 1] {
let row = q.row_ind[k];
let val = q.values[k];
if row == col {
diag[col] = val;
} else if row < col {
let abs = val.abs();
row_sum[row] += abs;
row_sum[col] += abs;
}
}
}
let mut shift = 0.0_f64;
for j in 0..n {
let lower = diag[j] - row_sum[j];
if lower < 0.0 { shift = shift.max(-lower); }
}
shift
}
let q_lower = lower_tri(2, &[(1, 0, 1.0)]);
let legacy = legacy_row_lt_col_only(&q_lower);
assert_eq!(legacy, 0.0, "旧 impl は lower-only off-diag を取り零し shift=0 (bug)");
let fixed = psd_shift_from_gershgorin(&q_lower);
assert!((fixed - 1.0).abs() < 1e-12, "新 impl は対称化 shift=1 を返す");
assert!(
(legacy - fixed).abs() > 0.5,
"fix の有無で挙動が乖離 (legacy={legacy}, fixed={fixed}) = sentinel が active"
);
}
#[test]
fn no_op_proof_mixed_asymmetric_canonicalize_required() {
fn legacy_has_upper_lower_half(q: &CscMatrix) -> f64 {
let n = q.nrows;
if n == 0 { return 0.0; }
let mut diag = vec![0.0_f64; n];
let mut row_sum = vec![0.0_f64; n];
let mut has_upper = false;
let mut has_lower = false;
for col in 0..n {
for k in q.col_ptr[col]..q.col_ptr[col + 1] {
let row = q.row_ind[k];
let val = q.values[k];
if row == col {
diag[col] = val;
} else {
if row < col { has_upper = true; } else { has_lower = true; }
let abs = val.abs();
row_sum[row] += abs;
row_sum[col] += abs;
}
}
}
if has_upper && has_lower {
for r in row_sum.iter_mut() { *r *= 0.5; }
}
let mut shift = 0.0_f64;
for j in 0..n {
let lower = diag[j] - row_sum[j];
if lower < 0.0 { shift = shift.max(-lower); }
}
shift
}
let q = CscMatrix::from_triplets(
&[0, 2], &[1, 1], &[1.0, 1.0], 3, 3,
).unwrap();
let legacy = legacy_has_upper_lower_half(&q);
assert!((legacy - 1.0).abs() < 1e-12, "旧 impl は mixed-asymm を full と誤認、shift={legacy}");
let fixed = psd_shift_from_gershgorin(&q);
assert!((fixed - 2.0).abs() < 1e-12, "新 impl: canonical dedup で shift={fixed}");
assert!(
(legacy - fixed).abs() > 0.5,
"fix の有無で乖離 (legacy={legacy}, fixed={fixed}) = sentinel が active"
);
}
}