1use crate::traits::{ComplexField, LinearOperator, Preconditioner};
10use ndarray::{Array1, Array2};
11use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
12
13#[derive(Debug, Clone)]
15pub struct GmresConfig<R> {
16 pub max_iterations: usize,
18 pub restart: usize,
20 pub tolerance: R,
22 pub print_interval: usize,
24}
25
26impl Default for GmresConfig<f64> {
27 fn default() -> Self {
28 Self {
29 max_iterations: 100,
30 restart: 30,
31 tolerance: 1e-6,
32 print_interval: 0,
33 }
34 }
35}
36
37impl Default for GmresConfig<f32> {
38 fn default() -> Self {
39 Self {
40 max_iterations: 100,
41 restart: 30,
42 tolerance: 1e-5,
43 print_interval: 0,
44 }
45 }
46}
47
48impl<R: Float + FromPrimitive> GmresConfig<R> {
49 pub fn for_small_problems() -> Self {
51 Self {
52 max_iterations: 50,
53 restart: 50,
54 tolerance: R::from_f64(1e-8).unwrap(),
55 print_interval: 0,
56 }
57 }
58
59 pub fn with_restart(restart: usize) -> Self
61 where
62 Self: Default,
63 {
64 Self {
65 restart,
66 ..Default::default()
67 }
68 }
69}
70
71#[derive(Debug)]
73pub struct GmresSolution<T: ComplexField> {
74 pub x: Array1<T>,
76 pub iterations: usize,
78 pub restarts: usize,
80 pub residual: T::Real,
82 pub converged: bool,
84}
85
86pub fn gmres<T, A>(operator: &A, b: &Array1<T>, config: &GmresConfig<T::Real>) -> GmresSolution<T>
96where
97 T: ComplexField,
98 A: LinearOperator<T>,
99{
100 gmres_with_guess(operator, b, None, config)
101}
102
103pub fn gmres_with_guess<T, A>(
105 operator: &A,
106 b: &Array1<T>,
107 x0: Option<&Array1<T>>,
108 config: &GmresConfig<T::Real>,
109) -> GmresSolution<T>
110where
111 T: ComplexField,
112 A: LinearOperator<T>,
113{
114 let n = b.len();
115 let m = config.restart;
116
117 let mut x = match x0 {
119 Some(x0) => x0.clone(),
120 None => Array1::from_elem(n, T::zero()),
121 };
122
123 let b_norm = vector_norm(b);
125 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
126 if b_norm < tol_threshold {
127 return GmresSolution {
128 x,
129 iterations: 0,
130 restarts: 0,
131 residual: T::Real::zero(),
132 converged: true,
133 };
134 }
135
136 let mut total_iterations = 0;
137 let mut restarts = 0;
138
139 for _outer in 0..config.max_iterations {
141 let ax = operator.apply(&x);
143 let r: Array1<T> = b - &ax;
144 let beta = vector_norm(&r);
145
146 let rel_residual = beta / b_norm;
148 if rel_residual < config.tolerance {
149 return GmresSolution {
150 x,
151 iterations: total_iterations,
152 restarts,
153 residual: rel_residual,
154 converged: true,
155 };
156 }
157
158 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
160 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
161
162 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
164
165 let mut cs: Vec<T> = Vec::with_capacity(m);
167 let mut sn: Vec<T> = Vec::with_capacity(m);
168
169 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
171 g[0] = T::from_real(beta);
172
173 let mut inner_converged = false;
174
175 for j in 0..m {
177 total_iterations += 1;
178
179 let mut w = operator.apply(&v[j]);
181
182 for i in 0..=j {
184 h[[i, j]] = inner_product(&v[i], &w);
185 let h_ij = h[[i, j]];
186 w = &w - &v[i].mapv(|vi| vi * h_ij);
187 }
188
189 let w_norm = vector_norm(&w);
190 h[[j + 1, j]] = T::from_real(w_norm);
191
192 let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
194 if w_norm < breakdown_tol {
195 inner_converged = true;
196 } else {
197 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
198 }
199
200 for i in 0..j {
202 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
203 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
204 h[[i, j]] = temp;
205 }
206
207 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
209 cs.push(c);
210 sn.push(s);
211
212 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
214 h[[j + 1, j]] = T::zero();
215
216 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
217 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
218 g[j] = temp;
219
220 let rel_residual = g[j + 1].norm() / b_norm;
222
223 if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
224 log::info!(
225 "GMRES iteration {} (restart {}): relative residual = {:.6e}",
226 total_iterations,
227 restarts,
228 rel_residual.to_f64().unwrap_or(0.0)
229 );
230 }
231
232 if rel_residual < config.tolerance || inner_converged {
233 let y = solve_upper_triangular(&h, &g, j + 1);
235
236 for (i, &yi) in y.iter().enumerate() {
238 x = &x + &v[i].mapv(|vi| vi * yi);
239 }
240
241 return GmresSolution {
242 x,
243 iterations: total_iterations,
244 restarts,
245 residual: rel_residual,
246 converged: true,
247 };
248 }
249 }
250
251 let y = solve_upper_triangular(&h, &g, m);
253
254 for (i, &yi) in y.iter().enumerate() {
255 x = &x + &v[i].mapv(|vi| vi * yi);
256 }
257
258 restarts += 1;
259 }
260
261 let ax = operator.apply(&x);
263 let r: Array1<T> = b - &ax;
264 let rel_residual = vector_norm(&r) / b_norm;
265
266 GmresSolution {
267 x,
268 iterations: total_iterations,
269 restarts,
270 residual: rel_residual,
271 converged: false,
272 }
273}
274
275pub fn gmres_preconditioned<T, A, P>(
279 operator: &A,
280 precond: &P,
281 b: &Array1<T>,
282 config: &GmresConfig<T::Real>,
283) -> GmresSolution<T>
284where
285 T: ComplexField,
286 A: LinearOperator<T>,
287 P: Preconditioner<T>,
288{
289 let n = b.len();
290 let m = config.restart;
291
292 let mut x = Array1::from_elem(n, T::zero());
293
294 let pb = precond.apply(b);
296 let b_norm = vector_norm(&pb);
297 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
298 if b_norm < tol_threshold {
299 return GmresSolution {
300 x,
301 iterations: 0,
302 restarts: 0,
303 residual: T::Real::zero(),
304 converged: true,
305 };
306 }
307
308 let mut total_iterations = 0;
309 let mut restarts = 0;
310
311 for _outer in 0..config.max_iterations {
312 let ax = operator.apply(&x);
314 let residual: Array1<T> = b - &ax;
315 let r = precond.apply(&residual);
316 let beta = vector_norm(&r);
317
318 let rel_residual = beta / b_norm;
319 if rel_residual < config.tolerance {
320 return GmresSolution {
321 x,
322 iterations: total_iterations,
323 restarts,
324 residual: rel_residual,
325 converged: true,
326 };
327 }
328
329 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
330 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
331
332 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
333 let mut cs: Vec<T> = Vec::with_capacity(m);
334 let mut sn: Vec<T> = Vec::with_capacity(m);
335
336 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
337 g[0] = T::from_real(beta);
338
339 let mut inner_converged = false;
340
341 for j in 0..m {
342 total_iterations += 1;
343
344 let av = operator.apply(&v[j]);
346 let mut w = precond.apply(&av);
347
348 for i in 0..=j {
350 h[[i, j]] = inner_product(&v[i], &w);
351 let h_ij = h[[i, j]];
352 w = &w - &v[i].mapv(|vi| vi * h_ij);
353 }
354
355 let w_norm = vector_norm(&w);
356 h[[j + 1, j]] = T::from_real(w_norm);
357
358 let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
359 if w_norm < breakdown_tol {
360 inner_converged = true;
361 } else {
362 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
363 }
364
365 for i in 0..j {
367 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
368 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
369 h[[i, j]] = temp;
370 }
371
372 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
373 cs.push(c);
374 sn.push(s);
375
376 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
377 h[[j + 1, j]] = T::zero();
378
379 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
380 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
381 g[j] = temp;
382
383 let rel_residual = g[j + 1].norm() / b_norm;
384
385 if rel_residual < config.tolerance || inner_converged {
386 let y = solve_upper_triangular(&h, &g, j + 1);
387
388 for (i, &yi) in y.iter().enumerate() {
389 x = &x + &v[i].mapv(|vi| vi * yi);
390 }
391
392 return GmresSolution {
393 x,
394 iterations: total_iterations,
395 restarts,
396 residual: rel_residual,
397 converged: true,
398 };
399 }
400 }
401
402 let y = solve_upper_triangular(&h, &g, m);
404 for (i, &yi) in y.iter().enumerate() {
405 x = &x + &v[i].mapv(|vi| vi * yi);
406 }
407
408 restarts += 1;
409 }
410
411 let ax = operator.apply(&x);
413 let residual: Array1<T> = b - &ax;
414 let r = precond.apply(&residual);
415 let rel_residual = vector_norm(&r) / b_norm;
416
417 GmresSolution {
418 x,
419 iterations: total_iterations,
420 restarts,
421 residual: rel_residual,
422 converged: false,
423 }
424}
425
426pub fn gmres_preconditioned_with_guess<T, A, P>(
431 operator: &A,
432 precond: &P,
433 b: &Array1<T>,
434 x0: Option<&Array1<T>>,
435 config: &GmresConfig<T::Real>,
436) -> GmresSolution<T>
437where
438 T: ComplexField,
439 A: LinearOperator<T>,
440 P: Preconditioner<T>,
441{
442 let n = b.len();
443 let m = config.restart;
444
445 let mut x = match x0 {
447 Some(guess) => guess.clone(),
448 None => Array1::from_elem(n, T::zero()),
449 };
450
451 let pb = precond.apply(b);
453 let b_norm = vector_norm(&pb);
454 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
455 if b_norm < tol_threshold {
456 return GmresSolution {
457 x,
458 iterations: 0,
459 restarts: 0,
460 residual: T::Real::zero(),
461 converged: true,
462 };
463 }
464
465 let mut total_iterations = 0;
466 let mut restarts = 0;
467
468 for _outer in 0..config.max_iterations {
469 let ax = operator.apply(&x);
471 let residual: Array1<T> = b - &ax;
472 let r = precond.apply(&residual);
473 let beta = vector_norm(&r);
474
475 let rel_residual = beta / b_norm;
476 if rel_residual < config.tolerance {
477 return GmresSolution {
478 x,
479 iterations: total_iterations,
480 restarts,
481 residual: rel_residual,
482 converged: true,
483 };
484 }
485
486 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
487 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
488
489 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
490 let mut cs: Vec<T> = Vec::with_capacity(m);
491 let mut sn: Vec<T> = Vec::with_capacity(m);
492
493 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
494 g[0] = T::from_real(beta);
495
496 let mut inner_converged = false;
497
498 for j in 0..m {
499 total_iterations += 1;
500
501 let av = operator.apply(&v[j]);
503 let mut w = precond.apply(&av);
504
505 for i in 0..=j {
507 h[[i, j]] = inner_product(&v[i], &w);
508 let h_ij = h[[i, j]];
509 w = &w - &v[i].mapv(|vi| vi * h_ij);
510 }
511
512 let w_norm = vector_norm(&w);
513 h[[j + 1, j]] = T::from_real(w_norm);
514
515 let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
516 if w_norm < breakdown_tol {
517 inner_converged = true;
518 } else {
519 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
520 }
521
522 for i in 0..j {
524 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
525 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
526 h[[i, j]] = temp;
527 }
528
529 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
530 cs.push(c);
531 sn.push(s);
532
533 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
534 h[[j + 1, j]] = T::zero();
535
536 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
537 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
538 g[j] = temp;
539
540 let rel_residual = g[j + 1].norm() / b_norm;
541
542 if rel_residual < config.tolerance || inner_converged {
543 let y = solve_upper_triangular(&h, &g, j + 1);
544
545 for (i, &yi) in y.iter().enumerate() {
546 x = &x + &v[i].mapv(|vi| vi * yi);
547 }
548
549 return GmresSolution {
550 x,
551 iterations: total_iterations,
552 restarts,
553 residual: rel_residual,
554 converged: true,
555 };
556 }
557 }
558
559 let y = solve_upper_triangular(&h, &g, m);
561 for (i, &yi) in y.iter().enumerate() {
562 x = &x + &v[i].mapv(|vi| vi * yi);
563 }
564
565 restarts += 1;
566 }
567
568 let ax = operator.apply(&x);
570 let residual: Array1<T> = b - &ax;
571 let r = precond.apply(&residual);
572 let rel_residual = vector_norm(&r) / b_norm;
573
574 GmresSolution {
575 x,
576 iterations: total_iterations,
577 restarts,
578 residual: rel_residual,
579 converged: false,
580 }
581}
582
583#[inline]
585fn inner_product<T: ComplexField>(x: &Array1<T>, y: &Array1<T>) -> T {
586 x.iter()
587 .zip(y.iter())
588 .fold(T::zero(), |acc, (&xi, &yi)| acc + xi.conj() * yi)
589}
590
591#[inline]
593fn vector_norm<T: ComplexField>(x: &Array1<T>) -> T::Real {
594 x.iter()
595 .map(|xi| xi.norm_sqr())
596 .fold(T::Real::zero(), |acc, v| acc + v)
597 .sqrt()
598}
599
600#[inline]
602fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
603 let tol = T::Real::from_f64(1e-30).unwrap();
604 if b.norm() < tol {
605 return (T::one(), T::zero());
606 }
607 if a.norm() < tol {
608 return (T::zero(), T::one());
609 }
610
611 let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
612 let c = a * T::from_real(T::Real::one() / r);
613 let s = b * T::from_real(T::Real::one() / r);
614
615 (c, s)
616}
617
618fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
620 let mut y = vec![T::zero(); k];
621 let tol = T::Real::from_f64(1e-30).unwrap();
622
623 for i in (0..k).rev() {
624 let mut sum = g[i];
625 for j in (i + 1)..k {
626 sum -= h[[i, j]] * y[j];
627 }
628 if h[[i, i]].norm() > tol {
629 y[i] = sum * h[[i, i]].inv();
630 }
631 }
632
633 y
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639 use crate::sparse::CsrMatrix;
640 use approx::assert_relative_eq;
641 use ndarray::array;
642 use num_complex::Complex64;
643
644 #[test]
645 fn test_gmres_simple() {
646 let dense = array![
647 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
648 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
649 ];
650
651 let a = CsrMatrix::from_dense(&dense, 1e-15);
652 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
653
654 let config = GmresConfig {
655 max_iterations: 100,
656 restart: 10,
657 tolerance: 1e-10,
658 print_interval: 0,
659 };
660
661 let solution = gmres(&a, &b, &config);
662
663 assert!(solution.converged, "GMRES should converge");
664
665 let ax = a.matvec(&solution.x);
667 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
668 assert!(error < 1e-8, "Solution should satisfy Ax = b");
669 }
670
671 #[test]
672 fn test_gmres_identity() {
673 let n = 5;
674 let id: CsrMatrix<Complex64> = CsrMatrix::identity(n);
675 let b = Array1::from_iter((1..=n).map(|i| Complex64::new(i as f64, 0.0)));
676
677 let config = GmresConfig {
678 max_iterations: 10,
679 restart: 10,
680 tolerance: 1e-12,
681 print_interval: 0,
682 };
683
684 let solution = gmres(&id, &b, &config);
685
686 assert!(solution.converged);
687 assert!(solution.iterations <= 2);
688
689 let error: f64 = (&solution.x - &b)
690 .iter()
691 .map(|e| e.norm_sqr())
692 .sum::<f64>()
693 .sqrt();
694 assert!(error < 1e-10);
695 }
696
697 #[test]
698 fn test_gmres_f64() {
699 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
700
701 let a = CsrMatrix::from_dense(&dense, 1e-15);
702 let b = array![1.0_f64, 2.0];
703
704 let config = GmresConfig {
705 max_iterations: 100,
706 restart: 10,
707 tolerance: 1e-10,
708 print_interval: 0,
709 };
710
711 let solution = gmres(&a, &b, &config);
712
713 assert!(solution.converged);
714
715 let ax = a.matvec(&solution.x);
716 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
717 assert_relative_eq!(error, 0.0, epsilon = 1e-8);
718 }
719}