Skip to main content

hekate_math/
matrix.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18use crate::{Flat, HardwareField};
19use alloc::vec::Vec;
20use core::arch::asm;
21use core::mem::MaybeUninit;
22#[cfg(feature = "parallel")]
23use rayon::prelude::*;
24
25/// Rows per processing unit.
26/// 1024 keeps the hot set within L1 cache.
27const CHUNK_SIZE: usize = 1024;
28
29/// Min rows to trigger Rayon.
30/// Binary XOR is too fast to justify
31/// thread sync overhead below 32k.
32#[cfg(feature = "parallel")]
33const PARALLEL_THRESHOLD: usize = 32768;
34
35/// 8 rows ahead keeps the memory
36/// controller saturated during
37/// random VectorSource access.
38const LOOKAHEAD: usize = 8;
39
40/// Abstract source of a vector for Matrix-Vector
41/// multiplication. Allows using both dense slices
42/// (RAM) and algorithmic generators (JIT).
43pub trait VectorSource<F>: Sync {
44    /// Get the length of the virtual vector.
45    fn len(&self) -> usize;
46
47    fn is_empty(&self) -> bool;
48
49    /// Get the element at the specified index.
50    fn get_at(&self, index: usize) -> F;
51
52    /// Optimized batch fetch.
53    /// Allows the source to use pipelining.
54    #[inline(always)]
55    fn get_batch<const N: usize>(&self, indices: &[usize; N]) -> [F; N] {
56        core::array::from_fn(|i| self.get_at(indices[i]))
57    }
58
59    /// Software prefetching hook.
60    #[inline(always)]
61    fn prefetch(&self, _indices: &[usize]) {
62        // Default no-op
63    }
64}
65
66/// Implementation for standard slice
67/// access (Zero-Cost abstraction).
68impl<F: Copy + Sync> VectorSource<F> for [F] {
69    #[inline(always)]
70    fn len(&self) -> usize {
71        self.len()
72    }
73
74    #[inline(always)]
75    fn is_empty(&self) -> bool {
76        self.is_empty()
77    }
78
79    #[inline(always)]
80    fn get_at(&self, index: usize) -> F {
81        self[index]
82    }
83
84    /// Explicit prefetching implementation using Inline ASM.
85    #[inline(always)]
86    fn prefetch(&self, indices: &[usize]) {
87        let base_ptr = self.as_ptr();
88        for &idx in indices {
89            unsafe {
90                let ptr = base_ptr.wrapping_add(idx) as *const u8;
91
92                // Apple Silicon (M1/M2/M3) & ARM64
93                #[cfg(target_arch = "aarch64")]
94                asm!(
95                    "prfm pldl1keep, [{p}]",
96                    p = in(reg) ptr,
97                    options(nostack, preserves_flags, readonly)
98                );
99
100                // Intel/AMD x86_64
101                #[cfg(target_arch = "x86_64")]
102                asm!(
103                    "prefetcht0 [{p}]",
104                    p = in(reg) ptr,
105                    options(nostack, preserves_flags, readonly)
106                );
107            }
108        }
109    }
110}
111
112/// A Field-Agnostic Sparse Matrix.
113/// Stores weights as `u8` to save memory.
114/// Can be applied to ANY field that
115/// implements `HardwareField`.
116///
117/// SOUNDNESS:
118/// weights must be binary (0 or 1).
119/// tower_bit is GF(2)-linear, not
120/// GF(2^k)-linear, non-binary weight
121/// break virtual packing commutativity.
122#[derive(Clone, Debug)]
123pub struct ByteSparseMatrix {
124    rows: usize,
125    cols: usize,
126    degree: usize,
127
128    /// Weights stored as bytes.
129    weights: Vec<u8>,
130
131    /// Column indices.
132    col_indices: Vec<u32>,
133}
134
135impl ByteSparseMatrix {
136    /// Creates a new matrix safely,
137    /// validating internal array lengths.
138    pub fn new(
139        rows: usize,
140        cols: usize,
141        degree: usize,
142        weights: Vec<u8>,
143        col_indices: Vec<u32>,
144    ) -> Self {
145        let expected_len = rows.checked_mul(degree).expect("Matrix size overflow");
146
147        assert_eq!(
148            weights.len(),
149            expected_len,
150            "Weights vector length mismatch"
151        );
152        assert_eq!(
153            col_indices.len(),
154            expected_len,
155            "Column indices vector length mismatch"
156        );
157        assert!(
158            weights.iter().all(|&w| w == 0 || w == 1),
159            "Virtual packing requires binary weights"
160        );
161
162        for &idx in &col_indices {
163            assert!(
164                (idx as usize) < cols,
165                "Column index {} exceeds matrix columns count {}",
166                idx,
167                cols
168            );
169        }
170
171        Self {
172            rows,
173            cols,
174            degree,
175            weights,
176            col_indices,
177        }
178    }
179
180    #[inline]
181    pub fn rows(&self) -> usize {
182        self.rows
183    }
184
185    #[inline]
186    pub fn cols(&self) -> usize {
187        self.cols
188    }
189
190    #[inline]
191    pub fn degree(&self) -> usize {
192        self.degree
193    }
194
195    #[inline]
196    pub fn weights(&self) -> &[u8] {
197        &self.weights
198    }
199
200    #[inline]
201    pub fn col_indices(&self) -> &[u32] {
202        &self.col_indices
203    }
204
205    /// Binary SpMV:
206    /// weights are 0 or 1, so each
207    /// term is a conditional XOR.
208    /// Accepts any source implementing `VectorSource`.
209    pub fn spmv<F, V>(&self, x: &V) -> Vec<Flat<F>>
210    where
211        F: HardwareField,
212        V: VectorSource<Flat<F>> + ?Sized,
213    {
214        assert_eq!(x.len(), self.cols);
215
216        let mut y: Vec<MaybeUninit<Flat<F>>> = Vec::with_capacity(self.rows);
217
218        // SAFETY:
219        // Every output slot is written
220        // exactly once below.
221        unsafe {
222            y.set_len(self.rows);
223        }
224
225        #[cfg(feature = "parallel")]
226        if self.rows >= PARALLEL_THRESHOLD {
227            y.par_chunks_mut(CHUNK_SIZE)
228                .enumerate()
229                .for_each(|(chunk_id, out_chunk)| {
230                    let start_row = chunk_id * CHUNK_SIZE;
231                    self.process_chunk(start_row, out_chunk, x);
232                });
233
234            // SAFETY:
235            // All elements were initialized above.
236            return unsafe { assume_init_vec(y) };
237        }
238
239        for (chunk_id, out_chunk) in y.chunks_mut(CHUNK_SIZE).enumerate() {
240            let start_row = chunk_id * CHUNK_SIZE;
241            self.process_chunk(start_row, out_chunk, x);
242        }
243
244        unsafe { assume_init_vec(y) }
245    }
246
247    /// Process a chunk of rows
248    /// with lookahead prefetching.
249    #[inline(always)]
250    fn process_chunk<F, V>(&self, start_row: usize, out_chunk: &mut [MaybeUninit<Flat<F>>], x: &V)
251    where
252        F: HardwareField + Default + Copy,
253        V: VectorSource<Flat<F>> + ?Sized,
254    {
255        // Strategy:
256        // Iterate rows. For row i, prefetch indices
257        // for row i+LOOKAHEAD. Keep the memory
258        // controller pipeline full.
259        for i in 0..out_chunk.len() {
260            let row_idx = start_row + i;
261
262            // A. PREFETCH LOOKAHEAD
263            // Look ahead to find which random
264            // memory addresses we will need soon.
265            if i + LOOKAHEAD < out_chunk.len() {
266                let next_row = row_idx + LOOKAHEAD;
267                let row_offset = next_row * self.degree;
268
269                // Read the column indices for the future row
270                unsafe {
271                    for k in 0..self.degree {
272                        let col_idx = *self.col_indices.get_unchecked(row_offset + k) as usize;
273                        x.prefetch(&[col_idx]);
274                    }
275                }
276            }
277
278            // B. COMPUTE CURRENT ROW
279            const B: usize = 8; // Inner loop unroll factor
280
281            let row_offset = row_idx * self.degree;
282
283            let mut acc = Flat::from_raw(F::ZERO);
284            let mut j = 0;
285
286            // Binary weight invariant:
287            // w ∈ {0, 1}
288            while j + B <= self.degree {
289                let mut col_idxs = [0usize; B];
290                unsafe {
291                    for (k, slot) in col_idxs.iter_mut().enumerate() {
292                        *slot = *self.col_indices.get_unchecked(row_offset + j + k) as usize;
293                    }
294                }
295
296                let values = x.get_batch::<B>(&col_idxs);
297                unsafe {
298                    for (k, &val) in values.iter().enumerate() {
299                        if *self.weights.get_unchecked(row_offset + j + k) != 0 {
300                            acc += val;
301                        }
302                    }
303                }
304
305                j += B;
306            }
307
308            while j < self.degree {
309                unsafe {
310                    let curr = row_offset + j;
311                    if *self.weights.get_unchecked(curr) != 0 {
312                        let col_idx = *self.col_indices.get_unchecked(curr) as usize;
313                        acc += x.get_at(col_idx);
314                    }
315                }
316
317                j += 1;
318            }
319
320            out_chunk[i].write(acc);
321        }
322    }
323}
324
325#[inline]
326unsafe fn assume_init_vec<T>(mut v: Vec<MaybeUninit<T>>) -> Vec<T> {
327    let ptr = v.as_mut_ptr() as *mut T;
328    let len = v.len();
329    let cap = v.capacity();
330
331    core::mem::forget(v);
332
333    unsafe { Vec::from_raw_parts(ptr, len, cap) }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::{Block128, HardwareField};
340    use alloc::vec;
341
342    struct VirtualLinearSource {
343        size: usize,
344        multiplier: u128,
345    }
346
347    impl VectorSource<Flat<Block128>> for VirtualLinearSource {
348        fn len(&self) -> usize {
349            self.size
350        }
351
352        fn is_empty(&self) -> bool {
353            unimplemented!()
354        }
355
356        fn get_at(&self, index: usize) -> Flat<Block128> {
357            // Generates value:
358            // index * multiplier
359            Block128::from((index as u128) * self.multiplier).to_hardware()
360        }
361    }
362
363    fn b128(v: u128) -> Block128 {
364        Block128::from(v)
365    }
366
367    #[test]
368    fn spmv_with_virtual_source() {
369        // Scenario: Multiply matrix by a
370        // "Virtual" vector without allocation.
371        // Matrix:
372        // [1, 1] (Indices: 0, 1)
373        // [1, 1] (Indices: 1, 0)
374        let weights = vec![1u8, 1u8, 1u8, 1u8];
375        let col_indices = vec![0, 1, 1, 0];
376
377        let matrix = ByteSparseMatrix::new(2, 2, 2, weights, col_indices);
378
379        // Virtual Vector: [0*10, 1*10] = [0, 10]
380        let source = VirtualLinearSource {
381            size: 2,
382            multiplier: 10,
383        };
384
385        // Expected Output:
386        // Row 0: 1*0 + 1*10 = 10
387        // Row 1: 1*10 + 1*0 = 10
388        let expected_val = Block128::from(10u128).to_hardware();
389        let expected = vec![expected_val, expected_val];
390
391        // Run SpMV with virtual source
392        let res = matrix.spmv(&source);
393
394        assert_eq!(res, expected, "SpMV failed with VirtualSource");
395    }
396
397    #[test]
398    fn byte_sparse_matrix_spmv() {
399        // 2 rows, 3 cols, degree 2,
400        // binary weights only.
401        let weights = vec![1u8, 1u8, 1u8, 1u8];
402
403        // Row 0:
404        // col 0 + col 2,
405        // Row 1:
406        // col 1 + col 0
407        let col_indices = vec![0, 2, 1, 0];
408
409        let matrix = ByteSparseMatrix::new(2, 3, 2, weights, col_indices);
410
411        let x0_tower = b128(10);
412        let x1_tower = b128(100);
413        let x2_tower = b128(255);
414
415        let x = vec![
416            x0_tower.to_hardware(),
417            x1_tower.to_hardware(),
418            x2_tower.to_hardware(),
419        ];
420
421        // Row 0:
422        // 1*x0 + 1*x2 = x0 + x2 (XOR)
423        let y0_tower = x0_tower + x2_tower;
424
425        // Row 1:
426        // 1*x1 + 1*x0 = x1 + x0 (XOR)
427        let y1_tower = x1_tower + x0_tower;
428
429        let expected = vec![y0_tower.to_hardware(), y1_tower.to_hardware()];
430        let res = matrix.spmv(x.as_slice());
431
432        assert_eq!(res, expected, "Sequential SpMV failed (Basis Mismatch?)");
433    }
434
435    #[test]
436    fn zero_weight_entries_contribute_nothing() {
437        // 2 rows, 3 cols, degree 3.
438        // Row 0:
439        // w=1 col=0,
440        // w=0 col=1,
441        // w=1 col=2
442        // Row 1:
443        // w=0 col=0,
444        // w=1 col=1,
445        // w=0 col=2
446        let weights = vec![1, 0, 1, 0, 1, 0];
447        let col_indices = vec![0, 1, 2, 0, 1, 2];
448        let matrix = ByteSparseMatrix::new(2, 3, 3, weights, col_indices);
449
450        let x0 = b128(0xA0);
451        let x1 = b128(0xB0);
452        let x2 = b128(0xC0);
453        let x = vec![x0.to_hardware(), x1.to_hardware(), x2.to_hardware()];
454
455        // Row 0:
456        // 1*x0 + 0*x1 + 1*x2 = x0 + x2
457        // Row 1:
458        // 0*x0 + 1*x1 + 0*x2 = x1
459        let expected = vec![(x0 + x2).to_hardware(), x1.to_hardware()];
460
461        assert_eq!(matrix.spmv(x.as_slice()), expected);
462    }
463
464    #[test]
465    #[should_panic(expected = "binary weights")]
466    fn rejects_non_binary_weights() {
467        ByteSparseMatrix::new(1, 2, 2, vec![1, 3], vec![0, 1]);
468    }
469}