use crate::error::{SparseError, SparseResult};
use crate::linalg::interface::LinearOperator;
use scirs2_core::numeric::{Float, NumAssign, SparseElement};
use std::fmt::Display;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct QMRResult<F> {
pub x: Vec<F>,
pub iterations: usize,
pub residual_norm: F,
pub converged: bool,
pub message: String,
}
pub struct QMROptions<F> {
pub max_iter: usize,
pub rtol: F,
pub atol: F,
pub x0: Option<Vec<F>>,
pub left_preconditioner: Option<Box<dyn LinearOperator<F>>>,
pub right_preconditioner: Option<Box<dyn LinearOperator<F>>>,
}
impl<F: Float> Default for QMROptions<F> {
fn default() -> Self {
Self {
max_iter: 1000,
rtol: F::from(1e-8).expect("Failed to convert constant to float"),
atol: F::from(1e-12).expect("Failed to convert constant to float"),
x0: None,
left_preconditioner: None,
right_preconditioner: None,
}
}
}
#[allow(dead_code)]
pub fn qmr<F>(
a: &dyn LinearOperator<F>,
b: &[F],
options: QMROptions<F>,
) -> SparseResult<QMRResult<F>>
where
F: Float + SparseElement + NumAssign + Sum + Display + 'static,
{
let n = b.len();
if a.shape().0 != n || a.shape().1 != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: a.shape().0,
});
}
let mut x = options.x0.unwrap_or_else(|| vec![F::sparse_zero(); n]);
let mut r = if !x.iter().all(|&xi| xi == F::sparse_zero()) {
let ax = a.matvec(&x)?;
vec_sub(b, &ax)
} else {
b.to_vec()
};
if let Some(ref ml) = options.left_preconditioner {
r = ml.matvec(&r)?;
}
let r_tilde = r.clone();
let mut p = vec![F::sparse_zero(); n];
let mut p_tilde = vec![F::sparse_zero(); n];
let mut q = vec![F::sparse_zero(); n];
let mut q_tilde = vec![F::sparse_zero(); n];
let mut rho = F::sparse_one();
let mut rho_old;
let mut alpha = F::sparse_zero();
let mut omega = F::sparse_one();
let bnorm = norm2(b);
let mut rnorm = norm2(&r);
let tol = options.atol + options.rtol * bnorm;
if rnorm < tol {
return Ok(QMRResult {
x,
iterations: 0,
residual_norm: rnorm,
converged: true,
message: "Converged at initial guess".to_string(),
});
}
for iter in 0..options.max_iter {
rho_old = rho;
rho = dot(&r_tilde, &r);
if rho.abs() < F::epsilon() * F::from(10).expect("Failed to convert constant to float") {
return Ok(QMRResult {
x,
iterations: iter,
residual_norm: rnorm,
converged: false,
message: "Breakdown: rho = 0".to_string(),
});
}
let beta = if iter == 0 {
F::sparse_zero()
} else {
(rho / rho_old) * (alpha / omega)
};
p = if iter == 0 {
r.clone()
} else {
vec_add(&r, &vec_scaled(&vec_sub(&p, &vec_scaled(&q, omega)), beta))
};
let p_prec = if let Some(ref mr) = options.right_preconditioner {
mr.matvec(&p)?
} else {
p.clone()
};
q = a.matvec(&p_prec)?;
if let Some(ref ml) = options.left_preconditioner {
q = ml.matvec(&q)?;
}
p_tilde = if iter == 0 {
r_tilde.clone()
} else {
let diff = vec_sub(&p_tilde, &vec_scaled(&q_tilde, omega));
vec_add(&r_tilde, &vec_scaled(&diff, beta))
};
let p_tilde_prec = if let Some(ref ml) = options.left_preconditioner {
ml.rmatvec(&p_tilde)?
} else {
p_tilde.clone()
};
q_tilde = a.rmatvec(&p_tilde_prec)?;
if let Some(ref mr) = options.right_preconditioner {
q_tilde = mr.rmatvec(&q_tilde)?;
}
let dot_pq = dot(&p_tilde, &q);
if dot_pq.abs() < F::epsilon() * F::from(10).expect("Failed to convert constant to float") {
return Ok(QMRResult {
x,
iterations: iter,
residual_norm: rnorm,
converged: false,
message: "Breakdown: <p_tilde, q> = 0".to_string(),
});
}
alpha = rho / dot_pq;
let s = vec_sub(&r, &vec_scaled(&q, alpha));
let _s_tilde = vec_sub(&r_tilde, &vec_scaled(&q_tilde, alpha));
let s_prec = if let Some(ref mr) = options.right_preconditioner {
mr.matvec(&s)?
} else {
s.clone()
};
let t = a.matvec(&s_prec)?;
let t = if let Some(ref ml) = options.left_preconditioner {
ml.matvec(&t)?
} else {
t
};
let dot_tt = dot(&t, &t);
if dot_tt == F::sparse_zero() {
omega = F::sparse_zero();
} else {
omega = dot(&t, &s) / dot_tt;
}
x = vec_add(&x, &vec_scaled(&p_prec, alpha));
x = vec_add(&x, &vec_scaled(&s_prec, omega));
r = vec_sub(&s, &vec_scaled(&t, omega));
rnorm = norm2(&r);
if rnorm < tol {
return Ok(QMRResult {
x,
iterations: iter + 1,
residual_norm: rnorm,
converged: true,
message: format!("Converged in {} iterations", iter + 1),
});
}
if omega.abs() < F::epsilon() {
return Ok(QMRResult {
x,
iterations: iter + 1,
residual_norm: rnorm,
converged: false,
message: "Breakdown: omega = 0".to_string(),
});
}
}
Ok(QMRResult {
x,
iterations: options.max_iter,
residual_norm: rnorm,
converged: false,
message: format!(
"Did not converge in {} iterations. Residual: {}",
options.max_iter, rnorm
),
})
}
#[allow(dead_code)]
fn dot<F: Float + Sum>(a: &[F], b: &[F]) -> F {
a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
}
#[allow(dead_code)]
fn norm2<F: Float + Sum>(v: &[F]) -> F {
v.iter().map(|&vi| vi * vi).sum::<F>().sqrt()
}
#[allow(dead_code)]
fn vec_add<F: Float>(a: &[F], b: &[F]) -> Vec<F> {
a.iter().zip(b.iter()).map(|(&ai, &bi)| ai + bi).collect()
}
#[allow(dead_code)]
fn vec_sub<F: Float>(a: &[F], b: &[F]) -> Vec<F> {
a.iter().zip(b.iter()).map(|(&ai, &bi)| ai - bi).collect()
}
#[allow(dead_code)]
fn vec_scaled<F: Float>(v: &[F], s: F) -> Vec<F> {
v.iter().map(|&vi| vi * s).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::interface::{DiagonalOperator, IdentityOperator};
#[test]
fn test_qmr_identity() {
let identity = IdentityOperator::<f64>::new(3);
let b = vec![1.0, 2.0, 3.0];
let options = QMROptions::default();
let result = qmr(&identity, &b, options).expect("Operation failed");
assert!(result.converged);
assert_eq!(result.iterations, 1); for (i, &b_val) in b.iter().enumerate() {
assert!((result.x[i] - b_val).abs() < 1e-10);
}
}
#[test]
fn test_qmr_diagonal() {
let diag = vec![2.0, 3.0, 4.0];
let diagonal = DiagonalOperator::new(diag.clone());
let b = vec![2.0, 6.0, 8.0]; let expected = [1.0, 2.0, 2.0];
let options = QMROptions {
rtol: 1e-10,
atol: 1e-12,
..Default::default()
};
let result = qmr(&diagonal, &b, options).expect("Operation failed");
assert!(result.converged);
assert!(result.iterations <= 10); for (i, &exp_val) in expected.iter().enumerate() {
assert!(
(result.x[i] - exp_val).abs() < 1e-9,
"x[{}] = {} != {}",
i,
result.x[i],
exp_val
);
}
}
#[test]
fn test_qmr_with_initial_guess() {
let identity = IdentityOperator::<f64>::new(3);
let b = vec![1.0, 2.0, 3.0];
let x0 = vec![0.9, 1.9, 2.9];
let options = QMROptions {
x0: Some(x0),
rtol: 1e-10,
atol: 1e-12,
..Default::default()
};
let result = qmr(&identity, &b, options).expect("Operation failed");
assert!(result.converged);
assert!(result.iterations <= 1); for (i, &b_val) in b.iter().enumerate() {
assert!((result.x[i] - b_val).abs() < 1e-10);
}
}
#[test]
fn test_qmr_max_iterations() {
let diag = vec![1e-8, 1.0, 1.0]; let diagonal = DiagonalOperator::new(diag.clone());
let b = vec![1.0, 1.0, 1.0];
let options = QMROptions {
max_iter: 5,
rtol: 1e-10,
atol: 1e-12,
..Default::default()
};
let result = qmr(&diagonal, &b, options).expect("Operation failed");
if !result.converged {
assert_eq!(result.iterations, 5);
assert!(result.message.contains("Did not converge"));
}
}
}