Skip to main content

math_audio_solvers/iterative/
gmres_deflated.rs

1//! Deflated GMRES solver for systems with clustered eigenvalues
2//!
3//! Deflated GMRES projects out a subspace W spanning approximations to
4//! problematic eigenmodes before iterating, dramatically reducing iteration
5//! counts for problems like Helmholtz where eigenvalues cluster near k².
6//!
7//! Reference: Erlangga & Nabben (2008), Gaul et al. (2013).
8
9use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::direct::lu::{LuFactorization, lu_factorize};
11use crate::direct::LuError;
12use crate::traits::{ComplexField, LinearOperator, Preconditioner};
13use ndarray::{Array1, Array2};
14use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
15
16use super::gmres::{GmresConfig, GmresSolution};
17
18/// Pre-computed deflation data for deflated GMRES
19///
20/// Given deflation vectors W = [w₁,...,wᵣ] (n×r, r≪n) and operator A:
21/// - AW = A·W columns
22/// - E = Wᴴ·A·W (r×r dense), factored via LU
23///
24/// Provides projectors:
25/// - Left: P(v) = v - AW·E⁻¹·Wᴴ·v
26/// - Coarse correction: W·E⁻¹·Wᴴ·b
27/// - Recovery: Q(v) = v - W·E⁻¹·Wᴴ·A·v
28#[derive(Debug, Clone)]
29pub struct DeflationSubspace<T: ComplexField> {
30    w_columns: Vec<Array1<T>>,
31    aw_columns: Vec<Array1<T>>,
32    e_lu: LuFactorization<T>,
33}
34
35impl<T: ComplexField> DeflationSubspace<T> {
36    /// Construct deflation subspace from deflation vectors and operator
37    ///
38    /// Computes AW = A·W and factors E = Wᴴ·A·W via LU.
39    ///
40    /// # Arguments
41    /// * `w_columns` - Deflation vectors (should be orthonormal or near-orthonormal)
42    /// * `operator` - The linear operator A
43    pub fn new<A: LinearOperator<T>>(
44        w_columns: Vec<Array1<T>>,
45        operator: &A,
46    ) -> Result<Self, LuError> {
47        let r = w_columns.len();
48        if r == 0 {
49            return Ok(Self {
50                w_columns,
51                aw_columns: Vec::new(),
52                e_lu: LuFactorization {
53                    lu: Array2::from_elem((0, 0), T::zero()),
54                    pivots: Vec::new(),
55                    n: 0,
56                },
57            });
58        }
59
60        // Compute AW columns
61        let aw_columns: Vec<Array1<T>> = w_columns.iter().map(|w| operator.apply(w)).collect();
62
63        // Build E = W^H * A * W (r x r dense)
64        let mut e = Array2::from_elem((r, r), T::zero());
65        for i in 0..r {
66            for j in 0..r {
67                e[[i, j]] = inner_product(&w_columns[i], &aw_columns[j]);
68            }
69        }
70
71        let e_lu = lu_factorize(&e)?;
72
73        Ok(Self {
74            w_columns,
75            aw_columns,
76            e_lu,
77        })
78    }
79
80    /// Number of deflation vectors
81    pub fn num_vectors(&self) -> usize {
82        self.w_columns.len()
83    }
84
85    /// Apply left deflation projector: P(v) = v - AW·E⁻¹·Wᴴ·v
86    pub fn apply_left_projector(&self, v: &Array1<T>) -> Array1<T> {
87        let r = self.w_columns.len();
88        if r == 0 {
89            return v.clone();
90        }
91
92        // Compute W^H * v (r-dimensional)
93        let mut wh_v = Array1::from_elem(r, T::zero());
94        for i in 0..r {
95            wh_v[i] = inner_product(&self.w_columns[i], v);
96        }
97
98        // Solve E * y = W^H * v
99        let y = self
100            .e_lu
101            .solve(&wh_v)
102            .expect("Deflation matrix E should be non-singular");
103
104        // result = v - AW * y
105        let mut result = v.clone();
106        for i in 0..r {
107            axpy(-y[i], &self.aw_columns[i], &mut result);
108        }
109
110        result
111    }
112
113    /// Compute coarse correction: W·E⁻¹·Wᴴ·b
114    pub fn coarse_correction(&self, b: &Array1<T>) -> Array1<T> {
115        let r = self.w_columns.len();
116        let n = b.len();
117        if r == 0 {
118            return Array1::from_elem(n, T::zero());
119        }
120
121        // Compute W^H * b
122        let mut wh_b = Array1::from_elem(r, T::zero());
123        for i in 0..r {
124            wh_b[i] = inner_product(&self.w_columns[i], b);
125        }
126
127        // Solve E * y = W^H * b
128        let y = self
129            .e_lu
130            .solve(&wh_b)
131            .expect("Deflation matrix E should be non-singular");
132
133        // result = W * y
134        let mut result = Array1::from_elem(n, T::zero());
135        for i in 0..r {
136            axpy(y[i], &self.w_columns[i], &mut result);
137        }
138
139        result
140    }
141
142    /// Apply recovery operator: Q(v) = v - W·E⁻¹·Wᴴ·A·v
143    pub fn apply_recovery<A: LinearOperator<T>>(&self, v: &Array1<T>, operator: &A) -> Array1<T> {
144        let r = self.w_columns.len();
145        if r == 0 {
146            return v.clone();
147        }
148
149        // Compute A * v
150        let av = operator.apply(v);
151
152        // Compute W^H * A * v
153        let mut wh_av = Array1::from_elem(r, T::zero());
154        for i in 0..r {
155            wh_av[i] = inner_product(&self.w_columns[i], &av);
156        }
157
158        // Solve E * y = W^H * A * v
159        let y = self
160            .e_lu
161            .solve(&wh_av)
162            .expect("Deflation matrix E should be non-singular");
163
164        // result = v - W * y
165        let mut result = v.clone();
166        for i in 0..r {
167            axpy(-y[i], &self.w_columns[i], &mut result);
168        }
169
170        result
171    }
172}
173
174/// Solve Ax = b using deflated GMRES (no preconditioner)
175///
176/// Projects out the deflation subspace before each Arnoldi step.
177/// After convergence, applies recovery operator and adds coarse correction.
178pub fn gmres_deflated<T, A>(
179    operator: &A,
180    deflation: &DeflationSubspace<T>,
181    b: &Array1<T>,
182    x0: Option<&Array1<T>>,
183    config: &GmresConfig<T::Real>,
184) -> GmresSolution<T>
185where
186    T: ComplexField,
187    A: LinearOperator<T>,
188{
189    use crate::traits::IdentityPreconditioner;
190    let precond = IdentityPreconditioner;
191    gmres_deflated_preconditioned(operator, &precond, deflation, b, x0, config)
192}
193
194/// Solve Ax = b using deflated preconditioned GMRES
195///
196/// Combines left preconditioning M⁻¹ with deflation projection P:
197/// At each Arnoldi step: w = M⁻¹ · P(A · vⱼ)
198/// After convergence to x̂: x = Q(x̂) + coarse_correction(b)
199///
200/// Reference: Erlangga & Nabben (2008)
201pub fn gmres_deflated_preconditioned<T, A, P>(
202    operator: &A,
203    precond: &P,
204    deflation: &DeflationSubspace<T>,
205    b: &Array1<T>,
206    x0: Option<&Array1<T>>,
207    config: &GmresConfig<T::Real>,
208) -> GmresSolution<T>
209where
210    T: ComplexField,
211    A: LinearOperator<T>,
212    P: Preconditioner<T>,
213{
214    // If no deflation vectors, fall back to standard preconditioned GMRES
215    if deflation.num_vectors() == 0 {
216        return super::gmres::gmres_preconditioned_with_guess(operator, precond, b, x0, config);
217    }
218
219    let n = b.len();
220    let m = config.restart;
221
222    // Compute coarse correction: x_c = W·E⁻¹·Wᴴ·b
223    let x_c = deflation.coarse_correction(b);
224
225    // Initialize solution from initial guess or zero
226    let mut x_hat = match x0 {
227        Some(guess) => guess.clone(),
228        None => Array1::from_elem(n, T::zero()),
229    };
230
231    // Compute preconditioned deflated RHS norm for relative tolerance
232    let pb = precond.apply(&deflation.apply_left_projector(b));
233    let b_norm = vector_norm(&pb);
234    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
235    if b_norm < tol_threshold {
236        // Deflated RHS is essentially zero — coarse correction is the solution
237        return GmresSolution {
238            x: x_c,
239            iterations: 0,
240            restarts: 0,
241            residual: T::Real::zero(),
242            converged: true,
243            status: crate::traits::SolverStatus::Converged,
244        };
245    }
246
247    let mut total_iterations = 0;
248    let mut restarts = 0;
249
250    // Outer iteration (restarts)
251    for _outer in 0..config.max_iterations {
252        // Compute deflated preconditioned residual: M⁻¹ · P · (b - A · x̂)
253        let ax = operator.apply(&x_hat);
254        let residual: Array1<T> = b - &ax;
255        let deflated_residual = deflation.apply_left_projector(&residual);
256        let r = precond.apply(&deflated_residual);
257        let beta = vector_norm(&r);
258
259        let rel_residual = beta / b_norm;
260        if rel_residual < config.tolerance {
261            // Apply recovery and coarse correction
262            let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
263            return GmresSolution {
264                x,
265                iterations: total_iterations,
266                restarts,
267                residual: rel_residual,
268                converged: true,
269                status: crate::traits::SolverStatus::Converged,
270            };
271        }
272
273        // Initialize Krylov basis
274        let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
275        v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
276
277        // Upper Hessenberg matrix
278        let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
279
280        // Givens rotation coefficients
281        let mut cs: Vec<T> = Vec::with_capacity(m);
282        let mut sn: Vec<T> = Vec::with_capacity(m);
283
284        // Right-hand side of least squares problem
285        let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
286        g[0] = T::from_real(beta);
287
288        let mut inner_converged = false;
289
290        // Inner iteration (Arnoldi process with deflation)
291        for j in 0..m {
292            total_iterations += 1;
293
294            // Deflated preconditioned matvec: w = M⁻¹ · P · A · vⱼ
295            let av = operator.apply(&v[j]);
296            let pav = deflation.apply_left_projector(&av);
297            let mut w = precond.apply(&pav);
298
299            // Modified Gram-Schmidt orthogonalization
300            for i in 0..=j {
301                h[[i, j]] = inner_product(&v[i], &w);
302                let h_ij = h[[i, j]];
303                w = &w - &v[i].mapv(|vi| vi * h_ij);
304            }
305
306            let w_norm = vector_norm(&w);
307            h[[j + 1, j]] = T::from_real(w_norm);
308
309            // Check for breakdown
310            let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
311            if w_norm < breakdown_tol {
312                inner_converged = true;
313            } else {
314                v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
315            }
316
317            // Apply previous Givens rotations to new column of H
318            for i in 0..j {
319                let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
320                h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
321                h[[i, j]] = temp;
322            }
323
324            // Compute new Givens rotation
325            let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
326            cs.push(c);
327            sn.push(s);
328
329            // Apply Givens rotation to H and g
330            h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
331            h[[j + 1, j]] = T::zero();
332
333            let temp = c.conj() * g[j] + s.conj() * g[j + 1];
334            g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
335            g[j] = temp;
336
337            // Check convergence
338            let abs_residual = g[j + 1].norm();
339            let rel_residual = abs_residual / b_norm;
340            let abs_tol = T::Real::from_f64(1e-20).unwrap();
341
342            if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
343                log::info!(
344                    "Deflated GMRES iteration {} (restart {}): relative residual = {:.6e}",
345                    total_iterations,
346                    restarts,
347                    rel_residual.to_f64().unwrap_or(0.0)
348                );
349            }
350
351            if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
352                // Solve upper triangular system Hy = g
353                let y = solve_upper_triangular(&h, &g, j + 1);
354
355                // Update x̂ = x̂ + V * y
356                for (i, &yi) in y.iter().enumerate() {
357                    x_hat = &x_hat + &v[i].mapv(|vi| vi * yi);
358                }
359
360                // Apply recovery and coarse correction: x = Q(x̂) + x_c
361                let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
362
363                let status =
364                    if inner_converged && rel_residual >= config.tolerance && abs_residual >= abs_tol
365                    {
366                        crate::traits::SolverStatus::Breakdown
367                    } else {
368                        crate::traits::SolverStatus::Converged
369                    };
370
371                return GmresSolution {
372                    x,
373                    iterations: total_iterations,
374                    restarts,
375                    residual: rel_residual,
376                    converged: true,
377                    status,
378                };
379            }
380        }
381
382        // Maximum inner iterations reached — update x̂ and restart
383        let y = solve_upper_triangular(&h, &g, m);
384        for (i, &yi) in y.iter().enumerate() {
385            x_hat = &x_hat + &v[i].mapv(|vi| vi * yi);
386        }
387
388        restarts += 1;
389    }
390
391    // Did not converge — still apply recovery for best-effort solution
392    let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
393
394    // Compute final true residual
395    let ax = operator.apply(&x);
396    let r: Array1<T> = b - &ax;
397    let r_norm = vector_norm(&r);
398    let b_true_norm = vector_norm(b);
399    let rel_residual = if b_true_norm > T::Real::from_f64(1e-15).unwrap() {
400        r_norm / b_true_norm
401    } else {
402        r_norm
403    };
404
405    GmresSolution {
406        x,
407        iterations: total_iterations,
408        restarts,
409        residual: rel_residual,
410        converged: false,
411        status: crate::traits::SolverStatus::MaxIterationsReached,
412    }
413}
414
415/// Compute Givens rotation coefficients
416#[inline]
417fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
418    let tol = T::Real::from_f64(1e-30).unwrap();
419    if b.norm() < tol {
420        return (T::one(), T::zero());
421    }
422    if a.norm() < tol {
423        return (T::zero(), T::one());
424    }
425
426    let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
427    let c = a * T::from_real(T::Real::one() / r);
428    let s = b * T::from_real(T::Real::one() / r);
429
430    (c, s)
431}
432
433/// Solve upper triangular system Hy = g
434fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
435    let mut y = vec![T::zero(); k];
436    let tol = T::Real::from_f64(1e-30).unwrap();
437
438    for i in (0..k).rev() {
439        let mut sum = g[i];
440        for j in (i + 1)..k {
441            sum -= h[[i, j]] * y[j];
442        }
443        if h[[i, i]].norm() > tol {
444            y[i] = sum * h[[i, i]].inv();
445        }
446    }
447
448    y
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use crate::sparse::CsrMatrix;
455    use approx::assert_relative_eq;
456    use ndarray::array;
457    use num_complex::Complex64;
458
459    /// Build a diagonal operator with clustered eigenvalues near `cluster_center`
460    /// plus a few well-separated eigenvalues
461    fn build_clustered_diagonal(n: usize, cluster_center: f64) -> CsrMatrix<Complex64> {
462        let mut dense = Array2::from_elem((n, n), Complex64::new(0.0, 0.0));
463        for i in 0..n {
464            if i < n - 3 {
465                // Clustered eigenvalues near cluster_center
466                let offset = 0.01 * (i as f64 - (n as f64 - 3.0) / 2.0);
467                dense[[i, i]] = Complex64::new(cluster_center + offset, 0.0);
468            } else {
469                // Well-separated eigenvalues
470                dense[[i, i]] = Complex64::new((i + 1) as f64 * 10.0, 0.0);
471            }
472        }
473        CsrMatrix::from_dense(&dense, 1e-15)
474    }
475
476    #[test]
477    fn test_deflation_subspace_construction() {
478        // Simple 3x3 diagonal system with known eigenvectors
479        let dense = array![
480            [Complex64::new(1.0, 0.0), Complex64::ZERO, Complex64::ZERO],
481            [Complex64::ZERO, Complex64::new(2.0, 0.0), Complex64::ZERO],
482            [Complex64::ZERO, Complex64::ZERO, Complex64::new(3.0, 0.0)],
483        ];
484        let a = CsrMatrix::from_dense(&dense, 1e-15);
485
486        // Deflate first eigenvector
487        let w1 = array![
488            Complex64::new(1.0, 0.0),
489            Complex64::ZERO,
490            Complex64::ZERO
491        ];
492        let deflation = DeflationSubspace::new(vec![w1], &a).unwrap();
493
494        assert_eq!(deflation.num_vectors(), 1);
495
496        // P(e1) should be zero (eigenvector is fully deflated)
497        let e1 = array![
498            Complex64::new(1.0, 0.0),
499            Complex64::ZERO,
500            Complex64::ZERO
501        ];
502        let pe1 = deflation.apply_left_projector(&e1);
503        let pe1_norm: f64 = pe1.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
504        assert!(pe1_norm < 1e-12, "Deflated eigenvector should be zero");
505
506        // P(e2) should still be nonzero
507        let e2 = array![
508            Complex64::ZERO,
509            Complex64::new(1.0, 0.0),
510            Complex64::ZERO
511        ];
512        let pe2 = deflation.apply_left_projector(&e2);
513        let pe2_norm: f64 = pe2.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
514        assert!(pe2_norm > 0.5, "Non-deflated vector should remain");
515    }
516
517    #[test]
518    fn test_deflated_gmres_fewer_iterations() {
519        let n = 20;
520        let cluster = 5.0;
521        let a = build_clustered_diagonal(n, cluster);
522
523        // Build RHS
524        let b = Array1::from_iter((0..n).map(|i| Complex64::new((i + 1) as f64, 0.0)));
525
526        let config = GmresConfig {
527            max_iterations: 200,
528            restart: 30,
529            tolerance: 1e-10,
530            print_interval: 0,
531        };
532
533        // Standard GMRES
534        let sol_standard = super::super::gmres::gmres(&a, &b, &config);
535
536        // Deflated GMRES with exact eigenvectors of clustered eigenvalues
537        let mut w_cols = Vec::new();
538        for i in 0..(n - 3) {
539            let mut w = Array1::from_elem(n, Complex64::ZERO);
540            w[i] = Complex64::new(1.0, 0.0);
541            w_cols.push(w);
542        }
543        let deflation = DeflationSubspace::new(w_cols, &a).unwrap();
544        let sol_deflated = gmres_deflated(&a, &deflation, &b, None, &config);
545
546        assert!(sol_standard.converged, "Standard GMRES should converge");
547        assert!(sol_deflated.converged, "Deflated GMRES should converge");
548
549        // Deflated version should need fewer or equal iterations
550        assert!(
551            sol_deflated.iterations <= sol_standard.iterations,
552            "Deflated ({}) should use <= iterations than standard ({})",
553            sol_deflated.iterations,
554            sol_standard.iterations
555        );
556
557        // Verify solution correctness
558        let ax = a.matvec(&sol_deflated.x);
559        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
560        assert!(error < 1e-8, "Deflated solution should satisfy Ax = b");
561    }
562
563    #[test]
564    fn test_deflated_gmres_with_preconditioner() {
565        let dense = array![
566            [
567                Complex64::new(4.0, 0.0),
568                Complex64::new(1.0, 0.0),
569                Complex64::ZERO,
570            ],
571            [
572                Complex64::new(1.0, 0.0),
573                Complex64::new(3.0, 0.0),
574                Complex64::new(1.0, 0.0),
575            ],
576            [
577                Complex64::ZERO,
578                Complex64::new(1.0, 0.0),
579                Complex64::new(5.0, 0.0),
580            ],
581        ];
582        let a = CsrMatrix::from_dense(&dense, 1e-15);
583        let b = array![
584            Complex64::new(1.0, 0.0),
585            Complex64::new(2.0, 0.0),
586            Complex64::new(3.0, 0.0)
587        ];
588
589        // Use Jacobi preconditioner
590        let precond = crate::preconditioners::DiagonalPreconditioner::from_csr(&a);
591
592        // Single deflation vector (approximate eigenvector)
593        let w1 = array![
594            Complex64::new(0.5, 0.0),
595            Complex64::new(0.7, 0.0),
596            Complex64::new(0.5, 0.0)
597        ];
598        let w1_norm = vector_norm(&w1);
599        let w1_normalized = w1.mapv(|v| v * Complex64::new(1.0 / w1_norm, 0.0));
600
601        let deflation = DeflationSubspace::new(vec![w1_normalized], &a).unwrap();
602
603        let config = GmresConfig {
604            max_iterations: 50,
605            restart: 10,
606            tolerance: 1e-10,
607            print_interval: 0,
608        };
609
610        let sol = gmres_deflated_preconditioned(&a, &precond, &deflation, &b, None, &config);
611        assert!(sol.converged, "Deflated preconditioned GMRES should converge");
612
613        // Verify solution
614        let ax = a.matvec(&sol.x);
615        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
616        assert!(error < 1e-8, "Solution should satisfy Ax = b");
617    }
618
619    #[test]
620    fn test_deflated_gmres_zero_vectors_fallback() {
621        // r=0 deflation vectors should fall back to standard GMRES
622        let dense = array![
623            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
624            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
625        ];
626        let a = CsrMatrix::from_dense(&dense, 1e-15);
627        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
628
629        let deflation = DeflationSubspace::new(Vec::new(), &a).unwrap();
630        assert_eq!(deflation.num_vectors(), 0);
631
632        let config = GmresConfig {
633            max_iterations: 50,
634            restart: 10,
635            tolerance: 1e-10,
636            print_interval: 0,
637        };
638
639        let sol = gmres_deflated(&a, &deflation, &b, None, &config);
640        assert!(sol.converged, "Zero-vector deflated GMRES should converge");
641
642        // Should give same result as standard GMRES
643        let sol_standard = super::super::gmres::gmres(&a, &b, &config);
644        let error: f64 = (&sol.x - &sol_standard.x)
645            .iter()
646            .map(|e| e.norm_sqr())
647            .sum::<f64>()
648            .sqrt();
649        assert!(
650            error < 1e-8,
651            "Zero-deflation should match standard GMRES solution"
652        );
653    }
654
655    #[test]
656    fn test_coarse_correction_exact_for_deflation_space() {
657        // If b lies entirely in range(A·W), the coarse correction should recover it
658        let dense = array![
659            [Complex64::new(2.0, 0.0), Complex64::ZERO],
660            [Complex64::ZERO, Complex64::new(3.0, 0.0)],
661        ];
662        let a = CsrMatrix::from_dense(&dense, 1e-15);
663
664        // Deflate both eigenvectors — coarse correction solves exactly
665        let w1 = array![Complex64::new(1.0, 0.0), Complex64::ZERO];
666        let w2 = array![Complex64::ZERO, Complex64::new(1.0, 0.0)];
667        let deflation = DeflationSubspace::new(vec![w1, w2], &a).unwrap();
668
669        let b = array![Complex64::new(4.0, 0.0), Complex64::new(9.0, 0.0)];
670        let x_c = deflation.coarse_correction(&b);
671
672        // x_c should be the exact solution A⁻¹b = [2, 3]
673        let ax_c = a.matvec(&x_c);
674        let error: f64 = (&ax_c - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
675        assert!(error < 1e-10, "Coarse correction should be exact solution");
676    }
677
678    #[test]
679    fn test_deflated_gmres_f64() {
680        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
681        let a = CsrMatrix::from_dense(&dense, 1e-15);
682        let b = array![1.0_f64, 2.0];
683
684        // Deflate one approximate eigenvector
685        let w1 = array![0.7071_f64, 0.7071];
686        let w1_norm = vector_norm(&w1);
687        let w1_normalized = w1.mapv(|v| v / w1_norm);
688        let deflation = DeflationSubspace::new(vec![w1_normalized], &a).unwrap();
689
690        let config = GmresConfig {
691            max_iterations: 50,
692            restart: 10,
693            tolerance: 1e-10,
694            print_interval: 0,
695        };
696
697        let sol = gmres_deflated(&a, &deflation, &b, None, &config);
698        assert!(sol.converged);
699
700        let ax = a.matvec(&sol.x);
701        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
702        assert_relative_eq!(error, 0.0, epsilon = 1e-8);
703    }
704}