1use crate::error::{SparseError, SparseResult};
8use scirs2_core::ndarray::{Array2, Axis};
9
10#[non_exhaustive]
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum GpuSpMvBackend {
18 #[default]
20 Cpu,
21 WebGpu,
23}
24
25#[derive(Debug, Clone)]
27pub struct GpuSpMvConfig {
28 pub backend: GpuSpMvBackend,
30 pub block_size: usize,
32 pub n_warps: usize,
34 pub use_texture: bool,
36}
37
38impl Default for GpuSpMvConfig {
39 fn default() -> Self {
40 Self {
41 backend: GpuSpMvBackend::Cpu,
42 block_size: 256,
43 n_warps: 8,
44 use_texture: false,
45 }
46 }
47}
48
49pub fn csr_spmv(
64 row_ptr: &[usize],
65 col_idx: &[usize],
66 values: &[f64],
67 x: &[f64],
68 config: &GpuSpMvConfig,
69) -> SparseResult<Vec<f64>> {
70 if row_ptr.is_empty() {
71 return Ok(Vec::new());
72 }
73 let n_rows = row_ptr.len() - 1;
74
75 if col_idx.len() != values.len() {
77 return Err(SparseError::InconsistentData {
78 reason: format!(
79 "col_idx length {} != values length {}",
80 col_idx.len(),
81 values.len()
82 ),
83 });
84 }
85
86 let mut y = vec![0.0_f64; n_rows];
87
88 match config.backend {
89 GpuSpMvBackend::Cpu => {
90 let block = config.block_size.max(1);
92 let mut row_start = 0usize;
93 while row_start < n_rows {
94 let row_end = (row_start + block).min(n_rows);
95 for row in row_start..row_end {
96 let col_start = row_ptr[row];
97 let col_end = row_ptr[row + 1];
98 let mut acc = 0.0_f64;
99 for k in col_start..col_end {
100 let col = col_idx[k];
101 if col >= x.len() {
102 return Err(SparseError::DimensionMismatch {
103 expected: x.len(),
104 found: col + 1,
105 });
106 }
107 acc += values[k] * x[col];
108 }
109 y[row] = acc;
110 }
111 row_start = row_end;
112 }
113 }
114 GpuSpMvBackend::WebGpu => {
115 for row in 0..n_rows {
117 let col_start = row_ptr[row];
118 let col_end = row_ptr[row + 1];
119 let mut acc = 0.0_f64;
120 for k in col_start..col_end {
121 let col = col_idx[k];
122 if col >= x.len() {
123 return Err(SparseError::DimensionMismatch {
124 expected: x.len(),
125 found: col + 1,
126 });
127 }
128 acc += values[k] * x[col];
129 }
130 y[row] = acc;
131 }
132 }
133 }
134
135 Ok(y)
136}
137
138pub fn csr_spmv_batch(
152 row_ptr: &[usize],
153 col_idx: &[usize],
154 values: &[f64],
155 x_batch: &Array2<f64>,
156 config: &GpuSpMvConfig,
157) -> SparseResult<Array2<f64>> {
158 if row_ptr.is_empty() {
159 return Ok(Array2::zeros((0, x_batch.ncols())));
160 }
161 let n_rows = row_ptr.len() - 1;
162 let n_rhs = x_batch.ncols();
163 let n_cols = x_batch.nrows();
164
165 let mut y = Array2::zeros((n_rows, n_rhs));
166
167 for rhs in 0..n_rhs {
168 let x_col = x_batch.index_axis(Axis(1), rhs);
169 let x_slice: Vec<f64> = x_col.iter().copied().collect();
170 if x_slice.len() != n_cols {
171 return Err(SparseError::DimensionMismatch {
172 expected: n_cols,
173 found: x_slice.len(),
174 });
175 }
176 let y_col = csr_spmv(row_ptr, col_idx, values, &x_slice, config)?;
177 for row in 0..n_rows {
178 y[[row, rhs]] = y_col[row];
179 }
180 }
181
182 Ok(y)
183}
184
185pub fn csr_spmm(
198 row_ptr: &[usize],
199 col_idx: &[usize],
200 values: &[f64],
201 b: &Array2<f64>,
202 config: &GpuSpMvConfig,
203) -> SparseResult<Array2<f64>> {
204 if row_ptr.is_empty() {
205 return Ok(Array2::zeros((0, b.ncols())));
206 }
207 let n_rows = row_ptr.len() - 1;
208 let k = b.ncols();
209 let n_b_rows = b.nrows();
210
211 let mut c = Array2::zeros((n_rows, k));
212
213 let block = match config.backend {
214 GpuSpMvBackend::Cpu => config.block_size.max(1),
215 GpuSpMvBackend::WebGpu => config.block_size.max(1),
216 };
217
218 let mut row_start = 0usize;
219 while row_start < n_rows {
220 let row_end = (row_start + block).min(n_rows);
221 for row in row_start..row_end {
222 let col_start = row_ptr[row];
223 let col_end = row_ptr[row + 1];
224 for k_i in col_start..col_end {
225 let col = col_idx[k_i];
226 if col >= n_b_rows {
227 return Err(SparseError::DimensionMismatch {
228 expected: n_b_rows,
229 found: col + 1,
230 });
231 }
232 let a_val = values[k_i];
233 for j in 0..k {
234 c[[row, j]] += a_val * b[[col, j]];
235 }
236 }
237 }
238 row_start = row_end;
239 }
240
241 Ok(c)
242}
243
244#[cfg(test)]
249mod tests {
250 use super::*;
251 use scirs2_core::ndarray::Array2;
252
253 fn identity_csr(n: usize) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
254 let row_ptr: Vec<usize> = (0..=n).collect();
255 let col_idx: Vec<usize> = (0..n).collect();
256 let values: Vec<f64> = vec![1.0; n];
257 (row_ptr, col_idx, values)
258 }
259
260 #[test]
261 fn test_spmv_identity() {
262 let n = 4;
263 let (row_ptr, col_idx, values) = identity_csr(n);
264 let x = vec![1.0, 2.0, 3.0, 4.0];
265 let config = GpuSpMvConfig::default();
266 let y = csr_spmv(&row_ptr, &col_idx, &values, &x, &config).expect("spmv failed");
267 assert_eq!(y, x);
268 }
269
270 #[test]
271 fn test_spmv_diagonal() {
272 let row_ptr = vec![0, 1, 2, 3];
274 let col_idx = vec![0, 1, 2];
275 let values = vec![2.0, 3.0, 5.0];
276 let x = vec![1.0, 1.0, 1.0];
277 let config = GpuSpMvConfig::default();
278 let y = csr_spmv(&row_ptr, &col_idx, &values, &x, &config).expect("spmv failed");
279 assert_eq!(y, vec![2.0, 3.0, 5.0]);
280 }
281
282 #[test]
283 fn test_spmv_dense() {
284 let row_ptr = vec![0, 3, 6];
286 let col_idx = vec![0, 1, 2, 0, 1, 2];
287 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
288 let x = vec![1.0, 0.0, 1.0];
289 let config = GpuSpMvConfig::default();
290 let y = csr_spmv(&row_ptr, &col_idx, &values, &x, &config).expect("spmv failed");
291 assert_eq!(y, vec![4.0, 10.0]);
292 }
293
294 #[test]
295 fn test_spmv_batch() {
296 let n = 3;
297 let (row_ptr, col_idx, values) = identity_csr(n);
298 let x_batch = Array2::from_shape_vec((3, 2), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0])
299 .expect("shape error");
300 let config = GpuSpMvConfig::default();
301 let y = csr_spmv_batch(&row_ptr, &col_idx, &values, &x_batch, &config)
302 .expect("spmv_batch failed");
303 assert_eq!(y.shape(), &[3, 2]);
304 assert_eq!(y[[0, 0]], 1.0);
305 assert_eq!(y[[2, 1]], 6.0);
306 }
307
308 #[test]
309 fn test_spmm() {
310 let n = 3;
312 let (row_ptr, col_idx, values) = identity_csr(n);
313 let b = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
314 .expect("shape error");
315 let config = GpuSpMvConfig::default();
316 let c = csr_spmm(&row_ptr, &col_idx, &values, &b, &config).expect("spmm failed");
317 assert_eq!(c, b);
318 }
319}