1use crate::csr::CsrMatrix;
13use crate::error::SparseError;
14
15#[derive(Debug, Clone)]
21pub struct SellMatrix {
22 rows: usize,
23 cols: usize,
24 slice_size: usize,
25 num_slices: usize,
27 slice_offsets: Vec<u32>,
29 slice_widths: Vec<u32>,
31 col_indices: Vec<u32>,
33 values: Vec<f32>,
35}
36
37impl SellMatrix {
38 #[must_use]
42 pub fn from_csr(csr: &CsrMatrix<f32>, slice_size: usize) -> Self {
43 let rows = csr.rows();
44 let cols = csr.cols();
45 let c = if slice_size == 0 { 1 } else { slice_size };
46 let num_slices = rows.div_ceil(c);
47
48 let mut slice_offsets = Vec::with_capacity(num_slices + 1);
49 let mut slice_widths = Vec::with_capacity(num_slices);
50 let mut col_indices = Vec::new();
51 let mut values = Vec::new();
52
53 slice_offsets.push(0u32);
54
55 for s in 0..num_slices {
56 let row_start = s * c;
57 let row_end = (row_start + c).min(rows);
58 let actual_rows = row_end - row_start;
59
60 let max_len = compute_slice_width(csr, row_start, row_end);
62 slice_widths.push(max_len as u32);
63
64 fill_slice_data(csr, row_start, actual_rows, c, max_len, &mut col_indices, &mut values);
66
67 let slice_elements = c * max_len;
68 let offset = slice_offsets.last().copied().unwrap_or(0);
69 slice_offsets.push(offset + slice_elements as u32);
70 }
71
72 Self {
73 rows,
74 cols,
75 slice_size: c,
76 num_slices,
77 slice_offsets,
78 slice_widths,
79 col_indices,
80 values,
81 }
82 }
83
84 pub fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
90 if x.len() != self.cols {
91 return Err(SparseError::SpMVDimensionMismatch {
92 matrix_cols: self.cols,
93 x_len: x.len(),
94 });
95 }
96 if y.len() != self.rows {
97 return Err(SparseError::SpMVOutputDimensionMismatch {
98 matrix_rows: self.rows,
99 y_len: y.len(),
100 });
101 }
102
103 for val in y.iter_mut() {
105 *val *= beta;
106 }
107
108 let c = self.slice_size;
109
110 for s in 0..self.num_slices {
111 let base = self.slice_offsets[s] as usize;
112 let width = self.slice_widths[s] as usize;
113 let row_start = s * c;
114 let row_end = (row_start + c).min(self.rows);
115
116 spmv_slice(
117 &self.col_indices,
118 &self.values,
119 x,
120 y,
121 alpha,
122 base,
123 c,
124 width,
125 row_start,
126 row_end,
127 );
128 }
129
130 Ok(())
131 }
132
133 #[must_use]
135 pub fn rows(&self) -> usize {
136 self.rows
137 }
138
139 #[must_use]
141 pub fn cols(&self) -> usize {
142 self.cols
143 }
144
145 #[must_use]
147 pub fn slice_size(&self) -> usize {
148 self.slice_size
149 }
150
151 #[must_use]
153 pub fn storage_size(&self) -> usize {
154 self.values.len()
155 }
156}
157
158fn compute_slice_width(csr: &CsrMatrix<f32>, row_start: usize, row_end: usize) -> usize {
160 let offsets = csr.offsets();
161 let mut max_len = 0usize;
162 for r in row_start..row_end {
163 let len = (offsets[r + 1] - offsets[r]) as usize;
164 if len > max_len {
165 max_len = len;
166 }
167 }
168 max_len
169}
170
171fn fill_slice_data(
173 csr: &CsrMatrix<f32>,
174 row_start: usize,
175 actual_rows: usize,
176 c: usize,
177 max_len: usize,
178 col_indices: &mut Vec<u32>,
179 values: &mut Vec<f32>,
180) {
181 let csr_off = csr.offsets();
182 let csr_cols = csr.col_indices();
183 let csr_vals = csr.values();
184
185 for j in 0..max_len {
187 for local_r in 0..c {
188 let global_r = row_start + local_r;
189 if local_r < actual_rows {
190 let row_start_idx = csr_off[global_r] as usize;
191 let row_len = (csr_off[global_r + 1] - csr_off[global_r]) as usize;
192 if j < row_len {
193 col_indices.push(csr_cols[row_start_idx + j]);
194 values.push(csr_vals[row_start_idx + j]);
195 } else {
196 col_indices.push(0);
197 values.push(0.0);
198 }
199 } else {
200 col_indices.push(0);
202 values.push(0.0);
203 }
204 }
205 }
206}
207
208#[allow(clippy::too_many_arguments)]
210fn spmv_slice(
211 col_indices: &[u32],
212 values: &[f32],
213 x: &[f32],
214 y: &mut [f32],
215 alpha: f32,
216 base: usize,
217 c: usize,
218 width: usize,
219 row_start: usize,
220 row_end: usize,
221) {
222 for j in 0..width {
223 for local_r in 0..(row_end - row_start) {
224 let idx = base + j * c + local_r;
225 let col = col_indices[idx] as usize;
226 let val = values[idx];
227 y[row_start + local_r] += alpha * val * x[col];
228 }
229 }
230}