use crate::error::{SparseError, SparseResult};
use crate::gpu::construction::GpuCsrMatrix;
use crate::gpu::spmv::GpuSpMvBackend;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum GpuSolverBackend {
#[default]
Cpu,
WebGpu,
}
impl From<GpuSolverBackend> for GpuSpMvBackend {
fn from(b: GpuSolverBackend) -> Self {
match b {
GpuSolverBackend::Cpu => GpuSpMvBackend::Cpu,
GpuSolverBackend::WebGpu => GpuSpMvBackend::WebGpu,
}
}
}
#[derive(Debug, Clone)]
pub struct GpuSolverConfig {
pub max_iter: usize,
pub tol: f64,
pub precond: bool,
pub backend: GpuSolverBackend,
}
impl Default for GpuSolverConfig {
fn default() -> Self {
Self {
max_iter: 1000,
tol: 1e-8,
precond: true,
backend: GpuSolverBackend::Cpu,
}
}
}
#[derive(Debug, Clone)]
pub struct SolverResult {
pub x: Vec<f64>,
pub residual_norm: f64,
pub n_iter: usize,
pub converged: bool,
}
#[inline]
fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn norm2(v: &[f64]) -> f64 {
dot(v, v).sqrt()
}
#[inline]
fn axpy(alpha: f64, x: &[f64], y: &mut [f64]) {
for (yi, &xi) in y.iter_mut().zip(x.iter()) {
*yi += alpha * xi;
}
}
#[inline]
fn axpby(alpha: f64, x: &[f64], beta: f64, y: &[f64], z: &mut [f64]) {
for ((zi, &xi), &yi) in z.iter_mut().zip(x.iter()).zip(y.iter()) {
*zi = alpha * xi + beta * yi;
}
}
fn jacobi_diag(matrix: &GpuCsrMatrix) -> Vec<f64> {
let n = matrix.n_rows;
let mut diag = vec![1.0_f64; n];
for row in 0..n {
let start = matrix.row_ptr[row];
let end = matrix.row_ptr[row + 1];
for k in start..end {
if matrix.col_idx[k] == row {
let d = matrix.values[k];
if d.abs() > f64::EPSILON {
diag[row] = d;
}
}
}
}
diag
}
fn apply_jacobi(diag: &[f64], r: &[f64], z: &mut [f64]) {
for ((zi, &ri), &di) in z.iter_mut().zip(r.iter()).zip(diag.iter()) {
*zi = ri / di;
}
}
pub fn cg_csr(
matrix: &GpuCsrMatrix,
b: &[f64],
x0: Option<&[f64]>,
config: &GpuSolverConfig,
) -> SparseResult<SolverResult> {
let n = matrix.n_rows;
if matrix.n_cols != n {
return Err(SparseError::ComputationError(
"CG requires a square matrix".to_string(),
));
}
if b.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: b.len(),
});
}
if let Some(x) = x0 {
if x.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: x.len(),
});
}
}
let diag = if config.precond {
jacobi_diag(matrix)
} else {
vec![1.0; n]
};
let mut x = match x0 {
Some(x0) => x0.to_vec(),
None => vec![0.0; n],
};
let ax = matrix.spmv(&x)?;
let mut r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, &ai)| bi - ai).collect();
let mut z = vec![0.0; n];
apply_jacobi(&diag, &r, &mut z);
let mut p = z.clone();
let mut rz = dot(&r, &z);
let b_norm = norm2(b);
let tol_abs = if b_norm > 0.0 {
config.tol * b_norm
} else {
config.tol
};
let mut iter = 0usize;
let mut converged = false;
while iter < config.max_iter {
let ap = matrix.spmv(&p)?;
let pap = dot(&p, &ap);
if pap.abs() < f64::MIN_POSITIVE {
break; }
let alpha = rz / pap;
axpy(alpha, &p, &mut x);
axpy(-alpha, &ap, &mut r);
let r_norm = norm2(&r);
iter += 1;
if r_norm <= tol_abs {
converged = true;
break;
}
apply_jacobi(&diag, &r, &mut z);
let rz_new = dot(&r, &z);
let beta = rz_new / rz;
rz = rz_new;
let p_old = p.clone();
axpby(1.0, &z, beta, &p_old, &mut p);
}
let residual = matrix.spmv(&x)?;
let res_norm = norm2(
&b.iter()
.zip(residual.iter())
.map(|(bi, &ri)| bi - ri)
.collect::<Vec<_>>(),
);
Ok(SolverResult {
x,
residual_norm: res_norm,
n_iter: iter,
converged,
})
}
pub fn bicgstab_csr(
matrix: &GpuCsrMatrix,
b: &[f64],
x0: Option<&[f64]>,
config: &GpuSolverConfig,
) -> SparseResult<SolverResult> {
let n = matrix.n_rows;
if matrix.n_cols != n {
return Err(SparseError::ComputationError(
"BiCGSTAB requires a square matrix".to_string(),
));
}
if b.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: b.len(),
});
}
let diag = if config.precond {
jacobi_diag(matrix)
} else {
vec![1.0; n]
};
let mut x = match x0 {
Some(x0) => x0.to_vec(),
None => vec![0.0; n],
};
let ax0 = matrix.spmv(&x)?;
let mut r: Vec<f64> = b.iter().zip(ax0.iter()).map(|(bi, &ai)| bi - ai).collect();
let r_hat = r.clone();
let b_norm = norm2(b);
let tol_abs = if b_norm > 0.0 {
config.tol * b_norm
} else {
config.tol
};
let mut p = r.clone();
let mut rho = dot(&r_hat, &r);
#[allow(unused_assignments)]
let mut omega = 1.0_f64;
let mut p_hat = vec![0.0; n];
let mut s_hat = vec![0.0; n];
let mut iter = 0usize;
let mut converged = false;
while iter < config.max_iter {
apply_jacobi(&diag, &p, &mut p_hat);
let v = matrix.spmv(&p_hat)?;
let rtv = dot(&r_hat, &v);
if rtv.abs() < f64::MIN_POSITIVE {
break; }
let alpha = rho / rtv;
let mut s: Vec<f64> = r
.iter()
.zip(v.iter())
.map(|(&ri, &vi)| ri - alpha * vi)
.collect();
let s_norm = norm2(&s);
if s_norm <= tol_abs {
axpy(alpha, &p_hat, &mut x);
iter += 1;
converged = true;
break;
}
apply_jacobi(&diag, &s, &mut s_hat);
let t = matrix.spmv(&s_hat)?;
let tt = dot(&t, &t);
omega = if tt > f64::MIN_POSITIVE {
dot(&t, &s) / tt
} else {
break;
};
axpy(alpha, &p_hat, &mut x);
axpy(omega, &s_hat, &mut x);
for ((ri, &si), &ti) in r.iter_mut().zip(s.iter()).zip(t.iter()) {
*ri = si - omega * ti;
}
let r_norm = norm2(&r);
iter += 1;
if r_norm <= tol_abs {
converged = true;
break;
}
let rho_new = dot(&r_hat, &r);
if rho_new.abs() < f64::MIN_POSITIVE {
break;
}
let beta = (rho_new / rho) * (alpha / omega);
rho = rho_new;
for ((pi, &ri), &vi) in p.iter_mut().zip(r.iter()).zip(v.iter()) {
*pi = ri + beta * (*pi - omega * vi);
}
}
let residual = matrix.spmv(&x)?;
let res_norm = norm2(
&b.iter()
.zip(residual.iter())
.map(|(bi, &ri)| bi - ri)
.collect::<Vec<_>>(),
);
Ok(SolverResult {
x,
residual_norm: res_norm,
n_iter: iter,
converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpu::construction::{GpuCooMatrix, GpuCsrMatrix};
fn tridiag_spd(n: usize) -> GpuCsrMatrix {
let mut coo = GpuCooMatrix::new(n, n);
for i in 0..n {
coo.push(i, i, 4.0);
if i > 0 {
coo.push(i, i - 1, -1.0);
coo.push(i - 1, i, -1.0);
}
}
coo.to_csr()
}
#[test]
fn test_cg_spd_system() {
let n = 5;
let mat = tridiag_spd(n);
let x_true = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = mat.spmv(&x_true).expect("spmv failed");
let config = GpuSolverConfig::default();
let result = cg_csr(&mat, &b, None, &config).expect("CG failed");
assert!(result.converged, "CG did not converge");
assert!(result.residual_norm < 1e-6);
for (xi, &xt) in result.x.iter().zip(x_true.iter()) {
assert!((xi - xt).abs() < 1e-6, "x[i]={xi} expected {xt}");
}
}
#[test]
fn test_bicgstab_general() {
let mut coo = GpuCooMatrix::new(4, 4);
coo.push(0, 0, 4.0);
coo.push(0, 1, 1.0);
coo.push(1, 0, 2.0);
coo.push(1, 1, 5.0);
coo.push(1, 2, 1.0);
coo.push(2, 1, 2.0);
coo.push(2, 2, 6.0);
coo.push(2, 3, 1.0);
coo.push(3, 2, 2.0);
coo.push(3, 3, 7.0);
let mat = coo.to_csr();
let x_true = vec![1.0, 2.0, 3.0, 4.0];
let b = mat.spmv(&x_true).expect("spmv failed");
let config = GpuSolverConfig::default();
let result = bicgstab_csr(&mat, &b, None, &config).expect("BiCGSTAB failed");
assert!(result.converged, "BiCGSTAB did not converge");
assert!(result.residual_norm < 1e-6);
}
#[test]
fn test_cg_with_precond() {
let n = 10;
let mat = tridiag_spd(n);
let b = vec![1.0; n];
let config_precond = GpuSolverConfig {
precond: true,
..Default::default()
};
let config_nopc = GpuSolverConfig {
precond: false,
..Default::default()
};
let r_precond = cg_csr(&mat, &b, None, &config_precond).expect("CG failed");
let r_nopc = cg_csr(&mat, &b, None, &config_nopc).expect("CG failed");
assert!(r_precond.converged);
assert!(r_nopc.converged);
assert!(r_precond.n_iter <= r_nopc.n_iter + 5); }
#[test]
fn test_cg_with_initial_guess() {
let n = 5;
let mat = tridiag_spd(n);
let x_true = vec![1.0; n];
let b = mat.spmv(&x_true).expect("spmv failed");
let x0 = vec![0.9; n];
let config = GpuSolverConfig::default();
let result = cg_csr(&mat, &b, Some(&x0), &config).expect("CG failed");
assert!(result.converged);
}
#[test]
fn test_solver_dimension_mismatch() {
let n = 3;
let mat = tridiag_spd(n);
let b_wrong = vec![1.0; n + 1];
let config = GpuSolverConfig::default();
assert!(cg_csr(&mat, &b_wrong, None, &config).is_err());
assert!(bicgstab_csr(&mat, &b_wrong, None, &config).is_err());
}
}