Skip to main content

trueno_sparse/
sell.rs

1//! Sliced ELLPACK (SELL) sparse matrix format.
2//!
3//! # Contract: sparse-formats-v1.yaml
4//!
5//! SELL-C-σ format: rows sorted by length within slices of C rows.
6//! Each slice is padded to the max row length in that slice.
7//! This gives SIMD-friendly contiguous access patterns.
8//!
9//! ## References
10//! - Kreutzer et al., "A unified sparse matrix data format for modern processors", 2014
11
12use crate::csr::CsrMatrix;
13use crate::error::SparseError;
14
15/// Sliced ELLPACK sparse matrix.
16///
17/// Rows are grouped into slices of `slice_size` rows. Within each slice,
18/// columns and values are stored in column-major order, padded to the
19/// max row length in that slice.
20#[derive(Debug, Clone)]
21pub struct SellMatrix {
22    rows: usize,
23    cols: usize,
24    slice_size: usize,
25    /// Number of slices = ceil(rows / slice_size).
26    num_slices: usize,
27    /// Offset into col_indices/values for each slice (len = num_slices + 1).
28    slice_offsets: Vec<u32>,
29    /// Max row length in each slice (len = num_slices).
30    slice_widths: Vec<u32>,
31    /// Column indices (padded, column-major within each slice).
32    col_indices: Vec<u32>,
33    /// Values (padded, column-major within each slice).
34    values: Vec<f32>,
35}
36
37impl SellMatrix {
38    /// Convert a CSR matrix to SELL format with the given slice size.
39    ///
40    /// Typical slice_size: 32 or 64 (matching SIMD width or warp size).
41    #[must_use]
42    pub fn from_csr(csr: &CsrMatrix<f32>, slice_size: usize) -> Self {
43        let rows = csr.rows();
44        let cols = csr.cols();
45        let c = if slice_size == 0 { 1 } else { slice_size };
46        let num_slices = rows.div_ceil(c);
47
48        let mut slice_offsets = Vec::with_capacity(num_slices + 1);
49        let mut slice_widths = Vec::with_capacity(num_slices);
50        let mut col_indices = Vec::new();
51        let mut values = Vec::new();
52
53        slice_offsets.push(0u32);
54
55        for s in 0..num_slices {
56            let row_start = s * c;
57            let row_end = (row_start + c).min(rows);
58            let actual_rows = row_end - row_start;
59
60            // Find max row length in this slice
61            let max_len = compute_slice_width(csr, row_start, row_end);
62            slice_widths.push(max_len as u32);
63
64            // Store in column-major order within the slice
65            fill_slice_data(
66                csr,
67                row_start,
68                actual_rows,
69                c,
70                max_len,
71                &mut col_indices,
72                &mut values,
73            );
74
75            let slice_elements = c * max_len;
76            let offset = slice_offsets.last().copied().unwrap_or(0);
77            slice_offsets.push(offset + slice_elements as u32);
78        }
79
80        Self {
81            rows,
82            cols,
83            slice_size: c,
84            num_slices,
85            slice_offsets,
86            slice_widths,
87            col_indices,
88            values,
89        }
90    }
91
92    /// SpMV: y = α·A·x + β·y
93    ///
94    /// # Errors
95    ///
96    /// Returns error on dimension mismatch.
97    pub fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
98        if x.len() != self.cols {
99            return Err(SparseError::SpMVDimensionMismatch {
100                matrix_cols: self.cols,
101                x_len: x.len(),
102            });
103        }
104        if y.len() != self.rows {
105            return Err(SparseError::SpMVOutputDimensionMismatch {
106                matrix_rows: self.rows,
107                y_len: y.len(),
108            });
109        }
110
111        // Scale y by beta
112        for val in y.iter_mut() {
113            *val *= beta;
114        }
115
116        let c = self.slice_size;
117
118        for s in 0..self.num_slices {
119            let base = self.slice_offsets[s] as usize;
120            let width = self.slice_widths[s] as usize;
121            let row_start = s * c;
122            let row_end = (row_start + c).min(self.rows);
123
124            spmv_slice(
125                &self.col_indices,
126                &self.values,
127                x,
128                y,
129                alpha,
130                base,
131                c,
132                width,
133                row_start,
134                row_end,
135            );
136        }
137
138        Ok(())
139    }
140
141    /// Number of rows.
142    #[must_use]
143    pub fn rows(&self) -> usize {
144        self.rows
145    }
146
147    /// Number of columns.
148    #[must_use]
149    pub fn cols(&self) -> usize {
150        self.cols
151    }
152
153    /// Slice size (C parameter).
154    #[must_use]
155    pub fn slice_size(&self) -> usize {
156        self.slice_size
157    }
158
159    /// Total stored elements (including padding zeros).
160    #[must_use]
161    pub fn storage_size(&self) -> usize {
162        self.values.len()
163    }
164}
165
166/// Compute max row length in a slice.
167fn compute_slice_width(csr: &CsrMatrix<f32>, row_start: usize, row_end: usize) -> usize {
168    let offsets = csr.offsets();
169    let mut max_len = 0usize;
170    for r in row_start..row_end {
171        let len = (offsets[r + 1] - offsets[r]) as usize;
172        if len > max_len {
173            max_len = len;
174        }
175    }
176    max_len
177}
178
179/// Fill column-major data for one slice.
180fn fill_slice_data(
181    csr: &CsrMatrix<f32>,
182    row_start: usize,
183    actual_rows: usize,
184    c: usize,
185    max_len: usize,
186    col_indices: &mut Vec<u32>,
187    values: &mut Vec<f32>,
188) {
189    let csr_off = csr.offsets();
190    let csr_cols = csr.col_indices();
191    let csr_vals = csr.values();
192
193    // Column-major: for each column position j, store all rows
194    for j in 0..max_len {
195        for local_r in 0..c {
196            let global_r = row_start + local_r;
197            if local_r < actual_rows {
198                let row_start_idx = csr_off[global_r] as usize;
199                let row_len = (csr_off[global_r + 1] - csr_off[global_r]) as usize;
200                if j < row_len {
201                    col_indices.push(csr_cols[row_start_idx + j]);
202                    values.push(csr_vals[row_start_idx + j]);
203                } else {
204                    col_indices.push(0);
205                    values.push(0.0);
206                }
207            } else {
208                // Padding rows (beyond actual matrix rows)
209                col_indices.push(0);
210                values.push(0.0);
211            }
212        }
213    }
214}
215
216/// SpMV for one SELL slice.
217#[allow(clippy::too_many_arguments)]
218fn spmv_slice(
219    col_indices: &[u32],
220    values: &[f32],
221    x: &[f32],
222    y: &mut [f32],
223    alpha: f32,
224    base: usize,
225    c: usize,
226    width: usize,
227    row_start: usize,
228    row_end: usize,
229) {
230    for j in 0..width {
231        for local_r in 0..(row_end - row_start) {
232            let idx = base + j * c + local_r;
233            let col = col_indices[idx] as usize;
234            let val = values[idx];
235            y[row_start + local_r] += alpha * val * x[col];
236        }
237    }
238}