use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct RefinementConfig<F> {
pub max_iterations: usize,
pub tolerance: F,
pub stop_on_stagnation: bool,
}
impl<F: Float> Default for RefinementConfig<F> {
fn default() -> Self {
Self {
max_iterations: 10,
tolerance: F::from(1e-14).unwrap_or_else(|| F::epsilon()),
stop_on_stagnation: true,
}
}
}
#[derive(Debug, Clone)]
pub struct IterativeRefinementResult<F> {
pub solution: Array1<F>,
pub iterations: usize,
pub residual_history: Vec<F>,
pub forward_error: F,
pub backward_error: F,
pub converged: bool,
}
pub fn lu_iterative_refinement<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
config: &RefinementConfig<F>,
) -> LinalgResult<IterativeRefinementResult<F>>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
if m != n {
return Err(LinalgError::DimensionError(
"LU refinement: matrix must be square".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::DimensionError(format!(
"LU refinement: b length ({}) != matrix dimension ({n})",
b.len()
)));
}
let (piv, l, u) = lu_factor_partial(a)?;
let pb = apply_perm_vec(b, &piv);
let y = forward_solve(&l, &pb.view())?;
let mut x = back_solve(&u, &y.view())?;
let b_norm = inf_norm_vec(b);
let a_norm = inf_norm_mat(a);
let mut residual_history = Vec::with_capacity(config.max_iterations + 1);
let mut converged = false;
let mut iterations = 0;
let mut prev_res_norm = F::infinity();
let r0 = compute_residual(a, &x.view(), b);
let r0_norm = inf_norm_arr(&r0);
residual_history.push(r0_norm);
if b_norm > F::epsilon() && r0_norm / b_norm < config.tolerance {
let x_norm_init = inf_norm_arr(&x);
let bw_err = r0_norm / (a_norm * x_norm_init + b_norm);
return Ok(IterativeRefinementResult {
solution: x,
iterations: 0,
residual_history,
forward_error: F::zero(),
backward_error: bw_err,
converged: true,
});
}
for _it in 0..config.max_iterations {
iterations += 1;
let r = compute_residual(a, &x.view(), b);
let r_norm = inf_norm_arr(&r);
residual_history.push(r_norm);
let x_norm = inf_norm_arr(&x);
let backward_err = if a_norm * x_norm + b_norm > F::epsilon() {
r_norm / (a_norm * x_norm + b_norm)
} else {
r_norm
};
if backward_err < config.tolerance {
converged = true;
break;
}
if config.stop_on_stagnation && r_norm >= prev_res_norm {
break;
}
prev_res_norm = r_norm;
let pr = apply_perm_vec(&r.view(), &piv);
let y_corr = forward_solve(&l, &pr.view())?;
let dx = back_solve(&u, &y_corr.view())?;
for i in 0..n {
x[i] += dx[i];
}
}
let final_r = compute_residual(a, &x.view(), b);
let final_r_norm = inf_norm_arr(&final_r);
let x_norm = inf_norm_arr(&x);
let forward_error = if x_norm > F::epsilon() {
if residual_history.len() >= 2 {
let last_res = residual_history[residual_history.len() - 1];
last_res / (a_norm * x_norm)
} else {
final_r_norm / (a_norm * x_norm)
}
} else {
final_r_norm
};
let backward_error = if a_norm * x_norm + b_norm > F::epsilon() {
final_r_norm / (a_norm * x_norm + b_norm)
} else {
final_r_norm
};
Ok(IterativeRefinementResult {
solution: x,
iterations,
residual_history,
forward_error,
backward_error,
converged,
})
}
pub fn qr_iterative_refinement<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
config: &RefinementConfig<F>,
) -> LinalgResult<IterativeRefinementResult<F>>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
if b.len() != m {
return Err(LinalgError::DimensionError(format!(
"QR refinement: b length ({}) != matrix rows ({m})",
b.len()
)));
}
if m < n {
return Err(LinalgError::DimensionError(
"QR refinement: requires m >= n (overdetermined or square)".to_string(),
));
}
let (q, r_mat) = householder_qr_internal(a)?;
let qtb = q.t().dot(b); let qtb_n = qtb.slice(scirs2_core::ndarray::s![..n]).to_owned();
let mut x = back_solve_rect(&r_mat, &qtb_n.view(), n)?;
let b_norm = inf_norm_vec(b);
let a_norm = inf_norm_mat(a);
let mut residual_history = Vec::with_capacity(config.max_iterations + 1);
let mut converged = false;
let mut iterations = 0;
let mut prev_res_norm = F::infinity();
let r0 = compute_residual(a, &x.view(), b);
let r0_norm = inf_norm_arr(&r0);
residual_history.push(r0_norm);
if b_norm > F::epsilon() && r0_norm / b_norm < config.tolerance {
let x_norm_init = inf_norm_arr(&x);
let bw_err = r0_norm / (a_norm * x_norm_init + b_norm);
return Ok(IterativeRefinementResult {
solution: x,
iterations: 0,
residual_history,
forward_error: F::zero(),
backward_error: bw_err,
converged: true,
});
}
for _it in 0..config.max_iterations {
iterations += 1;
let r = compute_residual(a, &x.view(), b);
let r_norm = inf_norm_arr(&r);
residual_history.push(r_norm);
let x_norm = inf_norm_arr(&x);
let backward_err = if a_norm * x_norm + b_norm > F::epsilon() {
r_norm / (a_norm * x_norm + b_norm)
} else {
r_norm
};
if backward_err < config.tolerance {
converged = true;
break;
}
if config.stop_on_stagnation && r_norm >= prev_res_norm {
break;
}
prev_res_norm = r_norm;
let qt_r = q.t().dot(&r);
let qt_r_n = qt_r.slice(scirs2_core::ndarray::s![..n]).to_owned();
let dx = back_solve_rect(&r_mat, &qt_r_n.view(), n)?;
for i in 0..n {
x[i] += dx[i];
}
}
let final_r = compute_residual(a, &x.view(), b);
let final_r_norm = inf_norm_arr(&final_r);
let x_norm = inf_norm_arr(&x);
let forward_error = if x_norm > F::epsilon() {
final_r_norm / (a_norm * x_norm)
} else {
final_r_norm
};
let backward_error = if a_norm * x_norm + b_norm > F::epsilon() {
final_r_norm / (a_norm * x_norm + b_norm)
} else {
final_r_norm
};
Ok(IterativeRefinementResult {
solution: x,
iterations,
residual_history,
forward_error,
backward_error,
converged,
})
}
pub fn generic_iterative_refinement<F, S>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
mut solver: S,
config: &RefinementConfig<F>,
) -> LinalgResult<IterativeRefinementResult<F>>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + Send + Sync + 'static,
S: FnMut(&Array1<F>) -> LinalgResult<Array1<F>>,
{
let (m, _n) = a.dim();
if b.len() != m {
return Err(LinalgError::DimensionError(format!(
"Generic refinement: b length ({}) != matrix rows ({m})",
b.len()
)));
}
let b_owned = b.to_owned();
let mut x = solver(&b_owned)?;
let b_norm = inf_norm_vec(b);
let a_norm = inf_norm_mat(a);
let mut residual_history = Vec::with_capacity(config.max_iterations + 1);
let mut converged = false;
let mut iterations = 0;
let mut prev_res_norm = F::infinity();
let r0 = compute_residual(a, &x.view(), b);
residual_history.push(inf_norm_arr(&r0));
for _it in 0..config.max_iterations {
iterations += 1;
let r = compute_residual(a, &x.view(), b);
let r_norm = inf_norm_arr(&r);
residual_history.push(r_norm);
let x_norm = inf_norm_arr(&x);
let backward_err = if a_norm * x_norm + b_norm > F::epsilon() {
r_norm / (a_norm * x_norm + b_norm)
} else {
r_norm
};
if backward_err < config.tolerance {
converged = true;
break;
}
if config.stop_on_stagnation && r_norm >= prev_res_norm {
break;
}
prev_res_norm = r_norm;
let dx = solver(&r)?;
let n = x.len();
for i in 0..n {
x[i] += dx[i];
}
}
let final_r = compute_residual(a, &x.view(), b);
let final_r_norm = inf_norm_arr(&final_r);
let x_norm = inf_norm_arr(&x);
let forward_error = if x_norm > F::epsilon() {
final_r_norm / (a_norm * x_norm)
} else {
final_r_norm
};
let backward_error = if a_norm * x_norm + b_norm > F::epsilon() {
final_r_norm / (a_norm * x_norm + b_norm)
} else {
final_r_norm
};
Ok(IterativeRefinementResult {
solution: x,
iterations,
residual_history,
forward_error,
backward_error,
converged,
})
}
fn compute_residual<F>(a: &ArrayView2<F>, x: &ArrayView1<F>, b: &ArrayView1<F>) -> Array1<F>
where
F: Float + NumAssign + Sum + ScalarOperand,
{
let ax = a.dot(x);
let n = b.len();
let mut r = Array1::<F>::zeros(n);
for i in 0..n {
r[i] = b[i] - ax[i];
}
r
}
fn inf_norm_vec<F: Float>(v: &ArrayView1<F>) -> F {
v.iter().fold(F::zero(), |acc, &x| acc.max(x.abs()))
}
fn inf_norm_arr<F: Float>(v: &Array1<F>) -> F {
v.iter().fold(F::zero(), |acc, &x| acc.max(x.abs()))
}
fn inf_norm_mat<F: Float + Sum>(a: &ArrayView2<F>) -> F {
let (m, n) = a.dim();
let mut max_row = F::zero();
for i in 0..m {
let mut row_sum = F::zero();
for j in 0..n {
row_sum = row_sum + a[[i, j]].abs();
}
if row_sum > max_row {
max_row = row_sum;
}
}
max_row
}
fn lu_factor_partial<F>(a: &ArrayView2<F>) -> LinalgResult<(Vec<usize>, Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Debug,
{
let n = a.nrows();
let mut lu = a.to_owned();
let mut piv: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = lu[[k, k]].abs();
let mut max_row = k;
for i in (k + 1)..n {
let v = lu[[i, k]].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val <= F::epsilon() {
return Err(LinalgError::SingularMatrixError(
"LU factorization: matrix is singular or nearly singular".to_string(),
));
}
if max_row != k {
piv.swap(k, max_row);
for j in 0..n {
let tmp = lu[[k, j]];
lu[[k, j]] = lu[[max_row, j]];
lu[[max_row, j]] = tmp;
}
}
for i in (k + 1)..n {
lu[[i, k]] = lu[[i, k]] / lu[[k, k]];
for j in (k + 1)..n {
let lik = lu[[i, k]];
let ukj = lu[[k, j]];
lu[[i, j]] -= lik * ukj;
}
}
}
let mut l = Array2::<F>::eye(n);
let mut u = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
if j < i {
l[[i, j]] = lu[[i, j]];
} else {
u[[i, j]] = lu[[i, j]];
}
}
}
Ok((piv, l, u))
}
fn apply_perm_vec<F: Float>(b: &ArrayView1<F>, piv: &[usize]) -> Array1<F> {
let n = b.len();
let mut result = Array1::<F>::zeros(n);
for i in 0..n {
result[i] = b[piv[i]];
}
result
}
fn forward_solve<F: Float + NumAssign>(
l: &Array2<F>,
b: &ArrayView1<F>,
) -> LinalgResult<Array1<F>> {
let n = l.nrows();
let mut y = Array1::<F>::zeros(n);
for i in 0..n {
let mut s = b[i];
for j in 0..i {
s -= l[[i, j]] * y[j];
}
y[i] = s; }
Ok(y)
}
fn back_solve<F: Float + NumAssign + Debug>(
u: &Array2<F>,
y: &ArrayView1<F>,
) -> LinalgResult<Array1<F>> {
let n = u.nrows();
let mut x = Array1::<F>::zeros(n);
for i in (0..n).rev() {
let mut s = y[i];
for j in (i + 1)..n {
s -= u[[i, j]] * x[j];
}
let diag = u[[i, i]];
if diag.abs() <= F::epsilon() {
return Err(LinalgError::SingularMatrixError(format!(
"Back substitution: zero diagonal at index {i}"
)));
}
x[i] = s / diag;
}
Ok(x)
}
fn back_solve_rect<F: Float + NumAssign + Debug>(
r: &Array2<F>,
y: &ArrayView1<F>,
k: usize,
) -> LinalgResult<Array1<F>> {
let mut x = Array1::<F>::zeros(k);
for i in (0..k).rev() {
let mut s = y[i];
for j in (i + 1)..k {
s -= r[[i, j]] * x[j];
}
let diag = r[[i, i]];
if diag.abs() <= F::epsilon() {
return Err(LinalgError::SingularMatrixError(format!(
"Back substitution: zero diagonal at index {i}"
)));
}
x[i] = s / diag;
}
Ok(x)
}
fn householder_qr_internal<F>(a: &ArrayView2<F>) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + 'static,
{
let (m, n) = a.dim();
let min_dim = m.min(n);
let mut r = a.to_owned();
let mut q = Array2::<F>::eye(m);
let two = F::from(2.0).unwrap_or_else(|| F::one() + F::one());
for k in 0..min_dim {
let mut x = Array1::<F>::zeros(m - k);
for i in k..m {
x[i - k] = r[[i, k]];
}
let x_norm = x.iter().fold(F::zero(), |acc, &v| acc + v * v).sqrt();
if x_norm <= F::epsilon() {
continue;
}
let alpha = if x[0] >= F::zero() { -x_norm } else { x_norm };
let mut v = x;
v[0] -= alpha;
let v_norm_sq = v.iter().fold(F::zero(), |acc, &val| acc + val * val);
if v_norm_sq <= F::epsilon() {
continue;
}
let beta = two / v_norm_sq;
for j in k..n {
let mut dot = F::zero();
for i in 0..(m - k) {
dot += v[i] * r[[i + k, j]];
}
for i in 0..(m - k) {
r[[i + k, j]] -= beta * v[i] * dot;
}
}
for row in 0..m {
let mut dot = F::zero();
for jj in 0..(m - k) {
dot += q[[row, jj + k]] * v[jj];
}
for jj in 0..(m - k) {
q[[row, jj + k]] -= beta * dot * v[jj];
}
}
}
Ok((q, r))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_lu_refinement_well_conditioned() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let x_true = array![1.0, 2.0];
let b = a.dot(&x_true);
let config = RefinementConfig {
max_iterations: 10,
tolerance: 1e-14,
stop_on_stagnation: true,
};
let result =
lu_iterative_refinement(&a.view(), &b.view(), &config).expect("LU refinement failed");
assert!(
result.converged,
"should converge for well-conditioned system"
);
for i in 0..2 {
assert!(
(result.solution[i] - x_true[i]).abs() < 1e-12,
"solution[{i}] = {} != {}",
result.solution[i],
x_true[i]
);
}
}
#[test]
fn test_lu_refinement_ill_conditioned() {
let a = array![
[1.0, 0.5, 1.0 / 3.0],
[0.5, 1.0 / 3.0, 0.25],
[1.0 / 3.0, 0.25, 0.2]
];
let x_true = array![1.0, 1.0, 1.0];
let b = a.dot(&x_true);
let config = RefinementConfig {
max_iterations: 20,
tolerance: 1e-12,
stop_on_stagnation: true,
};
let result =
lu_iterative_refinement(&a.view(), &b.view(), &config).expect("LU refinement failed");
let residual = &a.dot(&result.solution) - &b;
let res_norm: f64 = residual.iter().map(|&v| v.abs()).fold(0.0, f64::max);
assert!(
res_norm < 1e-10,
"residual should be small after refinement, got {res_norm}"
);
}
#[test]
fn test_lu_refinement_convergence_tracking() {
let a = array![[2.0, 1.0], [1.0, 3.0]];
let b = array![3.0, 4.0];
let config = RefinementConfig {
max_iterations: 5,
tolerance: 1e-14,
stop_on_stagnation: false,
};
let result =
lu_iterative_refinement(&a.view(), &b.view(), &config).expect("LU refinement failed");
assert!(
!result.residual_history.is_empty(),
"should have residual history"
);
assert!(result.forward_error.is_finite());
assert!(result.backward_error.is_finite());
}
#[test]
fn test_lu_refinement_singular_error() {
let a = array![[1.0, 2.0], [2.0, 4.0]]; let b = array![1.0, 2.0];
let config = RefinementConfig::default();
let result = lu_iterative_refinement(&a.view(), &b.view(), &config);
assert!(result.is_err(), "should fail on singular matrix");
}
#[test]
fn test_lu_refinement_dimension_mismatch() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![1.0, 2.0, 3.0];
let config = RefinementConfig::default();
assert!(lu_iterative_refinement(&a.view(), &b.view(), &config).is_err());
}
#[test]
fn test_qr_refinement_square() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let x_true = array![1.0, 2.0];
let b = a.dot(&x_true);
let config = RefinementConfig {
max_iterations: 10,
tolerance: 1e-14,
stop_on_stagnation: true,
};
let result =
qr_iterative_refinement(&a.view(), &b.view(), &config).expect("QR refinement failed");
for i in 0..2 {
assert!(
(result.solution[i] - x_true[i]).abs() < 1e-10,
"QR solution[{i}] = {} != {}",
result.solution[i],
x_true[i]
);
}
}
#[test]
fn test_qr_refinement_overdetermined() {
let a = array![[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]];
let b = array![1.0, 2.0, 3.0];
let config = RefinementConfig {
max_iterations: 10,
tolerance: 1e-12,
stop_on_stagnation: true,
};
let result =
qr_iterative_refinement(&a.view(), &b.view(), &config).expect("QR refinement failed");
let res = &a.dot(&result.solution) - &b;
let res_norm: f64 = res.iter().map(|&v| v * v).sum::<f64>().sqrt();
assert!(res_norm < 1e-8, "overdetermined LS residual = {res_norm}");
}
#[test]
fn test_qr_refinement_dimension_error() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![1.0, 2.0, 3.0]; let config = RefinementConfig::default();
assert!(qr_iterative_refinement(&a.view(), &b.view(), &config).is_err());
}
#[test]
fn test_generic_refinement() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let x_true = array![1.0, 2.0];
let b = a.dot(&x_true);
let a_clone = a.clone();
let solver = move |rhs: &Array1<f64>| -> LinalgResult<Array1<f64>> {
let (piv, l, u) = lu_factor_partial(&a_clone.view())?;
let pb = apply_perm_vec(&rhs.view(), &piv);
let y = forward_solve(&l, &pb.view())?;
back_solve(&u, &y.view())
};
let config = RefinementConfig {
max_iterations: 5,
tolerance: 1e-14,
stop_on_stagnation: true,
};
let result = generic_iterative_refinement(&a.view(), &b.view(), solver, &config)
.expect("generic refinement failed");
for i in 0..2 {
assert!(
(result.solution[i] - x_true[i]).abs() < 1e-10,
"generic solution[{i}] wrong"
);
}
}
#[test]
fn test_refinement_config_default() {
let config = RefinementConfig::<f64>::default();
assert_eq!(config.max_iterations, 10);
assert!(config.tolerance < 1e-10);
assert!(config.stop_on_stagnation);
}
#[test]
fn test_lu_refinement_identity() {
let a = Array2::<f64>::eye(3);
let b = array![1.0, 2.0, 3.0];
let config = RefinementConfig::default();
let result =
lu_iterative_refinement(&a.view(), &b.view(), &config).expect("identity solve failed");
for i in 0..3 {
assert!(
(result.solution[i] - b[i]).abs() < 1e-14,
"identity system wrong at {i}"
);
}
assert!(result.converged);
}
#[test]
fn test_lu_refinement_non_square_error() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![1.0, 2.0];
let config = RefinementConfig::default();
assert!(lu_iterative_refinement(&a.view(), &b.view(), &config).is_err());
}
}