use crate::algorithm::sparse_linalg::{IluOptions, SparseLinAlgAlgorithms};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::sparse::{CsrData, SparseOps};
use crate::tensor::Tensor;
use super::super::helpers::{BREAKDOWN_TOL, apply_ilu0_preconditioner, vector_dot, vector_norm};
use super::super::types::{PreconditionerType, QmrOptions, QmrResult};
pub fn qmr_impl<R, C>(
client: &C,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: QmrOptions,
) -> Result<QmrResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>,
{
let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
let device = b.device();
let dtype = b.dtype();
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType { dtype, op: "qmr" });
}
let mut x = match x0 {
Some(x0) => x0.clone(),
None => Tensor::<R>::zeros(&[n], dtype, device),
};
let precond = match options.preconditioner {
PreconditionerType::None => None,
PreconditionerType::Ilu0 => Some(client.ilu0(a, IluOptions::default())?),
PreconditionerType::Ic0 | PreconditionerType::Amg => {
return Err(Error::Internal(
"Only None and Ilu0 preconditioners supported for QMR".to_string(),
));
}
};
let b_norm = vector_norm(client, b)?;
if b_norm < options.atol {
return Ok(QmrResult {
solution: x,
iterations: 0,
residual_norm: b_norm,
converged: true,
});
}
let at = a.transpose().to_csr()?;
let ax = a.spmv(&x)?;
let r = client.sub(b, &ax)?;
let r_norm = vector_norm(client, &r)?;
if r_norm < options.atol || r_norm / b_norm < options.rtol {
return Ok(QmrResult {
solution: x,
iterations: 0,
residual_norm: r_norm,
converged: true,
});
}
let mut v_tilde = r.clone();
let mut w_tilde = r.clone();
let mut rho = vector_norm(client, &v_tilde)?;
let mut xi = vector_norm(client, &w_tilde)?;
let mut gamma_prev = 1.0_f64;
let mut eta = -1.0_f64;
let mut theta_prev = 0.0_f64;
let mut v = client.mul_scalar(&v_tilde, 1.0 / rho)?;
let mut w = client.mul_scalar(&w_tilde, 1.0 / xi)?;
let mut d = Tensor::<R>::zeros(&[n], dtype, device);
let mut s = Tensor::<R>::zeros(&[n], dtype, device);
let mut p;
let mut q;
let mut epsilon_prev = 0.0_f64;
let mut p_prev = Tensor::<R>::zeros(&[n], dtype, device);
let mut q_prev = Tensor::<R>::zeros(&[n], dtype, device);
let mut residual = r;
for iter in 0..options.max_iter {
let delta = vector_dot(client, &w, &v)?;
if delta.abs() < BREAKDOWN_TOL {
let res_norm = vector_norm(client, &residual)?;
return Ok(QmrResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
});
}
let y = apply_ilu0_preconditioner(client, &precond, &v)?;
let z = apply_ilu0_preconditioner(client, &precond, &w)?;
if iter == 0 {
p = y.clone();
q = z.clone();
} else {
let coeff_p = (xi * delta) / epsilon_prev;
let coeff_q = (rho * delta) / epsilon_prev;
let pp = client.mul_scalar(&p_prev, coeff_p)?;
let qq = client.mul_scalar(&q_prev, coeff_q)?;
p = client.sub(&y, &pp)?;
q = client.sub(&z, &qq)?;
}
let ap = a.spmv(&p)?;
let epsilon = vector_dot(client, &q, &ap)?;
if epsilon.abs() < BREAKDOWN_TOL {
let res_norm = vector_norm(client, &residual)?;
return Ok(QmrResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
});
}
let beta = epsilon / delta;
let bv = client.mul_scalar(&v, beta)?;
v_tilde = client.sub(&ap, &bv)?;
let atq = at.spmv(&q)?;
let bw = client.mul_scalar(&w, beta)?;
w_tilde = client.sub(&atq, &bw)?;
let rho_new = vector_norm(client, &v_tilde)?;
let xi_new = vector_norm(client, &w_tilde)?;
let theta = rho_new / (gamma_prev * beta.abs());
let gamma = 1.0 / (1.0 + theta * theta).sqrt();
if gamma.abs() < BREAKDOWN_TOL {
let res_norm = vector_norm(client, &residual)?;
return Ok(QmrResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
});
}
let eta_new = -eta * rho * gamma * gamma / (beta * gamma_prev * gamma_prev);
let tg2 = (theta_prev * gamma) * (theta_prev * gamma);
let ep = client.mul_scalar(&p, eta_new)?;
let td = client.mul_scalar(&d, tg2)?;
d = client.add(&ep, &td)?;
let eap = client.mul_scalar(&ap, eta_new)?;
let ts = client.mul_scalar(&s, tg2)?;
s = client.add(&eap, &ts)?;
x = client.add(&x, &d)?;
residual = client.sub(&residual, &s)?;
let res_norm = vector_norm(client, &residual)?;
if res_norm < options.atol || res_norm / b_norm < options.rtol {
return Ok(QmrResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: true,
});
}
if (iter + 1) % 50 == 0 {
let ax_check = a.spmv(&x)?;
residual = client.sub(b, &ax_check)?;
let true_norm = vector_norm(client, &residual)?;
if true_norm < options.atol || true_norm / b_norm < options.rtol {
return Ok(QmrResult {
solution: x,
iterations: iter + 1,
residual_norm: true_norm,
converged: true,
});
}
}
if rho_new < BREAKDOWN_TOL || xi_new < BREAKDOWN_TOL {
return Ok(QmrResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
});
}
v = client.mul_scalar(&v_tilde, 1.0 / rho_new)?;
w = client.mul_scalar(&w_tilde, 1.0 / xi_new)?;
p_prev = p;
q_prev = q;
rho = rho_new;
xi = xi_new;
gamma_prev = gamma;
theta_prev = theta;
eta = eta_new;
epsilon_prev = epsilon;
}
let ax = a.spmv(&x)?;
let r_final = client.sub(b, &ax)?;
let final_residual = vector_norm(client, &r_final)?;
Ok(QmrResult {
solution: x,
iterations: options.max_iter,
residual_norm: final_residual,
converged: false,
})
}