numr/algorithm/iterative/impl_generic/
qmr.rs1use crate::algorithm::sparse_linalg::{IluOptions, SparseLinAlgAlgorithms};
10use crate::dtype::DType;
11use crate::error::{Error, Result};
12use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
13use crate::runtime::Runtime;
14use crate::sparse::{CsrData, SparseOps};
15use crate::tensor::Tensor;
16
17use super::super::helpers::{BREAKDOWN_TOL, apply_ilu0_preconditioner, vector_dot, vector_norm};
18use super::super::types::{PreconditionerType, QmrOptions, QmrResult};
19
20pub fn qmr_impl<R, C>(
25 client: &C,
26 a: &CsrData<R>,
27 b: &Tensor<R>,
28 x0: Option<&Tensor<R>>,
29 options: QmrOptions,
30) -> Result<QmrResult<R>>
31where
32 R: Runtime<DType = DType>,
33 R::Client: SparseOps<R>,
34 C: SparseLinAlgAlgorithms<R>
35 + SparseOps<R>
36 + BinaryOps<R>
37 + UnaryOps<R>
38 + ReduceOps<R>
39 + ScalarOps<R>,
40{
41 let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
42 let device = b.device();
43 let dtype = b.dtype();
44
45 if !matches!(dtype, DType::F32 | DType::F64) {
46 return Err(Error::UnsupportedDType { dtype, op: "qmr" });
47 }
48
49 let mut x = match x0 {
50 Some(x0) => x0.clone(),
51 None => Tensor::<R>::zeros(&[n], dtype, device),
52 };
53
54 let precond = match options.preconditioner {
55 PreconditionerType::None => None,
56 PreconditionerType::Ilu0 => Some(client.ilu0(a, IluOptions::default())?),
57 PreconditionerType::Ic0 | PreconditionerType::Amg => {
58 return Err(Error::Internal(
59 "Only None and Ilu0 preconditioners supported for QMR".to_string(),
60 ));
61 }
62 };
63
64 let b_norm = vector_norm(client, b)?;
65 if b_norm < options.atol {
66 return Ok(QmrResult {
67 solution: x,
68 iterations: 0,
69 residual_norm: b_norm,
70 converged: true,
71 });
72 }
73
74 let at = a.transpose().to_csr()?;
76
77 let ax = a.spmv(&x)?;
79 let r = client.sub(b, &ax)?;
80
81 let r_norm = vector_norm(client, &r)?;
82 if r_norm < options.atol || r_norm / b_norm < options.rtol {
83 return Ok(QmrResult {
84 solution: x,
85 iterations: 0,
86 residual_norm: r_norm,
87 converged: true,
88 });
89 }
90
91 let mut v_tilde = r.clone();
93 let mut w_tilde = r.clone();
94
95 let mut rho = vector_norm(client, &v_tilde)?;
96 let mut xi = vector_norm(client, &w_tilde)?;
97
98 let mut gamma_prev = 1.0_f64;
99 let mut eta = -1.0_f64;
100 let mut theta_prev = 0.0_f64;
101
102 let mut v = client.mul_scalar(&v_tilde, 1.0 / rho)?;
103 let mut w = client.mul_scalar(&w_tilde, 1.0 / xi)?;
104
105 let mut d = Tensor::<R>::zeros(&[n], dtype, device);
106 let mut s = Tensor::<R>::zeros(&[n], dtype, device);
107
108 let mut p;
109 let mut q;
110 let mut epsilon_prev = 0.0_f64;
111
112 let mut p_prev = Tensor::<R>::zeros(&[n], dtype, device);
114 let mut q_prev = Tensor::<R>::zeros(&[n], dtype, device);
115
116 let mut residual = r;
117
118 for iter in 0..options.max_iter {
119 let delta = vector_dot(client, &w, &v)?;
120 if delta.abs() < BREAKDOWN_TOL {
121 let res_norm = vector_norm(client, &residual)?;
122 return Ok(QmrResult {
123 solution: x,
124 iterations: iter + 1,
125 residual_norm: res_norm,
126 converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
127 });
128 }
129
130 let y = apply_ilu0_preconditioner(client, &precond, &v)?;
132 let z = apply_ilu0_preconditioner(client, &precond, &w)?;
133
134 if iter == 0 {
136 p = y.clone();
137 q = z.clone();
138 } else {
139 let coeff_p = (xi * delta) / epsilon_prev;
140 let coeff_q = (rho * delta) / epsilon_prev;
141 let pp = client.mul_scalar(&p_prev, coeff_p)?;
142 let qq = client.mul_scalar(&q_prev, coeff_q)?;
143 p = client.sub(&y, &pp)?;
144 q = client.sub(&z, &qq)?;
145 }
146
147 let ap = a.spmv(&p)?;
149 let epsilon = vector_dot(client, &q, &ap)?;
150 if epsilon.abs() < BREAKDOWN_TOL {
151 let res_norm = vector_norm(client, &residual)?;
152 return Ok(QmrResult {
153 solution: x,
154 iterations: iter + 1,
155 residual_norm: res_norm,
156 converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
157 });
158 }
159
160 let beta = epsilon / delta;
161
162 let bv = client.mul_scalar(&v, beta)?;
164 v_tilde = client.sub(&ap, &bv)?;
165
166 let atq = at.spmv(&q)?;
168 let bw = client.mul_scalar(&w, beta)?;
169 w_tilde = client.sub(&atq, &bw)?;
170
171 let rho_new = vector_norm(client, &v_tilde)?;
172 let xi_new = vector_norm(client, &w_tilde)?;
173
174 let theta = rho_new / (gamma_prev * beta.abs());
176 let gamma = 1.0 / (1.0 + theta * theta).sqrt();
177
178 if gamma.abs() < BREAKDOWN_TOL {
179 let res_norm = vector_norm(client, &residual)?;
180 return Ok(QmrResult {
181 solution: x,
182 iterations: iter + 1,
183 residual_norm: res_norm,
184 converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
185 });
186 }
187
188 let eta_new = -eta * rho * gamma * gamma / (beta * gamma_prev * gamma_prev);
189
190 let tg2 = (theta_prev * gamma) * (theta_prev * gamma);
192 let ep = client.mul_scalar(&p, eta_new)?;
193 let td = client.mul_scalar(&d, tg2)?;
194 d = client.add(&ep, &td)?;
195
196 let eap = client.mul_scalar(&ap, eta_new)?;
198 let ts = client.mul_scalar(&s, tg2)?;
199 s = client.add(&eap, &ts)?;
200
201 x = client.add(&x, &d)?;
203
204 residual = client.sub(&residual, &s)?;
206
207 let res_norm = vector_norm(client, &residual)?;
208 if res_norm < options.atol || res_norm / b_norm < options.rtol {
209 return Ok(QmrResult {
210 solution: x,
211 iterations: iter + 1,
212 residual_norm: res_norm,
213 converged: true,
214 });
215 }
216
217 if (iter + 1) % 50 == 0 {
219 let ax_check = a.spmv(&x)?;
220 residual = client.sub(b, &ax_check)?;
221 let true_norm = vector_norm(client, &residual)?;
222 if true_norm < options.atol || true_norm / b_norm < options.rtol {
223 return Ok(QmrResult {
224 solution: x,
225 iterations: iter + 1,
226 residual_norm: true_norm,
227 converged: true,
228 });
229 }
230 }
231
232 if rho_new < BREAKDOWN_TOL || xi_new < BREAKDOWN_TOL {
234 return Ok(QmrResult {
235 solution: x,
236 iterations: iter + 1,
237 residual_norm: res_norm,
238 converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
239 });
240 }
241
242 v = client.mul_scalar(&v_tilde, 1.0 / rho_new)?;
243 w = client.mul_scalar(&w_tilde, 1.0 / xi_new)?;
244
245 p_prev = p;
246 q_prev = q;
247 rho = rho_new;
248 xi = xi_new;
249 gamma_prev = gamma;
250 theta_prev = theta;
251 eta = eta_new;
252 epsilon_prev = epsilon;
253 }
254
255 let ax = a.spmv(&x)?;
256 let r_final = client.sub(b, &ax)?;
257 let final_residual = vector_norm(client, &r_final)?;
258
259 Ok(QmrResult {
260 solution: x,
261 iterations: options.max_iter,
262 residual_norm: final_residual,
263 converged: false,
264 })
265}