1use crate::{Flat, HardwareField};
19use alloc::vec::Vec;
20use core::arch::asm;
21use core::mem::MaybeUninit;
22#[cfg(feature = "parallel")]
23use rayon::prelude::*;
24
25const CHUNK_SIZE: usize = 1024;
28
29#[cfg(feature = "parallel")]
33const PARALLEL_THRESHOLD: usize = 32768;
34
35const LOOKAHEAD: usize = 8;
39
40pub trait VectorSource<F>: Sync {
44 fn len(&self) -> usize;
46
47 fn is_empty(&self) -> bool;
48
49 fn get_at(&self, index: usize) -> F;
51
52 #[inline(always)]
55 fn get_batch<const N: usize>(&self, indices: &[usize; N]) -> [F; N] {
56 core::array::from_fn(|i| self.get_at(indices[i]))
57 }
58
59 #[inline(always)]
61 fn prefetch(&self, _indices: &[usize]) {
62 }
64}
65
66impl<F: Copy + Sync> VectorSource<F> for [F] {
69 #[inline(always)]
70 fn len(&self) -> usize {
71 self.len()
72 }
73
74 #[inline(always)]
75 fn is_empty(&self) -> bool {
76 self.is_empty()
77 }
78
79 #[inline(always)]
80 fn get_at(&self, index: usize) -> F {
81 self[index]
82 }
83
84 #[inline(always)]
86 fn prefetch(&self, indices: &[usize]) {
87 let base_ptr = self.as_ptr();
88 for &idx in indices {
89 unsafe {
90 let ptr = base_ptr.wrapping_add(idx) as *const u8;
91
92 #[cfg(target_arch = "aarch64")]
94 asm!(
95 "prfm pldl1keep, [{p}]",
96 p = in(reg) ptr,
97 options(nostack, preserves_flags, readonly)
98 );
99
100 #[cfg(target_arch = "x86_64")]
102 asm!(
103 "prefetcht0 [{p}]",
104 p = in(reg) ptr,
105 options(nostack, preserves_flags, readonly)
106 );
107 }
108 }
109 }
110}
111
112#[derive(Clone, Debug)]
123pub struct ByteSparseMatrix {
124 rows: usize,
125 cols: usize,
126 degree: usize,
127
128 weights: Vec<u8>,
130
131 col_indices: Vec<u32>,
133}
134
135impl ByteSparseMatrix {
136 pub fn new(
139 rows: usize,
140 cols: usize,
141 degree: usize,
142 weights: Vec<u8>,
143 col_indices: Vec<u32>,
144 ) -> Self {
145 let expected_len = rows.checked_mul(degree).expect("Matrix size overflow");
146
147 assert_eq!(
148 weights.len(),
149 expected_len,
150 "Weights vector length mismatch"
151 );
152 assert_eq!(
153 col_indices.len(),
154 expected_len,
155 "Column indices vector length mismatch"
156 );
157 assert!(
158 weights.iter().all(|&w| w == 0 || w == 1),
159 "Virtual packing requires binary weights"
160 );
161
162 for &idx in &col_indices {
163 assert!(
164 (idx as usize) < cols,
165 "Column index {} exceeds matrix columns count {}",
166 idx,
167 cols
168 );
169 }
170
171 Self {
172 rows,
173 cols,
174 degree,
175 weights,
176 col_indices,
177 }
178 }
179
180 #[inline]
181 pub fn rows(&self) -> usize {
182 self.rows
183 }
184
185 #[inline]
186 pub fn cols(&self) -> usize {
187 self.cols
188 }
189
190 #[inline]
191 pub fn degree(&self) -> usize {
192 self.degree
193 }
194
195 #[inline]
196 pub fn weights(&self) -> &[u8] {
197 &self.weights
198 }
199
200 #[inline]
201 pub fn col_indices(&self) -> &[u32] {
202 &self.col_indices
203 }
204
205 pub fn spmv<F, V>(&self, x: &V) -> Vec<Flat<F>>
210 where
211 F: HardwareField,
212 V: VectorSource<Flat<F>> + ?Sized,
213 {
214 assert_eq!(x.len(), self.cols);
215
216 let mut y: Vec<MaybeUninit<Flat<F>>> = Vec::with_capacity(self.rows);
217
218 unsafe {
222 y.set_len(self.rows);
223 }
224
225 #[cfg(feature = "parallel")]
226 if self.rows >= PARALLEL_THRESHOLD {
227 y.par_chunks_mut(CHUNK_SIZE)
228 .enumerate()
229 .for_each(|(chunk_id, out_chunk)| {
230 let start_row = chunk_id * CHUNK_SIZE;
231 self.process_chunk(start_row, out_chunk, x);
232 });
233
234 return unsafe { assume_init_vec(y) };
237 }
238
239 for (chunk_id, out_chunk) in y.chunks_mut(CHUNK_SIZE).enumerate() {
240 let start_row = chunk_id * CHUNK_SIZE;
241 self.process_chunk(start_row, out_chunk, x);
242 }
243
244 unsafe { assume_init_vec(y) }
245 }
246
247 #[inline(always)]
250 fn process_chunk<F, V>(&self, start_row: usize, out_chunk: &mut [MaybeUninit<Flat<F>>], x: &V)
251 where
252 F: HardwareField + Default + Copy,
253 V: VectorSource<Flat<F>> + ?Sized,
254 {
255 for i in 0..out_chunk.len() {
260 let row_idx = start_row + i;
261
262 if i + LOOKAHEAD < out_chunk.len() {
266 let next_row = row_idx + LOOKAHEAD;
267 let row_offset = next_row * self.degree;
268
269 unsafe {
271 for k in 0..self.degree {
272 let col_idx = *self.col_indices.get_unchecked(row_offset + k) as usize;
273 x.prefetch(&[col_idx]);
274 }
275 }
276 }
277
278 const B: usize = 8; let row_offset = row_idx * self.degree;
282
283 let mut acc = Flat::from_raw(F::ZERO);
284 let mut j = 0;
285
286 while j + B <= self.degree {
289 let mut col_idxs = [0usize; B];
290 unsafe {
291 for (k, slot) in col_idxs.iter_mut().enumerate() {
292 *slot = *self.col_indices.get_unchecked(row_offset + j + k) as usize;
293 }
294 }
295
296 let values = x.get_batch::<B>(&col_idxs);
297 unsafe {
298 for (k, &val) in values.iter().enumerate() {
299 if *self.weights.get_unchecked(row_offset + j + k) != 0 {
300 acc += val;
301 }
302 }
303 }
304
305 j += B;
306 }
307
308 while j < self.degree {
309 unsafe {
310 let curr = row_offset + j;
311 if *self.weights.get_unchecked(curr) != 0 {
312 let col_idx = *self.col_indices.get_unchecked(curr) as usize;
313 acc += x.get_at(col_idx);
314 }
315 }
316
317 j += 1;
318 }
319
320 out_chunk[i].write(acc);
321 }
322 }
323}
324
325#[inline]
326unsafe fn assume_init_vec<T>(mut v: Vec<MaybeUninit<T>>) -> Vec<T> {
327 let ptr = v.as_mut_ptr() as *mut T;
328 let len = v.len();
329 let cap = v.capacity();
330
331 core::mem::forget(v);
332
333 unsafe { Vec::from_raw_parts(ptr, len, cap) }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::{Block128, HardwareField};
340 use alloc::vec;
341
342 struct VirtualLinearSource {
343 size: usize,
344 multiplier: u128,
345 }
346
347 impl VectorSource<Flat<Block128>> for VirtualLinearSource {
348 fn len(&self) -> usize {
349 self.size
350 }
351
352 fn is_empty(&self) -> bool {
353 unimplemented!()
354 }
355
356 fn get_at(&self, index: usize) -> Flat<Block128> {
357 Block128::from((index as u128) * self.multiplier).to_hardware()
360 }
361 }
362
363 fn b128(v: u128) -> Block128 {
364 Block128::from(v)
365 }
366
367 #[test]
368 fn spmv_with_virtual_source() {
369 let weights = vec![1u8, 1u8, 1u8, 1u8];
375 let col_indices = vec![0, 1, 1, 0];
376
377 let matrix = ByteSparseMatrix::new(2, 2, 2, weights, col_indices);
378
379 let source = VirtualLinearSource {
381 size: 2,
382 multiplier: 10,
383 };
384
385 let expected_val = Block128::from(10u128).to_hardware();
389 let expected = vec![expected_val, expected_val];
390
391 let res = matrix.spmv(&source);
393
394 assert_eq!(res, expected, "SpMV failed with VirtualSource");
395 }
396
397 #[test]
398 fn byte_sparse_matrix_spmv() {
399 let weights = vec![1u8, 1u8, 1u8, 1u8];
402
403 let col_indices = vec![0, 2, 1, 0];
408
409 let matrix = ByteSparseMatrix::new(2, 3, 2, weights, col_indices);
410
411 let x0_tower = b128(10);
412 let x1_tower = b128(100);
413 let x2_tower = b128(255);
414
415 let x = vec![
416 x0_tower.to_hardware(),
417 x1_tower.to_hardware(),
418 x2_tower.to_hardware(),
419 ];
420
421 let y0_tower = x0_tower + x2_tower;
424
425 let y1_tower = x1_tower + x0_tower;
428
429 let expected = vec![y0_tower.to_hardware(), y1_tower.to_hardware()];
430 let res = matrix.spmv(x.as_slice());
431
432 assert_eq!(res, expected, "Sequential SpMV failed (Basis Mismatch?)");
433 }
434
435 #[test]
436 fn zero_weight_entries_contribute_nothing() {
437 let weights = vec![1, 0, 1, 0, 1, 0];
447 let col_indices = vec![0, 1, 2, 0, 1, 2];
448 let matrix = ByteSparseMatrix::new(2, 3, 3, weights, col_indices);
449
450 let x0 = b128(0xA0);
451 let x1 = b128(0xB0);
452 let x2 = b128(0xC0);
453 let x = vec![x0.to_hardware(), x1.to_hardware(), x2.to_hardware()];
454
455 let expected = vec![(x0 + x2).to_hardware(), x1.to_hardware()];
460
461 assert_eq!(matrix.spmv(x.as_slice()), expected);
462 }
463
464 #[test]
465 #[should_panic(expected = "binary weights")]
466 fn rejects_non_binary_weights() {
467 ByteSparseMatrix::new(1, 2, 2, vec![1, 3], vec![0, 1]);
468 }
469}