oxicuda_solver/sparse/cg.rs
1//! Conjugate Gradient (CG) iterative solver.
2//!
3//! Solves the linear system `A * x = b` where A is symmetric positive definite.
4//! The solver is matrix-free: it only requires a closure that computes the
5//! matrix-vector product `y = A * x`.
6//!
7//! # Algorithm
8//!
9//! The standard Conjugate Gradient algorithm (Hestenes & Stiefel, 1952):
10//! 1. r = b - A*x; p = r; rsold = r^T * r
11//! 2. For each iteration:
12//! a. Ap = A * p
13//! b. alpha = rsold / (p^T * Ap)
14//! c. x += alpha * p
15//! d. r -= alpha * Ap
16//! e. rsnew = r^T * r
17//! f. If sqrt(rsnew) < tol * ||b||: converged
18//! g. p = r + (rsnew / rsold) * p
19//! h. rsold = rsnew
20//!
21//! The solver operates on host-side vectors. For GPU-accelerated sparse
22//! matrix-vector products, the `spmv` closure should internally manage
23//! device memory transfers.
24
25#![allow(dead_code)]
26
27use oxicuda_blas::GpuFloat;
28
29use crate::error::{SolverError, SolverResult};
30use crate::handle::SolverHandle;
31
32// ---------------------------------------------------------------------------
33// GpuFloat <-> f64 conversion helpers
34// ---------------------------------------------------------------------------
35
36/// Converts a `GpuFloat` value to `f64` via bit reinterpretation.
37fn to_f64<T: GpuFloat>(val: T) -> f64 {
38 if T::SIZE == 4 {
39 f32::from_bits(val.to_bits_u64() as u32) as f64
40 } else {
41 f64::from_bits(val.to_bits_u64())
42 }
43}
44
45/// Converts an `f64` value to `T: GpuFloat` via bit reinterpretation.
46fn from_f64<T: GpuFloat>(val: f64) -> T {
47 if T::SIZE == 4 {
48 T::from_bits_u64(u64::from((val as f32).to_bits()))
49 } else {
50 T::from_bits_u64(val.to_bits())
51 }
52}
53
54// ---------------------------------------------------------------------------
55// Configuration
56// ---------------------------------------------------------------------------
57
58/// Configuration for the Conjugate Gradient solver.
59#[derive(Debug, Clone)]
60pub struct CgConfig {
61 /// Maximum number of iterations.
62 pub max_iter: u32,
63 /// Convergence tolerance (relative to ||b||).
64 pub tol: f64,
65}
66
67impl Default for CgConfig {
68 fn default() -> Self {
69 Self {
70 max_iter: 1000,
71 tol: 1e-6,
72 }
73 }
74}
75
76// ---------------------------------------------------------------------------
77// Public API
78// ---------------------------------------------------------------------------
79
80/// Solves `A * x = b` using the Conjugate Gradient method.
81///
82/// The matrix A is not passed directly. Instead, the caller provides a closure
83/// `spmv` that computes `y = A * x` given `x` and `y` buffers. This enables
84/// use with any sparse format, preconditioner, or matrix-free operator.
85///
86/// On entry, `x` should contain an initial guess (e.g., zeros). On exit, `x`
87/// contains the approximate solution.
88///
89/// # Arguments
90///
91/// * `_handle` — solver handle (reserved for future GPU-accelerated variants).
92/// * `spmv` — closure computing `y = A * x`: `spmv(x, y)`.
93/// * `b` — right-hand side vector (length n).
94/// * `x` — initial guess / solution vector (length n), modified in-place.
95/// * `n` — system dimension.
96/// * `config` — solver configuration (tolerance, max iterations).
97///
98/// # Returns
99///
100/// The number of iterations performed.
101///
102/// # Errors
103///
104/// Returns [`SolverError::ConvergenceFailure`] if the solver does not converge
105/// within `max_iter` iterations.
106/// Returns [`SolverError::DimensionMismatch`] if vector lengths are invalid.
107pub fn cg_solve<T, F>(
108 _handle: &SolverHandle,
109 spmv: F,
110 b: &[T],
111 x: &mut [T],
112 n: u32,
113 config: &CgConfig,
114) -> SolverResult<u32>
115where
116 T: GpuFloat,
117 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
118{
119 let n_usize = n as usize;
120
121 // Validate dimensions.
122 if b.len() < n_usize {
123 return Err(SolverError::DimensionMismatch(format!(
124 "cg_solve: b length ({}) < n ({n})",
125 b.len()
126 )));
127 }
128 if x.len() < n_usize {
129 return Err(SolverError::DimensionMismatch(format!(
130 "cg_solve: x length ({}) < n ({n})",
131 x.len()
132 )));
133 }
134 if n == 0 {
135 return Ok(0);
136 }
137
138 // Compute ||b|| for relative convergence check.
139 let b_norm = vec_norm(b, n_usize);
140 let abs_tol = if b_norm > 0.0 {
141 config.tol * b_norm
142 } else {
143 // b = 0 => x = 0 is the exact solution.
144 for xi in x.iter_mut().take(n_usize) {
145 *xi = T::gpu_zero();
146 }
147 return Ok(0);
148 };
149
150 // r = b - A*x
151 let mut r = vec![T::gpu_zero(); n_usize];
152 let mut ap = vec![T::gpu_zero(); n_usize];
153 spmv(x, &mut ap)?;
154 for i in 0..n_usize {
155 r[i] = sub_t(b[i], ap[i]);
156 }
157
158 // p = r.clone()
159 let mut p = r.clone();
160
161 // rsold = r^T * r
162 let mut rsold = dot_product(&r, &r, n_usize);
163
164 if rsold.sqrt() < abs_tol {
165 return Ok(0);
166 }
167
168 for iter in 0..config.max_iter {
169 // Ap = A * p
170 spmv(&p, &mut ap)?;
171
172 // alpha = rsold / (p^T * Ap)
173 let pap = dot_product(&p, &ap, n_usize);
174 if pap.abs() < 1e-300 {
175 return Err(SolverError::InternalError(
176 "cg_solve: p^T * A * p is near zero (A may not be SPD)".into(),
177 ));
178 }
179 let alpha = rsold / pap;
180 let alpha_t = from_f64(alpha);
181
182 // x += alpha * p
183 for i in 0..n_usize {
184 x[i] = add_t(x[i], mul_t(alpha_t, p[i]));
185 }
186
187 // r -= alpha * Ap
188 for i in 0..n_usize {
189 r[i] = sub_t(r[i], mul_t(alpha_t, ap[i]));
190 }
191
192 // rsnew = r^T * r
193 let rsnew = dot_product(&r, &r, n_usize);
194
195 // Check convergence.
196 if rsnew.sqrt() < abs_tol {
197 return Ok(iter + 1);
198 }
199
200 // beta = rsnew / rsold
201 let beta = rsnew / rsold;
202 let beta_t = from_f64(beta);
203
204 // p = r + beta * p
205 for i in 0..n_usize {
206 p[i] = add_t(r[i], mul_t(beta_t, p[i]));
207 }
208
209 rsold = rsnew;
210 }
211
212 Err(SolverError::ConvergenceFailure {
213 iterations: config.max_iter,
214 residual: rsold.sqrt(),
215 })
216}
217
218// ---------------------------------------------------------------------------
219// Vector arithmetic helpers (host-side, generic over GpuFloat)
220// ---------------------------------------------------------------------------
221
222/// Computes the dot product of two vectors as f64.
223fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
224 let mut sum = 0.0_f64;
225 for i in 0..n {
226 sum += to_f64(a[i]) * to_f64(b[i]);
227 }
228 sum
229}
230
231/// Computes the 2-norm of a vector as f64.
232fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
233 dot_product(v, v, n).sqrt()
234}
235
236/// Adds two GpuFloat values.
237fn add_t<T: GpuFloat>(a: T, b: T) -> T {
238 from_f64(to_f64(a) + to_f64(b))
239}
240
241/// Subtracts two GpuFloat values.
242fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
243 from_f64(to_f64(a) - to_f64(b))
244}
245
246/// Multiplies two GpuFloat values.
247fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
248 from_f64(to_f64(a) * to_f64(b))
249}
250
251// ---------------------------------------------------------------------------
252// Tests
253// ---------------------------------------------------------------------------
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn cg_config_default() {
261 let cfg = CgConfig::default();
262 assert_eq!(cfg.max_iter, 1000);
263 assert!((cfg.tol - 1e-6).abs() < 1e-15);
264 }
265
266 #[test]
267 fn dot_product_basic() {
268 let a = [1.0_f64, 2.0, 3.0];
269 let b = [4.0_f64, 5.0, 6.0];
270 let result = dot_product(&a, &b, 3);
271 assert!((result - 32.0).abs() < 1e-10);
272 }
273
274 #[test]
275 fn vec_norm_basic() {
276 let v = [3.0_f64, 4.0];
277 let result = vec_norm(&v, 2);
278 assert!((result - 5.0).abs() < 1e-10);
279 }
280
281 #[test]
282 fn add_sub_mul() {
283 let a = 3.0_f64;
284 let b = 4.0_f64;
285 assert!((to_f64(add_t(a, b)) - 7.0).abs() < 1e-15);
286 assert!((to_f64(sub_t(a, b)) - (-1.0)).abs() < 1e-15);
287 assert!((to_f64(mul_t(a, b)) - 12.0).abs() < 1e-15);
288 }
289
290 #[test]
291 fn cg_config_custom() {
292 let cfg = CgConfig {
293 max_iter: 500,
294 tol: 1e-10,
295 };
296 assert_eq!(cfg.max_iter, 500);
297 assert!((cfg.tol - 1e-10).abs() < 1e-20);
298 }
299
300 // -----------------------------------------------------------------------
301 // Quality gate: CG convergence on a 2×2 SPD system (CPU simulation)
302 // -----------------------------------------------------------------------
303
304 /// CPU-only conjugate gradient implementation for testing purposes.
305 ///
306 /// Solves A * x = b without requiring a `SolverHandle` (GPU context).
307 /// This isolates the algorithmic correctness from the GPU infrastructure.
308 fn cpu_cg_f64(
309 spmv: impl Fn(&[f64], &mut [f64]),
310 b: &[f64],
311 x: &mut [f64],
312 n: usize,
313 max_iter: usize,
314 tol: f64,
315 ) -> usize {
316 let b_norm = b.iter().map(|v| v * v).sum::<f64>().sqrt();
317 let abs_tol = tol * b_norm;
318
319 let mut ap = vec![0.0_f64; n];
320 spmv(x, &mut ap);
321 let mut r: Vec<f64> = (0..n).map(|i| b[i] - ap[i]).collect();
322 let mut p = r.clone();
323 let mut rsold: f64 = r.iter().map(|v| v * v).sum();
324
325 for iter in 0..max_iter {
326 spmv(&p, &mut ap);
327 let pap: f64 = p.iter().zip(&ap).map(|(pi, api)| pi * api).sum();
328 if pap.abs() < 1e-300 {
329 return iter;
330 }
331 let alpha = rsold / pap;
332 for i in 0..n {
333 x[i] += alpha * p[i];
334 r[i] -= alpha * ap[i];
335 }
336 let rsnew: f64 = r.iter().map(|v| v * v).sum();
337 if rsnew.sqrt() < abs_tol {
338 return iter + 1;
339 }
340 let beta = rsnew / rsold;
341 for i in 0..n {
342 p[i] = r[i] + beta * p[i];
343 }
344 rsold = rsnew;
345 }
346 max_iter
347 }
348
349 /// Quality gate: CG convergence on A = [[4, 1], [1, 3]], b = [1, 2].
350 ///
351 /// Exact solution: x = A^{-1} b
352 /// det(A) = 4*3 - 1*1 = 11
353 /// A^{-1} = (1/11) * [[3, -1], [-1, 4]]
354 /// x = (1/11) * [3*1 + (-1)*2, (-1)*1 + 4*2] = [1/11, 7/11]
355 ///
356 /// CG must converge in ≤ 5 iterations (at most n=2 for exact arithmetic).
357 #[test]
358 fn test_cg_convergence_spd_2x2() {
359 // A = [[4, 1], [1, 3]] — symmetric positive definite (eigenvalues 3.27, 3.73)
360 let a = [[4.0_f64, 1.0], [1.0, 3.0]];
361 let spmv = |x: &[f64], y: &mut [f64]| {
362 y[0] = a[0][0] * x[0] + a[0][1] * x[1];
363 y[1] = a[1][0] * x[0] + a[1][1] * x[1];
364 };
365
366 let b = [1.0_f64, 2.0];
367 let mut x = [0.0_f64, 0.0]; // zero initial guess
368
369 let iters = cpu_cg_f64(spmv, &b, &mut x, 2, 100, 1e-12);
370
371 // CG on an n×n SPD system converges in at most n steps in exact arithmetic.
372 assert!(
373 iters <= 5,
374 "CG on 2×2 SPD system must converge in ≤ 5 iterations, took {iters}"
375 );
376
377 // Verify solution matches x = [1/11, 7/11]
378 let x_exact = [1.0_f64 / 11.0, 7.0 / 11.0];
379 assert!(
380 (x[0] - x_exact[0]).abs() < 1e-10,
381 "CG 2×2: x[0]={} expected {}",
382 x[0],
383 x_exact[0],
384 );
385 assert!(
386 (x[1] - x_exact[1]).abs() < 1e-10,
387 "CG 2×2: x[1]={} expected {}",
388 x[1],
389 x_exact[1],
390 );
391 }
392
393 /// Quality gate: CG convergence on a 5×5 diagonal SPD system.
394 ///
395 /// For D = diag(1, 2, 3, 4, 5) and b = [1, 2, 3, 4, 5],
396 /// the exact solution is x = [1, 1, 1, 1, 1].
397 /// CG must converge in ≤ 10 iterations.
398 #[test]
399 fn test_cg_convergence_diagonal_5x5() {
400 let diag = [1.0_f64, 2.0, 3.0, 4.0, 5.0];
401 let spmv = |x: &[f64], y: &mut [f64]| {
402 for i in 0..5 {
403 y[i] = diag[i] * x[i];
404 }
405 };
406 let b = [1.0_f64, 2.0, 3.0, 4.0, 5.0];
407 let mut x = [0.0_f64; 5];
408
409 let iters = cpu_cg_f64(spmv, &b, &mut x, 5, 100, 1e-12);
410
411 assert!(
412 iters <= 10,
413 "CG on 5×5 diagonal SPD must converge in ≤ 10 iterations, took {iters}"
414 );
415
416 for (i, &xi) in x.iter().enumerate() {
417 assert!(
418 (xi - 1.0).abs() < 1e-10,
419 "CG diagonal 5×5: x[{i}]={xi} expected 1.0",
420 );
421 }
422 }
423}