Skip to main content

oxicuda_solver/sparse/
bicgstab.rs

1//! BiCGSTAB (Biconjugate Gradient Stabilized) iterative solver.
2//!
3//! Solves the linear system `A * x = b` for general (non-symmetric) matrices.
4//! The solver is matrix-free: it only requires a closure that computes the
5//! matrix-vector product `y = A * x`.
6//!
7//! # Algorithm
8//!
9//! The BiCGSTAB algorithm (van der Vorst, 1992):
10//! 1. r = b - A*x; r0_hat = r; rho = alpha = omega = 1; v = p = 0
11//! 2. For each iteration:
12//!    a. rho_new = r0_hat^T * r
13//!    b. beta = (rho_new / rho) * (alpha / omega)
14//!    c. p = r + beta * (p - omega * v)
15//!    d. v = A * p
16//!    e. alpha = rho_new / (r0_hat^T * v)
17//!    f. s = r - alpha * v
18//!    g. t = A * s
19//!    h. omega = (t^T * s) / (t^T * t)
20//!    i. x += alpha * p + omega * s
21//!    j. r = s - omega * t
22//!    k. Check convergence: ||r|| < tol * ||b||
23
24#![allow(dead_code)]
25
26use oxicuda_blas::GpuFloat;
27
28use crate::error::{SolverError, SolverResult};
29use crate::handle::SolverHandle;
30
31// ---------------------------------------------------------------------------
32// GpuFloat <-> f64 conversion helpers
33// ---------------------------------------------------------------------------
34
35fn to_f64<T: GpuFloat>(val: T) -> f64 {
36    if T::SIZE == 4 {
37        f32::from_bits(val.to_bits_u64() as u32) as f64
38    } else {
39        f64::from_bits(val.to_bits_u64())
40    }
41}
42
43fn from_f64<T: GpuFloat>(val: f64) -> T {
44    if T::SIZE == 4 {
45        T::from_bits_u64(u64::from((val as f32).to_bits()))
46    } else {
47        T::from_bits_u64(val.to_bits())
48    }
49}
50
51// ---------------------------------------------------------------------------
52// Configuration
53// ---------------------------------------------------------------------------
54
55/// Configuration for the BiCGSTAB solver.
56#[derive(Debug, Clone)]
57pub struct BiCgStabConfig {
58    /// Maximum number of iterations.
59    pub max_iter: u32,
60    /// Convergence tolerance (relative to ||b||).
61    pub tol: f64,
62}
63
64impl Default for BiCgStabConfig {
65    fn default() -> Self {
66        Self {
67            max_iter: 1000,
68            tol: 1e-6,
69        }
70    }
71}
72
73// ---------------------------------------------------------------------------
74// Public API
75// ---------------------------------------------------------------------------
76
77/// Solves `A * x = b` using the BiCGSTAB method.
78///
79/// On entry, `x` should contain an initial guess. On exit, `x` contains the
80/// approximate solution.
81///
82/// # Arguments
83///
84/// * `_handle` — solver handle (reserved for future GPU-accelerated variants).
85/// * `spmv` — closure computing `y = A * x`: `spmv(x, y)`.
86/// * `b` — right-hand side vector (length n).
87/// * `x` — initial guess / solution vector (length n), modified in-place.
88/// * `n` — system dimension.
89/// * `config` — solver configuration.
90///
91/// # Returns
92///
93/// The number of iterations performed.
94///
95/// # Errors
96///
97/// Returns [`SolverError::ConvergenceFailure`] if the solver does not converge.
98/// Returns [`SolverError::InternalError`] if a breakdown is detected (e.g., rho = 0).
99pub fn bicgstab_solve<T, F>(
100    _handle: &SolverHandle,
101    spmv: F,
102    b: &[T],
103    x: &mut [T],
104    n: u32,
105    config: &BiCgStabConfig,
106) -> SolverResult<u32>
107where
108    T: GpuFloat,
109    F: Fn(&[T], &mut [T]) -> SolverResult<()>,
110{
111    let n_usize = n as usize;
112
113    // Validate dimensions.
114    if b.len() < n_usize {
115        return Err(SolverError::DimensionMismatch(format!(
116            "bicgstab_solve: b length ({}) < n ({n})",
117            b.len()
118        )));
119    }
120    if x.len() < n_usize {
121        return Err(SolverError::DimensionMismatch(format!(
122            "bicgstab_solve: x length ({}) < n ({n})",
123            x.len()
124        )));
125    }
126    if n == 0 {
127        return Ok(0);
128    }
129
130    // Compute ||b|| for relative convergence check.
131    let b_norm = vec_norm(b, n_usize);
132    let abs_tol = if b_norm > 0.0 {
133        config.tol * b_norm
134    } else {
135        for xi in x.iter_mut().take(n_usize) {
136            *xi = T::gpu_zero();
137        }
138        return Ok(0);
139    };
140
141    // r = b - A*x
142    let mut r = vec![T::gpu_zero(); n_usize];
143    let mut tmp = vec![T::gpu_zero(); n_usize];
144    spmv(x, &mut tmp)?;
145    for i in 0..n_usize {
146        r[i] = sub_t(b[i], tmp[i]);
147    }
148
149    // r0_hat = r (shadow residual, kept constant)
150    let r0_hat = r.clone();
151
152    // Initialize scalars.
153    let mut rho = 1.0_f64;
154    let mut alpha = 1.0_f64;
155    let mut omega = 1.0_f64;
156
157    // Initialize vectors.
158    let mut v = vec![T::gpu_zero(); n_usize];
159    let mut p = vec![T::gpu_zero(); n_usize];
160    let mut s = vec![T::gpu_zero(); n_usize];
161    let mut t = vec![T::gpu_zero(); n_usize];
162
163    for iter in 0..config.max_iter {
164        // rho_new = r0_hat^T * r
165        let rho_new = dot_product(&r0_hat, &r, n_usize);
166
167        if rho_new.abs() < 1e-300 {
168            return Err(SolverError::InternalError(
169                "bicgstab_solve: rho breakdown (r0_hat^T * r ~ 0)".into(),
170            ));
171        }
172
173        // beta = (rho_new / rho) * (alpha / omega)
174        let beta = if rho.abs() > 1e-300 && omega.abs() > 1e-300 {
175            (rho_new / rho) * (alpha / omega)
176        } else {
177            0.0
178        };
179        let beta_t = from_f64(beta);
180        let omega_t = from_f64(omega);
181
182        // p = r + beta * (p - omega * v)
183        for i in 0..n_usize {
184            let pv = sub_t(p[i], mul_t(omega_t, v[i]));
185            p[i] = add_t(r[i], mul_t(beta_t, pv));
186        }
187
188        // v = A * p
189        spmv(&p, &mut v)?;
190
191        // alpha = rho_new / (r0_hat^T * v)
192        let r0v = dot_product(&r0_hat, &v, n_usize);
193        if r0v.abs() < 1e-300 {
194            return Err(SolverError::InternalError(
195                "bicgstab_solve: alpha breakdown (r0_hat^T * v ~ 0)".into(),
196            ));
197        }
198        alpha = rho_new / r0v;
199        let alpha_t = from_f64(alpha);
200
201        // s = r - alpha * v
202        for i in 0..n_usize {
203            s[i] = sub_t(r[i], mul_t(alpha_t, v[i]));
204        }
205
206        // Check if s is small enough (early exit).
207        let s_norm = vec_norm(&s, n_usize);
208        if s_norm < abs_tol {
209            // x += alpha * p
210            for i in 0..n_usize {
211                x[i] = add_t(x[i], mul_t(alpha_t, p[i]));
212            }
213            return Ok(iter + 1);
214        }
215
216        // t = A * s
217        spmv(&s, &mut t)?;
218
219        // omega = (t^T * s) / (t^T * t)
220        let tt = dot_product(&t, &t, n_usize);
221        omega = if tt.abs() > 1e-300 {
222            dot_product(&t, &s, n_usize) / tt
223        } else {
224            0.0
225        };
226        let omega_new_t = from_f64(omega);
227
228        // x += alpha * p + omega * s
229        for i in 0..n_usize {
230            x[i] = add_t(x[i], add_t(mul_t(alpha_t, p[i]), mul_t(omega_new_t, s[i])));
231        }
232
233        // r = s - omega * t
234        for i in 0..n_usize {
235            r[i] = sub_t(s[i], mul_t(omega_new_t, t[i]));
236        }
237
238        // Check convergence.
239        let r_norm = vec_norm(&r, n_usize);
240        if r_norm < abs_tol {
241            return Ok(iter + 1);
242        }
243
244        if omega.abs() < 1e-300 {
245            return Err(SolverError::InternalError(
246                "bicgstab_solve: omega breakdown".into(),
247            ));
248        }
249
250        rho = rho_new;
251    }
252
253    Err(SolverError::ConvergenceFailure {
254        iterations: config.max_iter,
255        residual: vec_norm(&r, n_usize),
256    })
257}
258
259// ---------------------------------------------------------------------------
260// Vector arithmetic helpers
261// ---------------------------------------------------------------------------
262
263fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
264    let mut sum = 0.0_f64;
265    for i in 0..n {
266        sum += to_f64(a[i]) * to_f64(b[i]);
267    }
268    sum
269}
270
271fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
272    dot_product(v, v, n).sqrt()
273}
274
275fn add_t<T: GpuFloat>(a: T, b: T) -> T {
276    from_f64(to_f64(a) + to_f64(b))
277}
278
279fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
280    from_f64(to_f64(a) - to_f64(b))
281}
282
283fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
284    from_f64(to_f64(a) * to_f64(b))
285}
286
287// ---------------------------------------------------------------------------
288// Tests
289// ---------------------------------------------------------------------------
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    /// CPU-only BiCGSTAB solver for testing without a GPU handle.
296    ///
297    /// Mirrors `bicgstab_solve` but omits `_handle`, enabling pure host testing
298    /// with a closure-based matrix-vector product.
299    fn bicgstab_solve_cpu<T, F>(
300        spmv: F,
301        b: &[T],
302        x: &mut [T],
303        n: u32,
304        config: &BiCgStabConfig,
305    ) -> SolverResult<u32>
306    where
307        T: GpuFloat,
308        F: Fn(&[T], &mut [T]) -> SolverResult<()>,
309    {
310        let n_usize = n as usize;
311
312        if b.len() < n_usize {
313            return Err(SolverError::DimensionMismatch(format!(
314                "bicgstab_solve_cpu: b length ({}) < n ({n})",
315                b.len()
316            )));
317        }
318        if x.len() < n_usize {
319            return Err(SolverError::DimensionMismatch(format!(
320                "bicgstab_solve_cpu: x length ({}) < n ({n})",
321                x.len()
322            )));
323        }
324        if n == 0 {
325            return Ok(0);
326        }
327
328        let b_norm = vec_norm(b, n_usize);
329        let abs_tol = if b_norm > 0.0 {
330            config.tol * b_norm
331        } else {
332            for xi in x.iter_mut().take(n_usize) {
333                *xi = T::gpu_zero();
334            }
335            return Ok(0);
336        };
337
338        let mut r = vec![T::gpu_zero(); n_usize];
339        let mut tmp = vec![T::gpu_zero(); n_usize];
340        spmv(x, &mut tmp)?;
341        for i in 0..n_usize {
342            r[i] = sub_t(b[i], tmp[i]);
343        }
344
345        let r0_hat = r.clone();
346        let mut rho = 1.0_f64;
347        let mut alpha = 1.0_f64;
348        let mut omega = 1.0_f64;
349        let mut v = vec![T::gpu_zero(); n_usize];
350        let mut p = vec![T::gpu_zero(); n_usize];
351        let mut s = vec![T::gpu_zero(); n_usize];
352        let mut t = vec![T::gpu_zero(); n_usize];
353
354        for iter in 0..config.max_iter {
355            let rho_new = dot_product(&r0_hat, &r, n_usize);
356            if rho_new.abs() < 1e-300 {
357                return Err(SolverError::InternalError(
358                    "bicgstab_solve_cpu: rho breakdown".into(),
359                ));
360            }
361
362            let beta = if rho.abs() > 1e-300 && omega.abs() > 1e-300 {
363                (rho_new / rho) * (alpha / omega)
364            } else {
365                0.0
366            };
367            let beta_t = from_f64(beta);
368            let omega_t = from_f64(omega);
369
370            for i in 0..n_usize {
371                let pv = sub_t(p[i], mul_t(omega_t, v[i]));
372                p[i] = add_t(r[i], mul_t(beta_t, pv));
373            }
374
375            spmv(&p, &mut v)?;
376
377            let r0v = dot_product(&r0_hat, &v, n_usize);
378            if r0v.abs() < 1e-300 {
379                return Err(SolverError::InternalError(
380                    "bicgstab_solve_cpu: alpha breakdown".into(),
381                ));
382            }
383            alpha = rho_new / r0v;
384            let alpha_t = from_f64(alpha);
385
386            for i in 0..n_usize {
387                s[i] = sub_t(r[i], mul_t(alpha_t, v[i]));
388            }
389
390            let s_norm = vec_norm(&s, n_usize);
391            if s_norm < abs_tol {
392                for i in 0..n_usize {
393                    x[i] = add_t(x[i], mul_t(alpha_t, p[i]));
394                }
395                return Ok(iter + 1);
396            }
397
398            spmv(&s, &mut t)?;
399
400            let tt = dot_product(&t, &t, n_usize);
401            omega = if tt.abs() > 1e-300 {
402                dot_product(&t, &s, n_usize) / tt
403            } else {
404                0.0
405            };
406            let omega_new_t = from_f64(omega);
407
408            for i in 0..n_usize {
409                x[i] = add_t(x[i], add_t(mul_t(alpha_t, p[i]), mul_t(omega_new_t, s[i])));
410            }
411
412            for i in 0..n_usize {
413                r[i] = sub_t(s[i], mul_t(omega_new_t, t[i]));
414            }
415
416            let r_norm = vec_norm(&r, n_usize);
417            if r_norm < abs_tol {
418                return Ok(iter + 1);
419            }
420
421            if omega.abs() < 1e-300 {
422                return Err(SolverError::InternalError(
423                    "bicgstab_solve_cpu: omega breakdown".into(),
424                ));
425            }
426
427            rho = rho_new;
428        }
429
430        Err(SolverError::ConvergenceFailure {
431            iterations: config.max_iter,
432            residual: vec_norm(&r, n_usize),
433        })
434    }
435
436    #[test]
437    fn bicgstab_config_default() {
438        let cfg = BiCgStabConfig::default();
439        assert_eq!(cfg.max_iter, 1000);
440        assert!((cfg.tol - 1e-6).abs() < 1e-15);
441    }
442
443    #[test]
444    fn bicgstab_config_custom() {
445        let cfg = BiCgStabConfig {
446            max_iter: 2000,
447            tol: 1e-8,
448        };
449        assert_eq!(cfg.max_iter, 2000);
450        assert!((cfg.tol - 1e-8).abs() < 1e-20);
451    }
452
453    #[test]
454    fn dot_product_basic() {
455        let a = [1.0_f64, 2.0, 3.0];
456        let b = [4.0_f64, 5.0, 6.0];
457        assert!((dot_product(&a, &b, 3) - 32.0).abs() < 1e-10);
458    }
459
460    #[test]
461    fn vec_norm_basic() {
462        let v = [3.0_f64, 4.0];
463        assert!((vec_norm(&v, 2) - 5.0).abs() < 1e-10);
464    }
465
466    /// BiCGSTAB converges on a 3×3 symmetric positive definite system.
467    ///
468    /// A = [[4,-1,0],[-1,4,-1],[0,-1,4]], b = [6, 0, 6].
469    /// Exact solution (via numpy): x = [12/7, 6/7, 12/7] ≈ [1.7143, 0.8571, 1.7143].
470    #[test]
471    fn bicgstab_converges_spd_3x3() {
472        let b = vec![6.0_f64, 0.0, 6.0];
473        let mut x = vec![0.0_f64; 3];
474        let config = BiCgStabConfig {
475            max_iter: 200,
476            tol: 1e-10,
477        };
478
479        // A = [[4,-1,0],[-1,4,-1],[0,-1,4]]
480        let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
481            out[0] = 4.0 * v[0] - v[1];
482            out[1] = -v[0] + 4.0 * v[1] - v[2];
483            out[2] = -v[1] + 4.0 * v[2];
484            Ok(())
485        };
486
487        let _iters = bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config)
488            .expect("BiCGSTAB should converge on SPD system");
489
490        let x0_exact = 12.0_f64 / 7.0; // ≈ 1.714286
491        let x1_exact = 6.0_f64 / 7.0; // ≈ 0.857143
492        assert!(
493            (x[0] - x0_exact).abs() < 1e-7,
494            "x[0] = {} expected {x0_exact}",
495            x[0]
496        );
497        assert!(
498            (x[1] - x1_exact).abs() < 1e-7,
499            "x[1] = {} expected {x1_exact}",
500            x[1]
501        );
502        assert!(
503            (x[2] - x0_exact).abs() < 1e-7,
504            "x[2] = {} expected {x0_exact}",
505            x[2]
506        );
507    }
508
509    /// BiCGSTAB converges on the identity system in a single iteration.
510    #[test]
511    fn bicgstab_converges_identity() {
512        let b = vec![5.0_f64, -3.0, 2.0];
513        let mut x = vec![0.0_f64; 3];
514        let config = BiCgStabConfig {
515            max_iter: 50,
516            tol: 1e-12,
517        };
518
519        // A = I
520        let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
521            out.copy_from_slice(v);
522            Ok(())
523        };
524
525        let _iters = bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config)
526            .expect("BiCGSTAB should converge on identity");
527
528        assert!((x[0] - 5.0).abs() < 1e-9);
529        assert!((x[1] - (-3.0)).abs() < 1e-9);
530        assert!((x[2] - 2.0).abs() < 1e-9);
531    }
532
533    /// BiCGSTAB with zero RHS returns the zero vector immediately.
534    #[test]
535    fn bicgstab_zero_rhs_returns_zero() {
536        let b = vec![0.0_f64; 3];
537        let mut x = vec![1.0_f64; 3];
538        let config = BiCgStabConfig::default();
539
540        let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
541            out.copy_from_slice(v);
542            Ok(())
543        };
544
545        let iters =
546            bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config).expect("zero RHS should succeed");
547        assert_eq!(iters, 0);
548        for &xi in &x {
549            assert!(xi.abs() < 1e-15);
550        }
551    }
552
553    /// BiCGSTAB returns DimensionMismatch when b is shorter than n.
554    #[test]
555    fn bicgstab_dimension_mismatch() {
556        let b = vec![1.0_f64]; // length 1, n = 3
557        let mut x = vec![0.0_f64; 3];
558        let config = BiCgStabConfig::default();
559        let spmv = |_: &[f64], _: &mut [f64]| -> SolverResult<()> { Ok(()) };
560        let result = bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config);
561        assert!(matches!(result, Err(SolverError::DimensionMismatch(_))));
562    }
563
564    /// BiCGSTAB converges on a diagonal system with varying eigenvalues.
565    ///
566    /// A = diag(1, 3, 7), b = [2, 9, 14] → exact x = [2, 3, 2].
567    #[test]
568    fn bicgstab_converges_diagonal() {
569        let b = vec![2.0_f64, 9.0, 14.0];
570        let mut x = vec![0.0_f64; 3];
571        let config = BiCgStabConfig {
572            max_iter: 200,
573            tol: 1e-10,
574        };
575
576        let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
577            out[0] = 1.0 * v[0];
578            out[1] = 3.0 * v[1];
579            out[2] = 7.0 * v[2];
580            Ok(())
581        };
582
583        let _iters = bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config)
584            .expect("BiCGSTAB should converge on diagonal system");
585
586        assert!((x[0] - 2.0).abs() < 1e-8, "x[0] = {} expected 2.0", x[0]);
587        assert!((x[1] - 3.0).abs() < 1e-8, "x[1] = {} expected 3.0", x[1]);
588        assert!((x[2] - 2.0).abs() < 1e-8, "x[2] = {} expected 2.0", x[2]);
589    }
590}