1use crate::csr::CsrMatrix;
13use crate::error::SparseError;
14
15#[derive(Debug, Clone)]
21pub struct SellMatrix {
22 rows: usize,
23 cols: usize,
24 slice_size: usize,
25 num_slices: usize,
27 slice_offsets: Vec<u32>,
29 slice_widths: Vec<u32>,
31 col_indices: Vec<u32>,
33 values: Vec<f32>,
35}
36
37impl SellMatrix {
38 #[must_use]
42 pub fn from_csr(csr: &CsrMatrix<f32>, slice_size: usize) -> Self {
43 let rows = csr.rows();
44 let cols = csr.cols();
45 let c = if slice_size == 0 { 1 } else { slice_size };
46 let num_slices = rows.div_ceil(c);
47
48 let mut slice_offsets = Vec::with_capacity(num_slices + 1);
49 let mut slice_widths = Vec::with_capacity(num_slices);
50 let mut col_indices = Vec::new();
51 let mut values = Vec::new();
52
53 slice_offsets.push(0u32);
54
55 for s in 0..num_slices {
56 let row_start = s * c;
57 let row_end = (row_start + c).min(rows);
58 let actual_rows = row_end - row_start;
59
60 let max_len = compute_slice_width(csr, row_start, row_end);
62 slice_widths.push(max_len as u32);
63
64 fill_slice_data(
66 csr,
67 row_start,
68 actual_rows,
69 c,
70 max_len,
71 &mut col_indices,
72 &mut values,
73 );
74
75 let slice_elements = c * max_len;
76 let offset = slice_offsets.last().copied().unwrap_or(0);
77 slice_offsets.push(offset + slice_elements as u32);
78 }
79
80 Self {
81 rows,
82 cols,
83 slice_size: c,
84 num_slices,
85 slice_offsets,
86 slice_widths,
87 col_indices,
88 values,
89 }
90 }
91
92 pub fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
98 if x.len() != self.cols {
99 return Err(SparseError::SpMVDimensionMismatch {
100 matrix_cols: self.cols,
101 x_len: x.len(),
102 });
103 }
104 if y.len() != self.rows {
105 return Err(SparseError::SpMVOutputDimensionMismatch {
106 matrix_rows: self.rows,
107 y_len: y.len(),
108 });
109 }
110
111 for val in y.iter_mut() {
113 *val *= beta;
114 }
115
116 let c = self.slice_size;
117
118 for s in 0..self.num_slices {
119 let base = self.slice_offsets[s] as usize;
120 let width = self.slice_widths[s] as usize;
121 let row_start = s * c;
122 let row_end = (row_start + c).min(self.rows);
123
124 spmv_slice(
125 &self.col_indices,
126 &self.values,
127 x,
128 y,
129 alpha,
130 base,
131 c,
132 width,
133 row_start,
134 row_end,
135 );
136 }
137
138 Ok(())
139 }
140
141 #[must_use]
143 pub fn rows(&self) -> usize {
144 self.rows
145 }
146
147 #[must_use]
149 pub fn cols(&self) -> usize {
150 self.cols
151 }
152
153 #[must_use]
155 pub fn slice_size(&self) -> usize {
156 self.slice_size
157 }
158
159 #[must_use]
161 pub fn storage_size(&self) -> usize {
162 self.values.len()
163 }
164}
165
166fn compute_slice_width(csr: &CsrMatrix<f32>, row_start: usize, row_end: usize) -> usize {
168 let offsets = csr.offsets();
169 let mut max_len = 0usize;
170 for r in row_start..row_end {
171 let len = (offsets[r + 1] - offsets[r]) as usize;
172 if len > max_len {
173 max_len = len;
174 }
175 }
176 max_len
177}
178
179fn fill_slice_data(
181 csr: &CsrMatrix<f32>,
182 row_start: usize,
183 actual_rows: usize,
184 c: usize,
185 max_len: usize,
186 col_indices: &mut Vec<u32>,
187 values: &mut Vec<f32>,
188) {
189 let csr_off = csr.offsets();
190 let csr_cols = csr.col_indices();
191 let csr_vals = csr.values();
192
193 for j in 0..max_len {
195 for local_r in 0..c {
196 let global_r = row_start + local_r;
197 if local_r < actual_rows {
198 let row_start_idx = csr_off[global_r] as usize;
199 let row_len = (csr_off[global_r + 1] - csr_off[global_r]) as usize;
200 if j < row_len {
201 col_indices.push(csr_cols[row_start_idx + j]);
202 values.push(csr_vals[row_start_idx + j]);
203 } else {
204 col_indices.push(0);
205 values.push(0.0);
206 }
207 } else {
208 col_indices.push(0);
210 values.push(0.0);
211 }
212 }
213 }
214}
215
216#[allow(clippy::too_many_arguments)]
218fn spmv_slice(
219 col_indices: &[u32],
220 values: &[f32],
221 x: &[f32],
222 y: &mut [f32],
223 alpha: f32,
224 base: usize,
225 c: usize,
226 width: usize,
227 row_start: usize,
228 row_end: usize,
229) {
230 for j in 0..width {
231 for local_r in 0..(row_end - row_start) {
232 let idx = base + j * c + local_r;
233 let col = col_indices[idx] as usize;
234 let val = values[idx];
235 y[row_start + local_r] += alpha * val * x[col];
236 }
237 }
238}