math_audio_solvers/iterative/
gmres.rs

1//! GMRES (Generalized Minimal Residual) solver
2//!
3//! Implementation of the restarted GMRES algorithm based on Saad & Schultz (1986).
4//!
5//! GMRES is often the best choice for large non-symmetric systems.
6//! It minimizes the residual in a Krylov subspace and has smooth, monotonic
7//! convergence behavior.
8
9use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::traits::{ComplexField, LinearOperator, Preconditioner};
11use ndarray::{Array1, Array2};
12use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
13
14/// GMRES solver configuration
15#[derive(Debug, Clone)]
16pub struct GmresConfig<R> {
17    /// Maximum number of outer iterations (restarts)
18    pub max_iterations: usize,
19    /// Restart parameter (number of inner iterations before restart)
20    pub restart: usize,
21    /// Relative tolerance for convergence
22    pub tolerance: R,
23    /// Print progress every N iterations (0 = no output)
24    pub print_interval: usize,
25}
26
27impl Default for GmresConfig<f64> {
28    fn default() -> Self {
29        Self {
30            max_iterations: 100,
31            restart: 30,
32            tolerance: 1e-6,
33            print_interval: 0,
34        }
35    }
36}
37
38impl Default for GmresConfig<f32> {
39    fn default() -> Self {
40        Self {
41            max_iterations: 100,
42            restart: 30,
43            tolerance: 1e-5,
44            print_interval: 0,
45        }
46    }
47}
48
49impl<R: Float + FromPrimitive> GmresConfig<R> {
50    /// Create config for small problems
51    pub fn for_small_problems() -> Self {
52        Self {
53            max_iterations: 50,
54            restart: 50,
55            tolerance: R::from_f64(1e-8).unwrap(),
56            print_interval: 0,
57        }
58    }
59
60    /// Create config with specific restart parameter
61    pub fn with_restart(restart: usize) -> Self
62    where
63        Self: Default,
64    {
65        Self {
66            restart,
67            ..Default::default()
68        }
69    }
70}
71
72/// GMRES solver result
73#[derive(Debug)]
74pub struct GmresSolution<T: ComplexField> {
75    /// Solution vector
76    pub x: Array1<T>,
77    /// Total number of matrix-vector products
78    pub iterations: usize,
79    /// Number of restarts performed
80    pub restarts: usize,
81    /// Final relative residual
82    pub residual: T::Real,
83    /// Whether convergence was achieved
84    pub converged: bool,
85}
86
87/// Solve Ax = b using the restarted GMRES method
88///
89/// # Arguments
90/// * `operator` - Linear operator representing A
91/// * `b` - Right-hand side vector
92/// * `config` - Solver configuration
93///
94/// # Returns
95/// Solution struct containing x, iteration count, and convergence info
96pub fn gmres<T, A>(operator: &A, b: &Array1<T>, config: &GmresConfig<T::Real>) -> GmresSolution<T>
97where
98    T: ComplexField,
99    A: LinearOperator<T>,
100{
101    gmres_with_guess(operator, b, None, config)
102}
103
104/// Solve Ax = b using GMRES with an initial guess
105pub fn gmres_with_guess<T, A>(
106    operator: &A,
107    b: &Array1<T>,
108    x0: Option<&Array1<T>>,
109    config: &GmresConfig<T::Real>,
110) -> GmresSolution<T>
111where
112    T: ComplexField,
113    A: LinearOperator<T>,
114{
115    let n = b.len();
116    let m = config.restart;
117
118    // Initialize solution vector
119    let mut x = match x0 {
120        Some(x0) => x0.clone(),
121        None => Array1::from_elem(n, T::zero()),
122    };
123
124    // Compute initial residual norm for relative tolerance
125    let b_norm = vector_norm(b);
126    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
127    if b_norm < tol_threshold {
128        return GmresSolution {
129            x,
130            iterations: 0,
131            restarts: 0,
132            residual: T::Real::zero(),
133            converged: true,
134        };
135    }
136
137    let mut total_iterations = 0;
138    let mut restarts = 0;
139
140    // Outer iteration (restarts)
141    for _outer in 0..config.max_iterations {
142        // Compute residual r = b - Ax
143        let ax = operator.apply(&x);
144        let r: Array1<T> = b - &ax;
145        let beta = vector_norm(&r);
146
147        // Check convergence
148        let rel_residual = beta / b_norm;
149        if rel_residual < config.tolerance {
150            return GmresSolution {
151                x,
152                iterations: total_iterations,
153                restarts,
154                residual: rel_residual,
155                converged: true,
156            };
157        }
158
159        // Initialize Krylov basis V
160        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
161        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
162
163        // Upper Hessenberg matrix H
164        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
165
166        // Givens rotation coefficients
167        let mut cs: Vec<T> = Vec::with_capacity(m);
168        let mut sn: Vec<T> = Vec::with_capacity(m);
169
170        // Right-hand side of least squares problem
171        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
172        g[0] = T::from_real(beta);
173
174        let mut inner_converged = false;
175
176        // Inner iteration (Arnoldi process)
177        for j in 0..m {
178            total_iterations += 1;
179
180            // w = A * v_j
181            let mut w = operator.apply(&v[j]);
182
183            // Modified Gram-Schmidt orthogonalization
184            for i in 0..=j {
185                h[[i, j]] = inner_product(&v[i], &w);
186                let h_ij = h[[i, j]];
187                axpy(-h_ij, &v[i], &mut w);
188            }
189
190            let w_norm = vector_norm(&w);
191            h[[j + 1, j]] = T::from_real(w_norm);
192
193            // Check for breakdown
194            let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
195            if w_norm < breakdown_tol {
196                inner_converged = true;
197            } else {
198                let inv_norm = T::from_real(T::Real::one() / w_norm);
199                let mut new_v = w.clone();
200                axpy(inv_norm - T::one(), &w, &mut new_v);
201                v.push(new_v);
202            }
203
204            // Apply previous Givens rotations to new column of H
205            for i in 0..j {
206                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
207                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
208                h[[i, j]] = temp;
209            }
210
211            // Compute new Givens rotation
212            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
213            cs.push(c);
214            sn.push(s);
215
216            // Apply Givens rotation to H and g
217            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
218            h[[j + 1, j]] = T::zero();
219
220            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
221            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
222            g[j] = temp;
223
224            // Check convergence
225            let rel_residual = g[j + 1].norm() / b_norm;
226
227            if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
228                log::info!(
229                    "GMRES iteration {} (restart {}): relative residual = {:.6e}",
230                    total_iterations,
231                    restarts,
232                    rel_residual.to_f64().unwrap_or(0.0)
233                );
234            }
235
236            if rel_residual < config.tolerance || inner_converged {
237                // Solve upper triangular system Hy = g
238                let y = solve_upper_triangular(&h, &g, j + 1);
239
240                // Update solution x = x + V * y
241                for (i, &yi) in y.iter().enumerate() {
242                    axpy(yi, &v[i], &mut x);
243                }
244
245                return GmresSolution {
246                    x,
247                    iterations: total_iterations,
248                    restarts,
249                    residual: rel_residual,
250                    converged: true,
251                };
252            }
253        }
254
255        // Maximum inner iterations reached, compute solution and restart
256        let y = solve_upper_triangular(&h, &g, m);
257
258        for (i, &yi) in y.iter().enumerate() {
259            axpy(yi, &v[i], &mut x);
260        }
261
262        restarts += 1;
263    }
264
265    // Compute final residual
266    let ax = operator.apply(&x);
267    let r: Array1<T> = b - &ax;
268    let rel_residual = vector_norm(&r) / b_norm;
269
270    GmresSolution {
271        x,
272        iterations: total_iterations,
273        restarts,
274        residual: rel_residual,
275        converged: false,
276    }
277}
278
279/// GMRES solver with preconditioner
280///
281/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
282pub fn gmres_preconditioned<T, A, P>(
283    operator: &A,
284    precond: &P,
285    b: &Array1<T>,
286    config: &GmresConfig<T::Real>,
287) -> GmresSolution<T>
288where
289    T: ComplexField,
290    A: LinearOperator<T>,
291    P: Preconditioner<T>,
292{
293    let n = b.len();
294    let m = config.restart;
295
296    let mut x = Array1::from_elem(n, T::zero());
297
298    // Compute preconditioned RHS norm
299    let pb = precond.apply(b);
300    let b_norm = vector_norm(&pb);
301    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
302    if b_norm < tol_threshold {
303        return GmresSolution {
304            x,
305            iterations: 0,
306            restarts: 0,
307            residual: T::Real::zero(),
308            converged: true,
309        };
310    }
311
312    let mut total_iterations = 0;
313    let mut restarts = 0;
314
315    for _outer in 0..config.max_iterations {
316        // Compute preconditioned residual r = M⁻¹(b - Ax)
317        let ax = operator.apply(&x);
318        let residual: Array1<T> = b - &ax;
319        let r = precond.apply(&residual);
320        let beta = vector_norm(&r);
321
322        let rel_residual = beta / b_norm;
323        if rel_residual < config.tolerance {
324            return GmresSolution {
325                x,
326                iterations: total_iterations,
327                restarts,
328                residual: rel_residual,
329                converged: true,
330            };
331        }
332
333        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
334        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
335
336        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
337        let mut cs: Vec<T> = Vec::with_capacity(m);
338        let mut sn: Vec<T> = Vec::with_capacity(m);
339
340        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
341        g[0] = T::from_real(beta);
342
343        let mut inner_converged = false;
344
345        for j in 0..m {
346            total_iterations += 1;
347
348            // w = M⁻¹ * A * v_j
349            let av = operator.apply(&v[j]);
350            let mut w = precond.apply(&av);
351
352            // Modified Gram-Schmidt
353            for i in 0..=j {
354                h[[i, j]] = inner_product(&v[i], &w);
355                let h_ij = h[[i, j]];
356                w = &w - &v[i].mapv(|vi| vi * h_ij);
357            }
358
359            let w_norm = vector_norm(&w);
360            h[[j + 1, j]] = T::from_real(w_norm);
361
362            let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
363            if w_norm < breakdown_tol {
364                inner_converged = true;
365            } else {
366                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
367            }
368
369            // Apply previous Givens rotations
370            for i in 0..j {
371                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
372                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
373                h[[i, j]] = temp;
374            }
375
376            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
377            cs.push(c);
378            sn.push(s);
379
380            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
381            h[[j + 1, j]] = T::zero();
382
383            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
384            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
385            g[j] = temp;
386
387            let rel_residual = g[j + 1].norm() / b_norm;
388
389            if rel_residual < config.tolerance || inner_converged {
390                let y = solve_upper_triangular(&h, &g, j + 1);
391
392                for (i, &yi) in y.iter().enumerate() {
393                    x = &x + &v[i].mapv(|vi| vi * yi);
394                }
395
396                return GmresSolution {
397                    x,
398                    iterations: total_iterations,
399                    restarts,
400                    residual: rel_residual,
401                    converged: true,
402                };
403            }
404        }
405
406        // Restart
407        let y = solve_upper_triangular(&h, &g, m);
408        for (i, &yi) in y.iter().enumerate() {
409            x = &x + &v[i].mapv(|vi| vi * yi);
410        }
411
412        restarts += 1;
413    }
414
415    // Final residual
416    let ax = operator.apply(&x);
417    let residual: Array1<T> = b - &ax;
418    let r = precond.apply(&residual);
419    let rel_residual = vector_norm(&r) / b_norm;
420
421    GmresSolution {
422        x,
423        iterations: total_iterations,
424        restarts,
425        residual: rel_residual,
426        converged: false,
427    }
428}
429
430/// GMRES solver with preconditioner and initial guess
431///
432/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
433/// with an optional initial guess x0.
434pub fn gmres_preconditioned_with_guess<T, A, P>(
435    operator: &A,
436    precond: &P,
437    b: &Array1<T>,
438    x0: Option<&Array1<T>>,
439    config: &GmresConfig<T::Real>,
440) -> GmresSolution<T>
441where
442    T: ComplexField,
443    A: LinearOperator<T>,
444    P: Preconditioner<T>,
445{
446    let n = b.len();
447    let m = config.restart;
448
449    // Initialize solution vector from initial guess or zero
450    let mut x = match x0 {
451        Some(guess) => guess.clone(),
452        None => Array1::from_elem(n, T::zero()),
453    };
454
455    // Compute preconditioned RHS norm
456    let pb = precond.apply(b);
457    let b_norm = vector_norm(&pb);
458    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
459    if b_norm < tol_threshold {
460        return GmresSolution {
461            x,
462            iterations: 0,
463            restarts: 0,
464            residual: T::Real::zero(),
465            converged: true,
466        };
467    }
468
469    let mut total_iterations = 0;
470    let mut restarts = 0;
471
472    for _outer in 0..config.max_iterations {
473        // Compute preconditioned residual r = M⁻¹(b - Ax)
474        let ax = operator.apply(&x);
475        let residual: Array1<T> = b - &ax;
476        let r = precond.apply(&residual);
477        let beta = vector_norm(&r);
478
479        let rel_residual = beta / b_norm;
480        if rel_residual < config.tolerance {
481            return GmresSolution {
482                x,
483                iterations: total_iterations,
484                restarts,
485                residual: rel_residual,
486                converged: true,
487            };
488        }
489
490        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
491        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
492
493        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
494        let mut cs: Vec<T> = Vec::with_capacity(m);
495        let mut sn: Vec<T> = Vec::with_capacity(m);
496
497        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
498        g[0] = T::from_real(beta);
499
500        let mut inner_converged = false;
501
502        for j in 0..m {
503            total_iterations += 1;
504
505            // w = M⁻¹ * A * v_j
506            let av = operator.apply(&v[j]);
507            let mut w = precond.apply(&av);
508
509            // Modified Gram-Schmidt
510            for i in 0..=j {
511                h[[i, j]] = inner_product(&v[i], &w);
512                let h_ij = h[[i, j]];
513                w = &w - &v[i].mapv(|vi| vi * h_ij);
514            }
515
516            let w_norm = vector_norm(&w);
517            h[[j + 1, j]] = T::from_real(w_norm);
518
519            let breakdown_tol = T::Real::from_f64(1e-14).unwrap();
520            if w_norm < breakdown_tol {
521                inner_converged = true;
522            } else {
523                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
524            }
525
526            // Apply previous Givens rotations
527            for i in 0..j {
528                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
529                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
530                h[[i, j]] = temp;
531            }
532
533            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
534            cs.push(c);
535            sn.push(s);
536
537            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
538            h[[j + 1, j]] = T::zero();
539
540            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
541            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
542            g[j] = temp;
543
544            let rel_residual = g[j + 1].norm() / b_norm;
545
546            if rel_residual < config.tolerance || inner_converged {
547                let y = solve_upper_triangular(&h, &g, j + 1);
548
549                for (i, &yi) in y.iter().enumerate() {
550                    x = &x + &v[i].mapv(|vi| vi * yi);
551                }
552
553                return GmresSolution {
554                    x,
555                    iterations: total_iterations,
556                    restarts,
557                    residual: rel_residual,
558                    converged: true,
559                };
560            }
561        }
562
563        // Restart
564        let y = solve_upper_triangular(&h, &g, m);
565        for (i, &yi) in y.iter().enumerate() {
566            x = &x + &v[i].mapv(|vi| vi * yi);
567        }
568
569        restarts += 1;
570    }
571
572    // Final residual
573    let ax = operator.apply(&x);
574    let residual: Array1<T> = b - &ax;
575    let r = precond.apply(&residual);
576    let rel_residual = vector_norm(&r) / b_norm;
577
578    GmresSolution {
579        x,
580        iterations: total_iterations,
581        restarts,
582        residual: rel_residual,
583        converged: false,
584    }
585}
586
587/// Compute Givens rotation coefficients
588#[inline]
589fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
590    let tol = T::Real::from_f64(1e-30).unwrap();
591    if b.norm() < tol {
592        return (T::one(), T::zero());
593    }
594    if a.norm() < tol {
595        return (T::zero(), T::one());
596    }
597
598    let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
599    let c = a * T::from_real(T::Real::one() / r);
600    let s = b * T::from_real(T::Real::one() / r);
601
602    (c, s)
603}
604
605/// Solve upper triangular system Hy = g
606fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
607    let mut y = vec![T::zero(); k];
608    let tol = T::Real::from_f64(1e-30).unwrap();
609
610    for i in (0..k).rev() {
611        let mut sum = g[i];
612        for j in (i + 1)..k {
613            sum -= h[[i, j]] * y[j];
614        }
615        if h[[i, i]].norm() > tol {
616            y[i] = sum * h[[i, i]].inv();
617        }
618    }
619
620    y
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use crate::sparse::CsrMatrix;
627    use approx::assert_relative_eq;
628    use ndarray::array;
629    use num_complex::Complex64;
630
631    #[test]
632    fn test_gmres_simple() {
633        let dense = array![
634            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
635            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
636        ];
637
638        let a = CsrMatrix::from_dense(&dense, 1e-15);
639        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
640
641        let config = GmresConfig {
642            max_iterations: 100,
643            restart: 10,
644            tolerance: 1e-10,
645            print_interval: 0,
646        };
647
648        let solution = gmres(&a, &b, &config);
649
650        assert!(solution.converged, "GMRES should converge");
651
652        // Verify solution: Ax ≈ b
653        let ax = a.matvec(&solution.x);
654        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
655        assert!(error < 1e-8, "Solution should satisfy Ax = b");
656    }
657
658    #[test]
659    fn test_gmres_identity() {
660        let n = 5;
661        let id: CsrMatrix<Complex64> = CsrMatrix::identity(n);
662        let b = Array1::from_iter((1..=n).map(|i| Complex64::new(i as f64, 0.0)));
663
664        let config = GmresConfig {
665            max_iterations: 10,
666            restart: 10,
667            tolerance: 1e-12,
668            print_interval: 0,
669        };
670
671        let solution = gmres(&id, &b, &config);
672
673        assert!(solution.converged);
674        assert!(solution.iterations <= 2);
675
676        let error: f64 = (&solution.x - &b)
677            .iter()
678            .map(|e| e.norm_sqr())
679            .sum::<f64>()
680            .sqrt();
681        assert!(error < 1e-10);
682    }
683
684    #[test]
685    fn test_gmres_f64() {
686        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
687
688        let a = CsrMatrix::from_dense(&dense, 1e-15);
689        let b = array![1.0_f64, 2.0];
690
691        let config = GmresConfig {
692            max_iterations: 100,
693            restart: 10,
694            tolerance: 1e-10,
695            print_interval: 0,
696        };
697
698        let solution = gmres(&a, &b, &config);
699
700        assert!(solution.converged);
701
702        let ax = a.matvec(&solution.x);
703        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
704        assert_relative_eq!(error, 0.0, epsilon = 1e-8);
705    }
706}