1use crate::{Flat, HardwareField};
19use aes::Aes256;
20use aes::cipher::{BlockCipherEncrypt, KeyInit};
21use alloc::vec::Vec;
22use core::arch::asm;
23use core::convert::Infallible;
24use core::mem::MaybeUninit;
25use rand::{RngExt, SeedableRng, TryRng};
26#[cfg(feature = "parallel")]
27use rayon::prelude::*;
28
29const CHUNK_SIZE: usize = 1024;
32
33#[cfg(feature = "parallel")]
37const PARALLEL_THRESHOLD: usize = 32768;
38
39const LOOKAHEAD: usize = 8;
43
44const GEN_CHUNK_ROWS: usize = 256;
50
51const AES_BLOCK: usize = 16;
54const AES_BATCH: usize = 8;
55const AES_BUF_SIZE: usize = AES_BATCH * AES_BLOCK;
56
57pub trait VectorSource<F>: Sync {
61 fn len(&self) -> usize;
63
64 fn is_empty(&self) -> bool;
65
66 fn get_at(&self, index: usize) -> F;
68
69 #[inline(always)]
72 fn get_batch<const N: usize>(&self, indices: &[usize; N]) -> [F; N] {
73 core::array::from_fn(|i| self.get_at(indices[i]))
74 }
75
76 #[inline(always)]
78 fn prefetch(&self, _indices: &[usize]) {
79 }
81}
82
83impl<F: Copy + Sync> VectorSource<F> for [F] {
86 #[inline(always)]
87 fn len(&self) -> usize {
88 self.len()
89 }
90
91 #[inline(always)]
92 fn is_empty(&self) -> bool {
93 self.is_empty()
94 }
95
96 #[inline(always)]
97 fn get_at(&self, index: usize) -> F {
98 self[index]
99 }
100
101 #[inline(always)]
103 fn prefetch(&self, indices: &[usize]) {
104 let base_ptr = self.as_ptr();
105 for &idx in indices {
106 unsafe {
107 let ptr = base_ptr.wrapping_add(idx) as *const u8;
108
109 #[cfg(target_arch = "aarch64")]
111 asm!(
112 "prfm pldl1keep, [{p}]",
113 p = in(reg) ptr,
114 options(nostack, preserves_flags, readonly)
115 );
116
117 #[cfg(target_arch = "x86_64")]
119 asm!(
120 "prefetcht0 [{p}]",
121 p = in(reg) ptr,
122 options(nostack, preserves_flags, readonly)
123 );
124 }
125 }
126 }
127}
128
129#[derive(Clone, Debug)]
140pub struct ByteSparseMatrix {
141 rows: usize,
142 cols: usize,
143 degree: usize,
144
145 weights: Vec<u8>,
147
148 col_indices: Vec<u32>,
150}
151
152impl ByteSparseMatrix {
153 pub fn new(
156 rows: usize,
157 cols: usize,
158 degree: usize,
159 weights: Vec<u8>,
160 col_indices: Vec<u32>,
161 ) -> Self {
162 let expected_len = rows.checked_mul(degree).expect("Matrix size overflow");
163
164 assert_eq!(
165 weights.len(),
166 expected_len,
167 "Weights vector length mismatch"
168 );
169 assert_eq!(
170 col_indices.len(),
171 expected_len,
172 "Column indices vector length mismatch"
173 );
174 assert!(
175 weights.iter().all(|&w| w == 0 || w == 1),
176 "Virtual packing requires binary weights"
177 );
178
179 for &idx in &col_indices {
180 assert!(
181 (idx as usize) < cols,
182 "Column index {} exceeds matrix columns count {}",
183 idx,
184 cols
185 );
186 }
187
188 Self {
189 rows,
190 cols,
191 degree,
192 weights,
193 col_indices,
194 }
195 }
196
197 pub fn generate_random(rows: usize, cols: usize, degree: usize, seed: [u8; 32]) -> Self {
199 const MAX_DEGREE: usize = 256;
200 assert!(
201 degree <= MAX_DEGREE,
202 "Expander degree exceeds stack buffer size"
203 );
204
205 assert!(
209 cols > 0,
210 "Matrix generation requires cols > 0 (division by zero in RNG)"
211 );
212 assert!(
213 degree <= cols,
214 "Expander degree cannot exceed cols (would cause infinite loop in generation)"
215 );
216
217 let total_elems = rows
218 .checked_mul(degree)
219 .expect("Matrix size overflow: rows * degree exceeds usize::MAX");
220
221 if total_elems == 0 {
222 return Self {
223 rows,
224 cols,
225 degree,
226 weights: Vec::new(),
227 col_indices: Vec::new(),
228 };
229 }
230
231 let mut weights: Vec<u8> = Vec::with_capacity(total_elems);
232 let mut col_indices: Vec<u32> = Vec::with_capacity(total_elems);
233
234 let weights_uninit = weights.spare_capacity_mut();
235 let col_indices_uninit = col_indices.spare_capacity_mut();
236
237 debug_assert!(weights_uninit.len() >= total_elems);
238 debug_assert!(col_indices_uninit.len() >= total_elems);
239
240 #[cfg(feature = "parallel")]
241 {
242 let rows_per_chunk = GEN_CHUNK_ROWS.min(rows.max(1));
243 let aligned_chunk_len = rows_per_chunk * degree;
244
245 weights_uninit[..total_elems]
246 .par_chunks_mut(aligned_chunk_len)
247 .zip(col_indices_uninit[..total_elems].par_chunks_mut(aligned_chunk_len))
248 .enumerate()
249 .for_each(|(chunk_id, (w_chunk, col_chunk))| {
250 let rows_in_this_chunk = w_chunk.len() / degree;
251
252 let mut rng = AesCtrPrg::from_seed(seed);
253 rng.set_stream(chunk_id as u64);
254
255 let mut used_cols = [0u32; MAX_DEGREE];
256 for r in 0..rows_in_this_chunk {
257 let row_offset = r * degree;
258
259 for d in 0..degree {
260 w_chunk[row_offset + d].write(1u8);
261
262 let mut col_idx;
263 loop {
264 col_idx = rng.random_range(0..cols as u32);
265
266 if !used_cols[..d].contains(&col_idx) {
273 break;
274 }
275 }
276
277 used_cols[d] = col_idx;
278 col_chunk[row_offset + d].write(col_idx);
279 }
280 }
281 });
282 }
283
284 #[cfg(not(feature = "parallel"))]
285 {
286 let rows_per_chunk = GEN_CHUNK_ROWS.min(rows.max(1));
287 let aligned_chunk_len = rows_per_chunk * degree;
288 let num_chunks = total_elems.div_ceil(aligned_chunk_len);
289
290 let mut used_cols = [0u32; MAX_DEGREE];
291 for chunk_id in 0..num_chunks {
292 let mut rng = AesCtrPrg::from_seed(seed);
293 rng.set_stream(chunk_id as u64);
294
295 let elem_start = chunk_id * aligned_chunk_len;
296 let elem_end = (elem_start + aligned_chunk_len).min(total_elems);
297 let rows_in_this_chunk = (elem_end - elem_start) / degree;
298
299 for r in 0..rows_in_this_chunk {
300 let row_offset = elem_start + r * degree;
301
302 for d in 0..degree {
303 weights_uninit[row_offset + d].write(1u8);
304
305 let mut col_idx;
306 loop {
307 col_idx = rng.random_range(0..cols as u32);
308 if !used_cols[..d].contains(&col_idx) {
309 break;
310 }
311 }
312
313 used_cols[d] = col_idx;
314 col_indices_uninit[row_offset + d].write(col_idx);
315 }
316 }
317 }
318 }
319
320 unsafe {
325 weights.set_len(total_elems);
326 col_indices.set_len(total_elems);
327 }
328
329 assert!(
330 weights.iter().all(|&w| w == 0 || w == 1),
331 "Binary weight invariant violated in generate_random"
332 );
333
334 Self {
335 rows,
336 cols,
337 degree,
338 weights,
339 col_indices,
340 }
341 }
342
343 #[inline]
344 pub fn rows(&self) -> usize {
345 self.rows
346 }
347
348 #[inline]
349 pub fn cols(&self) -> usize {
350 self.cols
351 }
352
353 #[inline]
354 pub fn degree(&self) -> usize {
355 self.degree
356 }
357
358 #[inline]
359 pub fn weights(&self) -> &[u8] {
360 &self.weights
361 }
362
363 #[inline]
364 pub fn col_indices(&self) -> &[u32] {
365 &self.col_indices
366 }
367
368 pub fn spmv<F, V>(&self, x: &V) -> Vec<Flat<F>>
373 where
374 F: HardwareField,
375 V: VectorSource<Flat<F>> + ?Sized,
376 {
377 assert_eq!(x.len(), self.cols);
378
379 let mut y: Vec<MaybeUninit<Flat<F>>> = Vec::with_capacity(self.rows);
380
381 unsafe {
385 y.set_len(self.rows);
386 }
387
388 #[cfg(feature = "parallel")]
389 if self.rows >= PARALLEL_THRESHOLD {
390 y.par_chunks_mut(CHUNK_SIZE)
391 .enumerate()
392 .for_each(|(chunk_id, out_chunk)| {
393 let start_row = chunk_id * CHUNK_SIZE;
394 self.process_chunk(start_row, out_chunk, x);
395 });
396
397 return unsafe { assume_init_vec(y) };
400 }
401
402 for (chunk_id, out_chunk) in y.chunks_mut(CHUNK_SIZE).enumerate() {
403 let start_row = chunk_id * CHUNK_SIZE;
404 self.process_chunk(start_row, out_chunk, x);
405 }
406
407 unsafe { assume_init_vec(y) }
408 }
409
410 #[inline(always)]
413 fn process_chunk<F, V>(&self, start_row: usize, out_chunk: &mut [MaybeUninit<Flat<F>>], x: &V)
414 where
415 F: HardwareField + Default + Copy,
416 V: VectorSource<Flat<F>> + ?Sized,
417 {
418 for i in 0..out_chunk.len() {
423 let row_idx = start_row + i;
424
425 if i + LOOKAHEAD < out_chunk.len() {
429 let next_row = row_idx + LOOKAHEAD;
430 let row_offset = next_row * self.degree;
431
432 unsafe {
434 for k in 0..self.degree {
435 let col_idx = *self.col_indices.get_unchecked(row_offset + k) as usize;
436 x.prefetch(&[col_idx]);
437 }
438 }
439 }
440
441 const B: usize = 8; let row_offset = row_idx * self.degree;
445
446 let mut acc = Flat::from_raw(F::ZERO);
447 let mut j = 0;
448
449 while j + B <= self.degree {
452 let mut col_idxs = [0usize; B];
453 unsafe {
454 for (k, slot) in col_idxs.iter_mut().enumerate() {
455 *slot = *self.col_indices.get_unchecked(row_offset + j + k) as usize;
456 }
457 }
458
459 let values = x.get_batch::<B>(&col_idxs);
460 unsafe {
461 for (k, &val) in values.iter().enumerate() {
462 if *self.weights.get_unchecked(row_offset + j + k) != 0 {
463 acc += val;
464 }
465 }
466 }
467
468 j += B;
469 }
470
471 while j < self.degree {
472 unsafe {
473 let curr = row_offset + j;
474 if *self.weights.get_unchecked(curr) != 0 {
475 let col_idx = *self.col_indices.get_unchecked(curr) as usize;
476 acc += x.get_at(col_idx);
477 }
478 }
479
480 j += 1;
481 }
482
483 out_chunk[i].write(acc);
484 }
485 }
486}
487
488struct AesCtrPrg {
491 cipher: Aes256,
492 nonce: u64,
493 counter: u64,
494 buffer: [u8; AES_BUF_SIZE],
495 buf_pos: usize,
496}
497
498impl AesCtrPrg {
499 fn set_stream(&mut self, stream_id: u64) {
500 self.nonce = stream_id;
501 self.counter = 0;
502 self.buf_pos = AES_BUF_SIZE;
503 }
504
505 fn refill(&mut self) {
506 let nonce_high = (self.nonce as u128) << 64;
507
508 let mut blocks: [aes::Block; AES_BATCH] = Default::default();
509 for (i, block) in blocks.iter_mut().enumerate() {
510 let val = (self.counter + i as u64) as u128 | nonce_high;
511 *block = val.to_le_bytes().into();
512 }
513
514 self.cipher.encrypt_blocks(&mut blocks);
515
516 for (i, block) in blocks.iter().enumerate() {
517 self.buffer[i * AES_BLOCK..(i + 1) * AES_BLOCK].copy_from_slice(block.as_slice());
518 }
519
520 self.counter += AES_BATCH as u64;
521 self.buf_pos = 0;
522 }
523}
524
525impl SeedableRng for AesCtrPrg {
526 type Seed = [u8; 32];
527
528 fn from_seed(seed: [u8; 32]) -> Self {
529 Self {
530 cipher: Aes256::new(&seed.into()),
531 nonce: 0,
532 counter: 0,
533 buffer: [0u8; AES_BUF_SIZE],
534 buf_pos: AES_BUF_SIZE,
535 }
536 }
537}
538
539impl TryRng for AesCtrPrg {
540 type Error = Infallible;
541
542 fn try_next_u32(&mut self) -> Result<u32, Infallible> {
543 if self.buf_pos + 4 > AES_BUF_SIZE {
544 self.refill();
545 }
546
547 let p = self.buf_pos;
548 let val = u32::from_le_bytes(core::array::from_fn(|i| self.buffer[p + i]));
549
550 self.buf_pos = p + 4;
551
552 Ok(val)
553 }
554
555 fn try_next_u64(&mut self) -> Result<u64, Infallible> {
556 if self.buf_pos + 8 > AES_BUF_SIZE {
557 self.refill();
558 }
559
560 let p = self.buf_pos;
561 let val = u64::from_le_bytes(core::array::from_fn(|i| self.buffer[p + i]));
562
563 self.buf_pos = p + 8;
564
565 Ok(val)
566 }
567
568 fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Infallible> {
569 let mut written = 0;
570 while written < dst.len() {
571 if self.buf_pos >= AES_BUF_SIZE {
572 self.refill();
573 }
574
575 let available = AES_BUF_SIZE - self.buf_pos;
576 let copy_len = available.min(dst.len() - written);
577
578 dst[written..written + copy_len]
579 .copy_from_slice(&self.buffer[self.buf_pos..self.buf_pos + copy_len]);
580
581 self.buf_pos += copy_len;
582 written += copy_len;
583 }
584
585 Ok(())
586 }
587}
588
589#[inline]
590unsafe fn assume_init_vec<T>(mut v: Vec<MaybeUninit<T>>) -> Vec<T> {
591 let ptr = v.as_mut_ptr() as *mut T;
592 let len = v.len();
593 let cap = v.capacity();
594
595 core::mem::forget(v);
596
597 unsafe { Vec::from_raw_parts(ptr, len, cap) }
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603 use crate::{Block128, HardwareField};
604 use alloc::vec;
605 use proptest::prelude::*;
606
607 struct VirtualLinearSource {
608 size: usize,
609 multiplier: u128,
610 }
611
612 impl VectorSource<Flat<Block128>> for VirtualLinearSource {
613 fn len(&self) -> usize {
614 self.size
615 }
616
617 fn is_empty(&self) -> bool {
618 unimplemented!()
619 }
620
621 fn get_at(&self, index: usize) -> Flat<Block128> {
622 Block128::from((index as u128) * self.multiplier).to_hardware()
625 }
626 }
627
628 fn b128(v: u128) -> Block128 {
629 Block128::from(v)
630 }
631
632 #[test]
633 fn spmv_with_virtual_source() {
634 let weights = vec![1u8, 1u8, 1u8, 1u8];
640 let col_indices = vec![0, 1, 1, 0];
641
642 let matrix = ByteSparseMatrix::new(2, 2, 2, weights, col_indices);
643
644 let source = VirtualLinearSource {
646 size: 2,
647 multiplier: 10,
648 };
649
650 let expected_val = Block128::from(10u128).to_hardware();
654 let expected = vec![expected_val, expected_val];
655
656 let res = matrix.spmv(&source);
658
659 assert_eq!(res, expected, "SpMV failed with VirtualSource");
660 }
661
662 #[test]
663 fn byte_sparse_matrix_spmv() {
664 let weights = vec![1u8, 1u8, 1u8, 1u8];
667
668 let col_indices = vec![0, 2, 1, 0];
673
674 let matrix = ByteSparseMatrix::new(2, 3, 2, weights, col_indices);
675
676 let x0_tower = b128(10);
677 let x1_tower = b128(100);
678 let x2_tower = b128(255);
679
680 let x = vec![
681 x0_tower.to_hardware(),
682 x1_tower.to_hardware(),
683 x2_tower.to_hardware(),
684 ];
685
686 let y0_tower = x0_tower + x2_tower;
689
690 let y1_tower = x1_tower + x0_tower;
693
694 let expected = vec![y0_tower.to_hardware(), y1_tower.to_hardware()];
695 let res = matrix.spmv(x.as_slice());
696
697 assert_eq!(res, expected, "Sequential SpMV failed (Basis Mismatch?)");
698 }
699
700 #[test]
701 fn zero_weight_entries_contribute_nothing() {
702 let weights = vec![1, 0, 1, 0, 1, 0];
712 let col_indices = vec![0, 1, 2, 0, 1, 2];
713 let matrix = ByteSparseMatrix::new(2, 3, 3, weights, col_indices);
714
715 let x0 = b128(0xA0);
716 let x1 = b128(0xB0);
717 let x2 = b128(0xC0);
718 let x = vec![x0.to_hardware(), x1.to_hardware(), x2.to_hardware()];
719
720 let expected = vec![(x0 + x2).to_hardware(), x1.to_hardware()];
725
726 assert_eq!(matrix.spmv(x.as_slice()), expected);
727 }
728
729 #[test]
730 #[should_panic(expected = "binary weights")]
731 fn rejects_non_binary_weights() {
732 ByteSparseMatrix::new(1, 2, 2, vec![1, 3], vec![0, 1]);
733 }
734
735 #[test]
736 #[should_panic(expected = "cols > 0")]
737 fn safety_rejects_zero_cols() {
738 ByteSparseMatrix::generate_random(10, 0, 5, [1u8; 32]);
741 }
742
743 #[test]
744 fn accepts_valid_dimensions() {
745 let m = ByteSparseMatrix::generate_random(10, 10, 5, [1u8; 32]);
747 assert_eq!(m.rows(), 10);
748 assert_eq!(m.cols(), 10);
749 assert_eq!(m.degree(), 5);
750 assert_eq!(m.weights().len(), 50); }
752
753 #[test]
754 fn accepts_zero_rows_or_degree() {
755 let m1 = ByteSparseMatrix::generate_random(0, 10, 5, [1u8; 32]);
757 assert_eq!(m1.weights().len(), 0);
758
759 let m2 = ByteSparseMatrix::generate_random(10, 10, 0, [1u8; 32]);
760 assert_eq!(m2.weights().len(), 0);
761 }
762
763 #[test]
764 fn expander_properties_sanity_check() {
765 let rows = 4096;
768 let cols = 4096;
769 let degree = 16; let seed = [42u8; 32];
771
772 let matrix = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
774
775 let hamming_weight = |vec: &[Flat<Block128>]| -> usize {
778 vec.iter()
779 .filter(|&&x| x != Block128::from(0u128).to_hardware())
780 .count()
781 };
782
783 for i in 0..100 {
788 let mut x = vec![Block128::from(0u128).to_hardware(); cols];
789 x[i] = Block128::from(1u128).to_hardware();
791
792 let y = matrix.spmv(x.as_slice());
793 let w = hamming_weight(&y);
794
795 assert!(w > 0, "Column {} is empty! Information loss", i);
796 }
797
798 let mut rng = AesCtrPrg::from_seed([1u8; 32]);
804 let mut total_weight = 0;
805
806 let trials = 100;
807 for _ in 0..trials {
808 let mut x = vec![Block128::from(0u128).to_hardware(); cols];
809
810 let idx1 = rng.random_range(0..cols);
812 let idx2 = (idx1 + 1) % cols;
813
814 x[idx1] = Block128::from(1u128).to_hardware();
815 x[idx2] = Block128::from(1u128).to_hardware();
816
817 let y = matrix.spmv(x.as_slice());
818 total_weight += hamming_weight(&y);
819 }
820
821 let avg_weight = total_weight as f64 / trials as f64;
822 let expected_max = (degree * 2) as f64;
823
824 assert!(
830 avg_weight > (expected_max * 0.8),
831 "Too many collisions! Poor expansion property. Avg: {}",
832 avg_weight
833 );
834
835 let input_w = 10;
840 let mut x = vec![Block128::from(0u128).to_hardware(); cols];
841
842 for val in x.iter_mut().take(input_w) {
843 *val = Block128::from(1u128).to_hardware();
844 }
845
846 let y = matrix.spmv(x.as_slice());
847 let w_out = hamming_weight(&y);
848
849 assert!(
851 w_out > (input_w * degree * 8 / 10),
852 "Weight-10 vector collapsed too much! Weight: {}",
853 w_out
854 );
855 }
856
857 #[test]
858 fn check_determinism() {
859 let seed = [42u8; 32];
860 let rows = 1024;
861 let cols = 1024;
862 let degree = 16;
863
864 let matrix1 = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
867 let matrix2 = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
868
869 assert_eq!(
871 matrix1.weights(),
872 matrix2.weights(),
873 "Matrix weights must be deterministic for the same seed"
874 );
875 assert_eq!(
876 matrix1.col_indices(),
877 matrix2.col_indices(),
878 "Matrix column indices must be deterministic for the same seed"
879 );
880
881 #[cfg(feature = "parallel")]
883 {
884 use rayon::ThreadPoolBuilder;
885
886 let matrix_1thread = ThreadPoolBuilder::new()
887 .num_threads(1)
888 .build()
889 .unwrap()
890 .install(|| ByteSparseMatrix::generate_random(rows, cols, degree, seed));
891
892 let matrix_8threads = ThreadPoolBuilder::new()
893 .num_threads(8)
894 .build()
895 .unwrap()
896 .install(|| ByteSparseMatrix::generate_random(rows, cols, degree, seed));
897
898 assert_eq!(
899 matrix_1thread.weights(),
900 matrix_8threads.weights(),
901 "Matrix must be identical regardless of thread count"
902 );
903 assert_eq!(
904 matrix_1thread.col_indices(),
905 matrix_8threads.col_indices(),
906 "Matrix indices must be identical regardless of thread count"
907 );
908 }
909 }
910
911 #[test]
912 fn security_prevent_expander_collapse() {
913 let rows = 1000;
920 let cols = 32;
921 let degree = 32; let seed = [99u8; 32];
923
924 let matrix = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
929
930 for r in 0..rows {
933 let row_offset = r * degree;
934
935 let mut row_indices: Vec<u32> =
937 matrix.col_indices()[row_offset..row_offset + degree].to_vec();
938 row_indices.sort_unstable();
939
940 for d in 0..degree - 1 {
941 assert_ne!(
942 row_indices[d],
943 row_indices[d + 1],
944 "Expander Collapse detected in row {}! Duplicate column index {}. \
945 The rejection sampling loop has been compromised.",
946 r,
947 row_indices[d]
948 );
949 }
950 }
951 }
952
953 #[test]
956 fn cross_feature_determinism_golden() {
957 let matrix = ByteSparseMatrix::generate_random(1024, 512, 16, [42u8; 32]);
958
959 #[rustfmt::skip]
960 const EXPECTED: [u32; 64] = [
961 442, 352, 465, 69, 176, 472, 322, 109,
962 349, 216, 74, 35, 206, 50, 7, 443,
963 349, 214, 30, 332, 66, 316, 297, 415,
964 325, 88, 484, 345, 5, 224, 106, 326,
965 454, 345, 295, 443, 267, 264, 91, 333,
966 163, 359, 262, 49, 112, 499, 219, 67,
967 420, 106, 415, 54, 437, 123, 366, 284,
968 503, 249, 26, 353, 90, 29, 311, 111,
969 ];
970
971 assert_eq!(&matrix.col_indices()[..64], &EXPECTED);
972 }
973
974 #[test]
976 fn aes_ctr_prg_golden() {
977 #[rustfmt::skip]
978 const EXPECTED: [u8; 128] = [
979 0xdc, 0x95, 0xc0, 0x78, 0xa2, 0x40, 0x89, 0x89,
981 0xad, 0x48, 0xa2, 0x14, 0x92, 0x84, 0x20, 0x87,
982 0x52, 0x75, 0xf3, 0xd8, 0x6b, 0x4f, 0xb8, 0x68,
984 0x45, 0x93, 0x13, 0x3e, 0xbf, 0xa5, 0x3c, 0xd3,
985 0x77, 0x9b, 0x38, 0xd1, 0x5b, 0xff, 0xb6, 0x3d,
987 0x8d, 0x60, 0x9d, 0x55, 0x1a, 0x5c, 0xc9, 0x8e,
988 0x39, 0xd6, 0xe9, 0xae, 0x76, 0xa9, 0xb2, 0xf3,
990 0xfc, 0x46, 0x26, 0x80, 0xf7, 0x66, 0x72, 0x0e,
991 0x75, 0xd1, 0x1b, 0x0e, 0x3a, 0x68, 0xc4, 0x22,
993 0x3d, 0x88, 0xdb, 0xf0, 0x17, 0x97, 0x7d, 0xd7,
994 0x84, 0x5c, 0x7d, 0x46, 0x90, 0xfa, 0x59, 0x4f,
996 0x90, 0xe6, 0x7f, 0x7b, 0x52, 0x11, 0xa5, 0x1a,
997 0x6f, 0x87, 0x1f, 0x44, 0x5c, 0x18, 0xaf, 0xc2,
999 0xf8, 0x93, 0x7a, 0xf8, 0x41, 0xfd, 0x2a, 0xd0,
1000 0x8d, 0x3a, 0xe1, 0x50, 0x22, 0x15, 0x52, 0x33,
1002 0x4d, 0xdb, 0x29, 0xfe, 0x36, 0xa0, 0xb7, 0x24,
1003 ];
1004
1005 let mut prg = AesCtrPrg::from_seed([0u8; 32]);
1006 let mut output = [0u8; 128];
1007
1008 let _ = prg.try_fill_bytes(&mut output);
1009
1010 assert_eq!(output, EXPECTED);
1011 }
1012
1013 #[test]
1014 fn aes_ctr_prg_stream_isolation() {
1015 let seed = [0xabu8; 32];
1016
1017 let mut prg0 = AesCtrPrg::from_seed(seed);
1018 prg0.set_stream(0);
1019
1020 let mut out0 = [0u8; 64];
1021 let _ = prg0.try_fill_bytes(&mut out0);
1022
1023 let mut prg1 = AesCtrPrg::from_seed(seed);
1024 prg1.set_stream(1);
1025
1026 let mut out1 = [0u8; 64];
1027 let _ = prg1.try_fill_bytes(&mut out1);
1028
1029 assert_ne!(
1030 out0, out1,
1031 "Different streams must produce different output"
1032 );
1033
1034 let mut prg0_again = AesCtrPrg::from_seed(seed);
1035 prg0_again.set_stream(0);
1036
1037 let mut out0_again = [0u8; 64];
1038 let _ = prg0_again.try_fill_bytes(&mut out0_again);
1039
1040 assert_eq!(out0, out0_again, "Same stream must be deterministic");
1041 }
1042
1043 proptest! {
1044 #![proptest_config(ProptestConfig::with_cases(1000))]
1045 #[test]
1046 fn expansion_proptest(
1047 seed in any::<[u8; 32]>(),
1048 random_col in 0..1024usize,
1049 val_raw in 1..255u128
1050 ) {
1051 let rows = 1024;
1052 let cols = 1024;
1053 let degree = 16;
1054 let matrix = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
1055
1056 let mut x = vec![Block128::from(0u128).to_hardware(); cols];
1057 x[random_col] = Block128::from(val_raw).to_hardware();
1058
1059 let y = matrix.spmv(x.as_slice());
1060 let weight = y.iter().filter(|&&v|
1061 v != Block128::from(0u128).to_hardware()).count();
1062
1063 let min_weight = degree / 6;
1064 prop_assert!(
1065 weight >= min_weight,
1066 "Column {} failed expansion: weight {}",
1067 random_col, weight,
1068 );
1069 }
1070 }
1071}