use crate::error::{SparseError, SparseResult};
use crate::linalg::interface::LinearOperator;
use scirs2_core::numeric::{Float, NumAssign, SparseElement};
use std::fmt::Debug;
use std::iter::Sum;
pub struct LGMRESOptions<F> {
pub max_iter: usize,
pub rtol: F,
pub atol: F,
pub inner_m: usize,
pub outer_k: usize,
pub x0: Option<Vec<F>>,
pub preconditioner: Option<Box<dyn LinearOperator<F>>>,
}
impl<F: Float> Default for LGMRESOptions<F> {
fn default() -> Self {
Self {
max_iter: 500,
rtol: F::from(1e-8).expect("Failed to convert constant to float"),
atol: F::from(1e-8).expect("Failed to convert constant to float"),
inner_m: 30,
outer_k: 3,
x0: None,
preconditioner: None,
}
}
}
#[derive(Debug, Clone)]
pub struct LGMRESResult<F> {
pub x: Vec<F>,
pub iterations: usize,
pub residual_norm: F,
pub converged: bool,
}
#[allow(dead_code)]
pub fn lgmres<F>(
a: &dyn LinearOperator<F>,
b: &[F],
options: LGMRESOptions<F>,
) -> SparseResult<LGMRESResult<F>>
where
F: Float + SparseElement + NumAssign + Sum + Debug + 'static,
{
let (m, n) = a.shape();
if m != n {
return Err(SparseError::DimensionMismatch {
expected: m,
found: n,
});
}
if b.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: b.len(),
});
}
let mut x = options
.x0
.clone()
.unwrap_or_else(|| vec![F::sparse_zero(); n]);
let b_norm = b.iter().map(|&bi| bi * bi).sum::<F>().sqrt();
if b_norm < options.atol {
return Ok(LGMRESResult {
x,
iterations: 0,
residual_norm: F::sparse_zero(),
converged: true,
});
}
let mut r = if let Some(ref m) = options.preconditioner {
let ax = a.matvec(&x)?;
let residual: Vec<F> = b
.iter()
.zip(ax.iter())
.map(|(&bi, &axi)| bi - axi)
.collect();
m.matvec(&residual)?
} else {
let ax = a.matvec(&x)?;
b.iter()
.zip(ax.iter())
.map(|(&bi, &axi)| bi - axi)
.collect()
};
let mut r_norm = r.iter().map(|&ri| ri * ri).sum::<F>().sqrt();
if r_norm < options.atol || r_norm / b_norm < options.rtol {
return Ok(LGMRESResult {
x,
iterations: 0,
residual_norm: r_norm,
converged: true,
});
}
let mut augmented_vectors: Vec<Vec<F>> = Vec::new();
let mut total_iter = 0;
for _ in 0..options.max_iter {
let (y, new_r_norm, v_list) = inner_gmres(
a,
&r,
options.inner_m,
&augmented_vectors,
options.preconditioner.as_deref(),
)?;
for (xi, &yi) in x.iter_mut().zip(y.iter()) {
*xi += yi;
}
total_iter += 1;
if let Some(ref m) = options.preconditioner {
let ax = a.matvec(&x)?;
let residual: Vec<F> = b
.iter()
.zip(ax.iter())
.map(|(&bi, &axi)| bi - axi)
.collect();
r = m.matvec(&residual)?;
} else {
let ax = a.matvec(&x)?;
r = b
.iter()
.zip(ax.iter())
.map(|(&bi, &axi)| bi - axi)
.collect();
}
r_norm = r.iter().map(|&ri| ri * ri).sum::<F>().sqrt();
if r_norm < options.atol || r_norm / b_norm < options.rtol {
return Ok(LGMRESResult {
x,
iterations: total_iter,
residual_norm: r_norm,
converged: true,
});
}
for v in v_list {
augmented_vectors.push(v);
}
if augmented_vectors.len() > options.outer_k {
augmented_vectors.drain(0..augmented_vectors.len() - options.outer_k);
}
}
Ok(LGMRESResult {
x,
iterations: total_iter,
residual_norm: r_norm,
converged: false,
})
}
#[allow(dead_code)]
fn inner_gmres<F>(
a: &dyn LinearOperator<F>,
r0: &[F],
m: usize,
augmented_vectors: &[Vec<F>],
preconditioner: Option<&dyn LinearOperator<F>>,
) -> SparseResult<(Vec<F>, F, Vec<Vec<F>>)>
where
F: Float + SparseElement + NumAssign + Sum + Debug + 'static,
{
let n = r0.len();
let mut v = vec![vec![F::sparse_zero(); n]; m + 1];
let r0_norm = r0.iter().map(|&ri| ri * ri).sum::<F>().sqrt();
if r0_norm < F::epsilon() {
return Ok((vec![F::sparse_zero(); n], F::sparse_zero(), vec![]));
}
v[0] = r0.iter().map(|&ri| ri / r0_norm).collect();
let mut h = vec![vec![F::sparse_zero(); m]; m + 1];
let mut s = vec![F::sparse_zero(); m + 1];
let mut c = vec![F::sparse_zero(); m + 1];
let mut beta = vec![F::sparse_zero(); m + 2];
beta[0] = r0_norm;
let mut k = 0; for j in 0..m {
let w = if let Some(prec) = preconditioner {
let av = a.matvec(&v[j])?;
prec.matvec(&av)?
} else {
a.matvec(&v[j])?
};
let mut w_orth = w.clone();
for i in 0..=j {
let h_ij = w
.iter()
.zip(v[i].iter())
.map(|(&wi, &vi)| wi * vi)
.sum::<F>();
h[i][j] = h_ij;
for (idx, w_elem) in w_orth.iter_mut().enumerate().take(n) {
*w_elem -= h_ij * v[i][idx];
}
}
let h_jp1_j = w_orth.iter().map(|&wi| wi * wi).sum::<F>().sqrt();
if h_jp1_j > F::epsilon() {
h[j + 1][j] = h_jp1_j;
v[j + 1] = w_orth.iter().map(|&wi| wi / h_jp1_j).collect();
} else {
k = j + 1;
break;
}
for i in 0..j {
let temp = c[i] * h[i][j] + s[i] * h[i + 1][j];
h[i + 1][j] = -s[i] * h[i][j] + c[i] * h[i + 1][j];
h[i][j] = temp;
}
let h_jj = h[j][j];
let h_jp1_j = h[j + 1][j];
let rho = (h_jj * h_jj + h_jp1_j * h_jp1_j).sqrt();
if rho > F::epsilon() {
c[j] = h_jj / rho;
s[j] = h_jp1_j / rho;
h[j][j] = c[j] * h_jj + s[j] * h_jp1_j;
h[j + 1][j] = F::sparse_zero();
beta[j + 1] = -s[j] * beta[j];
beta[j] = c[j] * beta[j];
k = j + 1;
if beta[j + 1].abs() < F::from(1e-10).expect("Failed to convert constant to float") {
break;
}
} else {
k = j;
break;
}
}
if k == 0 {
return Ok((vec![F::sparse_zero(); n], r0_norm, vec![]));
}
let mut y = vec![F::sparse_zero(); k];
for i in (0..k).rev() {
y[i] = beta[i];
for j in (i + 1)..k {
y[i] = y[i] - h[i][j] * y[j];
}
if h[i][i].abs() > F::epsilon() {
y[i] /= h[i][i];
} else {
y[i] = F::sparse_zero();
}
}
let mut x = vec![F::sparse_zero(); n];
for i in 0..k {
for (j, x_val) in x.iter_mut().enumerate().take(n) {
*x_val += y[i] * v[i][j];
}
}
let v_list: Vec<Vec<F>> = if k > 1 {
v.into_iter().skip(1).take(k - 1).collect()
} else {
vec![]
};
Ok((x, beta[k].abs(), v_list))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr::CsrMatrix;
use crate::linalg::interface::{AsLinearOperator, IdentityOperator};
#[test]
fn test_lgmres_identity() {
let a = IdentityOperator::new(3);
let b = vec![1.0, 2.0, 3.0];
let options = LGMRESOptions::default();
let result = lgmres(&a, &b, options).expect("Operation failed");
assert!(result.converged);
assert!((result.x[0] - 1.0).abs() < 1e-10);
assert!((result.x[1] - 2.0).abs() < 1e-10);
assert!((result.x[2] - 3.0).abs() < 1e-10);
}
#[test]
fn test_lgmres_spd_matrix() {
let data = vec![4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
let indptr = vec![0, 2, 5, 7];
let indices = vec![0, 1, 0, 1, 2, 1, 2];
let matrix =
CsrMatrix::from_raw_csr(data, indptr, indices, (3, 3)).expect("Operation failed");
let linear_op = matrix.as_linear_operator();
let b = vec![1.0, 2.0, 3.0];
let options = LGMRESOptions::default();
let result = lgmres(linear_op.as_ref(), &b, options).expect("Operation failed");
assert!(result.converged);
assert!(result.residual_norm < 1e-8);
}
}