use crate::algorithm::sparse_linalg::{IluDecomposition, IluOptions, SparseLinAlgAlgorithms};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::sparse::{CsrData, SparseOps};
use crate::tensor::Tensor;
use super::super::helpers::{apply_ilu0_preconditioner, vector_dot, vector_norm};
use super::super::types::{BiCgStabOptions, BiCgStabResult, PreconditionerType};
pub fn bicgstab_impl<R, C>(
client: &C,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: BiCgStabOptions,
) -> Result<BiCgStabResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>,
{
let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
let device = b.device();
let dtype = b.dtype();
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType {
dtype,
op: "bicgstab",
});
}
let mut x = match x0 {
Some(x0) => x0.clone(),
None => Tensor::<R>::zeros(&[n], dtype, device),
};
let precond: Option<IluDecomposition<R>> = match options.preconditioner {
PreconditionerType::None => None,
PreconditionerType::Ilu0 => {
let ilu = client.ilu0(a, IluOptions::default())?;
Some(ilu)
}
PreconditionerType::Amg => {
return Err(Error::Internal(
"AMG preconditioner not supported for BiCGSTAB - use amg_preconditioned_cg"
.to_string(),
));
}
PreconditionerType::Ic0 => {
return Err(Error::Internal(
"IC0 preconditioner not yet supported for BiCGSTAB - use ILU0".to_string(),
));
}
};
let b_norm = vector_norm(client, b)?;
if b_norm < options.atol {
return Ok(BiCgStabResult {
solution: x,
iterations: 0,
residual_norm: b_norm,
converged: true,
});
}
let ax = a.spmv(&x)?;
let mut r = client.sub(b, &ax)?;
let r_hat = r.clone();
let mut rho = 1.0;
let mut alpha = 1.0;
let mut omega = 1.0;
let mut v = Tensor::<R>::zeros(&[n], dtype, device);
let mut p = Tensor::<R>::zeros(&[n], dtype, device);
for iter in 0..options.max_iter {
let rho_new = vector_dot(client, &r_hat, &r)?;
if rho_new.abs() < 1e-40 {
let res_norm = vector_norm(client, &r)?;
return Ok(BiCgStabResult {
solution: x,
iterations: iter,
residual_norm: res_norm,
converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
});
}
let beta = (rho_new / rho) * (alpha / omega);
let p_scaled = client.mul_scalar(&p, beta)?;
let v_scaled = client.mul_scalar(&v, beta * omega)?;
let temp = client.sub(&p_scaled, &v_scaled)?;
p = client.add(&r, &temp)?;
let p_hat = apply_ilu0_preconditioner(client, &precond, &p)?;
v = a.spmv(&p_hat)?;
let r_hat_v = vector_dot(client, &r_hat, &v)?;
if r_hat_v.abs() < 1e-40 {
let res_norm = vector_norm(client, &r)?;
return Ok(BiCgStabResult {
solution: x,
iterations: iter,
residual_norm: res_norm,
converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
});
}
alpha = rho_new / r_hat_v;
let v_scaled = client.mul_scalar(&v, alpha)?;
let s = client.sub(&r, &v_scaled)?;
let s_norm = vector_norm(client, &s)?;
if s_norm < options.atol || s_norm / b_norm < options.rtol {
let p_hat_scaled = client.mul_scalar(&p_hat, alpha)?;
x = client.add(&x, &p_hat_scaled)?;
return Ok(BiCgStabResult {
solution: x,
iterations: iter + 1,
residual_norm: s_norm,
converged: true,
});
}
let s_hat = apply_ilu0_preconditioner(client, &precond, &s)?;
let t = a.spmv(&s_hat)?;
let t_s = vector_dot(client, &t, &s)?;
let t_t = vector_dot(client, &t, &t)?;
if t_t.abs() < 1e-40 {
let res_norm = vector_norm(client, &s)?;
return Ok(BiCgStabResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: false,
});
}
omega = t_s / t_t;
let p_hat_scaled = client.mul_scalar(&p_hat, alpha)?;
let s_hat_scaled = client.mul_scalar(&s_hat, omega)?;
x = client.add(&x, &p_hat_scaled)?;
x = client.add(&x, &s_hat_scaled)?;
let t_scaled = client.mul_scalar(&t, omega)?;
r = client.sub(&s, &t_scaled)?;
rho = rho_new;
let res_norm = vector_norm(client, &r)?;
if res_norm < options.atol || res_norm / b_norm < options.rtol {
return Ok(BiCgStabResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: true,
});
}
if omega.abs() < 1e-40 {
return Ok(BiCgStabResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: false,
});
}
}
let ax = a.spmv(&x)?;
let r_final = client.sub(b, &ax)?;
let final_residual = vector_norm(client, &r_final)?;
Ok(BiCgStabResult {
solution: x,
iterations: options.max_iter,
residual_norm: final_residual,
converged: false,
})
}