Skip to main content

numr/algorithm/iterative/impl_generic/
qmr.rs

1//! Generic QMR (Quasi-Minimal Residual) implementation
2//!
3//! Based on the QMR algorithm from Barrett et al., "Templates for the
4//! Solution of Linear Systems: Building Blocks for Iterative Methods".
5//!
6//! QMR maintains a quasi-minimal residual property, providing smoother
7//! convergence than BiCGSTAB for many non-symmetric systems.
8
9use crate::algorithm::sparse_linalg::{IluOptions, SparseLinAlgAlgorithms};
10use crate::dtype::DType;
11use crate::error::{Error, Result};
12use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
13use crate::runtime::Runtime;
14use crate::sparse::{CsrData, SparseOps};
15use crate::tensor::Tensor;
16
17use super::super::helpers::{BREAKDOWN_TOL, apply_ilu0_preconditioner, vector_dot, vector_norm};
18use super::super::types::{PreconditionerType, QmrOptions, QmrResult};
19
20/// Generic QMR implementation
21///
22/// Uses coupled two-term Lanczos biorthogonalization with quasi-minimal
23/// residual smoothing. Follows the Templates book algorithm.
24pub fn qmr_impl<R, C>(
25    client: &C,
26    a: &CsrData<R>,
27    b: &Tensor<R>,
28    x0: Option<&Tensor<R>>,
29    options: QmrOptions,
30) -> Result<QmrResult<R>>
31where
32    R: Runtime<DType = DType>,
33    R::Client: SparseOps<R>,
34    C: SparseLinAlgAlgorithms<R>
35        + SparseOps<R>
36        + BinaryOps<R>
37        + UnaryOps<R>
38        + ReduceOps<R>
39        + ScalarOps<R>,
40{
41    let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
42    let device = b.device();
43    let dtype = b.dtype();
44
45    if !matches!(dtype, DType::F32 | DType::F64) {
46        return Err(Error::UnsupportedDType { dtype, op: "qmr" });
47    }
48
49    let mut x = match x0 {
50        Some(x0) => x0.clone(),
51        None => Tensor::<R>::zeros(&[n], dtype, device),
52    };
53
54    let precond = match options.preconditioner {
55        PreconditionerType::None => None,
56        PreconditionerType::Ilu0 => Some(client.ilu0(a, IluOptions::default())?),
57        PreconditionerType::Ic0 | PreconditionerType::Amg => {
58            return Err(Error::Internal(
59                "Only None and Ilu0 preconditioners supported for QMR".to_string(),
60            ));
61        }
62    };
63
64    let b_norm = vector_norm(client, b)?;
65    if b_norm < options.atol {
66        return Ok(QmrResult {
67            solution: x,
68            iterations: 0,
69            residual_norm: b_norm,
70            converged: true,
71        });
72    }
73
74    // Build A^T for transpose SpMV
75    let at = a.transpose().to_csr()?;
76
77    // r = b - A*x
78    let ax = a.spmv(&x)?;
79    let r = client.sub(b, &ax)?;
80
81    let r_norm = vector_norm(client, &r)?;
82    if r_norm < options.atol || r_norm / b_norm < options.rtol {
83        return Ok(QmrResult {
84            solution: x,
85            iterations: 0,
86            residual_norm: r_norm,
87            converged: true,
88        });
89    }
90
91    // v_tilde = r, w_tilde = r
92    let mut v_tilde = r.clone();
93    let mut w_tilde = r.clone();
94
95    let mut rho = vector_norm(client, &v_tilde)?;
96    let mut xi = vector_norm(client, &w_tilde)?;
97
98    let mut gamma_prev = 1.0_f64;
99    let mut eta = -1.0_f64;
100    let mut theta_prev = 0.0_f64;
101
102    let mut v = client.mul_scalar(&v_tilde, 1.0 / rho)?;
103    let mut w = client.mul_scalar(&w_tilde, 1.0 / xi)?;
104
105    let mut d = Tensor::<R>::zeros(&[n], dtype, device);
106    let mut s = Tensor::<R>::zeros(&[n], dtype, device);
107
108    let mut p;
109    let mut q;
110    let mut epsilon_prev = 0.0_f64;
111
112    // Track p_prev and q_prev for the recurrence
113    let mut p_prev = Tensor::<R>::zeros(&[n], dtype, device);
114    let mut q_prev = Tensor::<R>::zeros(&[n], dtype, device);
115
116    let mut residual = r;
117
118    for iter in 0..options.max_iter {
119        let delta = vector_dot(client, &w, &v)?;
120        if delta.abs() < BREAKDOWN_TOL {
121            let res_norm = vector_norm(client, &residual)?;
122            return Ok(QmrResult {
123                solution: x,
124                iterations: iter + 1,
125                residual_norm: res_norm,
126                converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
127            });
128        }
129
130        // Apply preconditioner
131        let y = apply_ilu0_preconditioner(client, &precond, &v)?;
132        let z = apply_ilu0_preconditioner(client, &precond, &w)?;
133
134        // p, q direction vectors
135        if iter == 0 {
136            p = y.clone();
137            q = z.clone();
138        } else {
139            let coeff_p = (xi * delta) / epsilon_prev;
140            let coeff_q = (rho * delta) / epsilon_prev;
141            let pp = client.mul_scalar(&p_prev, coeff_p)?;
142            let qq = client.mul_scalar(&q_prev, coeff_q)?;
143            p = client.sub(&y, &pp)?;
144            q = client.sub(&z, &qq)?;
145        }
146
147        // epsilon = <q, A*p>
148        let ap = a.spmv(&p)?;
149        let epsilon = vector_dot(client, &q, &ap)?;
150        if epsilon.abs() < BREAKDOWN_TOL {
151            let res_norm = vector_norm(client, &residual)?;
152            return Ok(QmrResult {
153                solution: x,
154                iterations: iter + 1,
155                residual_norm: res_norm,
156                converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
157            });
158        }
159
160        let beta = epsilon / delta;
161
162        // v_tilde = A*p - beta*v
163        let bv = client.mul_scalar(&v, beta)?;
164        v_tilde = client.sub(&ap, &bv)?;
165
166        // w_tilde = A^T*q - conj(beta)*w
167        let atq = at.spmv(&q)?;
168        let bw = client.mul_scalar(&w, beta)?;
169        w_tilde = client.sub(&atq, &bw)?;
170
171        let rho_new = vector_norm(client, &v_tilde)?;
172        let xi_new = vector_norm(client, &w_tilde)?;
173
174        // QMR quasi-minimization
175        let theta = rho_new / (gamma_prev * beta.abs());
176        let gamma = 1.0 / (1.0 + theta * theta).sqrt();
177
178        if gamma.abs() < BREAKDOWN_TOL {
179            let res_norm = vector_norm(client, &residual)?;
180            return Ok(QmrResult {
181                solution: x,
182                iterations: iter + 1,
183                residual_norm: res_norm,
184                converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
185            });
186        }
187
188        let eta_new = -eta * rho * gamma * gamma / (beta * gamma_prev * gamma_prev);
189
190        // d = eta*p + (theta_prev*gamma)^2 * d_prev
191        let tg2 = (theta_prev * gamma) * (theta_prev * gamma);
192        let ep = client.mul_scalar(&p, eta_new)?;
193        let td = client.mul_scalar(&d, tg2)?;
194        d = client.add(&ep, &td)?;
195
196        // s = eta*A*p + (theta_prev*gamma)^2 * s_prev
197        let eap = client.mul_scalar(&ap, eta_new)?;
198        let ts = client.mul_scalar(&s, tg2)?;
199        s = client.add(&eap, &ts)?;
200
201        // x = x + d
202        x = client.add(&x, &d)?;
203
204        // r = r - s
205        residual = client.sub(&residual, &s)?;
206
207        let res_norm = vector_norm(client, &residual)?;
208        if res_norm < options.atol || res_norm / b_norm < options.rtol {
209            return Ok(QmrResult {
210                solution: x,
211                iterations: iter + 1,
212                residual_norm: res_norm,
213                converged: true,
214            });
215        }
216
217        // Check true residual periodically to guard against drift
218        if (iter + 1) % 50 == 0 {
219            let ax_check = a.spmv(&x)?;
220            residual = client.sub(b, &ax_check)?;
221            let true_norm = vector_norm(client, &residual)?;
222            if true_norm < options.atol || true_norm / b_norm < options.rtol {
223                return Ok(QmrResult {
224                    solution: x,
225                    iterations: iter + 1,
226                    residual_norm: true_norm,
227                    converged: true,
228                });
229            }
230        }
231
232        // Update for next iteration
233        if rho_new < BREAKDOWN_TOL || xi_new < BREAKDOWN_TOL {
234            return Ok(QmrResult {
235                solution: x,
236                iterations: iter + 1,
237                residual_norm: res_norm,
238                converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
239            });
240        }
241
242        v = client.mul_scalar(&v_tilde, 1.0 / rho_new)?;
243        w = client.mul_scalar(&w_tilde, 1.0 / xi_new)?;
244
245        p_prev = p;
246        q_prev = q;
247        rho = rho_new;
248        xi = xi_new;
249        gamma_prev = gamma;
250        theta_prev = theta;
251        eta = eta_new;
252        epsilon_prev = epsilon;
253    }
254
255    let ax = a.spmv(&x)?;
256    let r_final = client.sub(b, &ax)?;
257    let final_residual = vector_norm(client, &r_final)?;
258
259    Ok(QmrResult {
260        solution: x,
261        iterations: options.max_iter,
262        residual_norm: final_residual,
263        converged: false,
264    })
265}