#![allow(clippy::needless_range_loop)]
use super::factorize::SparseFactors;
use crate::error::FeralError;
use crate::scaling::ScalingInfo;
use crate::sparse::csc::CscMatrix;
pub fn solve_sparse(factors: &SparseFactors, rhs: &[f64]) -> Result<Vec<f64>, FeralError> {
let n = factors.n;
if n == 0 && rhs.is_empty() {
return Ok(Vec::new());
}
let mut x = vec![0.0; n];
let mut ws = SolveWorkspace::for_factors(factors);
solve_sparse_into_ws(factors, rhs, &mut x, &mut ws)?;
Ok(x)
}
struct SolveWorkspace {
y: Vec<f64>,
w: Vec<f64>,
scaled_rhs: Vec<f64>,
}
impl SolveWorkspace {
fn for_factors(factors: &SparseFactors) -> Self {
let n = factors.n;
let max_nrow = factors
.node_factors
.iter()
.map(|node| node.frontal_factors.nrow)
.max()
.unwrap_or(0);
let scaled_rhs_len = if matches!(factors.scaling_info, ScalingInfo::NotApplied) {
0
} else {
n
};
Self {
y: vec![0.0; n],
w: vec![0.0; max_nrow],
scaled_rhs: vec![0.0; scaled_rhs_len],
}
}
}
fn solve_sparse_into_ws(
factors: &SparseFactors,
rhs: &[f64],
x_out: &mut [f64],
ws: &mut SolveWorkspace,
) -> Result<(), FeralError> {
let n = factors.n;
if rhs.len() != n {
return Err(FeralError::DimensionMismatch {
expected: n,
got: rhs.len(),
});
}
if x_out.len() != n {
return Err(FeralError::DimensionMismatch {
expected: n,
got: x_out.len(),
});
}
if n == 0 {
return Ok(());
}
let needs_scaling = !matches!(factors.scaling_info, ScalingInfo::NotApplied);
let rhs_for_core: &[f64] = if needs_scaling {
for i in 0..n {
ws.scaled_rhs[i] = rhs[i] * factors.scaling[i];
}
&ws.scaled_rhs
} else {
rhs
};
solve_sparse_core_into(factors, rhs_for_core, x_out, &mut ws.y, &mut ws.w);
if needs_scaling {
for i in 0..n {
x_out[i] *= factors.scaling[i];
}
}
Ok(())
}
fn solve_sparse_core_into(
factors: &SparseFactors,
rhs: &[f64],
x_out: &mut [f64],
y_buf: &mut [f64],
w_buf: &mut [f64],
) {
let n = factors.n;
let y = &mut y_buf[..n];
for (new_idx, &old_idx) in factors.perm.iter().enumerate() {
y[new_idx] = rhs[old_idx];
}
for node in &factors.node_factors {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow];
for i in 0..nrow {
w[i] = y[node.row_indices[ff.perm[i]]];
}
for j in 0..nelim {
let w_j = w[j];
for i in (j + 1)..nrow {
w[i] -= ff.l[j * nrow + i] * w_j;
}
}
for i in 0..nrow {
y[node.row_indices[ff.perm[i]]] = w[i];
}
}
for node in &factors.node_factors {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow];
for i in 0..nrow {
w[i] = y[node.row_indices[ff.perm[i]]];
}
let mut k = 0;
while k < nelim {
if k + 1 < nelim && ff.d_subdiag[k] != 0.0 {
let a = ff.d_diag[k];
let b = ff.d_subdiag[k];
let c = ff.d_diag[k + 1];
let det = a * c - b * b;
if det.abs() > ff.zero_tol_2x2 {
let z1 = w[k];
let z2 = w[k + 1];
if b.abs() > f64::EPSILON * (a.abs() + c.abs()).max(1.0) {
let ak = a / b;
let ck = c / b;
let denom = 1.0 / (ak * ck - 1.0);
let z1k = z1 / b;
let z2k = z2 / b;
w[k] = (ck * z1k - z2k) * denom;
w[k + 1] = (ak * z2k - z1k) * denom;
} else {
w[k] = (c * z1 - b * z2) / det;
w[k + 1] = (a * z2 - b * z1) / det;
}
}
k += 2;
} else {
if ff.d_diag[k].abs() > ff.zero_tol {
w[k] /= ff.d_diag[k];
}
k += 1;
}
}
for i in 0..nrow {
y[node.row_indices[ff.perm[i]]] = w[i];
}
}
for node in factors.node_factors.iter().rev() {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow];
for i in 0..nrow {
w[i] = y[node.row_indices[ff.perm[i]]];
}
for j in (0..nelim).rev() {
let mut sum = 0.0;
for i in (j + 1)..nrow {
sum += ff.l[j * nrow + i] * w[i];
}
w[j] -= sum;
}
for i in 0..nrow {
y[node.row_indices[ff.perm[i]]] = w[i];
}
}
for (new_idx, &old_idx) in factors.perm.iter().enumerate() {
x_out[old_idx] = y[new_idx];
}
}
pub struct SolveManyWorkspace {
y: Vec<f64>,
w: Vec<f64>,
scaled_rhs: Vec<f64>,
nrhs: usize,
n: usize,
}
impl SolveManyWorkspace {
pub fn for_factors(factors: &SparseFactors, nrhs: usize) -> Self {
let n = factors.n;
let max_nrow = factors
.node_factors
.iter()
.map(|node| node.frontal_factors.nrow)
.max()
.unwrap_or(0);
let scaled_rhs_len = if matches!(factors.scaling_info, ScalingInfo::NotApplied) {
0
} else {
n * nrhs
};
Self {
y: vec![0.0; n * nrhs],
w: vec![0.0; max_nrow * nrhs],
scaled_rhs: vec![0.0; scaled_rhs_len],
nrhs,
n,
}
}
}
pub fn solve_sparse_many(
factors: &SparseFactors,
rhs: &[f64],
nrhs: usize,
) -> Result<Vec<f64>, FeralError> {
let n = factors.n;
if nrhs == 0 {
return Ok(Vec::new());
}
let mut x = vec![0.0; n * nrhs];
let mut ws = SolveManyWorkspace::for_factors(factors, nrhs);
solve_sparse_many_into(factors, rhs, nrhs, &mut x, &mut ws)?;
Ok(x)
}
pub fn solve_sparse_many_into(
factors: &SparseFactors,
rhs: &[f64],
nrhs: usize,
x_out: &mut [f64],
ws: &mut SolveManyWorkspace,
) -> Result<(), FeralError> {
let n = factors.n;
if nrhs == 0 {
return Ok(());
}
if ws.nrhs != nrhs || ws.n != n {
return Err(FeralError::DimensionMismatch {
expected: n * nrhs,
got: ws.n * ws.nrhs,
});
}
if rhs.len() != n * nrhs {
return Err(FeralError::DimensionMismatch {
expected: n * nrhs,
got: rhs.len(),
});
}
if x_out.len() != n * nrhs {
return Err(FeralError::DimensionMismatch {
expected: n * nrhs,
got: x_out.len(),
});
}
if n == 0 {
return Ok(());
}
let needs_scaling = !matches!(factors.scaling_info, ScalingInfo::NotApplied);
let rhs_for_core: &[f64] = if needs_scaling {
for c in 0..nrhs {
let off = c * n;
for i in 0..n {
ws.scaled_rhs[off + i] = rhs[off + i] * factors.scaling[i];
}
}
&ws.scaled_rhs
} else {
rhs
};
solve_sparse_core_many_into(factors, rhs_for_core, nrhs, x_out, &mut ws.y, &mut ws.w);
if needs_scaling {
for c in 0..nrhs {
let off = c * n;
for i in 0..n {
x_out[off + i] *= factors.scaling[i];
}
}
}
Ok(())
}
fn solve_sparse_core_many_into(
factors: &SparseFactors,
rhs: &[f64],
nrhs: usize,
x_out: &mut [f64],
y_buf: &mut [f64],
w_buf: &mut [f64],
) {
let n = factors.n;
let y = &mut y_buf[..n * nrhs];
for c in 0..nrhs {
let src_off = c * n;
let dst_off = c * n;
for (new_idx, &old_idx) in factors.perm.iter().enumerate() {
y[dst_off + new_idx] = rhs[src_off + old_idx];
}
}
for node in &factors.node_factors {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow * nrhs];
for c in 0..nrhs {
let w_col = &mut w[c * nrow..(c + 1) * nrow];
let y_col = &y[c * n..(c + 1) * n];
for i in 0..nrow {
w_col[i] = y_col[node.row_indices[ff.perm[i]]];
}
}
for j in 0..nelim {
for i in (j + 1)..nrow {
let l_ij = ff.l[j * nrow + i];
for c in 0..nrhs {
let off = c * nrow;
w[off + i] -= l_ij * w[off + j];
}
}
}
for c in 0..nrhs {
let w_col = &w[c * nrow..(c + 1) * nrow];
let y_col = &mut y[c * n..(c + 1) * n];
for i in 0..nrow {
y_col[node.row_indices[ff.perm[i]]] = w_col[i];
}
}
}
for node in &factors.node_factors {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow * nrhs];
for c in 0..nrhs {
let w_col = &mut w[c * nrow..(c + 1) * nrow];
let y_col = &y[c * n..(c + 1) * n];
for i in 0..nrow {
w_col[i] = y_col[node.row_indices[ff.perm[i]]];
}
}
for c in 0..nrhs {
let w_col = &mut w[c * nrow..(c + 1) * nrow];
let mut k = 0;
while k < nelim {
if k + 1 < nelim && ff.d_subdiag[k] != 0.0 {
let a = ff.d_diag[k];
let b = ff.d_subdiag[k];
let cc = ff.d_diag[k + 1];
let det = a * cc - b * b;
if det.abs() > ff.zero_tol_2x2 {
let z1 = w_col[k];
let z2 = w_col[k + 1];
if b.abs() > f64::EPSILON * (a.abs() + cc.abs()).max(1.0) {
let ak = a / b;
let ck = cc / b;
let denom = 1.0 / (ak * ck - 1.0);
let z1k = z1 / b;
let z2k = z2 / b;
w_col[k] = (ck * z1k - z2k) * denom;
w_col[k + 1] = (ak * z2k - z1k) * denom;
} else {
w_col[k] = (cc * z1 - b * z2) / det;
w_col[k + 1] = (a * z2 - b * z1) / det;
}
}
k += 2;
} else {
if ff.d_diag[k].abs() > ff.zero_tol {
w_col[k] /= ff.d_diag[k];
}
k += 1;
}
}
}
for c in 0..nrhs {
let w_col = &w[c * nrow..(c + 1) * nrow];
let y_col = &mut y[c * n..(c + 1) * n];
for i in 0..nrow {
y_col[node.row_indices[ff.perm[i]]] = w_col[i];
}
}
}
for node in factors.node_factors.iter().rev() {
let ff = &node.frontal_factors;
let nelim = ff.nelim;
let nrow = ff.nrow;
if nelim == 0 {
continue;
}
let w = &mut w_buf[..nrow * nrhs];
for c in 0..nrhs {
let w_col = &mut w[c * nrow..(c + 1) * nrow];
let y_col = &y[c * n..(c + 1) * n];
for i in 0..nrow {
w_col[i] = y_col[node.row_indices[ff.perm[i]]];
}
}
for j in (0..nelim).rev() {
for c in 0..nrhs {
let off = c * nrow;
let mut sum = 0.0;
for i in (j + 1)..nrow {
sum += ff.l[j * nrow + i] * w[off + i];
}
w[off + j] -= sum;
}
}
for c in 0..nrhs {
let w_col = &w[c * nrow..(c + 1) * nrow];
let y_col = &mut y[c * n..(c + 1) * n];
for i in 0..nrow {
y_col[node.row_indices[ff.perm[i]]] = w_col[i];
}
}
}
for c in 0..nrhs {
let src_off = c * n;
let dst_off = c * n;
for (new_idx, &old_idx) in factors.perm.iter().enumerate() {
x_out[dst_off + old_idx] = y[src_off + new_idx];
}
}
}
pub fn solve_sparse_refined(
matrix: &CscMatrix,
factors: &SparseFactors,
rhs: &[f64],
) -> Result<Vec<f64>, FeralError> {
let n = factors.n;
if rhs.len() != n {
return Err(FeralError::DimensionMismatch {
expected: n,
got: rhs.len(),
});
}
let mut ws = SolveWorkspace::for_factors(factors);
let mut x = vec![0.0; n];
solve_sparse_into_ws(factors, rhs, &mut x, &mut ws)?;
let mut r = vec![0.0; n];
matrix.symv(&x, &mut r);
for i in 0..n {
r[i] = rhs[i] - r[i];
}
let mut r_norm = norm2(&r);
let mut best_x = x.clone();
let mut best_r_norm = r_norm;
let mut stagnant_count: usize = 0;
let mut dx = vec![0.0; n];
let max_steps = 10;
let max_stagnant_steps = 2;
let n_sqrt = (n as f64).sqrt();
let threshold = f64::EPSILON * n_sqrt;
let divergence_factor = 100.0;
let b_norm = norm2(rhs);
let relative_reached = |r_norm: f64| -> bool {
if b_norm > 0.0 {
r_norm < threshold * b_norm
} else {
r_norm < threshold
}
};
for _step in 0..max_steps {
if relative_reached(best_r_norm) {
break;
}
solve_sparse_into_ws(factors, &r, &mut dx, &mut ws)?;
for i in 0..n {
x[i] += dx[i];
}
matrix.symv(&x, &mut r);
for i in 0..n {
r[i] = rhs[i] - r[i];
}
r_norm = norm2(&r);
let improved = r_norm < best_r_norm;
if improved {
best_r_norm = r_norm;
best_x.copy_from_slice(&x);
stagnant_count = 0;
} else {
stagnant_count += 1;
}
if r_norm > best_r_norm * divergence_factor {
break;
}
if stagnant_count >= max_stagnant_steps {
break;
}
}
Ok(best_x)
}
fn norm2(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dense::factor::{BunchKaufmanParams, ZeroPivotAction};
use crate::numeric::factorize::factorize_multifrontal;
use crate::sparse::csc::CscMatrix;
use crate::symbolic::{symbolic_factorize, SupernodeParams};
fn make_params() -> crate::numeric::factorize::NumericParams {
crate::numeric::factorize::NumericParams::with_bk(BunchKaufmanParams {
on_zero_pivot: ZeroPivotAction::ForceAccept,
..BunchKaufmanParams::default()
})
}
fn check_solve(m: &CscMatrix, rhs: &[f64], tol: f64) {
let sym = symbolic_factorize(m, &SupernodeParams::default()).unwrap();
let params = make_params();
let (factors, _) = factorize_multifrontal(m, &sym, ¶ms).unwrap();
let x = solve_sparse(&factors, rhs).unwrap();
let n = m.n;
let mut ax = vec![0.0; n];
m.symv(&x, &mut ax);
let mut res_sq = 0.0;
let mut b_sq = 0.0;
for i in 0..n {
res_sq += (ax[i] - rhs[i]).powi(2);
b_sq += rhs[i].powi(2);
}
let rel_res = if b_sq > 0.0 {
(res_sq / b_sq).sqrt()
} else {
res_sq.sqrt()
};
assert!(
rel_res < tol,
"relative residual {:.2e} exceeds tolerance {:.2e}",
rel_res,
tol
);
}
#[test]
fn test_solve_diagonal() {
let m = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[2.0, 3.0, 5.0]).unwrap();
check_solve(&m, &[4.0, 9.0, 25.0], 1e-14);
}
#[test]
fn test_solve_tridiagonal() {
let m = CscMatrix::from_triplets(
3,
&[0, 1, 1, 2, 2],
&[0, 0, 1, 1, 2],
&[2.0, -1.0, 2.0, -1.0, 2.0],
)
.unwrap();
check_solve(&m, &[1.0, 0.0, 1.0], 1e-13);
}
#[test]
fn test_solve_kkt() {
let m = CscMatrix::from_triplets(
3,
&[0, 1, 2, 2, 2],
&[0, 1, 0, 1, 2],
&[2.0, 3.0, 1.0, 1.0, -1e-8],
)
.unwrap();
check_solve(&m, &[1.0, 2.0, 3.0], 1e-6);
}
#[test]
fn test_solve_larger_spd() {
let n = 5;
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for i in 0..n {
rows.push(i);
cols.push(i);
vals.push(4.0);
if i + 1 < n {
rows.push(i + 1);
cols.push(i);
vals.push(-1.0);
}
}
let m = CscMatrix::from_triplets(n, &rows, &cols, &vals).unwrap();
check_solve(
&m,
&(0..n).map(|i| (i + 1) as f64).collect::<Vec<_>>(),
1e-13,
);
}
#[test]
fn test_solve_indefinite() {
let m = CscMatrix::from_triplets(2, &[0, 1, 1], &[0, 0, 1], &[1.0, 2.0, 1.0]).unwrap();
check_solve(&m, &[5.0, 4.0], 1e-13);
}
#[test]
fn test_solve_arrow_multi_supernode() {
let m = CscMatrix::from_triplets(
5,
&[0, 1, 2, 3, 4, 1, 2, 3, 4],
&[0, 0, 0, 0, 0, 1, 2, 3, 4],
&[10.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
)
.unwrap();
check_solve(&m, &[1.0, 2.0, 3.0, 4.0, 5.0], 1e-12);
}
}