use crate::dense::matrix::SymmetricMatrix;
use crate::scaling::ScalingInfo;
use crate::sparse::csc::CscMatrix;
pub fn compute_infnorm(matrix: &CscMatrix) -> (Vec<f64>, ScalingInfo) {
let n = matrix.n;
if n == 0 {
return (Vec::new(), ScalingInfo::Applied);
}
let mut d = vec![1.0f64; n];
let max_iter = 10;
let tol = 1e-8;
let mut row_max = vec![0.0f64; n];
for _ in 0..max_iter {
for r in row_max.iter_mut() {
*r = 0.0;
}
for j in 0..n {
for k in matrix.col_ptr[j]..matrix.col_ptr[j + 1] {
let i = matrix.row_idx[k];
let v = (d[i] * matrix.values[k] * d[j]).abs();
if v > row_max[i] {
row_max[i] = v;
}
if i != j && v > row_max[j] {
row_max[j] = v;
}
}
}
let mut max_dev = 0.0f64;
for i in 0..n {
let m = row_max[i];
if m > 0.0 {
d[i] /= m.sqrt();
let dev = (m - 1.0).abs();
if dev > max_dev {
max_dev = dev;
}
}
}
if max_dev < tol {
break;
}
}
(d, ScalingInfo::Applied)
}
pub fn compute_infnorm_dense(sym: &SymmetricMatrix) -> (Vec<f64>, ScalingInfo) {
let n = sym.n;
if n == 0 {
return (Vec::new(), ScalingInfo::Applied);
}
let mut d = vec![1.0f64; n];
let max_iter = 10;
let tol = 1e-8;
let mut row_max = vec![0.0f64; n];
for _ in 0..max_iter {
for r in row_max.iter_mut() {
*r = 0.0;
}
for j in 0..n {
let col = j * n;
let dj = d[j];
for i in j..n {
let v = (d[i] * sym.data[col + i] * dj).abs();
if v > row_max[i] {
row_max[i] = v;
}
if i != j && v > row_max[j] {
row_max[j] = v;
}
}
}
let mut max_dev = 0.0f64;
for i in 0..n {
let m = row_max[i];
if m > 0.0 {
d[i] /= m.sqrt();
let dev = (m - 1.0).abs();
if dev > max_dev {
max_dev = dev;
}
}
}
if max_dev < tol {
break;
}
}
(d, ScalingInfo::Applied)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse::csc::CscMatrix;
#[test]
fn diag_3x3() {
let m = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[2.0, 3.0, 5.0]).unwrap();
let (d, _info) = compute_infnorm(&m);
let expected = [1.0 / 2f64.sqrt(), 1.0 / 3f64.sqrt(), 1.0 / 5f64.sqrt()];
for i in 0..3 {
assert!(
(d[i] - expected[i]).abs() < 1e-12,
"d[{}] = {} != {}",
i,
d[i],
expected[i]
);
}
}
#[test]
fn sym_2x2() {
let m = CscMatrix::from_triplets(2, &[0, 1, 1], &[0, 0, 1], &[4.0, 2.0, 9.0]).unwrap();
let (d, _) = compute_infnorm(&m);
let a00 = d[0] * d[0] * 4.0;
let a01 = d[0] * d[1] * 2.0;
let a11 = d[1] * d[1] * 9.0;
let row0 = a00.abs().max(a01.abs());
let row1 = a01.abs().max(a11.abs());
assert!((row0 - 1.0).abs() < 1e-6, "row0 max = {}", row0);
assert!((row1 - 1.0).abs() < 1e-6, "row1 max = {}", row1);
}
#[test]
fn dense_matches_sparse_on_arrow_6x6() {
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for j in 0..6 {
rows.push(j);
cols.push(j);
vals.push((j + 2) as f64);
}
for j in 0..5 {
rows.push(5);
cols.push(j);
vals.push(1.0);
}
let m = CscMatrix::from_triplets(6, &rows, &cols, &vals).unwrap();
let sym = m.to_dense();
let (d_sparse, _) = compute_infnorm(&m);
let (d_dense, _) = compute_infnorm_dense(&sym);
assert_eq!(d_sparse.len(), d_dense.len());
for i in 0..d_sparse.len() {
assert_eq!(
d_sparse[i].to_bits(),
d_dense[i].to_bits(),
"dense-vs-sparse KR parity broke at i={}: sparse={} dense={}",
i,
d_sparse[i],
d_dense[i],
);
}
}
#[test]
fn dense_matches_sparse_on_dense_5x5() {
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for j in 0..5 {
for i in j..5 {
rows.push(i);
cols.push(j);
vals.push(if i == j {
10.0 * (i as f64 + 1.0)
} else {
1.0 + 0.1 * (i - j) as f64
});
}
}
let m = CscMatrix::from_triplets(5, &rows, &cols, &vals).unwrap();
let sym = m.to_dense();
let (d_sparse, _) = compute_infnorm(&m);
let (d_dense, _) = compute_infnorm_dense(&sym);
for i in 0..5 {
assert_eq!(
d_sparse[i].to_bits(),
d_dense[i].to_bits(),
"dense KR diverged at i={}: sparse={} dense={}",
i,
d_sparse[i],
d_dense[i],
);
}
}
#[test]
fn arrow_6x6() {
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for j in 0..6 {
rows.push(j);
cols.push(j);
vals.push((j + 2) as f64);
}
for j in 0..5 {
rows.push(5);
cols.push(j);
vals.push(1.0);
}
let m = CscMatrix::from_triplets(6, &rows, &cols, &vals).unwrap();
let (d, _) = compute_infnorm(&m);
for i in 0..6 {
let mut row_max = 0.0f64;
for j in 0..6 {
let (ii, jj) = if i >= j { (i, j) } else { (j, i) };
let mut v = 0.0;
for k in m.col_ptr[jj]..m.col_ptr[jj + 1] {
if m.row_idx[k] == ii {
v = m.values[k];
break;
}
}
let scaled = (d[i] * v * d[j]).abs();
if scaled > row_max {
row_max = scaled;
}
}
assert!(
(row_max - 1.0).abs() < 1e-6,
"row {} max = {}, expected 1",
i,
row_max
);
}
}
}