1use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::traits::{ComplexField, LinearOperator, Preconditioner};
11use ndarray::{Array1, Array2};
12use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
13
14#[derive(Debug, Clone)]
16pub struct GmresConfig<R> {
17 pub max_iterations: usize,
19 pub restart: usize,
21 pub tolerance: R,
23 pub print_interval: usize,
25}
26
27impl Default for GmresConfig<f64> {
28 fn default() -> Self {
29 Self {
30 max_iterations: 100,
31 restart: 30,
32 tolerance: 1e-6,
33 print_interval: 0,
34 }
35 }
36}
37
38impl Default for GmresConfig<f32> {
39 fn default() -> Self {
40 Self {
41 max_iterations: 100,
42 restart: 30,
43 tolerance: 1e-5,
44 print_interval: 0,
45 }
46 }
47}
48
49impl<R: Float + FromPrimitive> GmresConfig<R> {
50 pub fn for_small_problems() -> Self {
52 Self {
53 max_iterations: 50,
54 restart: 50,
55 tolerance: R::from_f64(1e-8).unwrap(),
56 print_interval: 0,
57 }
58 }
59
60 pub fn with_restart(restart: usize) -> Self
62 where
63 Self: Default,
64 {
65 Self {
66 restart,
67 ..Default::default()
68 }
69 }
70}
71
72#[derive(Debug)]
74pub struct GmresSolution<T: ComplexField> {
75 pub x: Array1<T>,
77 pub iterations: usize,
79 pub restarts: usize,
81 pub residual: T::Real,
83 pub converged: bool,
85}
86
87pub fn gmres<T, A>(operator: &A, b: &Array1<T>, config: &GmresConfig<T::Real>) -> GmresSolution<T>
97where
98 T: ComplexField,
99 A: LinearOperator<T>,
100{
101 gmres_with_guess(operator, b, None, config)
102}
103
104pub fn gmres_with_guess<T, A>(
106 operator: &A,
107 b: &Array1<T>,
108 x0: Option<&Array1<T>>,
109 config: &GmresConfig<T::Real>,
110) -> GmresSolution<T>
111where
112 T: ComplexField,
113 A: LinearOperator<T>,
114{
115 let n = b.len();
116 let m = config.restart;
117
118 let mut x = match x0 {
120 Some(x0) => x0.clone(),
121 None => Array1::from_elem(n, T::zero()),
122 };
123
124 let b_norm = vector_norm(b);
126 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
127 if b_norm < tol_threshold {
128 return GmresSolution {
129 x,
130 iterations: 0,
131 restarts: 0,
132 residual: T::Real::zero(),
133 converged: true,
134 };
135 }
136
137 let mut total_iterations = 0;
138 let mut restarts = 0;
139
140 for _outer in 0..config.max_iterations {
142 let ax = operator.apply(&x);
144 let r: Array1<T> = b - &ax;
145 let beta = vector_norm(&r);
146
147 let rel_residual = beta / b_norm;
149 if rel_residual < config.tolerance {
150 return GmresSolution {
151 x,
152 iterations: total_iterations,
153 restarts,
154 residual: rel_residual,
155 converged: true,
156 };
157 }
158
159 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
161 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
162
163 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
165
166 let mut cs: Vec<T> = Vec::with_capacity(m);
168 let mut sn: Vec<T> = Vec::with_capacity(m);
169
170 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
172 g[0] = T::from_real(beta);
173
174 let mut inner_converged = false;
175
176 for j in 0..m {
178 total_iterations += 1;
179
180 let mut w = operator.apply(&v[j]);
182
183 for i in 0..=j {
185 h[[i, j]] = inner_product(&v[i], &w);
186 let h_ij = h[[i, j]];
187 axpy(-h_ij, &v[i], &mut w);
188 }
189
190 let w_norm = vector_norm(&w);
191 h[[j + 1, j]] = T::from_real(w_norm);
192
193 let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
195 if w_norm < breakdown_tol {
196 inner_converged = true;
197 } else {
198 let inv_norm = T::from_real(T::Real::one() / w_norm);
199 let mut new_v = w.clone();
200 axpy(inv_norm - T::one(), &w, &mut new_v);
201 v.push(new_v);
202 }
203
204 for i in 0..j {
206 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
207 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
208 h[[i, j]] = temp;
209 }
210
211 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
213 cs.push(c);
214 sn.push(s);
215
216 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
218 h[[j + 1, j]] = T::zero();
219
220 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
221 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
222 g[j] = temp;
223
224 let rel_residual = g[j + 1].norm() / b_norm;
226
227 if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
228 log::info!(
229 "GMRES iteration {} (restart {}): relative residual = {:.6e}",
230 total_iterations,
231 restarts,
232 rel_residual.to_f64().unwrap_or(0.0)
233 );
234 }
235
236 if rel_residual < config.tolerance || inner_converged {
237 let y = solve_upper_triangular(&h, &g, j + 1);
239
240 for (i, &yi) in y.iter().enumerate() {
242 axpy(yi, &v[i], &mut x);
243 }
244
245 return GmresSolution {
246 x,
247 iterations: total_iterations,
248 restarts,
249 residual: rel_residual,
250 converged: true,
251 };
252 }
253 }
254
255 let y = solve_upper_triangular(&h, &g, m);
257
258 for (i, &yi) in y.iter().enumerate() {
259 axpy(yi, &v[i], &mut x);
260 }
261
262 restarts += 1;
263 }
264
265 let ax = operator.apply(&x);
267 let r: Array1<T> = b - &ax;
268 let rel_residual = vector_norm(&r) / b_norm;
269
270 GmresSolution {
271 x,
272 iterations: total_iterations,
273 restarts,
274 residual: rel_residual,
275 converged: false,
276 }
277}
278
279pub fn gmres_preconditioned<T, A, P>(
283 operator: &A,
284 precond: &P,
285 b: &Array1<T>,
286 config: &GmresConfig<T::Real>,
287) -> GmresSolution<T>
288where
289 T: ComplexField,
290 A: LinearOperator<T>,
291 P: Preconditioner<T>,
292{
293 let n = b.len();
294 let m = config.restart;
295
296 let mut x = Array1::from_elem(n, T::zero());
297
298 let pb = precond.apply(b);
300 let b_norm = vector_norm(&pb);
301 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
302 if b_norm < tol_threshold {
303 return GmresSolution {
304 x,
305 iterations: 0,
306 restarts: 0,
307 residual: T::Real::zero(),
308 converged: true,
309 };
310 }
311
312 let mut total_iterations = 0;
313 let mut restarts = 0;
314
315 for _outer in 0..config.max_iterations {
316 let ax = operator.apply(&x);
318 let residual: Array1<T> = b - &ax;
319 let r = precond.apply(&residual);
320 let beta = vector_norm(&r);
321
322 let rel_residual = beta / b_norm;
323 if rel_residual < config.tolerance {
324 return GmresSolution {
325 x,
326 iterations: total_iterations,
327 restarts,
328 residual: rel_residual,
329 converged: true,
330 };
331 }
332
333 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
334 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
335
336 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
337 let mut cs: Vec<T> = Vec::with_capacity(m);
338 let mut sn: Vec<T> = Vec::with_capacity(m);
339
340 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
341 g[0] = T::from_real(beta);
342
343 let mut inner_converged = false;
344
345 for j in 0..m {
346 total_iterations += 1;
347
348 let av = operator.apply(&v[j]);
350 let mut w = precond.apply(&av);
351
352 for i in 0..=j {
354 h[[i, j]] = inner_product(&v[i], &w);
355 let h_ij = h[[i, j]];
356 w = &w - &v[i].mapv(|vi| vi * h_ij);
357 }
358
359 let w_norm = vector_norm(&w);
360 h[[j + 1, j]] = T::from_real(w_norm);
361
362 let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
363 if w_norm < breakdown_tol {
364 inner_converged = true;
365 } else {
366 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
367 }
368
369 for i in 0..j {
371 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
372 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
373 h[[i, j]] = temp;
374 }
375
376 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
377 cs.push(c);
378 sn.push(s);
379
380 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
381 h[[j + 1, j]] = T::zero();
382
383 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
384 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
385 g[j] = temp;
386
387 let rel_residual = g[j + 1].norm() / b_norm;
388
389 if rel_residual < config.tolerance || inner_converged {
390 let y = solve_upper_triangular(&h, &g, j + 1);
391
392 for (i, &yi) in y.iter().enumerate() {
393 x = &x + &v[i].mapv(|vi| vi * yi);
394 }
395
396 return GmresSolution {
397 x,
398 iterations: total_iterations,
399 restarts,
400 residual: rel_residual,
401 converged: true,
402 };
403 }
404 }
405
406 let y = solve_upper_triangular(&h, &g, m);
408 for (i, &yi) in y.iter().enumerate() {
409 x = &x + &v[i].mapv(|vi| vi * yi);
410 }
411
412 restarts += 1;
413 }
414
415 let ax = operator.apply(&x);
417 let residual: Array1<T> = b - &ax;
418 let r = precond.apply(&residual);
419 let rel_residual = vector_norm(&r) / b_norm;
420
421 GmresSolution {
422 x,
423 iterations: total_iterations,
424 restarts,
425 residual: rel_residual,
426 converged: false,
427 }
428}
429
430pub fn gmres_preconditioned_with_guess<T, A, P>(
435 operator: &A,
436 precond: &P,
437 b: &Array1<T>,
438 x0: Option<&Array1<T>>,
439 config: &GmresConfig<T::Real>,
440) -> GmresSolution<T>
441where
442 T: ComplexField,
443 A: LinearOperator<T>,
444 P: Preconditioner<T>,
445{
446 let n = b.len();
447 let m = config.restart;
448
449 let mut x = match x0 {
451 Some(guess) => guess.clone(),
452 None => Array1::from_elem(n, T::zero()),
453 };
454
455 let pb = precond.apply(b);
457 let b_norm = vector_norm(&pb);
458 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
459 if b_norm < tol_threshold {
460 return GmresSolution {
461 x,
462 iterations: 0,
463 restarts: 0,
464 residual: T::Real::zero(),
465 converged: true,
466 };
467 }
468
469 let mut total_iterations = 0;
470 let mut restarts = 0;
471
472 for _outer in 0..config.max_iterations {
473 let ax = operator.apply(&x);
475 let residual: Array1<T> = b - &ax;
476 let r = precond.apply(&residual);
477 let beta = vector_norm(&r);
478
479 let rel_residual = beta / b_norm;
480 if rel_residual < config.tolerance {
481 return GmresSolution {
482 x,
483 iterations: total_iterations,
484 restarts,
485 residual: rel_residual,
486 converged: true,
487 };
488 }
489
490 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
491 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
492
493 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
494 let mut cs: Vec<T> = Vec::with_capacity(m);
495 let mut sn: Vec<T> = Vec::with_capacity(m);
496
497 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
498 g[0] = T::from_real(beta);
499
500 let mut inner_converged = false;
501
502 for j in 0..m {
503 total_iterations += 1;
504
505 let av = operator.apply(&v[j]);
507 let mut w = precond.apply(&av);
508
509 for i in 0..=j {
511 h[[i, j]] = inner_product(&v[i], &w);
512 let h_ij = h[[i, j]];
513 w = &w - &v[i].mapv(|vi| vi * h_ij);
514 }
515
516 let w_norm = vector_norm(&w);
517 h[[j + 1, j]] = T::from_real(w_norm);
518
519 let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
520 if w_norm < breakdown_tol {
521 inner_converged = true;
522 } else {
523 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
524 }
525
526 for i in 0..j {
528 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
529 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
530 h[[i, j]] = temp;
531 }
532
533 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
534 cs.push(c);
535 sn.push(s);
536
537 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
538 h[[j + 1, j]] = T::zero();
539
540 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
541 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
542 g[j] = temp;
543
544 let rel_residual = g[j + 1].norm() / b_norm;
545
546 if rel_residual < config.tolerance || inner_converged {
547 let y = solve_upper_triangular(&h, &g, j + 1);
548
549 for (i, &yi) in y.iter().enumerate() {
550 x = &x + &v[i].mapv(|vi| vi * yi);
551 }
552
553 return GmresSolution {
554 x,
555 iterations: total_iterations,
556 restarts,
557 residual: rel_residual,
558 converged: true,
559 };
560 }
561 }
562
563 let y = solve_upper_triangular(&h, &g, m);
565 for (i, &yi) in y.iter().enumerate() {
566 x = &x + &v[i].mapv(|vi| vi * yi);
567 }
568
569 restarts += 1;
570 }
571
572 let ax = operator.apply(&x);
574 let residual: Array1<T> = b - &ax;
575 let r = precond.apply(&residual);
576 let rel_residual = vector_norm(&r) / b_norm;
577
578 GmresSolution {
579 x,
580 iterations: total_iterations,
581 restarts,
582 residual: rel_residual,
583 converged: false,
584 }
585}
586
587#[inline]
589fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
590 let tol = T::Real::from_f64(1e-30).unwrap();
591 if b.norm() < tol {
592 return (T::one(), T::zero());
593 }
594 if a.norm() < tol {
595 return (T::zero(), T::one());
596 }
597
598 let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
599 let c = a * T::from_real(T::Real::one() / r);
600 let s = b * T::from_real(T::Real::one() / r);
601
602 (c, s)
603}
604
605fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
607 let mut y = vec![T::zero(); k];
608 let tol = T::Real::from_f64(1e-30).unwrap();
609
610 for i in (0..k).rev() {
611 let mut sum = g[i];
612 for j in (i + 1)..k {
613 sum -= h[[i, j]] * y[j];
614 }
615 if h[[i, i]].norm() > tol {
616 y[i] = sum * h[[i, i]].inv();
617 }
618 }
619
620 y
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use crate::sparse::CsrMatrix;
627 use approx::assert_relative_eq;
628 use ndarray::array;
629 use num_complex::Complex64;
630
631 #[test]
632 fn test_gmres_simple() {
633 let dense = array![
634 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
635 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
636 ];
637
638 let a = CsrMatrix::from_dense(&dense, 1e-15);
639 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
640
641 let config = GmresConfig {
642 max_iterations: 100,
643 restart: 10,
644 tolerance: 1e-10,
645 print_interval: 0,
646 };
647
648 let solution = gmres(&a, &b, &config);
649
650 assert!(solution.converged, "GMRES should converge");
651
652 let ax = a.matvec(&solution.x);
654 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
655 assert!(error < 1e-8, "Solution should satisfy Ax = b");
656 }
657
658 #[test]
659 fn test_gmres_identity() {
660 let n = 5;
661 let id: CsrMatrix<Complex64> = CsrMatrix::identity(n);
662 let b = Array1::from_iter((1..=n).map(|i| Complex64::new(i as f64, 0.0)));
663
664 let config = GmresConfig {
665 max_iterations: 10,
666 restart: 10,
667 tolerance: 1e-12,
668 print_interval: 0,
669 };
670
671 let solution = gmres(&id, &b, &config);
672
673 assert!(solution.converged);
674 assert!(solution.iterations <= 2);
675
676 let error: f64 = (&solution.x - &b)
677 .iter()
678 .map(|e| e.norm_sqr())
679 .sum::<f64>()
680 .sqrt();
681 assert!(error < 1e-10);
682 }
683
684 #[test]
685 fn test_gmres_f64() {
686 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
687
688 let a = CsrMatrix::from_dense(&dense, 1e-15);
689 let b = array![1.0_f64, 2.0];
690
691 let config = GmresConfig {
692 max_iterations: 100,
693 restart: 10,
694 tolerance: 1e-10,
695 print_interval: 0,
696 };
697
698 let solution = gmres(&a, &b, &config);
699
700 assert!(solution.converged);
701
702 let ax = a.matvec(&solution.x);
703 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
704 assert_relative_eq!(error, 0.0, epsilon = 1e-8);
705 }
706}