Skip to main content

math_audio_solvers/iterative/
gmres.rs

1//! GMRES (Generalized Minimal Residual) solver
2//!
3//! Implementation of the restarted GMRES algorithm based on Saad & Schultz (1986).
4//!
5//! GMRES is often the best choice for large non-symmetric systems.
6//! It minimizes the residual in a Krylov subspace and has smooth, monotonic
7//! convergence behavior.
8
9use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::traits::{ComplexField, LinearOperator, Preconditioner, SolverStatus};
11use ndarray::{Array1, Array2};
12use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
13
14/// GMRES solver configuration
15#[derive(Debug, Clone)]
16pub struct GmresConfig<R> {
17    /// Maximum number of outer iterations (restarts)
18    pub max_iterations: usize,
19    /// Restart parameter (number of inner iterations before restart)
20    pub restart: usize,
21    /// Relative tolerance for convergence
22    pub tolerance: R,
23    /// Print progress every N iterations (0 = no output)
24    pub print_interval: usize,
25}
26
27impl Default for GmresConfig<f64> {
28    fn default() -> Self {
29        Self {
30            max_iterations: 100,
31            restart: 30,
32            tolerance: 1e-6,
33            print_interval: 0,
34        }
35    }
36}
37
38impl Default for GmresConfig<f32> {
39    fn default() -> Self {
40        Self {
41            max_iterations: 100,
42            restart: 30,
43            tolerance: 1e-5,
44            print_interval: 0,
45        }
46    }
47}
48
49impl<R: Float + FromPrimitive> GmresConfig<R> {
50    /// Create config for small problems
51    pub fn for_small_problems() -> Self {
52        Self {
53            max_iterations: 50,
54            restart: 50,
55            tolerance: R::from_f64(1e-8).unwrap(),
56            print_interval: 0,
57        }
58    }
59
60    /// Create config with specific restart parameter
61    pub fn with_restart(restart: usize) -> Self
62    where
63        Self: Default,
64    {
65        Self {
66            restart,
67            ..Default::default()
68        }
69    }
70}
71
72/// GMRES solver result
73#[derive(Debug)]
74pub struct GmresSolution<T: ComplexField> {
75    /// Solution vector
76    pub x: Array1<T>,
77    /// Total number of matrix-vector products
78    pub iterations: usize,
79    /// Number of restarts performed
80    pub restarts: usize,
81    /// Final relative residual
82    pub residual: T::Real,
83    /// Whether convergence was achieved
84    pub converged: bool,
85    /// Solver status
86    pub status: SolverStatus,
87}
88
89/// Solve Ax = b using the restarted GMRES method
90///
91/// # Arguments
92/// * `operator` - Linear operator representing A
93/// * `b` - Right-hand side vector
94/// * `config` - Solver configuration
95///
96/// # Returns
97/// Solution struct containing x, iteration count, and convergence info
98pub fn gmres<T, A>(operator: &A, b: &Array1<T>, config: &GmresConfig<T::Real>) -> GmresSolution<T>
99where
100    T: ComplexField,
101    A: LinearOperator<T>,
102{
103    gmres_with_guess(operator, b, None, config)
104}
105
106/// Solve Ax = b using GMRES with an initial guess
107pub fn gmres_with_guess<T, A>(
108    operator: &A,
109    b: &Array1<T>,
110    x0: Option<&Array1<T>>,
111    config: &GmresConfig<T::Real>,
112) -> GmresSolution<T>
113where
114    T: ComplexField,
115    A: LinearOperator<T>,
116{
117    let n = b.len();
118    let m = config.restart;
119
120    // Initialize solution vector
121    let mut x = match x0 {
122        Some(x0) => x0.clone(),
123        None => Array1::from_elem(n, T::zero()),
124    };
125
126    // Compute initial residual norm for relative tolerance
127    let b_norm = vector_norm(b);
128    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
129    if b_norm < tol_threshold {
130        return GmresSolution {
131            x,
132            iterations: 0,
133            restarts: 0,
134            residual: T::Real::zero(),
135            converged: true,
136            status: SolverStatus::Converged,
137        };
138    }
139
140    let mut total_iterations = 0;
141    let mut restarts = 0;
142
143    // Outer iteration (restarts)
144    for _outer in 0..config.max_iterations {
145        // Compute residual r = b - Ax
146        let ax = operator.apply(&x);
147        let r: Array1<T> = b - &ax;
148        let beta = vector_norm(&r);
149
150        // Check convergence
151        let rel_residual = beta / b_norm;
152        if rel_residual < config.tolerance {
153            return GmresSolution {
154                x,
155                iterations: total_iterations,
156                restarts,
157                residual: rel_residual,
158                converged: true,
159                status: SolverStatus::Converged,
160            };
161        }
162
163        // Initialize Krylov basis V
164        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
165        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
166
167        // Upper Hessenberg matrix H
168        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
169
170        // Givens rotation coefficients
171        let mut cs: Vec<T> = Vec::with_capacity(m);
172        let mut sn: Vec<T> = Vec::with_capacity(m);
173
174        // Right-hand side of least squares problem
175        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
176        g[0] = T::from_real(beta);
177
178        let mut inner_converged = false;
179
180        // Inner iteration (Arnoldi process)
181        for j in 0..m {
182            total_iterations += 1;
183
184            // w = A * v_j
185            let mut w = operator.apply(&v[j]);
186
187            // Modified Gram-Schmidt orthogonalization
188            for i in 0..=j {
189                h[[i, j]] = inner_product(&v[i], &w);
190                let h_ij = h[[i, j]];
191                axpy(-h_ij, &v[i], &mut w);
192            }
193
194            let w_norm = vector_norm(&w);
195            h[[j + 1, j]] = T::from_real(w_norm);
196
197            // Check for breakdown
198            let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
199            if w_norm < breakdown_tol {
200                inner_converged = true;
201            } else {
202                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
203            }
204
205            // Apply previous Givens rotations to new column of H
206            for i in 0..j {
207                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
208                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
209                h[[i, j]] = temp;
210            }
211
212            // Compute new Givens rotation
213            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
214            cs.push(c);
215            sn.push(s);
216
217            // Apply Givens rotation to H and g
218            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
219            h[[j + 1, j]] = T::zero();
220
221            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
222            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
223            g[j] = temp;
224
225            // Check convergence
226            let abs_residual = g[j + 1].norm();
227            let rel_residual = abs_residual / b_norm;
228            let abs_tol = T::Real::from_f64(1e-20).unwrap();
229
230            if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
231                log::info!(
232                    "GMRES iteration {} (restart {}): relative residual = {:.6e}",
233                    total_iterations,
234                    restarts,
235                    rel_residual.to_f64().unwrap_or(0.0)
236                );
237            }
238
239            if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
240                // Solve upper triangular system Hy = g
241                let y = solve_upper_triangular(&h, &g, j + 1);
242
243                // Update solution x = x + V * y
244                for (i, &yi) in y.iter().enumerate() {
245                    axpy(yi, &v[i], &mut x);
246                }
247
248                let status = if inner_converged && rel_residual >= config.tolerance && abs_residual >= abs_tol {
249                    SolverStatus::Breakdown
250                } else {
251                    SolverStatus::Converged
252                };
253
254                return GmresSolution {
255                    x,
256                    iterations: total_iterations,
257                    restarts,
258                    residual: rel_residual,
259                    converged: true,
260                    status,
261                };
262            }
263        }
264
265        // Maximum inner iterations reached, compute solution and restart
266        let y = solve_upper_triangular(&h, &g, m);
267
268        for (i, &yi) in y.iter().enumerate() {
269            axpy(yi, &v[i], &mut x);
270        }
271
272        restarts += 1;
273    }
274
275    // Compute final residual
276    let ax = operator.apply(&x);
277    let r: Array1<T> = b - &ax;
278    let rel_residual = vector_norm(&r) / b_norm;
279
280    GmresSolution {
281        x,
282        iterations: total_iterations,
283        restarts,
284        residual: rel_residual,
285        converged: false,
286        status: SolverStatus::MaxIterationsReached,
287    }
288}
289
290/// GMRES solver with preconditioner
291///
292/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
293pub fn gmres_preconditioned<T, A, P>(
294    operator: &A,
295    precond: &P,
296    b: &Array1<T>,
297    config: &GmresConfig<T::Real>,
298) -> GmresSolution<T>
299where
300    T: ComplexField,
301    A: LinearOperator<T>,
302    P: Preconditioner<T>,
303{
304    let n = b.len();
305    let m = config.restart;
306
307    let mut x = Array1::from_elem(n, T::zero());
308
309    // Compute preconditioned RHS norm
310    let pb = precond.apply(b);
311    let b_norm = vector_norm(&pb);
312    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
313    if b_norm < tol_threshold {
314        return GmresSolution {
315            x,
316            iterations: 0,
317            restarts: 0,
318            residual: T::Real::zero(),
319            converged: true,
320            status: SolverStatus::Converged,
321        };
322    }
323
324    let mut total_iterations = 0;
325    let mut restarts = 0;
326
327    for _outer in 0..config.max_iterations {
328        // Compute preconditioned residual r = M⁻¹(b - Ax)
329        let ax = operator.apply(&x);
330        let residual: Array1<T> = b - &ax;
331        let r = precond.apply(&residual);
332        let beta = vector_norm(&r);
333
334        let rel_residual = beta / b_norm;
335        if rel_residual < config.tolerance {
336            return GmresSolution {
337                x,
338                iterations: total_iterations,
339                restarts,
340                residual: rel_residual,
341                converged: true,
342                status: SolverStatus::Converged,
343            };
344        }
345
346        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
347        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
348
349        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
350        let mut cs: Vec<T> = Vec::with_capacity(m);
351        let mut sn: Vec<T> = Vec::with_capacity(m);
352
353        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
354        g[0] = T::from_real(beta);
355
356        let mut inner_converged = false;
357
358        for j in 0..m {
359            total_iterations += 1;
360
361            // w = M⁻¹ * A * v_j
362            let av = operator.apply(&v[j]);
363            let mut w = precond.apply(&av);
364
365            // Modified Gram-Schmidt
366            for i in 0..=j {
367                h[[i, j]] = inner_product(&v[i], &w);
368                let h_ij = h[[i, j]];
369                w = &w - &v[i].mapv(|vi| vi * h_ij);
370            }
371
372            let w_norm = vector_norm(&w);
373            h[[j + 1, j]] = T::from_real(w_norm);
374
375            let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
376            if w_norm < breakdown_tol {
377                inner_converged = true;
378            } else {
379                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
380            }
381
382            // Apply previous Givens rotations
383            for i in 0..j {
384                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
385                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
386                h[[i, j]] = temp;
387            }
388
389            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
390            cs.push(c);
391            sn.push(s);
392
393            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
394            h[[j + 1, j]] = T::zero();
395
396            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
397            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
398            g[j] = temp;
399
400            let abs_residual = g[j + 1].norm();
401            let rel_residual = abs_residual / b_norm;
402            let abs_tol = T::Real::from_f64(1e-20).unwrap();
403
404            if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
405                let y = solve_upper_triangular(&h, &g, j + 1);
406
407                for (i, &yi) in y.iter().enumerate() {
408                    x = &x + &v[i].mapv(|vi| vi * yi);
409                }
410
411                return GmresSolution {
412                    x,
413                    iterations: total_iterations,
414                    restarts,
415                    residual: rel_residual,
416                    converged: true,
417                    status: if inner_converged && rel_residual >= config.tolerance && abs_residual >= abs_tol {
418                        SolverStatus::Breakdown
419                    } else {
420                        SolverStatus::Converged
421                    },
422                };
423            }
424        }
425
426        // Restart
427        let y = solve_upper_triangular(&h, &g, m);
428        for (i, &yi) in y.iter().enumerate() {
429            x = &x + &v[i].mapv(|vi| vi * yi);
430        }
431
432        restarts += 1;
433    }
434
435    // Final residual
436    let ax = operator.apply(&x);
437    let residual: Array1<T> = b - &ax;
438    let r = precond.apply(&residual);
439    let rel_residual = vector_norm(&r) / b_norm;
440
441    GmresSolution {
442        x,
443        iterations: total_iterations,
444        restarts,
445        residual: rel_residual,
446        converged: false,
447        status: SolverStatus::MaxIterationsReached,
448    }
449}
450
451/// GMRES solver with preconditioner and initial guess
452///
453/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
454/// with an optional initial guess x0.
455pub fn gmres_preconditioned_with_guess<T, A, P>(
456    operator: &A,
457    precond: &P,
458    b: &Array1<T>,
459    x0: Option<&Array1<T>>,
460    config: &GmresConfig<T::Real>,
461) -> GmresSolution<T>
462where
463    T: ComplexField,
464    A: LinearOperator<T>,
465    P: Preconditioner<T>,
466{
467    let n = b.len();
468    let m = config.restart;
469
470    // Initialize solution vector from initial guess or zero
471    let mut x = match x0 {
472        Some(guess) => guess.clone(),
473        None => Array1::from_elem(n, T::zero()),
474    };
475
476    // Compute preconditioned RHS norm
477    let pb = precond.apply(b);
478    let b_norm = vector_norm(&pb);
479    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
480    if b_norm < tol_threshold {
481        return GmresSolution {
482            x,
483            iterations: 0,
484            restarts: 0,
485            residual: T::Real::zero(),
486            converged: true,
487            status: SolverStatus::Converged,
488        };
489    }
490
491    let mut total_iterations = 0;
492    let mut restarts = 0;
493
494    for _outer in 0..config.max_iterations {
495        // Compute preconditioned residual r = M⁻¹(b - Ax)
496        let ax = operator.apply(&x);
497        let residual: Array1<T> = b - &ax;
498        let r = precond.apply(&residual);
499        let beta = vector_norm(&r);
500
501        let rel_residual = beta / b_norm;
502        if rel_residual < config.tolerance {
503            return GmresSolution {
504                x,
505                iterations: total_iterations,
506                restarts,
507                residual: rel_residual,
508                converged: true,
509                status: SolverStatus::Converged,
510            };
511        }
512
513        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
514        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
515
516        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
517        let mut cs: Vec<T> = Vec::with_capacity(m);
518        let mut sn: Vec<T> = Vec::with_capacity(m);
519
520        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
521        g[0] = T::from_real(beta);
522
523        let mut inner_converged = false;
524
525        for j in 0..m {
526            total_iterations += 1;
527
528            // w = M⁻¹ * A * v_j
529            let av = operator.apply(&v[j]);
530            let mut w = precond.apply(&av);
531
532            // Modified Gram-Schmidt
533            for i in 0..=j {
534                h[[i, j]] = inner_product(&v[i], &w);
535                let h_ij = h[[i, j]];
536                w = &w - &v[i].mapv(|vi| vi * h_ij);
537            }
538
539            let w_norm = vector_norm(&w);
540            h[[j + 1, j]] = T::from_real(w_norm);
541
542            let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
543            if w_norm < breakdown_tol {
544                inner_converged = true;
545            } else {
546                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
547            }
548
549            // Apply previous Givens rotations
550            for i in 0..j {
551                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
552                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
553                h[[i, j]] = temp;
554            }
555
556            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
557            cs.push(c);
558            sn.push(s);
559
560            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
561            h[[j + 1, j]] = T::zero();
562
563            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
564            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
565            g[j] = temp;
566
567            let abs_residual = g[j + 1].norm();
568            let rel_residual = abs_residual / b_norm;
569            let abs_tol = T::Real::from_f64(1e-20).unwrap();
570
571            if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
572                let y = solve_upper_triangular(&h, &g, j + 1);
573
574                for (i, &yi) in y.iter().enumerate() {
575                    x = &x + &v[i].mapv(|vi| vi * yi);
576                }
577
578                return GmresSolution {
579                    x,
580                    iterations: total_iterations,
581                    restarts,
582                    residual: rel_residual,
583                    converged: true,
584                    status: if inner_converged && rel_residual >= config.tolerance && abs_residual >= abs_tol {
585                        SolverStatus::Breakdown
586                    } else {
587                        SolverStatus::Converged
588                    },
589                };
590            }
591        }
592
593        // Restart
594        let y = solve_upper_triangular(&h, &g, m);
595        for (i, &yi) in y.iter().enumerate() {
596            x = &x + &v[i].mapv(|vi| vi * yi);
597        }
598
599        restarts += 1;
600    }
601
602    // Final residual
603    let ax = operator.apply(&x);
604    let residual: Array1<T> = b - &ax;
605    let r = precond.apply(&residual);
606    let rel_residual = vector_norm(&r) / b_norm;
607
608    GmresSolution {
609        x,
610        iterations: total_iterations,
611        restarts,
612        residual: rel_residual,
613        converged: false,
614        status: SolverStatus::MaxIterationsReached,
615    }
616}
617
618/// Compute Givens rotation coefficients
619#[inline]
620fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
621    let tol = T::Real::from_f64(1e-30).unwrap();
622    if b.norm() < tol {
623        return (T::one(), T::zero());
624    }
625    if a.norm() < tol {
626        return (T::zero(), T::one());
627    }
628
629    let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
630    let c = a * T::from_real(T::Real::one() / r);
631    let s = b * T::from_real(T::Real::one() / r);
632
633    (c, s)
634}
635
636/// Solve upper triangular system Hy = g
637fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
638    let mut y = vec![T::zero(); k];
639    let tol = T::Real::from_f64(1e-30).unwrap();
640
641    for i in (0..k).rev() {
642        let mut sum = g[i];
643        for j in (i + 1)..k {
644            sum -= h[[i, j]] * y[j];
645        }
646        if h[[i, i]].norm() > tol {
647            y[i] = sum * h[[i, i]].inv();
648        }
649    }
650
651    y
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use crate::sparse::CsrMatrix;
658    use approx::assert_relative_eq;
659    use ndarray::array;
660    use num_complex::Complex64;
661
662    #[test]
663    fn test_gmres_simple() {
664        let dense = array![
665            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
666            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
667        ];
668
669        let a = CsrMatrix::from_dense(&dense, 1e-15);
670        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
671
672        let config = GmresConfig {
673            max_iterations: 100,
674            restart: 10,
675            tolerance: 1e-10,
676            print_interval: 0,
677        };
678
679        let solution = gmres(&a, &b, &config);
680
681        assert!(solution.converged, "GMRES should converge");
682
683        // Verify solution: Ax ≈ b
684        let ax = a.matvec(&solution.x);
685        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
686        assert!(error < 1e-8, "Solution should satisfy Ax = b");
687    }
688
689    #[test]
690    fn test_gmres_identity() {
691        let n = 5;
692        let id: CsrMatrix<Complex64> = CsrMatrix::identity(n);
693        let b = Array1::from_iter((1..=n).map(|i| Complex64::new(i as f64, 0.0)));
694
695        let config = GmresConfig {
696            max_iterations: 10,
697            restart: 10,
698            tolerance: 1e-12,
699            print_interval: 0,
700        };
701
702        let solution = gmres(&id, &b, &config);
703
704        assert!(solution.converged);
705        assert!(solution.iterations <= 2);
706
707        let error: f64 = (&solution.x - &b)
708            .iter()
709            .map(|e| e.norm_sqr())
710            .sum::<f64>()
711            .sqrt();
712        assert!(error < 1e-10);
713    }
714
715    #[test]
716    fn test_gmres_f64() {
717        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
718
719        let a = CsrMatrix::from_dense(&dense, 1e-15);
720        let b = array![1.0_f64, 2.0];
721
722        let config = GmresConfig {
723            max_iterations: 100,
724            restart: 10,
725            tolerance: 1e-10,
726            print_interval: 0,
727        };
728
729        let solution = gmres(&a, &b, &config);
730
731        assert!(solution.converged);
732
733        let ax = a.matvec(&solution.x);
734        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
735        assert_relative_eq!(error, 0.0, epsilon = 1e-8);
736    }
737}