1use crate::csr::CsrMatrix;
18use crate::error::SparseError;
19
20pub trait SparseBackend {
26 fn spmv_kernel(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]);
28
29 fn spmm_kernel(
33 a: &CsrMatrix<f32>,
34 alpha: f32,
35 b: &[f32],
36 b_cols: usize,
37 beta: f32,
38 c: &mut [f32],
39 );
40}
41
42pub struct ScalarBackend;
44
45impl SparseBackend for ScalarBackend {
46 fn spmv_kernel(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
47 spmv_csr_scalar(a, alpha, x, beta, y);
48 }
49
50 fn spmm_kernel(
51 a: &CsrMatrix<f32>,
52 alpha: f32,
53 b: &[f32],
54 b_cols: usize,
55 beta: f32,
56 c: &mut [f32],
57 ) {
58 spmm_csr_scalar(a, alpha, b, b_cols, beta, c);
59 }
60}
61
62#[cfg(target_arch = "x86_64")]
64pub struct Avx2Backend;
65
66#[cfg(target_arch = "x86_64")]
67impl SparseBackend for Avx2Backend {
68 fn spmv_kernel(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
69 unsafe { spmv_csr_avx2(a, alpha, x, beta, y) }
71 }
72
73 fn spmm_kernel(
74 a: &CsrMatrix<f32>,
75 alpha: f32,
76 b: &[f32],
77 b_cols: usize,
78 beta: f32,
79 c: &mut [f32],
80 ) {
81 spmm_csr_scalar(a, alpha, b, b_cols, beta, c);
83 }
84}
85
86#[cfg(target_arch = "aarch64")]
88pub struct NeonBackend;
89
90#[cfg(target_arch = "aarch64")]
91impl SparseBackend for NeonBackend {
92 fn spmv_kernel(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
93 spmv_csr_scalar(a, alpha, x, beta, y);
95 }
96
97 fn spmm_kernel(
98 a: &CsrMatrix<f32>,
99 alpha: f32,
100 b: &[f32],
101 b_cols: usize,
102 beta: f32,
103 c: &mut [f32],
104 ) {
105 spmm_csr_scalar(a, alpha, b, b_cols, beta, c);
107 }
108}
109
110pub trait SparseOps {
114 fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError>;
125
126 fn spmm(
134 &self,
135 alpha: f32,
136 b: &[f32],
137 b_cols: usize,
138 beta: f32,
139 c: &mut [f32],
140 ) -> Result<(), SparseError>;
141}
142
143impl SparseOps for CsrMatrix<f32> {
144 fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
145 if x.len() != self.cols() {
147 return Err(SparseError::SpMVDimensionMismatch {
148 matrix_cols: self.cols(),
149 x_len: x.len(),
150 });
151 }
152 if y.len() != self.rows() {
153 return Err(SparseError::SpMVOutputDimensionMismatch {
154 matrix_rows: self.rows(),
155 y_len: y.len(),
156 });
157 }
158
159 #[cfg(target_arch = "x86_64")]
161 {
162 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
163 unsafe {
165 spmv_csr_avx2(self, alpha, x, beta, y);
166 return Ok(());
167 }
168 }
169 }
170
171 spmv_csr_scalar(self, alpha, x, beta, y);
173 Ok(())
174 }
175
176 fn spmm(
177 &self,
178 alpha: f32,
179 b: &[f32],
180 b_cols: usize,
181 beta: f32,
182 c: &mut [f32],
183 ) -> Result<(), SparseError> {
184 if b.len() != self.cols() * b_cols {
185 return Err(SparseError::SpMVDimensionMismatch {
186 matrix_cols: self.cols(),
187 x_len: b.len(),
188 });
189 }
190 if c.len() != self.rows() * b_cols {
191 return Err(SparseError::SpMVOutputDimensionMismatch {
192 matrix_rows: self.rows(),
193 y_len: c.len(),
194 });
195 }
196
197 spmm_csr_scalar(self, alpha, b, b_cols, beta, c);
198 Ok(())
199 }
200}
201
202fn spmv_csr_scalar(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
206 let offsets = a.offsets();
207 let col_indices = a.col_indices();
208 let values = a.values();
209
210 for i in 0..a.rows() {
211 let start = offsets[i] as usize;
212 let end = offsets[i + 1] as usize;
213
214 let mut sum = 0.0_f64;
216 let mut comp = 0.0_f64;
217
218 for idx in start..end {
219 let j = col_indices[idx] as usize;
220 let product = f64::from(values[idx]) * f64::from(x[j]);
221 let t = sum + product;
222 if sum.abs() >= product.abs() {
223 comp += (sum - t) + product;
224 } else {
225 comp += (product - t) + sum;
226 }
227 sum = t;
228 }
229 sum += comp;
230
231 y[i] = (f64::from(alpha) * sum + f64::from(beta) * f64::from(y[i])) as f32;
232 }
233}
234
235#[cfg(target_arch = "x86_64")]
240#[target_feature(enable = "avx2,fma")]
241unsafe fn spmv_csr_avx2(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
242 use std::arch::x86_64::*;
243
244 let offsets = a.offsets();
245 let col_indices = a.col_indices();
246 let values = a.values();
247
248 for i in 0..a.rows() {
249 let start = offsets[i] as usize;
250 let end = offsets[i + 1] as usize;
251 let row_nnz = end - start;
252
253 let mut acc = _mm256_setzero_ps();
255
256 let chunks = row_nnz / 8;
258 for c in 0..chunks {
259 let base = start + c * 8;
260 unsafe {
261 let idx = _mm256_loadu_si256(col_indices[base..].as_ptr().cast());
262 let v = _mm256_loadu_ps(values[base..].as_ptr());
263 let x_gathered = _mm256_i32gather_ps::<4>(x.as_ptr(), idx);
264 acc = _mm256_fmadd_ps(v, x_gathered, acc);
265 }
266 }
267
268 let hi = _mm256_extractf128_ps::<1>(acc);
270 let lo = _mm256_castps256_ps128(acc);
271 let sum128 = _mm_add_ps(lo, hi);
272 let shuf = _mm_movehdup_ps(sum128);
273 let sums = _mm_add_ps(sum128, shuf);
274 let shuf2 = _mm_movehl_ps(sums, sums);
275 let sums2 = _mm_add_ss(sums, shuf2);
276 let mut row_sum = _mm_cvtss_f32(sums2);
277
278 for idx in (start + chunks * 8)..end {
280 unsafe {
281 let j = *col_indices.get_unchecked(idx) as usize;
282 row_sum += *values.get_unchecked(idx) * *x.get_unchecked(j);
283 }
284 }
285
286 unsafe {
287 *y.get_unchecked_mut(i) = alpha * row_sum + beta * *y.get_unchecked(i);
288 }
289 }
290}
291
292fn spmm_csr_scalar(
294 a: &CsrMatrix<f32>,
295 alpha: f32,
296 b: &[f32],
297 b_cols: usize,
298 beta: f32,
299 c: &mut [f32],
300) {
301 let offsets = a.offsets();
302 let col_indices = a.col_indices();
303 let values = a.values();
304
305 for i in 0..a.rows() {
306 let start = offsets[i] as usize;
307 let end = offsets[i + 1] as usize;
308
309 for k in 0..b_cols {
311 c[i * b_cols + k] *= beta;
312 }
313
314 for idx in start..end {
316 let j = col_indices[idx] as usize;
317 let a_val = alpha * values[idx];
318 for k in 0..b_cols {
319 c[i * b_cols + k] += a_val * b[j * b_cols + k];
320 }
321 }
322 }
323}