#![allow(clippy::needless_range_loop)]
use super::condition::{estimate_inverse_norm_1, matrix_norm_1};
use super::factorize::SparseFactors;
use crate::error::FeralError;
use crate::scaling::ScalingInfo;
use crate::sparse::csc::CscMatrix;
const BLAS3_NRHS_THRESHOLD: usize = 32;
pub(crate) const BLAS3_REFINE_THRESHOLD: usize = 16;
pub fn solve_sparse(factors: &SparseFactors, rhs: &[f64]) -> Result<Vec<f64>, FeralError> {
let n = factors.n;
if n == 0 && rhs.is_empty() {
return Ok(Vec::new());
}
let mut x = vec![0.0; n];
let mut ws = SolveWorkspace::for_factors(factors);
solve_sparse_into_ws(factors, rhs, &mut x, &mut ws)?;
Ok(x)
}
#[cfg(test)]
thread_local! {
pub(super) static SOLVE_WORKSPACE_BUILDS: std::cell::Cell<usize> =
const { std::cell::Cell::new(0) };
}
#[cfg(test)]
pub(super) fn reset_solve_workspace_builds() {
SOLVE_WORKSPACE_BUILDS.with(|c| c.set(0));
}
#[cfg(test)]
pub(super) fn solve_workspace_builds() -> usize {
SOLVE_WORKSPACE_BUILDS.with(|c| c.get())
}
pub(super) struct SolveWorkspace {
y: Vec<f64>,
w: Vec<f64>,
scaled_rhs: Vec<f64>,
}
impl SolveWorkspace {
pub(super) fn for_factors(factors: &SparseFactors) -> Self {
#[cfg(test)]
SOLVE_WORKSPACE_BUILDS.with(|c| c.set(c.get() + 1));
let n = factors.n;
let max_nrow = factors
.node_factors
.iter()
.map(|node| node.frontal_factors.nrow)
.max()
.unwrap_or(0);
let scaled_rhs_len = if matches!(factors.scaling_info, ScalingInfo::NotApplied) {
0
} else {
n
};
Self {
y: vec![0.0; n],
w: vec![0.0; max_nrow],
scaled_rhs: vec![0.0; scaled_rhs_len],
}
}
}
pub(super) fn solve_sparse_into_ws(
factors: &SparseFactors,
rhs: &[f64],
x_out: &mut [f64],
ws: &mut SolveWorkspace,
) -> Result<(), FeralError> {
let n = factors.n;
if rhs.len() != n {
return Err(FeralError::DimensionMismatch {
expected: n,
got: rhs.len(),
});
}
if x_out.len() != n {
return Err(FeralError::DimensionMismatch {
expected: n,
got: x_out.len(),
});
}
if n == 0 {
return Ok(());
}
let needs_scaling = !matches!(factors.scaling_info, ScalingInfo::NotApplied);
let rhs_for_core: &[f64] = if needs_scaling {
for i in 0..n {
ws.scaled_rhs[i] = rhs[i] * factors.scaling[i];
}
&ws.scaled_rhs
} else {
rhs
};
solve_sparse_core_into(factors, rhs_for_core, x_out, &mut ws.y, &mut ws.w);
if needs_scaling {
for i in 0..n {
x_out[i] *= factors.scaling[i];
}
}
Ok(())
}
#[inline]
fn solve_2x2_dblock(a: f64, b: f64, c: f64, z0: f64, z1: f64) -> Option<(f64, f64)> {
if crate::dense::factor::ssids_det_floor_fail(a, b, c) {
return None;
}
if b.abs() > f64::EPSILON * (a.abs() + c.abs()).max(1.0) {
let ak = a / b;
let ck = c / b;
let denom = 1.0 / (ak * ck - 1.0);
let z0k = z0 / b;
let z1k = z1 / b;
Some(((ck * z0k - z1k) * denom, (ak * z1k - z0k) * denom))
} else {
let det = a * c - b * b;
Some(((c * z0 - b * z1) / det, (a * z1 - b * z0) / det))
}
}
fn solve_sparse_core_into(
factors: &SparseFactors,
rhs: &[f64],
x_out: &mut [f64],
y_buf: &mut [f64],
w_buf: &mut [f64],
) {
let n = factors.n;
let y = &mut y_buf[..n];
for (new_idx, &old_idx) in factors.perm.iter().enumerate() {
y[new_idx] = rhs[old_idx];
}
for node in &factors.node_factors {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow];
for i in 0..nrow {
w[i] = y[node.row_indices[ff.perm[i]]];
}
for j in 0..nelim {
let w_j = w[j];
for i in (j + 1)..nrow {
w[i] -= ff.l[j * nrow + i] * w_j;
}
}
for i in 0..nrow {
y[node.row_indices[ff.perm[i]]] = w[i];
}
}
for node in &factors.node_factors {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow];
for i in 0..nrow {
w[i] = y[node.row_indices[ff.perm[i]]];
}
let mut k = 0;
while k < nelim {
if k + 1 < nelim && ff.d_subdiag[k] != 0.0 {
let a = ff.d_diag[k];
let b = ff.d_subdiag[k];
let c = ff.d_diag[k + 1];
if let Some((x0, x1)) = solve_2x2_dblock(a, b, c, w[k], w[k + 1]) {
w[k] = x0;
w[k + 1] = x1;
}
k += 2;
} else {
if ff.d_diag[k].abs() > ff.zero_tol {
w[k] /= ff.d_diag[k];
}
k += 1;
}
}
for i in 0..nrow {
y[node.row_indices[ff.perm[i]]] = w[i];
}
}
for node in factors.node_factors.iter().rev() {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow];
for i in 0..nrow {
w[i] = y[node.row_indices[ff.perm[i]]];
}
for j in (0..nelim).rev() {
let mut sum = 0.0;
for i in (j + 1)..nrow {
sum += ff.l[j * nrow + i] * w[i];
}
w[j] -= sum;
}
for i in 0..nrow {
y[node.row_indices[ff.perm[i]]] = w[i];
}
}
for (new_idx, &old_idx) in factors.perm.iter().enumerate() {
x_out[old_idx] = y[new_idx];
}
}
pub struct SolveManyWorkspace {
y: Vec<f64>,
w: Vec<f64>,
acc: Vec<f64>,
scaled_rhs: Vec<f64>,
nrhs: usize,
n: usize,
}
impl SolveManyWorkspace {
pub fn for_factors(factors: &SparseFactors, nrhs: usize) -> Self {
let n = factors.n;
let max_nrow = factors
.node_factors
.iter()
.map(|node| node.frontal_factors.nrow)
.max()
.unwrap_or(0);
let scaled_rhs_len = if matches!(factors.scaling_info, ScalingInfo::NotApplied) {
0
} else {
n * nrhs
};
Self {
y: vec![0.0; n * nrhs],
w: vec![0.0; max_nrow * nrhs],
acc: vec![0.0; nrhs],
scaled_rhs: vec![0.0; scaled_rhs_len],
nrhs,
n,
}
}
}
pub fn solve_sparse_many(
factors: &SparseFactors,
rhs: &[f64],
nrhs: usize,
) -> Result<Vec<f64>, FeralError> {
let n = factors.n;
if nrhs == 0 {
return Ok(Vec::new());
}
let mut x = vec![0.0; n * nrhs];
let mut ws = SolveManyWorkspace::for_factors(factors, nrhs);
solve_sparse_many_into(factors, rhs, nrhs, &mut x, &mut ws)?;
Ok(x)
}
pub fn solve_sparse_many_into(
factors: &SparseFactors,
rhs: &[f64],
nrhs: usize,
x_out: &mut [f64],
ws: &mut SolveManyWorkspace,
) -> Result<(), FeralError> {
let n = factors.n;
if nrhs == 0 {
return Ok(());
}
if ws.nrhs != nrhs || ws.n != n {
return Err(FeralError::DimensionMismatch {
expected: n * nrhs,
got: ws.n * ws.nrhs,
});
}
if rhs.len() != n * nrhs {
return Err(FeralError::DimensionMismatch {
expected: n * nrhs,
got: rhs.len(),
});
}
if x_out.len() != n * nrhs {
return Err(FeralError::DimensionMismatch {
expected: n * nrhs,
got: x_out.len(),
});
}
let needs_scaling = !matches!(factors.scaling_info, ScalingInfo::NotApplied);
let expected_scaled_len = if needs_scaling { n * nrhs } else { 0 };
if ws.scaled_rhs.len() != expected_scaled_len {
return Err(FeralError::DimensionMismatch {
expected: expected_scaled_len,
got: ws.scaled_rhs.len(),
});
}
if n == 0 {
return Ok(());
}
let rhs_for_core: &[f64] = if needs_scaling {
for c in 0..nrhs {
let off = c * n;
for i in 0..n {
ws.scaled_rhs[off + i] = rhs[off + i] * factors.scaling[i];
}
}
&ws.scaled_rhs
} else {
rhs
};
solve_sparse_core_many_into(
factors,
rhs_for_core,
nrhs,
x_out,
&mut ws.y,
&mut ws.w,
&mut ws.acc,
);
if needs_scaling {
for c in 0..nrhs {
let off = c * n;
for i in 0..n {
x_out[off + i] *= factors.scaling[i];
}
}
}
Ok(())
}
fn solve_sparse_core_many_into(
factors: &SparseFactors,
rhs: &[f64],
nrhs: usize,
x_out: &mut [f64],
y_buf: &mut [f64],
w_buf: &mut [f64],
acc_buf: &mut [f64],
) {
let n = factors.n;
let y = &mut y_buf[..n * nrhs];
let use_blas3 = nrhs >= BLAS3_NRHS_THRESHOLD;
for (new_idx, &old_idx) in factors.perm.iter().enumerate() {
let dst = new_idx * nrhs;
for c in 0..nrhs {
y[dst + c] = rhs[c * n + old_idx];
}
}
for node in &factors.node_factors {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow * nrhs];
for i in 0..nrow {
let src = node.row_indices[ff.perm[i]] * nrhs;
w[i * nrhs..(i + 1) * nrhs].copy_from_slice(&y[src..src + nrhs]);
}
if use_blas3 {
fwd_blas3(w, &ff.l, nrow, nelim, nrhs);
} else {
fwd_rank1(w, &ff.l, nrow, nelim, nrhs);
}
dsolve_node(w, ff, nelim, nrhs);
for i in 0..nrow {
let dst = node.row_indices[ff.perm[i]] * nrhs;
y[dst..dst + nrhs].copy_from_slice(&w[i * nrhs..(i + 1) * nrhs]);
}
}
let acc = &mut acc_buf[..nrhs];
for node in factors.node_factors.iter().rev() {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow * nrhs];
for i in 0..nrow {
let src = node.row_indices[ff.perm[i]] * nrhs;
w[i * nrhs..(i + 1) * nrhs].copy_from_slice(&y[src..src + nrhs]);
}
if use_blas3 {
back_blas3(w, &ff.l, nrow, nelim, nrhs, acc);
} else {
back_rank1(w, &ff.l, nrow, nelim, nrhs, acc);
}
for i in 0..nrow {
let dst = node.row_indices[ff.perm[i]] * nrhs;
y[dst..dst + nrhs].copy_from_slice(&w[i * nrhs..(i + 1) * nrhs]);
}
}
for (new_idx, &old_idx) in factors.perm.iter().enumerate() {
let src = new_idx * nrhs;
for c in 0..nrhs {
x_out[c * n + old_idx] = y[src + c];
}
}
}
fn fwd_rank1(w: &mut [f64], l: &[f64], nrow: usize, nelim: usize, nrhs: usize) {
for j in 0..nelim {
let (head, tail) = w.split_at_mut((j + 1) * nrhs);
let w_j = &head[j * nrhs..(j + 1) * nrhs];
for i in (j + 1)..nrow {
let l_ij = l[j * nrow + i];
let base = (i - j - 1) * nrhs;
let w_i = &mut tail[base..base + nrhs];
for c in 0..nrhs {
w_i[c] -= l_ij * w_j[c];
}
}
}
}
fn back_rank1(w: &mut [f64], l: &[f64], nrow: usize, nelim: usize, nrhs: usize, acc: &mut [f64]) {
for j in (0..nelim).rev() {
for s in acc.iter_mut() {
*s = 0.0;
}
for i in (j + 1)..nrow {
let l_ij = l[j * nrow + i];
let w_i = &w[i * nrhs..(i + 1) * nrhs];
for c in 0..nrhs {
acc[c] += l_ij * w_i[c];
}
}
let w_j = &mut w[j * nrhs..(j + 1) * nrhs];
for c in 0..nrhs {
w_j[c] -= acc[c];
}
}
}
fn fwd_blas3(w: &mut [f64], l: &[f64], nrow: usize, nelim: usize, nrhs: usize) {
for j in 0..nelim {
let (head, tail) = w.split_at_mut((j + 1) * nrhs);
let w_j = &head[j * nrhs..(j + 1) * nrhs];
for i in (j + 1)..nelim {
let l_ij = l[j * nrow + i];
let base = (i - j - 1) * nrhs;
let w_i = &mut tail[base..base + nrhs];
for c in 0..nrhs {
w_i[c] -= l_ij * w_j[c];
}
}
}
if nelim < nrow {
let (top, bot) = w.split_at_mut(nelim * nrhs);
let a = PanelBlock {
l,
base: nelim,
row_stride: 1,
col_stride: nrow,
};
gemm_panel_minus(bot, &a, top, nrow - nelim, nelim, nrhs);
}
}
fn back_blas3(w: &mut [f64], l: &[f64], nrow: usize, nelim: usize, nrhs: usize, acc: &mut [f64]) {
if nelim < nrow {
let (top, bot) = w.split_at_mut(nelim * nrhs);
let a = PanelBlock {
l,
base: nelim,
row_stride: nrow,
col_stride: 1,
};
gemm_panel_minus(top, &a, bot, nelim, nrow - nelim, nrhs);
}
for j in (0..nelim).rev() {
for s in acc.iter_mut() {
*s = 0.0;
}
for i in (j + 1)..nelim {
let l_ij = l[j * nrow + i];
let w_i = &w[i * nrhs..(i + 1) * nrhs];
for c in 0..nrhs {
acc[c] += l_ij * w_i[c];
}
}
let w_j = &mut w[j * nrhs..(j + 1) * nrhs];
for c in 0..nrhs {
w_j[c] -= acc[c];
}
}
}
fn dsolve_node(
w: &mut [f64],
ff: &crate::dense::factor::FrontalFactors,
nelim: usize,
nrhs: usize,
) {
for c in 0..nrhs {
let mut k = 0;
while k < nelim {
if k + 1 < nelim && ff.d_subdiag[k] != 0.0 {
let a = ff.d_diag[k];
let b = ff.d_subdiag[k];
let cc = ff.d_diag[k + 1];
if let Some((x0, x1)) =
solve_2x2_dblock(a, b, cc, w[k * nrhs + c], w[(k + 1) * nrhs + c])
{
w[k * nrhs + c] = x0;
w[(k + 1) * nrhs + c] = x1;
}
k += 2;
} else {
if ff.d_diag[k].abs() > ff.zero_tol {
w[k * nrhs + c] /= ff.d_diag[k];
}
k += 1;
}
}
}
}
struct PanelBlock<'a> {
l: &'a [f64],
base: usize,
row_stride: usize,
col_stride: usize,
}
fn gemm_panel_minus(
c_rows: &mut [f64],
a: &PanelBlock,
b_rows: &[f64],
m_dim: usize,
k_dim: usize,
nrhs: usize,
) {
const MR: usize = 4;
const NR: usize = 8;
let m_main = m_dim - m_dim % MR;
let c_main = nrhs - nrhs % NR;
let mut m0 = 0;
while m0 < m_main {
let block = &mut c_rows[m0 * nrhs..(m0 + MR) * nrhs];
let (r0, rest) = block.split_at_mut(nrhs);
let (r1, rest) = rest.split_at_mut(nrhs);
let (r2, r3) = rest.split_at_mut(nrhs);
let ab0 = a.base + m0 * a.row_stride;
let ab1 = a.base + (m0 + 1) * a.row_stride;
let ab2 = a.base + (m0 + 2) * a.row_stride;
let ab3 = a.base + (m0 + 3) * a.row_stride;
let mut c0 = 0;
while c0 < c_main {
let mut acc0 = [0.0f64; NR];
let mut acc1 = [0.0f64; NR];
let mut acc2 = [0.0f64; NR];
let mut acc3 = [0.0f64; NR];
acc0.copy_from_slice(&r0[c0..c0 + NR]);
acc1.copy_from_slice(&r1[c0..c0 + NR]);
acc2.copy_from_slice(&r2[c0..c0 + NR]);
acc3.copy_from_slice(&r3[c0..c0 + NR]);
let mut bb = [0.0f64; NR];
for k in 0..k_dim {
bb.copy_from_slice(&b_rows[k * nrhs + c0..k * nrhs + c0 + NR]);
let kc = k * a.col_stride;
let a0 = a.l[ab0 + kc];
let a1 = a.l[ab1 + kc];
let a2 = a.l[ab2 + kc];
let a3 = a.l[ab3 + kc];
for s in 0..NR {
let bv = bb[s];
acc0[s] -= a0 * bv;
acc1[s] -= a1 * bv;
acc2[s] -= a2 * bv;
acc3[s] -= a3 * bv;
}
}
r0[c0..c0 + NR].copy_from_slice(&acc0);
r1[c0..c0 + NR].copy_from_slice(&acc1);
r2[c0..c0 + NR].copy_from_slice(&acc2);
r3[c0..c0 + NR].copy_from_slice(&acc3);
c0 += NR;
}
m0 += MR;
}
gemm_scalar_block(c_rows, a, b_rows, 0, m_main, c_main, nrhs, k_dim, nrhs);
gemm_scalar_block(c_rows, a, b_rows, m_main, m_dim, 0, nrhs, k_dim, nrhs);
}
#[allow(clippy::too_many_arguments)]
fn gemm_scalar_block(
c_rows: &mut [f64],
a: &PanelBlock,
b_rows: &[f64],
m_lo: usize,
m_hi: usize,
c_lo: usize,
c_hi: usize,
k_dim: usize,
nrhs: usize,
) {
for m in m_lo..m_hi {
let ab = a.base + m * a.row_stride;
let row = &mut c_rows[m * nrhs..(m + 1) * nrhs];
for c in c_lo..c_hi {
let mut sum = row[c];
for k in 0..k_dim {
sum -= a.l[ab + k * a.col_stride] * b_rows[k * nrhs + c];
}
row[c] = sum;
}
}
}
pub fn solve_sparse_refined(
matrix: &CscMatrix,
factors: &SparseFactors,
rhs: &[f64],
) -> Result<Vec<f64>, FeralError> {
let (x, _) = solve_sparse_refined_core(matrix, factors, rhs, false)?;
Ok(x)
}
#[derive(Debug, Clone, Copy)]
pub struct RefinementStep {
pub step: usize,
pub residual_2norm: f64,
pub relative_residual: f64,
pub forward_error_bound: f64,
pub improved: bool,
}
#[derive(Debug, Clone)]
pub struct RefinementDiagnostics {
pub anorm_1: f64,
pub kappa_1_est: f64,
pub steps: Vec<RefinementStep>,
pub returned_step: usize,
}
pub fn solve_sparse_refined_with_diagnostics(
matrix: &CscMatrix,
factors: &SparseFactors,
rhs: &[f64],
) -> Result<(Vec<f64>, RefinementDiagnostics), FeralError> {
let (x, diag) = solve_sparse_refined_core(matrix, factors, rhs, true)?;
let diag = diag.ok_or(FeralError::DimensionMismatch {
expected: 1,
got: 0,
})?;
Ok((x, diag))
}
fn solve_sparse_refined_core(
matrix: &CscMatrix,
factors: &SparseFactors,
rhs: &[f64],
with_diagnostics: bool,
) -> Result<(Vec<f64>, Option<RefinementDiagnostics>), FeralError> {
let n = factors.n;
if rhs.len() != n {
return Err(FeralError::DimensionMismatch {
expected: n,
got: rhs.len(),
});
}
let (anorm_1, kappa_1_est) = if with_diagnostics && n > 0 {
let a1 = matrix_norm_1(matrix);
let inv1 = estimate_inverse_norm_1(factors)?;
(a1, a1 * inv1)
} else {
(0.0, 0.0)
};
let mut ws = SolveWorkspace::for_factors(factors);
let mut x = vec![0.0; n];
solve_sparse_into_ws(factors, rhs, &mut x, &mut ws)?;
let mut r = vec![0.0; n];
matrix.symv(&x, &mut r);
for i in 0..n {
r[i] = rhs[i] - r[i];
}
let mut r_norm = norm2(&r);
let mut best_x = x.clone();
let mut best_r_norm = r_norm;
let mut stagnant_count: usize = 0;
let mut dx = vec![0.0; n];
let max_steps = 10;
let max_stagnant_steps = 2;
let n_sqrt = (n as f64).sqrt();
let threshold = f64::EPSILON * n_sqrt;
let divergence_factor = 100.0;
let b_norm = norm2(rhs);
let relative_reached = |r_norm: f64| -> bool {
if b_norm > 0.0 {
r_norm < threshold * b_norm
} else {
r_norm < threshold
}
};
let rel_res = |rn: f64| if b_norm > 0.0 { rn / b_norm } else { rn };
let mut steps: Vec<RefinementStep> = if with_diagnostics {
let rr = rel_res(r_norm);
vec![RefinementStep {
step: 0,
residual_2norm: r_norm,
relative_residual: rr,
forward_error_bound: kappa_1_est * rr,
improved: true,
}]
} else {
Vec::new()
};
let mut returned_step: usize = 0;
for step in 1..=max_steps {
if relative_reached(best_r_norm) {
break;
}
solve_sparse_into_ws(factors, &r, &mut dx, &mut ws)?;
for i in 0..n {
x[i] += dx[i];
}
matrix.symv(&x, &mut r);
for i in 0..n {
r[i] = rhs[i] - r[i];
}
r_norm = norm2(&r);
let improved = r_norm < best_r_norm;
if improved {
best_r_norm = r_norm;
best_x.copy_from_slice(&x);
stagnant_count = 0;
if with_diagnostics {
returned_step = step;
}
} else {
stagnant_count += 1;
}
if with_diagnostics {
let rr = rel_res(r_norm);
steps.push(RefinementStep {
step,
residual_2norm: r_norm,
relative_residual: rr,
forward_error_bound: kappa_1_est * rr,
improved,
});
}
if r_norm > best_r_norm * divergence_factor {
break;
}
if stagnant_count >= max_stagnant_steps {
break;
}
}
let diag = if with_diagnostics {
Some(RefinementDiagnostics {
anorm_1,
kappa_1_est,
steps,
returned_step,
})
} else {
None
};
Ok((best_x, diag))
}
pub fn solve_sparse_many_refined(
matrix: &CscMatrix,
factors: &SparseFactors,
rhs: &[f64],
nrhs: usize,
) -> Result<Vec<f64>, FeralError> {
let n = factors.n;
if rhs.len() != n * nrhs {
return Err(FeralError::DimensionMismatch {
expected: n * nrhs,
got: rhs.len(),
});
}
if nrhs == 0 || n == 0 {
return Ok(vec![0.0; n * nrhs]);
}
let max_steps = 10;
let max_stagnant_steps = 2;
let threshold = f64::EPSILON * (n as f64).sqrt();
let divergence_factor = 100.0;
let relative_reached = |r_norm: f64, b_norm: f64| -> bool {
if b_norm > 0.0 {
r_norm < threshold * b_norm
} else {
r_norm < threshold
}
};
let mut x = solve_sparse_many(factors, rhs, nrhs)?;
let mut best_rn = vec![0.0f64; nrhs];
let mut bnorm = vec![0.0f64; nrhs];
let mut rc = vec![0.0f64; n];
let mut active: Vec<usize> = Vec::new();
for c in 0..nrhs {
matrix.symv(&x[c * n..(c + 1) * n], &mut rc);
for i in 0..n {
rc[i] = rhs[c * n + i] - rc[i];
}
bnorm[c] = norm2(&rhs[c * n..(c + 1) * n]);
best_rn[c] = norm2(&rc);
if !relative_reached(best_rn[c], bnorm[c]) {
active.push(c);
}
}
if active.is_empty() {
return Ok(x);
}
let mut best_x = x.clone();
let mut stagnant = vec![0usize; nrhs];
let mut r_act = vec![0.0f64; n * active.len()];
for _step in 1..=max_steps {
if active.is_empty() {
break;
}
let na = active.len();
for (k, &c) in active.iter().enumerate() {
matrix.symv(&x[c * n..(c + 1) * n], &mut r_act[k * n..(k + 1) * n]);
for i in 0..n {
r_act[k * n + i] = rhs[c * n + i] - r_act[k * n + i];
}
}
let dx = solve_sparse_many(factors, &r_act[..n * na], na)?;
let mut still: Vec<usize> = Vec::with_capacity(na);
for (k, &c) in active.iter().enumerate() {
for i in 0..n {
x[c * n + i] += dx[k * n + i];
}
matrix.symv(&x[c * n..(c + 1) * n], &mut rc);
for i in 0..n {
rc[i] = rhs[c * n + i] - rc[i];
}
let rn = norm2(&rc);
if rn < best_rn[c] {
best_rn[c] = rn;
best_x[c * n..(c + 1) * n].copy_from_slice(&x[c * n..(c + 1) * n]);
stagnant[c] = 0;
} else {
stagnant[c] += 1;
}
let done = relative_reached(best_rn[c], bnorm[c])
|| rn > best_rn[c] * divergence_factor
|| stagnant[c] >= max_stagnant_steps;
if !done {
still.push(c);
}
}
active = still;
}
Ok(best_x)
}
fn norm2(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dense::factor::{BunchKaufmanParams, ZeroPivotAction};
use crate::numeric::factorize::factorize_multifrontal;
use crate::sparse::csc::CscMatrix;
use crate::symbolic::{symbolic_factorize, SupernodeParams};
fn make_params() -> crate::numeric::factorize::NumericParams {
crate::numeric::factorize::NumericParams::with_bk(BunchKaufmanParams {
on_zero_pivot: ZeroPivotAction::ForceAccept,
..BunchKaufmanParams::default()
})
}
fn check_solve(m: &CscMatrix, rhs: &[f64], tol: f64) {
let sym = symbolic_factorize(m, &SupernodeParams::default()).unwrap();
let params = make_params();
let (factors, _) = factorize_multifrontal(m, &sym, ¶ms).unwrap();
let x = solve_sparse(&factors, rhs).unwrap();
let n = m.n;
let mut ax = vec![0.0; n];
m.symv(&x, &mut ax);
let mut res_sq = 0.0;
let mut b_sq = 0.0;
for i in 0..n {
res_sq += (ax[i] - rhs[i]).powi(2);
b_sq += rhs[i].powi(2);
}
let rel_res = if b_sq > 0.0 {
(res_sq / b_sq).sqrt()
} else {
res_sq.sqrt()
};
assert!(
rel_res < tol,
"relative residual {:.2e} exceeds tolerance {:.2e}",
rel_res,
tol
);
}
#[test]
fn test_solve_diagonal() {
let m = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[2.0, 3.0, 5.0]).unwrap();
check_solve(&m, &[4.0, 9.0, 25.0], 1e-14);
}
#[test]
fn test_solve_tridiagonal() {
let m = CscMatrix::from_triplets(
3,
&[0, 1, 1, 2, 2],
&[0, 0, 1, 1, 2],
&[2.0, -1.0, 2.0, -1.0, 2.0],
)
.unwrap();
check_solve(&m, &[1.0, 0.0, 1.0], 1e-13);
}
#[test]
fn test_solve_kkt() {
let m = CscMatrix::from_triplets(
3,
&[0, 1, 2, 2, 2],
&[0, 1, 0, 1, 2],
&[2.0, 3.0, 1.0, 1.0, -1e-8],
)
.unwrap();
check_solve(&m, &[1.0, 2.0, 3.0], 1e-6);
}
#[test]
fn test_solve_larger_spd() {
let n = 5;
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for i in 0..n {
rows.push(i);
cols.push(i);
vals.push(4.0);
if i + 1 < n {
rows.push(i + 1);
cols.push(i);
vals.push(-1.0);
}
}
let m = CscMatrix::from_triplets(n, &rows, &cols, &vals).unwrap();
check_solve(
&m,
&(0..n).map(|i| (i + 1) as f64).collect::<Vec<_>>(),
1e-13,
);
}
#[test]
fn test_solve_indefinite() {
let m = CscMatrix::from_triplets(2, &[0, 1, 1], &[0, 0, 1], &[1.0, 2.0, 1.0]).unwrap();
check_solve(&m, &[5.0, 4.0], 1e-13);
}
#[test]
fn test_solve_arrow_multi_supernode() {
let m = CscMatrix::from_triplets(
5,
&[0, 1, 2, 3, 4, 1, 2, 3, 4],
&[0, 0, 0, 0, 0, 1, 2, 3, 4],
&[10.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
)
.unwrap();
check_solve(&m, &[1.0, 2.0, 3.0, 4.0, 5.0], 1e-12);
}
fn factor_well_cond(m: &CscMatrix) -> SparseFactors {
let sym = symbolic_factorize(m, &SupernodeParams::default()).unwrap();
let (factors, _) = factorize_multifrontal(
m,
&sym,
&crate::numeric::factorize::NumericParams::default(),
)
.unwrap();
factors
}
fn hilbert(n: usize) -> CscMatrix {
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for j in 0..n {
for i in j..n {
rows.push(i);
cols.push(j);
vals.push(1.0 / ((i + j + 1) as f64));
}
}
CscMatrix::from_triplets(n, &rows, &cols, &vals).unwrap()
}
#[test]
fn diagnostics_match_non_diagnostic_solution() {
let m = CscMatrix::from_triplets(
3,
&[0, 1, 2, 2, 2],
&[0, 1, 0, 1, 2],
&[2.0, 3.0, 1.0, 1.0, -1e-8],
)
.unwrap();
let rhs = [1.0, 2.0, 3.0];
let factors = factor_well_cond(&m);
let x_plain = solve_sparse_refined(&m, &factors, &rhs).unwrap();
let (x_diag, _diag) = solve_sparse_refined_with_diagnostics(&m, &factors, &rhs).unwrap();
for i in 0..x_plain.len() {
assert_eq!(
x_plain[i].to_bits(),
x_diag[i].to_bits(),
"iterate mismatch at index {}: {} vs {}",
i,
x_plain[i],
x_diag[i],
);
}
}
#[test]
fn diagnostics_populate_well_conditioned() {
let n = 5;
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for i in 0..n {
rows.push(i);
cols.push(i);
vals.push(4.0);
if i + 1 < n {
rows.push(i + 1);
cols.push(i);
vals.push(-1.0);
}
}
let m = CscMatrix::from_triplets(n, &rows, &cols, &vals).unwrap();
let rhs: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
let factors = factor_well_cond(&m);
let (_, diag) = solve_sparse_refined_with_diagnostics(&m, &factors, &rhs).unwrap();
assert!(diag.anorm_1 > 0.0, "anorm_1 must be > 0 for nonzero A");
assert!(
diag.kappa_1_est >= 1.0 - 1e-8,
"kappa_1_est {} below 1.0 lower bound",
diag.kappa_1_est
);
assert!(!diag.steps.is_empty(), "diagnostics must contain step 0");
assert_eq!(diag.steps[0].step, 0);
assert!(diag.returned_step < diag.steps.len());
let best = diag
.steps
.iter()
.map(|s| s.residual_2norm)
.fold(f64::INFINITY, f64::min);
assert_eq!(diag.steps[diag.returned_step].residual_2norm, best);
}
#[test]
fn diagnostics_kappa_matches_standalone() {
let m = hilbert(6);
let rhs = [1.0, 0.5, 1.0, 0.5, 1.0, 0.5];
let factors = factor_well_cond(&m);
let kappa_standalone =
crate::numeric::condition::estimate_condition_1norm(&m, &factors).unwrap();
let (_, diag) = solve_sparse_refined_with_diagnostics(&m, &factors, &rhs).unwrap();
assert_eq!(
diag.kappa_1_est.to_bits(),
kappa_standalone.to_bits(),
"diag kappa {} != standalone {}",
diag.kappa_1_est,
kappa_standalone,
);
assert!(
diag.kappa_1_est > 1.0e4,
"Hilbert-6 kappa_1_est {} too small",
diag.kappa_1_est,
);
}
#[test]
fn diagnostics_forward_error_bound_field() {
let m = hilbert(4);
let rhs = [1.0, 2.0, 3.0, 4.0];
let factors = factor_well_cond(&m);
let (_, diag) = solve_sparse_refined_with_diagnostics(&m, &factors, &rhs).unwrap();
for s in &diag.steps {
let expected = diag.kappa_1_est * s.relative_residual;
let diff = (s.forward_error_bound - expected).abs();
assert!(
diff <= 1e-15 * expected.max(1.0),
"step {} fwd-err {} vs expected {} (diff {})",
s.step,
s.forward_error_bound,
expected,
diff
);
assert!(s.forward_error_bound >= 0.0);
assert!(s.residual_2norm.is_finite());
}
}
#[test]
fn diagnostics_n_zero() {
let m = CscMatrix::from_triplets(0, &[], &[], &[]).unwrap();
let factors = factor_well_cond(&m);
let (x, diag) = solve_sparse_refined_with_diagnostics(&m, &factors, &[]).unwrap();
assert!(x.is_empty());
assert_eq!(diag.anorm_1, 0.0);
assert_eq!(diag.kappa_1_est, 0.0);
}
#[test]
fn diagnostics_dim_mismatch_rejected() {
let m = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[1.0, 2.0, 3.0]).unwrap();
let factors = factor_well_cond(&m);
let r = solve_sparse_refined_with_diagnostics(&m, &factors, &[1.0, 2.0]);
assert!(r.is_err());
}
#[test]
fn solve_many_into_rejects_scaling_mismatched_workspace() {
use crate::scaling::ScalingStrategy;
let m = CscMatrix::from_triplets(
3,
&[0, 1, 1, 2, 2],
&[0, 0, 1, 1, 2],
&[2.0, -1.0, 2.0, -1.0, 2.0],
)
.unwrap();
let sym = symbolic_factorize(&m, &SupernodeParams::default()).unwrap();
let nrhs = 2;
let mut params_unscaled = make_params();
params_unscaled.scaling = ScalingStrategy::Identity;
let (factors_unscaled, _) = factorize_multifrontal(&m, &sym, ¶ms_unscaled).unwrap();
assert!(matches!(
factors_unscaled.scaling_info,
ScalingInfo::NotApplied
));
let mut ws = SolveManyWorkspace::for_factors(&factors_unscaled, nrhs);
assert_eq!(ws.scaled_rhs.len(), 0);
let mut params_scaled = make_params();
params_scaled.scaling = ScalingStrategy::External(vec![1.0; m.n]);
let (factors_scaled, _) = factorize_multifrontal(&m, &sym, ¶ms_scaled).unwrap();
assert!(!matches!(
factors_scaled.scaling_info,
ScalingInfo::NotApplied
));
let rhs = vec![1.0; m.n * nrhs];
let mut x = vec![0.0; m.n * nrhs];
let result = solve_sparse_many_into(&factors_scaled, &rhs, nrhs, &mut x, &mut ws);
assert!(
matches!(result, Err(FeralError::DimensionMismatch { .. })),
"expected DimensionMismatch for a scaling-mismatched workspace, got {result:?}"
);
}
#[test]
fn reg3_sparse_dsolve_small_scale_2x2_inverted() {
let (a, b, c) = (1e-16, 1e-17, 1e-16);
let x_true = [1.0_f64, 1.0_f64];
let rhs = [a * x_true[0] + b * x_true[1], b * x_true[0] + c * x_true[1]];
let ff = crate::dense::factor::FrontalFactors {
nrow: 2,
ncol: 2,
nelim: 2,
l: vec![1.0, 0.0, 0.0, 1.0],
d_diag: vec![a, c],
d_subdiag: vec![b, 0.0],
perm: vec![0, 1],
perm_inv: vec![0, 1],
contrib: vec![],
contrib_dim: 0,
n_delayed: 0,
inertia: crate::inertia::Inertia::new(1, 1, 0),
needs_refinement: false,
n_rook_rescues: 0,
n_tiny: 0,
zero_tol: f64::EPSILON,
zero_tol_2x2: f64::EPSILON * f64::EPSILON,
};
let mut w = vec![rhs[0], rhs[1]]; dsolve_node(&mut w, &ff, 2, 1);
assert!(
(w[0] - x_true[0]).abs() < 1e-6 && (w[1] - x_true[1]).abs() < 1e-6,
"REG-3: small-scale 2×2 must be inverted by sparse dsolve, not \
skipped; got w = {w:?}, expected ≈ {x_true:?}"
);
}
#[test]
fn reg3_helper_inverts_small_scale_block() {
let (a, b, c) = (1e-16, 1e-17, 1e-16);
let x = [1.0_f64, 1.0_f64];
let (z0, z1) = (a * x[0] + b * x[1], b * x[0] + c * x[1]);
let (x0, x1) = solve_2x2_dblock(a, b, c, z0, z1).expect("accepted");
assert!((x0 - 1.0).abs() < 1e-6 && (x1 - 1.0).abs() < 1e-6);
}
#[test]
fn reg3_rejected_block_skipped_by_helper() {
let p = (1u64 << 53) as f64;
assert!(solve_2x2_dblock(p + 1.0, p, p, 1.0, 2.0).is_none());
}
}