use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::iter::Sum;
use crate::decomposition::svd;
use crate::error::{LinalgError, LinalgResult};
use crate::solve::solve;
#[derive(Debug, Clone)]
pub struct RefinementResult<F> {
pub solution: Array1<F>,
pub iterations: usize,
pub forward_error: F,
pub backward_error: F,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub struct EquilibrationResult<F> {
pub row_scaling: Array1<F>,
pub col_scaling: Array1<F>,
pub equilibrated: Array2<F>,
}
#[derive(Debug, Clone)]
pub struct ConditionEstimate<F> {
pub condition_number: F,
pub norm_a: F,
pub norm_a_inv: F,
pub confidence: F,
}
#[derive(Debug, Clone)]
pub struct BackwardErrorResult<F> {
pub normwise: F,
pub componentwise: F,
pub residual: Array1<F>,
pub residual_norm: F,
}
pub fn iterative_refinement<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
max_iter: usize,
tolerance: F,
) -> LinalgResult<RefinementResult<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let (m, n) = a.dim();
if m != n {
return Err(LinalgError::DimensionError(
"Matrix must be square for iterative refinement".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::DimensionError(format!(
"Right-hand side length ({}) does not match matrix dimension ({n})",
b.len()
)));
}
let mut x = solve(a, b, None)?;
let b_norm = vector_norm_1(b);
if b_norm < F::epsilon() {
return Ok(RefinementResult {
solution: x,
iterations: 0,
forward_error: F::zero(),
backward_error: F::zero(),
converged: true,
});
}
let mut converged = false;
let mut iterations = 0;
let mut forward_error = F::infinity();
let mut backward_error = F::infinity();
for iter in 0..max_iter {
iterations = iter + 1;
let ax = a.dot(&x);
let mut r = Array1::zeros(n);
for i in 0..n {
r[i] = b[i] - ax[i];
}
let r_norm = vector_norm_inf(&r.view());
backward_error = r_norm / (matrix_norm_inf(a) * vector_norm_inf_arr(&x) + b_norm);
if backward_error < tolerance {
converged = true;
break;
}
let dx = solve(a, &r.view(), None)?;
let dx_norm = vector_norm_inf_arr(&dx);
let x_norm = vector_norm_inf_arr(&x);
forward_error = if x_norm > F::epsilon() {
dx_norm / x_norm
} else {
dx_norm
};
for i in 0..n {
x[i] += dx[i];
}
if forward_error < tolerance {
converged = true;
break;
}
}
Ok(RefinementResult {
solution: x,
iterations,
forward_error,
backward_error,
converged,
})
}
pub fn equilibrate<F>(a: &ArrayView2<F>) -> LinalgResult<EquilibrationResult<F>>
where
F: Float + NumAssign + Sum + Debug + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (m, n) = a.dim();
let mut row_scaling = Array1::zeros(m);
for i in 0..m {
let mut max_val = F::zero();
for j in 0..n {
let abs_val = a[[i, j]].abs();
if abs_val > max_val {
max_val = abs_val;
}
}
row_scaling[i] = if max_val > F::epsilon() {
F::one() / max_val
} else {
F::one()
};
}
let mut scaled = a.to_owned();
for i in 0..m {
for j in 0..n {
scaled[[i, j]] *= row_scaling[i];
}
}
let mut col_scaling = Array1::zeros(n);
for j in 0..n {
let mut max_val = F::zero();
for i in 0..m {
let abs_val = scaled[[i, j]].abs();
if abs_val > max_val {
max_val = abs_val;
}
}
col_scaling[j] = if max_val > F::epsilon() {
F::one() / max_val
} else {
F::one()
};
}
let mut equilibrated = scaled;
for i in 0..m {
for j in 0..n {
equilibrated[[i, j]] *= col_scaling[j];
}
}
Ok(EquilibrationResult {
row_scaling,
col_scaling,
equilibrated,
})
}
pub fn equilibrated_solve<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> LinalgResult<Array1<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let n = a.nrows();
if a.ncols() != n || b.len() != n {
return Err(LinalgError::DimensionError(
"Dimensions mismatch for equilibrated solve".to_string(),
));
}
let eq = equilibrate(a)?;
let mut b_eq = Array1::zeros(n);
for i in 0..n {
b_eq[i] = eq.row_scaling[i] * b[i];
}
let y = solve(&eq.equilibrated.view(), &b_eq.view(), None)?;
let mut x = Array1::zeros(n);
for i in 0..n {
x[i] = eq.col_scaling[i] * y[i];
}
Ok(x)
}
pub fn estimate_condition<F>(a: &ArrayView2<F>) -> LinalgResult<ConditionEstimate<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let (m, n) = a.dim();
if m != n {
return Err(LinalgError::DimensionError(
"Matrix must be square for condition estimation".to_string(),
));
}
let norm_a = matrix_norm_1(a);
if norm_a < F::epsilon() {
return Ok(ConditionEstimate {
condition_number: F::infinity(),
norm_a,
norm_a_inv: F::infinity(),
confidence: F::one(),
});
}
let n_f = F::from(n).unwrap_or(F::one());
let mut x = Array1::from_elem(n, F::one() / n_f);
let max_hager_iters = 5;
let mut gamma = F::zero();
for _iter in 0..max_hager_iters {
let w = solve(a, &x.view(), None)?;
gamma = vector_norm_1_arr(&w);
let mut z = Array1::zeros(n);
for i in 0..n {
z[i] = if w[i] >= F::zero() {
F::one()
} else {
-F::one()
};
}
let at = a.t().to_owned();
let v = solve(&at.view(), &z.view(), None)?;
let v_inf = vector_norm_inf_arr(&v);
let zt_x: F = z
.iter()
.zip(x.iter())
.fold(F::zero(), |acc, (&zi, &xi)| acc + zi * xi);
if v_inf <= zt_x {
break;
}
let mut max_idx = 0;
let mut max_val = F::zero();
for i in 0..n {
let abs_vi = v[i].abs();
if abs_vi > max_val {
max_val = abs_vi;
max_idx = i;
}
}
x = Array1::zeros(n);
x[max_idx] = F::one();
}
let norm_a_inv = gamma;
let condition_number = norm_a * norm_a_inv;
Ok(ConditionEstimate {
condition_number,
norm_a,
norm_a_inv,
confidence: F::from(0.9).unwrap_or(F::one()),
})
}
pub fn condition_number_svd<F>(a: &ArrayView2<F>) -> LinalgResult<F>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (_u, s, _vt) = svd(a, false, None)?;
if s.is_empty() {
return Ok(F::infinity());
}
let sigma_max = s[0];
let sigma_min = s[s.len() - 1];
if sigma_min < F::epsilon() {
Ok(F::infinity())
} else {
Ok(sigma_max / sigma_min)
}
}
pub fn backward_error<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
x: &ArrayView1<F>,
) -> LinalgResult<BackwardErrorResult<F>>
where
F: Float + NumAssign + Sum + Debug + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.nrows();
if a.ncols() != n || b.len() != n || x.len() != n {
return Err(LinalgError::DimensionError(
"Dimension mismatch in backward error analysis".to_string(),
));
}
let ax = a.dot(x);
let mut residual = Array1::zeros(n);
for i in 0..n {
residual[i] = b[i] - ax[i];
}
let r_norm = vector_norm_inf_arr(&residual);
let a_norm = matrix_norm_inf(a);
let x_norm = vector_norm_inf(x);
let b_norm = vector_norm_1(b);
let normwise = r_norm / (a_norm * x_norm + b_norm + F::epsilon());
let mut componentwise = F::zero();
for i in 0..n {
let mut denom = b[i].abs();
for j in 0..n {
denom += a[[i, j]].abs() * x[j].abs();
}
let omega_i = if denom > F::epsilon() {
residual[i].abs() / denom
} else {
F::zero()
};
if omega_i > componentwise {
componentwise = omega_i;
}
}
Ok(BackwardErrorResult {
normwise,
componentwise,
residual,
residual_norm: r_norm,
})
}
pub fn richardson_iteration<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
omega: Option<F>,
x0: Option<&ArrayView1<F>>,
max_iter: usize,
tolerance: F,
) -> LinalgResult<RefinementResult<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let (m, n) = a.dim();
if m != n {
return Err(LinalgError::DimensionError(
"Matrix must be square for Richardson iteration".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::DimensionError(format!(
"Right-hand side length ({}) must match matrix dimension ({n})",
b.len()
)));
}
let omega_val = omega.unwrap_or_else(|| {
let rho = estimate_spectral_radius(a, 20);
if rho > F::epsilon() {
F::one() / rho
} else {
F::one()
}
});
let mut x = if let Some(x0_ref) = x0 {
x0_ref.to_owned()
} else {
Array1::zeros(n)
};
let b_norm = vector_norm_1(b);
let mut converged = false;
let mut iterations = 0;
let mut forward_error = F::infinity();
let mut backward_error = F::infinity();
for iter in 0..max_iter {
iterations = iter + 1;
let ax = a.dot(&x);
let mut r = Array1::zeros(n);
for i in 0..n {
r[i] = b[i] - ax[i];
}
let r_norm = vector_norm_inf_arr(&r);
backward_error = if b_norm > F::epsilon() {
r_norm / b_norm
} else {
r_norm
};
if backward_error < tolerance {
converged = true;
break;
}
let x_old_norm = vector_norm_inf_arr(&x);
for i in 0..n {
x[i] += omega_val * r[i];
}
let mut dx_norm = F::zero();
for i in 0..n {
let dx_i = omega_val * r[i];
if dx_i.abs() > dx_norm {
dx_norm = dx_i.abs();
}
}
forward_error = if x_old_norm > F::epsilon() {
dx_norm / x_old_norm
} else {
dx_norm
};
}
Ok(RefinementResult {
solution: x,
iterations,
forward_error,
backward_error,
converged,
})
}
fn estimate_spectral_radius<F>(a: &ArrayView2<F>, max_iter: usize) -> F
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.nrows();
let mut x = Array1::from_elem(n, F::one() / F::from(n).unwrap_or(F::one()));
let mut eigenvalue = F::one();
for _ in 0..max_iter {
let y = a.dot(&x);
let y_norm = y.iter().fold(F::zero(), |acc, &v| acc.max(v.abs()));
if y_norm < F::epsilon() {
return F::zero();
}
eigenvalue = y_norm;
x = y.mapv(|v| v / y_norm);
}
eigenvalue
}
fn vector_norm_1<F: Float>(v: &ArrayView1<F>) -> F {
v.iter().fold(F::zero(), |acc, &x| acc + x.abs())
}
fn vector_norm_1_arr<F: Float>(v: &Array1<F>) -> F {
v.iter().fold(F::zero(), |acc, &x| acc + x.abs())
}
fn vector_norm_inf<F: Float>(v: &ArrayView1<F>) -> F {
v.iter().fold(F::zero(), |acc, &x| acc.max(x.abs()))
}
fn vector_norm_inf_arr<F: Float>(v: &Array1<F>) -> F {
v.iter().fold(F::zero(), |acc, &x| acc.max(x.abs()))
}
fn matrix_norm_1<F: Float>(a: &ArrayView2<F>) -> F {
let (_m, n) = a.dim();
let mut max_col_sum = F::zero();
for j in 0..n {
let col_sum = a.column(j).iter().fold(F::zero(), |acc, &x| acc + x.abs());
if col_sum > max_col_sum {
max_col_sum = col_sum;
}
}
max_col_sum
}
fn matrix_norm_inf<F: Float>(a: &ArrayView2<F>) -> F {
let (m, _n) = a.dim();
let mut max_row_sum = F::zero();
for i in 0..m {
let row_sum = a.row(i).iter().fold(F::zero(), |acc, &x| acc + x.abs());
if row_sum > max_row_sum {
max_row_sum = row_sum;
}
}
max_row_sum
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_iterative_refinement_basic() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![5.0, 4.0];
let result = iterative_refinement(&a.view(), &b.view(), 10, 1e-12);
assert!(result.is_ok());
let ref_result = result.expect("refinement failed");
let ax = a.dot(&ref_result.solution);
for i in 0..2 {
assert!(
(ax[i] - b[i]).abs() < 1e-10,
"Solution inaccurate at index {i}"
);
}
assert!(ref_result.converged);
}
#[test]
fn test_iterative_refinement_identity() {
let a = Array2::<f64>::eye(3);
let b = array![1.0, 2.0, 3.0];
let ref_result =
iterative_refinement(&a.view(), &b.view(), 5, 1e-14).expect("refinement failed");
for i in 0..3 {
assert!(
(ref_result.solution[i] - b[i]).abs() < 1e-12,
"Identity system solution wrong"
);
}
}
#[test]
fn test_iterative_refinement_dimension_errors() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; let b = array![1.0, 2.0];
assert!(iterative_refinement(&a.view(), &b.view(), 5, 1e-10).is_err());
let a2 = array![[1.0, 0.0], [0.0, 1.0]];
let b2 = array![1.0, 2.0, 3.0]; assert!(iterative_refinement(&a2.view(), &b2.view(), 5, 1e-10).is_err());
}
#[test]
fn test_equilibrate_basic() {
let a = array![[1000.0, 1.0], [1.0, 0.001]];
let result = equilibrate(&a.view());
assert!(result.is_ok());
let eq = result.expect("equilibrate failed");
assert_eq!(eq.equilibrated.nrows(), 2);
assert_eq!(eq.equilibrated.ncols(), 2);
assert_eq!(eq.row_scaling.len(), 2);
assert_eq!(eq.col_scaling.len(), 2);
for i in 0..2 {
let row_max: f64 = eq
.equilibrated
.row(i)
.iter()
.map(|&x| x.abs())
.fold(0.0, f64::max);
assert!(
row_max <= 1.0 + 1e-10,
"Row {i} max should be <= 1, got {row_max}"
);
}
}
#[test]
fn test_equilibrate_identity() {
let a = Array2::<f64>::eye(3);
let eq = equilibrate(&a.view()).expect("equilibrate failed");
for i in 0..3 {
assert!((eq.row_scaling[i] - 1.0).abs() < 1e-10);
assert!((eq.col_scaling[i] - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_equilibrated_solve() {
let a = array![[1000.0, 1.0], [1.0, 1000.0]];
let x_true = array![1.0, 1.0];
let b = a.dot(&x_true);
let x = equilibrated_solve(&a.view(), &b.view());
assert!(x.is_ok());
let x_sol = x.expect("equilibrated solve failed");
for i in 0..2 {
assert!(
(x_sol[i] - 1.0).abs() < 0.1,
"Equilibrated solve inaccurate at index {i}: {}",
x_sol[i]
);
}
}
#[test]
fn test_equilibrated_solve_dimension_error() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![1.0, 2.0, 3.0]; assert!(equilibrated_solve(&a.view(), &b.view()).is_err());
}
#[test]
fn test_estimate_condition_identity() {
let a = Array2::<f64>::eye(3);
let est = estimate_condition(&a.view()).expect("condition estimation failed");
assert!(
(est.condition_number - 1.0).abs() < 0.5,
"Identity condition should be ~1, got {}",
est.condition_number
);
}
#[test]
fn test_estimate_condition_ill_conditioned() {
let a = array![
[1.0, 0.5, 1.0 / 3.0],
[0.5, 1.0 / 3.0, 0.25],
[1.0 / 3.0, 0.25, 0.2]
];
let cond = condition_number_svd(&a.view()).expect("SVD condition failed");
assert!(cond > 100.0, "Should detect ill-conditioning, got {cond}");
}
#[test]
fn test_estimate_condition_non_square() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
assert!(estimate_condition(&a.view()).is_err());
}
#[test]
fn test_condition_number_svd() {
let a = array![[2.0, 0.0], [0.0, 1.0]];
let cond = condition_number_svd(&a.view()).expect("SVD condition failed");
assert!(
(cond - 2.0).abs() < 0.1,
"Condition should be ~2, got {cond}"
);
}
#[test]
fn test_condition_number_svd_singular() {
let a = array![[1.0, 2.0], [2.0, 4.0]]; let cond = condition_number_svd(&a.view()).expect("SVD condition failed");
assert!(
cond > 1e10 || cond.is_infinite(),
"Singular matrix should have infinite condition"
);
}
#[test]
fn test_backward_error_exact_solution() {
let a = array![[2.0, 1.0], [1.0, 3.0]];
let x_true = array![1.0, 1.0];
let b = a.dot(&x_true);
let be =
backward_error(&a.view(), &b.view(), &x_true.view()).expect("backward error failed");
assert!(
be.normwise < 1e-12,
"Exact solution should have tiny backward error"
);
assert!(
be.componentwise < 1e-12,
"Componentwise error should be tiny"
);
assert!(be.residual_norm < 1e-12, "Residual should be tiny");
}
#[test]
fn test_backward_error_approximate_solution() {
let a = array![[2.0, 1.0], [1.0, 3.0]];
let b = array![3.0, 4.0];
let x_approx = array![1.1, 0.9];
let be =
backward_error(&a.view(), &b.view(), &x_approx.view()).expect("backward error failed");
assert!(
be.normwise > 0.0,
"Approximate solution should have positive backward error"
);
assert!(be.residual_norm > 0.0, "Should have non-zero residual");
}
#[test]
fn test_backward_error_dimension_mismatch() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![1.0, 2.0, 3.0]; let x = array![1.0, 2.0];
assert!(backward_error(&a.view(), &b.view(), &x.view()).is_err());
}
#[test]
fn test_richardson_iteration_basic() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![5.0, 4.0];
let result = richardson_iteration(&a.view(), &b.view(), None, None, 200, 1e-8);
assert!(result.is_ok());
let ref_result = result.expect("Richardson failed");
let ax = a.dot(&ref_result.solution);
for i in 0..2 {
assert!(
(ax[i] - b[i]).abs() < 0.1,
"Richardson solution inaccurate at {i}: ax={}, b={}",
ax[i],
b[i]
);
}
}
#[test]
fn test_richardson_with_omega() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![5.0, 4.0];
let result = richardson_iteration(&a.view(), &b.view(), Some(0.25), None, 500, 1e-8);
assert!(result.is_ok());
}
#[test]
fn test_richardson_with_initial_guess() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![5.0, 4.0];
let x0 = array![1.0, 1.0];
let result =
richardson_iteration(&a.view(), &b.view(), Some(0.2), Some(&x0.view()), 100, 1e-8);
assert!(result.is_ok());
}
#[test]
fn test_richardson_dimension_errors() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; let b = array![1.0, 2.0];
assert!(richardson_iteration(&a.view(), &b.view(), None, None, 10, 1e-8).is_err());
let a2 = array![[1.0, 0.0], [0.0, 1.0]];
let b2 = array![1.0, 2.0, 3.0]; assert!(richardson_iteration(&a2.view(), &b2.view(), None, None, 10, 1e-8).is_err());
}
#[test]
fn test_norm_helpers() {
let v = array![1.0, -2.0, 3.0];
assert!((vector_norm_1(&v.view()) - 6.0).abs() < 1e-10);
assert!((vector_norm_inf(&v.view()) - 3.0).abs() < 1e-10);
let a = array![[1.0, -2.0], [3.0, 4.0]];
assert!((matrix_norm_1(&a.view()) - 6.0).abs() < 1e-10);
assert!((matrix_norm_inf(&a.view()) - 7.0).abs() < 1e-10);
}
#[test]
fn test_estimate_spectral_radius() {
let a = array![[2.0, 0.0], [0.0, 1.0]];
let rho = estimate_spectral_radius(&a.view(), 30);
assert!(
(rho - 2.0).abs() < 0.1,
"Spectral radius of diag(2,1) should be ~2, got {rho}"
);
}
#[test]
fn test_refinement_result_fields() {
let a = Array2::<f64>::eye(2);
let b = array![1.0, 2.0];
let result =
iterative_refinement(&a.view(), &b.view(), 5, 1e-10).expect("refinement failed");
assert!(result.forward_error.is_finite() || result.converged);
assert!(result.backward_error.is_finite() || result.converged);
assert!(result.iterations <= 5);
}
#[test]
fn test_iterative_refinement_zero_rhs() {
let a = array![[1.0, 0.0], [0.0, 1.0]];
let b = array![0.0, 0.0];
let result =
iterative_refinement(&a.view(), &b.view(), 5, 1e-10).expect("refinement failed");
for i in 0..2 {
assert!(
result.solution[i].abs() < 1e-12,
"Zero RHS should give zero solution"
);
}
assert!(result.converged);
}
}