use crate::algorithm::sparse_linalg::{IlukOptions, SparseLinAlgAlgorithms};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::sparse::{CsrData, SparseOps};
use crate::tensor::Tensor;
use super::super::helpers::{
apply_iluk_preconditioner, detect_stagnation, givens_rotation, solve_upper_triangular,
update_solution, vector_dot, vector_norm,
};
use super::super::types::{
AdaptiveGmresResult, AdaptivePreconditionerOptions, ConvergenceReason, GmresOptions,
};
pub fn adaptive_gmres_impl<R, C>(
client: &C,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
gmres_opts: GmresOptions,
adaptive_opts: AdaptivePreconditionerOptions,
) -> Result<AdaptiveGmresResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>,
{
let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
let device = b.device();
let dtype = b.dtype();
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType {
dtype,
op: "adaptive_gmres",
});
}
let mut x = match x0 {
Some(x0) => x0.clone(),
None => Tensor::<R>::zeros(&[n], dtype, device),
};
let b_norm = vector_norm(client, b)?;
if b_norm < gmres_opts.atol {
return Ok(AdaptiveGmresResult {
solution: x,
total_iterations: 0,
residual_norm: b_norm,
converged: true,
final_level: adaptive_opts.initial_level,
upgrades: 0,
ilu_metrics: Vec::new(),
reason: ConvergenceReason::ZeroRhs,
});
}
let mut current_level = adaptive_opts.initial_level;
let mut upgrades = 0;
let mut ilu_metrics = Vec::new();
let mut total_iterations = 0;
let mut residual_history: Vec<f64> = Vec::new();
loop {
let iluk_opts = IlukOptions {
fill_level: current_level,
drop_tolerance: 0.0,
diagonal_shift: 1e-10, pivot_threshold: 0.1,
};
let ilu = client.iluk(a, iluk_opts)?;
ilu_metrics.push(ilu.metrics.clone());
let result = gmres_with_iluk(
client,
a,
b,
&x,
&ilu,
&gmres_opts,
&adaptive_opts.stagnation,
b_norm,
&mut residual_history,
)?;
total_iterations += result.iterations;
x = result.solution;
if result.converged {
return Ok(AdaptiveGmresResult {
solution: x,
total_iterations,
residual_norm: result.residual_norm,
converged: true,
final_level: current_level,
upgrades,
ilu_metrics,
reason: result.reason,
});
}
let should_upgrade = matches!(result.reason, ConvergenceReason::Stagnation)
&& upgrades < adaptive_opts.max_upgrades
&& current_level.upgrade().is_some();
if should_upgrade {
current_level = current_level
.upgrade()
.expect("upgrade checked via is_some() above");
upgrades += 1;
if adaptive_opts.restart_on_upgrade {
residual_history.clear();
}
} else {
return Ok(AdaptiveGmresResult {
solution: x,
total_iterations,
residual_norm: result.residual_norm,
converged: false,
final_level: current_level,
upgrades,
ilu_metrics,
reason: result.reason,
});
}
}
}
struct GmresInternalResult<R: Runtime> {
solution: Tensor<R>,
iterations: usize,
residual_norm: f64,
converged: bool,
reason: ConvergenceReason,
}
fn gmres_with_iluk<R, C>(
client: &C,
a: &CsrData<R>,
b: &Tensor<R>,
x0: &Tensor<R>,
ilu: &crate::algorithm::sparse_linalg::IlukDecomposition<R>,
opts: &GmresOptions,
stagnation: &super::super::types::StagnationParams,
b_norm: f64,
residual_history: &mut Vec<f64>,
) -> Result<GmresInternalResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>,
{
let mut x = x0.clone();
let m = opts.restart;
let mut total_iterations = 0;
for _restart_cycle in 0..(opts.max_iter / m + 1) {
let ax = a.spmv(&x)?;
let r = client.sub(b, &ax)?;
let beta = vector_norm(client, &r)?;
if beta < opts.atol || beta / b_norm < opts.rtol {
let atol_met = beta < opts.atol;
let rtol_met = beta / b_norm < opts.rtol;
let reason = if atol_met && rtol_met {
ConvergenceReason::BothTolerances
} else if atol_met {
ConvergenceReason::AbsoluteTolerance
} else {
ConvergenceReason::RelativeTolerance
};
return Ok(GmresInternalResult {
solution: x,
iterations: total_iterations,
residual_norm: beta,
converged: true,
reason,
});
}
let v0 = client.mul_scalar(&r, 1.0 / beta)?;
let mut v_basis: Vec<Tensor<R>> = vec![v0];
let mut z_basis: Vec<Tensor<R>> = Vec::with_capacity(m);
let mut h_matrix: Vec<Vec<f64>> = Vec::with_capacity(m);
let mut cs: Vec<f64> = Vec::with_capacity(m);
let mut sn: Vec<f64> = Vec::with_capacity(m);
let mut g: Vec<f64> = vec![beta];
let mut j = 0;
while j < m && total_iterations < opts.max_iter {
total_iterations += 1;
let vj = &v_basis[j];
let z = apply_iluk_preconditioner(client, ilu, vj)?;
let w = a.spmv(&z)?;
z_basis.push(z);
let mut h_col: Vec<f64> = Vec::with_capacity(j + 2);
let mut w_current = w;
for i in 0..=j {
let h_ij = vector_dot(client, &w_current, &v_basis[i])?;
h_col.push(h_ij);
let scaled_vi = client.mul_scalar(&v_basis[i], h_ij)?;
w_current = client.sub(&w_current, &scaled_vi)?;
}
let h_jp1_j = vector_norm(client, &w_current)?;
h_col.push(h_jp1_j);
for i in 0..j {
let temp = cs[i] * h_col[i] + sn[i] * h_col[i + 1];
h_col[i + 1] = -sn[i] * h_col[i] + cs[i] * h_col[i + 1];
h_col[i] = temp;
}
let (c, s, r) = givens_rotation(h_col[j], h_col[j + 1]);
cs.push(c);
sn.push(s);
h_col[j] = r;
h_col[j + 1] = 0.0;
let g_old_j = g[j];
g.push(-s * g_old_j);
g[j] = c * g_old_j;
h_matrix.push(h_col);
let res_norm = g[j + 1].abs();
residual_history.push(res_norm);
if res_norm < opts.atol || res_norm / b_norm < opts.rtol {
let y = solve_upper_triangular(&h_matrix, &g[..j + 1]);
x = update_solution(client, &x, &z_basis, &y)?;
let atol_met = res_norm < opts.atol;
let rtol_met = res_norm / b_norm < opts.rtol;
let reason = if atol_met && rtol_met {
ConvergenceReason::BothTolerances
} else if atol_met {
ConvergenceReason::AbsoluteTolerance
} else {
ConvergenceReason::RelativeTolerance
};
return Ok(GmresInternalResult {
solution: x,
iterations: total_iterations,
residual_norm: res_norm,
converged: true,
reason,
});
}
if detect_stagnation(residual_history, stagnation) {
let y = solve_upper_triangular(&h_matrix, &g[..j + 1]);
x = update_solution(client, &x, &z_basis, &y)?;
return Ok(GmresInternalResult {
solution: x,
iterations: total_iterations,
residual_norm: res_norm,
converged: false,
reason: ConvergenceReason::Stagnation,
});
}
if h_jp1_j < 1e-14 {
let y = solve_upper_triangular(&h_matrix, &g[..j + 1]);
x = update_solution(client, &x, &z_basis, &y)?;
return Ok(GmresInternalResult {
solution: x,
iterations: total_iterations,
residual_norm: g[j + 1].abs(),
converged: true,
reason: ConvergenceReason::LuckyBreakdown,
});
}
let v_jp1 = client.mul_scalar(&w_current, 1.0 / h_jp1_j)?;
v_basis.push(v_jp1);
j += 1;
}
if !h_matrix.is_empty() {
let y = solve_upper_triangular(&h_matrix, &g[..j]);
x = update_solution(client, &x, &z_basis, &y)?;
}
}
let ax = a.spmv(&x)?;
let r = client.sub(b, &ax)?;
let final_residual = vector_norm(client, &r)?;
Ok(GmresInternalResult {
solution: x,
iterations: total_iterations,
residual_norm: final_residual,
converged: false,
reason: ConvergenceReason::MaxIterationsReached,
})
}
#[cfg(test)]
mod tests {
use super::super::super::helpers::detect_stagnation;
use super::super::super::types::StagnationParams;
#[test]
fn test_stagnation_detection() {
let params = StagnationParams {
reduction_factor: 0.5,
window_size: 3,
min_iterations: 2,
};
let history = vec![1.0, 0.9];
assert!(!detect_stagnation(&history, ¶ms));
let history = vec![1.0, 0.8, 0.6, 0.4, 0.2];
assert!(!detect_stagnation(&history, ¶ms));
let history = vec![1.0, 0.9, 0.85, 0.8, 0.75, 0.72];
assert!(detect_stagnation(&history, ¶ms));
}
}