use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::sparse::{CsrData, SparseOps};
use crate::tensor::Tensor;
use super::super::helpers::{extract_diagonal_inv, vector_norm};
use super::super::types::{JacobiOptions, JacobiResult};
pub fn jacobi_impl<R, C>(
client: &C,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: JacobiOptions,
) -> Result<JacobiResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseOps<R> + BinaryOps<R> + UnaryOps<R> + ReduceOps<R> + ScalarOps<R>,
{
let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
let device = b.device();
let dtype = b.dtype();
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType {
dtype,
op: "jacobi",
});
}
let mut x = match x0 {
Some(x0) => x0.clone(),
None => Tensor::<R>::zeros(&[n], dtype, device),
};
let b_norm = vector_norm(client, b)?;
if b_norm < options.atol {
return Ok(JacobiResult {
solution: x,
iterations: 0,
residual_norm: b_norm,
converged: true,
});
}
let d_inv = extract_diagonal_inv(client, a)?;
for iter in 0..options.max_iter {
let ax = a.spmv(&x)?;
let r = client.sub(b, &ax)?;
let res_norm = vector_norm(client, &r)?;
if res_norm < options.atol || res_norm / b_norm < options.rtol {
return Ok(JacobiResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: true,
});
}
let d_inv_r = client.mul(&d_inv, &r)?;
let update = client.mul_scalar(&d_inv_r, options.omega)?;
x = client.add(&x, &update)?;
}
let ax = a.spmv(&x)?;
let r = client.sub(b, &ax)?;
let final_residual = vector_norm(client, &r)?;
Ok(JacobiResult {
solution: x,
iterations: options.max_iter,
residual_norm: final_residual,
converged: false,
})
}