bem/core/solver/
gmres.rs

1//! GMRES (Generalized Minimal Residual) solver
2//!
3//! Implementation of the restarted GMRES algorithm based on Saad & Schultz (1986).
4//!
5//! GMRES is often the best choice for large non-symmetric systems like BEM.
6//! It minimizes the residual in a Krylov subspace and has smooth, monotonic
7//! convergence behavior.
8//!
9//! ## Algorithm
10//!
11//! GMRES builds an orthonormal basis for the Krylov subspace K_m = span{r, Ar, A²r, ...}
12//! using the Arnoldi process, then finds the solution that minimizes ||b - Ax|| in this
13//! subspace using a QR factorization approach.
14//!
15//! ## Restart
16//!
17//! Full GMRES requires storing m vectors where m is the number of iterations.
18//! For large problems, this becomes prohibitive. Restarted GMRES(m) restarts
19//! after m iterations to limit memory usage.
20//!
21//! Typical values:
22//! - m = 20-50 for moderate problems
23//! - m = 50-100 for larger problems
24//! - m = 100-200 for very large BEM problems
25
26use ndarray::{Array1, Array2};
27use num_complex::Complex64;
28
29/// GMRES solver configuration
30#[derive(Debug, Clone)]
31pub struct GmresConfig {
32    /// Maximum number of outer iterations (restarts)
33    pub max_iterations: usize,
34    /// Restart parameter (number of inner iterations before restart)
35    /// Also known as the Krylov subspace dimension
36    pub restart: usize,
37    /// Relative tolerance for convergence
38    pub tolerance: f64,
39    /// Print progress every N iterations (0 = no output)
40    pub print_interval: usize,
41}
42
43impl Default for GmresConfig {
44    fn default() -> Self {
45        Self {
46            max_iterations: 100,
47            restart: 30, // GMRES(30) - good default for BEM
48            tolerance: 1e-6,
49            print_interval: 10,
50        }
51    }
52}
53
54impl GmresConfig {
55    /// Create config for small problems (uses more memory, faster convergence)
56    pub fn for_small_problems() -> Self {
57        Self {
58            max_iterations: 50,
59            restart: 50,
60            tolerance: 1e-8,
61            print_interval: 0,
62        }
63    }
64
65    /// Create config for large BEM problems
66    pub fn for_large_bem() -> Self {
67        Self {
68            max_iterations: 200,
69            restart: 100,
70            tolerance: 1e-6,
71            print_interval: 20,
72        }
73    }
74
75    /// Create config with specific restart parameter
76    pub fn with_restart(restart: usize) -> Self {
77        Self {
78            restart,
79            ..Default::default()
80        }
81    }
82}
83
84/// GMRES solver result
85#[derive(Debug)]
86pub struct GmresSolution {
87    /// Solution vector
88    pub x: Array1<Complex64>,
89    /// Total number of matrix-vector products
90    pub iterations: usize,
91    /// Number of restarts performed
92    pub restarts: usize,
93    /// Final relative residual
94    pub residual: f64,
95    /// Whether convergence was achieved
96    pub converged: bool,
97}
98
99/// Solve Ax = b using the restarted GMRES method
100///
101/// # Arguments
102/// * `matvec` - Function to compute A*x for a given x
103/// * `b` - Right-hand side vector
104/// * `x0` - Optional initial guess (defaults to zero)
105/// * `config` - Solver configuration
106///
107/// # Returns
108/// Solution struct containing x, iteration count, and convergence info
109///
110/// # Example
111/// ```ignore
112/// let config = GmresConfig::with_restart(50);
113/// let matvec = |x: &Array1<Complex64>| system.matvec(x);
114/// let solution = gmres_solve(&matvec, &rhs, None, &config);
115/// ```
116pub fn gmres_solve<F>(
117    matvec: F,
118    b: &Array1<Complex64>,
119    x0: Option<&Array1<Complex64>>,
120    config: &GmresConfig,
121) -> GmresSolution
122where
123    F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
124{
125    let n = b.len();
126    let m = config.restart;
127
128    // Initialize solution vector
129    let mut x = match x0 {
130        Some(x0) => x0.clone(),
131        None => Array1::zeros(n),
132    };
133
134    // Compute initial residual norm for relative tolerance
135    let b_norm = vector_norm(b);
136    if b_norm < 1e-15 {
137        return GmresSolution {
138            x,
139            iterations: 0,
140            restarts: 0,
141            residual: 0.0,
142            converged: true,
143        };
144    }
145
146    let mut total_iterations = 0;
147    let mut restarts = 0;
148
149    // Outer iteration (restarts)
150    for _outer in 0..config.max_iterations {
151        // Compute residual r = b - Ax
152        let ax = matvec(&x);
153        let r: Array1<Complex64> = b - &ax;
154        let beta = vector_norm(&r);
155
156        // Check convergence
157        let rel_residual = beta / b_norm;
158        if rel_residual < config.tolerance {
159            return GmresSolution {
160                x,
161                iterations: total_iterations,
162                restarts,
163                residual: rel_residual,
164                converged: true,
165            };
166        }
167
168        // Initialize Krylov basis V (n x (m+1))
169        // V[:,0] = r / ||r||
170        let mut v: Vec<Array1<Complex64>> = Vec::with_capacity(m + 1);
171        v.push(&r / Complex64::new(beta, 0.0));
172
173        // Upper Hessenberg matrix H ((m+1) x m)
174        let mut h: Array2<Complex64> = Array2::zeros((m + 1, m));
175
176        // Givens rotation coefficients
177        let mut cs: Vec<Complex64> = Vec::with_capacity(m);
178        let mut sn: Vec<Complex64> = Vec::with_capacity(m);
179
180        // Right-hand side of least squares problem
181        let mut g: Array1<Complex64> = Array1::zeros(m + 1);
182        g[0] = Complex64::new(beta, 0.0);
183
184        let mut inner_converged = false;
185
186        // Inner iteration (Arnoldi process)
187        for j in 0..m {
188            total_iterations += 1;
189
190            // w = A * v_j
191            let w = matvec(&v[j]);
192            let mut w = w;
193
194            // Modified Gram-Schmidt orthogonalization
195            for i in 0..=j {
196                h[[i, j]] = inner_product(&v[i], &w);
197                w = &w - &(&v[i] * h[[i, j]]);
198            }
199
200            h[[j + 1, j]] = Complex64::new(vector_norm(&w), 0.0);
201
202            // Check for breakdown (lucky convergence or numerical issues)
203            if h[[j + 1, j]].norm() < 1e-14 {
204                // We can still get a solution from the current subspace
205                inner_converged = true;
206            } else {
207                // Normalize and add to basis
208                v.push(&w / h[[j + 1, j]]);
209            }
210
211            // Apply previous Givens rotations to new column of H
212            for i in 0..j {
213                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
214                h[[i + 1, j]] = -sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
215                h[[i, j]] = temp;
216            }
217
218            // Compute new Givens rotation
219            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
220            cs.push(c);
221            sn.push(s);
222
223            // Apply Givens rotation to H and g
224            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
225            h[[j + 1, j]] = Complex64::new(0.0, 0.0);
226
227            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
228            g[j + 1] = -s * g[j] + c * g[j + 1];
229            g[j] = temp;
230
231            // Check convergence
232            let rel_residual = g[j + 1].norm() / b_norm;
233
234            if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
235                eprintln!(
236                    "GMRES iteration {} (restart {}): relative residual = {:.6e}",
237                    total_iterations, restarts, rel_residual
238                );
239            }
240
241            if rel_residual < config.tolerance || inner_converged {
242                // Solve upper triangular system Hy = g
243                let y = solve_upper_triangular(&h, &g, j + 1);
244
245                // Update solution x = x + V * y
246                for (i, yi) in y.iter().enumerate() {
247                    x = &x + &(&v[i] * *yi);
248                }
249
250                return GmresSolution {
251                    x,
252                    iterations: total_iterations,
253                    restarts,
254                    residual: rel_residual,
255                    converged: true,
256                };
257            }
258        }
259
260        // Maximum inner iterations reached, compute solution and restart
261        let y = solve_upper_triangular(&h, &g, m);
262
263        // Update solution x = x + V * y
264        for (i, yi) in y.iter().enumerate() {
265            x = &x + &(&v[i] * *yi);
266        }
267
268        restarts += 1;
269    }
270
271    // Compute final residual
272    let ax = matvec(&x);
273    let r: Array1<Complex64> = b - &ax;
274    let rel_residual = vector_norm(&r) / b_norm;
275
276    GmresSolution {
277        x,
278        iterations: total_iterations,
279        restarts,
280        residual: rel_residual,
281        converged: false,
282    }
283}
284
285/// GMRES solver with preconditioner
286///
287/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
288pub fn gmres_solve_preconditioned<F, P>(
289    matvec: F,
290    precond_solve: P,
291    b: &Array1<Complex64>,
292    x0: Option<&Array1<Complex64>>,
293    config: &GmresConfig,
294) -> GmresSolution
295where
296    F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
297    P: Fn(&Array1<Complex64>) -> Array1<Complex64>,
298{
299    let n = b.len();
300    let m = config.restart;
301
302    // Initialize solution vector
303    let mut x = match x0 {
304        Some(x0) => x0.clone(),
305        None => Array1::zeros(n),
306    };
307
308    // Compute preconditioned RHS norm
309    let pb = precond_solve(b);
310    let b_norm = vector_norm(&pb);
311    if b_norm < 1e-15 {
312        return GmresSolution {
313            x,
314            iterations: 0,
315            restarts: 0,
316            residual: 0.0,
317            converged: true,
318        };
319    }
320
321    let mut total_iterations = 0;
322    let mut restarts = 0;
323
324    for _outer in 0..config.max_iterations {
325        // Compute preconditioned residual r = M⁻¹(b - Ax)
326        let ax = matvec(&x);
327        let residual: Array1<Complex64> = b - &ax;
328        let r = precond_solve(&residual);
329        let beta = vector_norm(&r);
330
331        let rel_residual = beta / b_norm;
332        if rel_residual < config.tolerance {
333            return GmresSolution {
334                x,
335                iterations: total_iterations,
336                restarts,
337                residual: rel_residual,
338                converged: true,
339            };
340        }
341
342        let mut v: Vec<Array1<Complex64>> = Vec::with_capacity(m + 1);
343        v.push(&r / Complex64::new(beta, 0.0));
344
345        let mut h: Array2<Complex64> = Array2::zeros((m + 1, m));
346        let mut cs: Vec<Complex64> = Vec::with_capacity(m);
347        let mut sn: Vec<Complex64> = Vec::with_capacity(m);
348
349        let mut g: Array1<Complex64> = Array1::zeros(m + 1);
350        g[0] = Complex64::new(beta, 0.0);
351
352        let mut inner_converged = false;
353
354        for j in 0..m {
355            total_iterations += 1;
356
357            // w = M⁻¹ * A * v_j
358            let av = matvec(&v[j]);
359            let w = precond_solve(&av);
360            let mut w = w;
361
362            // Modified Gram-Schmidt
363            for i in 0..=j {
364                h[[i, j]] = inner_product(&v[i], &w);
365                w = &w - &(&v[i] * h[[i, j]]);
366            }
367
368            h[[j + 1, j]] = Complex64::new(vector_norm(&w), 0.0);
369
370            if h[[j + 1, j]].norm() < 1e-14 {
371                inner_converged = true;
372            } else {
373                v.push(&w / h[[j + 1, j]]);
374            }
375
376            // Apply previous Givens rotations
377            for i in 0..j {
378                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
379                h[[i + 1, j]] = -sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
380                h[[i, j]] = temp;
381            }
382
383            // Compute new Givens rotation
384            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
385            cs.push(c);
386            sn.push(s);
387
388            // Apply to H and g
389            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
390            h[[j + 1, j]] = Complex64::new(0.0, 0.0);
391
392            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
393            g[j + 1] = -s * g[j] + c * g[j + 1];
394            g[j] = temp;
395
396            let rel_residual = g[j + 1].norm() / b_norm;
397
398            if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
399                eprintln!(
400                    "GMRES (precond) iteration {} (restart {}): relative residual = {:.6e}",
401                    total_iterations, restarts, rel_residual
402                );
403            }
404
405            if rel_residual < config.tolerance || inner_converged {
406                let y = solve_upper_triangular(&h, &g, j + 1);
407
408                for (i, yi) in y.iter().enumerate() {
409                    x = &x + &(&v[i] * *yi);
410                }
411
412                return GmresSolution {
413                    x,
414                    iterations: total_iterations,
415                    restarts,
416                    residual: rel_residual,
417                    converged: true,
418                };
419            }
420        }
421
422        // Restart
423        let y = solve_upper_triangular(&h, &g, m);
424        for (i, yi) in y.iter().enumerate() {
425            x = &x + &(&v[i] * *yi);
426        }
427
428        restarts += 1;
429    }
430
431    // Final residual
432    let ax = matvec(&x);
433    let residual: Array1<Complex64> = b - &ax;
434    let r = precond_solve(&residual);
435    let rel_residual = vector_norm(&r) / b_norm;
436
437    GmresSolution {
438        x,
439        iterations: total_iterations,
440        restarts,
441        residual: rel_residual,
442        converged: false,
443    }
444}
445
446/// Compute inner product (x, y) = Σ conj(x_i) * y_i
447#[inline]
448fn inner_product(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
449    x.iter()
450        .zip(y.iter())
451        .map(|(xi, yi)| xi.conj() * yi)
452        .sum()
453}
454
455/// Compute vector 2-norm
456#[inline]
457fn vector_norm(x: &Array1<Complex64>) -> f64 {
458    x.iter().map(|xi| xi.norm_sqr()).sum::<f64>().sqrt()
459}
460
461/// Compute Givens rotation coefficients
462///
463/// Returns (c, s) such that:
464/// [c*  s*] [a]   [r]
465/// [-s  c ] [b] = [0]
466#[inline]
467fn givens_rotation(a: Complex64, b: Complex64) -> (Complex64, Complex64) {
468    if b.norm() < 1e-30 {
469        return (Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0));
470    }
471    if a.norm() < 1e-30 {
472        return (Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0));
473    }
474
475    let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
476    let c = a / Complex64::new(r, 0.0);
477    let s = b / Complex64::new(r, 0.0);
478
479    (c, s)
480}
481
482/// Solve upper triangular system Hy = g
483///
484/// Only uses the upper k×k portion of H
485fn solve_upper_triangular(h: &Array2<Complex64>, g: &Array1<Complex64>, k: usize) -> Vec<Complex64> {
486    let mut y = vec![Complex64::new(0.0, 0.0); k];
487
488    for i in (0..k).rev() {
489        let mut sum = g[i];
490        for j in (i + 1)..k {
491            sum -= h[[i, j]] * y[j];
492        }
493        if h[[i, i]].norm() > 1e-30 {
494            y[i] = sum / h[[i, i]];
495        }
496    }
497
498    y
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_gmres_simple() {
507        // Simple 2x2 positive definite system
508        let a = Array2::from_shape_vec(
509            (2, 2),
510            vec![
511                Complex64::new(4.0, 0.0),
512                Complex64::new(1.0, 0.0),
513                Complex64::new(1.0, 0.0),
514                Complex64::new(3.0, 0.0),
515            ],
516        )
517        .unwrap();
518
519        let b = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)]);
520
521        let matvec = |x: &Array1<Complex64>| a.dot(x);
522
523        let config = GmresConfig {
524            max_iterations: 100,
525            restart: 10,
526            tolerance: 1e-10,
527            print_interval: 0,
528        };
529
530        let solution = gmres_solve(&matvec, &b, None, &config);
531
532        assert!(solution.converged, "GMRES should converge");
533
534        // Verify solution: Ax ≈ b
535        let ax = a.dot(&solution.x);
536        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
537        assert!(error < 1e-8, "Solution should satisfy Ax = b");
538    }
539
540    #[test]
541    fn test_gmres_complex() {
542        // Complex non-symmetric system
543        let a = Array2::from_shape_vec(
544            (2, 2),
545            vec![
546                Complex64::new(2.0, 1.0),
547                Complex64::new(0.0, -1.0),
548                Complex64::new(0.0, 1.0),
549                Complex64::new(2.0, -1.0),
550            ],
551        )
552        .unwrap();
553
554        let b = Array1::from_vec(vec![Complex64::new(1.0, 1.0), Complex64::new(1.0, -1.0)]);
555
556        let matvec = |x: &Array1<Complex64>| a.dot(x);
557
558        let config = GmresConfig {
559            max_iterations: 100,
560            restart: 10,
561            tolerance: 1e-10,
562            print_interval: 0,
563        };
564
565        let solution = gmres_solve(&matvec, &b, None, &config);
566
567        assert!(solution.converged, "GMRES should converge for complex system");
568
569        let ax = a.dot(&solution.x);
570        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
571        assert!(error < 1e-8, "Solution should satisfy Ax = b");
572    }
573
574    #[test]
575    fn test_gmres_identity() {
576        // Identity matrix - should converge in 1 iteration
577        let n = 5;
578        let b = Array1::from_vec(
579            (1..=n)
580                .map(|i| Complex64::new(i as f64, 0.0))
581                .collect::<Vec<_>>(),
582        );
583
584        let matvec = |x: &Array1<Complex64>| x.clone();
585
586        let config = GmresConfig {
587            max_iterations: 10,
588            restart: 10,
589            tolerance: 1e-12,
590            print_interval: 0,
591        };
592
593        let solution = gmres_solve(&matvec, &b, None, &config);
594
595        assert!(solution.converged);
596        assert!(solution.iterations <= 2);
597
598        let error: f64 = (&solution.x - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
599        assert!(error < 1e-10);
600    }
601
602    #[test]
603    fn test_gmres_non_symmetric() {
604        // Non-symmetric matrix - GMRES should handle this
605        let a = Array2::from_shape_vec(
606            (3, 3),
607            vec![
608                Complex64::new(4.0, 0.0),
609                Complex64::new(1.0, 0.0),
610                Complex64::new(0.0, 0.0),
611                Complex64::new(2.0, 0.0),
612                Complex64::new(5.0, 0.0),
613                Complex64::new(1.0, 0.0),
614                Complex64::new(0.0, 0.0),
615                Complex64::new(1.0, 0.0),
616                Complex64::new(3.0, 0.0),
617            ],
618        )
619        .unwrap();
620
621        let b = Array1::from_vec(vec![
622            Complex64::new(5.0, 0.0),
623            Complex64::new(8.0, 0.0),
624            Complex64::new(4.0, 0.0),
625        ]);
626
627        let matvec = |x: &Array1<Complex64>| a.dot(x);
628
629        let config = GmresConfig {
630            max_iterations: 100,
631            restart: 10,
632            tolerance: 1e-10,
633            print_interval: 0,
634        };
635
636        let solution = gmres_solve(&matvec, &b, None, &config);
637
638        assert!(
639            solution.converged,
640            "GMRES should converge for non-symmetric system"
641        );
642
643        let ax = a.dot(&solution.x);
644        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
645        assert!(error < 1e-8, "Solution should satisfy Ax = b");
646    }
647
648    #[test]
649    fn test_gmres_restart() {
650        // Larger system to test restart behavior
651        let n = 20;
652        let mut a = Array2::zeros((n, n));
653
654        // Tridiagonal matrix
655        for i in 0..n {
656            a[[i, i]] = Complex64::new(4.0, 0.0);
657            if i > 0 {
658                a[[i, i - 1]] = Complex64::new(-1.0, 0.1);
659            }
660            if i < n - 1 {
661                a[[i, i + 1]] = Complex64::new(-1.0, -0.1);
662            }
663        }
664
665        let b: Array1<Complex64> =
666            Array1::from_iter((0..n).map(|i| Complex64::new((i as f64 * 0.3).sin(), 0.0)));
667
668        let matvec = |x: &Array1<Complex64>| a.dot(x);
669
670        // Use small restart to force restarts
671        let config = GmresConfig {
672            max_iterations: 50,
673            restart: 5,
674            tolerance: 1e-10,
675            print_interval: 0,
676        };
677
678        let solution = gmres_solve(&matvec, &b, None, &config);
679
680        println!(
681            "GMRES: {} iterations, {} restarts, residual = {:.6e}",
682            solution.iterations, solution.restarts, solution.residual
683        );
684
685        assert!(solution.converged, "GMRES should converge even with restarts");
686
687        let ax = a.dot(&solution.x);
688        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
689        let rel_error = error / b.iter().map(|bi| bi.norm_sqr()).sum::<f64>().sqrt();
690        assert!(rel_error < 1e-8, "Solution should be accurate");
691    }
692
693    #[test]
694    fn test_gmres_config_builders() {
695        let small = GmresConfig::for_small_problems();
696        assert_eq!(small.restart, 50);
697
698        let large = GmresConfig::for_large_bem();
699        assert_eq!(large.restart, 100);
700
701        let custom = GmresConfig::with_restart(75);
702        assert_eq!(custom.restart, 75);
703    }
704}