use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::{Debug, Display};
#[derive(Debug, Clone)]
pub struct IterativeSolverOptions<A> {
pub max_iterations: usize,
pub tolerance: A,
pub verbose: bool,
pub restart: Option<usize>,
}
impl<A: Float> Default for IterativeSolverOptions<A> {
fn default() -> Self {
Self {
max_iterations: 1000,
tolerance: A::from(1e-10).expect("Operation failed"),
verbose: false,
restart: None,
}
}
}
#[derive(Debug, Clone)]
pub struct IterativeSolverResult<A> {
pub solution: Array1<A>,
pub iterations: usize,
pub residual_norm: A,
pub converged: bool,
}
#[allow(dead_code)]
pub fn conjugate_gradient<A>(
a: &ArrayView2<A>,
b: &ArrayView1<A>,
x0: Option<&ArrayView1<A>>,
options: &IterativeSolverOptions<A>,
) -> LinalgResult<IterativeSolverResult<A>>
where
A: Float + NumAssign + Debug + Display + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.shape()[0];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError(
"Matrix A must be square".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::ShapeError(
"Vector b must have same length as matrix dimension".to_string(),
));
}
let mut x = match x0 {
Some(x0_view) => {
if x0_view.len() != n {
return Err(LinalgError::ShapeError(
"Initial guess x0 must have same length as b".to_string(),
));
}
x0_view.to_owned()
}
None => Array1::zeros(n),
};
let mut r = b.to_owned() - a.dot(&x);
let mut r_norm_sq = r.dot(&r);
if r_norm_sq.sqrt() < options.tolerance {
return Ok(IterativeSolverResult {
solution: x,
iterations: 0,
residual_norm: r_norm_sq.sqrt(),
converged: true,
});
}
let mut p = r.clone();
let tolerance_sq = options.tolerance * options.tolerance;
for iteration in 0..options.max_iterations {
let ap = a.dot(&p);
let p_ap = p.dot(&ap);
if p_ap.abs() < A::epsilon() {
return Err(LinalgError::ConvergenceError(
"Conjugate gradient failed: p^T Ap is nearly zero".to_string(),
));
}
let alpha = r_norm_sq / p_ap;
x.scaled_add(alpha, &p);
r.scaled_add(-alpha, &ap);
let r_norm_sq_new = r.dot(&r);
if options.verbose && iteration % 10 == 0 {
println!(
"CG iteration {}: residual = {}",
iteration,
r_norm_sq_new.sqrt()
);
}
if r_norm_sq_new < tolerance_sq {
return Ok(IterativeSolverResult {
solution: x,
iterations: iteration + 1,
residual_norm: r_norm_sq_new.sqrt(),
converged: true,
});
}
let beta = r_norm_sq_new / r_norm_sq;
p = &r + &p * beta;
r_norm_sq = r_norm_sq_new;
}
Ok(IterativeSolverResult {
solution: x,
iterations: options.max_iterations,
residual_norm: r_norm_sq.sqrt(),
converged: false,
})
}
#[allow(dead_code)]
pub fn preconditioned_conjugate_gradient<A, F>(
a: &ArrayView2<A>,
b: &ArrayView1<A>,
mut preconditioner: F,
x0: Option<&ArrayView1<A>>,
options: &IterativeSolverOptions<A>,
) -> LinalgResult<IterativeSolverResult<A>>
where
A: Float + NumAssign + Debug + Display + scirs2_core::ndarray::ScalarOperand + 'static,
F: FnMut(&ArrayView1<A>) -> Array1<A>,
{
let n = a.shape()[0];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError(
"Matrix A must be square".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::ShapeError(
"Vector b must have same length as matrix dimension".to_string(),
));
}
let mut x = match x0 {
Some(x0_view) => {
if x0_view.len() != n {
return Err(LinalgError::ShapeError(
"Initial guess x0 must have same length as b".to_string(),
));
}
x0_view.to_owned()
}
None => Array1::zeros(n),
};
let mut r = b.to_owned() - a.dot(&x);
let mut z = preconditioner(&r.view());
let mut p = z.clone();
let mut rz_old = r.dot(&z);
let tolerance_sq = options.tolerance * options.tolerance;
for iteration in 0..options.max_iterations {
let ap = a.dot(&p);
let p_ap = p.dot(&ap);
if p_ap.abs() < A::epsilon() {
return Err(LinalgError::ConvergenceError(
"PCG failed: p^T Ap is nearly zero".to_string(),
));
}
let alpha = rz_old / p_ap;
x.scaled_add(alpha, &p);
r.scaled_add(-alpha, &ap);
let r_norm_sq = r.dot(&r);
if options.verbose && iteration % 10 == 0 {
println!(
"PCG iteration {}: residual = {}",
iteration,
r_norm_sq.sqrt()
);
}
if r_norm_sq < tolerance_sq {
return Ok(IterativeSolverResult {
solution: x,
iterations: iteration + 1,
residual_norm: r_norm_sq.sqrt(),
converged: true,
});
}
z = preconditioner(&r.view());
let rz_new = r.dot(&z);
let beta = rz_new / rz_old;
p = &z + &p * beta;
rz_old = rz_new;
}
let r_norm = r.dot(&r).sqrt();
Ok(IterativeSolverResult {
solution: x,
iterations: options.max_iterations,
residual_norm: r_norm,
converged: false,
})
}
#[allow(dead_code)]
pub fn gmres<A>(
a: &ArrayView2<A>,
b: &ArrayView1<A>,
x0: Option<&ArrayView1<A>>,
options: &IterativeSolverOptions<A>,
) -> LinalgResult<IterativeSolverResult<A>>
where
A: Float + NumAssign + Debug + Display + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.shape()[0];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError(
"Matrix A must be square".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::ShapeError(
"Vector b must have same length as matrix dimension".to_string(),
));
}
let mut x = match x0 {
Some(x0_view) => {
if x0_view.len() != n {
return Err(LinalgError::ShapeError(
"Initial guess x0 must have same length as b".to_string(),
));
}
x0_view.to_owned()
}
None => Array1::zeros(n),
};
let restart = options.restart.unwrap_or_else(|| n.min(30));
let mut total_iterations = 0;
for _outer in 0..(options.max_iterations / restart).max(1) {
let r = b.to_owned() - a.dot(&x);
let beta = r.dot(&r).sqrt();
if beta < options.tolerance {
return Ok(IterativeSolverResult {
solution: x,
iterations: total_iterations,
residual_norm: beta,
converged: true,
});
}
let mut v = vec![Array1::zeros(n); restart + 1];
v[0] = &r / beta;
let mut h = Array2::<A>::zeros((restart + 1, restart));
let mut j = 0;
while j < restart && total_iterations < options.max_iterations {
let w = a.dot(&v[j]);
let mut w_orth = w.clone();
for i in 0..=j {
h[[i, j]] = w.dot(&v[i]);
w_orth.scaled_add(-h[[i, j]], &v[i]);
}
h[[j + 1, j]] = w_orth.dot(&w_orth).sqrt();
if h[[j + 1, j]] < A::epsilon() {
j += 1;
break;
}
v[j + 1] = &w_orth / h[[j + 1, j]];
j += 1;
total_iterations += 1;
}
let y = solve_least_squares_gmres(&h.slice(scirs2_core::ndarray::s![..j + 1, ..j]), beta)?;
for (i, yi) in y.iter().enumerate() {
x.scaled_add(*yi, &v[i]);
}
let r_final = b.to_owned() - a.dot(&x);
let residual_norm = r_final.dot(&r_final).sqrt();
if options.verbose {
println!("GMRES outer iteration {_outer}: residual = {residual_norm}");
}
if residual_norm < options.tolerance {
return Ok(IterativeSolverResult {
solution: x,
iterations: total_iterations,
residual_norm,
converged: true,
});
}
}
let r_final = b.to_owned() - a.dot(&x);
let residual_norm = r_final.dot(&r_final).sqrt();
Ok(IterativeSolverResult {
solution: x,
iterations: total_iterations,
residual_norm,
converged: false,
})
}
#[allow(dead_code)]
pub fn bicgstab<A>(
a: &ArrayView2<A>,
b: &ArrayView1<A>,
x0: Option<&ArrayView1<A>>,
options: &IterativeSolverOptions<A>,
) -> LinalgResult<IterativeSolverResult<A>>
where
A: Float + NumAssign + Debug + Display + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.shape()[0];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError(
"Matrix A must be square".to_string(),
));
}
if b.len() != n {
return Err(LinalgError::ShapeError(
"Vector b must have same length as matrix dimension".to_string(),
));
}
let mut x = match x0 {
Some(x0_view) => {
if x0_view.len() != n {
return Err(LinalgError::ShapeError(
"Initial guess x0 must have same length as b".to_string(),
));
}
x0_view.to_owned()
}
None => Array1::zeros(n),
};
let mut r = b.to_owned() - a.dot(&x);
let r_norm_init = r.dot(&r).sqrt();
if r_norm_init < options.tolerance {
return Ok(IterativeSolverResult {
solution: x,
iterations: 0,
residual_norm: r_norm_init,
converged: true,
});
}
let r_hat = r.clone();
let mut rho = A::one();
let mut alpha = A::one();
let mut omega = A::one();
let mut v = Array1::zeros(n);
let mut p = Array1::zeros(n);
for iteration in 0..options.max_iterations {
let rho_old = rho;
rho = r_hat.dot(&r);
if rho.abs() < A::epsilon() {
return Err(LinalgError::ConvergenceError(
"BiCGSTAB failed: rho is nearly zero".to_string(),
));
}
let beta = (rho / rho_old) * (alpha / omega);
p = &r + &(&p - &v * omega) * beta;
v = a.dot(&p);
alpha = rho / r_hat.dot(&v);
let s = &r - &v * alpha;
let s_norm = s.dot(&s).sqrt();
if s_norm < options.tolerance {
x.scaled_add(alpha, &p);
return Ok(IterativeSolverResult {
solution: x,
iterations: iteration + 1,
residual_norm: s_norm,
converged: true,
});
}
let t = a.dot(&s);
omega = t.dot(&s) / t.dot(&t);
x.scaled_add(alpha, &p);
x.scaled_add(omega, &s);
r = &s - &t * omega;
let r_norm = r.dot(&r).sqrt();
if options.verbose && iteration % 10 == 0 {
println!("BiCGSTAB iteration {iteration}: residual = {r_norm}");
}
if r_norm < options.tolerance {
return Ok(IterativeSolverResult {
solution: x,
iterations: iteration + 1,
residual_norm: r_norm,
converged: true,
});
}
if omega.abs() < A::epsilon() {
return Err(LinalgError::ConvergenceError(
"BiCGSTAB failed: omega is nearly zero".to_string(),
));
}
}
let r_final = b.to_owned() - a.dot(&x);
let residual_norm = r_final.dot(&r_final).sqrt();
Ok(IterativeSolverResult {
solution: x,
iterations: options.max_iterations,
residual_norm,
converged: false,
})
}
#[allow(dead_code)]
fn solve_least_squares_gmres<A>(h: &ArrayView2<A>, beta: A) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Debug + Display + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (m, n) = h.dim();
if m <= n {
return Err(LinalgError::ShapeError(
"Hessenberg matrix must have more rows than columns".to_string(),
));
}
let mut h_copy = h.to_owned();
let mut g = Array1::zeros(m);
g[0] = beta;
for j in 0..n {
for i in (j + 1)..m {
if h_copy[[i, j]].abs() > A::epsilon() {
let a = h_copy[[j, j]];
let b = h_copy[[i, j]];
let r = (a * a + b * b).sqrt();
let c = a / r;
let s = b / r;
for k in j..n {
let hjk = h_copy[[j, k]];
let hik = h_copy[[i, k]];
h_copy[[j, k]] = c * hjk + s * hik;
h_copy[[i, k]] = -s * hjk + c * hik;
}
let gj = g[j];
let gi = g[i];
g[j] = c * gj + s * gi;
g[i] = -s * gj + c * gi;
}
}
}
let mut y = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = g[i];
for j in (i + 1)..n {
sum -= h_copy[[i, j]] * y[j];
}
if h_copy[[i, i]].abs() < A::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular in GMRES least squares solve".to_string(),
));
}
y[i] = sum / h_copy[[i, i]];
}
Ok(y)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_conjugate_gradient() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let options = IterativeSolverOptions::default();
let result =
conjugate_gradient(&a.view(), &b.view(), None, &options).expect("Operation failed");
assert!(result.converged);
assert!(result.iterations < 10);
let residual = &b - a.dot(&result.solution);
assert!(residual.dot(&residual).sqrt() < 1e-10);
}
#[test]
fn test_gmres() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![5.0, 11.0];
let options = IterativeSolverOptions::default();
let result = gmres(&a.view(), &b.view(), None, &options).expect("Operation failed");
assert!(result.converged);
assert_abs_diff_eq!(result.solution[0], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(result.solution[1], 2.0, epsilon = 1e-10);
}
#[test]
fn test_bicgstab() {
let a = array![[4.0, 1.0], [2.0, 3.0]];
let b = array![1.0, 2.0];
let options = IterativeSolverOptions::default();
let result = bicgstab(&a.view(), &b.view(), None, &options).expect("Operation failed");
assert!(result.converged);
let residual = &b - a.dot(&result.solution);
assert!(residual.dot(&residual).sqrt() < 1e-10);
}
#[test]
fn test_preconditioned_cg() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let preconditioner =
|r: &ArrayView1<f64>| -> Array1<f64> { array![r[0] / 4.0, r[1] / 3.0] };
let options = IterativeSolverOptions::default();
let result =
preconditioned_conjugate_gradient(&a.view(), &b.view(), preconditioner, None, &options)
.expect("Operation failed");
assert!(result.converged);
assert!(result.iterations <= 10);
let residual = &b - a.dot(&result.solution);
assert!(residual.dot(&residual).sqrt() < 1e-10);
}
}