Skip to main content

hekate_math/
matrix.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18use crate::{Flat, HardwareField};
19use aes::Aes256;
20use aes::cipher::{BlockCipherEncrypt, KeyInit};
21use alloc::vec::Vec;
22use core::arch::asm;
23use core::convert::Infallible;
24use core::mem::MaybeUninit;
25use rand::{RngExt, SeedableRng, TryRng};
26#[cfg(feature = "parallel")]
27use rayon::prelude::*;
28
29/// Rows per processing unit.
30/// 1024 keeps the hot set within L1 cache.
31const CHUNK_SIZE: usize = 1024;
32
33/// Min rows to trigger Rayon.
34/// Binary XOR is too fast to justify
35/// thread sync overhead below 32k.
36#[cfg(feature = "parallel")]
37const PARALLEL_THRESHOLD: usize = 32768;
38
39/// 8 rows ahead keeps the memory
40/// controller saturated during
41/// random VectorSource access.
42const LOOKAHEAD: usize = 8;
43
44/// Fixed chunk size for deterministic
45/// matrix generation. Both parallel
46/// and sequential paths use the same
47/// boundaries so output is bit-identical
48/// across feature configurations.
49const GEN_CHUNK_ROWS: usize = 256;
50
51/// 8 blocks saturates the AES-NI
52/// and ARMv8-CE pipeline via ILP.
53const AES_BLOCK: usize = 16;
54const AES_BATCH: usize = 8;
55const AES_BUF_SIZE: usize = AES_BATCH * AES_BLOCK;
56
57/// Abstract source of a vector for Matrix-Vector
58/// multiplication. Allows using both dense slices
59/// (RAM) and algorithmic generators (JIT).
60pub trait VectorSource<F>: Sync {
61    /// Get the length of the virtual vector.
62    fn len(&self) -> usize;
63
64    fn is_empty(&self) -> bool;
65
66    /// Get the element at the specified index.
67    fn get_at(&self, index: usize) -> F;
68
69    /// Optimized batch fetch.
70    /// Allows the source to use pipelining.
71    #[inline(always)]
72    fn get_batch<const N: usize>(&self, indices: &[usize; N]) -> [F; N] {
73        core::array::from_fn(|i| self.get_at(indices[i]))
74    }
75
76    /// Software prefetching hook.
77    #[inline(always)]
78    fn prefetch(&self, _indices: &[usize]) {
79        // Default no-op
80    }
81}
82
83/// Implementation for standard slice
84/// access (Zero-Cost abstraction).
85impl<F: Copy + Sync> VectorSource<F> for [F] {
86    #[inline(always)]
87    fn len(&self) -> usize {
88        self.len()
89    }
90
91    #[inline(always)]
92    fn is_empty(&self) -> bool {
93        self.is_empty()
94    }
95
96    #[inline(always)]
97    fn get_at(&self, index: usize) -> F {
98        self[index]
99    }
100
101    /// Explicit prefetching implementation using Inline ASM.
102    #[inline(always)]
103    fn prefetch(&self, indices: &[usize]) {
104        let base_ptr = self.as_ptr();
105        for &idx in indices {
106            unsafe {
107                let ptr = base_ptr.wrapping_add(idx) as *const u8;
108
109                // Apple Silicon (M1/M2/M3) & ARM64
110                #[cfg(target_arch = "aarch64")]
111                asm!(
112                    "prfm pldl1keep, [{p}]",
113                    p = in(reg) ptr,
114                    options(nostack, preserves_flags, readonly)
115                );
116
117                // Intel/AMD x86_64
118                #[cfg(target_arch = "x86_64")]
119                asm!(
120                    "prefetcht0 [{p}]",
121                    p = in(reg) ptr,
122                    options(nostack, preserves_flags, readonly)
123                );
124            }
125        }
126    }
127}
128
129/// A Field-Agnostic Sparse Matrix.
130/// Stores weights as `u8` to save memory.
131/// Can be applied to ANY field that
132/// implements `HardwareField`.
133///
134/// SOUNDNESS:
135/// weights must be binary (0 or 1).
136/// tower_bit is GF(2)-linear, not
137/// GF(2^k)-linear, non-binary weight
138/// break virtual packing commutativity.
139#[derive(Clone, Debug)]
140pub struct ByteSparseMatrix {
141    rows: usize,
142    cols: usize,
143    degree: usize,
144
145    /// Weights stored as bytes.
146    weights: Vec<u8>,
147
148    /// Column indices.
149    col_indices: Vec<u32>,
150}
151
152impl ByteSparseMatrix {
153    /// Creates a new matrix safely,
154    /// validating internal array lengths.
155    pub fn new(
156        rows: usize,
157        cols: usize,
158        degree: usize,
159        weights: Vec<u8>,
160        col_indices: Vec<u32>,
161    ) -> Self {
162        let expected_len = rows.checked_mul(degree).expect("Matrix size overflow");
163
164        assert_eq!(
165            weights.len(),
166            expected_len,
167            "Weights vector length mismatch"
168        );
169        assert_eq!(
170            col_indices.len(),
171            expected_len,
172            "Column indices vector length mismatch"
173        );
174        assert!(
175            weights.iter().all(|&w| w == 0 || w == 1),
176            "Virtual packing requires binary weights"
177        );
178
179        for &idx in &col_indices {
180            assert!(
181                (idx as usize) < cols,
182                "Column index {} exceeds matrix columns count {}",
183                idx,
184                cols
185            );
186        }
187
188        Self {
189            rows,
190            cols,
191            degree,
192            weights,
193            col_indices,
194        }
195    }
196
197    /// Generates the Expander Graph once.
198    pub fn generate_random(rows: usize, cols: usize, degree: usize, seed: [u8; 32]) -> Self {
199        const MAX_DEGREE: usize = 256;
200        assert!(
201            degree <= MAX_DEGREE,
202            "Expander degree exceeds stack buffer size"
203        );
204
205        // SAFETY:
206        // Validate dimensions to prevent overflow,
207        // division by zero, and infinite loops.
208        assert!(
209            cols > 0,
210            "Matrix generation requires cols > 0 (division by zero in RNG)"
211        );
212        assert!(
213            degree <= cols,
214            "Expander degree cannot exceed cols (would cause infinite loop in generation)"
215        );
216
217        let total_elems = rows
218            .checked_mul(degree)
219            .expect("Matrix size overflow: rows * degree exceeds usize::MAX");
220
221        if total_elems == 0 {
222            return Self {
223                rows,
224                cols,
225                degree,
226                weights: Vec::new(),
227                col_indices: Vec::new(),
228            };
229        }
230
231        let mut weights: Vec<u8> = Vec::with_capacity(total_elems);
232        let mut col_indices: Vec<u32> = Vec::with_capacity(total_elems);
233
234        let weights_uninit = weights.spare_capacity_mut();
235        let col_indices_uninit = col_indices.spare_capacity_mut();
236
237        debug_assert!(weights_uninit.len() >= total_elems);
238        debug_assert!(col_indices_uninit.len() >= total_elems);
239
240        #[cfg(feature = "parallel")]
241        {
242            let rows_per_chunk = GEN_CHUNK_ROWS.min(rows.max(1));
243            let aligned_chunk_len = rows_per_chunk * degree;
244
245            weights_uninit[..total_elems]
246                .par_chunks_mut(aligned_chunk_len)
247                .zip(col_indices_uninit[..total_elems].par_chunks_mut(aligned_chunk_len))
248                .enumerate()
249                .for_each(|(chunk_id, (w_chunk, col_chunk))| {
250                    let rows_in_this_chunk = w_chunk.len() / degree;
251
252                    let mut rng = AesCtrPrg::from_seed(seed);
253                    rng.set_stream(chunk_id as u64);
254
255                    let mut used_cols = [0u32; MAX_DEGREE];
256                    for r in 0..rows_in_this_chunk {
257                        let row_offset = r * degree;
258
259                        for d in 0..degree {
260                            w_chunk[row_offset + d].write(1u8);
261
262                            let mut col_idx;
263                            loop {
264                                col_idx = rng.random_range(0..cols as u32);
265
266                                // The expander collapse:
267                                // Never break early or fallback in
268                                // characteristic 2 fields. Duplicate
269                                // column indices will result in X ^ X = 0,
270                                // locally destroying the expander
271                                // graph degree and PCS soundness.
272                                if !used_cols[..d].contains(&col_idx) {
273                                    break;
274                                }
275                            }
276
277                            used_cols[d] = col_idx;
278                            col_chunk[row_offset + d].write(col_idx);
279                        }
280                    }
281                });
282        }
283
284        #[cfg(not(feature = "parallel"))]
285        {
286            let rows_per_chunk = GEN_CHUNK_ROWS.min(rows.max(1));
287            let aligned_chunk_len = rows_per_chunk * degree;
288            let num_chunks = total_elems.div_ceil(aligned_chunk_len);
289
290            let mut used_cols = [0u32; MAX_DEGREE];
291            for chunk_id in 0..num_chunks {
292                let mut rng = AesCtrPrg::from_seed(seed);
293                rng.set_stream(chunk_id as u64);
294
295                let elem_start = chunk_id * aligned_chunk_len;
296                let elem_end = (elem_start + aligned_chunk_len).min(total_elems);
297                let rows_in_this_chunk = (elem_end - elem_start) / degree;
298
299                for r in 0..rows_in_this_chunk {
300                    let row_offset = elem_start + r * degree;
301
302                    for d in 0..degree {
303                        weights_uninit[row_offset + d].write(1u8);
304
305                        let mut col_idx;
306                        loop {
307                            col_idx = rng.random_range(0..cols as u32);
308                            if !used_cols[..d].contains(&col_idx) {
309                                break;
310                            }
311                        }
312
313                        used_cols[d] = col_idx;
314                        col_indices_uninit[row_offset + d].write(col_idx);
315                    }
316                }
317            }
318        }
319
320        // SAFETY:
321        // weights_uninit[..total_elems] and
322        // col_indices_uninit[..total_elems]
323        // were fully initialized above.
324        unsafe {
325            weights.set_len(total_elems);
326            col_indices.set_len(total_elems);
327        }
328
329        assert!(
330            weights.iter().all(|&w| w == 0 || w == 1),
331            "Binary weight invariant violated in generate_random"
332        );
333
334        Self {
335            rows,
336            cols,
337            degree,
338            weights,
339            col_indices,
340        }
341    }
342
343    #[inline]
344    pub fn rows(&self) -> usize {
345        self.rows
346    }
347
348    #[inline]
349    pub fn cols(&self) -> usize {
350        self.cols
351    }
352
353    #[inline]
354    pub fn degree(&self) -> usize {
355        self.degree
356    }
357
358    #[inline]
359    pub fn weights(&self) -> &[u8] {
360        &self.weights
361    }
362
363    #[inline]
364    pub fn col_indices(&self) -> &[u32] {
365        &self.col_indices
366    }
367
368    /// Binary SpMV:
369    /// weights are 0 or 1, so each
370    /// term is a conditional XOR.
371    /// Accepts any source implementing `VectorSource`.
372    pub fn spmv<F, V>(&self, x: &V) -> Vec<Flat<F>>
373    where
374        F: HardwareField,
375        V: VectorSource<Flat<F>> + ?Sized,
376    {
377        assert_eq!(x.len(), self.cols);
378
379        let mut y: Vec<MaybeUninit<Flat<F>>> = Vec::with_capacity(self.rows);
380
381        // SAFETY:
382        // Every output slot is written
383        // exactly once below.
384        unsafe {
385            y.set_len(self.rows);
386        }
387
388        #[cfg(feature = "parallel")]
389        if self.rows >= PARALLEL_THRESHOLD {
390            y.par_chunks_mut(CHUNK_SIZE)
391                .enumerate()
392                .for_each(|(chunk_id, out_chunk)| {
393                    let start_row = chunk_id * CHUNK_SIZE;
394                    self.process_chunk(start_row, out_chunk, x);
395                });
396
397            // SAFETY:
398            // All elements were initialized above.
399            return unsafe { assume_init_vec(y) };
400        }
401
402        for (chunk_id, out_chunk) in y.chunks_mut(CHUNK_SIZE).enumerate() {
403            let start_row = chunk_id * CHUNK_SIZE;
404            self.process_chunk(start_row, out_chunk, x);
405        }
406
407        unsafe { assume_init_vec(y) }
408    }
409
410    /// Process a chunk of rows
411    /// with lookahead prefetching.
412    #[inline(always)]
413    fn process_chunk<F, V>(&self, start_row: usize, out_chunk: &mut [MaybeUninit<Flat<F>>], x: &V)
414    where
415        F: HardwareField + Default + Copy,
416        V: VectorSource<Flat<F>> + ?Sized,
417    {
418        // Strategy:
419        // Iterate rows. For row i, prefetch indices
420        // for row i+LOOKAHEAD. Keep the memory
421        // controller pipeline full.
422        for i in 0..out_chunk.len() {
423            let row_idx = start_row + i;
424
425            // A. PREFETCH LOOKAHEAD
426            // Look ahead to find which random
427            // memory addresses we will need soon.
428            if i + LOOKAHEAD < out_chunk.len() {
429                let next_row = row_idx + LOOKAHEAD;
430                let row_offset = next_row * self.degree;
431
432                // Read the column indices for the future row
433                unsafe {
434                    for k in 0..self.degree {
435                        let col_idx = *self.col_indices.get_unchecked(row_offset + k) as usize;
436                        x.prefetch(&[col_idx]);
437                    }
438                }
439            }
440
441            // B. COMPUTE CURRENT ROW
442            const B: usize = 8; // Inner loop unroll factor
443
444            let row_offset = row_idx * self.degree;
445
446            let mut acc = Flat::from_raw(F::ZERO);
447            let mut j = 0;
448
449            // Binary weight invariant:
450            // w ∈ {0, 1}
451            while j + B <= self.degree {
452                let mut col_idxs = [0usize; B];
453                unsafe {
454                    for (k, slot) in col_idxs.iter_mut().enumerate() {
455                        *slot = *self.col_indices.get_unchecked(row_offset + j + k) as usize;
456                    }
457                }
458
459                let values = x.get_batch::<B>(&col_idxs);
460                unsafe {
461                    for (k, &val) in values.iter().enumerate() {
462                        if *self.weights.get_unchecked(row_offset + j + k) != 0 {
463                            acc += val;
464                        }
465                    }
466                }
467
468                j += B;
469            }
470
471            while j < self.degree {
472                unsafe {
473                    let curr = row_offset + j;
474                    if *self.weights.get_unchecked(curr) != 0 {
475                        let col_idx = *self.col_indices.get_unchecked(curr) as usize;
476                        acc += x.get_at(col_idx);
477                    }
478                }
479
480                j += 1;
481            }
482
483            out_chunk[i].write(acc);
484        }
485    }
486}
487
488/// AES-256-CTR PRG for
489/// expander graph generation.
490struct AesCtrPrg {
491    cipher: Aes256,
492    nonce: u64,
493    counter: u64,
494    buffer: [u8; AES_BUF_SIZE],
495    buf_pos: usize,
496}
497
498impl AesCtrPrg {
499    fn set_stream(&mut self, stream_id: u64) {
500        self.nonce = stream_id;
501        self.counter = 0;
502        self.buf_pos = AES_BUF_SIZE;
503    }
504
505    fn refill(&mut self) {
506        let nonce_high = (self.nonce as u128) << 64;
507
508        let mut blocks: [aes::Block; AES_BATCH] = Default::default();
509        for (i, block) in blocks.iter_mut().enumerate() {
510            let val = (self.counter + i as u64) as u128 | nonce_high;
511            *block = val.to_le_bytes().into();
512        }
513
514        self.cipher.encrypt_blocks(&mut blocks);
515
516        for (i, block) in blocks.iter().enumerate() {
517            self.buffer[i * AES_BLOCK..(i + 1) * AES_BLOCK].copy_from_slice(block.as_slice());
518        }
519
520        self.counter += AES_BATCH as u64;
521        self.buf_pos = 0;
522    }
523}
524
525impl SeedableRng for AesCtrPrg {
526    type Seed = [u8; 32];
527
528    fn from_seed(seed: [u8; 32]) -> Self {
529        Self {
530            cipher: Aes256::new(&seed.into()),
531            nonce: 0,
532            counter: 0,
533            buffer: [0u8; AES_BUF_SIZE],
534            buf_pos: AES_BUF_SIZE,
535        }
536    }
537}
538
539impl TryRng for AesCtrPrg {
540    type Error = Infallible;
541
542    fn try_next_u32(&mut self) -> Result<u32, Infallible> {
543        if self.buf_pos + 4 > AES_BUF_SIZE {
544            self.refill();
545        }
546
547        let p = self.buf_pos;
548        let val = u32::from_le_bytes(core::array::from_fn(|i| self.buffer[p + i]));
549
550        self.buf_pos = p + 4;
551
552        Ok(val)
553    }
554
555    fn try_next_u64(&mut self) -> Result<u64, Infallible> {
556        if self.buf_pos + 8 > AES_BUF_SIZE {
557            self.refill();
558        }
559
560        let p = self.buf_pos;
561        let val = u64::from_le_bytes(core::array::from_fn(|i| self.buffer[p + i]));
562
563        self.buf_pos = p + 8;
564
565        Ok(val)
566    }
567
568    fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Infallible> {
569        let mut written = 0;
570        while written < dst.len() {
571            if self.buf_pos >= AES_BUF_SIZE {
572                self.refill();
573            }
574
575            let available = AES_BUF_SIZE - self.buf_pos;
576            let copy_len = available.min(dst.len() - written);
577
578            dst[written..written + copy_len]
579                .copy_from_slice(&self.buffer[self.buf_pos..self.buf_pos + copy_len]);
580
581            self.buf_pos += copy_len;
582            written += copy_len;
583        }
584
585        Ok(())
586    }
587}
588
589#[inline]
590unsafe fn assume_init_vec<T>(mut v: Vec<MaybeUninit<T>>) -> Vec<T> {
591    let ptr = v.as_mut_ptr() as *mut T;
592    let len = v.len();
593    let cap = v.capacity();
594
595    core::mem::forget(v);
596
597    unsafe { Vec::from_raw_parts(ptr, len, cap) }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603    use crate::{Block128, HardwareField};
604    use alloc::vec;
605    use proptest::prelude::*;
606
607    struct VirtualLinearSource {
608        size: usize,
609        multiplier: u128,
610    }
611
612    impl VectorSource<Flat<Block128>> for VirtualLinearSource {
613        fn len(&self) -> usize {
614            self.size
615        }
616
617        fn is_empty(&self) -> bool {
618            unimplemented!()
619        }
620
621        fn get_at(&self, index: usize) -> Flat<Block128> {
622            // Generates value:
623            // index * multiplier
624            Block128::from((index as u128) * self.multiplier).to_hardware()
625        }
626    }
627
628    fn b128(v: u128) -> Block128 {
629        Block128::from(v)
630    }
631
632    #[test]
633    fn spmv_with_virtual_source() {
634        // Scenario: Multiply matrix by a
635        // "Virtual" vector without allocation.
636        // Matrix:
637        // [1, 1] (Indices: 0, 1)
638        // [1, 1] (Indices: 1, 0)
639        let weights = vec![1u8, 1u8, 1u8, 1u8];
640        let col_indices = vec![0, 1, 1, 0];
641
642        let matrix = ByteSparseMatrix::new(2, 2, 2, weights, col_indices);
643
644        // Virtual Vector: [0*10, 1*10] = [0, 10]
645        let source = VirtualLinearSource {
646            size: 2,
647            multiplier: 10,
648        };
649
650        // Expected Output:
651        // Row 0: 1*0 + 1*10 = 10
652        // Row 1: 1*10 + 1*0 = 10
653        let expected_val = Block128::from(10u128).to_hardware();
654        let expected = vec![expected_val, expected_val];
655
656        // Run SpMV with virtual source
657        let res = matrix.spmv(&source);
658
659        assert_eq!(res, expected, "SpMV failed with VirtualSource");
660    }
661
662    #[test]
663    fn byte_sparse_matrix_spmv() {
664        // 2 rows, 3 cols, degree 2,
665        // binary weights only.
666        let weights = vec![1u8, 1u8, 1u8, 1u8];
667
668        // Row 0:
669        // col 0 + col 2,
670        // Row 1:
671        // col 1 + col 0
672        let col_indices = vec![0, 2, 1, 0];
673
674        let matrix = ByteSparseMatrix::new(2, 3, 2, weights, col_indices);
675
676        let x0_tower = b128(10);
677        let x1_tower = b128(100);
678        let x2_tower = b128(255);
679
680        let x = vec![
681            x0_tower.to_hardware(),
682            x1_tower.to_hardware(),
683            x2_tower.to_hardware(),
684        ];
685
686        // Row 0:
687        // 1*x0 + 1*x2 = x0 + x2 (XOR)
688        let y0_tower = x0_tower + x2_tower;
689
690        // Row 1:
691        // 1*x1 + 1*x0 = x1 + x0 (XOR)
692        let y1_tower = x1_tower + x0_tower;
693
694        let expected = vec![y0_tower.to_hardware(), y1_tower.to_hardware()];
695        let res = matrix.spmv(x.as_slice());
696
697        assert_eq!(res, expected, "Sequential SpMV failed (Basis Mismatch?)");
698    }
699
700    #[test]
701    fn zero_weight_entries_contribute_nothing() {
702        // 2 rows, 3 cols, degree 3.
703        // Row 0:
704        // w=1 col=0,
705        // w=0 col=1,
706        // w=1 col=2
707        // Row 1:
708        // w=0 col=0,
709        // w=1 col=1,
710        // w=0 col=2
711        let weights = vec![1, 0, 1, 0, 1, 0];
712        let col_indices = vec![0, 1, 2, 0, 1, 2];
713        let matrix = ByteSparseMatrix::new(2, 3, 3, weights, col_indices);
714
715        let x0 = b128(0xA0);
716        let x1 = b128(0xB0);
717        let x2 = b128(0xC0);
718        let x = vec![x0.to_hardware(), x1.to_hardware(), x2.to_hardware()];
719
720        // Row 0:
721        // 1*x0 + 0*x1 + 1*x2 = x0 + x2
722        // Row 1:
723        // 0*x0 + 1*x1 + 0*x2 = x1
724        let expected = vec![(x0 + x2).to_hardware(), x1.to_hardware()];
725
726        assert_eq!(matrix.spmv(x.as_slice()), expected);
727    }
728
729    #[test]
730    #[should_panic(expected = "binary weights")]
731    fn rejects_non_binary_weights() {
732        ByteSparseMatrix::new(1, 2, 2, vec![1, 3], vec![0, 1]);
733    }
734
735    #[test]
736    #[should_panic(expected = "cols > 0")]
737    fn safety_rejects_zero_cols() {
738        // Division by zero prevention
739        // cols == 0 would cause panic in random_range(0..cols)
740        ByteSparseMatrix::generate_random(10, 0, 5, [1u8; 32]);
741    }
742
743    #[test]
744    fn accepts_valid_dimensions() {
745        // Valid dimensions should work
746        let m = ByteSparseMatrix::generate_random(10, 10, 5, [1u8; 32]);
747        assert_eq!(m.rows(), 10);
748        assert_eq!(m.cols(), 10);
749        assert_eq!(m.degree(), 5);
750        assert_eq!(m.weights().len(), 50); // 10 * 5
751    }
752
753    #[test]
754    fn accepts_zero_rows_or_degree() {
755        // Zero rows or degree should return empty matrix
756        let m1 = ByteSparseMatrix::generate_random(0, 10, 5, [1u8; 32]);
757        assert_eq!(m1.weights().len(), 0);
758
759        let m2 = ByteSparseMatrix::generate_random(10, 10, 0, [1u8; 32]);
760        assert_eq!(m2.weights().len(), 0);
761    }
762
763    #[test]
764    fn expander_properties_sanity_check() {
765        // Test parameters
766        // Use a reasonably small matrix to verify properties fast.
767        let rows = 4096;
768        let cols = 4096;
769        let degree = 16; // Standard degree for Brakedown
770        let seed = [42u8; 32];
771
772        // 1. Generate matrix using production logic
773        let matrix = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
774
775        // Helper to count Hamming weight (non-zero elements)
776        // Verify properties using the Hardware Field representation.
777        let hamming_weight = |vec: &[Flat<Block128>]| -> usize {
778            vec.iter()
779                .filter(|&&x| x != Block128::from(0u128).to_hardware())
780                .count()
781        };
782
783        // TEST 1:
784        // Expansion of Weight-1 Vectors (Atomic check)
785        // Input weight 1 -> Output weight MUST be exactly 'degree'.
786        // This guarantees no row indices are duplicated for a single column.
787        for i in 0..100 {
788            let mut x = vec![Block128::from(0u128).to_hardware(); cols];
789            // Set 1 at random position (simulating a single active variable)
790            x[i] = Block128::from(1u128).to_hardware();
791
792            let y = matrix.spmv(x.as_slice());
793            let w = hamming_weight(&y);
794
795            assert!(w > 0, "Column {} is empty! Information loss", i);
796        }
797
798        // TEST 2:
799        // Expansion of Weight-2 Vectors (Collision check)
800        // Two columns should not share too many neighbors.
801        // Expected weight ~ 2 * degree. Significantly less
802        // implies poor expansion (collisions).
803        let mut rng = AesCtrPrg::from_seed([1u8; 32]);
804        let mut total_weight = 0;
805
806        let trials = 100;
807        for _ in 0..trials {
808            let mut x = vec![Block128::from(0u128).to_hardware(); cols];
809
810            // Pick two distinct indices
811            let idx1 = rng.random_range(0..cols);
812            let idx2 = (idx1 + 1) % cols;
813
814            x[idx1] = Block128::from(1u128).to_hardware();
815            x[idx2] = Block128::from(1u128).to_hardware();
816
817            let y = matrix.spmv(x.as_slice());
818            total_weight += hamming_weight(&y);
819        }
820
821        let avg_weight = total_weight as f64 / trials as f64;
822        let expected_max = (degree * 2) as f64;
823
824        // Allow some collisions (birthday paradox),
825        // but avg weight should be high.
826        // If avg < 25.6 (80% of 32),
827        // the expander quality is suspicious.
828        // println!("Average weight: {}", avg_weight);
829        assert!(
830            avg_weight > (expected_max * 0.8),
831            "Too many collisions! Poor expansion property. Avg: {}",
832            avg_weight
833        );
834
835        // TEST 3:
836        // Avalanche Effect (Weight-10)
837        // A small change in input should produce a large change in output.
838        // Input weight 10 -> Output weight should be close to 160.
839        let input_w = 10;
840        let mut x = vec![Block128::from(0u128).to_hardware(); cols];
841
842        for val in x.iter_mut().take(input_w) {
843            *val = Block128::from(1u128).to_hardware();
844        }
845
846        let y = matrix.spmv(x.as_slice());
847        let w_out = hamming_weight(&y);
848
849        // Allow ~20% loss due to collisions for this density.
850        assert!(
851            w_out > (input_w * degree * 8 / 10),
852            "Weight-10 vector collapsed too much! Weight: {}",
853            w_out
854        );
855    }
856
857    #[test]
858    fn check_determinism() {
859        let seed = [42u8; 32];
860        let rows = 1024;
861        let cols = 1024;
862        let degree = 16;
863
864        // Generate matrix twice (simulating
865        // different thread pool configurations).
866        let matrix1 = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
867        let matrix2 = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
868
869        // These should be exactly the same
870        assert_eq!(
871            matrix1.weights(),
872            matrix2.weights(),
873            "Matrix weights must be deterministic for the same seed"
874        );
875        assert_eq!(
876            matrix1.col_indices(),
877            matrix2.col_indices(),
878            "Matrix column indices must be deterministic for the same seed"
879        );
880
881        // Also test with different thread counts
882        #[cfg(feature = "parallel")]
883        {
884            use rayon::ThreadPoolBuilder;
885
886            let matrix_1thread = ThreadPoolBuilder::new()
887                .num_threads(1)
888                .build()
889                .unwrap()
890                .install(|| ByteSparseMatrix::generate_random(rows, cols, degree, seed));
891
892            let matrix_8threads = ThreadPoolBuilder::new()
893                .num_threads(8)
894                .build()
895                .unwrap()
896                .install(|| ByteSparseMatrix::generate_random(rows, cols, degree, seed));
897
898            assert_eq!(
899                matrix_1thread.weights(),
900                matrix_8threads.weights(),
901                "Matrix must be identical regardless of thread count"
902            );
903            assert_eq!(
904                matrix_1thread.col_indices(),
905                matrix_8threads.col_indices(),
906                "Matrix indices must be identical regardless of thread count"
907            );
908        }
909    }
910
911    #[test]
912    fn security_prevent_expander_collapse() {
913        // SECURITY TEST:
914        // The Expander Collapse
915        // Force a high-density scenario where `degree`
916        // equals `cols`. In GF(2^k), duplicate indices
917        // cause X ^ X = 0, destroying the PCS soundness.
918
919        let rows = 1000;
920        let cols = 32;
921        let degree = 32; // Maximum possible density
922        let seed = [99u8; 32];
923
924        // If the infinite loop protection or
925        // Naive Rejection Sampling is broken,
926        // this will either hang forever or produce
927        // invalid matrices.
928        let matrix = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
929
930        // Verify strictly that every row has
931        // exactly `degree` unique column indices.
932        for r in 0..rows {
933            let row_offset = r * degree;
934
935            // Extract the indices for the current row
936            let mut row_indices: Vec<u32> =
937                matrix.col_indices()[row_offset..row_offset + degree].to_vec();
938            row_indices.sort_unstable();
939
940            for d in 0..degree - 1 {
941                assert_ne!(
942                    row_indices[d],
943                    row_indices[d + 1],
944                    "Expander Collapse detected in row {}! Duplicate column index {}. \
945                     The rejection sampling loop has been compromised.",
946                    r,
947                    row_indices[d]
948                );
949            }
950        }
951    }
952
953    /// Identical output regardless of
954    /// `--features parallel` or not.
955    #[test]
956    fn cross_feature_determinism_golden() {
957        let matrix = ByteSparseMatrix::generate_random(1024, 512, 16, [42u8; 32]);
958
959        #[rustfmt::skip]
960        const EXPECTED: [u32; 64] = [
961            442, 352, 465,  69, 176, 472, 322, 109,
962            349, 216,  74,  35, 206,  50,   7, 443,
963            349, 214,  30, 332,  66, 316, 297, 415,
964            325,  88, 484, 345,   5, 224, 106, 326,
965            454, 345, 295, 443, 267, 264,  91, 333,
966            163, 359, 262,  49, 112, 499, 219,  67,
967            420, 106, 415,  54, 437, 123, 366, 284,
968            503, 249,  26, 353,  90,  29, 311, 111,
969        ];
970
971        assert_eq!(&matrix.col_indices()[..64], &EXPECTED);
972    }
973
974    /// Counter block = (nonce << 64 | counter).to_le_bytes()
975    #[test]
976    fn aes_ctr_prg_golden() {
977        #[rustfmt::skip]
978        const EXPECTED: [u8; 128] = [
979            // block 0: AES-256([0;32], counter=0)
980            0xdc, 0x95, 0xc0, 0x78, 0xa2, 0x40, 0x89, 0x89,
981            0xad, 0x48, 0xa2, 0x14, 0x92, 0x84, 0x20, 0x87,
982            // block 1: counter=1
983            0x52, 0x75, 0xf3, 0xd8, 0x6b, 0x4f, 0xb8, 0x68,
984            0x45, 0x93, 0x13, 0x3e, 0xbf, 0xa5, 0x3c, 0xd3,
985            // block 2: counter=2
986            0x77, 0x9b, 0x38, 0xd1, 0x5b, 0xff, 0xb6, 0x3d,
987            0x8d, 0x60, 0x9d, 0x55, 0x1a, 0x5c, 0xc9, 0x8e,
988            // block 3: counter=3
989            0x39, 0xd6, 0xe9, 0xae, 0x76, 0xa9, 0xb2, 0xf3,
990            0xfc, 0x46, 0x26, 0x80, 0xf7, 0x66, 0x72, 0x0e,
991            // block 4: counter=4
992            0x75, 0xd1, 0x1b, 0x0e, 0x3a, 0x68, 0xc4, 0x22,
993            0x3d, 0x88, 0xdb, 0xf0, 0x17, 0x97, 0x7d, 0xd7,
994            // block 5: counter=5
995            0x84, 0x5c, 0x7d, 0x46, 0x90, 0xfa, 0x59, 0x4f,
996            0x90, 0xe6, 0x7f, 0x7b, 0x52, 0x11, 0xa5, 0x1a,
997            // block 6: counter=6
998            0x6f, 0x87, 0x1f, 0x44, 0x5c, 0x18, 0xaf, 0xc2,
999            0xf8, 0x93, 0x7a, 0xf8, 0x41, 0xfd, 0x2a, 0xd0,
1000            // block 7: counter=7
1001            0x8d, 0x3a, 0xe1, 0x50, 0x22, 0x15, 0x52, 0x33,
1002            0x4d, 0xdb, 0x29, 0xfe, 0x36, 0xa0, 0xb7, 0x24,
1003        ];
1004
1005        let mut prg = AesCtrPrg::from_seed([0u8; 32]);
1006        let mut output = [0u8; 128];
1007
1008        let _ = prg.try_fill_bytes(&mut output);
1009
1010        assert_eq!(output, EXPECTED);
1011    }
1012
1013    #[test]
1014    fn aes_ctr_prg_stream_isolation() {
1015        let seed = [0xabu8; 32];
1016
1017        let mut prg0 = AesCtrPrg::from_seed(seed);
1018        prg0.set_stream(0);
1019
1020        let mut out0 = [0u8; 64];
1021        let _ = prg0.try_fill_bytes(&mut out0);
1022
1023        let mut prg1 = AesCtrPrg::from_seed(seed);
1024        prg1.set_stream(1);
1025
1026        let mut out1 = [0u8; 64];
1027        let _ = prg1.try_fill_bytes(&mut out1);
1028
1029        assert_ne!(
1030            out0, out1,
1031            "Different streams must produce different output"
1032        );
1033
1034        let mut prg0_again = AesCtrPrg::from_seed(seed);
1035        prg0_again.set_stream(0);
1036
1037        let mut out0_again = [0u8; 64];
1038        let _ = prg0_again.try_fill_bytes(&mut out0_again);
1039
1040        assert_eq!(out0, out0_again, "Same stream must be deterministic");
1041    }
1042
1043    proptest! {
1044        #![proptest_config(ProptestConfig::with_cases(1000))]
1045        #[test]
1046        fn expansion_proptest(
1047            seed in any::<[u8; 32]>(),
1048            random_col in 0..1024usize,
1049            val_raw in 1..255u128
1050        ) {
1051            let rows = 1024;
1052            let cols = 1024;
1053            let degree = 16;
1054            let matrix = ByteSparseMatrix::generate_random(rows, cols, degree, seed);
1055
1056            let mut x = vec![Block128::from(0u128).to_hardware(); cols];
1057            x[random_col] = Block128::from(val_raw).to_hardware();
1058
1059            let y = matrix.spmv(x.as_slice());
1060            let weight = y.iter().filter(|&&v|
1061                v != Block128::from(0u128).to_hardware()).count();
1062
1063            let min_weight = degree / 6;
1064            prop_assert!(
1065                weight >= min_weight,
1066                "Column {} failed expansion: weight {}",
1067                random_col, weight,
1068            );
1069        }
1070    }
1071}