use crate::dense::factor::Factors;
use crate::dense::matrix::SymmetricMatrix;
use crate::error::FeralError;
pub fn solve(factors: &Factors, rhs: &[f64]) -> Result<Vec<f64>, FeralError> {
let n = factors.n;
if rhs.len() != n {
return Err(FeralError::DimensionMismatch {
expected: n,
got: rhs.len(),
});
}
let mut b_hat = vec![0.0; n];
for i in 0..n {
b_hat[i] = factors.d_eq[i] * rhs[i];
}
let mut y = vec![0.0; n];
for i in 0..n {
y[i] = b_hat[factors.perm[i]];
}
let mut z = y;
forward_substitute(factors, &mut z);
let mut w = z;
d_block_solve(factors, &mut w);
let mut v = w;
backward_substitute(factors, &mut v);
let mut x_hat = vec![0.0; n];
for i in 0..n {
x_hat[factors.perm[i]] = v[i];
}
let mut x = x_hat;
for (xi, &di) in x.iter_mut().zip(factors.d_eq.iter()) {
*xi *= di;
}
Ok(x)
}
pub fn solve_refined(
matrix: &SymmetricMatrix,
factors: &Factors,
rhs: &[f64],
) -> Result<Vec<f64>, FeralError> {
let n = factors.n;
if rhs.len() != n {
return Err(FeralError::DimensionMismatch {
expected: n,
got: rhs.len(),
});
}
let mut x = solve(factors, rhs)?;
let mut r = vec![0.0; n];
let mut ax = vec![0.0; n];
matrix.symv(&x, &mut ax);
for i in 0..n {
r[i] = rhs[i] - ax[i];
}
let mut r_norm = norm2(&r);
let mut best_x = x.clone();
let mut best_r_norm = r_norm;
let max_steps = 10;
let n_sqrt = (n as f64).sqrt();
let threshold = f64::EPSILON * n_sqrt;
let divergence_factor = 100.0;
let b_norm = norm2(rhs).max(1.0);
let target_r = threshold * b_norm;
for _ in 0..max_steps {
if best_r_norm < target_r {
break;
}
let dx = solve(factors, &r)?;
let mut x_new = x.clone();
for i in 0..n {
x_new[i] += dx[i];
}
let mut r_new = vec![0.0; n];
let mut ax_new = vec![0.0; n];
matrix.symv(&x_new, &mut ax_new);
for i in 0..n {
r_new[i] = rhs[i] - ax_new[i];
}
let r_new_norm = norm2(&r_new);
if r_new_norm < best_r_norm {
best_r_norm = r_new_norm;
best_x = x_new.clone();
}
x = x_new;
r = r_new;
r_norm = r_new_norm;
if r_norm > best_r_norm * divergence_factor {
break;
}
}
Ok(best_x)
}
fn forward_substitute(factors: &Factors, z: &mut [f64]) {
let n = factors.n;
let l = &factors.l;
for j in 0..n {
let z_j = z[j];
for i in (j + 1)..n {
z[i] -= l[j * n + i] * z_j;
}
}
}
fn backward_substitute(factors: &Factors, v: &mut [f64]) {
let n = factors.n;
let l = &factors.l;
for j in (0..n).rev() {
let mut sum = 0.0;
for i in (j + 1)..n {
sum += l[j * n + i] * v[i];
}
v[j] -= sum;
}
}
fn d_block_solve(factors: &Factors, w: &mut [f64]) {
let n = factors.n;
let mut k = 0;
while k < n {
if k + 1 < n && factors.d_subdiag[k] != 0.0 {
let a = factors.d_diag[k];
let b = factors.d_subdiag[k];
let c = factors.d_diag[k + 1];
let det = a * c - b * b;
if det.abs() > factors.zero_tol_2x2 {
let b_inv = 1.0 / b;
let ak = a * b_inv;
let ck = c * b_inv;
let denom = 1.0 / (ak * ck - 1.0);
let z0k = w[k] * b_inv;
let z1k = w[k + 1] * b_inv;
w[k] = (ck * z0k - z1k) * denom;
w[k + 1] = (ak * z1k - z0k) * denom;
}
k += 2;
} else {
let d = factors.d_diag[k];
if d.abs() > factors.zero_tol {
w[k] /= d;
}
k += 1;
}
}
}
fn norm2(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}