Skip to main content

math_audio_solvers/iterative/
gmres.rs

1//! GMRES (Generalized Minimal Residual) solver
2//!
3//! Implementation of the restarted GMRES algorithm based on Saad & Schultz (1986).
4//!
5//! GMRES is often the best choice for large non-symmetric systems.
6//! It minimizes the residual in a Krylov subspace and has smooth, monotonic
7//! convergence behavior.
8
9use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::traits::{ComplexField, LinearOperator, Preconditioner, SolverStatus};
11use ndarray::{Array1, Array2};
12use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
13
14/// GMRES solver configuration
15#[derive(Debug, Clone)]
16pub struct GmresConfig<R> {
17    /// Maximum number of outer iterations (restarts)
18    pub max_iterations: usize,
19    /// Restart parameter (number of inner iterations before restart)
20    pub restart: usize,
21    /// Relative tolerance for convergence
22    pub tolerance: R,
23    /// Print progress every N iterations (0 = no output)
24    pub print_interval: usize,
25}
26
27impl Default for GmresConfig<f64> {
28    fn default() -> Self {
29        Self {
30            max_iterations: 100,
31            restart: 30,
32            tolerance: 1e-6,
33            print_interval: 0,
34        }
35    }
36}
37
38impl Default for GmresConfig<f32> {
39    fn default() -> Self {
40        Self {
41            max_iterations: 100,
42            restart: 30,
43            tolerance: 1e-5,
44            print_interval: 0,
45        }
46    }
47}
48
49impl<R: Float + FromPrimitive> GmresConfig<R> {
50    /// Create config for small problems
51    pub fn for_small_problems() -> Self {
52        Self {
53            max_iterations: 50,
54            restart: 50,
55            tolerance: R::from_f64(1e-8).unwrap(),
56            print_interval: 0,
57        }
58    }
59
60    /// Create config with specific restart parameter
61    pub fn with_restart(restart: usize) -> Self
62    where
63        Self: Default,
64    {
65        Self {
66            restart,
67            ..Default::default()
68        }
69    }
70}
71
72/// GMRES solver result
73#[derive(Debug)]
74pub struct GmresSolution<T: ComplexField> {
75    /// Solution vector
76    pub x: Array1<T>,
77    /// Total number of matrix-vector products
78    pub iterations: usize,
79    /// Number of restarts performed
80    pub restarts: usize,
81    /// Final relative residual
82    pub residual: T::Real,
83    /// Whether convergence was achieved
84    pub converged: bool,
85    /// Solver status
86    pub status: SolverStatus,
87}
88
89/// Solve Ax = b using the restarted GMRES method
90///
91/// # Arguments
92/// * `operator` - Linear operator representing A
93/// * `b` - Right-hand side vector
94/// * `config` - Solver configuration
95///
96/// # Returns
97/// Solution struct containing x, iteration count, and convergence info
98pub fn gmres<T, A>(operator: &A, b: &Array1<T>, config: &GmresConfig<T::Real>) -> GmresSolution<T>
99where
100    T: ComplexField,
101    A: LinearOperator<T>,
102{
103    gmres_with_guess(operator, b, None, config)
104}
105
106/// Solve Ax = b using GMRES with an initial guess
107pub fn gmres_with_guess<T, A>(
108    operator: &A,
109    b: &Array1<T>,
110    x0: Option<&Array1<T>>,
111    config: &GmresConfig<T::Real>,
112) -> GmresSolution<T>
113where
114    T: ComplexField,
115    A: LinearOperator<T>,
116{
117    let n = b.len();
118    let m = config.restart;
119
120    // Initialize solution vector
121    let mut x = match x0 {
122        Some(x0) => x0.clone(),
123        None => Array1::from_elem(n, T::zero()),
124    };
125
126    // Compute initial residual norm for relative tolerance
127    let b_norm = vector_norm(b);
128    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
129    if b_norm < tol_threshold {
130        return GmresSolution {
131            x,
132            iterations: 0,
133            restarts: 0,
134            residual: T::Real::zero(),
135            converged: true,
136            status: SolverStatus::Converged,
137        };
138    }
139
140    let mut total_iterations = 0;
141    let mut restarts = 0;
142
143    // Outer iteration (restarts)
144    for _outer in 0..config.max_iterations {
145        // Compute residual r = b - Ax
146        let ax = operator.apply(&x);
147        let r: Array1<T> = b - &ax;
148        let beta = vector_norm(&r);
149
150        // Check convergence
151        let rel_residual = beta / b_norm;
152        if rel_residual < config.tolerance {
153            return GmresSolution {
154                x,
155                iterations: total_iterations,
156                restarts,
157                residual: rel_residual,
158                converged: true,
159                status: SolverStatus::Converged,
160            };
161        }
162
163        // Initialize Krylov basis V
164        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
165        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
166
167        // Upper Hessenberg matrix H
168        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
169
170        // Givens rotation coefficients
171        let mut cs: Vec<T> = Vec::with_capacity(m);
172        let mut sn: Vec<T> = Vec::with_capacity(m);
173
174        // Right-hand side of least squares problem
175        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
176        g[0] = T::from_real(beta);
177
178        let mut inner_converged = false;
179
180        // Inner iteration (Arnoldi process)
181        for j in 0..m {
182            total_iterations += 1;
183
184            // w = A * v_j
185            let mut w = operator.apply(&v[j]);
186
187            // Modified Gram-Schmidt orthogonalization
188            for i in 0..=j {
189                h[[i, j]] = inner_product(&v[i], &w);
190                let h_ij = h[[i, j]];
191                axpy(-h_ij, &v[i], &mut w);
192            }
193
194            let w_norm = vector_norm(&w);
195            h[[j + 1, j]] = T::from_real(w_norm);
196
197            // Check for breakdown
198            let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
199            if w_norm < breakdown_tol {
200                inner_converged = true;
201            } else {
202                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
203            }
204
205            // Apply previous Givens rotations to new column of H
206            for i in 0..j {
207                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
208                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
209                h[[i, j]] = temp;
210            }
211
212            // Compute new Givens rotation
213            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
214            cs.push(c);
215            sn.push(s);
216
217            // Apply Givens rotation to H and g
218            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
219            h[[j + 1, j]] = T::zero();
220
221            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
222            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
223            g[j] = temp;
224
225            // Check convergence
226            let abs_residual = g[j + 1].norm();
227            let rel_residual = abs_residual / b_norm;
228            let abs_tol = T::Real::from_f64(1e-20).unwrap();
229
230            if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
231                log::info!(
232                    "GMRES iteration {} (restart {}): relative residual = {:.6e}",
233                    total_iterations,
234                    restarts,
235                    rel_residual.to_f64().unwrap_or(0.0)
236                );
237            }
238
239            if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
240                // Solve upper triangular system Hy = g
241                let y = solve_upper_triangular(&h, &g, j + 1);
242
243                // Update solution x = x + V * y
244                for (i, &yi) in y.iter().enumerate() {
245                    axpy(yi, &v[i], &mut x);
246                }
247
248                let status = if inner_converged
249                    && rel_residual >= config.tolerance
250                    && abs_residual >= abs_tol
251                {
252                    SolverStatus::Breakdown
253                } else {
254                    SolverStatus::Converged
255                };
256
257                return GmresSolution {
258                    x,
259                    iterations: total_iterations,
260                    restarts,
261                    residual: rel_residual,
262                    converged: true,
263                    status,
264                };
265            }
266        }
267
268        // Maximum inner iterations reached, compute solution and restart
269        let y = solve_upper_triangular(&h, &g, m);
270
271        for (i, &yi) in y.iter().enumerate() {
272            axpy(yi, &v[i], &mut x);
273        }
274
275        restarts += 1;
276    }
277
278    // Compute final residual
279    let ax = operator.apply(&x);
280    let r: Array1<T> = b - &ax;
281    let rel_residual = vector_norm(&r) / b_norm;
282
283    GmresSolution {
284        x,
285        iterations: total_iterations,
286        restarts,
287        residual: rel_residual,
288        converged: false,
289        status: SolverStatus::MaxIterationsReached,
290    }
291}
292
293/// GMRES solver with preconditioner
294///
295/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
296pub fn gmres_preconditioned<T, A, P>(
297    operator: &A,
298    precond: &P,
299    b: &Array1<T>,
300    config: &GmresConfig<T::Real>,
301) -> GmresSolution<T>
302where
303    T: ComplexField,
304    A: LinearOperator<T>,
305    P: Preconditioner<T>,
306{
307    let n = b.len();
308    let m = config.restart;
309
310    let mut x = Array1::from_elem(n, T::zero());
311
312    // Compute preconditioned RHS norm
313    let pb = precond.apply(b);
314    let b_norm = vector_norm(&pb);
315    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
316    if b_norm < tol_threshold {
317        return GmresSolution {
318            x,
319            iterations: 0,
320            restarts: 0,
321            residual: T::Real::zero(),
322            converged: true,
323            status: SolverStatus::Converged,
324        };
325    }
326
327    let mut total_iterations = 0;
328    let mut restarts = 0;
329
330    for _outer in 0..config.max_iterations {
331        // Compute preconditioned residual r = M⁻¹(b - Ax)
332        let ax = operator.apply(&x);
333        let residual: Array1<T> = b - &ax;
334        let r = precond.apply(&residual);
335        let beta = vector_norm(&r);
336
337        let rel_residual = beta / b_norm;
338        if rel_residual < config.tolerance {
339            return GmresSolution {
340                x,
341                iterations: total_iterations,
342                restarts,
343                residual: rel_residual,
344                converged: true,
345                status: SolverStatus::Converged,
346            };
347        }
348
349        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
350        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
351
352        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
353        let mut cs: Vec<T> = Vec::with_capacity(m);
354        let mut sn: Vec<T> = Vec::with_capacity(m);
355
356        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
357        g[0] = T::from_real(beta);
358
359        let mut inner_converged = false;
360
361        for j in 0..m {
362            total_iterations += 1;
363
364            // w = M⁻¹ * A * v_j
365            let av = operator.apply(&v[j]);
366            let mut w = precond.apply(&av);
367
368            // Modified Gram-Schmidt
369            for i in 0..=j {
370                h[[i, j]] = inner_product(&v[i], &w);
371                let h_ij = h[[i, j]];
372                w = &w - &v[i].mapv(|vi| vi * h_ij);
373            }
374
375            let w_norm = vector_norm(&w);
376            h[[j + 1, j]] = T::from_real(w_norm);
377
378            let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
379            if w_norm < breakdown_tol {
380                inner_converged = true;
381            } else {
382                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
383            }
384
385            // Apply previous Givens rotations
386            for i in 0..j {
387                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
388                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
389                h[[i, j]] = temp;
390            }
391
392            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
393            cs.push(c);
394            sn.push(s);
395
396            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
397            h[[j + 1, j]] = T::zero();
398
399            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
400            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
401            g[j] = temp;
402
403            let abs_residual = g[j + 1].norm();
404            let rel_residual = abs_residual / b_norm;
405            let abs_tol = T::Real::from_f64(1e-20).unwrap();
406
407            if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
408                let y = solve_upper_triangular(&h, &g, j + 1);
409
410                for (i, &yi) in y.iter().enumerate() {
411                    x = &x + &v[i].mapv(|vi| vi * yi);
412                }
413
414                return GmresSolution {
415                    x,
416                    iterations: total_iterations,
417                    restarts,
418                    residual: rel_residual,
419                    converged: true,
420                    status: if inner_converged
421                        && rel_residual >= config.tolerance
422                        && abs_residual >= abs_tol
423                    {
424                        SolverStatus::Breakdown
425                    } else {
426                        SolverStatus::Converged
427                    },
428                };
429            }
430        }
431
432        // Restart
433        let y = solve_upper_triangular(&h, &g, m);
434        for (i, &yi) in y.iter().enumerate() {
435            x = &x + &v[i].mapv(|vi| vi * yi);
436        }
437
438        restarts += 1;
439    }
440
441    // Final residual
442    let ax = operator.apply(&x);
443    let residual: Array1<T> = b - &ax;
444    let r = precond.apply(&residual);
445    let rel_residual = vector_norm(&r) / b_norm;
446
447    GmresSolution {
448        x,
449        iterations: total_iterations,
450        restarts,
451        residual: rel_residual,
452        converged: false,
453        status: SolverStatus::MaxIterationsReached,
454    }
455}
456
457/// GMRES solver with preconditioner and initial guess
458///
459/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
460/// with an optional initial guess x0.
461pub fn gmres_preconditioned_with_guess<T, A, P>(
462    operator: &A,
463    precond: &P,
464    b: &Array1<T>,
465    x0: Option<&Array1<T>>,
466    config: &GmresConfig<T::Real>,
467) -> GmresSolution<T>
468where
469    T: ComplexField,
470    A: LinearOperator<T>,
471    P: Preconditioner<T>,
472{
473    let n = b.len();
474    let m = config.restart;
475
476    // Initialize solution vector from initial guess or zero
477    let mut x = match x0 {
478        Some(guess) => guess.clone(),
479        None => Array1::from_elem(n, T::zero()),
480    };
481
482    // Compute preconditioned RHS norm
483    let pb = precond.apply(b);
484    let b_norm = vector_norm(&pb);
485    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
486    if b_norm < tol_threshold {
487        return GmresSolution {
488            x,
489            iterations: 0,
490            restarts: 0,
491            residual: T::Real::zero(),
492            converged: true,
493            status: SolverStatus::Converged,
494        };
495    }
496
497    let mut total_iterations = 0;
498    let mut restarts = 0;
499
500    for _outer in 0..config.max_iterations {
501        // Compute preconditioned residual r = M⁻¹(b - Ax)
502        let ax = operator.apply(&x);
503        let residual: Array1<T> = b - &ax;
504        let r = precond.apply(&residual);
505        let beta = vector_norm(&r);
506
507        let rel_residual = beta / b_norm;
508        if rel_residual < config.tolerance {
509            return GmresSolution {
510                x,
511                iterations: total_iterations,
512                restarts,
513                residual: rel_residual,
514                converged: true,
515                status: SolverStatus::Converged,
516            };
517        }
518
519        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
520        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
521
522        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
523        let mut cs: Vec<T> = Vec::with_capacity(m);
524        let mut sn: Vec<T> = Vec::with_capacity(m);
525
526        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
527        g[0] = T::from_real(beta);
528
529        let mut inner_converged = false;
530
531        for j in 0..m {
532            total_iterations += 1;
533
534            // w = M⁻¹ * A * v_j
535            let av = operator.apply(&v[j]);
536            let mut w = precond.apply(&av);
537
538            // Modified Gram-Schmidt
539            for i in 0..=j {
540                h[[i, j]] = inner_product(&v[i], &w);
541                let h_ij = h[[i, j]];
542                w = &w - &v[i].mapv(|vi| vi * h_ij);
543            }
544
545            let w_norm = vector_norm(&w);
546            h[[j + 1, j]] = T::from_real(w_norm);
547
548            let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
549            if w_norm < breakdown_tol {
550                inner_converged = true;
551            } else {
552                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
553            }
554
555            // Apply previous Givens rotations
556            for i in 0..j {
557                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
558                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
559                h[[i, j]] = temp;
560            }
561
562            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
563            cs.push(c);
564            sn.push(s);
565
566            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
567            h[[j + 1, j]] = T::zero();
568
569            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
570            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
571            g[j] = temp;
572
573            let abs_residual = g[j + 1].norm();
574            let rel_residual = abs_residual / b_norm;
575            let abs_tol = T::Real::from_f64(1e-20).unwrap();
576
577            if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
578                let y = solve_upper_triangular(&h, &g, j + 1);
579
580                for (i, &yi) in y.iter().enumerate() {
581                    x = &x + &v[i].mapv(|vi| vi * yi);
582                }
583
584                return GmresSolution {
585                    x,
586                    iterations: total_iterations,
587                    restarts,
588                    residual: rel_residual,
589                    converged: true,
590                    status: if inner_converged
591                        && rel_residual >= config.tolerance
592                        && abs_residual >= abs_tol
593                    {
594                        SolverStatus::Breakdown
595                    } else {
596                        SolverStatus::Converged
597                    },
598                };
599            }
600        }
601
602        // Restart
603        let y = solve_upper_triangular(&h, &g, m);
604        for (i, &yi) in y.iter().enumerate() {
605            x = &x + &v[i].mapv(|vi| vi * yi);
606        }
607
608        restarts += 1;
609    }
610
611    // Final residual
612    let ax = operator.apply(&x);
613    let residual: Array1<T> = b - &ax;
614    let r = precond.apply(&residual);
615    let rel_residual = vector_norm(&r) / b_norm;
616
617    GmresSolution {
618        x,
619        iterations: total_iterations,
620        restarts,
621        residual: rel_residual,
622        converged: false,
623        status: SolverStatus::MaxIterationsReached,
624    }
625}
626
627/// Compute Givens rotation coefficients
628#[inline]
629fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
630    let tol = T::Real::from_f64(1e-30).unwrap();
631    if b.norm() < tol {
632        return (T::one(), T::zero());
633    }
634    if a.norm() < tol {
635        return (T::zero(), T::one());
636    }
637
638    let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
639    let c = a * T::from_real(T::Real::one() / r);
640    let s = b * T::from_real(T::Real::one() / r);
641
642    (c, s)
643}
644
645/// Solve upper triangular system Hy = g
646fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
647    let mut y = vec![T::zero(); k];
648    let tol = T::Real::from_f64(1e-30).unwrap();
649
650    for i in (0..k).rev() {
651        let mut sum = g[i];
652        for j in (i + 1)..k {
653            sum -= h[[i, j]] * y[j];
654        }
655        if h[[i, i]].norm() > tol {
656            y[i] = sum * h[[i, i]].inv();
657        }
658    }
659
660    y
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666    use crate::sparse::CsrMatrix;
667    use approx::assert_relative_eq;
668    use ndarray::array;
669    use num_complex::Complex64;
670
671    #[test]
672    fn test_gmres_simple() {
673        let dense = array![
674            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
675            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
676        ];
677
678        let a = CsrMatrix::from_dense(&dense, 1e-15);
679        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
680
681        let config = GmresConfig {
682            max_iterations: 100,
683            restart: 10,
684            tolerance: 1e-10,
685            print_interval: 0,
686        };
687
688        let solution = gmres(&a, &b, &config);
689
690        assert!(solution.converged, "GMRES should converge");
691
692        // Verify solution: Ax ≈ b
693        let ax = a.matvec(&solution.x);
694        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
695        assert!(error < 1e-8, "Solution should satisfy Ax = b");
696    }
697
698    #[test]
699    fn test_gmres_identity() {
700        let n = 5;
701        let id: CsrMatrix<Complex64> = CsrMatrix::identity(n);
702        let b = Array1::from_iter((1..=n).map(|i| Complex64::new(i as f64, 0.0)));
703
704        let config = GmresConfig {
705            max_iterations: 10,
706            restart: 10,
707            tolerance: 1e-12,
708            print_interval: 0,
709        };
710
711        let solution = gmres(&id, &b, &config);
712
713        assert!(solution.converged);
714        assert!(solution.iterations <= 2);
715
716        let error: f64 = (&solution.x - &b)
717            .iter()
718            .map(|e| e.norm_sqr())
719            .sum::<f64>()
720            .sqrt();
721        assert!(error < 1e-10);
722    }
723
724    #[test]
725    fn test_gmres_f64() {
726        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
727
728        let a = CsrMatrix::from_dense(&dense, 1e-15);
729        let b = array![1.0_f64, 2.0];
730
731        let config = GmresConfig {
732            max_iterations: 100,
733            restart: 10,
734            tolerance: 1e-10,
735            print_interval: 0,
736        };
737
738        let solution = gmres(&a, &b, &config);
739
740        assert!(solution.converged);
741
742        let ax = a.matvec(&solution.x);
743        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
744        assert_relative_eq!(error, 0.0, epsilon = 1e-8);
745    }
746}