Skip to main content

oxicuda_solver/sparse/
cg.rs

1//! Conjugate Gradient (CG) iterative solver.
2//!
3//! Solves the linear system `A * x = b` where A is symmetric positive definite.
4//! The solver is matrix-free: it only requires a closure that computes the
5//! matrix-vector product `y = A * x`.
6//!
7//! # Algorithm
8//!
9//! The standard Conjugate Gradient algorithm (Hestenes & Stiefel, 1952):
10//! 1. r = b - A*x; p = r; rsold = r^T * r
11//! 2. For each iteration:
12//!    a. Ap = A * p
13//!    b. alpha = rsold / (p^T * Ap)
14//!    c. x += alpha * p
15//!    d. r -= alpha * Ap
16//!    e. rsnew = r^T * r
17//!    f. If sqrt(rsnew) < tol * ||b||: converged
18//!    g. p = r + (rsnew / rsold) * p
19//!    h. rsold = rsnew
20//!
21//! The solver operates on host-side vectors. For GPU-accelerated sparse
22//! matrix-vector products, the `spmv` closure should internally manage
23//! device memory transfers.
24
25#![allow(dead_code)]
26
27use oxicuda_blas::GpuFloat;
28
29use crate::error::{SolverError, SolverResult};
30use crate::handle::SolverHandle;
31
32// ---------------------------------------------------------------------------
33// GpuFloat <-> f64 conversion helpers
34// ---------------------------------------------------------------------------
35
36/// Converts a `GpuFloat` value to `f64` via bit reinterpretation.
37fn to_f64<T: GpuFloat>(val: T) -> f64 {
38    if T::SIZE == 4 {
39        f32::from_bits(val.to_bits_u64() as u32) as f64
40    } else {
41        f64::from_bits(val.to_bits_u64())
42    }
43}
44
45/// Converts an `f64` value to `T: GpuFloat` via bit reinterpretation.
46fn from_f64<T: GpuFloat>(val: f64) -> T {
47    if T::SIZE == 4 {
48        T::from_bits_u64(u64::from((val as f32).to_bits()))
49    } else {
50        T::from_bits_u64(val.to_bits())
51    }
52}
53
54// ---------------------------------------------------------------------------
55// Configuration
56// ---------------------------------------------------------------------------
57
58/// Configuration for the Conjugate Gradient solver.
59#[derive(Debug, Clone)]
60pub struct CgConfig {
61    /// Maximum number of iterations.
62    pub max_iter: u32,
63    /// Convergence tolerance (relative to ||b||).
64    pub tol: f64,
65}
66
67impl Default for CgConfig {
68    fn default() -> Self {
69        Self {
70            max_iter: 1000,
71            tol: 1e-6,
72        }
73    }
74}
75
76// ---------------------------------------------------------------------------
77// Public API
78// ---------------------------------------------------------------------------
79
80/// Solves `A * x = b` using the Conjugate Gradient method.
81///
82/// The matrix A is not passed directly. Instead, the caller provides a closure
83/// `spmv` that computes `y = A * x` given `x` and `y` buffers. This enables
84/// use with any sparse format, preconditioner, or matrix-free operator.
85///
86/// On entry, `x` should contain an initial guess (e.g., zeros). On exit, `x`
87/// contains the approximate solution.
88///
89/// # Arguments
90///
91/// * `_handle` — solver handle (reserved for future GPU-accelerated variants).
92/// * `spmv` — closure computing `y = A * x`: `spmv(x, y)`.
93/// * `b` — right-hand side vector (length n).
94/// * `x` — initial guess / solution vector (length n), modified in-place.
95/// * `n` — system dimension.
96/// * `config` — solver configuration (tolerance, max iterations).
97///
98/// # Returns
99///
100/// The number of iterations performed.
101///
102/// # Errors
103///
104/// Returns [`SolverError::ConvergenceFailure`] if the solver does not converge
105/// within `max_iter` iterations.
106/// Returns [`SolverError::DimensionMismatch`] if vector lengths are invalid.
107pub fn cg_solve<T, F>(
108    _handle: &SolverHandle,
109    spmv: F,
110    b: &[T],
111    x: &mut [T],
112    n: u32,
113    config: &CgConfig,
114) -> SolverResult<u32>
115where
116    T: GpuFloat,
117    F: Fn(&[T], &mut [T]) -> SolverResult<()>,
118{
119    let n_usize = n as usize;
120
121    // Validate dimensions.
122    if b.len() < n_usize {
123        return Err(SolverError::DimensionMismatch(format!(
124            "cg_solve: b length ({}) < n ({n})",
125            b.len()
126        )));
127    }
128    if x.len() < n_usize {
129        return Err(SolverError::DimensionMismatch(format!(
130            "cg_solve: x length ({}) < n ({n})",
131            x.len()
132        )));
133    }
134    if n == 0 {
135        return Ok(0);
136    }
137
138    // Compute ||b|| for relative convergence check.
139    let b_norm = vec_norm(b, n_usize);
140    let abs_tol = if b_norm > 0.0 {
141        config.tol * b_norm
142    } else {
143        // b = 0 => x = 0 is the exact solution.
144        for xi in x.iter_mut().take(n_usize) {
145            *xi = T::gpu_zero();
146        }
147        return Ok(0);
148    };
149
150    // r = b - A*x
151    let mut r = vec![T::gpu_zero(); n_usize];
152    let mut ap = vec![T::gpu_zero(); n_usize];
153    spmv(x, &mut ap)?;
154    for i in 0..n_usize {
155        r[i] = sub_t(b[i], ap[i]);
156    }
157
158    // p = r.clone()
159    let mut p = r.clone();
160
161    // rsold = r^T * r
162    let mut rsold = dot_product(&r, &r, n_usize);
163
164    if rsold.sqrt() < abs_tol {
165        return Ok(0);
166    }
167
168    for iter in 0..config.max_iter {
169        // Ap = A * p
170        spmv(&p, &mut ap)?;
171
172        // alpha = rsold / (p^T * Ap)
173        let pap = dot_product(&p, &ap, n_usize);
174        if pap.abs() < 1e-300 {
175            return Err(SolverError::InternalError(
176                "cg_solve: p^T * A * p is near zero (A may not be SPD)".into(),
177            ));
178        }
179        let alpha = rsold / pap;
180        let alpha_t = from_f64(alpha);
181
182        // x += alpha * p
183        for i in 0..n_usize {
184            x[i] = add_t(x[i], mul_t(alpha_t, p[i]));
185        }
186
187        // r -= alpha * Ap
188        for i in 0..n_usize {
189            r[i] = sub_t(r[i], mul_t(alpha_t, ap[i]));
190        }
191
192        // rsnew = r^T * r
193        let rsnew = dot_product(&r, &r, n_usize);
194
195        // Check convergence.
196        if rsnew.sqrt() < abs_tol {
197            return Ok(iter + 1);
198        }
199
200        // beta = rsnew / rsold
201        let beta = rsnew / rsold;
202        let beta_t = from_f64(beta);
203
204        // p = r + beta * p
205        for i in 0..n_usize {
206            p[i] = add_t(r[i], mul_t(beta_t, p[i]));
207        }
208
209        rsold = rsnew;
210    }
211
212    Err(SolverError::ConvergenceFailure {
213        iterations: config.max_iter,
214        residual: rsold.sqrt(),
215    })
216}
217
218// ---------------------------------------------------------------------------
219// Vector arithmetic helpers (host-side, generic over GpuFloat)
220// ---------------------------------------------------------------------------
221
222/// Computes the dot product of two vectors as f64.
223fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
224    let mut sum = 0.0_f64;
225    for i in 0..n {
226        sum += to_f64(a[i]) * to_f64(b[i]);
227    }
228    sum
229}
230
231/// Computes the 2-norm of a vector as f64.
232fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
233    dot_product(v, v, n).sqrt()
234}
235
236/// Adds two GpuFloat values.
237fn add_t<T: GpuFloat>(a: T, b: T) -> T {
238    from_f64(to_f64(a) + to_f64(b))
239}
240
241/// Subtracts two GpuFloat values.
242fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
243    from_f64(to_f64(a) - to_f64(b))
244}
245
246/// Multiplies two GpuFloat values.
247fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
248    from_f64(to_f64(a) * to_f64(b))
249}
250
251// ---------------------------------------------------------------------------
252// Tests
253// ---------------------------------------------------------------------------
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn cg_config_default() {
261        let cfg = CgConfig::default();
262        assert_eq!(cfg.max_iter, 1000);
263        assert!((cfg.tol - 1e-6).abs() < 1e-15);
264    }
265
266    #[test]
267    fn dot_product_basic() {
268        let a = [1.0_f64, 2.0, 3.0];
269        let b = [4.0_f64, 5.0, 6.0];
270        let result = dot_product(&a, &b, 3);
271        assert!((result - 32.0).abs() < 1e-10);
272    }
273
274    #[test]
275    fn vec_norm_basic() {
276        let v = [3.0_f64, 4.0];
277        let result = vec_norm(&v, 2);
278        assert!((result - 5.0).abs() < 1e-10);
279    }
280
281    #[test]
282    fn add_sub_mul() {
283        let a = 3.0_f64;
284        let b = 4.0_f64;
285        assert!((to_f64(add_t(a, b)) - 7.0).abs() < 1e-15);
286        assert!((to_f64(sub_t(a, b)) - (-1.0)).abs() < 1e-15);
287        assert!((to_f64(mul_t(a, b)) - 12.0).abs() < 1e-15);
288    }
289
290    #[test]
291    fn cg_config_custom() {
292        let cfg = CgConfig {
293            max_iter: 500,
294            tol: 1e-10,
295        };
296        assert_eq!(cfg.max_iter, 500);
297        assert!((cfg.tol - 1e-10).abs() < 1e-20);
298    }
299
300    // -----------------------------------------------------------------------
301    // Quality gate: CG convergence on a 2×2 SPD system (CPU simulation)
302    // -----------------------------------------------------------------------
303
304    /// CPU-only conjugate gradient implementation for testing purposes.
305    ///
306    /// Solves A * x = b without requiring a `SolverHandle` (GPU context).
307    /// This isolates the algorithmic correctness from the GPU infrastructure.
308    fn cpu_cg_f64(
309        spmv: impl Fn(&[f64], &mut [f64]),
310        b: &[f64],
311        x: &mut [f64],
312        n: usize,
313        max_iter: usize,
314        tol: f64,
315    ) -> usize {
316        let b_norm = b.iter().map(|v| v * v).sum::<f64>().sqrt();
317        let abs_tol = tol * b_norm;
318
319        let mut ap = vec![0.0_f64; n];
320        spmv(x, &mut ap);
321        let mut r: Vec<f64> = (0..n).map(|i| b[i] - ap[i]).collect();
322        let mut p = r.clone();
323        let mut rsold: f64 = r.iter().map(|v| v * v).sum();
324
325        for iter in 0..max_iter {
326            spmv(&p, &mut ap);
327            let pap: f64 = p.iter().zip(&ap).map(|(pi, api)| pi * api).sum();
328            if pap.abs() < 1e-300 {
329                return iter;
330            }
331            let alpha = rsold / pap;
332            for i in 0..n {
333                x[i] += alpha * p[i];
334                r[i] -= alpha * ap[i];
335            }
336            let rsnew: f64 = r.iter().map(|v| v * v).sum();
337            if rsnew.sqrt() < abs_tol {
338                return iter + 1;
339            }
340            let beta = rsnew / rsold;
341            for i in 0..n {
342                p[i] = r[i] + beta * p[i];
343            }
344            rsold = rsnew;
345        }
346        max_iter
347    }
348
349    /// Quality gate: CG convergence on A = [[4, 1], [1, 3]], b = [1, 2].
350    ///
351    /// Exact solution: x = A^{-1} b
352    ///   det(A) = 4*3 - 1*1 = 11
353    ///   A^{-1} = (1/11) * [[3, -1], [-1, 4]]
354    ///   x = (1/11) * [3*1 + (-1)*2, (-1)*1 + 4*2] = [1/11, 7/11]
355    ///
356    /// CG must converge in ≤ 5 iterations (at most n=2 for exact arithmetic).
357    #[test]
358    fn test_cg_convergence_spd_2x2() {
359        // A = [[4, 1], [1, 3]] — symmetric positive definite (eigenvalues 3.27, 3.73)
360        let a = [[4.0_f64, 1.0], [1.0, 3.0]];
361        let spmv = |x: &[f64], y: &mut [f64]| {
362            y[0] = a[0][0] * x[0] + a[0][1] * x[1];
363            y[1] = a[1][0] * x[0] + a[1][1] * x[1];
364        };
365
366        let b = [1.0_f64, 2.0];
367        let mut x = [0.0_f64, 0.0]; // zero initial guess
368
369        let iters = cpu_cg_f64(spmv, &b, &mut x, 2, 100, 1e-12);
370
371        // CG on an n×n SPD system converges in at most n steps in exact arithmetic.
372        assert!(
373            iters <= 5,
374            "CG on 2×2 SPD system must converge in ≤ 5 iterations, took {iters}"
375        );
376
377        // Verify solution matches x = [1/11, 7/11]
378        let x_exact = [1.0_f64 / 11.0, 7.0 / 11.0];
379        assert!(
380            (x[0] - x_exact[0]).abs() < 1e-10,
381            "CG 2×2: x[0]={} expected {}",
382            x[0],
383            x_exact[0],
384        );
385        assert!(
386            (x[1] - x_exact[1]).abs() < 1e-10,
387            "CG 2×2: x[1]={} expected {}",
388            x[1],
389            x_exact[1],
390        );
391    }
392
393    /// Quality gate: CG convergence on a 5×5 diagonal SPD system.
394    ///
395    /// For D = diag(1, 2, 3, 4, 5) and b = [1, 2, 3, 4, 5],
396    /// the exact solution is x = [1, 1, 1, 1, 1].
397    /// CG must converge in ≤ 10 iterations.
398    #[test]
399    fn test_cg_convergence_diagonal_5x5() {
400        let diag = [1.0_f64, 2.0, 3.0, 4.0, 5.0];
401        let spmv = |x: &[f64], y: &mut [f64]| {
402            for i in 0..5 {
403                y[i] = diag[i] * x[i];
404            }
405        };
406        let b = [1.0_f64, 2.0, 3.0, 4.0, 5.0];
407        let mut x = [0.0_f64; 5];
408
409        let iters = cpu_cg_f64(spmv, &b, &mut x, 5, 100, 1e-12);
410
411        assert!(
412            iters <= 10,
413            "CG on 5×5 diagonal SPD must converge in ≤ 10 iterations, took {iters}"
414        );
415
416        for (i, &xi) in x.iter().enumerate() {
417            assert!(
418                (xi - 1.0).abs() < 1e-10,
419                "CG diagonal 5×5: x[{i}]={xi} expected 1.0",
420            );
421        }
422    }
423}