use pounce_common::types::{Index, Number};
use pounce_linsol::scaling::TSymScalingMethod;
use crate::ffi::mc19ad_;
const MAX_VALID_SCALE: Number = 1e40;
#[derive(Debug, Default, Clone, Copy)]
pub struct Mc19TSymScalingMethod;
impl Mc19TSymScalingMethod {
pub fn new() -> Self {
Self
}
}
impl TSymScalingMethod for Mc19TSymScalingMethod {
fn compute_sym_t_scaling_factors(
&mut self,
n: Index,
nnz: Index,
airn: &[Index],
ajcn: &[Index],
a: &[Number],
scaling_factors: &mut [Number],
) -> bool {
let n_us = n as usize;
let nnz_us = nnz as usize;
debug_assert_eq!(airn.len(), nnz_us);
debug_assert_eq!(ajcn.len(), nnz_us);
debug_assert_eq!(a.len(), nnz_us);
debug_assert_eq!(scaling_factors.len(), n_us);
let mut airn2: Vec<Index> = Vec::with_capacity(2 * nnz_us);
let mut ajcn2: Vec<Index> = Vec::with_capacity(2 * nnz_us);
let mut a2: Vec<Number> = Vec::with_capacity(2 * nnz_us);
for k in 0..nnz_us {
let (i, j, v) = (airn[k], ajcn[k], a[k]);
airn2.push(i);
ajcn2.push(j);
a2.push(v);
if i != j {
airn2.push(j);
ajcn2.push(i);
a2.push(v);
}
}
let nnz2: Index = airn2.len() as Index;
let mut r = vec![0.0f32; n_us];
let mut c = vec![0.0f32; n_us];
let mut w = vec![0.0f32; 5 * n_us];
unsafe {
mc19ad_(
&n,
&nnz2,
a2.as_mut_ptr(),
airn2.as_mut_ptr(),
ajcn2.as_mut_ptr(),
r.as_mut_ptr(),
c.as_mut_ptr(),
w.as_mut_ptr(),
);
}
let mut sum: Number = 0.0;
let mut smax: Number = 0.0;
for i in 0..n_us {
let s = ((Number::from(r[i]) + Number::from(c[i])) / 2.0).exp();
scaling_factors[i] = s;
sum += s;
if s > smax {
smax = s;
}
}
if !sum.is_finite() || smax > MAX_VALID_SCALE {
for s in scaling_factors.iter_mut() {
*s = 1.0;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn diagonal_matrix_produces_well_balanced_factors() {
let mut method = Mc19TSymScalingMethod::new();
let irn = [1, 2];
let jcn = [1, 2];
let a = [1.0e8, 1.0e-8];
let mut s = vec![0.0; 2];
assert!(method.compute_sym_t_scaling_factors(2, 2, &irn, &jcn, &a, &mut s));
assert!(s[0] > 1e-6 && s[0] < 1e-2, "s[0] = {}", s[0]);
assert!(s[1] > 1e2 && s[1] < 1e6, "s[1] = {}", s[1]);
for k in 0..2 {
let i = irn[k] as usize - 1;
let j = jcn[k] as usize - 1;
let scaled = s[i] * a[k] * s[j];
assert!(scaled > 1e-3 && scaled < 1e3, "scaled[{k}] = {}", scaled);
}
}
#[test]
fn off_diagonals_get_mirrored_and_factors_finite() {
let mut method = Mc19TSymScalingMethod::new();
let irn = [1, 1, 2];
let jcn = [1, 2, 2];
let a = [4.0, 2.0, 100.0];
let mut s = vec![0.0; 2];
assert!(method.compute_sym_t_scaling_factors(2, 3, &irn, &jcn, &a, &mut s));
assert!(s.iter().all(|x| x.is_finite() && *x > 0.0));
}
#[test]
fn invalid_factors_fall_back_to_identity() {
let mut s = [f64::NAN, 1.0];
let sum: Number = s.iter().sum();
if !sum.is_finite() {
for x in s.iter_mut() {
*x = 1.0;
}
}
assert_eq!(s, [1.0, 1.0]);
}
}