solvers/iterative/
gmres.rs

1//! GMRES (Generalized Minimal Residual) solver
2//!
3//! Implementation of the restarted GMRES algorithm based on Saad & Schultz (1986).
4//!
5//! GMRES is often the best choice for large non-symmetric systems.
6//! It minimizes the residual in a Krylov subspace and has smooth, monotonic
7//! convergence behavior.
8
9use crate::traits::{ComplexField, LinearOperator, Preconditioner};
10use ndarray::{Array1, Array2};
11use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
12
13/// GMRES solver configuration
14#[derive(Debug, Clone)]
15pub struct GmresConfig<R> {
16    /// Maximum number of outer iterations (restarts)
17    pub max_iterations: usize,
18    /// Restart parameter (number of inner iterations before restart)
19    pub restart: usize,
20    /// Relative tolerance for convergence
21    pub tolerance: R,
22    /// Print progress every N iterations (0 = no output)
23    pub print_interval: usize,
24}
25
26impl Default for GmresConfig<f64> {
27    fn default() -> Self {
28        Self {
29            max_iterations: 100,
30            restart: 30,
31            tolerance: 1e-6,
32            print_interval: 0,
33        }
34    }
35}
36
37impl Default for GmresConfig<f32> {
38    fn default() -> Self {
39        Self {
40            max_iterations: 100,
41            restart: 30,
42            tolerance: 1e-5,
43            print_interval: 0,
44        }
45    }
46}
47
48impl<R: Float + FromPrimitive> GmresConfig<R> {
49    /// Create config for small problems
50    pub fn for_small_problems() -> Self {
51        Self {
52            max_iterations: 50,
53            restart: 50,
54            tolerance: R::from_f64(1e-8).unwrap(),
55            print_interval: 0,
56        }
57    }
58
59    /// Create config with specific restart parameter
60    pub fn with_restart(restart: usize) -> Self
61    where
62        Self: Default,
63    {
64        Self {
65            restart,
66            ..Default::default()
67        }
68    }
69}
70
71/// GMRES solver result
72#[derive(Debug)]
73pub struct GmresSolution<T: ComplexField> {
74    /// Solution vector
75    pub x: Array1<T>,
76    /// Total number of matrix-vector products
77    pub iterations: usize,
78    /// Number of restarts performed
79    pub restarts: usize,
80    /// Final relative residual
81    pub residual: T::Real,
82    /// Whether convergence was achieved
83    pub converged: bool,
84}
85
86/// Solve Ax = b using the restarted GMRES method
87///
88/// # Arguments
89/// * `operator` - Linear operator representing A
90/// * `b` - Right-hand side vector
91/// * `config` - Solver configuration
92///
93/// # Returns
94/// Solution struct containing x, iteration count, and convergence info
95pub fn gmres<T, A>(operator: &A, b: &Array1<T>, config: &GmresConfig<T::Real>) -> GmresSolution<T>
96where
97    T: ComplexField,
98    A: LinearOperator<T>,
99{
100    gmres_with_guess(operator, b, None, config)
101}
102
103/// Solve Ax = b using GMRES with an initial guess
104pub fn gmres_with_guess<T, A>(
105    operator: &A,
106    b: &Array1<T>,
107    x0: Option<&Array1<T>>,
108    config: &GmresConfig<T::Real>,
109) -> GmresSolution<T>
110where
111    T: ComplexField,
112    A: LinearOperator<T>,
113{
114    let n = b.len();
115    let m = config.restart;
116
117    // Initialize solution vector
118    let mut x = match x0 {
119        Some(x0) => x0.clone(),
120        None => Array1::from_elem(n, T::zero()),
121    };
122
123    // Compute initial residual norm for relative tolerance
124    let b_norm = vector_norm(b);
125    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
126    if b_norm < tol_threshold {
127        return GmresSolution {
128            x,
129            iterations: 0,
130            restarts: 0,
131            residual: T::Real::zero(),
132            converged: true,
133        };
134    }
135
136    let mut total_iterations = 0;
137    let mut restarts = 0;
138
139    // Outer iteration (restarts)
140    for _outer in 0..config.max_iterations {
141        // Compute residual r = b - Ax
142        let ax = operator.apply(&x);
143        let r: Array1<T> = b - &ax;
144        let beta = vector_norm(&r);
145
146        // Check convergence
147        let rel_residual = beta / b_norm;
148        if rel_residual < config.tolerance {
149            return GmresSolution {
150                x,
151                iterations: total_iterations,
152                restarts,
153                residual: rel_residual,
154                converged: true,
155            };
156        }
157
158        // Initialize Krylov basis V
159        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
160        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
161
162        // Upper Hessenberg matrix H
163        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
164
165        // Givens rotation coefficients
166        let mut cs: Vec<T> = Vec::with_capacity(m);
167        let mut sn: Vec<T> = Vec::with_capacity(m);
168
169        // Right-hand side of least squares problem
170        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
171        g[0] = T::from_real(beta);
172
173        let mut inner_converged = false;
174
175        // Inner iteration (Arnoldi process)
176        for j in 0..m {
177            total_iterations += 1;
178
179            // w = A * v_j
180            let mut w = operator.apply(&v[j]);
181
182            // Modified Gram-Schmidt orthogonalization
183            for i in 0..=j {
184                h[[i, j]] = inner_product(&v[i], &w);
185                let h_ij = h[[i, j]];
186                w = &w - &v[i].mapv(|vi| vi * h_ij);
187            }
188
189            let w_norm = vector_norm(&w);
190            h[[j + 1, j]] = T::from_real(w_norm);
191
192            // Check for breakdown
193            let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
194            if w_norm < breakdown_tol {
195                inner_converged = true;
196            } else {
197                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
198            }
199
200            // Apply previous Givens rotations to new column of H
201            for i in 0..j {
202                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
203                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
204                h[[i, j]] = temp;
205            }
206
207            // Compute new Givens rotation
208            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
209            cs.push(c);
210            sn.push(s);
211
212            // Apply Givens rotation to H and g
213            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
214            h[[j + 1, j]] = T::zero();
215
216            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
217            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
218            g[j] = temp;
219
220            // Check convergence
221            let rel_residual = g[j + 1].norm() / b_norm;
222
223            if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
224                log::info!(
225                    "GMRES iteration {} (restart {}): relative residual = {:.6e}",
226                    total_iterations,
227                    restarts,
228                    rel_residual.to_f64().unwrap_or(0.0)
229                );
230            }
231
232            if rel_residual < config.tolerance || inner_converged {
233                // Solve upper triangular system Hy = g
234                let y = solve_upper_triangular(&h, &g, j + 1);
235
236                // Update solution x = x + V * y
237                for (i, &yi) in y.iter().enumerate() {
238                    x = &x + &v[i].mapv(|vi| vi * yi);
239                }
240
241                return GmresSolution {
242                    x,
243                    iterations: total_iterations,
244                    restarts,
245                    residual: rel_residual,
246                    converged: true,
247                };
248            }
249        }
250
251        // Maximum inner iterations reached, compute solution and restart
252        let y = solve_upper_triangular(&h, &g, m);
253
254        for (i, &yi) in y.iter().enumerate() {
255            x = &x + &v[i].mapv(|vi| vi * yi);
256        }
257
258        restarts += 1;
259    }
260
261    // Compute final residual
262    let ax = operator.apply(&x);
263    let r: Array1<T> = b - &ax;
264    let rel_residual = vector_norm(&r) / b_norm;
265
266    GmresSolution {
267        x,
268        iterations: total_iterations,
269        restarts,
270        residual: rel_residual,
271        converged: false,
272    }
273}
274
275/// GMRES solver with preconditioner
276///
277/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
278pub fn gmres_preconditioned<T, A, P>(
279    operator: &A,
280    precond: &P,
281    b: &Array1<T>,
282    config: &GmresConfig<T::Real>,
283) -> GmresSolution<T>
284where
285    T: ComplexField,
286    A: LinearOperator<T>,
287    P: Preconditioner<T>,
288{
289    let n = b.len();
290    let m = config.restart;
291
292    let mut x = Array1::from_elem(n, T::zero());
293
294    // Compute preconditioned RHS norm
295    let pb = precond.apply(b);
296    let b_norm = vector_norm(&pb);
297    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
298    if b_norm < tol_threshold {
299        return GmresSolution {
300            x,
301            iterations: 0,
302            restarts: 0,
303            residual: T::Real::zero(),
304            converged: true,
305        };
306    }
307
308    let mut total_iterations = 0;
309    let mut restarts = 0;
310
311    for _outer in 0..config.max_iterations {
312        // Compute preconditioned residual r = M⁻¹(b - Ax)
313        let ax = operator.apply(&x);
314        let residual: Array1<T> = b - &ax;
315        let r = precond.apply(&residual);
316        let beta = vector_norm(&r);
317
318        let rel_residual = beta / b_norm;
319        if rel_residual < config.tolerance {
320            return GmresSolution {
321                x,
322                iterations: total_iterations,
323                restarts,
324                residual: rel_residual,
325                converged: true,
326            };
327        }
328
329        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
330        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
331
332        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
333        let mut cs: Vec<T> = Vec::with_capacity(m);
334        let mut sn: Vec<T> = Vec::with_capacity(m);
335
336        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
337        g[0] = T::from_real(beta);
338
339        let mut inner_converged = false;
340
341        for j in 0..m {
342            total_iterations += 1;
343
344            // w = M⁻¹ * A * v_j
345            let av = operator.apply(&v[j]);
346            let mut w = precond.apply(&av);
347
348            // Modified Gram-Schmidt
349            for i in 0..=j {
350                h[[i, j]] = inner_product(&v[i], &w);
351                let h_ij = h[[i, j]];
352                w = &w - &v[i].mapv(|vi| vi * h_ij);
353            }
354
355            let w_norm = vector_norm(&w);
356            h[[j + 1, j]] = T::from_real(w_norm);
357
358            let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
359            if w_norm < breakdown_tol {
360                inner_converged = true;
361            } else {
362                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
363            }
364
365            // Apply previous Givens rotations
366            for i in 0..j {
367                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
368                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
369                h[[i, j]] = temp;
370            }
371
372            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
373            cs.push(c);
374            sn.push(s);
375
376            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
377            h[[j + 1, j]] = T::zero();
378
379            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
380            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
381            g[j] = temp;
382
383            let rel_residual = g[j + 1].norm() / b_norm;
384
385            if rel_residual < config.tolerance || inner_converged {
386                let y = solve_upper_triangular(&h, &g, j + 1);
387
388                for (i, &yi) in y.iter().enumerate() {
389                    x = &x + &v[i].mapv(|vi| vi * yi);
390                }
391
392                return GmresSolution {
393                    x,
394                    iterations: total_iterations,
395                    restarts,
396                    residual: rel_residual,
397                    converged: true,
398                };
399            }
400        }
401
402        // Restart
403        let y = solve_upper_triangular(&h, &g, m);
404        for (i, &yi) in y.iter().enumerate() {
405            x = &x + &v[i].mapv(|vi| vi * yi);
406        }
407
408        restarts += 1;
409    }
410
411    // Final residual
412    let ax = operator.apply(&x);
413    let residual: Array1<T> = b - &ax;
414    let r = precond.apply(&residual);
415    let rel_residual = vector_norm(&r) / b_norm;
416
417    GmresSolution {
418        x,
419        iterations: total_iterations,
420        restarts,
421        residual: rel_residual,
422        converged: false,
423    }
424}
425
426/// GMRES solver with preconditioner and initial guess
427///
428/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
429/// with an optional initial guess x0.
430pub fn gmres_preconditioned_with_guess<T, A, P>(
431    operator: &A,
432    precond: &P,
433    b: &Array1<T>,
434    x0: Option<&Array1<T>>,
435    config: &GmresConfig<T::Real>,
436) -> GmresSolution<T>
437where
438    T: ComplexField,
439    A: LinearOperator<T>,
440    P: Preconditioner<T>,
441{
442    let n = b.len();
443    let m = config.restart;
444
445    // Initialize solution vector from initial guess or zero
446    let mut x = match x0 {
447        Some(guess) => guess.clone(),
448        None => Array1::from_elem(n, T::zero()),
449    };
450
451    // Compute preconditioned RHS norm
452    let pb = precond.apply(b);
453    let b_norm = vector_norm(&pb);
454    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
455    if b_norm < tol_threshold {
456        return GmresSolution {
457            x,
458            iterations: 0,
459            restarts: 0,
460            residual: T::Real::zero(),
461            converged: true,
462        };
463    }
464
465    let mut total_iterations = 0;
466    let mut restarts = 0;
467
468    for _outer in 0..config.max_iterations {
469        // Compute preconditioned residual r = M⁻¹(b - Ax)
470        let ax = operator.apply(&x);
471        let residual: Array1<T> = b - &ax;
472        let r = precond.apply(&residual);
473        let beta = vector_norm(&r);
474
475        let rel_residual = beta / b_norm;
476        if rel_residual < config.tolerance {
477            return GmresSolution {
478                x,
479                iterations: total_iterations,
480                restarts,
481                residual: rel_residual,
482                converged: true,
483            };
484        }
485
486        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
487        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
488
489        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
490        let mut cs: Vec<T> = Vec::with_capacity(m);
491        let mut sn: Vec<T> = Vec::with_capacity(m);
492
493        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
494        g[0] = T::from_real(beta);
495
496        let mut inner_converged = false;
497
498        for j in 0..m {
499            total_iterations += 1;
500
501            // w = M⁻¹ * A * v_j
502            let av = operator.apply(&v[j]);
503            let mut w = precond.apply(&av);
504
505            // Modified Gram-Schmidt
506            for i in 0..=j {
507                h[[i, j]] = inner_product(&v[i], &w);
508                let h_ij = h[[i, j]];
509                w = &w - &v[i].mapv(|vi| vi * h_ij);
510            }
511
512            let w_norm = vector_norm(&w);
513            h[[j + 1, j]] = T::from_real(w_norm);
514
515            let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
516            if w_norm < breakdown_tol {
517                inner_converged = true;
518            } else {
519                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
520            }
521
522            // Apply previous Givens rotations
523            for i in 0..j {
524                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
525                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
526                h[[i, j]] = temp;
527            }
528
529            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
530            cs.push(c);
531            sn.push(s);
532
533            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
534            h[[j + 1, j]] = T::zero();
535
536            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
537            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
538            g[j] = temp;
539
540            let rel_residual = g[j + 1].norm() / b_norm;
541
542            if rel_residual < config.tolerance || inner_converged {
543                let y = solve_upper_triangular(&h, &g, j + 1);
544
545                for (i, &yi) in y.iter().enumerate() {
546                    x = &x + &v[i].mapv(|vi| vi * yi);
547                }
548
549                return GmresSolution {
550                    x,
551                    iterations: total_iterations,
552                    restarts,
553                    residual: rel_residual,
554                    converged: true,
555                };
556            }
557        }
558
559        // Restart
560        let y = solve_upper_triangular(&h, &g, m);
561        for (i, &yi) in y.iter().enumerate() {
562            x = &x + &v[i].mapv(|vi| vi * yi);
563        }
564
565        restarts += 1;
566    }
567
568    // Final residual
569    let ax = operator.apply(&x);
570    let residual: Array1<T> = b - &ax;
571    let r = precond.apply(&residual);
572    let rel_residual = vector_norm(&r) / b_norm;
573
574    GmresSolution {
575        x,
576        iterations: total_iterations,
577        restarts,
578        residual: rel_residual,
579        converged: false,
580    }
581}
582
583/// Compute inner product (x, y) = Σ conj(x_i) * y_i
584#[inline]
585fn inner_product<T: ComplexField>(x: &Array1<T>, y: &Array1<T>) -> T {
586    x.iter()
587        .zip(y.iter())
588        .fold(T::zero(), |acc, (&xi, &yi)| acc + xi.conj() * yi)
589}
590
591/// Compute vector 2-norm
592#[inline]
593fn vector_norm<T: ComplexField>(x: &Array1<T>) -> T::Real {
594    x.iter()
595        .map(|xi| xi.norm_sqr())
596        .fold(T::Real::zero(), |acc, v| acc + v)
597        .sqrt()
598}
599
600/// Compute Givens rotation coefficients
601#[inline]
602fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
603    let tol = T::Real::from_f64(1e-30).unwrap();
604    if b.norm() < tol {
605        return (T::one(), T::zero());
606    }
607    if a.norm() < tol {
608        return (T::zero(), T::one());
609    }
610
611    let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
612    let c = a * T::from_real(T::Real::one() / r);
613    let s = b * T::from_real(T::Real::one() / r);
614
615    (c, s)
616}
617
618/// Solve upper triangular system Hy = g
619fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
620    let mut y = vec![T::zero(); k];
621    let tol = T::Real::from_f64(1e-30).unwrap();
622
623    for i in (0..k).rev() {
624        let mut sum = g[i];
625        for j in (i + 1)..k {
626            sum -= h[[i, j]] * y[j];
627        }
628        if h[[i, i]].norm() > tol {
629            y[i] = sum * h[[i, i]].inv();
630        }
631    }
632
633    y
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639    use crate::sparse::CsrMatrix;
640    use approx::assert_relative_eq;
641    use ndarray::array;
642    use num_complex::Complex64;
643
644    #[test]
645    fn test_gmres_simple() {
646        let dense = array![
647            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
648            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
649        ];
650
651        let a = CsrMatrix::from_dense(&dense, 1e-15);
652        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
653
654        let config = GmresConfig {
655            max_iterations: 100,
656            restart: 10,
657            tolerance: 1e-10,
658            print_interval: 0,
659        };
660
661        let solution = gmres(&a, &b, &config);
662
663        assert!(solution.converged, "GMRES should converge");
664
665        // Verify solution: Ax ≈ b
666        let ax = a.matvec(&solution.x);
667        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
668        assert!(error < 1e-8, "Solution should satisfy Ax = b");
669    }
670
671    #[test]
672    fn test_gmres_identity() {
673        let n = 5;
674        let id: CsrMatrix<Complex64> = CsrMatrix::identity(n);
675        let b = Array1::from_iter((1..=n).map(|i| Complex64::new(i as f64, 0.0)));
676
677        let config = GmresConfig {
678            max_iterations: 10,
679            restart: 10,
680            tolerance: 1e-12,
681            print_interval: 0,
682        };
683
684        let solution = gmres(&id, &b, &config);
685
686        assert!(solution.converged);
687        assert!(solution.iterations <= 2);
688
689        let error: f64 = (&solution.x - &b)
690            .iter()
691            .map(|e| e.norm_sqr())
692            .sum::<f64>()
693            .sqrt();
694        assert!(error < 1e-10);
695    }
696
697    #[test]
698    fn test_gmres_f64() {
699        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
700
701        let a = CsrMatrix::from_dense(&dense, 1e-15);
702        let b = array![1.0_f64, 2.0];
703
704        let config = GmresConfig {
705            max_iterations: 100,
706            restart: 10,
707            tolerance: 1e-10,
708            print_interval: 0,
709        };
710
711        let solution = gmres(&a, &b, &config);
712
713        assert!(solution.converged);
714
715        let ax = a.matvec(&solution.x);
716        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
717        assert_relative_eq!(error, 0.0, epsilon = 1e-8);
718    }
719}