Skip to main content

oxicuda_solver/sparse/
fgmres.rs

1//! Flexible GMRES (FGMRES) iterative solver.
2//!
3//! Flexible GMRES allows a different (possibly nonlinear or variable)
4//! preconditioner at each iteration. Unlike standard right-preconditioned
5//! GMRES which stores only the Krylov basis vectors V, FGMRES also stores
6//! the preconditioned vectors Z_j = M_j^{-1} * v_j separately. The final
7//! solution update uses the Z vectors rather than V.
8//!
9//! # Algorithm
10//!
11//! 1. Compute r_0 = b - A*x_0, beta = ||r_0||, v_1 = r_0 / beta.
12//! 2. For j = 1, 2, ..., m:
13//!    a. z_j = M_j^{-1} * v_j  (preconditioner may vary per iteration)
14//!    b. w = A * z_j
15//!    c. Modified Gram-Schmidt: orthogonalize w against v_1, ..., v_j
16//!    d. h_{j+1,j} = ||w||; v_{j+1} = w / h_{j+1,j}
17//!    e. Apply previous Givens rotations; compute new rotation
18//!    f. Check convergence
19//! 3. Solve the Hessenberg least squares: H * y = g
20//! 4. Update: x = x_0 + Z * y  (NOT x_0 + V * y)
21//!
22//! The key difference from standard GMRES is step 4: Z vectors are used
23//! instead of V vectors.
24
25#![allow(dead_code)]
26
27use oxicuda_blas::GpuFloat;
28
29use crate::error::{SolverError, SolverResult};
30use crate::handle::SolverHandle;
31use crate::sparse::preconditioned::{IterativeSolverResult, Preconditioner};
32
33// ---------------------------------------------------------------------------
34// GpuFloat <-> f64 conversion helpers
35// ---------------------------------------------------------------------------
36
37fn to_f64<T: GpuFloat>(val: T) -> f64 {
38    if T::SIZE == 4 {
39        f32::from_bits(val.to_bits_u64() as u32) as f64
40    } else {
41        f64::from_bits(val.to_bits_u64())
42    }
43}
44
45fn from_f64<T: GpuFloat>(val: f64) -> T {
46    if T::SIZE == 4 {
47        T::from_bits_u64(u64::from((val as f32).to_bits()))
48    } else {
49        T::from_bits_u64(val.to_bits())
50    }
51}
52
53// ---------------------------------------------------------------------------
54// Configuration
55// ---------------------------------------------------------------------------
56
57/// FGMRES configuration.
58#[derive(Debug, Clone)]
59pub struct FgmresConfig {
60    /// Restart parameter: maximum Arnoldi steps per cycle.
61    pub restart: usize,
62    /// Maximum total iterations (across all restarts).
63    pub max_iter: usize,
64    /// Convergence tolerance (relative to ||b||).
65    pub tol: f64,
66}
67
68impl Default for FgmresConfig {
69    fn default() -> Self {
70        Self {
71            restart: 30,
72            max_iter: 1000,
73            tol: 1e-6,
74        }
75    }
76}
77
78// ---------------------------------------------------------------------------
79// Public API
80// ---------------------------------------------------------------------------
81
82/// Solves `A * x = b` using Flexible GMRES with variable preconditioner.
83///
84/// On entry, `x` should contain an initial guess. On exit, `x` contains
85/// the approximate solution.
86///
87/// # Type Parameters
88///
89/// * `T` — floating-point type (f32 or f64).
90/// * `P` — preconditioner type (may vary internally per iteration).
91/// * `F` — closure computing `y = A * x`.
92///
93/// # Arguments
94///
95/// * `_handle` — solver handle (reserved for future GPU variants).
96/// * `spmv` — closure computing `y = A * x`: `spmv(x, y)`.
97/// * `precond` — preconditioner implementing [`Preconditioner`].
98/// * `b` — right-hand side vector (length n).
99/// * `x` — initial guess / solution vector (length n), modified in-place.
100/// * `config` — FGMRES configuration.
101///
102/// # Returns
103///
104/// An [`IterativeSolverResult`] with iteration count, residual, and convergence.
105///
106/// # Errors
107///
108/// Returns [`SolverError::DimensionMismatch`] for invalid dimensions.
109pub fn fgmres<T, P, F>(
110    _handle: &SolverHandle,
111    spmv: F,
112    precond: &P,
113    b: &[T],
114    x: &mut [T],
115    config: &FgmresConfig,
116) -> SolverResult<IterativeSolverResult<T>>
117where
118    T: GpuFloat,
119    P: Preconditioner<T>,
120    F: Fn(&[T], &mut [T]) -> SolverResult<()>,
121{
122    let n = b.len();
123    if x.len() < n {
124        return Err(SolverError::DimensionMismatch(format!(
125            "fgmres: x length ({}) < b length ({n})",
126            x.len()
127        )));
128    }
129    if n == 0 {
130        return Ok(IterativeSolverResult {
131            iterations: 0,
132            residual: T::gpu_zero(),
133            converged: true,
134        });
135    }
136
137    let b_norm = vec_norm(b, n);
138    let abs_tol = if b_norm > 0.0 {
139        config.tol * b_norm
140    } else {
141        for xi in x.iter_mut().take(n) {
142            *xi = T::gpu_zero();
143        }
144        return Ok(IterativeSolverResult {
145            iterations: 0,
146            residual: T::gpu_zero(),
147            converged: true,
148        });
149    };
150
151    let m = config.restart.min(n);
152    let mut total_iters = 0_u32;
153
154    // Outer restart loop.
155    while (total_iters as usize) < config.max_iter {
156        let remaining = config.max_iter.saturating_sub(total_iters as usize);
157        let (iters, converged, res_norm) =
158            fgmres_cycle(&spmv, precond, b, x, n, m, abs_tol, remaining)?;
159        total_iters += iters;
160
161        if converged {
162            return Ok(IterativeSolverResult {
163                iterations: total_iters,
164                residual: from_f64(res_norm),
165                converged: true,
166            });
167        }
168
169        if iters == 0 {
170            break; // No progress.
171        }
172    }
173
174    // Compute final residual.
175    let mut r = vec![T::gpu_zero(); n];
176    let mut ax = vec![T::gpu_zero(); n];
177    spmv(x, &mut ax)?;
178    for i in 0..n {
179        r[i] = sub_t(b[i], ax[i]);
180    }
181    let r_norm = vec_norm(&r, n);
182
183    Ok(IterativeSolverResult {
184        iterations: total_iters,
185        residual: from_f64(r_norm),
186        converged: r_norm < abs_tol,
187    })
188}
189
190// ---------------------------------------------------------------------------
191// FGMRES cycle
192// ---------------------------------------------------------------------------
193
194/// One FGMRES cycle: runs up to `m` Arnoldi steps with flexible preconditioning,
195/// solves the Hessenberg least squares, and updates `x` using Z vectors.
196///
197/// Returns `(iters, converged, residual_norm)`.
198#[allow(clippy::too_many_arguments)]
199fn fgmres_cycle<T, P, F>(
200    spmv: &F,
201    precond: &P,
202    b: &[T],
203    x: &mut [T],
204    n: usize,
205    m: usize,
206    abs_tol: f64,
207    max_iters: usize,
208) -> SolverResult<(u32, bool, f64)>
209where
210    T: GpuFloat,
211    P: Preconditioner<T>,
212    F: Fn(&[T], &mut [T]) -> SolverResult<()>,
213{
214    // Compute initial residual r = b - A*x.
215    let mut r = vec![T::gpu_zero(); n];
216    let mut ax = vec![T::gpu_zero(); n];
217    spmv(x, &mut ax)?;
218    for i in 0..n {
219        r[i] = sub_t(b[i], ax[i]);
220    }
221    let beta = vec_norm(&r, n);
222
223    if beta < abs_tol {
224        return Ok((0, true, beta));
225    }
226
227    // V basis vectors (Krylov space).
228    let mut v_basis: Vec<Vec<T>> = Vec::with_capacity(m + 1);
229    // Z vectors (preconditioned basis) — this is the key FGMRES difference.
230    let mut z_basis: Vec<Vec<T>> = Vec::with_capacity(m);
231
232    // v_0 = r / beta
233    let inv_beta = from_f64(1.0 / beta);
234    let v0: Vec<T> = r.iter().map(|&ri| mul_t(ri, inv_beta)).collect();
235    v_basis.push(v0);
236
237    // Upper Hessenberg matrix H (m+1 x m), stored column-major as Vec<Vec<f64>>.
238    let mut h = vec![vec![0.0_f64; m + 1]; m];
239
240    // Givens rotation parameters.
241    let mut cs = vec![0.0_f64; m];
242    let mut sn = vec![0.0_f64; m];
243
244    // Right-hand side for Hessenberg least squares: g = beta * e_1.
245    let mut g = vec![0.0_f64; m + 1];
246    g[0] = beta;
247
248    let mut j = 0;
249    let max_j = m.min(max_iters);
250    let mut converged = false;
251
252    while j < max_j {
253        // FGMRES step a: z_j = M_j^{-1} * v_j
254        let mut z_j = vec![T::gpu_zero(); n];
255        precond.apply(&v_basis[j], &mut z_j)?;
256        z_basis.push(z_j);
257
258        // FGMRES step b: w = A * z_j
259        let mut w = vec![T::gpu_zero(); n];
260        spmv(&z_basis[j], &mut w)?;
261
262        // FGMRES step c: Modified Gram-Schmidt orthogonalization.
263        for i in 0..=j {
264            h[j][i] = dot_product(&v_basis[i], &w, n);
265            let h_ij_t = from_f64(h[j][i]);
266            for k in 0..n {
267                w[k] = sub_t(w[k], mul_t(h_ij_t, v_basis[i][k]));
268            }
269        }
270
271        // FGMRES step d: Normalize w to get v_{j+1}.
272        let w_norm = vec_norm(&w, n);
273        h[j][j + 1] = w_norm;
274
275        if w_norm > 1e-300 {
276            let inv_w = from_f64(1.0 / w_norm);
277            let vj1: Vec<T> = w.iter().map(|&wi| mul_t(wi, inv_w)).collect();
278            v_basis.push(vj1);
279        } else {
280            // Lucky breakdown.
281            let vj1 = vec![T::gpu_zero(); n];
282            v_basis.push(vj1);
283        }
284
285        // FGMRES step e: Apply previous Givens rotations to new column of H.
286        for i in 0..j {
287            let tmp = cs[i] * h[j][i] + sn[i] * h[j][i + 1];
288            h[j][i + 1] = -sn[i] * h[j][i] + cs[i] * h[j][i + 1];
289            h[j][i] = tmp;
290        }
291
292        // Compute new Givens rotation.
293        let (c, s) = givens_rotation(h[j][j], h[j][j + 1]);
294        cs[j] = c;
295        sn[j] = s;
296
297        // Apply to H.
298        h[j][j] = c * h[j][j] + s * h[j][j + 1];
299        h[j][j + 1] = 0.0;
300
301        // Apply to g.
302        let tmp = cs[j] * g[j] + sn[j] * g[j + 1];
303        g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
304        g[j] = tmp;
305
306        j += 1;
307
308        // FGMRES step f: Check convergence.
309        if g[j].abs() < abs_tol {
310            converged = true;
311            break;
312        }
313    }
314
315    // Solve the upper triangular system H[0:j, 0:j] * y = g[0:j].
316    let mut y = vec![0.0_f64; j];
317    for i in (0..j).rev() {
318        y[i] = g[i];
319        for k in (i + 1)..j {
320            y[i] -= h[k][i] * y[k];
321        }
322        if h[i][i].abs() > 1e-300 {
323            y[i] /= h[i][i];
324        }
325    }
326
327    // FGMRES update: x += Z * y  (NOT V * y — this is the key difference).
328    for i in 0..j {
329        let yi_t = from_f64(y[i]);
330        for k in 0..n {
331            x[k] = add_t(x[k], mul_t(yi_t, z_basis[i][k]));
332        }
333    }
334
335    // Compute actual residual for reporting.
336    let mut r_final = vec![T::gpu_zero(); n];
337    let mut ax_final = vec![T::gpu_zero(); n];
338    spmv(x, &mut ax_final)?;
339    for i in 0..n {
340        r_final[i] = sub_t(b[i], ax_final[i]);
341    }
342    let r_norm = vec_norm(&r_final, n);
343
344    Ok((j as u32, converged || r_norm < abs_tol, r_norm))
345}
346
347// ---------------------------------------------------------------------------
348// Helpers
349// ---------------------------------------------------------------------------
350
351fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
352    if b.abs() < 1e-300 {
353        return (1.0, 0.0);
354    }
355    if a.abs() < 1e-300 {
356        return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
357    }
358    let r = (a * a + b * b).sqrt();
359    (a / r, b / r)
360}
361
362fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
363    let mut sum = 0.0_f64;
364    for i in 0..n {
365        sum += to_f64(a[i]) * to_f64(b[i]);
366    }
367    sum
368}
369
370fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
371    dot_product(v, v, n).sqrt()
372}
373
374fn add_t<T: GpuFloat>(a: T, b: T) -> T {
375    from_f64(to_f64(a) + to_f64(b))
376}
377
378fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
379    from_f64(to_f64(a) - to_f64(b))
380}
381
382fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
383    from_f64(to_f64(a) * to_f64(b))
384}
385
386// ---------------------------------------------------------------------------
387// Tests
388// ---------------------------------------------------------------------------
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::sparse::preconditioned::IdentityPreconditioner;
394
395    #[test]
396    fn fgmres_config_default() {
397        let cfg = FgmresConfig::default();
398        assert_eq!(cfg.restart, 30);
399        assert_eq!(cfg.max_iter, 1000);
400        assert!((cfg.tol - 1e-6).abs() < 1e-15);
401    }
402
403    #[test]
404    fn fgmres_config_custom() {
405        let cfg = FgmresConfig {
406            restart: 50,
407            max_iter: 2000,
408            tol: 1e-10,
409        };
410        assert_eq!(cfg.restart, 50);
411        assert_eq!(cfg.max_iter, 2000);
412    }
413
414    #[test]
415    fn givens_rotation_basic() {
416        let (cs, sn) = givens_rotation(3.0, 4.0);
417        let r = cs * 3.0 + sn * 4.0;
418        assert!((r - 5.0).abs() < 1e-10);
419    }
420
421    #[test]
422    fn givens_rotation_zero_b() {
423        let (cs, sn) = givens_rotation(5.0, 0.0);
424        assert!((cs - 1.0).abs() < 1e-15);
425        assert!(sn.abs() < 1e-15);
426    }
427
428    #[test]
429    fn givens_rotation_zero_a() {
430        let (cs, sn) = givens_rotation(0.0, 3.0);
431        assert!(cs.abs() < 1e-15);
432        assert!((sn - 1.0).abs() < 1e-15);
433    }
434
435    #[test]
436    fn dot_product_basic() {
437        let a = [1.0_f64, 2.0, 3.0];
438        let b = [4.0_f64, 5.0, 6.0];
439        assert!((dot_product(&a, &b, 3) - 32.0).abs() < 1e-10);
440    }
441
442    #[test]
443    fn vec_norm_unit() {
444        let v = [1.0_f64, 0.0, 0.0];
445        assert!((vec_norm(&v, 3) - 1.0).abs() < 1e-15);
446    }
447
448    #[test]
449    fn vec_norm_345() {
450        let v = [3.0_f64, 4.0];
451        assert!((vec_norm(&v, 2) - 5.0).abs() < 1e-10);
452    }
453
454    #[test]
455    fn add_sub_mul_helpers() {
456        let a = 3.0_f64;
457        let b = 4.0_f64;
458        assert!((to_f64(add_t(a, b)) - 7.0).abs() < 1e-15);
459        assert!((to_f64(sub_t(a, b)) - (-1.0)).abs() < 1e-15);
460        assert!((to_f64(mul_t(a, b)) - 12.0).abs() < 1e-15);
461    }
462
463    #[test]
464    fn identity_preconditioner_with_fgmres() {
465        let _precond = IdentityPreconditioner;
466        // Verify that IdentityPreconditioner implements Preconditioner.
467        let r = [1.0_f64, 2.0, 3.0];
468        let mut z = [0.0_f64; 3];
469        let result = _precond.apply(&r, &mut z);
470        assert!(result.is_ok());
471        assert!((z[0] - 1.0).abs() < 1e-15);
472    }
473
474    #[test]
475    fn f64_conversion_roundtrip() {
476        let val = std::f64::consts::PI;
477        let as_f64 = to_f64(val);
478        let back: f64 = from_f64(as_f64);
479        assert!((back - val).abs() < 1e-15);
480    }
481
482    #[test]
483    fn f32_conversion_roundtrip() {
484        let val = std::f32::consts::PI;
485        let as_f64 = to_f64(val);
486        let back: f32 = from_f64(as_f64);
487        assert!((back - val).abs() < 1e-5);
488    }
489}