use crate::algorithm::sparse_linalg::{IluOptions, SparseLinAlgAlgorithms};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::sparse::{CsrData, SparseOps};
use crate::tensor::Tensor;
use super::super::helpers::{
apply_ilu0_preconditioner, givens_rotation, solve_upper_triangular, update_solution,
vector_dot, vector_norm,
};
use super::super::types::{
ConvergenceReason, GmresDiagnostics, GmresOptions, GmresResult, PreconditionerType,
};
pub fn gmres_impl<R, C>(
client: &C,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: GmresOptions,
) -> Result<GmresResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>,
{
let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
let device = b.device();
let dtype = b.dtype();
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType { dtype, op: "gmres" });
}
let mut x = match x0 {
Some(x0) => x0.clone(),
None => Tensor::<R>::zeros(&[n], dtype, device),
};
let precond = match options.preconditioner {
PreconditionerType::None => None,
PreconditionerType::Ilu0 => {
let ilu = client.ilu0(a, IluOptions::default())?;
Some(ilu)
}
PreconditionerType::Amg => {
return Err(Error::Internal(
"AMG preconditioner not supported for GMRES — use amg_preconditioned_cg"
.to_string(),
));
}
PreconditionerType::Ic0 => {
return Err(Error::Internal(
"IC0 preconditioner not yet supported for GMRES - use ILU0".to_string(),
));
}
};
let b_norm = vector_norm(client, b)?;
let mut residual_history = if options.track_residual_history {
Vec::with_capacity(options.max_iter)
} else {
Vec::new()
};
let build_diagnostics = |initial_residual_norm: f64, history: Vec<f64>| GmresDiagnostics {
rtol: options.rtol,
atol: options.atol,
max_iter: options.max_iter,
restart: options.restart,
initial_residual_norm,
rhs_norm: b_norm,
residual_history: history,
};
if b_norm < options.atol {
let reason = ConvergenceReason::ZeroRhs;
return Ok(GmresResult {
solution: x,
iterations: 0,
residual_norm: b_norm,
converged: true,
reason,
diagnostics: build_diagnostics(b_norm, residual_history),
});
}
let m = options.restart;
let mut total_iterations = 0;
let mut initial_residual_norm = 0.0;
for restart_cycle in 0..(options.max_iter / m + 1) {
let ax = a.spmv(&x)?;
let r = client.sub(b, &ax)?;
let beta = vector_norm(client, &r)?;
if restart_cycle == 0 {
initial_residual_norm = beta;
}
let atol_met = beta < options.atol;
let rtol_met = beta / b_norm < options.rtol;
if atol_met || rtol_met {
let reason = if atol_met && rtol_met {
ConvergenceReason::BothTolerances
} else if atol_met {
ConvergenceReason::AbsoluteTolerance
} else {
ConvergenceReason::RelativeTolerance
};
return Ok(GmresResult {
solution: x,
iterations: total_iterations,
residual_norm: beta,
converged: true,
reason,
diagnostics: build_diagnostics(initial_residual_norm, residual_history),
});
}
let v0 = client.mul_scalar(&r, 1.0 / beta)?;
let mut v_basis: Vec<Tensor<R>> = vec![v0];
let mut z_basis: Vec<Tensor<R>> = Vec::with_capacity(m);
let mut h_matrix: Vec<Vec<f64>> = Vec::with_capacity(m);
let mut cs: Vec<f64> = Vec::with_capacity(m);
let mut sn: Vec<f64> = Vec::with_capacity(m);
let mut g: Vec<f64> = vec![beta];
let mut j = 0;
while j < m && total_iterations < options.max_iter {
total_iterations += 1;
let vj = &v_basis[j];
let z = apply_ilu0_preconditioner(client, &precond, vj)?;
let w = a.spmv(&z)?;
z_basis.push(z);
let mut h_col: Vec<f64> = Vec::with_capacity(j + 2);
let mut w_current = w;
for i in 0..=j {
let h_ij = vector_dot(client, &w_current, &v_basis[i])?;
h_col.push(h_ij);
let scaled_vi = client.mul_scalar(&v_basis[i], h_ij)?;
w_current = client.sub(&w_current, &scaled_vi)?;
}
let h_jp1_j = vector_norm(client, &w_current)?;
h_col.push(h_jp1_j);
for i in 0..j {
let temp = cs[i] * h_col[i] + sn[i] * h_col[i + 1];
h_col[i + 1] = -sn[i] * h_col[i] + cs[i] * h_col[i + 1];
h_col[i] = temp;
}
let (c, s, r) = givens_rotation(h_col[j], h_col[j + 1]);
cs.push(c);
sn.push(s);
h_col[j] = r;
h_col[j + 1] = 0.0;
let g_old_j = g[j];
g.push(-s * g_old_j);
g[j] = c * g_old_j;
h_matrix.push(h_col);
let res_norm = g[j + 1].abs();
if options.track_residual_history {
residual_history.push(res_norm);
}
let atol_met = res_norm < options.atol;
let rtol_met = res_norm / b_norm < options.rtol;
if atol_met || rtol_met {
let y = solve_upper_triangular(&h_matrix, &g[..j + 1]);
x = update_solution(client, &x, &z_basis, &y)?;
let reason = if atol_met && rtol_met {
ConvergenceReason::BothTolerances
} else if atol_met {
ConvergenceReason::AbsoluteTolerance
} else {
ConvergenceReason::RelativeTolerance
};
return Ok(GmresResult {
solution: x,
iterations: total_iterations,
residual_norm: res_norm,
converged: true,
reason,
diagnostics: build_diagnostics(initial_residual_norm, residual_history),
});
}
if h_jp1_j < 1e-14 {
let y = solve_upper_triangular(&h_matrix, &g[..j + 1]);
x = update_solution(client, &x, &z_basis, &y)?;
return Ok(GmresResult {
solution: x,
iterations: total_iterations,
residual_norm: g[j + 1].abs(),
converged: true,
reason: ConvergenceReason::LuckyBreakdown,
diagnostics: build_diagnostics(initial_residual_norm, residual_history),
});
}
let v_jp1 = client.mul_scalar(&w_current, 1.0 / h_jp1_j)?;
v_basis.push(v_jp1);
j += 1;
}
if !h_matrix.is_empty() {
let y = solve_upper_triangular(&h_matrix, &g[..j]);
x = update_solution(client, &x, &z_basis, &y)?;
}
}
let ax = a.spmv(&x)?;
let r = client.sub(b, &ax)?;
let final_residual = vector_norm(client, &r)?;
Ok(GmresResult {
solution: x,
iterations: total_iterations,
residual_norm: final_residual,
converged: false,
reason: ConvergenceReason::MaxIterationsReached,
diagnostics: build_diagnostics(initial_residual_norm, residual_history),
})
}