#![allow(dead_code)]
use oxicuda_blas::GpuFloat;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
fn to_f64<T: GpuFloat>(val: T) -> f64 {
if T::SIZE == 4 {
f32::from_bits(val.to_bits_u64() as u32) as f64
} else {
f64::from_bits(val.to_bits_u64())
}
}
fn from_f64<T: GpuFloat>(val: f64) -> T {
if T::SIZE == 4 {
T::from_bits_u64(u64::from((val as f32).to_bits()))
} else {
T::from_bits_u64(val.to_bits())
}
}
#[derive(Debug, Clone)]
pub struct CgConfig {
pub max_iter: u32,
pub tol: f64,
}
impl Default for CgConfig {
fn default() -> Self {
Self {
max_iter: 1000,
tol: 1e-6,
}
}
}
pub fn cg_solve<T, F>(
_handle: &SolverHandle,
spmv: F,
b: &[T],
x: &mut [T],
n: u32,
config: &CgConfig,
) -> SolverResult<u32>
where
T: GpuFloat,
F: Fn(&[T], &mut [T]) -> SolverResult<()>,
{
let n_usize = n as usize;
if b.len() < n_usize {
return Err(SolverError::DimensionMismatch(format!(
"cg_solve: b length ({}) < n ({n})",
b.len()
)));
}
if x.len() < n_usize {
return Err(SolverError::DimensionMismatch(format!(
"cg_solve: x length ({}) < n ({n})",
x.len()
)));
}
if n == 0 {
return Ok(0);
}
let b_norm = vec_norm(b, n_usize);
let abs_tol = if b_norm > 0.0 {
config.tol * b_norm
} else {
for xi in x.iter_mut().take(n_usize) {
*xi = T::gpu_zero();
}
return Ok(0);
};
let mut r = vec![T::gpu_zero(); n_usize];
let mut ap = vec![T::gpu_zero(); n_usize];
spmv(x, &mut ap)?;
for i in 0..n_usize {
r[i] = sub_t(b[i], ap[i]);
}
let mut p = r.clone();
let mut rsold = dot_product(&r, &r, n_usize);
if rsold.sqrt() < abs_tol {
return Ok(0);
}
for iter in 0..config.max_iter {
spmv(&p, &mut ap)?;
let pap = dot_product(&p, &ap, n_usize);
if pap.abs() < 1e-300 {
return Err(SolverError::InternalError(
"cg_solve: p^T * A * p is near zero (A may not be SPD)".into(),
));
}
let alpha = rsold / pap;
let alpha_t = from_f64(alpha);
for i in 0..n_usize {
x[i] = add_t(x[i], mul_t(alpha_t, p[i]));
}
for i in 0..n_usize {
r[i] = sub_t(r[i], mul_t(alpha_t, ap[i]));
}
let rsnew = dot_product(&r, &r, n_usize);
if rsnew.sqrt() < abs_tol {
return Ok(iter + 1);
}
let beta = rsnew / rsold;
let beta_t = from_f64(beta);
for i in 0..n_usize {
p[i] = add_t(r[i], mul_t(beta_t, p[i]));
}
rsold = rsnew;
}
Err(SolverError::ConvergenceFailure {
iterations: config.max_iter,
residual: rsold.sqrt(),
})
}
fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
let mut sum = 0.0_f64;
for i in 0..n {
sum += to_f64(a[i]) * to_f64(b[i]);
}
sum
}
fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
dot_product(v, v, n).sqrt()
}
fn add_t<T: GpuFloat>(a: T, b: T) -> T {
from_f64(to_f64(a) + to_f64(b))
}
fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
from_f64(to_f64(a) - to_f64(b))
}
fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
from_f64(to_f64(a) * to_f64(b))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cg_config_default() {
let cfg = CgConfig::default();
assert_eq!(cfg.max_iter, 1000);
assert!((cfg.tol - 1e-6).abs() < 1e-15);
}
#[test]
fn dot_product_basic() {
let a = [1.0_f64, 2.0, 3.0];
let b = [4.0_f64, 5.0, 6.0];
let result = dot_product(&a, &b, 3);
assert!((result - 32.0).abs() < 1e-10);
}
#[test]
fn vec_norm_basic() {
let v = [3.0_f64, 4.0];
let result = vec_norm(&v, 2);
assert!((result - 5.0).abs() < 1e-10);
}
#[test]
fn add_sub_mul() {
let a = 3.0_f64;
let b = 4.0_f64;
assert!((to_f64(add_t(a, b)) - 7.0).abs() < 1e-15);
assert!((to_f64(sub_t(a, b)) - (-1.0)).abs() < 1e-15);
assert!((to_f64(mul_t(a, b)) - 12.0).abs() < 1e-15);
}
#[test]
fn cg_config_custom() {
let cfg = CgConfig {
max_iter: 500,
tol: 1e-10,
};
assert_eq!(cfg.max_iter, 500);
assert!((cfg.tol - 1e-10).abs() < 1e-20);
}
fn cpu_cg_f64(
spmv: impl Fn(&[f64], &mut [f64]),
b: &[f64],
x: &mut [f64],
n: usize,
max_iter: usize,
tol: f64,
) -> usize {
let b_norm = b.iter().map(|v| v * v).sum::<f64>().sqrt();
let abs_tol = tol * b_norm;
let mut ap = vec![0.0_f64; n];
spmv(x, &mut ap);
let mut r: Vec<f64> = (0..n).map(|i| b[i] - ap[i]).collect();
let mut p = r.clone();
let mut rsold: f64 = r.iter().map(|v| v * v).sum();
for iter in 0..max_iter {
spmv(&p, &mut ap);
let pap: f64 = p.iter().zip(&ap).map(|(pi, api)| pi * api).sum();
if pap.abs() < 1e-300 {
return iter;
}
let alpha = rsold / pap;
for i in 0..n {
x[i] += alpha * p[i];
r[i] -= alpha * ap[i];
}
let rsnew: f64 = r.iter().map(|v| v * v).sum();
if rsnew.sqrt() < abs_tol {
return iter + 1;
}
let beta = rsnew / rsold;
for i in 0..n {
p[i] = r[i] + beta * p[i];
}
rsold = rsnew;
}
max_iter
}
#[test]
fn test_cg_convergence_spd_2x2() {
let a = [[4.0_f64, 1.0], [1.0, 3.0]];
let spmv = |x: &[f64], y: &mut [f64]| {
y[0] = a[0][0] * x[0] + a[0][1] * x[1];
y[1] = a[1][0] * x[0] + a[1][1] * x[1];
};
let b = [1.0_f64, 2.0];
let mut x = [0.0_f64, 0.0];
let iters = cpu_cg_f64(spmv, &b, &mut x, 2, 100, 1e-12);
assert!(
iters <= 5,
"CG on 2×2 SPD system must converge in ≤ 5 iterations, took {iters}"
);
let x_exact = [1.0_f64 / 11.0, 7.0 / 11.0];
assert!(
(x[0] - x_exact[0]).abs() < 1e-10,
"CG 2×2: x[0]={} expected {}",
x[0],
x_exact[0],
);
assert!(
(x[1] - x_exact[1]).abs() < 1e-10,
"CG 2×2: x[1]={} expected {}",
x[1],
x_exact[1],
);
}
#[test]
fn test_cg_convergence_diagonal_5x5() {
let diag = [1.0_f64, 2.0, 3.0, 4.0, 5.0];
let spmv = |x: &[f64], y: &mut [f64]| {
for i in 0..5 {
y[i] = diag[i] * x[i];
}
};
let b = [1.0_f64, 2.0, 3.0, 4.0, 5.0];
let mut x = [0.0_f64; 5];
let iters = cpu_cg_f64(spmv, &b, &mut x, 5, 100, 1e-12);
assert!(
iters <= 10,
"CG on 5×5 diagonal SPD must converge in ≤ 10 iterations, took {iters}"
);
for (i, &xi) in x.iter().enumerate() {
assert!(
(xi - 1.0).abs() < 1e-10,
"CG diagonal 5×5: x[{i}]={xi} expected 1.0",
);
}
}
}