1#![allow(unused_variables)]
13#![allow(unused_assignments)]
14#![allow(unused_mut)]
15
16use crate::error::{SparseError, SparseResult};
17use crate::sparray::SparseArray;
18use scirs2_core::ndarray::{Array1, ArrayView1};
19use scirs2_core::numeric::{Float, SparseElement};
20use std::fmt::Debug;
21
22#[derive(Debug, Clone)]
24pub struct TFQMROptions {
25 pub max_iter: usize,
27 pub tol: f64,
29 pub use_left_preconditioner: bool,
31 pub use_right_preconditioner: bool,
33}
34
35impl Default for TFQMROptions {
36 fn default() -> Self {
37 Self {
38 max_iter: 1000,
39 tol: 1e-6,
40 use_left_preconditioner: false,
41 use_right_preconditioner: false,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct TFQMRResult<T> {
49 pub x: Array1<T>,
51 pub iterations: usize,
53 pub residual_norm: T,
55 pub converged: bool,
57 pub residual_history: Option<Vec<T>>,
59}
60
61#[allow(dead_code)]
98pub fn tfqmr<T, S>(
99 matrix: &S,
100 b: &ArrayView1<T>,
101 x0: Option<&ArrayView1<T>>,
102 options: TFQMROptions,
103) -> SparseResult<TFQMRResult<T>>
104where
105 T: Float + SparseElement + Debug + Copy + 'static,
106 S: SparseArray<T>,
107{
108 let n = b.len();
109 let (rows, cols) = matrix.shape();
110
111 if rows != cols || rows != n {
112 return Err(SparseError::DimensionMismatch {
113 expected: n,
114 found: rows,
115 });
116 }
117
118 let one = T::sparse_one();
119 let zero = T::sparse_zero();
120
121 let mut x = match x0 {
123 Some(x0_val) => x0_val.to_owned(),
124 None => Array1::zeros(n),
125 };
126
127 let ax = matrix_vector_multiply(matrix, &x.view())?;
129 let r = b - &ax;
130
131 let r0norm = l2_norm(&r.view());
132 let b_norm = l2_norm(b);
133 let tolerance = T::from(options.tol).expect("Operation failed")
134 * b_norm.max(T::from(1e-10).expect("Operation failed"));
135
136 if r0norm <= tolerance || r0norm == zero {
137 return Ok(TFQMRResult {
138 x,
139 iterations: 0,
140 residual_norm: r0norm,
141 converged: true,
142 residual_history: Some(vec![r0norm]),
143 });
144 }
145
146 let mut u = r.clone();
148 let mut w = r.clone();
149 let rstar = r.clone(); let ar = matrix_vector_multiply(matrix, &r.view())?;
153 let mut v = ar;
154 let mut uhat = v.clone();
155
156 let mut d: Array1<T> = Array1::zeros(n);
157 let mut theta = zero;
158 let mut eta = zero;
159
160 let mut rho = dot_product(&rstar.view(), &r.view());
162 let mut rho_last = rho;
163 let mut tau = r0norm;
164
165 let mut residual_history = Vec::new();
166 residual_history.push(r0norm);
167
168 let mut converged = false;
169 let mut iter = 0;
170
171 for it in 0..options.max_iter {
172 iter = it + 1;
173 let even = it % 2 == 0;
174
175 let mut alpha = zero;
177 let mut u_next: Array1<T> = Array1::zeros(n);
178
179 if even {
180 let vtrstar = dot_product(&rstar.view(), &v.view());
181 if vtrstar.abs() < T::from(1e-300).expect("Operation failed") {
182 return Err(SparseError::ConvergenceError(
183 "TFQMR breakdown: v'*rstar = 0".to_string(),
184 ));
185 }
186 alpha = rho / vtrstar;
187
188 for i in 0..n {
190 u_next[i] = u[i] - alpha * v[i];
191 }
192 }
193
194 let alpha_used = if even {
196 alpha
197 } else {
198 rho / dot_product(&rstar.view(), &v.view())
199 };
200
201 for i in 0..n {
202 w[i] = w[i] - alpha_used * uhat[i];
203 }
204
205 let theta_sq_over_alpha = if alpha_used.abs() > T::from(1e-300).expect("Operation failed") {
207 theta * theta / alpha_used
208 } else {
209 zero
210 };
211 for i in 0..n {
212 d[i] = u[i] + theta_sq_over_alpha * eta * d[i];
213 }
214
215 theta = l2_norm(&w.view()) / tau;
217
218 let c = one / (one + theta * theta).sqrt();
220
221 tau = tau * theta * c;
223
224 eta = c * c * alpha_used;
226
227 for i in 0..n {
229 x[i] = x[i] + eta * d[i];
230 }
231
232 residual_history.push(tau);
233
234 let iter_f = T::from(iter).expect("Operation failed");
236 if tau * iter_f.sqrt() < tolerance {
237 converged = true;
238 break;
239 }
240
241 if !even {
242 rho = dot_product(&rstar.view(), &w.view());
244
245 if rho.abs() < T::from(1e-300).expect("Operation failed") {
246 return Err(SparseError::ConvergenceError(
247 "TFQMR breakdown: rho = 0".to_string(),
248 ));
249 }
250
251 let beta = rho / rho_last;
252
253 for i in 0..n {
255 u[i] = w[i] + beta * u[i];
256 }
257
258 for i in 0..n {
260 v[i] = beta * uhat[i] + beta * beta * v[i];
261 }
262
263 let au = matrix_vector_multiply(matrix, &u.view())?;
265 uhat = au;
266
267 for i in 0..n {
269 v[i] = v[i] + uhat[i];
270 }
271 } else {
272 let au_next = matrix_vector_multiply(matrix, &u_next.view())?;
275 uhat = au_next;
276
277 u = u_next;
279
280 rho_last = rho;
282 }
283 }
284
285 let ax_final = matrix_vector_multiply(matrix, &x.view())?;
287 let final_residual = b - &ax_final;
288 let final_residual_norm = l2_norm(&final_residual.view());
289
290 Ok(TFQMRResult {
291 x,
292 iterations: iter,
293 residual_norm: final_residual_norm,
294 converged,
295 residual_history: Some(residual_history),
296 })
297}
298
299#[allow(dead_code)]
301fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
302where
303 T: Float + SparseElement + Debug + Copy + 'static,
304 S: SparseArray<T>,
305{
306 let (rows, cols) = matrix.shape();
307 if x.len() != cols {
308 return Err(SparseError::DimensionMismatch {
309 expected: cols,
310 found: x.len(),
311 });
312 }
313
314 let mut result = Array1::zeros(rows);
315 let (row_indices, col_indices, values) = matrix.find();
316
317 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
318 result[i] = result[i] + values[k] * x[j];
319 }
320
321 Ok(result)
322}
323
324#[allow(dead_code)]
326fn l2_norm<T>(x: &ArrayView1<T>) -> T
327where
328 T: Float + SparseElement + Debug + Copy,
329{
330 (x.iter()
331 .map(|&val| val * val)
332 .fold(T::sparse_zero(), |a, b| a + b))
333 .sqrt()
334}
335
336#[allow(dead_code)]
338fn dot_product<T>(x: &ArrayView1<T>, y: &ArrayView1<T>) -> T
339where
340 T: Float + SparseElement + Debug + Copy,
341{
342 x.iter()
343 .zip(y.iter())
344 .map(|(&xi, &yi)| xi * yi)
345 .fold(T::sparse_zero(), |a, b| a + b)
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use crate::csr_array::CsrArray;
352 use approx::assert_relative_eq;
353
354 #[test]
355 fn test_tfqmr_simple_system() {
356 let rows = vec![0, 0, 1, 1, 2, 2];
358 let cols = vec![0, 1, 0, 1, 1, 2];
359 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
360 let matrix =
361 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
362
363 let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
364 let result =
365 tfqmr(&matrix, &b.view(), None, TFQMROptions::default()).expect("Operation failed");
366
367 assert!(result.converged);
368
369 let ax = matrix_vector_multiply(&matrix, &result.x.view()).expect("Operation failed");
371 let residual = &b - &ax;
372 let residual_norm = l2_norm(&residual.view());
373
374 assert!(residual_norm < 1e-6);
375 }
376
377 #[test]
378 fn test_tfqmr_diagonal_system() {
379 let rows = vec![0, 1, 2];
381 let cols = vec![0, 1, 2];
382 let data = vec![2.0, 3.0, 4.0];
383 let matrix =
384 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
385
386 let b = Array1::from_vec(vec![4.0, 9.0, 16.0]);
387 let result =
388 tfqmr(&matrix, &b.view(), None, TFQMROptions::default()).expect("Operation failed");
389
390 assert!(result.converged);
391
392 assert_relative_eq!(result.x[0], 2.0, epsilon = 1e-6);
394 assert_relative_eq!(result.x[1], 3.0, epsilon = 1e-6);
395 assert_relative_eq!(result.x[2], 4.0, epsilon = 1e-6);
396 }
397
398 #[test]
399 fn test_tfqmr_with_initial_guess() {
400 let rows = vec![0, 1, 2];
401 let cols = vec![0, 1, 2];
402 let data = vec![1.0, 1.0, 1.0];
403 let matrix =
404 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
405
406 let b = Array1::from_vec(vec![5.0, 6.0, 7.0]);
407 let x0 = Array1::from_vec(vec![4.0, 5.0, 6.0]); let result = tfqmr(
410 &matrix,
411 &b.view(),
412 Some(&x0.view()),
413 TFQMROptions::default(),
414 )
415 .expect("Operation failed");
416
417 assert!(result.converged);
418 assert!(result.iterations <= 5); }
420}