1use ndarray::{Array1, Array2};
27use num_complex::Complex64;
28
29#[derive(Debug, Clone)]
31pub struct GmresConfig {
32 pub max_iterations: usize,
34 pub restart: usize,
37 pub tolerance: f64,
39 pub print_interval: usize,
41}
42
43impl Default for GmresConfig {
44 fn default() -> Self {
45 Self {
46 max_iterations: 100,
47 restart: 30, tolerance: 1e-6,
49 print_interval: 10,
50 }
51 }
52}
53
54impl GmresConfig {
55 pub fn for_small_problems() -> Self {
57 Self {
58 max_iterations: 50,
59 restart: 50,
60 tolerance: 1e-8,
61 print_interval: 0,
62 }
63 }
64
65 pub fn for_large_bem() -> Self {
67 Self {
68 max_iterations: 200,
69 restart: 100,
70 tolerance: 1e-6,
71 print_interval: 20,
72 }
73 }
74
75 pub fn with_restart(restart: usize) -> Self {
77 Self {
78 restart,
79 ..Default::default()
80 }
81 }
82}
83
84#[derive(Debug)]
86pub struct GmresSolution {
87 pub x: Array1<Complex64>,
89 pub iterations: usize,
91 pub restarts: usize,
93 pub residual: f64,
95 pub converged: bool,
97}
98
99pub fn gmres_solve<F>(
117 matvec: F,
118 b: &Array1<Complex64>,
119 x0: Option<&Array1<Complex64>>,
120 config: &GmresConfig,
121) -> GmresSolution
122where
123 F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
124{
125 let n = b.len();
126 let m = config.restart;
127
128 let mut x = match x0 {
130 Some(x0) => x0.clone(),
131 None => Array1::zeros(n),
132 };
133
134 let b_norm = vector_norm(b);
136 if b_norm < 1e-15 {
137 return GmresSolution {
138 x,
139 iterations: 0,
140 restarts: 0,
141 residual: 0.0,
142 converged: true,
143 };
144 }
145
146 let mut total_iterations = 0;
147 let mut restarts = 0;
148
149 for _outer in 0..config.max_iterations {
151 let ax = matvec(&x);
153 let r: Array1<Complex64> = b - &ax;
154 let beta = vector_norm(&r);
155
156 let rel_residual = beta / b_norm;
158 if rel_residual < config.tolerance {
159 return GmresSolution {
160 x,
161 iterations: total_iterations,
162 restarts,
163 residual: rel_residual,
164 converged: true,
165 };
166 }
167
168 let mut v: Vec<Array1<Complex64>> = Vec::with_capacity(m + 1);
171 v.push(&r / Complex64::new(beta, 0.0));
172
173 let mut h: Array2<Complex64> = Array2::zeros((m + 1, m));
175
176 let mut cs: Vec<Complex64> = Vec::with_capacity(m);
178 let mut sn: Vec<Complex64> = Vec::with_capacity(m);
179
180 let mut g: Array1<Complex64> = Array1::zeros(m + 1);
182 g[0] = Complex64::new(beta, 0.0);
183
184 let mut inner_converged = false;
185
186 for j in 0..m {
188 total_iterations += 1;
189
190 let w = matvec(&v[j]);
192 let mut w = w;
193
194 for i in 0..=j {
196 h[[i, j]] = inner_product(&v[i], &w);
197 w = &w - &(&v[i] * h[[i, j]]);
198 }
199
200 h[[j + 1, j]] = Complex64::new(vector_norm(&w), 0.0);
201
202 if h[[j + 1, j]].norm() < 1e-14 {
204 inner_converged = true;
206 } else {
207 v.push(&w / h[[j + 1, j]]);
209 }
210
211 for i in 0..j {
213 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
214 h[[i + 1, j]] = -sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
215 h[[i, j]] = temp;
216 }
217
218 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
220 cs.push(c);
221 sn.push(s);
222
223 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
225 h[[j + 1, j]] = Complex64::new(0.0, 0.0);
226
227 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
228 g[j + 1] = -s * g[j] + c * g[j + 1];
229 g[j] = temp;
230
231 let rel_residual = g[j + 1].norm() / b_norm;
233
234 if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
235 eprintln!(
236 "GMRES iteration {} (restart {}): relative residual = {:.6e}",
237 total_iterations, restarts, rel_residual
238 );
239 }
240
241 if rel_residual < config.tolerance || inner_converged {
242 let y = solve_upper_triangular(&h, &g, j + 1);
244
245 for (i, yi) in y.iter().enumerate() {
247 x = &x + &(&v[i] * *yi);
248 }
249
250 return GmresSolution {
251 x,
252 iterations: total_iterations,
253 restarts,
254 residual: rel_residual,
255 converged: true,
256 };
257 }
258 }
259
260 let y = solve_upper_triangular(&h, &g, m);
262
263 for (i, yi) in y.iter().enumerate() {
265 x = &x + &(&v[i] * *yi);
266 }
267
268 restarts += 1;
269 }
270
271 let ax = matvec(&x);
273 let r: Array1<Complex64> = b - &ax;
274 let rel_residual = vector_norm(&r) / b_norm;
275
276 GmresSolution {
277 x,
278 iterations: total_iterations,
279 restarts,
280 residual: rel_residual,
281 converged: false,
282 }
283}
284
285pub fn gmres_solve_preconditioned<F, P>(
289 matvec: F,
290 precond_solve: P,
291 b: &Array1<Complex64>,
292 x0: Option<&Array1<Complex64>>,
293 config: &GmresConfig,
294) -> GmresSolution
295where
296 F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
297 P: Fn(&Array1<Complex64>) -> Array1<Complex64>,
298{
299 let n = b.len();
300 let m = config.restart;
301
302 let mut x = match x0 {
304 Some(x0) => x0.clone(),
305 None => Array1::zeros(n),
306 };
307
308 let pb = precond_solve(b);
310 let b_norm = vector_norm(&pb);
311 if b_norm < 1e-15 {
312 return GmresSolution {
313 x,
314 iterations: 0,
315 restarts: 0,
316 residual: 0.0,
317 converged: true,
318 };
319 }
320
321 let mut total_iterations = 0;
322 let mut restarts = 0;
323
324 for _outer in 0..config.max_iterations {
325 let ax = matvec(&x);
327 let residual: Array1<Complex64> = b - &ax;
328 let r = precond_solve(&residual);
329 let beta = vector_norm(&r);
330
331 let rel_residual = beta / b_norm;
332 if rel_residual < config.tolerance {
333 return GmresSolution {
334 x,
335 iterations: total_iterations,
336 restarts,
337 residual: rel_residual,
338 converged: true,
339 };
340 }
341
342 let mut v: Vec<Array1<Complex64>> = Vec::with_capacity(m + 1);
343 v.push(&r / Complex64::new(beta, 0.0));
344
345 let mut h: Array2<Complex64> = Array2::zeros((m + 1, m));
346 let mut cs: Vec<Complex64> = Vec::with_capacity(m);
347 let mut sn: Vec<Complex64> = Vec::with_capacity(m);
348
349 let mut g: Array1<Complex64> = Array1::zeros(m + 1);
350 g[0] = Complex64::new(beta, 0.0);
351
352 let mut inner_converged = false;
353
354 for j in 0..m {
355 total_iterations += 1;
356
357 let av = matvec(&v[j]);
359 let w = precond_solve(&av);
360 let mut w = w;
361
362 for i in 0..=j {
364 h[[i, j]] = inner_product(&v[i], &w);
365 w = &w - &(&v[i] * h[[i, j]]);
366 }
367
368 h[[j + 1, j]] = Complex64::new(vector_norm(&w), 0.0);
369
370 if h[[j + 1, j]].norm() < 1e-14 {
371 inner_converged = true;
372 } else {
373 v.push(&w / h[[j + 1, j]]);
374 }
375
376 for i in 0..j {
378 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
379 h[[i + 1, j]] = -sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
380 h[[i, j]] = temp;
381 }
382
383 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
385 cs.push(c);
386 sn.push(s);
387
388 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
390 h[[j + 1, j]] = Complex64::new(0.0, 0.0);
391
392 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
393 g[j + 1] = -s * g[j] + c * g[j + 1];
394 g[j] = temp;
395
396 let rel_residual = g[j + 1].norm() / b_norm;
397
398 if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
399 eprintln!(
400 "GMRES (precond) iteration {} (restart {}): relative residual = {:.6e}",
401 total_iterations, restarts, rel_residual
402 );
403 }
404
405 if rel_residual < config.tolerance || inner_converged {
406 let y = solve_upper_triangular(&h, &g, j + 1);
407
408 for (i, yi) in y.iter().enumerate() {
409 x = &x + &(&v[i] * *yi);
410 }
411
412 return GmresSolution {
413 x,
414 iterations: total_iterations,
415 restarts,
416 residual: rel_residual,
417 converged: true,
418 };
419 }
420 }
421
422 let y = solve_upper_triangular(&h, &g, m);
424 for (i, yi) in y.iter().enumerate() {
425 x = &x + &(&v[i] * *yi);
426 }
427
428 restarts += 1;
429 }
430
431 let ax = matvec(&x);
433 let residual: Array1<Complex64> = b - &ax;
434 let r = precond_solve(&residual);
435 let rel_residual = vector_norm(&r) / b_norm;
436
437 GmresSolution {
438 x,
439 iterations: total_iterations,
440 restarts,
441 residual: rel_residual,
442 converged: false,
443 }
444}
445
446#[inline]
448fn inner_product(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
449 x.iter()
450 .zip(y.iter())
451 .map(|(xi, yi)| xi.conj() * yi)
452 .sum()
453}
454
455#[inline]
457fn vector_norm(x: &Array1<Complex64>) -> f64 {
458 x.iter().map(|xi| xi.norm_sqr()).sum::<f64>().sqrt()
459}
460
461#[inline]
467fn givens_rotation(a: Complex64, b: Complex64) -> (Complex64, Complex64) {
468 if b.norm() < 1e-30 {
469 return (Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0));
470 }
471 if a.norm() < 1e-30 {
472 return (Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0));
473 }
474
475 let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
476 let c = a / Complex64::new(r, 0.0);
477 let s = b / Complex64::new(r, 0.0);
478
479 (c, s)
480}
481
482fn solve_upper_triangular(h: &Array2<Complex64>, g: &Array1<Complex64>, k: usize) -> Vec<Complex64> {
486 let mut y = vec![Complex64::new(0.0, 0.0); k];
487
488 for i in (0..k).rev() {
489 let mut sum = g[i];
490 for j in (i + 1)..k {
491 sum -= h[[i, j]] * y[j];
492 }
493 if h[[i, i]].norm() > 1e-30 {
494 y[i] = sum / h[[i, i]];
495 }
496 }
497
498 y
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_gmres_simple() {
507 let a = Array2::from_shape_vec(
509 (2, 2),
510 vec![
511 Complex64::new(4.0, 0.0),
512 Complex64::new(1.0, 0.0),
513 Complex64::new(1.0, 0.0),
514 Complex64::new(3.0, 0.0),
515 ],
516 )
517 .unwrap();
518
519 let b = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)]);
520
521 let matvec = |x: &Array1<Complex64>| a.dot(x);
522
523 let config = GmresConfig {
524 max_iterations: 100,
525 restart: 10,
526 tolerance: 1e-10,
527 print_interval: 0,
528 };
529
530 let solution = gmres_solve(&matvec, &b, None, &config);
531
532 assert!(solution.converged, "GMRES should converge");
533
534 let ax = a.dot(&solution.x);
536 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
537 assert!(error < 1e-8, "Solution should satisfy Ax = b");
538 }
539
540 #[test]
541 fn test_gmres_complex() {
542 let a = Array2::from_shape_vec(
544 (2, 2),
545 vec![
546 Complex64::new(2.0, 1.0),
547 Complex64::new(0.0, -1.0),
548 Complex64::new(0.0, 1.0),
549 Complex64::new(2.0, -1.0),
550 ],
551 )
552 .unwrap();
553
554 let b = Array1::from_vec(vec![Complex64::new(1.0, 1.0), Complex64::new(1.0, -1.0)]);
555
556 let matvec = |x: &Array1<Complex64>| a.dot(x);
557
558 let config = GmresConfig {
559 max_iterations: 100,
560 restart: 10,
561 tolerance: 1e-10,
562 print_interval: 0,
563 };
564
565 let solution = gmres_solve(&matvec, &b, None, &config);
566
567 assert!(solution.converged, "GMRES should converge for complex system");
568
569 let ax = a.dot(&solution.x);
570 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
571 assert!(error < 1e-8, "Solution should satisfy Ax = b");
572 }
573
574 #[test]
575 fn test_gmres_identity() {
576 let n = 5;
578 let b = Array1::from_vec(
579 (1..=n)
580 .map(|i| Complex64::new(i as f64, 0.0))
581 .collect::<Vec<_>>(),
582 );
583
584 let matvec = |x: &Array1<Complex64>| x.clone();
585
586 let config = GmresConfig {
587 max_iterations: 10,
588 restart: 10,
589 tolerance: 1e-12,
590 print_interval: 0,
591 };
592
593 let solution = gmres_solve(&matvec, &b, None, &config);
594
595 assert!(solution.converged);
596 assert!(solution.iterations <= 2);
597
598 let error: f64 = (&solution.x - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
599 assert!(error < 1e-10);
600 }
601
602 #[test]
603 fn test_gmres_non_symmetric() {
604 let a = Array2::from_shape_vec(
606 (3, 3),
607 vec![
608 Complex64::new(4.0, 0.0),
609 Complex64::new(1.0, 0.0),
610 Complex64::new(0.0, 0.0),
611 Complex64::new(2.0, 0.0),
612 Complex64::new(5.0, 0.0),
613 Complex64::new(1.0, 0.0),
614 Complex64::new(0.0, 0.0),
615 Complex64::new(1.0, 0.0),
616 Complex64::new(3.0, 0.0),
617 ],
618 )
619 .unwrap();
620
621 let b = Array1::from_vec(vec![
622 Complex64::new(5.0, 0.0),
623 Complex64::new(8.0, 0.0),
624 Complex64::new(4.0, 0.0),
625 ]);
626
627 let matvec = |x: &Array1<Complex64>| a.dot(x);
628
629 let config = GmresConfig {
630 max_iterations: 100,
631 restart: 10,
632 tolerance: 1e-10,
633 print_interval: 0,
634 };
635
636 let solution = gmres_solve(&matvec, &b, None, &config);
637
638 assert!(
639 solution.converged,
640 "GMRES should converge for non-symmetric system"
641 );
642
643 let ax = a.dot(&solution.x);
644 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
645 assert!(error < 1e-8, "Solution should satisfy Ax = b");
646 }
647
648 #[test]
649 fn test_gmres_restart() {
650 let n = 20;
652 let mut a = Array2::zeros((n, n));
653
654 for i in 0..n {
656 a[[i, i]] = Complex64::new(4.0, 0.0);
657 if i > 0 {
658 a[[i, i - 1]] = Complex64::new(-1.0, 0.1);
659 }
660 if i < n - 1 {
661 a[[i, i + 1]] = Complex64::new(-1.0, -0.1);
662 }
663 }
664
665 let b: Array1<Complex64> =
666 Array1::from_iter((0..n).map(|i| Complex64::new((i as f64 * 0.3).sin(), 0.0)));
667
668 let matvec = |x: &Array1<Complex64>| a.dot(x);
669
670 let config = GmresConfig {
672 max_iterations: 50,
673 restart: 5,
674 tolerance: 1e-10,
675 print_interval: 0,
676 };
677
678 let solution = gmres_solve(&matvec, &b, None, &config);
679
680 println!(
681 "GMRES: {} iterations, {} restarts, residual = {:.6e}",
682 solution.iterations, solution.restarts, solution.residual
683 );
684
685 assert!(solution.converged, "GMRES should converge even with restarts");
686
687 let ax = a.dot(&solution.x);
688 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
689 let rel_error = error / b.iter().map(|bi| bi.norm_sqr()).sum::<f64>().sqrt();
690 assert!(rel_error < 1e-8, "Solution should be accurate");
691 }
692
693 #[test]
694 fn test_gmres_config_builders() {
695 let small = GmresConfig::for_small_problems();
696 assert_eq!(small.restart, 50);
697
698 let large = GmresConfig::for_large_bem();
699 assert_eq!(large.restart, 100);
700
701 let custom = GmresConfig::with_restart(75);
702 assert_eq!(custom.restart, 75);
703 }
704}