Skip to main content

oxicuda_solver/sparse/
gmres.rs

1//! GMRES(m) (Generalized Minimal Residual with restart) iterative solver.
2//!
3//! Solves the linear system `A * x = b` for general matrices using the
4//! GMRES algorithm with periodic restarts after `m` iterations.
5//!
6//! # Algorithm
7//!
8//! GMRES builds an orthonormal basis for the Krylov subspace
9//! `K_m = span{r, A*r, A^2*r, ..., A^{m-1}*r}` via the Arnoldi process,
10//! then solves a small least squares problem on the resulting upper
11//! Hessenberg matrix using Givens rotations.
12//!
13//! After `m` iterations without convergence, the algorithm restarts with
14//! the current best solution as the new initial guess.
15
16#![allow(dead_code)]
17
18use oxicuda_blas::GpuFloat;
19
20use crate::error::{SolverError, SolverResult};
21use crate::handle::SolverHandle;
22
23// ---------------------------------------------------------------------------
24// GpuFloat <-> f64 conversion helpers
25// ---------------------------------------------------------------------------
26
27fn to_f64<T: GpuFloat>(val: T) -> f64 {
28    if T::SIZE == 4 {
29        f32::from_bits(val.to_bits_u64() as u32) as f64
30    } else {
31        f64::from_bits(val.to_bits_u64())
32    }
33}
34
35fn from_f64<T: GpuFloat>(val: f64) -> T {
36    if T::SIZE == 4 {
37        T::from_bits_u64(u64::from((val as f32).to_bits()))
38    } else {
39        T::from_bits_u64(val.to_bits())
40    }
41}
42
43/// Default restart parameter for GMRES.
44const DEFAULT_RESTART: u32 = 30;
45
46// ---------------------------------------------------------------------------
47// Configuration
48// ---------------------------------------------------------------------------
49
50/// Configuration for the GMRES(m) solver.
51#[derive(Debug, Clone)]
52pub struct GmresConfig {
53    /// Maximum total number of iterations (across all restarts).
54    pub max_iter: u32,
55    /// Convergence tolerance (relative to ||b||).
56    pub tol: f64,
57    /// Restart parameter: number of Arnoldi steps before restarting.
58    pub restart: u32,
59}
60
61impl Default for GmresConfig {
62    fn default() -> Self {
63        Self {
64            max_iter: 1000,
65            tol: 1e-6,
66            restart: DEFAULT_RESTART,
67        }
68    }
69}
70
71// ---------------------------------------------------------------------------
72// Public API
73// ---------------------------------------------------------------------------
74
75/// Solves `A * x = b` using GMRES(m) with restart.
76///
77/// On entry, `x` should contain an initial guess. On exit, `x` contains
78/// the approximate solution.
79///
80/// # Arguments
81///
82/// * `_handle` — solver handle (reserved for future GPU-accelerated variants).
83/// * `spmv` — closure computing `y = A * x`: `spmv(x, y)`.
84/// * `b` — right-hand side vector (length n).
85/// * `x` — initial guess / solution vector (length n), modified in-place.
86/// * `n` — system dimension.
87/// * `config` — solver configuration.
88///
89/// # Returns
90///
91/// The total number of matrix-vector products performed.
92///
93/// # Errors
94///
95/// Returns [`SolverError::ConvergenceFailure`] if the solver does not converge.
96pub fn gmres_solve<T, F>(
97    _handle: &SolverHandle,
98    spmv: F,
99    b: &[T],
100    x: &mut [T],
101    n: u32,
102    config: &GmresConfig,
103) -> SolverResult<u32>
104where
105    T: GpuFloat,
106    F: Fn(&[T], &mut [T]) -> SolverResult<()>,
107{
108    let n_usize = n as usize;
109
110    // Validate dimensions.
111    if b.len() < n_usize {
112        return Err(SolverError::DimensionMismatch(format!(
113            "gmres_solve: b length ({}) < n ({n})",
114            b.len()
115        )));
116    }
117    if x.len() < n_usize {
118        return Err(SolverError::DimensionMismatch(format!(
119            "gmres_solve: x length ({}) < n ({n})",
120            x.len()
121        )));
122    }
123    if n == 0 {
124        return Ok(0);
125    }
126
127    let b_norm = vec_norm(b, n_usize);
128    let abs_tol = if b_norm > 0.0 {
129        config.tol * b_norm
130    } else {
131        for xi in x.iter_mut().take(n_usize) {
132            *xi = T::gpu_zero();
133        }
134        return Ok(0);
135    };
136
137    let m = config.restart.min(n) as usize;
138    let mut total_iters = 0_u32;
139
140    // Outer restart loop.
141    while total_iters < config.max_iter {
142        let iters = gmres_cycle(
143            &spmv,
144            b,
145            x,
146            n_usize,
147            m,
148            abs_tol,
149            config.max_iter - total_iters,
150        )?;
151        total_iters += iters;
152
153        // Check if we converged in this cycle.
154        let mut r = vec![T::gpu_zero(); n_usize];
155        let mut ax = vec![T::gpu_zero(); n_usize];
156        spmv(x, &mut ax)?;
157        for i in 0..n_usize {
158            r[i] = sub_t(b[i], ax[i]);
159        }
160        total_iters += 1; // Count the residual check spmv.
161
162        let r_norm = vec_norm(&r, n_usize);
163        if r_norm < abs_tol {
164            return Ok(total_iters);
165        }
166
167        if iters == 0 {
168            break; // No progress in this cycle.
169        }
170    }
171
172    // Compute final residual for error reporting.
173    let mut r = vec![T::gpu_zero(); n_usize];
174    let mut ax = vec![T::gpu_zero(); n_usize];
175    spmv(x, &mut ax)?;
176    for i in 0..n_usize {
177        r[i] = sub_t(b[i], ax[i]);
178    }
179    let r_norm = vec_norm(&r, n_usize);
180
181    if r_norm < abs_tol {
182        Ok(total_iters)
183    } else {
184        Err(SolverError::ConvergenceFailure {
185            iterations: total_iters,
186            residual: r_norm,
187        })
188    }
189}
190
191// ---------------------------------------------------------------------------
192// GMRES cycle (one restart)
193// ---------------------------------------------------------------------------
194
195/// One GMRES cycle: runs up to `m` Arnoldi steps, solves the Hessenberg
196/// least squares problem, and updates `x`.
197///
198/// Returns the number of matrix-vector products performed in this cycle.
199fn gmres_cycle<T, F>(
200    spmv: &F,
201    b: &[T],
202    x: &mut [T],
203    n: usize,
204    m: usize,
205    abs_tol: f64,
206    max_iters: u32,
207) -> SolverResult<u32>
208where
209    T: GpuFloat,
210    F: Fn(&[T], &mut [T]) -> SolverResult<()>,
211{
212    // Compute initial residual r = b - A*x.
213    let mut r = vec![T::gpu_zero(); n];
214    let mut ax = vec![T::gpu_zero(); n];
215    spmv(x, &mut ax)?;
216    for i in 0..n {
217        r[i] = sub_t(b[i], ax[i]);
218    }
219    let beta = vec_norm(&r, n);
220
221    if beta < abs_tol {
222        return Ok(0);
223    }
224
225    // Arnoldi basis vectors: V = [v_0, v_1, ..., v_m] where each v_i is length n.
226    let mut v_basis: Vec<Vec<T>> = Vec::with_capacity(m + 1);
227
228    // v_0 = r / beta
229    let inv_beta = from_f64(1.0 / beta);
230    let v0: Vec<T> = r.iter().map(|&ri| mul_t(ri, inv_beta)).collect();
231    v_basis.push(v0);
232
233    // Upper Hessenberg matrix H (m+1 x m), stored column-major as Vec<Vec<f64>>.
234    let mut h = vec![vec![0.0_f64; m + 1]; m];
235
236    // Givens rotation parameters.
237    let mut cs = vec![0.0_f64; m];
238    let mut sn = vec![0.0_f64; m];
239
240    // Right-hand side for the Hessenberg least squares: g = beta * e_1.
241    let mut g = vec![0.0_f64; m + 1];
242    g[0] = beta;
243
244    let mut j = 0;
245    let max_j = m.min(max_iters as usize);
246
247    while j < max_j {
248        // Arnoldi step: w = A * v_j.
249        let mut w = vec![T::gpu_zero(); n];
250        spmv(&v_basis[j], &mut w)?;
251
252        // Modified Gram-Schmidt orthogonalization.
253        for i in 0..=j {
254            h[j][i] = dot_product(&v_basis[i], &w, n);
255            let h_ij_t = from_f64(h[j][i]);
256            for k in 0..n {
257                w[k] = sub_t(w[k], mul_t(h_ij_t, v_basis[i][k]));
258            }
259        }
260
261        let w_norm = vec_norm(&w, n);
262        h[j][j + 1] = w_norm;
263
264        // Normalize w to get v_{j+1}.
265        if w_norm > 1e-300 {
266            let inv_w = from_f64(1.0 / w_norm);
267            let vj1: Vec<T> = w.iter().map(|&wi| mul_t(wi, inv_w)).collect();
268            v_basis.push(vj1);
269        } else {
270            // Lucky breakdown: w is in the span of existing basis.
271            let vj1 = vec![T::gpu_zero(); n];
272            v_basis.push(vj1);
273        }
274
275        // Apply previous Givens rotations to the new column of H.
276        for i in 0..j {
277            let tmp = cs[i] * h[j][i] + sn[i] * h[j][i + 1];
278            h[j][i + 1] = -sn[i] * h[j][i] + cs[i] * h[j][i + 1];
279            h[j][i] = tmp;
280        }
281
282        // Compute new Givens rotation for row (j, j+1).
283        let (c, s) = givens_rotation(h[j][j], h[j][j + 1]);
284        cs[j] = c;
285        sn[j] = s;
286
287        // Apply to H.
288        h[j][j] = c * h[j][j] + s * h[j][j + 1];
289        h[j][j + 1] = 0.0;
290
291        // Apply to g.
292        let tmp = cs[j] * g[j] + sn[j] * g[j + 1];
293        g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
294        g[j] = tmp;
295
296        j += 1;
297
298        // Check convergence: |g[j]| is the residual norm.
299        if g[j].abs() < abs_tol {
300            break;
301        }
302    }
303
304    // Solve the upper triangular system H[0:j, 0:j] * y = g[0:j].
305    let mut y = vec![0.0_f64; j];
306    for i in (0..j).rev() {
307        y[i] = g[i];
308        for k in (i + 1)..j {
309            y[i] -= h[k][i] * y[k];
310        }
311        if h[i][i].abs() > 1e-300 {
312            y[i] /= h[i][i];
313        }
314    }
315
316    // Update x: x += V * y.
317    for i in 0..j {
318        let yi_t = from_f64(y[i]);
319        for k in 0..n {
320            x[k] = add_t(x[k], mul_t(yi_t, v_basis[i][k]));
321        }
322    }
323
324    Ok(j as u32)
325}
326
327// ---------------------------------------------------------------------------
328// Helpers
329// ---------------------------------------------------------------------------
330
331fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
332    if b.abs() < 1e-300 {
333        return (1.0, 0.0);
334    }
335    if a.abs() < 1e-300 {
336        return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
337    }
338    let r = (a * a + b * b).sqrt();
339    (a / r, b / r)
340}
341
342fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
343    let mut sum = 0.0_f64;
344    for i in 0..n {
345        sum += to_f64(a[i]) * to_f64(b[i]);
346    }
347    sum
348}
349
350fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
351    dot_product(v, v, n).sqrt()
352}
353
354fn add_t<T: GpuFloat>(a: T, b: T) -> T {
355    from_f64(to_f64(a) + to_f64(b))
356}
357
358fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
359    from_f64(to_f64(a) - to_f64(b))
360}
361
362fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
363    from_f64(to_f64(a) * to_f64(b))
364}
365
366// ---------------------------------------------------------------------------
367// Tests
368// ---------------------------------------------------------------------------
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    /// CPU-only GMRES solver for testing without a GPU handle.
375    ///
376    /// Mirrors `gmres_solve` but omits the `_handle` parameter, enabling
377    /// pure host testing with a closure-based matrix-vector product.
378    fn gmres_solve_cpu<T, F>(
379        spmv: F,
380        b: &[T],
381        x: &mut [T],
382        n: u32,
383        config: &GmresConfig,
384    ) -> SolverResult<u32>
385    where
386        T: GpuFloat,
387        F: Fn(&[T], &mut [T]) -> SolverResult<()>,
388    {
389        let n_usize = n as usize;
390
391        if b.len() < n_usize {
392            return Err(SolverError::DimensionMismatch(format!(
393                "gmres_solve_cpu: b length ({}) < n ({n})",
394                b.len()
395            )));
396        }
397        if x.len() < n_usize {
398            return Err(SolverError::DimensionMismatch(format!(
399                "gmres_solve_cpu: x length ({}) < n ({n})",
400                x.len()
401            )));
402        }
403        if n == 0 {
404            return Ok(0);
405        }
406
407        let b_norm = vec_norm(b, n_usize);
408        let abs_tol = if b_norm > 0.0 {
409            config.tol * b_norm
410        } else {
411            for xi in x.iter_mut().take(n_usize) {
412                *xi = T::gpu_zero();
413            }
414            return Ok(0);
415        };
416
417        let m = config.restart.min(n) as usize;
418        let mut total_iters = 0_u32;
419
420        while total_iters < config.max_iter {
421            let iters = gmres_cycle(
422                &spmv,
423                b,
424                x,
425                n_usize,
426                m,
427                abs_tol,
428                config.max_iter - total_iters,
429            )?;
430            total_iters += iters;
431
432            let mut r = vec![T::gpu_zero(); n_usize];
433            let mut ax = vec![T::gpu_zero(); n_usize];
434            spmv(x, &mut ax)?;
435            for i in 0..n_usize {
436                r[i] = sub_t(b[i], ax[i]);
437            }
438            total_iters += 1;
439
440            let r_norm = vec_norm(&r, n_usize);
441            if r_norm < abs_tol {
442                return Ok(total_iters);
443            }
444
445            if iters == 0 {
446                break;
447            }
448        }
449
450        let mut r = vec![T::gpu_zero(); n_usize];
451        let mut ax = vec![T::gpu_zero(); n_usize];
452        spmv(x, &mut ax)?;
453        for i in 0..n_usize {
454            r[i] = sub_t(b[i], ax[i]);
455        }
456        let r_norm = vec_norm(&r, n_usize);
457
458        if r_norm < abs_tol {
459            Ok(total_iters)
460        } else {
461            Err(SolverError::ConvergenceFailure {
462                iterations: total_iters,
463                residual: r_norm,
464            })
465        }
466    }
467
468    #[test]
469    fn gmres_config_default() {
470        let cfg = GmresConfig::default();
471        assert_eq!(cfg.max_iter, 1000);
472        assert!((cfg.tol - 1e-6).abs() < 1e-15);
473        assert_eq!(cfg.restart, DEFAULT_RESTART);
474    }
475
476    #[test]
477    fn gmres_config_custom() {
478        let cfg = GmresConfig {
479            max_iter: 500,
480            tol: 1e-10,
481            restart: 50,
482        };
483        assert_eq!(cfg.restart, 50);
484    }
485
486    #[test]
487    fn givens_rotation_basic() {
488        let (cs, sn) = givens_rotation(3.0, 4.0);
489        let r = cs * 3.0 + sn * 4.0;
490        assert!((r - 5.0).abs() < 1e-10);
491    }
492
493    #[test]
494    fn givens_rotation_zero_b() {
495        let (cs, sn) = givens_rotation(5.0, 0.0);
496        assert!((cs - 1.0).abs() < 1e-15);
497        assert!(sn.abs() < 1e-15);
498    }
499
500    #[test]
501    fn dot_product_basic() {
502        let a = [1.0_f64, 2.0, 3.0];
503        let b = [4.0_f64, 5.0, 6.0];
504        assert!((dot_product(&a, &b, 3) - 32.0).abs() < 1e-10);
505    }
506
507    #[test]
508    fn vec_norm_unit() {
509        let v = [1.0_f64, 0.0, 0.0];
510        assert!((vec_norm(&v, 3) - 1.0).abs() < 1e-15);
511    }
512
513    /// GMRES converges on a 3×3 identity matrix in a single Arnoldi step.
514    ///
515    /// A = I, b = [3, 7, -2] → exact solution x = [3, 7, -2].
516    /// The identity matrix has a single eigenvalue λ=1, so GMRES minimises
517    /// the residual in exactly one step (Krylov space = full space).
518    #[test]
519    fn gmres_converges_identity_3x3() {
520        let b = vec![3.0_f64, 7.0, -2.0];
521        let mut x = vec![0.0_f64; 3];
522        let config = GmresConfig {
523            max_iter: 50,
524            tol: 1e-10,
525            restart: 10,
526        };
527
528        // A = I
529        let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
530            out.copy_from_slice(v);
531            Ok(())
532        };
533
534        let _iters = gmres_solve_cpu(spmv, &b, &mut x, 3, &config)
535            .expect("GMRES should converge on identity system");
536
537        assert!((x[0] - 3.0).abs() < 1e-8, "x[0] = {} expected 3.0", x[0]);
538        assert!((x[1] - 7.0).abs() < 1e-8, "x[1] = {} expected 7.0", x[1]);
539        assert!(
540            (x[2] - (-2.0)).abs() < 1e-8,
541            "x[2] = {} expected -2.0",
542            x[2]
543        );
544    }
545
546    /// GMRES converges on a 4×4 tridiagonal SPD system in ≤ N steps.
547    ///
548    /// A = tridiag(-1, 2, -1), b = [1, 1, 1, 1], exact x = [2, 3, 3, 2].
549    #[test]
550    fn gmres_converges_tridiagonal_4x4() {
551        let b = vec![1.0_f64, 1.0, 1.0, 1.0];
552        let mut x = vec![0.0_f64; 4];
553        let config = GmresConfig {
554            max_iter: 200,
555            tol: 1e-10,
556            restart: 10,
557        };
558
559        // A = tridiag(-1, 2, -1), 4×4
560        let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
561            out[0] = 2.0 * v[0] - v[1];
562            out[1] = -v[0] + 2.0 * v[1] - v[2];
563            out[2] = -v[1] + 2.0 * v[2] - v[3];
564            out[3] = -v[2] + 2.0 * v[3];
565            Ok(())
566        };
567
568        let _iters = gmres_solve_cpu(spmv, &b, &mut x, 4, &config)
569            .expect("GMRES should converge on tridiagonal system");
570
571        assert!((x[0] - 2.0).abs() < 1e-7, "x[0] = {} expected 2.0", x[0]);
572        assert!((x[1] - 3.0).abs() < 1e-7, "x[1] = {} expected 3.0", x[1]);
573        assert!((x[2] - 3.0).abs() < 1e-7, "x[2] = {} expected 3.0", x[2]);
574        assert!((x[3] - 2.0).abs() < 1e-7, "x[3] = {} expected 2.0", x[3]);
575    }
576
577    /// GMRES with zero RHS returns immediately without iterating.
578    #[test]
579    fn gmres_zero_rhs_returns_zero() {
580        let b = vec![0.0_f64; 3];
581        let mut x = vec![1.0_f64; 3]; // non-zero initial guess
582        let config = GmresConfig::default();
583
584        let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
585            out.copy_from_slice(v);
586            Ok(())
587        };
588
589        let iters = gmres_solve_cpu(spmv, &b, &mut x, 3, &config).expect("zero RHS should succeed");
590        assert_eq!(iters, 0);
591        for &xi in &x {
592            assert!(xi.abs() < 1e-15, "x should be zeroed for zero RHS");
593        }
594    }
595
596    /// GMRES returns DimensionMismatch when b is shorter than n.
597    #[test]
598    fn gmres_dimension_mismatch() {
599        let b = vec![1.0_f64]; // length 1, n = 3
600        let mut x = vec![0.0_f64; 3];
601        let config = GmresConfig::default();
602        let spmv = |_: &[f64], _: &mut [f64]| -> SolverResult<()> { Ok(()) };
603        let result = gmres_solve_cpu(spmv, &b, &mut x, 3, &config);
604        assert!(matches!(result, Err(SolverError::DimensionMismatch(_))));
605    }
606
607    /// GMRES converges on a diagonal SPD system in at most N Arnoldi steps.
608    ///
609    /// A = diag(1, 4, 9), b = [1, 4, 9] → exact x = [1, 1, 1].
610    #[test]
611    fn gmres_converges_diagonal_spd() {
612        let b = vec![1.0_f64, 4.0, 9.0];
613        let mut x = vec![0.0_f64; 3];
614        let config = GmresConfig {
615            max_iter: 100,
616            tol: 1e-10,
617            restart: 10,
618        };
619
620        let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
621            out[0] = 1.0 * v[0];
622            out[1] = 4.0 * v[1];
623            out[2] = 9.0 * v[2];
624            Ok(())
625        };
626
627        let _iters = gmres_solve_cpu(spmv, &b, &mut x, 3, &config)
628            .expect("GMRES should converge on diagonal SPD");
629
630        assert!((x[0] - 1.0).abs() < 1e-8, "x[0] = {} expected 1.0", x[0]);
631        assert!((x[1] - 1.0).abs() < 1e-8, "x[1] = {} expected 1.0", x[1]);
632        assert!((x[2] - 1.0).abs() < 1e-8, "x[2] = {} expected 1.0", x[2]);
633    }
634}