Skip to main content

cjc_runtime/
tensor_simd.rs

1//! SIMD acceleration for tensor operations (AVX2, 4-wide f64).
2//!
3//! Provides AVX2-accelerated kernels for:
4//! - Element-wise binary operations (add, sub, mul, div)
5//! - Element-wise unary operations (relu, abs, neg, sqrt)
6//! - Inner loop of tiled matrix multiplication (axpy: c += a * b)
7//!
8//! # Determinism
9//!
10//! All SIMD paths produce **bit-identical** results to scalar paths because:
11//! - IEEE 754 mandates identical rounding for scalar and SIMD add/sub/mul/div/sqrt.
12//! - **No FMA** instructions are used (`_mm256_fmadd_pd` changes rounding vs
13//!   separate mul+add — we explicitly avoid it).
14//! - Element-wise ops are independent — no cross-lane reductions.
15//! - Tiled matmul SIMD processes multiple j-columns simultaneously but each
16//!   `C[i,j]` accumulates the same values in the same order.
17//!
18//! # Fallback
19//!
20//! On non-x86_64 platforms or CPUs without AVX2, all functions fall back to
21//! scalar implementations that produce identical results.
22
23/// Runtime check for AVX2 support.
24#[cfg(target_arch = "x86_64")]
25#[inline]
26pub fn has_avx2() -> bool {
27    // `is_x86_feature_detected!` caches the CPUID result after the first call.
28    is_x86_feature_detected!("avx2")
29}
30
31#[cfg(not(target_arch = "x86_64"))]
32#[inline]
33pub fn has_avx2() -> bool {
34    false
35}
36
37// ── Element-wise binary operations ──────────────────────────────────────────
38
39/// Dispatch tag for SIMD-able binary operations.
40#[derive(Clone, Copy)]
41pub enum BinOp {
42    Add,
43    Sub,
44    Mul,
45    Div,
46}
47
48/// Threshold above which element-wise operations are parallelized.
49/// Below this, thread creation overhead dominates.
50const PARALLEL_THRESHOLD: usize = 100_000;
51
52/// SIMD-accelerated element-wise binary operation on equal-length slices.
53///
54/// Returns a new Vec with `out[i] = a[i] ⊕ b[i]` for the chosen operation.
55/// Bit-identical to the scalar loop `a.iter().zip(b).map(|(&x, &y)| op(x, y))`.
56///
57/// For tensors > 100K elements (when the `parallel` feature is enabled),
58/// splits work across threads with each thread using SIMD on its chunk.
59/// Deterministic because each element is independent (no cross-element reduction).
60pub fn simd_binop(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
61    let n = a.len();
62    debug_assert_eq!(n, b.len());
63
64    // Parallel path for large tensors.
65    #[cfg(feature = "parallel")]
66    {
67        if n >= PARALLEL_THRESHOLD {
68            return simd_binop_parallel(a, b, op);
69        }
70    }
71
72    simd_binop_sequential(a, b, op)
73}
74
75/// Sequential SIMD binop (used for small/medium tensors or as fallback).
76fn simd_binop_sequential(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
77    let n = a.len();
78    let mut out = vec![0.0f64; n];
79
80    #[cfg(target_arch = "x86_64")]
81    {
82        if has_avx2() {
83            unsafe {
84                match op {
85                    BinOp::Add => avx2_binop::<ADD_TAG>(a, b, &mut out),
86                    BinOp::Sub => avx2_binop::<SUB_TAG>(a, b, &mut out),
87                    BinOp::Mul => avx2_binop::<MUL_TAG>(a, b, &mut out),
88                    BinOp::Div => avx2_binop::<DIV_TAG>(a, b, &mut out),
89                }
90            }
91            return out;
92        }
93    }
94
95    // Scalar fallback
96    match op {
97        BinOp::Add => { for i in 0..n { out[i] = a[i] + b[i]; } }
98        BinOp::Sub => { for i in 0..n { out[i] = a[i] - b[i]; } }
99        BinOp::Mul => { for i in 0..n { out[i] = a[i] * b[i]; } }
100        BinOp::Div => { for i in 0..n { out[i] = a[i] / b[i]; } }
101    }
102    out
103}
104
105/// Parallel SIMD binop for large tensors.
106///
107/// Splits the data into chunks, each processed by a thread using SIMD.
108/// Deterministic because each element `out[i] = a[i] ⊕ b[i]` is independent.
109#[cfg(feature = "parallel")]
110fn simd_binop_parallel(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
111    use rayon::prelude::*;
112
113    let n = a.len();
114    let mut out = vec![0.0f64; n];
115    let chunk_size = 4096; // ~32 KB per chunk (good L1 cache fit)
116
117    out.par_chunks_mut(chunk_size)
118        .enumerate()
119        .for_each(|(chunk_idx, out_chunk)| {
120            let start = chunk_idx * chunk_size;
121            let len = out_chunk.len();
122            let a_chunk = &a[start..start + len];
123            let b_chunk = &b[start..start + len];
124
125            #[cfg(target_arch = "x86_64")]
126            {
127                if has_avx2() {
128                    unsafe {
129                        match op {
130                            BinOp::Add => avx2_binop::<ADD_TAG>(a_chunk, b_chunk, out_chunk),
131                            BinOp::Sub => avx2_binop::<SUB_TAG>(a_chunk, b_chunk, out_chunk),
132                            BinOp::Mul => avx2_binop::<MUL_TAG>(a_chunk, b_chunk, out_chunk),
133                            BinOp::Div => avx2_binop::<DIV_TAG>(a_chunk, b_chunk, out_chunk),
134                        }
135                    }
136                    return;
137                }
138            }
139
140            match op {
141                BinOp::Add => { for i in 0..len { out_chunk[i] = a_chunk[i] + b_chunk[i]; } }
142                BinOp::Sub => { for i in 0..len { out_chunk[i] = a_chunk[i] - b_chunk[i]; } }
143                BinOp::Mul => { for i in 0..len { out_chunk[i] = a_chunk[i] * b_chunk[i]; } }
144                BinOp::Div => { for i in 0..len { out_chunk[i] = a_chunk[i] / b_chunk[i]; } }
145            }
146        });
147
148    out
149}
150
151// Const tags for the generic AVX2 binop function.
152const ADD_TAG: u8 = 0;
153const SUB_TAG: u8 = 1;
154const MUL_TAG: u8 = 2;
155const DIV_TAG: u8 = 3;
156
157#[cfg(target_arch = "x86_64")]
158#[target_feature(enable = "avx2")]
159unsafe fn avx2_binop<const OP: u8>(a: &[f64], b: &[f64], out: &mut [f64]) {
160    use std::arch::x86_64::*;
161    let n = a.len();
162    let mut i = 0;
163
164    while i + 4 <= n {
165        let va = _mm256_loadu_pd(a.as_ptr().add(i));
166        let vb = _mm256_loadu_pd(b.as_ptr().add(i));
167        let vr = match OP {
168            ADD_TAG => _mm256_add_pd(va, vb),
169            SUB_TAG => _mm256_sub_pd(va, vb),
170            MUL_TAG => _mm256_mul_pd(va, vb),
171            _       => _mm256_div_pd(va, vb), // DIV_TAG
172        };
173        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
174        i += 4;
175    }
176
177    // Scalar tail (0-3 elements)
178    while i < n {
179        out[i] = match OP {
180            ADD_TAG => a[i] + b[i],
181            SUB_TAG => a[i] - b[i],
182            MUL_TAG => a[i] * b[i],
183            _       => a[i] / b[i],
184        };
185        i += 1;
186    }
187}
188
189// ── Element-wise unary operations ───────────────────────────────────────────
190
191/// Dispatch tag for SIMD-able unary operations.
192#[derive(Clone, Copy)]
193pub enum UnaryOp {
194    Sqrt,
195    Abs,
196    Neg,
197    Relu,
198}
199
200/// SIMD-accelerated element-wise unary operation.
201///
202/// Returns a new Vec with `out[i] = f(a[i])`.
203/// Bit-identical to scalar for all supported operations:
204/// - `sqrt`: IEEE 754 mandates correctly-rounded sqrt.
205/// - `abs`: Bit mask operation (clear sign bit).
206/// - `neg`: Bit flip operation (toggle sign bit).
207/// - `relu`: max(0, x) via compare + blend.
208pub fn simd_unary(a: &[f64], op: UnaryOp) -> Vec<f64> {
209    let n = a.len();
210    let mut out = vec![0.0f64; n];
211
212    #[cfg(target_arch = "x86_64")]
213    {
214        if has_avx2() {
215            unsafe {
216                match op {
217                    UnaryOp::Sqrt => avx2_sqrt(a, &mut out),
218                    UnaryOp::Abs  => avx2_abs(a, &mut out),
219                    UnaryOp::Neg  => avx2_neg(a, &mut out),
220                    UnaryOp::Relu => avx2_relu(a, &mut out),
221                }
222            }
223            return out;
224        }
225    }
226
227    // Scalar fallback
228    match op {
229        UnaryOp::Sqrt => { for i in 0..n { out[i] = a[i].sqrt(); } }
230        UnaryOp::Abs  => { for i in 0..n { out[i] = a[i].abs(); } }
231        UnaryOp::Neg  => { for i in 0..n { out[i] = -a[i]; } }
232        UnaryOp::Relu => { for i in 0..n { out[i] = if a[i] > 0.0 { a[i] } else { 0.0 }; } }
233    }
234    out
235}
236
237#[cfg(target_arch = "x86_64")]
238#[target_feature(enable = "avx2")]
239unsafe fn avx2_sqrt(a: &[f64], out: &mut [f64]) {
240    use std::arch::x86_64::*;
241    let n = a.len();
242    let mut i = 0;
243    while i + 4 <= n {
244        let va = _mm256_loadu_pd(a.as_ptr().add(i));
245        let vr = _mm256_sqrt_pd(va);
246        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
247        i += 4;
248    }
249    while i < n { out[i] = a[i].sqrt(); i += 1; }
250}
251
252#[cfg(target_arch = "x86_64")]
253#[target_feature(enable = "avx2")]
254unsafe fn avx2_abs(a: &[f64], out: &mut [f64]) {
255    use std::arch::x86_64::*;
256    let n = a.len();
257    // Clear sign bit: AND with 0x7FFF_FFFF_FFFF_FFFF
258    let mask = _mm256_set1_pd(f64::from_bits(0x7FFF_FFFF_FFFF_FFFFu64));
259    let mut i = 0;
260    while i + 4 <= n {
261        let va = _mm256_loadu_pd(a.as_ptr().add(i));
262        let vr = _mm256_and_pd(va, mask);
263        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
264        i += 4;
265    }
266    while i < n { out[i] = a[i].abs(); i += 1; }
267}
268
269#[cfg(target_arch = "x86_64")]
270#[target_feature(enable = "avx2")]
271unsafe fn avx2_neg(a: &[f64], out: &mut [f64]) {
272    use std::arch::x86_64::*;
273    let n = a.len();
274    // Toggle sign bit: XOR with 0x8000_0000_0000_0000
275    let sign_bit = _mm256_set1_pd(f64::from_bits(0x8000_0000_0000_0000u64));
276    let mut i = 0;
277    while i + 4 <= n {
278        let va = _mm256_loadu_pd(a.as_ptr().add(i));
279        let vr = _mm256_xor_pd(va, sign_bit);
280        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
281        i += 4;
282    }
283    while i < n { out[i] = -a[i]; i += 1; }
284}
285
286#[cfg(target_arch = "x86_64")]
287#[target_feature(enable = "avx2")]
288unsafe fn avx2_relu(a: &[f64], out: &mut [f64]) {
289    use std::arch::x86_64::*;
290    let n = a.len();
291    let zero = _mm256_setzero_pd();
292    let mut i = 0;
293    while i + 4 <= n {
294        let va = _mm256_loadu_pd(a.as_ptr().add(i));
295        let vr = _mm256_max_pd(va, zero);
296        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
297        i += 4;
298    }
299    while i < n { out[i] = if a[i] > 0.0 { a[i] } else { 0.0 }; i += 1; }
300}
301
302// ── Tiled matmul AXPY kernel ────────────────────────────────────────────────
303
304/// SIMD-accelerated AXPY: `c[0..len] += scalar * b[0..len]`.
305///
306/// Used in the inner loop of tiled matrix multiplication where `scalar = A[i,p]`
307/// and `b` is a row segment of B. Processes 4 elements per iteration with AVX2.
308///
309/// Deterministic because each `c[j]` accumulates the same `scalar * b[j]`
310/// contribution using separate mul + add (no FMA), matching scalar behavior.
311pub fn simd_axpy(c: &mut [f64], b: &[f64], scalar: f64, len: usize) {
312    debug_assert!(c.len() >= len);
313    debug_assert!(b.len() >= len);
314
315    #[cfg(target_arch = "x86_64")]
316    {
317        if has_avx2() {
318            unsafe { avx2_axpy(c, b, scalar, len); }
319            return;
320        }
321    }
322
323    // Scalar fallback
324    for j in 0..len {
325        c[j] += scalar * b[j];
326    }
327}
328
329#[cfg(target_arch = "x86_64")]
330#[target_feature(enable = "avx2")]
331unsafe fn avx2_axpy(c: &mut [f64], b: &[f64], scalar: f64, len: usize) {
332    use std::arch::x86_64::*;
333    let a_vec = _mm256_set1_pd(scalar);
334    let mut j = 0;
335
336    while j + 4 <= len {
337        let c_ptr = c.as_mut_ptr().add(j);
338        let b_ptr = b.as_ptr().add(j);
339        let c_val = _mm256_loadu_pd(c_ptr);
340        let b_val = _mm256_loadu_pd(b_ptr);
341        // Separate mul + add (NOT FMA) — preserves bit-identity with scalar path.
342        let prod = _mm256_mul_pd(a_vec, b_val);
343        let result = _mm256_add_pd(c_val, prod);
344        _mm256_storeu_pd(c_ptr, result);
345        j += 4;
346    }
347
348    // Scalar tail
349    while j < len {
350        *c.get_unchecked_mut(j) += scalar * *b.get_unchecked(j);
351        j += 1;
352    }
353}
354
355// ── Tests ───────────────────────────────────────────────────────────────────
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_simd_add_matches_scalar() {
363        let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.3).collect();
364        let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.7).collect();
365        let result = simd_binop(&a, &b, BinOp::Add);
366        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect();
367        assert_eq!(result, expected, "SIMD add must be bit-identical to scalar");
368    }
369
370    #[test]
371    fn test_simd_sub_matches_scalar() {
372        let a: Vec<f64> = (0..17).map(|i| i as f64 * 1.1).collect();
373        let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.9).collect();
374        let result = simd_binop(&a, &b, BinOp::Sub);
375        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect();
376        assert_eq!(result, expected, "SIMD sub must be bit-identical to scalar");
377    }
378
379    #[test]
380    fn test_simd_mul_matches_scalar() {
381        let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.1 + 0.01).collect();
382        let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.2 + 0.03).collect();
383        let result = simd_binop(&a, &b, BinOp::Mul);
384        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
385        assert_eq!(result, expected, "SIMD mul must be bit-identical to scalar");
386    }
387
388    #[test]
389    fn test_simd_div_matches_scalar() {
390        let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.5 + 1.0).collect();
391        let b: Vec<f64> = (0..17).map(|i| (i + 1) as f64 * 0.3).collect();
392        let result = simd_binop(&a, &b, BinOp::Div);
393        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect();
394        assert_eq!(result, expected, "SIMD div must be bit-identical to scalar");
395    }
396
397    #[test]
398    fn test_simd_sqrt_matches_scalar() {
399        let a: Vec<f64> = (0..17).map(|i| i as f64 * 2.5 + 0.1).collect();
400        let result = simd_unary(&a, UnaryOp::Sqrt);
401        let expected: Vec<f64> = a.iter().map(|&x| x.sqrt()).collect();
402        assert_eq!(result, expected, "SIMD sqrt must be bit-identical to scalar");
403    }
404
405    #[test]
406    fn test_simd_abs_matches_scalar() {
407        let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
408        let result = simd_unary(&a, UnaryOp::Abs);
409        let expected: Vec<f64> = a.iter().map(|&x| x.abs()).collect();
410        assert_eq!(result, expected, "SIMD abs must be bit-identical to scalar");
411    }
412
413    #[test]
414    fn test_simd_neg_matches_scalar() {
415        let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
416        let result = simd_unary(&a, UnaryOp::Neg);
417        let expected: Vec<f64> = a.iter().map(|&x| -x).collect();
418        assert_eq!(result, expected, "SIMD neg must be bit-identical to scalar");
419    }
420
421    #[test]
422    fn test_simd_relu_matches_scalar() {
423        let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
424        let result = simd_unary(&a, UnaryOp::Relu);
425        let expected: Vec<f64> = a.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect();
426        assert_eq!(result, expected, "SIMD relu must be bit-identical to scalar");
427    }
428
429    #[test]
430    fn test_simd_axpy_matches_scalar() {
431        let b: Vec<f64> = (0..17).map(|i| i as f64 * 0.3).collect();
432        let scalar = 2.5;
433        let mut c_simd: Vec<f64> = (0..17).map(|i| i as f64 * 0.1).collect();
434        let mut c_scalar = c_simd.clone();
435
436        simd_axpy(&mut c_simd, &b, scalar, 17);
437        for j in 0..17 {
438            c_scalar[j] += scalar * b[j];
439        }
440        assert_eq!(c_simd, c_scalar, "SIMD axpy must be bit-identical to scalar");
441    }
442
443    #[test]
444    fn test_simd_empty_input() {
445        let empty: Vec<f64> = vec![];
446        assert_eq!(simd_binop(&empty, &empty, BinOp::Add), Vec::<f64>::new());
447        assert_eq!(simd_unary(&empty, UnaryOp::Sqrt), Vec::<f64>::new());
448    }
449
450    #[test]
451    fn test_simd_single_element() {
452        let a = vec![3.0];
453        let b = vec![4.0];
454        assert_eq!(simd_binop(&a, &b, BinOp::Add), vec![7.0]);
455        assert_eq!(simd_unary(&a, UnaryOp::Sqrt), vec![3.0f64.sqrt()]);
456    }
457
458    #[test]
459    fn test_simd_exactly_four_elements() {
460        let a = vec![1.0, 2.0, 3.0, 4.0];
461        let b = vec![5.0, 6.0, 7.0, 8.0];
462        assert_eq!(simd_binop(&a, &b, BinOp::Add), vec![6.0, 8.0, 10.0, 12.0]);
463        assert_eq!(simd_binop(&a, &b, BinOp::Mul), vec![5.0, 12.0, 21.0, 32.0]);
464    }
465
466    #[test]
467    fn test_avx2_detection() {
468        // Just verify the function doesn't panic.
469        let _has = has_avx2();
470    }
471}