Skip to main content

math_audio_solvers/iterative/
gmres_deflated.rs

1//! Deflated GMRES solver for systems with clustered eigenvalues
2//!
3//! Deflated GMRES projects out a subspace W spanning approximations to
4//! problematic eigenmodes before iterating, dramatically reducing iteration
5//! counts for problems like Helmholtz where eigenvalues cluster near k².
6//!
7//! Reference: Erlangga & Nabben (2008), Gaul et al. (2013).
8
9use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::direct::LuError;
11use crate::direct::lu::{LuFactorization, lu_factorize};
12use crate::traits::{ComplexField, LinearOperator, Preconditioner};
13use ndarray::{Array1, Array2};
14use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
15
16use super::gmres::{GmresConfig, GmresSolution};
17
18/// Pre-computed deflation data for deflated GMRES
19///
20/// Given deflation vectors W = [w₁,...,wᵣ] (n×r, r≪n) and operator A:
21/// - AW = A·W columns
22/// - E = Wᴴ·A·W (r×r dense), factored via LU
23///
24/// Provides projectors:
25/// - Left: P(v) = v - AW·E⁻¹·Wᴴ·v
26/// - Coarse correction: W·E⁻¹·Wᴴ·b
27/// - Recovery: Q(v) = v - W·E⁻¹·Wᴴ·A·v
28#[derive(Debug, Clone)]
29pub struct DeflationSubspace<T: ComplexField> {
30    w_columns: Vec<Array1<T>>,
31    aw_columns: Vec<Array1<T>>,
32    e_lu: LuFactorization<T>,
33}
34
35impl<T: ComplexField> DeflationSubspace<T> {
36    /// Construct deflation subspace from deflation vectors and operator
37    ///
38    /// Computes AW = A·W and factors E = Wᴴ·A·W via LU.
39    ///
40    /// # Arguments
41    /// * `w_columns` - Deflation vectors (should be orthonormal or near-orthonormal)
42    /// * `operator` - The linear operator A
43    pub fn new<A: LinearOperator<T>>(
44        w_columns: Vec<Array1<T>>,
45        operator: &A,
46    ) -> Result<Self, LuError> {
47        let r = w_columns.len();
48        if r == 0 {
49            return Ok(Self {
50                w_columns,
51                aw_columns: Vec::new(),
52                e_lu: LuFactorization {
53                    lu: Array2::from_elem((0, 0), T::zero()),
54                    pivots: Vec::new(),
55                    n: 0,
56                },
57            });
58        }
59
60        // Compute AW columns
61        let aw_columns: Vec<Array1<T>> = w_columns.iter().map(|w| operator.apply(w)).collect();
62
63        // Build E = W^H * A * W (r x r dense)
64        let mut e = Array2::from_elem((r, r), T::zero());
65        for i in 0..r {
66            for j in 0..r {
67                e[[i, j]] = inner_product(&w_columns[i], &aw_columns[j]);
68            }
69        }
70
71        let e_lu = lu_factorize(&e)?;
72
73        Ok(Self {
74            w_columns,
75            aw_columns,
76            e_lu,
77        })
78    }
79
80    /// Number of deflation vectors
81    pub fn num_vectors(&self) -> usize {
82        self.w_columns.len()
83    }
84
85    /// Apply left deflation projector: P(v) = v - AW·E⁻¹·Wᴴ·v
86    pub fn apply_left_projector(&self, v: &Array1<T>) -> Array1<T> {
87        let r = self.w_columns.len();
88        if r == 0 {
89            return v.clone();
90        }
91
92        // Compute W^H * v (r-dimensional)
93        let mut wh_v = Array1::from_elem(r, T::zero());
94        for i in 0..r {
95            wh_v[i] = inner_product(&self.w_columns[i], v);
96        }
97
98        // Solve E * y = W^H * v
99        let y = self
100            .e_lu
101            .solve(&wh_v)
102            .expect("Deflation matrix E should be non-singular");
103
104        // result = v - AW * y
105        let mut result = v.clone();
106        for i in 0..r {
107            axpy(-y[i], &self.aw_columns[i], &mut result);
108        }
109
110        result
111    }
112
113    /// Compute coarse correction: W·E⁻¹·Wᴴ·b
114    pub fn coarse_correction(&self, b: &Array1<T>) -> Array1<T> {
115        let r = self.w_columns.len();
116        let n = b.len();
117        if r == 0 {
118            return Array1::from_elem(n, T::zero());
119        }
120
121        // Compute W^H * b
122        let mut wh_b = Array1::from_elem(r, T::zero());
123        for i in 0..r {
124            wh_b[i] = inner_product(&self.w_columns[i], b);
125        }
126
127        // Solve E * y = W^H * b
128        let y = self
129            .e_lu
130            .solve(&wh_b)
131            .expect("Deflation matrix E should be non-singular");
132
133        // result = W * y
134        let mut result = Array1::from_elem(n, T::zero());
135        for i in 0..r {
136            axpy(y[i], &self.w_columns[i], &mut result);
137        }
138
139        result
140    }
141
142    /// Apply recovery operator: Q(v) = v - W·E⁻¹·Wᴴ·A·v
143    pub fn apply_recovery<A: LinearOperator<T>>(&self, v: &Array1<T>, operator: &A) -> Array1<T> {
144        let r = self.w_columns.len();
145        if r == 0 {
146            return v.clone();
147        }
148
149        // Compute A * v
150        let av = operator.apply(v);
151
152        // Compute W^H * A * v
153        let mut wh_av = Array1::from_elem(r, T::zero());
154        for i in 0..r {
155            wh_av[i] = inner_product(&self.w_columns[i], &av);
156        }
157
158        // Solve E * y = W^H * A * v
159        let y = self
160            .e_lu
161            .solve(&wh_av)
162            .expect("Deflation matrix E should be non-singular");
163
164        // result = v - W * y
165        let mut result = v.clone();
166        for i in 0..r {
167            axpy(-y[i], &self.w_columns[i], &mut result);
168        }
169
170        result
171    }
172}
173
174/// Solve Ax = b using deflated GMRES (no preconditioner)
175///
176/// Projects out the deflation subspace before each Arnoldi step.
177/// After convergence, applies recovery operator and adds coarse correction.
178pub fn gmres_deflated<T, A>(
179    operator: &A,
180    deflation: &DeflationSubspace<T>,
181    b: &Array1<T>,
182    x0: Option<&Array1<T>>,
183    config: &GmresConfig<T::Real>,
184) -> GmresSolution<T>
185where
186    T: ComplexField,
187    A: LinearOperator<T>,
188{
189    use crate::traits::IdentityPreconditioner;
190    let precond = IdentityPreconditioner;
191    gmres_deflated_preconditioned(operator, &precond, deflation, b, x0, config)
192}
193
194/// Solve Ax = b using deflated preconditioned GMRES
195///
196/// Combines left preconditioning M⁻¹ with deflation projection P:
197/// At each Arnoldi step: w = M⁻¹ · P(A · vⱼ)
198/// After convergence to x̂: x = Q(x̂) + coarse_correction(b)
199///
200/// Reference: Erlangga & Nabben (2008)
201pub fn gmres_deflated_preconditioned<T, A, P>(
202    operator: &A,
203    precond: &P,
204    deflation: &DeflationSubspace<T>,
205    b: &Array1<T>,
206    x0: Option<&Array1<T>>,
207    config: &GmresConfig<T::Real>,
208) -> GmresSolution<T>
209where
210    T: ComplexField,
211    A: LinearOperator<T>,
212    P: Preconditioner<T>,
213{
214    // If no deflation vectors, fall back to standard preconditioned GMRES
215    if deflation.num_vectors() == 0 {
216        return super::gmres::gmres_preconditioned_with_guess(operator, precond, b, x0, config);
217    }
218
219    let n = b.len();
220    let m = config.restart;
221
222    // Compute coarse correction: x_c = W·E⁻¹·Wᴴ·b
223    let x_c = deflation.coarse_correction(b);
224
225    // Initialize solution from initial guess or zero
226    let mut x_hat = match x0 {
227        Some(guess) => guess.clone(),
228        None => Array1::from_elem(n, T::zero()),
229    };
230
231    // Compute preconditioned deflated RHS norm for relative tolerance
232    let pb = precond.apply(&deflation.apply_left_projector(b));
233    let b_norm = vector_norm(&pb);
234    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
235    if b_norm < tol_threshold {
236        // Deflated RHS is essentially zero — coarse correction is the solution
237        return GmresSolution {
238            x: x_c,
239            iterations: 0,
240            restarts: 0,
241            residual: T::Real::zero(),
242            converged: true,
243            status: crate::traits::SolverStatus::Converged,
244        };
245    }
246
247    let mut total_iterations = 0;
248    let mut restarts = 0;
249
250    // Outer iteration (restarts)
251    for _outer in 0..config.max_iterations {
252        // Compute deflated preconditioned residual: M⁻¹ · P · (b - A · x̂)
253        let ax = operator.apply(&x_hat);
254        let residual: Array1<T> = b - &ax;
255        let deflated_residual = deflation.apply_left_projector(&residual);
256        let r = precond.apply(&deflated_residual);
257        let beta = vector_norm(&r);
258
259        let rel_residual = beta / b_norm;
260        if rel_residual < config.tolerance {
261            // Apply recovery and coarse correction
262            let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
263            return GmresSolution {
264                x,
265                iterations: total_iterations,
266                restarts,
267                residual: rel_residual,
268                converged: true,
269                status: crate::traits::SolverStatus::Converged,
270            };
271        }
272
273        // Initialize Krylov basis
274        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
275        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
276
277        // Upper Hessenberg matrix
278        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
279
280        // Givens rotation coefficients
281        let mut cs: Vec<T> = Vec::with_capacity(m);
282        let mut sn: Vec<T> = Vec::with_capacity(m);
283
284        // Right-hand side of least squares problem
285        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
286        g[0] = T::from_real(beta);
287
288        let mut inner_converged = false;
289
290        // Inner iteration (Arnoldi process with deflation)
291        for j in 0..m {
292            total_iterations += 1;
293
294            // Deflated preconditioned matvec: w = M⁻¹ · P · A · vⱼ
295            let av = operator.apply(&v[j]);
296            let pav = deflation.apply_left_projector(&av);
297            let mut w = precond.apply(&pav);
298
299            // Modified Gram-Schmidt orthogonalization
300            for i in 0..=j {
301                h[[i, j]] = inner_product(&v[i], &w);
302                let h_ij = h[[i, j]];
303                w = &w - &v[i].mapv(|vi| vi * h_ij);
304            }
305
306            let w_norm = vector_norm(&w);
307            h[[j + 1, j]] = T::from_real(w_norm);
308
309            // Check for breakdown
310            let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
311            if w_norm < breakdown_tol {
312                inner_converged = true;
313            } else {
314                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
315            }
316
317            // Apply previous Givens rotations to new column of H
318            for i in 0..j {
319                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
320                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
321                h[[i, j]] = temp;
322            }
323
324            // Compute new Givens rotation
325            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
326            cs.push(c);
327            sn.push(s);
328
329            // Apply Givens rotation to H and g
330            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
331            h[[j + 1, j]] = T::zero();
332
333            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
334            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
335            g[j] = temp;
336
337            // Check convergence
338            let abs_residual = g[j + 1].norm();
339            let rel_residual = abs_residual / b_norm;
340            let abs_tol = T::Real::from_f64(1e-20).unwrap();
341
342            if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
343                log::info!(
344                    "Deflated GMRES iteration {} (restart {}): relative residual = {:.6e}",
345                    total_iterations,
346                    restarts,
347                    rel_residual.to_f64().unwrap_or(0.0)
348                );
349            }
350
351            if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
352                // Solve upper triangular system Hy = g
353                let y = solve_upper_triangular(&h, &g, j + 1);
354
355                // Update x̂ = x̂ + V * y
356                for (i, &yi) in y.iter().enumerate() {
357                    x_hat = &x_hat + &v[i].mapv(|vi| vi * yi);
358                }
359
360                // Apply recovery and coarse correction: x = Q(x̂) + x_c
361                let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
362
363                let status = if inner_converged
364                    && rel_residual >= config.tolerance
365                    && abs_residual >= abs_tol
366                {
367                    crate::traits::SolverStatus::Breakdown
368                } else {
369                    crate::traits::SolverStatus::Converged
370                };
371
372                return GmresSolution {
373                    x,
374                    iterations: total_iterations,
375                    restarts,
376                    residual: rel_residual,
377                    converged: true,
378                    status,
379                };
380            }
381        }
382
383        // Maximum inner iterations reached — update x̂ and restart
384        let y = solve_upper_triangular(&h, &g, m);
385        for (i, &yi) in y.iter().enumerate() {
386            x_hat = &x_hat + &v[i].mapv(|vi| vi * yi);
387        }
388
389        restarts += 1;
390    }
391
392    // Did not converge — still apply recovery for best-effort solution
393    let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
394
395    // Compute final true residual
396    let ax = operator.apply(&x);
397    let r: Array1<T> = b - &ax;
398    let r_norm = vector_norm(&r);
399    let b_true_norm = vector_norm(b);
400    let rel_residual = if b_true_norm > T::Real::from_f64(1e-15).unwrap() {
401        r_norm / b_true_norm
402    } else {
403        r_norm
404    };
405
406    GmresSolution {
407        x,
408        iterations: total_iterations,
409        restarts,
410        residual: rel_residual,
411        converged: false,
412        status: crate::traits::SolverStatus::MaxIterationsReached,
413    }
414}
415
416/// Compute Givens rotation coefficients
417#[inline]
418fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
419    let tol = T::Real::from_f64(1e-30).unwrap();
420    if b.norm() < tol {
421        return (T::one(), T::zero());
422    }
423    if a.norm() < tol {
424        return (T::zero(), T::one());
425    }
426
427    let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
428    let c = a * T::from_real(T::Real::one() / r);
429    let s = b * T::from_real(T::Real::one() / r);
430
431    (c, s)
432}
433
434/// Solve upper triangular system Hy = g
435fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
436    let mut y = vec![T::zero(); k];
437    let tol = T::Real::from_f64(1e-30).unwrap();
438
439    for i in (0..k).rev() {
440        let mut sum = g[i];
441        for j in (i + 1)..k {
442            sum -= h[[i, j]] * y[j];
443        }
444        if h[[i, i]].norm() > tol {
445            y[i] = sum * h[[i, i]].inv();
446        }
447    }
448
449    y
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use crate::sparse::CsrMatrix;
456    use approx::assert_relative_eq;
457    use ndarray::array;
458    use num_complex::Complex64;
459
460    /// Build a diagonal operator with clustered eigenvalues near `cluster_center`
461    /// plus a few well-separated eigenvalues
462    fn build_clustered_diagonal(n: usize, cluster_center: f64) -> CsrMatrix<Complex64> {
463        let mut dense = Array2::from_elem((n, n), Complex64::new(0.0, 0.0));
464        for i in 0..n {
465            if i < n - 3 {
466                // Clustered eigenvalues near cluster_center
467                let offset = 0.01 * (i as f64 - (n as f64 - 3.0) / 2.0);
468                dense[[i, i]] = Complex64::new(cluster_center + offset, 0.0);
469            } else {
470                // Well-separated eigenvalues
471                dense[[i, i]] = Complex64::new((i + 1) as f64 * 10.0, 0.0);
472            }
473        }
474        CsrMatrix::from_dense(&dense, 1e-15)
475    }
476
477    #[test]
478    fn test_deflation_subspace_construction() {
479        // Simple 3x3 diagonal system with known eigenvectors
480        let dense = array![
481            [Complex64::new(1.0, 0.0), Complex64::ZERO, Complex64::ZERO],
482            [Complex64::ZERO, Complex64::new(2.0, 0.0), Complex64::ZERO],
483            [Complex64::ZERO, Complex64::ZERO, Complex64::new(3.0, 0.0)],
484        ];
485        let a = CsrMatrix::from_dense(&dense, 1e-15);
486
487        // Deflate first eigenvector
488        let w1 = array![Complex64::new(1.0, 0.0), Complex64::ZERO, Complex64::ZERO];
489        let deflation = DeflationSubspace::new(vec![w1], &a).unwrap();
490
491        assert_eq!(deflation.num_vectors(), 1);
492
493        // P(e1) should be zero (eigenvector is fully deflated)
494        let e1 = array![Complex64::new(1.0, 0.0), Complex64::ZERO, Complex64::ZERO];
495        let pe1 = deflation.apply_left_projector(&e1);
496        let pe1_norm: f64 = pe1.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
497        assert!(pe1_norm < 1e-12, "Deflated eigenvector should be zero");
498
499        // P(e2) should still be nonzero
500        let e2 = array![Complex64::ZERO, Complex64::new(1.0, 0.0), Complex64::ZERO];
501        let pe2 = deflation.apply_left_projector(&e2);
502        let pe2_norm: f64 = pe2.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
503        assert!(pe2_norm > 0.5, "Non-deflated vector should remain");
504    }
505
506    #[test]
507    fn test_deflated_gmres_fewer_iterations() {
508        let n = 20;
509        let cluster = 5.0;
510        let a = build_clustered_diagonal(n, cluster);
511
512        // Build RHS
513        let b = Array1::from_iter((0..n).map(|i| Complex64::new((i + 1) as f64, 0.0)));
514
515        let config = GmresConfig {
516            max_iterations: 200,
517            restart: 30,
518            tolerance: 1e-10,
519            print_interval: 0,
520        };
521
522        // Standard GMRES
523        let sol_standard = super::super::gmres::gmres(&a, &b, &config);
524
525        // Deflated GMRES with exact eigenvectors of clustered eigenvalues
526        let mut w_cols = Vec::new();
527        for i in 0..(n - 3) {
528            let mut w = Array1::from_elem(n, Complex64::ZERO);
529            w[i] = Complex64::new(1.0, 0.0);
530            w_cols.push(w);
531        }
532        let deflation = DeflationSubspace::new(w_cols, &a).unwrap();
533        let sol_deflated = gmres_deflated(&a, &deflation, &b, None, &config);
534
535        assert!(sol_standard.converged, "Standard GMRES should converge");
536        assert!(sol_deflated.converged, "Deflated GMRES should converge");
537
538        // Deflated version should need fewer or equal iterations
539        assert!(
540            sol_deflated.iterations <= sol_standard.iterations,
541            "Deflated ({}) should use <= iterations than standard ({})",
542            sol_deflated.iterations,
543            sol_standard.iterations
544        );
545
546        // Verify solution correctness
547        let ax = a.matvec(&sol_deflated.x);
548        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
549        assert!(error < 1e-8, "Deflated solution should satisfy Ax = b");
550    }
551
552    #[test]
553    fn test_deflated_gmres_with_preconditioner() {
554        let dense = array![
555            [
556                Complex64::new(4.0, 0.0),
557                Complex64::new(1.0, 0.0),
558                Complex64::ZERO,
559            ],
560            [
561                Complex64::new(1.0, 0.0),
562                Complex64::new(3.0, 0.0),
563                Complex64::new(1.0, 0.0),
564            ],
565            [
566                Complex64::ZERO,
567                Complex64::new(1.0, 0.0),
568                Complex64::new(5.0, 0.0),
569            ],
570        ];
571        let a = CsrMatrix::from_dense(&dense, 1e-15);
572        let b = array![
573            Complex64::new(1.0, 0.0),
574            Complex64::new(2.0, 0.0),
575            Complex64::new(3.0, 0.0)
576        ];
577
578        // Use Jacobi preconditioner
579        let precond = crate::preconditioners::DiagonalPreconditioner::from_csr(&a);
580
581        // Single deflation vector (approximate eigenvector)
582        let w1 = array![
583            Complex64::new(0.5, 0.0),
584            Complex64::new(0.7, 0.0),
585            Complex64::new(0.5, 0.0)
586        ];
587        let w1_norm = vector_norm(&w1);
588        let w1_normalized = w1.mapv(|v| v * Complex64::new(1.0 / w1_norm, 0.0));
589
590        let deflation = DeflationSubspace::new(vec![w1_normalized], &a).unwrap();
591
592        let config = GmresConfig {
593            max_iterations: 50,
594            restart: 10,
595            tolerance: 1e-10,
596            print_interval: 0,
597        };
598
599        let sol = gmres_deflated_preconditioned(&a, &precond, &deflation, &b, None, &config);
600        assert!(
601            sol.converged,
602            "Deflated preconditioned GMRES should converge"
603        );
604
605        // Verify solution
606        let ax = a.matvec(&sol.x);
607        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
608        assert!(error < 1e-8, "Solution should satisfy Ax = b");
609    }
610
611    #[test]
612    fn test_deflated_gmres_zero_vectors_fallback() {
613        // r=0 deflation vectors should fall back to standard GMRES
614        let dense = array![
615            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
616            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
617        ];
618        let a = CsrMatrix::from_dense(&dense, 1e-15);
619        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
620
621        let deflation = DeflationSubspace::new(Vec::new(), &a).unwrap();
622        assert_eq!(deflation.num_vectors(), 0);
623
624        let config = GmresConfig {
625            max_iterations: 50,
626            restart: 10,
627            tolerance: 1e-10,
628            print_interval: 0,
629        };
630
631        let sol = gmres_deflated(&a, &deflation, &b, None, &config);
632        assert!(sol.converged, "Zero-vector deflated GMRES should converge");
633
634        // Should give same result as standard GMRES
635        let sol_standard = super::super::gmres::gmres(&a, &b, &config);
636        let error: f64 = (&sol.x - &sol_standard.x)
637            .iter()
638            .map(|e| e.norm_sqr())
639            .sum::<f64>()
640            .sqrt();
641        assert!(
642            error < 1e-8,
643            "Zero-deflation should match standard GMRES solution"
644        );
645    }
646
647    #[test]
648    fn test_coarse_correction_exact_for_deflation_space() {
649        // If b lies entirely in range(A·W), the coarse correction should recover it
650        let dense = array![
651            [Complex64::new(2.0, 0.0), Complex64::ZERO],
652            [Complex64::ZERO, Complex64::new(3.0, 0.0)],
653        ];
654        let a = CsrMatrix::from_dense(&dense, 1e-15);
655
656        // Deflate both eigenvectors — coarse correction solves exactly
657        let w1 = array![Complex64::new(1.0, 0.0), Complex64::ZERO];
658        let w2 = array![Complex64::ZERO, Complex64::new(1.0, 0.0)];
659        let deflation = DeflationSubspace::new(vec![w1, w2], &a).unwrap();
660
661        let b = array![Complex64::new(4.0, 0.0), Complex64::new(9.0, 0.0)];
662        let x_c = deflation.coarse_correction(&b);
663
664        // x_c should be the exact solution A⁻¹b = [2, 3]
665        let ax_c = a.matvec(&x_c);
666        let error: f64 = (&ax_c - &b)
667            .iter()
668            .map(|e| e.norm_sqr())
669            .sum::<f64>()
670            .sqrt();
671        assert!(error < 1e-10, "Coarse correction should be exact solution");
672    }
673
674    #[test]
675    fn test_deflated_gmres_f64() {
676        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
677        let a = CsrMatrix::from_dense(&dense, 1e-15);
678        let b = array![1.0_f64, 2.0];
679
680        // Deflate one approximate eigenvector
681        let w1 = array![std::f64::consts::FRAC_1_SQRT_2, std::f64::consts::FRAC_1_SQRT_2];
682        let w1_norm = vector_norm(&w1);
683        let w1_normalized = w1.mapv(|v| v / w1_norm);
684        let deflation = DeflationSubspace::new(vec![w1_normalized], &a).unwrap();
685
686        let config = GmresConfig {
687            max_iterations: 50,
688            restart: 10,
689            tolerance: 1e-10,
690            print_interval: 0,
691        };
692
693        let sol = gmres_deflated(&a, &deflation, &b, None, &config);
694        assert!(sol.converged);
695
696        let ax = a.matvec(&sol.x);
697        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
698        assert_relative_eq!(error, 0.0, epsilon = 1e-8);
699    }
700}