Skip to main content

irithyll_core/simd/
ops.rs

1//! Core SIMD-accelerated operations: dot product and matrix-vector multiply.
2//!
3//! These are the two hottest primitives across SSM, ESN, and attention forward
4//! passes. AVX2 processes 4 `f64` values per cycle, giving up to ~4x throughput
5//! on aligned inner loops.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Public API (safe)           Internal dispatch
11//! ─────────────────           ─────────────────
12//! simd_dot(a, b)       ──►   avx2::dot_avx2     (x86_64 + AVX2 detected)
13//!                      └──►  dot_scalar          (fallback)
14//!
15//! simd_mat_vec(w,x,..) ──►   avx2::mat_vec_avx2 (x86_64 + AVX2 detected)
16//!                      └──►  mat_vec_scalar      (fallback)
17//! ```
18
19// Runtime detection macro — only available with std.
20#[cfg(all(target_arch = "x86_64", feature = "std"))]
21use std::is_x86_feature_detected;
22
23// ---------------------------------------------------------------------------
24// Scalar fallbacks (always available, no_std compatible)
25// ---------------------------------------------------------------------------
26
27/// Scalar dot product of two slices.
28///
29/// Computes `sum(a[i] * b[i])` for `i` in `0..min(a.len(), b.len())`.
30#[inline]
31fn dot_scalar(a: &[f64], b: &[f64]) -> f64 {
32    let n = a.len().min(b.len());
33    let mut sum = 0.0;
34    for i in 0..n {
35        sum += a[i] * b[i];
36    }
37    sum
38}
39
40/// Scalar matrix-vector multiply: `out[i] = dot(w[i*cols..], x)`.
41///
42/// `w` is a `rows x cols` row-major matrix, `x` is a `cols`-vector,
43/// `out` is a `rows`-vector (must be pre-allocated).
44#[inline]
45fn mat_vec_scalar(w: &[f64], x: &[f64], _rows: usize, cols: usize, out: &mut [f64]) {
46    for (row, out_i) in out.iter_mut().enumerate() {
47        let start = row * cols;
48        let mut sum = 0.0;
49        for j in 0..cols {
50            sum += w[start + j] * x[j];
51        }
52        *out_i = sum;
53    }
54}
55
56// ---------------------------------------------------------------------------
57// AVX2 implementations (x86_64 + std only)
58// ---------------------------------------------------------------------------
59
60#[cfg(all(target_arch = "x86_64", feature = "std"))]
61mod avx2 {
62    /// AVX2-accelerated dot product: processes 4 f64 values per iteration.
63    ///
64    /// # Safety
65    ///
66    /// Caller must ensure AVX2 is available at runtime (checked via
67    /// `is_x86_feature_detected!("avx2")`).
68    #[target_feature(enable = "avx2")]
69    pub(super) unsafe fn dot_avx2(a: &[f64], b: &[f64]) -> f64 {
70        #[cfg(target_arch = "x86_64")]
71        use core::arch::x86_64::*;
72
73        let n = a.len().min(b.len());
74        let chunks = n / 4;
75        let remainder = n % 4;
76
77        let a_ptr = a.as_ptr();
78        let b_ptr = b.as_ptr();
79
80        // SAFETY: AVX2 availability verified by caller. All pointer arithmetic
81        // stays within slice bounds (chunks * 4 <= n).
82        unsafe {
83            let mut acc = _mm256_setzero_pd();
84
85            for i in 0..chunks {
86                let offset = i * 4;
87                let va = _mm256_loadu_pd(a_ptr.add(offset));
88                let vb = _mm256_loadu_pd(b_ptr.add(offset));
89                acc = _mm256_add_pd(acc, _mm256_mul_pd(va, vb));
90            }
91
92            // Horizontal sum of 4 f64 lanes: [a0, a1, a2, a3]
93            let hi128 = _mm256_extractf128_pd(acc, 1); // [a2, a3]
94            let lo128 = _mm256_castpd256_pd128(acc); // [a0, a1]
95            let pair = _mm_add_pd(lo128, hi128); // [a0+a2, a1+a3]
96            let high64 = _mm_unpackhi_pd(pair, pair); // [a1+a3, a1+a3]
97            let total = _mm_add_sd(pair, high64); // low lane = a0+a1+a2+a3
98            let mut scalar_sum = _mm_cvtsd_f64(total);
99
100            // Handle remainder with scalar tail.
101            let base = chunks * 4;
102            for i in 0..remainder {
103                scalar_sum += *a_ptr.add(base + i) * *b_ptr.add(base + i);
104            }
105
106            scalar_sum
107        }
108    }
109
110    /// AVX2-accelerated matrix-vector multiply.
111    ///
112    /// Each row is computed as a SIMD dot product of `w[row*cols..]` with `x`.
113    ///
114    /// # Safety
115    ///
116    /// Caller must ensure:
117    /// - AVX2 is available at runtime
118    /// - `w.len() >= rows * cols`, `x.len() >= cols`, `out.len() >= rows`
119    #[target_feature(enable = "avx2")]
120    pub(super) unsafe fn mat_vec_avx2(
121        w: &[f64],
122        x: &[f64],
123        _rows: usize,
124        cols: usize,
125        out: &mut [f64],
126    ) {
127        for (row, out_i) in out.iter_mut().enumerate() {
128            let row_start = row * cols;
129            // SAFETY: caller ensures w has at least rows*cols elements.
130            // dot_avx2 uses min(a.len(), b.len()) so slicing is safe.
131            unsafe {
132                *out_i = dot_avx2(&w[row_start..row_start + cols], &x[..cols]);
133            }
134        }
135    }
136}
137
138// ---------------------------------------------------------------------------
139// Public safe dispatch functions
140// ---------------------------------------------------------------------------
141
142/// SIMD-accelerated dot product with runtime feature detection.
143///
144/// Uses AVX2 on x86_64 (with `std` feature) when available, falls back to
145/// scalar otherwise.
146///
147/// Returns the dot product of `a` and `b`, processing up to the shorter
148/// slice's length.
149///
150/// # Examples
151///
152/// ```
153/// use irithyll_core::simd::simd_dot;
154///
155/// let a = [1.0, 2.0, 3.0];
156/// let b = [4.0, 5.0, 6.0];
157/// assert!((simd_dot(&a, &b) - 32.0).abs() < 1e-12);
158/// ```
159pub fn simd_dot(a: &[f64], b: &[f64]) -> f64 {
160    #[cfg(all(target_arch = "x86_64", feature = "std"))]
161    {
162        if is_x86_feature_detected!("avx2") {
163            // SAFETY: we just checked for AVX2 support.
164            return unsafe { avx2::dot_avx2(a, b) };
165        }
166    }
167    dot_scalar(a, b)
168}
169
170/// SIMD-accelerated matrix-vector multiply with runtime feature detection.
171///
172/// Computes `out[i] = sum_j w[i*cols + j] * x[j]` for each row.
173/// Uses AVX2 on x86_64 (with `std` feature) when available, falls back to
174/// scalar otherwise.
175///
176/// # Panics
177///
178/// Panics if `w.len() < rows * cols`, `out.len() < rows`, or `x.len() < cols`.
179///
180/// # Examples
181///
182/// ```
183/// use irithyll_core::simd::simd_mat_vec;
184///
185/// // 2x3 matrix times 3-vector
186/// let w = [1.0, 2.0, 3.0,  4.0, 5.0, 6.0];
187/// let x = [1.0, 1.0, 1.0];
188/// let mut out = [0.0; 2];
189/// simd_mat_vec(&w, &x, 2, 3, &mut out);
190/// assert!((out[0] - 6.0).abs() < 1e-12);   // 1+2+3
191/// assert!((out[1] - 15.0).abs() < 1e-12);  // 4+5+6
192/// ```
193pub fn simd_mat_vec(w: &[f64], x: &[f64], rows: usize, cols: usize, out: &mut [f64]) {
194    assert!(
195        w.len() >= rows * cols,
196        "simd_mat_vec: w.len()={} < rows*cols={}",
197        w.len(),
198        rows * cols
199    );
200    assert!(
201        out.len() >= rows,
202        "simd_mat_vec: out.len()={} < rows={}",
203        out.len(),
204        rows
205    );
206    assert!(
207        x.len() >= cols,
208        "simd_mat_vec: x.len()={} < cols={}",
209        x.len(),
210        cols
211    );
212
213    #[cfg(all(target_arch = "x86_64", feature = "std"))]
214    {
215        if is_x86_feature_detected!("avx2") {
216            // SAFETY: bounds checked above, AVX2 detected.
217            unsafe {
218                avx2::mat_vec_avx2(w, x, rows, cols, out);
219            }
220            return;
221        }
222    }
223    mat_vec_scalar(w, x, rows, cols, out);
224}
225
226// ===========================================================================
227// Tests
228// ===========================================================================
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use alloc::vec;
234    use alloc::vec::Vec;
235
236    // Simple deterministic PRNG for test data generation.
237    struct TestRng(u64);
238
239    impl TestRng {
240        fn new(seed: u64) -> Self {
241            Self(seed)
242        }
243
244        fn next_u64(&mut self) -> u64 {
245            let mut x = self.0;
246            x ^= x << 13;
247            x ^= x >> 7;
248            x ^= x << 17;
249            self.0 = x;
250            x
251        }
252
253        fn next_f64(&mut self) -> f64 {
254            // Map to [-1, 1) range for interesting test values
255            (self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64) * 2.0 - 1.0
256        }
257
258        fn fill_vec(&mut self, n: usize) -> Vec<f64> {
259            (0..n).map(|_| self.next_f64()).collect()
260        }
261    }
262
263    // -------------------------------------------------------------------
264    // Dot product tests
265    // -------------------------------------------------------------------
266
267    #[test]
268    fn dot_empty_returns_zero() {
269        let a: [f64; 0] = [];
270        let b: [f64; 0] = [];
271        assert_eq!(simd_dot(&a, &b), 0.0, "dot of empty slices should be 0");
272    }
273
274    #[test]
275    fn dot_single_element() {
276        let a = [3.0];
277        let b = [4.0];
278        assert!(
279            (simd_dot(&a, &b) - 12.0).abs() < 1e-12,
280            "dot([3], [4]) should be 12, got {}",
281            simd_dot(&a, &b)
282        );
283    }
284
285    #[test]
286    fn dot_known_result() {
287        let a = [1.0, 2.0, 3.0];
288        let b = [4.0, 5.0, 6.0];
289        let result = simd_dot(&a, &b);
290        assert!(
291            (result - 32.0).abs() < 1e-12,
292            "dot([1,2,3], [4,5,6]) should be 32, got {}",
293            result
294        );
295    }
296
297    #[test]
298    fn dot_large_matches_scalar() {
299        let mut rng = TestRng::new(42);
300        let a = rng.fill_vec(1000);
301        let b = rng.fill_vec(1000);
302
303        let simd_result = simd_dot(&a, &b);
304        let scalar_result = dot_scalar(&a, &b);
305
306        assert!(
307            (simd_result - scalar_result).abs() < 1e-9,
308            "1000-element dot: SIMD={} vs scalar={}, diff={}",
309            simd_result,
310            scalar_result,
311            (simd_result - scalar_result).abs()
312        );
313    }
314
315    #[test]
316    fn dot_mismatched_lengths() {
317        // Should use the shorter length
318        let a = [1.0, 2.0, 3.0, 999.0];
319        let b = [4.0, 5.0, 6.0];
320        let result = simd_dot(&a, &b);
321        assert!(
322            (result - 32.0).abs() < 1e-12,
323            "mismatched lengths should use min, expected 32, got {}",
324            result
325        );
326    }
327
328    #[test]
329    fn dot_non_aligned_length() {
330        // 7 elements: 1 full AVX2 chunk (4) + 3 remainder
331        let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
332        let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
333        let result = simd_dot(&a, &b);
334        assert!(
335            (result - 28.0).abs() < 1e-12,
336            "dot of [1..7] with [1..1] should be 28, got {}",
337            result
338        );
339    }
340
341    #[test]
342    fn dot_negative_values() {
343        let a = [-1.0, -2.0, -3.0, -4.0];
344        let b = [4.0, 3.0, 2.0, 1.0];
345        // -4 + -6 + -6 + -4 = -20
346        let result = simd_dot(&a, &b);
347        assert!(
348            (result - (-20.0)).abs() < 1e-12,
349            "expected -20, got {}",
350            result
351        );
352    }
353
354    #[test]
355    fn dot_orthogonal_vectors() {
356        let a = [1.0, 0.0, 0.0, 0.0];
357        let b = [0.0, 1.0, 0.0, 0.0];
358        let result = simd_dot(&a, &b);
359        assert!(
360            result.abs() < 1e-12,
361            "orthogonal vectors should have dot=0, got {}",
362            result
363        );
364    }
365
366    // -------------------------------------------------------------------
367    // Matrix-vector multiply tests
368    // -------------------------------------------------------------------
369
370    #[test]
371    fn mat_vec_identity_like() {
372        // 3x3 identity matrix times [1, 2, 3] = [1, 2, 3]
373        let w = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
374        let x = [1.0, 2.0, 3.0];
375        let mut out = [0.0; 3];
376        simd_mat_vec(&w, &x, 3, 3, &mut out);
377        assert!(
378            (out[0] - 1.0).abs() < 1e-12,
379            "identity row 0: expected 1, got {}",
380            out[0]
381        );
382        assert!(
383            (out[1] - 2.0).abs() < 1e-12,
384            "identity row 1: expected 2, got {}",
385            out[1]
386        );
387        assert!(
388            (out[2] - 3.0).abs() < 1e-12,
389            "identity row 2: expected 3, got {}",
390            out[2]
391        );
392    }
393
394    #[test]
395    fn mat_vec_known_result() {
396        // 2x3 matrix:
397        // [1 2 3]   [1]   [1+4+9]   [14]
398        // [4 5 6] * [2] = [4+10+18] = [32]
399        //           [3]
400        let w = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
401        let x = [1.0, 2.0, 3.0];
402        let mut out = [0.0; 2];
403        simd_mat_vec(&w, &x, 2, 3, &mut out);
404        assert!(
405            (out[0] - 14.0).abs() < 1e-12,
406            "row 0: expected 14, got {}",
407            out[0]
408        );
409        assert!(
410            (out[1] - 32.0).abs() < 1e-12,
411            "row 1: expected 32, got {}",
412            out[1]
413        );
414    }
415
416    #[test]
417    fn mat_vec_large_matches_scalar() {
418        let mut rng = TestRng::new(7777);
419        let rows = 100;
420        let cols = 100;
421        let w = rng.fill_vec(rows * cols);
422        let x = rng.fill_vec(cols);
423        let mut out_simd = vec![0.0; rows];
424        let mut out_scalar = vec![0.0; rows];
425
426        simd_mat_vec(&w, &x, rows, cols, &mut out_simd);
427        mat_vec_scalar(&w, &x, rows, cols, &mut out_scalar);
428
429        for i in 0..rows {
430            assert!(
431                (out_simd[i] - out_scalar[i]).abs() < 1e-9,
432                "row {}: SIMD={} vs scalar={}, diff={}",
433                i,
434                out_simd[i],
435                out_scalar[i],
436                (out_simd[i] - out_scalar[i]).abs()
437            );
438        }
439    }
440
441    #[test]
442    fn mat_vec_single_row() {
443        // 1xN is just a dot product
444        let w = [1.0, 2.0, 3.0, 4.0, 5.0];
445        let x = [2.0, 2.0, 2.0, 2.0, 2.0];
446        let mut out = [0.0; 1];
447        simd_mat_vec(&w, &x, 1, 5, &mut out);
448        // 2+4+6+8+10 = 30
449        assert!(
450            (out[0] - 30.0).abs() < 1e-12,
451            "single-row mat_vec should be dot product, expected 30, got {}",
452            out[0]
453        );
454    }
455
456    #[test]
457    fn mat_vec_single_element() {
458        let w = [7.0];
459        let x = [3.0];
460        let mut out = [0.0; 1];
461        simd_mat_vec(&w, &x, 1, 1, &mut out);
462        assert!(
463            (out[0] - 21.0).abs() < 1e-12,
464            "1x1 mat_vec: 7*3=21, got {}",
465            out[0]
466        );
467    }
468
469    // -------------------------------------------------------------------
470    // Panic tests
471    // -------------------------------------------------------------------
472
473    #[test]
474    #[should_panic(expected = "simd_mat_vec: w.len()")]
475    fn mat_vec_panics_w_too_short() {
476        let w = [1.0, 2.0]; // need 2*3=6
477        let x = [1.0, 2.0, 3.0];
478        let mut out = [0.0; 2];
479        simd_mat_vec(&w, &x, 2, 3, &mut out);
480    }
481
482    #[test]
483    #[should_panic(expected = "simd_mat_vec: out.len()")]
484    fn mat_vec_panics_out_too_short() {
485        let w = [1.0; 6];
486        let x = [1.0; 3];
487        let mut out = [0.0; 1]; // need 2
488        simd_mat_vec(&w, &x, 2, 3, &mut out);
489    }
490
491    #[test]
492    #[should_panic(expected = "simd_mat_vec: x.len()")]
493    fn mat_vec_panics_x_too_short() {
494        let w = [1.0; 6];
495        let x = [1.0; 2]; // need 3
496        let mut out = [0.0; 2];
497        simd_mat_vec(&w, &x, 2, 3, &mut out);
498    }
499
500    // -------------------------------------------------------------------
501    // Platform-specific test
502    // -------------------------------------------------------------------
503
504    #[cfg(all(target_arch = "x86_64", feature = "std"))]
505    #[test]
506    fn simd_available_on_x86() {
507        // On modern x86_64, AVX2 should be available.
508        // This test verifies the runtime detection path doesn't panic.
509        let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
510        let b = [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
511        let result = simd_dot(&a, &b);
512        // 8+14+18+20+20+18+14+8 = 120
513        assert!(
514            (result - 120.0).abs() < 1e-12,
515            "8-element dot product should be 120, got {}",
516            result
517        );
518
519        // Also verify AVX2 is actually detected (informational).
520        if is_x86_feature_detected!("avx2") {
521            // AVX2 path was used — good.
522        }
523    }
524}