use crate::algorithm::sparse_linalg::{IluOptions, SparseLinAlgAlgorithms};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::sparse::{CsrData, SparseOps};
use crate::tensor::Tensor;
use super::super::helpers::{BREAKDOWN_TOL, apply_ilu0_preconditioner, vector_dot, vector_norm};
use super::super::types::{CgsOptions, CgsResult, PreconditionerType};
pub fn cgs_impl<R, C>(
client: &C,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: CgsOptions,
) -> Result<CgsResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>,
{
let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
let device = b.device();
let dtype = b.dtype();
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType { dtype, op: "cgs" });
}
let mut x = match x0 {
Some(x0) => x0.clone(),
None => Tensor::<R>::zeros(&[n], dtype, device),
};
let precond = match options.preconditioner {
PreconditionerType::None => None,
PreconditionerType::Ilu0 => Some(client.ilu0(a, IluOptions::default())?),
PreconditionerType::Amg => {
return Err(Error::Internal(
"AMG preconditioner not supported for CGS — use amg_preconditioned_cg".to_string(),
));
}
PreconditionerType::Ic0 => {
return Err(Error::Internal(
"IC0 preconditioner not supported for CGS — use ILU0".to_string(),
));
}
};
let b_norm = vector_norm(client, b)?;
if b_norm < options.atol {
return Ok(CgsResult {
solution: x,
iterations: 0,
residual_norm: b_norm,
converged: true,
});
}
let ax = a.spmv(&x)?;
let mut r = client.sub(b, &ax)?;
let r_hat = r.clone();
let mut rho = vector_dot(client, &r_hat, &r)?;
let mut u = r.clone();
let mut p = r.clone();
for iter in 0..options.max_iter {
if rho.abs() < BREAKDOWN_TOL {
let res_norm = vector_norm(client, &r)?;
return Ok(CgsResult {
solution: x,
iterations: iter,
residual_norm: res_norm,
converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
});
}
let p_hat = apply_ilu0_preconditioner(client, &precond, &p)?;
let v = a.spmv(&p_hat)?;
let sigma = vector_dot(client, &r_hat, &v)?;
if sigma.abs() < BREAKDOWN_TOL {
let res_norm = vector_norm(client, &r)?;
return Ok(CgsResult {
solution: x,
iterations: iter,
residual_norm: res_norm,
converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
});
}
let alpha = rho / sigma;
let v_scaled = client.mul_scalar(&v, alpha)?;
let q = client.sub(&u, &v_scaled)?;
let u_plus_q = client.add(&u, &q)?;
let uq_hat = apply_ilu0_preconditioner(client, &precond, &u_plus_q)?;
let uq_scaled = client.mul_scalar(&uq_hat, alpha)?;
x = client.add(&x, &uq_scaled)?;
let a_uq = a.spmv(&uq_hat)?;
let a_uq_scaled = client.mul_scalar(&a_uq, alpha)?;
r = client.sub(&r, &a_uq_scaled)?;
let res_norm = vector_norm(client, &r)?;
if res_norm < options.atol || res_norm / b_norm < options.rtol {
return Ok(CgsResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: true,
});
}
let rho_new = vector_dot(client, &r_hat, &r)?;
let beta = rho_new / rho;
let q_scaled = client.mul_scalar(&q, beta)?;
u = client.add(&r, &q_scaled)?;
let p_scaled = client.mul_scalar(&p, beta)?;
let q_plus_bp = client.add(&q, &p_scaled)?;
let qbp_scaled = client.mul_scalar(&q_plus_bp, beta)?;
p = client.add(&u, &qbp_scaled)?;
rho = rho_new;
}
let ax = a.spmv(&x)?;
let r_final = client.sub(b, &ax)?;
let final_residual = vector_norm(client, &r_final)?;
Ok(CgsResult {
solution: x,
iterations: options.max_iter,
residual_norm: final_residual,
converged: false,
})
}