Skip to main content

trueno_sparse/
ops.rs

1//! Sparse matrix operations (SpMV, SpMM).
2//!
3//! # Contract: sparse-spmv-v1.yaml
4//!
5//! SpMV equation: `y_i = α · Σ_j A_{ij} · x_j + β · y_i`
6//!
7//! ## Proof obligations
8//! - Output dimension: `len(y) == A.rows()`
9//! - Backward error: `|Ax - y_exact| ≤ nnz_per_row · u · |A| · |x|`
10//! - SIMD-scalar equivalence: within 8 ULP
11//!
12//! ## Kernel phases
13//! 1. format_validation (at construction)
14//! 2. row-split accumulation (scalar/SIMD)
15//! 3. output scaling (α, β)
16
17use crate::csr::CsrMatrix;
18use crate::error::SparseError;
19
20/// Backend dispatch trait for pluggable SIMD SpMV kernels.
21///
22/// Implementations provide SpMV for a specific hardware target.
23/// The default dispatch in `SparseOps::spmv` selects the best
24/// available backend at runtime.
25pub trait SparseBackend {
26    /// Perform SpMV: `y = alpha * A * x + beta * y` using this backend.
27    fn spmv_kernel(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]);
28
29    /// Perform SpMM: `C = alpha * A * B + beta * C` using this backend.
30    ///
31    /// B is row-major with `b_cols` columns.
32    fn spmm_kernel(
33        a: &CsrMatrix<f32>,
34        alpha: f32,
35        b: &[f32],
36        b_cols: usize,
37        beta: f32,
38        c: &mut [f32],
39    );
40}
41
42/// Scalar (portable) SpMV backend.
43pub struct ScalarBackend;
44
45impl SparseBackend for ScalarBackend {
46    fn spmv_kernel(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
47        spmv_csr_scalar(a, alpha, x, beta, y);
48    }
49
50    fn spmm_kernel(
51        a: &CsrMatrix<f32>,
52        alpha: f32,
53        b: &[f32],
54        b_cols: usize,
55        beta: f32,
56        c: &mut [f32],
57    ) {
58        spmm_csr_scalar(a, alpha, b, b_cols, beta, c);
59    }
60}
61
62/// AVX2 SpMV backend (x86_64 with AVX2+FMA).
63#[cfg(target_arch = "x86_64")]
64pub struct Avx2Backend;
65
66#[cfg(target_arch = "x86_64")]
67impl SparseBackend for Avx2Backend {
68    fn spmv_kernel(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
69        // SAFETY: caller must ensure AVX2+FMA is available
70        unsafe { spmv_csr_avx2(a, alpha, x, beta, y) }
71    }
72
73    fn spmm_kernel(
74        a: &CsrMatrix<f32>,
75        alpha: f32,
76        b: &[f32],
77        b_cols: usize,
78        beta: f32,
79        c: &mut [f32],
80    ) {
81        // AVX2 SpMM not yet specialized — fall back to scalar
82        spmm_csr_scalar(a, alpha, b, b_cols, beta, c);
83    }
84}
85
86/// NEON SpMV backend stub (aarch64).
87#[cfg(target_arch = "aarch64")]
88pub struct NeonBackend;
89
90#[cfg(target_arch = "aarch64")]
91impl SparseBackend for NeonBackend {
92    fn spmv_kernel(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
93        // NEON dispatch not yet implemented — falls back to scalar
94        spmv_csr_scalar(a, alpha, x, beta, y);
95    }
96
97    fn spmm_kernel(
98        a: &CsrMatrix<f32>,
99        alpha: f32,
100        b: &[f32],
101        b_cols: usize,
102        beta: f32,
103        c: &mut [f32],
104    ) {
105        // NEON SpMM not yet specialized — fall back to scalar
106        spmm_csr_scalar(a, alpha, b, b_cols, beta, c);
107    }
108}
109
110/// Sparse matrix operations trait.
111///
112/// Provides SpMV and SpMM with provable error bounds.
113pub trait SparseOps {
114    /// Sparse matrix-vector multiply: `y = α * A * x + β * y`
115    ///
116    /// # Contract: sparse-spmv-v1.yaml / spmv
117    ///
118    /// **Preconditions**: `x.len() == self.cols()`, `y.len() == self.rows()`
119    /// **Postcondition**: backward error ≤ `nnz_per_row * f32::EPSILON * ||A||_inf * ||x||_inf`
120    ///
121    /// # Errors
122    ///
123    /// Returns error on dimension mismatch.
124    fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError>;
125
126    /// Sparse matrix-dense matrix multiply: `C = α * A * B + β * C`
127    ///
128    /// B is row-major with `b_cols` columns. C is row-major with `b_cols` columns.
129    ///
130    /// # Errors
131    ///
132    /// Returns error on dimension mismatch.
133    fn spmm(
134        &self,
135        alpha: f32,
136        b: &[f32],
137        b_cols: usize,
138        beta: f32,
139        c: &mut [f32],
140    ) -> Result<(), SparseError>;
141}
142
143impl SparseOps for CsrMatrix<f32> {
144    fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
145        // Dimension checks (contract enforcement)
146        if x.len() != self.cols() {
147            return Err(SparseError::SpMVDimensionMismatch {
148                matrix_cols: self.cols(),
149                x_len: x.len(),
150            });
151        }
152        if y.len() != self.rows() {
153            return Err(SparseError::SpMVOutputDimensionMismatch {
154                matrix_rows: self.rows(),
155                y_len: y.len(),
156            });
157        }
158
159        // Dispatch to best available backend
160        #[cfg(target_arch = "x86_64")]
161        {
162            if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
163                // SAFETY: AVX2+FMA detected at runtime
164                unsafe {
165                    spmv_csr_avx2(self, alpha, x, beta, y);
166                    return Ok(());
167                }
168            }
169        }
170
171        // Scalar fallback
172        spmv_csr_scalar(self, alpha, x, beta, y);
173        Ok(())
174    }
175
176    fn spmm(
177        &self,
178        alpha: f32,
179        b: &[f32],
180        b_cols: usize,
181        beta: f32,
182        c: &mut [f32],
183    ) -> Result<(), SparseError> {
184        if b.len() != self.cols() * b_cols {
185            return Err(SparseError::SpMVDimensionMismatch {
186                matrix_cols: self.cols(),
187                x_len: b.len(),
188            });
189        }
190        if c.len() != self.rows() * b_cols {
191            return Err(SparseError::SpMVOutputDimensionMismatch {
192                matrix_rows: self.rows(),
193                y_len: c.len(),
194            });
195        }
196
197        spmm_csr_scalar(self, alpha, b, b_cols, beta, c);
198        Ok(())
199    }
200}
201
202/// Scalar SpMV reference implementation.
203///
204/// Contract: this is the ground truth for SIMD/GPU parity testing.
205fn spmv_csr_scalar(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
206    let offsets = a.offsets();
207    let col_indices = a.col_indices();
208    let values = a.values();
209
210    for i in 0..a.rows() {
211        let start = offsets[i] as usize;
212        let end = offsets[i + 1] as usize;
213
214        // Kahan summation for improved accuracy (LAProof-aligned)
215        let mut sum = 0.0_f64;
216        let mut comp = 0.0_f64;
217
218        for idx in start..end {
219            let j = col_indices[idx] as usize;
220            let product = f64::from(values[idx]) * f64::from(x[j]);
221            let t = sum + product;
222            if sum.abs() >= product.abs() {
223                comp += (sum - t) + product;
224            } else {
225                comp += (product - t) + sum;
226            }
227            sum = t;
228        }
229        sum += comp;
230
231        y[i] = (f64::from(alpha) * sum + f64::from(beta) * f64::from(y[i])) as f32;
232    }
233}
234
235/// AVX2 SpMV with gather instructions.
236///
237/// Uses `_mm256_i32gather_ps` for indirect x[col_indices[j]] access
238/// and FMA for accumulation.
239#[cfg(target_arch = "x86_64")]
240#[target_feature(enable = "avx2,fma")]
241unsafe fn spmv_csr_avx2(a: &CsrMatrix<f32>, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) {
242    use std::arch::x86_64::*;
243
244    let offsets = a.offsets();
245    let col_indices = a.col_indices();
246    let values = a.values();
247
248    for i in 0..a.rows() {
249        let start = offsets[i] as usize;
250        let end = offsets[i + 1] as usize;
251        let row_nnz = end - start;
252
253        // SAFETY: AVX2 feature gate checked by caller via is_x86_feature_detected
254        let mut acc = _mm256_setzero_ps();
255
256        // Process 8 elements at a time
257        let chunks = row_nnz / 8;
258        for c in 0..chunks {
259            let base = start + c * 8;
260            unsafe {
261                let idx = _mm256_loadu_si256(col_indices[base..].as_ptr().cast());
262                let v = _mm256_loadu_ps(values[base..].as_ptr());
263                let x_gathered = _mm256_i32gather_ps::<4>(x.as_ptr(), idx);
264                acc = _mm256_fmadd_ps(v, x_gathered, acc);
265            }
266        }
267
268        // Horizontal sum of acc
269        let hi = _mm256_extractf128_ps::<1>(acc);
270        let lo = _mm256_castps256_ps128(acc);
271        let sum128 = _mm_add_ps(lo, hi);
272        let shuf = _mm_movehdup_ps(sum128);
273        let sums = _mm_add_ps(sum128, shuf);
274        let shuf2 = _mm_movehl_ps(sums, sums);
275        let sums2 = _mm_add_ss(sums, shuf2);
276        let mut row_sum = _mm_cvtss_f32(sums2);
277
278        // Scalar tail for remaining elements
279        for idx in (start + chunks * 8)..end {
280            unsafe {
281                let j = *col_indices.get_unchecked(idx) as usize;
282                row_sum += *values.get_unchecked(idx) * *x.get_unchecked(j);
283            }
284        }
285
286        unsafe {
287            *y.get_unchecked_mut(i) = alpha * row_sum + beta * *y.get_unchecked(i);
288        }
289    }
290}
291
292/// Scalar SpMM reference implementation.
293fn spmm_csr_scalar(
294    a: &CsrMatrix<f32>,
295    alpha: f32,
296    b: &[f32],
297    b_cols: usize,
298    beta: f32,
299    c: &mut [f32],
300) {
301    let offsets = a.offsets();
302    let col_indices = a.col_indices();
303    let values = a.values();
304
305    for i in 0..a.rows() {
306        let start = offsets[i] as usize;
307        let end = offsets[i + 1] as usize;
308
309        // Scale existing C values by beta
310        for k in 0..b_cols {
311            c[i * b_cols + k] *= beta;
312        }
313
314        // Accumulate A[i,:] * B
315        for idx in start..end {
316            let j = col_indices[idx] as usize;
317            let a_val = alpha * values[idx];
318            for k in 0..b_cols {
319                c[i * b_cols + k] += a_val * b[j * b_cols + k];
320            }
321        }
322    }
323}