1use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::traits::{ComplexField, LinearOperator, Preconditioner, SolverStatus};
11use ndarray::{Array1, Array2};
12use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
13
14#[derive(Debug, Clone)]
16pub struct GmresConfig<R> {
17 pub max_iterations: usize,
19 pub restart: usize,
21 pub tolerance: R,
23 pub print_interval: usize,
25}
26
27impl Default for GmresConfig<f64> {
28 fn default() -> Self {
29 Self {
30 max_iterations: 100,
31 restart: 30,
32 tolerance: 1e-6,
33 print_interval: 0,
34 }
35 }
36}
37
38impl Default for GmresConfig<f32> {
39 fn default() -> Self {
40 Self {
41 max_iterations: 100,
42 restart: 30,
43 tolerance: 1e-5,
44 print_interval: 0,
45 }
46 }
47}
48
49impl<R: Float + FromPrimitive> GmresConfig<R> {
50 pub fn for_small_problems() -> Self {
52 Self {
53 max_iterations: 50,
54 restart: 50,
55 tolerance: R::from_f64(1e-8).unwrap(),
56 print_interval: 0,
57 }
58 }
59
60 pub fn with_restart(restart: usize) -> Self
62 where
63 Self: Default,
64 {
65 Self {
66 restart,
67 ..Default::default()
68 }
69 }
70}
71
72#[derive(Debug)]
74pub struct GmresSolution<T: ComplexField> {
75 pub x: Array1<T>,
77 pub iterations: usize,
79 pub restarts: usize,
81 pub residual: T::Real,
83 pub converged: bool,
85 pub status: SolverStatus,
87}
88
89pub fn gmres<T, A>(operator: &A, b: &Array1<T>, config: &GmresConfig<T::Real>) -> GmresSolution<T>
99where
100 T: ComplexField,
101 A: LinearOperator<T>,
102{
103 gmres_with_guess(operator, b, None, config)
104}
105
106pub fn gmres_with_guess<T, A>(
108 operator: &A,
109 b: &Array1<T>,
110 x0: Option<&Array1<T>>,
111 config: &GmresConfig<T::Real>,
112) -> GmresSolution<T>
113where
114 T: ComplexField,
115 A: LinearOperator<T>,
116{
117 let n = b.len();
118 let m = config.restart;
119
120 let mut x = match x0 {
122 Some(x0) => x0.clone(),
123 None => Array1::from_elem(n, T::zero()),
124 };
125
126 let b_norm = vector_norm(b);
128 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
129 if b_norm < tol_threshold {
130 return GmresSolution {
131 x,
132 iterations: 0,
133 restarts: 0,
134 residual: T::Real::zero(),
135 converged: true,
136 status: SolverStatus::Converged,
137 };
138 }
139
140 let mut total_iterations = 0;
141 let mut restarts = 0;
142
143 for _outer in 0..config.max_iterations {
145 let ax = operator.apply(&x);
147 let r: Array1<T> = b - &ax;
148 let beta = vector_norm(&r);
149
150 let rel_residual = beta / b_norm;
152 if rel_residual < config.tolerance {
153 return GmresSolution {
154 x,
155 iterations: total_iterations,
156 restarts,
157 residual: rel_residual,
158 converged: true,
159 status: SolverStatus::Converged,
160 };
161 }
162
163 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
165 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
166
167 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
169
170 let mut cs: Vec<T> = Vec::with_capacity(m);
172 let mut sn: Vec<T> = Vec::with_capacity(m);
173
174 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
176 g[0] = T::from_real(beta);
177
178 let mut inner_converged = false;
179
180 for j in 0..m {
182 total_iterations += 1;
183
184 let mut w = operator.apply(&v[j]);
186
187 for i in 0..=j {
189 h[[i, j]] = inner_product(&v[i], &w);
190 let h_ij = h[[i, j]];
191 axpy(-h_ij, &v[i], &mut w);
192 }
193
194 let w_norm = vector_norm(&w);
195 h[[j + 1, j]] = T::from_real(w_norm);
196
197 let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
199 if w_norm < breakdown_tol {
200 inner_converged = true;
201 } else {
202 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
203 }
204
205 for i in 0..j {
207 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
208 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
209 h[[i, j]] = temp;
210 }
211
212 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
214 cs.push(c);
215 sn.push(s);
216
217 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
219 h[[j + 1, j]] = T::zero();
220
221 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
222 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
223 g[j] = temp;
224
225 let abs_residual = g[j + 1].norm();
227 let rel_residual = abs_residual / b_norm;
228 let abs_tol = T::Real::from_f64(1e-20).unwrap();
229
230 if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
231 log::info!(
232 "GMRES iteration {} (restart {}): relative residual = {:.6e}",
233 total_iterations,
234 restarts,
235 rel_residual.to_f64().unwrap_or(0.0)
236 );
237 }
238
239 if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
240 let y = solve_upper_triangular(&h, &g, j + 1);
242
243 for (i, &yi) in y.iter().enumerate() {
245 axpy(yi, &v[i], &mut x);
246 }
247
248 let status = if inner_converged && rel_residual >= config.tolerance && abs_residual >= abs_tol {
249 SolverStatus::Breakdown
250 } else {
251 SolverStatus::Converged
252 };
253
254 return GmresSolution {
255 x,
256 iterations: total_iterations,
257 restarts,
258 residual: rel_residual,
259 converged: true,
260 status,
261 };
262 }
263 }
264
265 let y = solve_upper_triangular(&h, &g, m);
267
268 for (i, &yi) in y.iter().enumerate() {
269 axpy(yi, &v[i], &mut x);
270 }
271
272 restarts += 1;
273 }
274
275 let ax = operator.apply(&x);
277 let r: Array1<T> = b - &ax;
278 let rel_residual = vector_norm(&r) / b_norm;
279
280 GmresSolution {
281 x,
282 iterations: total_iterations,
283 restarts,
284 residual: rel_residual,
285 converged: false,
286 status: SolverStatus::MaxIterationsReached,
287 }
288}
289
290pub fn gmres_preconditioned<T, A, P>(
294 operator: &A,
295 precond: &P,
296 b: &Array1<T>,
297 config: &GmresConfig<T::Real>,
298) -> GmresSolution<T>
299where
300 T: ComplexField,
301 A: LinearOperator<T>,
302 P: Preconditioner<T>,
303{
304 let n = b.len();
305 let m = config.restart;
306
307 let mut x = Array1::from_elem(n, T::zero());
308
309 let pb = precond.apply(b);
311 let b_norm = vector_norm(&pb);
312 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
313 if b_norm < tol_threshold {
314 return GmresSolution {
315 x,
316 iterations: 0,
317 restarts: 0,
318 residual: T::Real::zero(),
319 converged: true,
320 status: SolverStatus::Converged,
321 };
322 }
323
324 let mut total_iterations = 0;
325 let mut restarts = 0;
326
327 for _outer in 0..config.max_iterations {
328 let ax = operator.apply(&x);
330 let residual: Array1<T> = b - &ax;
331 let r = precond.apply(&residual);
332 let beta = vector_norm(&r);
333
334 let rel_residual = beta / b_norm;
335 if rel_residual < config.tolerance {
336 return GmresSolution {
337 x,
338 iterations: total_iterations,
339 restarts,
340 residual: rel_residual,
341 converged: true,
342 status: SolverStatus::Converged,
343 };
344 }
345
346 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
347 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
348
349 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
350 let mut cs: Vec<T> = Vec::with_capacity(m);
351 let mut sn: Vec<T> = Vec::with_capacity(m);
352
353 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
354 g[0] = T::from_real(beta);
355
356 let mut inner_converged = false;
357
358 for j in 0..m {
359 total_iterations += 1;
360
361 let av = operator.apply(&v[j]);
363 let mut w = precond.apply(&av);
364
365 for i in 0..=j {
367 h[[i, j]] = inner_product(&v[i], &w);
368 let h_ij = h[[i, j]];
369 w = &w - &v[i].mapv(|vi| vi * h_ij);
370 }
371
372 let w_norm = vector_norm(&w);
373 h[[j + 1, j]] = T::from_real(w_norm);
374
375 let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
376 if w_norm < breakdown_tol {
377 inner_converged = true;
378 } else {
379 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
380 }
381
382 for i in 0..j {
384 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
385 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
386 h[[i, j]] = temp;
387 }
388
389 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
390 cs.push(c);
391 sn.push(s);
392
393 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
394 h[[j + 1, j]] = T::zero();
395
396 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
397 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
398 g[j] = temp;
399
400 let abs_residual = g[j + 1].norm();
401 let rel_residual = abs_residual / b_norm;
402 let abs_tol = T::Real::from_f64(1e-20).unwrap();
403
404 if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
405 let y = solve_upper_triangular(&h, &g, j + 1);
406
407 for (i, &yi) in y.iter().enumerate() {
408 x = &x + &v[i].mapv(|vi| vi * yi);
409 }
410
411 return GmresSolution {
412 x,
413 iterations: total_iterations,
414 restarts,
415 residual: rel_residual,
416 converged: true,
417 status: if inner_converged && rel_residual >= config.tolerance && abs_residual >= abs_tol {
418 SolverStatus::Breakdown
419 } else {
420 SolverStatus::Converged
421 },
422 };
423 }
424 }
425
426 let y = solve_upper_triangular(&h, &g, m);
428 for (i, &yi) in y.iter().enumerate() {
429 x = &x + &v[i].mapv(|vi| vi * yi);
430 }
431
432 restarts += 1;
433 }
434
435 let ax = operator.apply(&x);
437 let residual: Array1<T> = b - &ax;
438 let r = precond.apply(&residual);
439 let rel_residual = vector_norm(&r) / b_norm;
440
441 GmresSolution {
442 x,
443 iterations: total_iterations,
444 restarts,
445 residual: rel_residual,
446 converged: false,
447 status: SolverStatus::MaxIterationsReached,
448 }
449}
450
451pub fn gmres_preconditioned_with_guess<T, A, P>(
456 operator: &A,
457 precond: &P,
458 b: &Array1<T>,
459 x0: Option<&Array1<T>>,
460 config: &GmresConfig<T::Real>,
461) -> GmresSolution<T>
462where
463 T: ComplexField,
464 A: LinearOperator<T>,
465 P: Preconditioner<T>,
466{
467 let n = b.len();
468 let m = config.restart;
469
470 let mut x = match x0 {
472 Some(guess) => guess.clone(),
473 None => Array1::from_elem(n, T::zero()),
474 };
475
476 let pb = precond.apply(b);
478 let b_norm = vector_norm(&pb);
479 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
480 if b_norm < tol_threshold {
481 return GmresSolution {
482 x,
483 iterations: 0,
484 restarts: 0,
485 residual: T::Real::zero(),
486 converged: true,
487 status: SolverStatus::Converged,
488 };
489 }
490
491 let mut total_iterations = 0;
492 let mut restarts = 0;
493
494 for _outer in 0..config.max_iterations {
495 let ax = operator.apply(&x);
497 let residual: Array1<T> = b - &ax;
498 let r = precond.apply(&residual);
499 let beta = vector_norm(&r);
500
501 let rel_residual = beta / b_norm;
502 if rel_residual < config.tolerance {
503 return GmresSolution {
504 x,
505 iterations: total_iterations,
506 restarts,
507 residual: rel_residual,
508 converged: true,
509 status: SolverStatus::Converged,
510 };
511 }
512
513 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
514 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
515
516 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
517 let mut cs: Vec<T> = Vec::with_capacity(m);
518 let mut sn: Vec<T> = Vec::with_capacity(m);
519
520 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
521 g[0] = T::from_real(beta);
522
523 let mut inner_converged = false;
524
525 for j in 0..m {
526 total_iterations += 1;
527
528 let av = operator.apply(&v[j]);
530 let mut w = precond.apply(&av);
531
532 for i in 0..=j {
534 h[[i, j]] = inner_product(&v[i], &w);
535 let h_ij = h[[i, j]];
536 w = &w - &v[i].mapv(|vi| vi * h_ij);
537 }
538
539 let w_norm = vector_norm(&w);
540 h[[j + 1, j]] = T::from_real(w_norm);
541
542 let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
543 if w_norm < breakdown_tol {
544 inner_converged = true;
545 } else {
546 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
547 }
548
549 for i in 0..j {
551 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
552 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
553 h[[i, j]] = temp;
554 }
555
556 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
557 cs.push(c);
558 sn.push(s);
559
560 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
561 h[[j + 1, j]] = T::zero();
562
563 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
564 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
565 g[j] = temp;
566
567 let abs_residual = g[j + 1].norm();
568 let rel_residual = abs_residual / b_norm;
569 let abs_tol = T::Real::from_f64(1e-20).unwrap();
570
571 if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
572 let y = solve_upper_triangular(&h, &g, j + 1);
573
574 for (i, &yi) in y.iter().enumerate() {
575 x = &x + &v[i].mapv(|vi| vi * yi);
576 }
577
578 return GmresSolution {
579 x,
580 iterations: total_iterations,
581 restarts,
582 residual: rel_residual,
583 converged: true,
584 status: if inner_converged && rel_residual >= config.tolerance && abs_residual >= abs_tol {
585 SolverStatus::Breakdown
586 } else {
587 SolverStatus::Converged
588 },
589 };
590 }
591 }
592
593 let y = solve_upper_triangular(&h, &g, m);
595 for (i, &yi) in y.iter().enumerate() {
596 x = &x + &v[i].mapv(|vi| vi * yi);
597 }
598
599 restarts += 1;
600 }
601
602 let ax = operator.apply(&x);
604 let residual: Array1<T> = b - &ax;
605 let r = precond.apply(&residual);
606 let rel_residual = vector_norm(&r) / b_norm;
607
608 GmresSolution {
609 x,
610 iterations: total_iterations,
611 restarts,
612 residual: rel_residual,
613 converged: false,
614 status: SolverStatus::MaxIterationsReached,
615 }
616}
617
618#[inline]
620fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
621 let tol = T::Real::from_f64(1e-30).unwrap();
622 if b.norm() < tol {
623 return (T::one(), T::zero());
624 }
625 if a.norm() < tol {
626 return (T::zero(), T::one());
627 }
628
629 let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
630 let c = a * T::from_real(T::Real::one() / r);
631 let s = b * T::from_real(T::Real::one() / r);
632
633 (c, s)
634}
635
636fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
638 let mut y = vec![T::zero(); k];
639 let tol = T::Real::from_f64(1e-30).unwrap();
640
641 for i in (0..k).rev() {
642 let mut sum = g[i];
643 for j in (i + 1)..k {
644 sum -= h[[i, j]] * y[j];
645 }
646 if h[[i, i]].norm() > tol {
647 y[i] = sum * h[[i, i]].inv();
648 }
649 }
650
651 y
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657 use crate::sparse::CsrMatrix;
658 use approx::assert_relative_eq;
659 use ndarray::array;
660 use num_complex::Complex64;
661
662 #[test]
663 fn test_gmres_simple() {
664 let dense = array![
665 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
666 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
667 ];
668
669 let a = CsrMatrix::from_dense(&dense, 1e-15);
670 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
671
672 let config = GmresConfig {
673 max_iterations: 100,
674 restart: 10,
675 tolerance: 1e-10,
676 print_interval: 0,
677 };
678
679 let solution = gmres(&a, &b, &config);
680
681 assert!(solution.converged, "GMRES should converge");
682
683 let ax = a.matvec(&solution.x);
685 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
686 assert!(error < 1e-8, "Solution should satisfy Ax = b");
687 }
688
689 #[test]
690 fn test_gmres_identity() {
691 let n = 5;
692 let id: CsrMatrix<Complex64> = CsrMatrix::identity(n);
693 let b = Array1::from_iter((1..=n).map(|i| Complex64::new(i as f64, 0.0)));
694
695 let config = GmresConfig {
696 max_iterations: 10,
697 restart: 10,
698 tolerance: 1e-12,
699 print_interval: 0,
700 };
701
702 let solution = gmres(&id, &b, &config);
703
704 assert!(solution.converged);
705 assert!(solution.iterations <= 2);
706
707 let error: f64 = (&solution.x - &b)
708 .iter()
709 .map(|e| e.norm_sqr())
710 .sum::<f64>()
711 .sqrt();
712 assert!(error < 1e-10);
713 }
714
715 #[test]
716 fn test_gmres_f64() {
717 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
718
719 let a = CsrMatrix::from_dense(&dense, 1e-15);
720 let b = array![1.0_f64, 2.0];
721
722 let config = GmresConfig {
723 max_iterations: 100,
724 restart: 10,
725 tolerance: 1e-10,
726 print_interval: 0,
727 };
728
729 let solution = gmres(&a, &b, &config);
730
731 assert!(solution.converged);
732
733 let ax = a.matvec(&solution.x);
734 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
735 assert_relative_eq!(error, 0.0, epsilon = 1e-8);
736 }
737}