use super::sparse_matrix::SparseColMatrix;
use super::LuScaling;
use crate::error::FeralError;
use crate::scaling::{hungarian_match, CostGraph};
#[derive(Debug, Clone)]
pub struct LuScale {
pub rperm: Vec<usize>,
pub d_row: Vec<f64>,
pub d_col: Vec<f64>,
}
impl LuScale {
pub fn identity(m: usize) -> Self {
LuScale {
rperm: (0..m).collect(),
d_row: vec![1.0; m],
d_col: vec![1.0; m],
}
}
pub fn is_identity(&self) -> bool {
self.rperm.iter().enumerate().all(|(i, &r)| i == r)
&& self.d_row.iter().all(|&x| x == 1.0)
&& self.d_col.iter().all(|&x| x == 1.0)
}
pub fn scaled_dense_columns(&self, b: &SparseColMatrix) -> Vec<Vec<f64>> {
let m = b.m;
let mut rperm_inv = vec![0usize; m];
for (i, &r) in self.rperm.iter().enumerate() {
rperm_inv[r] = i;
}
let mut cols = vec![vec![0.0; m]; m];
for (j, colj) in cols.iter_mut().enumerate() {
let (rows, vals) = b.column(j);
for (&r, &v) in rows.iter().zip(vals.iter()) {
let i = rperm_inv[r];
colj[i] = self.d_row[i] * v * self.d_col[j];
}
}
cols
}
}
impl LuScale {
pub fn apply_sparse(&self, b: &SparseColMatrix) -> Result<SparseColMatrix, FeralError> {
let m = b.m;
let mut rperm_inv = vec![0usize; m];
for (i, &r) in self.rperm.iter().enumerate() {
rperm_inv[r] = i;
}
let mut sparse_cols: Vec<Vec<(usize, f64)>> = vec![Vec::new(); m];
for (j, colj) in sparse_cols.iter_mut().enumerate() {
let (rows, vals) = b.column(j);
for (&r, &v) in rows.iter().zip(vals.iter()) {
let i = rperm_inv[r];
colj.push((i, self.d_row[i] * v * self.d_col[j]));
}
}
SparseColMatrix::from_sparse_columns(m, &sparse_cols)
}
}
pub fn compute_lu_scale(b: &SparseColMatrix, strategy: LuScaling) -> Result<LuScale, FeralError> {
match strategy {
LuScaling::None => Ok(LuScale::identity(b.m)),
LuScaling::InfNorm => Ok(equilibrate_infnorm(b, 5)),
LuScaling::Mc64 => mc64_unsymmetric(b),
LuScaling::Mc64ThenInfNorm => {
let mc = mc64_unsymmetric(b)?;
Ok(compose_with_infnorm(b, mc))
}
}
}
pub fn equilibrate_infnorm(b: &SparseColMatrix, iters: usize) -> LuScale {
let m = b.m;
let mut d_row = vec![1.0_f64; m];
let mut d_col = vec![1.0_f64; m];
for _ in 0..iters {
let mut row_max = vec![0.0_f64; m];
let mut col_max = vec![0.0_f64; m];
for (j, cmj) in col_max.iter_mut().enumerate() {
let (rows, vals) = b.column(j);
for (&i, &v) in rows.iter().zip(vals.iter()) {
let a = (d_row[i] * v * d_col[j]).abs();
if a > row_max[i] {
row_max[i] = a;
}
if a > *cmj {
*cmj = a;
}
}
}
for (dr, &rm) in d_row.iter_mut().zip(row_max.iter()) {
if rm > 0.0 {
*dr /= rm.sqrt();
}
}
for (dc, &cm) in d_col.iter_mut().zip(col_max.iter()) {
if cm > 0.0 {
*dc /= cm.sqrt();
}
}
}
LuScale {
rperm: (0..m).collect(),
d_row,
d_col,
}
}
pub fn mc64_unsymmetric(b: &SparseColMatrix) -> Result<LuScale, FeralError> {
let m = b.m;
if m == 0 {
return Ok(LuScale::identity(0));
}
let (graph, cmax) = build_cost_graph(b);
let matching = hungarian_match(&graph);
if matching.n_matched != m {
return Ok(equilibrate_infnorm(b, 5));
}
let rperm = matching.perm.clone();
let mut d_col = vec![1.0; m];
for (j, dc) in d_col.iter_mut().enumerate() {
*dc = clamp_exp(matching.v[j] - cmax[j]).exp();
}
let mut d_row = vec![1.0; m];
for (i, dr) in d_row.iter_mut().enumerate() {
*dr = clamp_exp(matching.u[rperm[i]]).exp();
}
Ok(LuScale {
rperm,
d_row,
d_col,
})
}
fn compose_with_infnorm(b: &SparseColMatrix, mut scale: LuScale) -> LuScale {
let m = b.m;
let mut rperm_inv = vec![0usize; m];
for (i, &r) in scale.rperm.iter().enumerate() {
rperm_inv[r] = i;
}
for _ in 0..3 {
let mut row_max = vec![0.0_f64; m];
let mut col_max = vec![0.0_f64; m];
for (j, cmj) in col_max.iter_mut().enumerate() {
let (rows, vals) = b.column(j);
for (&r, &v) in rows.iter().zip(vals.iter()) {
let i = rperm_inv[r];
let a = (scale.d_row[i] * v * scale.d_col[j]).abs();
if a > row_max[i] {
row_max[i] = a;
}
if a > *cmj {
*cmj = a;
}
}
}
for (dr, &rm) in scale.d_row.iter_mut().zip(row_max.iter()) {
if rm > 0.0 {
*dr /= rm.sqrt();
}
}
for (dc, &cm) in scale.d_col.iter_mut().zip(col_max.iter()) {
if cm > 0.0 {
*dc /= cm.sqrt();
}
}
}
scale
}
fn clamp_exp(x: f64) -> f64 {
x.clamp(-709.0, 709.0)
}
fn build_cost_graph(b: &SparseColMatrix) -> (CostGraph, Vec<f64>) {
let m = b.m;
let mut col_ptr = Vec::with_capacity(m + 1);
let mut row_idx = Vec::new();
let mut cost = Vec::new();
let mut cmax = vec![f64::NEG_INFINITY; m];
col_ptr.push(0);
for (j, cm) in cmax.iter_mut().enumerate() {
let (_rows, vals) = b.column(j);
for &v in vals {
if v != 0.0 {
let c = v.abs().ln();
if c > *cm {
*cm = c;
}
}
}
}
for (j, &cmax_j) in cmax.iter().enumerate() {
let (rows, vals) = b.column(j);
for (&i, &v) in rows.iter().zip(vals.iter()) {
if v != 0.0 {
let c = cmax_j - v.abs().ln();
row_idx.push(i);
cost.push(if c < 0.0 { 0.0 } else { c });
}
}
col_ptr.push(row_idx.len());
}
(
CostGraph {
n: m,
col_ptr,
row_idx,
cost,
},
cmax,
)
}