Skip to main content

cjc_runtime/
tensor_simd.rs

1//! SIMD acceleration for tensor operations (AVX2, 4-wide f64).
2//!
3//! Provides AVX2-accelerated kernels for:
4//! - Element-wise binary operations (add, sub, mul, div)
5//! - Element-wise unary operations (relu, abs, neg, sqrt)
6//! - Inner loop of tiled matrix multiplication (axpy: c += a * b)
7//!
8//! # Determinism
9//!
10//! All SIMD paths produce **bit-identical** results to scalar paths because:
11//! - IEEE 754 mandates identical rounding for scalar and SIMD add/sub/mul/div/sqrt.
12//! - **No FMA** instructions are used (`_mm256_fmadd_pd` changes rounding vs
13//!   separate mul+add — we explicitly avoid it).
14//! - Element-wise ops are independent — no cross-lane reductions.
15//! - Tiled matmul SIMD processes multiple j-columns simultaneously but each
16//!   `C[i,j]` accumulates the same values in the same order.
17//!
18//! # Fallback
19//!
20//! On non-x86_64 platforms or CPUs without AVX2, all functions fall back to
21//! scalar implementations that produce identical results.
22
23/// Runtime check for AVX2 support.
24#[cfg(target_arch = "x86_64")]
25#[inline]
26pub fn has_avx2() -> bool {
27    // `is_x86_feature_detected!` caches the CPUID result after the first call.
28    is_x86_feature_detected!("avx2")
29}
30
31#[cfg(not(target_arch = "x86_64"))]
32#[inline]
33pub fn has_avx2() -> bool {
34    false
35}
36
37// ── Element-wise binary operations ──────────────────────────────────────────
38
39/// Dispatch tag for SIMD-able binary operations.
40#[derive(Clone, Copy)]
41pub enum BinOp {
42    Add,
43    Sub,
44    Mul,
45    Div,
46}
47
48/// Threshold above which element-wise operations are parallelized.
49/// Below this, thread creation overhead dominates.
50const PARALLEL_THRESHOLD: usize = 100_000;
51
52/// SIMD-accelerated element-wise binary operation on equal-length slices.
53///
54/// Returns a new Vec with `out[i] = a[i] ⊕ b[i]` for the chosen operation.
55/// Bit-identical to the scalar loop `a.iter().zip(b).map(|(&x, &y)| op(x, y))`.
56///
57/// For tensors > 100K elements (when the `parallel` feature is enabled),
58/// splits work across threads with each thread using SIMD on its chunk.
59/// Deterministic because each element is independent (no cross-element reduction).
60pub fn simd_binop(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
61    let n = a.len();
62    debug_assert_eq!(n, b.len());
63
64    // Parallel path for large tensors.
65    #[cfg(feature = "parallel")]
66    {
67        if n >= PARALLEL_THRESHOLD {
68            return simd_binop_parallel(a, b, op);
69        }
70    }
71
72    simd_binop_sequential(a, b, op)
73}
74
75/// Sequential SIMD binop (used for small/medium tensors or as fallback).
76fn simd_binop_sequential(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
77    let n = a.len();
78    let mut out = vec![0.0f64; n];
79
80    #[cfg(target_arch = "x86_64")]
81    {
82        if has_avx2() {
83            unsafe {
84                match op {
85                    BinOp::Add => avx2_binop::<ADD_TAG>(a, b, &mut out),
86                    BinOp::Sub => avx2_binop::<SUB_TAG>(a, b, &mut out),
87                    BinOp::Mul => avx2_binop::<MUL_TAG>(a, b, &mut out),
88                    BinOp::Div => avx2_binop::<DIV_TAG>(a, b, &mut out),
89                }
90            }
91            return out;
92        }
93    }
94
95    // Scalar fallback
96    match op {
97        BinOp::Add => { for i in 0..n { out[i] = a[i] + b[i]; } }
98        BinOp::Sub => { for i in 0..n { out[i] = a[i] - b[i]; } }
99        BinOp::Mul => { for i in 0..n { out[i] = a[i] * b[i]; } }
100        BinOp::Div => { for i in 0..n { out[i] = a[i] / b[i]; } }
101    }
102    out
103}
104
105/// Parallel SIMD binop for large tensors.
106///
107/// Splits the data into chunks, each processed by a thread using SIMD.
108/// Deterministic because each element `out[i] = a[i] ⊕ b[i]` is independent.
109#[cfg(feature = "parallel")]
110fn simd_binop_parallel(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
111    use rayon::prelude::*;
112
113    let n = a.len();
114    let mut out = vec![0.0f64; n];
115    let chunk_size = 4096; // ~32 KB per chunk (good L1 cache fit)
116
117    // Each element `out[i] = a[i] ⊕ b[i]` is independent, so throttling
118    // concurrency via `run_parallel` (thermal policy) leaves output identical.
119    crate::runtime_policy::run_parallel(|| {
120        out.par_chunks_mut(chunk_size)
121            .enumerate()
122            .for_each(|(chunk_idx, out_chunk)| {
123                let start = chunk_idx * chunk_size;
124                let len = out_chunk.len();
125                let a_chunk = &a[start..start + len];
126                let b_chunk = &b[start..start + len];
127
128                #[cfg(target_arch = "x86_64")]
129                {
130                    if has_avx2() {
131                        unsafe {
132                            match op {
133                                BinOp::Add => avx2_binop::<ADD_TAG>(a_chunk, b_chunk, out_chunk),
134                                BinOp::Sub => avx2_binop::<SUB_TAG>(a_chunk, b_chunk, out_chunk),
135                                BinOp::Mul => avx2_binop::<MUL_TAG>(a_chunk, b_chunk, out_chunk),
136                                BinOp::Div => avx2_binop::<DIV_TAG>(a_chunk, b_chunk, out_chunk),
137                            }
138                        }
139                        return;
140                    }
141                }
142
143                match op {
144                    BinOp::Add => { for i in 0..len { out_chunk[i] = a_chunk[i] + b_chunk[i]; } }
145                    BinOp::Sub => { for i in 0..len { out_chunk[i] = a_chunk[i] - b_chunk[i]; } }
146                    BinOp::Mul => { for i in 0..len { out_chunk[i] = a_chunk[i] * b_chunk[i]; } }
147                    BinOp::Div => { for i in 0..len { out_chunk[i] = a_chunk[i] / b_chunk[i]; } }
148                }
149            });
150    });
151
152    out
153}
154
155// Const tags for the generic AVX2 binop function.
156const ADD_TAG: u8 = 0;
157const SUB_TAG: u8 = 1;
158const MUL_TAG: u8 = 2;
159const DIV_TAG: u8 = 3;
160
161#[cfg(target_arch = "x86_64")]
162#[target_feature(enable = "avx2")]
163unsafe fn avx2_binop<const OP: u8>(a: &[f64], b: &[f64], out: &mut [f64]) {
164    use std::arch::x86_64::*;
165    let n = a.len();
166    let mut i = 0;
167
168    while i + 4 <= n {
169        let va = _mm256_loadu_pd(a.as_ptr().add(i));
170        let vb = _mm256_loadu_pd(b.as_ptr().add(i));
171        let vr = match OP {
172            ADD_TAG => _mm256_add_pd(va, vb),
173            SUB_TAG => _mm256_sub_pd(va, vb),
174            MUL_TAG => _mm256_mul_pd(va, vb),
175            _       => _mm256_div_pd(va, vb), // DIV_TAG
176        };
177        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
178        i += 4;
179    }
180
181    // Scalar tail (0-3 elements)
182    while i < n {
183        out[i] = match OP {
184            ADD_TAG => a[i] + b[i],
185            SUB_TAG => a[i] - b[i],
186            MUL_TAG => a[i] * b[i],
187            _       => a[i] / b[i],
188        };
189        i += 1;
190    }
191}
192
193// ── Element-wise unary operations ───────────────────────────────────────────
194
195/// Dispatch tag for SIMD-able unary operations.
196#[derive(Clone, Copy)]
197pub enum UnaryOp {
198    Sqrt,
199    Abs,
200    Neg,
201    Relu,
202}
203
204/// SIMD-accelerated element-wise unary operation.
205///
206/// Returns a new Vec with `out[i] = f(a[i])`.
207/// Bit-identical to scalar for all supported operations:
208/// - `sqrt`: IEEE 754 mandates correctly-rounded sqrt.
209/// - `abs`: Bit mask operation (clear sign bit).
210/// - `neg`: Bit flip operation (toggle sign bit).
211/// - `relu`: max(0, x) via compare + blend.
212pub fn simd_unary(a: &[f64], op: UnaryOp) -> Vec<f64> {
213    let n = a.len();
214    let mut out = vec![0.0f64; n];
215
216    #[cfg(target_arch = "x86_64")]
217    {
218        if has_avx2() {
219            unsafe {
220                match op {
221                    UnaryOp::Sqrt => avx2_sqrt(a, &mut out),
222                    UnaryOp::Abs  => avx2_abs(a, &mut out),
223                    UnaryOp::Neg  => avx2_neg(a, &mut out),
224                    UnaryOp::Relu => avx2_relu(a, &mut out),
225                }
226            }
227            return out;
228        }
229    }
230
231    // Scalar fallback
232    match op {
233        UnaryOp::Sqrt => { for i in 0..n { out[i] = a[i].sqrt(); } }
234        UnaryOp::Abs  => { for i in 0..n { out[i] = a[i].abs(); } }
235        UnaryOp::Neg  => { for i in 0..n { out[i] = -a[i]; } }
236        UnaryOp::Relu => { for i in 0..n { out[i] = if a[i] > 0.0 { a[i] } else { 0.0 }; } }
237    }
238    out
239}
240
241#[cfg(target_arch = "x86_64")]
242#[target_feature(enable = "avx2")]
243unsafe fn avx2_sqrt(a: &[f64], out: &mut [f64]) {
244    use std::arch::x86_64::*;
245    let n = a.len();
246    let mut i = 0;
247    while i + 4 <= n {
248        let va = _mm256_loadu_pd(a.as_ptr().add(i));
249        let vr = _mm256_sqrt_pd(va);
250        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
251        i += 4;
252    }
253    while i < n { out[i] = a[i].sqrt(); i += 1; }
254}
255
256#[cfg(target_arch = "x86_64")]
257#[target_feature(enable = "avx2")]
258unsafe fn avx2_abs(a: &[f64], out: &mut [f64]) {
259    use std::arch::x86_64::*;
260    let n = a.len();
261    // Clear sign bit: AND with 0x7FFF_FFFF_FFFF_FFFF
262    let mask = _mm256_set1_pd(f64::from_bits(0x7FFF_FFFF_FFFF_FFFFu64));
263    let mut i = 0;
264    while i + 4 <= n {
265        let va = _mm256_loadu_pd(a.as_ptr().add(i));
266        let vr = _mm256_and_pd(va, mask);
267        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
268        i += 4;
269    }
270    while i < n { out[i] = a[i].abs(); i += 1; }
271}
272
273#[cfg(target_arch = "x86_64")]
274#[target_feature(enable = "avx2")]
275unsafe fn avx2_neg(a: &[f64], out: &mut [f64]) {
276    use std::arch::x86_64::*;
277    let n = a.len();
278    // Toggle sign bit: XOR with 0x8000_0000_0000_0000
279    let sign_bit = _mm256_set1_pd(f64::from_bits(0x8000_0000_0000_0000u64));
280    let mut i = 0;
281    while i + 4 <= n {
282        let va = _mm256_loadu_pd(a.as_ptr().add(i));
283        let vr = _mm256_xor_pd(va, sign_bit);
284        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
285        i += 4;
286    }
287    while i < n { out[i] = -a[i]; i += 1; }
288}
289
290#[cfg(target_arch = "x86_64")]
291#[target_feature(enable = "avx2")]
292unsafe fn avx2_relu(a: &[f64], out: &mut [f64]) {
293    use std::arch::x86_64::*;
294    let n = a.len();
295    let zero = _mm256_setzero_pd();
296    let mut i = 0;
297    while i + 4 <= n {
298        let va = _mm256_loadu_pd(a.as_ptr().add(i));
299        let vr = _mm256_max_pd(va, zero);
300        _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
301        i += 4;
302    }
303    while i < n { out[i] = if a[i] > 0.0 { a[i] } else { 0.0 }; i += 1; }
304}
305
306// ── Tiled matmul AXPY kernel ────────────────────────────────────────────────
307
308/// SIMD-accelerated AXPY: `c[0..len] += scalar * b[0..len]`.
309///
310/// Used in the inner loop of tiled matrix multiplication where `scalar = A[i,p]`
311/// and `b` is a row segment of B. Processes 4 elements per iteration with AVX2.
312///
313/// Deterministic because each `c[j]` accumulates the same `scalar * b[j]`
314/// contribution using separate mul + add (no FMA), matching scalar behavior.
315pub fn simd_axpy(c: &mut [f64], b: &[f64], scalar: f64, len: usize) {
316    debug_assert!(c.len() >= len);
317    debug_assert!(b.len() >= len);
318
319    #[cfg(target_arch = "x86_64")]
320    {
321        if has_avx2() {
322            unsafe { avx2_axpy(c, b, scalar, len); }
323            return;
324        }
325    }
326
327    // Scalar fallback
328    for j in 0..len {
329        c[j] += scalar * b[j];
330    }
331}
332
333#[cfg(target_arch = "x86_64")]
334#[target_feature(enable = "avx2")]
335unsafe fn avx2_axpy(c: &mut [f64], b: &[f64], scalar: f64, len: usize) {
336    use std::arch::x86_64::*;
337    let a_vec = _mm256_set1_pd(scalar);
338    let mut j = 0;
339
340    while j + 4 <= len {
341        let c_ptr = c.as_mut_ptr().add(j);
342        let b_ptr = b.as_ptr().add(j);
343        let c_val = _mm256_loadu_pd(c_ptr);
344        let b_val = _mm256_loadu_pd(b_ptr);
345        // Separate mul + add (NOT FMA) — preserves bit-identity with scalar path.
346        let prod = _mm256_mul_pd(a_vec, b_val);
347        let result = _mm256_add_pd(c_val, prod);
348        _mm256_storeu_pd(c_ptr, result);
349        j += 4;
350    }
351
352    // Scalar tail
353    while j < len {
354        *c.get_unchecked_mut(j) += scalar * *b.get_unchecked(j);
355        j += 1;
356    }
357}
358
359// ── Tests ───────────────────────────────────────────────────────────────────
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_simd_add_matches_scalar() {
367        let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.3).collect();
368        let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.7).collect();
369        let result = simd_binop(&a, &b, BinOp::Add);
370        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect();
371        assert_eq!(result, expected, "SIMD add must be bit-identical to scalar");
372    }
373
374    #[test]
375    fn test_simd_sub_matches_scalar() {
376        let a: Vec<f64> = (0..17).map(|i| i as f64 * 1.1).collect();
377        let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.9).collect();
378        let result = simd_binop(&a, &b, BinOp::Sub);
379        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect();
380        assert_eq!(result, expected, "SIMD sub must be bit-identical to scalar");
381    }
382
383    #[test]
384    fn test_simd_mul_matches_scalar() {
385        let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.1 + 0.01).collect();
386        let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.2 + 0.03).collect();
387        let result = simd_binop(&a, &b, BinOp::Mul);
388        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
389        assert_eq!(result, expected, "SIMD mul must be bit-identical to scalar");
390    }
391
392    #[test]
393    fn test_simd_div_matches_scalar() {
394        let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.5 + 1.0).collect();
395        let b: Vec<f64> = (0..17).map(|i| (i + 1) as f64 * 0.3).collect();
396        let result = simd_binop(&a, &b, BinOp::Div);
397        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect();
398        assert_eq!(result, expected, "SIMD div must be bit-identical to scalar");
399    }
400
401    #[test]
402    fn test_simd_sqrt_matches_scalar() {
403        let a: Vec<f64> = (0..17).map(|i| i as f64 * 2.5 + 0.1).collect();
404        let result = simd_unary(&a, UnaryOp::Sqrt);
405        let expected: Vec<f64> = a.iter().map(|&x| x.sqrt()).collect();
406        assert_eq!(result, expected, "SIMD sqrt must be bit-identical to scalar");
407    }
408
409    #[test]
410    fn test_simd_abs_matches_scalar() {
411        let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
412        let result = simd_unary(&a, UnaryOp::Abs);
413        let expected: Vec<f64> = a.iter().map(|&x| x.abs()).collect();
414        assert_eq!(result, expected, "SIMD abs must be bit-identical to scalar");
415    }
416
417    #[test]
418    fn test_simd_neg_matches_scalar() {
419        let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
420        let result = simd_unary(&a, UnaryOp::Neg);
421        let expected: Vec<f64> = a.iter().map(|&x| -x).collect();
422        assert_eq!(result, expected, "SIMD neg must be bit-identical to scalar");
423    }
424
425    #[test]
426    fn test_simd_relu_matches_scalar() {
427        let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
428        let result = simd_unary(&a, UnaryOp::Relu);
429        let expected: Vec<f64> = a.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect();
430        assert_eq!(result, expected, "SIMD relu must be bit-identical to scalar");
431    }
432
433    #[test]
434    fn test_simd_axpy_matches_scalar() {
435        let b: Vec<f64> = (0..17).map(|i| i as f64 * 0.3).collect();
436        let scalar = 2.5;
437        let mut c_simd: Vec<f64> = (0..17).map(|i| i as f64 * 0.1).collect();
438        let mut c_scalar = c_simd.clone();
439
440        simd_axpy(&mut c_simd, &b, scalar, 17);
441        for j in 0..17 {
442            c_scalar[j] += scalar * b[j];
443        }
444        assert_eq!(c_simd, c_scalar, "SIMD axpy must be bit-identical to scalar");
445    }
446
447    #[test]
448    fn test_simd_empty_input() {
449        let empty: Vec<f64> = vec![];
450        assert_eq!(simd_binop(&empty, &empty, BinOp::Add), Vec::<f64>::new());
451        assert_eq!(simd_unary(&empty, UnaryOp::Sqrt), Vec::<f64>::new());
452    }
453
454    #[test]
455    fn test_simd_single_element() {
456        let a = vec![3.0];
457        let b = vec![4.0];
458        assert_eq!(simd_binop(&a, &b, BinOp::Add), vec![7.0]);
459        assert_eq!(simd_unary(&a, UnaryOp::Sqrt), vec![3.0f64.sqrt()]);
460    }
461
462    #[test]
463    fn test_simd_exactly_four_elements() {
464        let a = vec![1.0, 2.0, 3.0, 4.0];
465        let b = vec![5.0, 6.0, 7.0, 8.0];
466        assert_eq!(simd_binop(&a, &b, BinOp::Add), vec![6.0, 8.0, 10.0, 12.0]);
467        assert_eq!(simd_binop(&a, &b, BinOp::Mul), vec![5.0, 12.0, 21.0, 32.0]);
468    }
469
470    #[test]
471    fn test_avx2_detection() {
472        // Just verify the function doesn't panic.
473        let _has = has_avx2();
474    }
475}