Skip to main content

scirs2_sparse/gpu/
spmv.rs

1//! GPU-ready SpMV (Sparse Matrix-Vector Multiply) operations
2//!
3//! Provides CPU-side SIMD-friendly implementations that serve as compute-shader
4//! placeholders. All hot paths are row-parallel and chunked for cache efficiency,
5//! matching the memory access pattern expected by a GPU compute kernel.
6
7use crate::error::{SparseError, SparseResult};
8use scirs2_core::ndarray::{Array2, Axis};
9
10// ============================================================
11// Configuration
12// ============================================================
13
14/// GPU compute backend selector.
15#[non_exhaustive]
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum GpuSpMvBackend {
18    /// CPU simulation — row-parallel, cache-efficient (always available).
19    #[default]
20    Cpu,
21    /// WebGPU via wgpu (feature-gated, not yet wired).
22    WebGpu,
23}
24
25/// Configuration for GPU-ready SpMV operations.
26#[derive(Debug, Clone)]
27pub struct GpuSpMvConfig {
28    /// Compute backend to use.
29    pub backend: GpuSpMvBackend,
30    /// Workgroup / warp size (default 256).
31    pub block_size: usize,
32    /// Number of warps per block (default 8).
33    pub n_warps: usize,
34    /// Whether to use texture memory / L1 hints for the x vector (default false).
35    pub use_texture: bool,
36}
37
38impl Default for GpuSpMvConfig {
39    fn default() -> Self {
40        Self {
41            backend: GpuSpMvBackend::Cpu,
42            block_size: 256,
43            n_warps: 8,
44            use_texture: false,
45        }
46    }
47}
48
49// ============================================================
50// CSR SpMV  y = A * x
51// ============================================================
52
53/// Compute `y = A * x` for a matrix stored in CSR format.
54///
55/// The CPU path processes rows in chunks sized to `config.block_size` so that
56/// the access pattern mirrors what a GPU compute shader would execute per
57/// workgroup.
58///
59/// # Errors
60///
61/// Returns [`SparseError::DimensionMismatch`] when the vector length does not
62/// match the number of columns implied by `col_idx`.
63pub fn csr_spmv(
64    row_ptr: &[usize],
65    col_idx: &[usize],
66    values: &[f64],
67    x: &[f64],
68    config: &GpuSpMvConfig,
69) -> SparseResult<Vec<f64>> {
70    if row_ptr.is_empty() {
71        return Ok(Vec::new());
72    }
73    let n_rows = row_ptr.len() - 1;
74
75    // Basic consistency check
76    if col_idx.len() != values.len() {
77        return Err(SparseError::InconsistentData {
78            reason: format!(
79                "col_idx length {} != values length {}",
80                col_idx.len(),
81                values.len()
82            ),
83        });
84    }
85
86    let mut y = vec![0.0_f64; n_rows];
87
88    match config.backend {
89        GpuSpMvBackend::Cpu => {
90            // Chunk rows by block_size to simulate GPU workgroup granularity.
91            let block = config.block_size.max(1);
92            let mut row_start = 0usize;
93            while row_start < n_rows {
94                let row_end = (row_start + block).min(n_rows);
95                for row in row_start..row_end {
96                    let col_start = row_ptr[row];
97                    let col_end = row_ptr[row + 1];
98                    let mut acc = 0.0_f64;
99                    for k in col_start..col_end {
100                        let col = col_idx[k];
101                        if col >= x.len() {
102                            return Err(SparseError::DimensionMismatch {
103                                expected: x.len(),
104                                found: col + 1,
105                            });
106                        }
107                        acc += values[k] * x[col];
108                    }
109                    y[row] = acc;
110                }
111                row_start = row_end;
112            }
113        }
114        GpuSpMvBackend::WebGpu => {
115            // Fall back to CPU simulation; real wgpu dispatch would go here.
116            for row in 0..n_rows {
117                let col_start = row_ptr[row];
118                let col_end = row_ptr[row + 1];
119                let mut acc = 0.0_f64;
120                for k in col_start..col_end {
121                    let col = col_idx[k];
122                    if col >= x.len() {
123                        return Err(SparseError::DimensionMismatch {
124                            expected: x.len(),
125                            found: col + 1,
126                        });
127                    }
128                    acc += values[k] * x[col];
129                }
130                y[row] = acc;
131            }
132        }
133    }
134
135    Ok(y)
136}
137
138// ============================================================
139// Batched SpMV  Y = A * X  where X is [n_cols, n_rhs]
140// ============================================================
141
142/// Compute `Y = A * X` for multiple right-hand side vectors.
143///
144/// `x_batch` has shape `[n_cols, n_rhs]`; the result has shape
145/// `[n_rows, n_rhs]`.
146///
147/// # Errors
148///
149/// Returns [`SparseError::DimensionMismatch`] when `x_batch.nrows()` does not
150/// match the number of columns in the sparse matrix.
151pub fn csr_spmv_batch(
152    row_ptr: &[usize],
153    col_idx: &[usize],
154    values: &[f64],
155    x_batch: &Array2<f64>,
156    config: &GpuSpMvConfig,
157) -> SparseResult<Array2<f64>> {
158    if row_ptr.is_empty() {
159        return Ok(Array2::zeros((0, x_batch.ncols())));
160    }
161    let n_rows = row_ptr.len() - 1;
162    let n_rhs = x_batch.ncols();
163    let n_cols = x_batch.nrows();
164
165    let mut y = Array2::zeros((n_rows, n_rhs));
166
167    for rhs in 0..n_rhs {
168        let x_col = x_batch.index_axis(Axis(1), rhs);
169        let x_slice: Vec<f64> = x_col.iter().copied().collect();
170        if x_slice.len() != n_cols {
171            return Err(SparseError::DimensionMismatch {
172                expected: n_cols,
173                found: x_slice.len(),
174            });
175        }
176        let y_col = csr_spmv(row_ptr, col_idx, values, &x_slice, config)?;
177        for row in 0..n_rows {
178            y[[row, rhs]] = y_col[row];
179        }
180    }
181
182    Ok(y)
183}
184
185// ============================================================
186// SpMM  C = A * B  where B is dense [n_cols, k]
187// ============================================================
188
189/// Compute the sparse-dense product `C = A * B`.
190///
191/// `b` has shape `[n_cols, k]`; the result `C` has shape `[n_rows, k]`.
192///
193/// # Errors
194///
195/// Returns [`SparseError::DimensionMismatch`] when `b.nrows()` does not equal
196/// the number of columns implied by `col_idx`.
197pub fn csr_spmm(
198    row_ptr: &[usize],
199    col_idx: &[usize],
200    values: &[f64],
201    b: &Array2<f64>,
202    config: &GpuSpMvConfig,
203) -> SparseResult<Array2<f64>> {
204    if row_ptr.is_empty() {
205        return Ok(Array2::zeros((0, b.ncols())));
206    }
207    let n_rows = row_ptr.len() - 1;
208    let k = b.ncols();
209    let n_b_rows = b.nrows();
210
211    let mut c = Array2::zeros((n_rows, k));
212
213    let block = match config.backend {
214        GpuSpMvBackend::Cpu => config.block_size.max(1),
215        GpuSpMvBackend::WebGpu => config.block_size.max(1),
216    };
217
218    let mut row_start = 0usize;
219    while row_start < n_rows {
220        let row_end = (row_start + block).min(n_rows);
221        for row in row_start..row_end {
222            let col_start = row_ptr[row];
223            let col_end = row_ptr[row + 1];
224            for k_i in col_start..col_end {
225                let col = col_idx[k_i];
226                if col >= n_b_rows {
227                    return Err(SparseError::DimensionMismatch {
228                        expected: n_b_rows,
229                        found: col + 1,
230                    });
231                }
232                let a_val = values[k_i];
233                for j in 0..k {
234                    c[[row, j]] += a_val * b[[col, j]];
235                }
236            }
237        }
238        row_start = row_end;
239    }
240
241    Ok(c)
242}
243
244// ============================================================
245// Tests
246// ============================================================
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use scirs2_core::ndarray::Array2;
252
253    fn identity_csr(n: usize) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
254        let row_ptr: Vec<usize> = (0..=n).collect();
255        let col_idx: Vec<usize> = (0..n).collect();
256        let values: Vec<f64> = vec![1.0; n];
257        (row_ptr, col_idx, values)
258    }
259
260    #[test]
261    fn test_spmv_identity() {
262        let n = 4;
263        let (row_ptr, col_idx, values) = identity_csr(n);
264        let x = vec![1.0, 2.0, 3.0, 4.0];
265        let config = GpuSpMvConfig::default();
266        let y = csr_spmv(&row_ptr, &col_idx, &values, &x, &config).expect("spmv failed");
267        assert_eq!(y, x);
268    }
269
270    #[test]
271    fn test_spmv_diagonal() {
272        // Diagonal matrix with [2, 3, 5]
273        let row_ptr = vec![0, 1, 2, 3];
274        let col_idx = vec![0, 1, 2];
275        let values = vec![2.0, 3.0, 5.0];
276        let x = vec![1.0, 1.0, 1.0];
277        let config = GpuSpMvConfig::default();
278        let y = csr_spmv(&row_ptr, &col_idx, &values, &x, &config).expect("spmv failed");
279        assert_eq!(y, vec![2.0, 3.0, 5.0]);
280    }
281
282    #[test]
283    fn test_spmv_dense() {
284        // Full 2×3 matrix [[1,2,3],[4,5,6]]
285        let row_ptr = vec![0, 3, 6];
286        let col_idx = vec![0, 1, 2, 0, 1, 2];
287        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
288        let x = vec![1.0, 0.0, 1.0];
289        let config = GpuSpMvConfig::default();
290        let y = csr_spmv(&row_ptr, &col_idx, &values, &x, &config).expect("spmv failed");
291        assert_eq!(y, vec![4.0, 10.0]);
292    }
293
294    #[test]
295    fn test_spmv_batch() {
296        let n = 3;
297        let (row_ptr, col_idx, values) = identity_csr(n);
298        let x_batch = Array2::from_shape_vec((3, 2), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0])
299            .expect("shape error");
300        let config = GpuSpMvConfig::default();
301        let y = csr_spmv_batch(&row_ptr, &col_idx, &values, &x_batch, &config)
302            .expect("spmv_batch failed");
303        assert_eq!(y.shape(), &[3, 2]);
304        assert_eq!(y[[0, 0]], 1.0);
305        assert_eq!(y[[2, 1]], 6.0);
306    }
307
308    #[test]
309    fn test_spmm() {
310        // I * B = B
311        let n = 3;
312        let (row_ptr, col_idx, values) = identity_csr(n);
313        let b = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
314            .expect("shape error");
315        let config = GpuSpMvConfig::default();
316        let c = csr_spmm(&row_ptr, &col_idx, &values, &b, &config).expect("spmm failed");
317        assert_eq!(c, b);
318    }
319}