1use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::traits::{ComplexField, LinearOperator, Preconditioner, SolverStatus};
11use ndarray::{Array1, Array2};
12use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
13
14#[derive(Debug, Clone)]
16pub struct GmresConfig<R> {
17 pub max_iterations: usize,
19 pub restart: usize,
21 pub tolerance: R,
23 pub print_interval: usize,
25}
26
27impl Default for GmresConfig<f64> {
28 fn default() -> Self {
29 Self {
30 max_iterations: 100,
31 restart: 30,
32 tolerance: 1e-6,
33 print_interval: 0,
34 }
35 }
36}
37
38impl Default for GmresConfig<f32> {
39 fn default() -> Self {
40 Self {
41 max_iterations: 100,
42 restart: 30,
43 tolerance: 1e-5,
44 print_interval: 0,
45 }
46 }
47}
48
49impl<R: Float + FromPrimitive> GmresConfig<R> {
50 pub fn for_small_problems() -> Self {
52 Self {
53 max_iterations: 50,
54 restart: 50,
55 tolerance: R::from_f64(1e-8).unwrap(),
56 print_interval: 0,
57 }
58 }
59
60 pub fn with_restart(restart: usize) -> Self
62 where
63 Self: Default,
64 {
65 Self {
66 restart,
67 ..Default::default()
68 }
69 }
70}
71
72#[derive(Debug)]
74pub struct GmresSolution<T: ComplexField> {
75 pub x: Array1<T>,
77 pub iterations: usize,
79 pub restarts: usize,
81 pub residual: T::Real,
83 pub converged: bool,
85 pub status: SolverStatus,
87}
88
89pub fn gmres<T, A>(operator: &A, b: &Array1<T>, config: &GmresConfig<T::Real>) -> GmresSolution<T>
99where
100 T: ComplexField,
101 A: LinearOperator<T>,
102{
103 gmres_with_guess(operator, b, None, config)
104}
105
106pub fn gmres_with_guess<T, A>(
108 operator: &A,
109 b: &Array1<T>,
110 x0: Option<&Array1<T>>,
111 config: &GmresConfig<T::Real>,
112) -> GmresSolution<T>
113where
114 T: ComplexField,
115 A: LinearOperator<T>,
116{
117 let n = b.len();
118 let m = config.restart;
119
120 let mut x = match x0 {
122 Some(x0) => x0.clone(),
123 None => Array1::from_elem(n, T::zero()),
124 };
125
126 let b_norm = vector_norm(b);
128 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
129 if b_norm < tol_threshold {
130 return GmresSolution {
131 x,
132 iterations: 0,
133 restarts: 0,
134 residual: T::Real::zero(),
135 converged: true,
136 status: SolverStatus::Converged,
137 };
138 }
139
140 let mut total_iterations = 0;
141 let mut restarts = 0;
142
143 for _outer in 0..config.max_iterations {
145 let ax = operator.apply(&x);
147 let r: Array1<T> = b - &ax;
148 let beta = vector_norm(&r);
149
150 let rel_residual = beta / b_norm;
152 if rel_residual < config.tolerance {
153 return GmresSolution {
154 x,
155 iterations: total_iterations,
156 restarts,
157 residual: rel_residual,
158 converged: true,
159 status: SolverStatus::Converged,
160 };
161 }
162
163 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
165 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
166
167 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
169
170 let mut cs: Vec<T> = Vec::with_capacity(m);
172 let mut sn: Vec<T> = Vec::with_capacity(m);
173
174 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
176 g[0] = T::from_real(beta);
177
178 let mut inner_converged = false;
179
180 for j in 0..m {
182 total_iterations += 1;
183
184 let mut w = operator.apply(&v[j]);
186
187 for i in 0..=j {
189 h[[i, j]] = inner_product(&v[i], &w);
190 let h_ij = h[[i, j]];
191 axpy(-h_ij, &v[i], &mut w);
192 }
193
194 let w_norm = vector_norm(&w);
195 h[[j + 1, j]] = T::from_real(w_norm);
196
197 let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
199 if w_norm < breakdown_tol {
200 inner_converged = true;
201 } else {
202 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
203 }
204
205 for i in 0..j {
207 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
208 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
209 h[[i, j]] = temp;
210 }
211
212 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
214 cs.push(c);
215 sn.push(s);
216
217 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
219 h[[j + 1, j]] = T::zero();
220
221 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
222 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
223 g[j] = temp;
224
225 let abs_residual = g[j + 1].norm();
227 let rel_residual = abs_residual / b_norm;
228 let abs_tol = T::Real::from_f64(1e-20).unwrap();
229
230 if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
231 log::info!(
232 "GMRES iteration {} (restart {}): relative residual = {:.6e}",
233 total_iterations,
234 restarts,
235 rel_residual.to_f64().unwrap_or(0.0)
236 );
237 }
238
239 if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
240 let y = solve_upper_triangular(&h, &g, j + 1);
242
243 for (i, &yi) in y.iter().enumerate() {
245 axpy(yi, &v[i], &mut x);
246 }
247
248 let status = if inner_converged
249 && rel_residual >= config.tolerance
250 && abs_residual >= abs_tol
251 {
252 SolverStatus::Breakdown
253 } else {
254 SolverStatus::Converged
255 };
256
257 return GmresSolution {
258 x,
259 iterations: total_iterations,
260 restarts,
261 residual: rel_residual,
262 converged: true,
263 status,
264 };
265 }
266 }
267
268 let y = solve_upper_triangular(&h, &g, m);
270
271 for (i, &yi) in y.iter().enumerate() {
272 axpy(yi, &v[i], &mut x);
273 }
274
275 restarts += 1;
276 }
277
278 let ax = operator.apply(&x);
280 let r: Array1<T> = b - &ax;
281 let rel_residual = vector_norm(&r) / b_norm;
282
283 GmresSolution {
284 x,
285 iterations: total_iterations,
286 restarts,
287 residual: rel_residual,
288 converged: false,
289 status: SolverStatus::MaxIterationsReached,
290 }
291}
292
293pub fn gmres_preconditioned<T, A, P>(
297 operator: &A,
298 precond: &P,
299 b: &Array1<T>,
300 config: &GmresConfig<T::Real>,
301) -> GmresSolution<T>
302where
303 T: ComplexField,
304 A: LinearOperator<T>,
305 P: Preconditioner<T>,
306{
307 let n = b.len();
308 let m = config.restart;
309
310 let mut x = Array1::from_elem(n, T::zero());
311
312 let pb = precond.apply(b);
314 let b_norm = vector_norm(&pb);
315 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
316 if b_norm < tol_threshold {
317 return GmresSolution {
318 x,
319 iterations: 0,
320 restarts: 0,
321 residual: T::Real::zero(),
322 converged: true,
323 status: SolverStatus::Converged,
324 };
325 }
326
327 let mut total_iterations = 0;
328 let mut restarts = 0;
329
330 for _outer in 0..config.max_iterations {
331 let ax = operator.apply(&x);
333 let residual: Array1<T> = b - &ax;
334 let r = precond.apply(&residual);
335 let beta = vector_norm(&r);
336
337 let rel_residual = beta / b_norm;
338 if rel_residual < config.tolerance {
339 return GmresSolution {
340 x,
341 iterations: total_iterations,
342 restarts,
343 residual: rel_residual,
344 converged: true,
345 status: SolverStatus::Converged,
346 };
347 }
348
349 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
350 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
351
352 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
353 let mut cs: Vec<T> = Vec::with_capacity(m);
354 let mut sn: Vec<T> = Vec::with_capacity(m);
355
356 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
357 g[0] = T::from_real(beta);
358
359 let mut inner_converged = false;
360
361 for j in 0..m {
362 total_iterations += 1;
363
364 let av = operator.apply(&v[j]);
366 let mut w = precond.apply(&av);
367
368 for i in 0..=j {
370 h[[i, j]] = inner_product(&v[i], &w);
371 let h_ij = h[[i, j]];
372 w = &w - &v[i].mapv(|vi| vi * h_ij);
373 }
374
375 let w_norm = vector_norm(&w);
376 h[[j + 1, j]] = T::from_real(w_norm);
377
378 let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
379 if w_norm < breakdown_tol {
380 inner_converged = true;
381 } else {
382 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
383 }
384
385 for i in 0..j {
387 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
388 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
389 h[[i, j]] = temp;
390 }
391
392 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
393 cs.push(c);
394 sn.push(s);
395
396 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
397 h[[j + 1, j]] = T::zero();
398
399 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
400 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
401 g[j] = temp;
402
403 let abs_residual = g[j + 1].norm();
404 let rel_residual = abs_residual / b_norm;
405 let abs_tol = T::Real::from_f64(1e-20).unwrap();
406
407 if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
408 let y = solve_upper_triangular(&h, &g, j + 1);
409
410 for (i, &yi) in y.iter().enumerate() {
411 x = &x + &v[i].mapv(|vi| vi * yi);
412 }
413
414 return GmresSolution {
415 x,
416 iterations: total_iterations,
417 restarts,
418 residual: rel_residual,
419 converged: true,
420 status: if inner_converged
421 && rel_residual >= config.tolerance
422 && abs_residual >= abs_tol
423 {
424 SolverStatus::Breakdown
425 } else {
426 SolverStatus::Converged
427 },
428 };
429 }
430 }
431
432 let y = solve_upper_triangular(&h, &g, m);
434 for (i, &yi) in y.iter().enumerate() {
435 x = &x + &v[i].mapv(|vi| vi * yi);
436 }
437
438 restarts += 1;
439 }
440
441 let ax = operator.apply(&x);
443 let residual: Array1<T> = b - &ax;
444 let r = precond.apply(&residual);
445 let rel_residual = vector_norm(&r) / b_norm;
446
447 GmresSolution {
448 x,
449 iterations: total_iterations,
450 restarts,
451 residual: rel_residual,
452 converged: false,
453 status: SolverStatus::MaxIterationsReached,
454 }
455}
456
457pub fn gmres_preconditioned_with_guess<T, A, P>(
462 operator: &A,
463 precond: &P,
464 b: &Array1<T>,
465 x0: Option<&Array1<T>>,
466 config: &GmresConfig<T::Real>,
467) -> GmresSolution<T>
468where
469 T: ComplexField,
470 A: LinearOperator<T>,
471 P: Preconditioner<T>,
472{
473 let n = b.len();
474 let m = config.restart;
475
476 let mut x = match x0 {
478 Some(guess) => guess.clone(),
479 None => Array1::from_elem(n, T::zero()),
480 };
481
482 let pb = precond.apply(b);
484 let b_norm = vector_norm(&pb);
485 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
486 if b_norm < tol_threshold {
487 return GmresSolution {
488 x,
489 iterations: 0,
490 restarts: 0,
491 residual: T::Real::zero(),
492 converged: true,
493 status: SolverStatus::Converged,
494 };
495 }
496
497 let mut total_iterations = 0;
498 let mut restarts = 0;
499
500 for _outer in 0..config.max_iterations {
501 let ax = operator.apply(&x);
503 let residual: Array1<T> = b - &ax;
504 let r = precond.apply(&residual);
505 let beta = vector_norm(&r);
506
507 let rel_residual = beta / b_norm;
508 if rel_residual < config.tolerance {
509 return GmresSolution {
510 x,
511 iterations: total_iterations,
512 restarts,
513 residual: rel_residual,
514 converged: true,
515 status: SolverStatus::Converged,
516 };
517 }
518
519 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
520 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
521
522 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
523 let mut cs: Vec<T> = Vec::with_capacity(m);
524 let mut sn: Vec<T> = Vec::with_capacity(m);
525
526 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
527 g[0] = T::from_real(beta);
528
529 let mut inner_converged = false;
530
531 for j in 0..m {
532 total_iterations += 1;
533
534 let av = operator.apply(&v[j]);
536 let mut w = precond.apply(&av);
537
538 for i in 0..=j {
540 h[[i, j]] = inner_product(&v[i], &w);
541 let h_ij = h[[i, j]];
542 w = &w - &v[i].mapv(|vi| vi * h_ij);
543 }
544
545 let w_norm = vector_norm(&w);
546 h[[j + 1, j]] = T::from_real(w_norm);
547
548 let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
549 if w_norm < breakdown_tol {
550 inner_converged = true;
551 } else {
552 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
553 }
554
555 for i in 0..j {
557 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
558 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
559 h[[i, j]] = temp;
560 }
561
562 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
563 cs.push(c);
564 sn.push(s);
565
566 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
567 h[[j + 1, j]] = T::zero();
568
569 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
570 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
571 g[j] = temp;
572
573 let abs_residual = g[j + 1].norm();
574 let rel_residual = abs_residual / b_norm;
575 let abs_tol = T::Real::from_f64(1e-20).unwrap();
576
577 if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
578 let y = solve_upper_triangular(&h, &g, j + 1);
579
580 for (i, &yi) in y.iter().enumerate() {
581 x = &x + &v[i].mapv(|vi| vi * yi);
582 }
583
584 return GmresSolution {
585 x,
586 iterations: total_iterations,
587 restarts,
588 residual: rel_residual,
589 converged: true,
590 status: if inner_converged
591 && rel_residual >= config.tolerance
592 && abs_residual >= abs_tol
593 {
594 SolverStatus::Breakdown
595 } else {
596 SolverStatus::Converged
597 },
598 };
599 }
600 }
601
602 let y = solve_upper_triangular(&h, &g, m);
604 for (i, &yi) in y.iter().enumerate() {
605 x = &x + &v[i].mapv(|vi| vi * yi);
606 }
607
608 restarts += 1;
609 }
610
611 let ax = operator.apply(&x);
613 let residual: Array1<T> = b - &ax;
614 let r = precond.apply(&residual);
615 let rel_residual = vector_norm(&r) / b_norm;
616
617 GmresSolution {
618 x,
619 iterations: total_iterations,
620 restarts,
621 residual: rel_residual,
622 converged: false,
623 status: SolverStatus::MaxIterationsReached,
624 }
625}
626
627#[inline]
629fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
630 let tol = T::Real::from_f64(1e-30).unwrap();
631 if b.norm() < tol {
632 return (T::one(), T::zero());
633 }
634 if a.norm() < tol {
635 return (T::zero(), T::one());
636 }
637
638 let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
639 let c = a * T::from_real(T::Real::one() / r);
640 let s = b * T::from_real(T::Real::one() / r);
641
642 (c, s)
643}
644
645fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
647 let mut y = vec![T::zero(); k];
648 let tol = T::Real::from_f64(1e-30).unwrap();
649
650 for i in (0..k).rev() {
651 let mut sum = g[i];
652 for j in (i + 1)..k {
653 sum -= h[[i, j]] * y[j];
654 }
655 if h[[i, i]].norm() > tol {
656 y[i] = sum * h[[i, i]].inv();
657 }
658 }
659
660 y
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666 use crate::sparse::CsrMatrix;
667 use approx::assert_relative_eq;
668 use ndarray::array;
669 use num_complex::Complex64;
670
671 #[test]
672 fn test_gmres_simple() {
673 let dense = array![
674 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
675 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
676 ];
677
678 let a = CsrMatrix::from_dense(&dense, 1e-15);
679 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
680
681 let config = GmresConfig {
682 max_iterations: 100,
683 restart: 10,
684 tolerance: 1e-10,
685 print_interval: 0,
686 };
687
688 let solution = gmres(&a, &b, &config);
689
690 assert!(solution.converged, "GMRES should converge");
691
692 let ax = a.matvec(&solution.x);
694 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
695 assert!(error < 1e-8, "Solution should satisfy Ax = b");
696 }
697
698 #[test]
699 fn test_gmres_identity() {
700 let n = 5;
701 let id: CsrMatrix<Complex64> = CsrMatrix::identity(n);
702 let b = Array1::from_iter((1..=n).map(|i| Complex64::new(i as f64, 0.0)));
703
704 let config = GmresConfig {
705 max_iterations: 10,
706 restart: 10,
707 tolerance: 1e-12,
708 print_interval: 0,
709 };
710
711 let solution = gmres(&id, &b, &config);
712
713 assert!(solution.converged);
714 assert!(solution.iterations <= 2);
715
716 let error: f64 = (&solution.x - &b)
717 .iter()
718 .map(|e| e.norm_sqr())
719 .sum::<f64>()
720 .sqrt();
721 assert!(error < 1e-10);
722 }
723
724 #[test]
725 fn test_gmres_f64() {
726 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
727
728 let a = CsrMatrix::from_dense(&dense, 1e-15);
729 let b = array![1.0_f64, 2.0];
730
731 let config = GmresConfig {
732 max_iterations: 100,
733 restart: 10,
734 tolerance: 1e-10,
735 print_interval: 0,
736 };
737
738 let solution = gmres(&a, &b, &config);
739
740 assert!(solution.converged);
741
742 let ax = a.matvec(&solution.x);
743 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
744 assert_relative_eq!(error, 0.0, epsilon = 1e-8);
745 }
746}