Skip to main content

trueno/blis/
elementwise.rs

1//! SIMD-accelerated element-wise operations.
2//!
3//! AVX2 implementations of ReLU, vector add, and scalar multiply.
4//! These are bandwidth-bound at large sizes; SIMD helps at small-to-medium
5//! sizes by reducing instruction count and enabling wider stores.
6//!
7//! # Algorithm
8//!
9//! ReLU: `_mm256_max_ps(x, zero)` — single instruction per 8 elements
10//! Add: `_mm256_add_ps(a, b)` — single instruction per 8 elements
11//! Mul scalar: `_mm256_mul_ps(x, scalar_vec)` — single instruction per 8 elements
12//!
13//! Contract: provable-contracts/contracts/activation-kernel-v1.yaml
14
15use crate::error::TruenoError;
16
17// ============================================================================
18// ReLU
19// ============================================================================
20
21/// ReLU: output_i = max(0, input_i)
22///
23/// Uses AVX2 `_mm256_max_ps` when available.
24///
25/// # Errors
26///
27/// Returns `Err` if input and output lengths don't match.
28pub fn relu(input: &[f32], output: &mut [f32]) -> Result<(), TruenoError> {
29    contract_pre_relu!(input);
30    let n = input.len();
31    if n != output.len() {
32        return Err(TruenoError::InvalidInput(format!(
33            "relu size mismatch: input[{}], output[{}]",
34            n,
35            output.len()
36        )));
37    }
38
39    #[cfg(target_arch = "x86_64")]
40    {
41        // For bandwidth-bound elementwise ops, AVX2 at full clock beats
42        // AVX-512 at throttled clock (Zen 4: ~30% frequency reduction).
43        // For bandwidth-bound sizes (>4K), let LLVM auto-vectorize.
44        // LLVM -O3 with target-cpu=native produces optimal SIMD code
45        // that matches or beats hand-written intrinsics for simple ops,
46        // with better register allocation and loop fusion.
47        if n > 4096 {
48            relu_autovec(input, output);
49            contract_post_elementwise_parity!(output);
50            return Ok(());
51        }
52        if is_x86_feature_detected!("avx512f") {
53            unsafe {
54                relu_avx512(input, output);
55            }
56            contract_post_elementwise_parity!(output);
57            return Ok(());
58        }
59        if is_x86_feature_detected!("avx2") {
60            unsafe {
61                relu_avx2(input, output);
62            }
63            contract_post_elementwise_parity!(output);
64            return Ok(());
65        }
66    }
67
68    relu_autovec(input, output);
69    contract_post_elementwise_parity!(output);
70    Ok(())
71}
72
73/// ReLU via simple loop — LLVM auto-vectorizes this to optimal SIMD.
74/// For bandwidth-bound workloads (>4K elements), LLVM's autovectorizer
75/// with -O3 -C target-cpu=native produces code that matches hand-written
76/// intrinsics, with better register scheduling and no calling overhead.
77#[inline]
78fn relu_autovec(input: &[f32], output: &mut [f32]) {
79    for i in 0..input.len() {
80        output[i] = input[i].max(0.0);
81    }
82}
83
84/// AVX-512 ReLU with NT stores for large arrays.
85#[cfg(target_arch = "x86_64")]
86#[target_feature(enable = "avx512f")]
87unsafe fn relu_avx512(input: &[f32], output: &mut [f32]) {
88    use std::arch::x86_64::*;
89    unsafe {
90        let n = input.len();
91        let ip = input.as_ptr();
92        let op = output.as_mut_ptr();
93        let zero = _mm512_setzero_ps();
94        let mut i = 0;
95
96        let data_bytes = n * 4;
97        let op_aligned = (op as usize) % 64 == 0;
98        if data_bytes > NT_STORE_THRESHOLD_BYTES && op_aligned {
99            // NT path: 4-way unrolled, requires 64-byte aligned output
100            while i + 64 <= n {
101                _mm_prefetch(ip.add(i + 128).cast::<i8>(), _MM_HINT_T0);
102
103                _mm512_stream_ps(op.add(i), _mm512_max_ps(_mm512_loadu_ps(ip.add(i)), zero));
104                _mm512_stream_ps(
105                    op.add(i + 16),
106                    _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 16)), zero),
107                );
108                _mm512_stream_ps(
109                    op.add(i + 32),
110                    _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 32)), zero),
111                );
112                _mm512_stream_ps(
113                    op.add(i + 48),
114                    _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 48)), zero),
115                );
116                i += 64;
117            }
118            while i + 16 <= n {
119                _mm512_stream_ps(op.add(i), _mm512_max_ps(_mm512_loadu_ps(ip.add(i)), zero));
120                i += 16;
121            }
122            _mm_sfence();
123        } else {
124            while i + 64 <= n {
125                _mm512_storeu_ps(op.add(i), _mm512_max_ps(_mm512_loadu_ps(ip.add(i)), zero));
126                _mm512_storeu_ps(
127                    op.add(i + 16),
128                    _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 16)), zero),
129                );
130                _mm512_storeu_ps(
131                    op.add(i + 32),
132                    _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 32)), zero),
133                );
134                _mm512_storeu_ps(
135                    op.add(i + 48),
136                    _mm512_max_ps(_mm512_loadu_ps(ip.add(i + 48)), zero),
137                );
138                i += 64;
139            }
140            while i + 16 <= n {
141                _mm512_storeu_ps(op.add(i), _mm512_max_ps(_mm512_loadu_ps(ip.add(i)), zero));
142                i += 16;
143            }
144        }
145        for j in i..n {
146            output[j] = input[j].max(0.0);
147        }
148    } // unsafe
149}
150
151/// Prefetch distance in bytes. 8 cache lines (512 bytes = 128 f32) ahead.
152/// Tuned for Zen 4 L1→L2 latency (~4ns) and L2→L3 latency (~12ns).
153/// At ~1 iteration/ns throughput, 512B ahead hides ~12ns L2 latency.
154const PREFETCH_DISTANCE: usize = 512;
155
156/// NT store threshold (bytes). Use non-temporal stores when total working set
157/// (2 inputs + 1 output = 3 arrays) exceeds L2 cache per core.
158/// Zen 4 L2 = 1MB/core. For add: 3 × data_bytes. NT is beneficial when
159/// data_bytes > ~333KB. Use 512KB for safety margin + alignment effects.
160/// Below this, data fits in L2 and cached stores are faster.
161const NT_STORE_THRESHOLD_BYTES: usize = 512 * 1024; // 512KB output = 128K f32
162
163#[cfg(target_arch = "x86_64")]
164#[target_feature(enable = "avx2")]
165unsafe fn relu_avx2(input: &[f32], output: &mut [f32]) {
166    use std::arch::x86_64::*;
167
168    let n = input.len();
169    let data_bytes = n * 4;
170
171    // For large arrays (>L3-stream threshold), use non-temporal stores
172    // ONLY if output is 32-byte aligned (required by _mm256_stream_ps).
173    // NT stores bypass cache write-allocate, eliminating RFO traffic.
174    let out_aligned = (output.as_ptr() as usize) % 32 == 0;
175    if data_bytes > NT_STORE_THRESHOLD_BYTES && out_aligned {
176        unsafe { relu_avx2_nt(input, output) }
177        return;
178    }
179
180    // 8× unrolled (64 elements per iteration) — no software prefetch.
181    // Hardware prefetcher on Zen 4/Intel 12th gen+ detects sequential
182    // streaming patterns and prefetches 2-4 cache lines ahead automatically.
183    // Software prefetch adds ~1 µop/32 elements of overhead without benefit
184    // for sequential access, and can interfere with HW prefetcher at L3 sizes.
185    let chunks = n / 64;
186    let remainder_64 = chunks * 64;
187
188    unsafe {
189        let zero = _mm256_setzero_ps();
190        let inp = input.as_ptr();
191        let out = output.as_mut_ptr();
192
193        for i in 0..chunks {
194            let base = i * 64;
195            let v0 = _mm256_loadu_ps(inp.add(base));
196            let v1 = _mm256_loadu_ps(inp.add(base + 8));
197            let v2 = _mm256_loadu_ps(inp.add(base + 16));
198            let v3 = _mm256_loadu_ps(inp.add(base + 24));
199            let v4 = _mm256_loadu_ps(inp.add(base + 32));
200            let v5 = _mm256_loadu_ps(inp.add(base + 40));
201            let v6 = _mm256_loadu_ps(inp.add(base + 48));
202            let v7 = _mm256_loadu_ps(inp.add(base + 56));
203            _mm256_storeu_ps(out.add(base), _mm256_max_ps(v0, zero));
204            _mm256_storeu_ps(out.add(base + 8), _mm256_max_ps(v1, zero));
205            _mm256_storeu_ps(out.add(base + 16), _mm256_max_ps(v2, zero));
206            _mm256_storeu_ps(out.add(base + 24), _mm256_max_ps(v3, zero));
207            _mm256_storeu_ps(out.add(base + 32), _mm256_max_ps(v4, zero));
208            _mm256_storeu_ps(out.add(base + 40), _mm256_max_ps(v5, zero));
209            _mm256_storeu_ps(out.add(base + 48), _mm256_max_ps(v6, zero));
210            _mm256_storeu_ps(out.add(base + 56), _mm256_max_ps(v7, zero));
211        }
212
213        let mut i = remainder_64;
214        while i + 8 <= n {
215            let v = _mm256_loadu_ps(inp.add(i));
216            _mm256_storeu_ps(out.add(i), _mm256_max_ps(v, zero));
217            i += 8;
218        }
219
220        while i < n {
221            *out.add(i) = (*inp.add(i)).max(0.0);
222            i += 1;
223        }
224    }
225}
226
227/// Non-temporal store variant for large arrays (>L2 cache size).
228/// Combines software prefetch pipeline with streaming stores to maximize
229/// DRAM bandwidth utilization. Write-combining buffers batch stores to
230/// full cache lines, eliminating read-for-ownership transactions.
231#[cfg(target_arch = "x86_64")]
232#[target_feature(enable = "avx2")]
233unsafe fn relu_avx2_nt(input: &[f32], output: &mut [f32]) {
234    use std::arch::x86_64::*;
235
236    let n = input.len();
237    let chunks = n / 32;
238    let remainder_32 = chunks * 32;
239
240    unsafe {
241        let zero = _mm256_setzero_ps();
242
243        for i in 0..chunks {
244            let base = i * 32;
245            // Prefetch input data ahead (L2→L3 latency hiding)
246            _mm_prefetch(
247                input.as_ptr().add(base + PREFETCH_DISTANCE / 4) as *const i8,
248                _MM_HINT_T0,
249            );
250            let v0 = _mm256_loadu_ps(input.as_ptr().add(base));
251            let v1 = _mm256_loadu_ps(input.as_ptr().add(base + 8));
252            let v2 = _mm256_loadu_ps(input.as_ptr().add(base + 16));
253            let v3 = _mm256_loadu_ps(input.as_ptr().add(base + 24));
254            // Non-temporal stores: bypass cache, write to WC buffers
255            _mm256_stream_ps(output.as_mut_ptr().add(base), _mm256_max_ps(v0, zero));
256            _mm256_stream_ps(output.as_mut_ptr().add(base + 8), _mm256_max_ps(v1, zero));
257            _mm256_stream_ps(output.as_mut_ptr().add(base + 16), _mm256_max_ps(v2, zero));
258            _mm256_stream_ps(output.as_mut_ptr().add(base + 24), _mm256_max_ps(v3, zero));
259        }
260
261        // Fence: ensure all NT stores are globally visible before return
262        _mm_sfence();
263
264        // Remainder with regular stores (< 1 cache line, no NT benefit)
265        let mut i = remainder_32;
266        while i + 8 <= n {
267            let v = _mm256_loadu_ps(input.as_ptr().add(i));
268            _mm256_storeu_ps(output.as_mut_ptr().add(i), _mm256_max_ps(v, zero));
269            i += 8;
270        }
271        while i < n {
272            output[i] = input[i].max(0.0);
273            i += 1;
274        }
275    }
276}
277
278// ============================================================================
279// Vector Add
280// ============================================================================
281
282/// Element-wise add: output_i = a_i + b_i
283///
284/// Uses AVX2 `_mm256_add_ps` when available.
285///
286/// # Errors
287///
288/// Returns `Err` if a, b, and output lengths don't match.
289pub fn add(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<(), TruenoError> {
290    let n = a.len();
291    if n != b.len() || n != output.len() {
292        return Err(TruenoError::InvalidInput(format!(
293            "add size mismatch: a[{}], b[{}], output[{}]",
294            n,
295            b.len(),
296            output.len()
297        )));
298    }
299    contract_pre_add!(a, b);
300
301    #[cfg(target_arch = "x86_64")]
302    {
303        // For bandwidth-bound sizes (>4K), let LLVM auto-vectorize.
304        // LLVM -O3 with target-cpu=native matches hand-written intrinsics
305        // without #[target_feature] calling convention overhead.
306        if n > 4096 {
307            add_autovec(a, b, output);
308            return Ok(());
309        }
310        if is_x86_feature_detected!("avx512f") {
311            unsafe {
312                add_avx512(a, b, output);
313            }
314            return Ok(());
315        }
316        if is_x86_feature_detected!("avx2") {
317            unsafe {
318                add_avx2(a, b, output);
319            }
320            return Ok(());
321        }
322    }
323
324    add_autovec(a, b, output);
325    contract_post_elementwise_parity!(output);
326    Ok(())
327}
328
329/// Add via simple loop — LLVM auto-vectorizes optimally.
330#[inline]
331fn add_autovec(a: &[f32], b: &[f32], output: &mut [f32]) {
332    for i in 0..a.len() {
333        output[i] = a[i] + b[i];
334    }
335}
336
337/// AVX-512 add with NT stores for large arrays.
338#[cfg(target_arch = "x86_64")]
339#[target_feature(enable = "avx512f")]
340unsafe fn add_avx512(a: &[f32], b: &[f32], output: &mut [f32]) {
341    use std::arch::x86_64::*;
342    unsafe {
343        let n = a.len();
344        let ap = a.as_ptr();
345        let bp = b.as_ptr();
346        let rp = output.as_mut_ptr();
347        let mut i = 0;
348
349        let data_bytes = n * 4;
350        let rp_aligned = (rp as usize) % 64 == 0;
351        if data_bytes > NT_STORE_THRESHOLD_BYTES && rp_aligned {
352            // NT path: 4-way unrolled, requires 64-byte aligned output
353            while i + 64 <= n {
354                // Guard prefetch to avoid reading past allocation (#242 SIGSEGV fix)
355                if i + 128 <= n {
356                    _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
357                    _mm_prefetch(bp.add(i + 128).cast::<i8>(), _MM_HINT_T0);
358                }
359
360                _mm512_stream_ps(
361                    rp.add(i),
362                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i)), _mm512_loadu_ps(bp.add(i))),
363                );
364                _mm512_stream_ps(
365                    rp.add(i + 16),
366                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 16)), _mm512_loadu_ps(bp.add(i + 16))),
367                );
368                _mm512_stream_ps(
369                    rp.add(i + 32),
370                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 32)), _mm512_loadu_ps(bp.add(i + 32))),
371                );
372                _mm512_stream_ps(
373                    rp.add(i + 48),
374                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 48)), _mm512_loadu_ps(bp.add(i + 48))),
375                );
376                i += 64;
377            }
378            while i + 16 <= n {
379                _mm512_stream_ps(
380                    rp.add(i),
381                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i)), _mm512_loadu_ps(bp.add(i))),
382                );
383                i += 16;
384            }
385            _mm_sfence();
386        } else {
387            while i + 64 <= n {
388                _mm512_storeu_ps(
389                    rp.add(i),
390                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i)), _mm512_loadu_ps(bp.add(i))),
391                );
392                _mm512_storeu_ps(
393                    rp.add(i + 16),
394                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 16)), _mm512_loadu_ps(bp.add(i + 16))),
395                );
396                _mm512_storeu_ps(
397                    rp.add(i + 32),
398                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 32)), _mm512_loadu_ps(bp.add(i + 32))),
399                );
400                _mm512_storeu_ps(
401                    rp.add(i + 48),
402                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i + 48)), _mm512_loadu_ps(bp.add(i + 48))),
403                );
404                i += 64;
405            }
406            while i + 16 <= n {
407                _mm512_storeu_ps(
408                    rp.add(i),
409                    _mm512_add_ps(_mm512_loadu_ps(ap.add(i)), _mm512_loadu_ps(bp.add(i))),
410                );
411                i += 16;
412            }
413        }
414        for j in i..n {
415            output[j] = a[j] + b[j];
416        }
417    } // unsafe
418}
419
420#[cfg(target_arch = "x86_64")]
421#[target_feature(enable = "avx2")]
422unsafe fn add_avx2(a: &[f32], b: &[f32], output: &mut [f32]) {
423    use std::arch::x86_64::*;
424
425    let n = a.len();
426    let data_bytes = n * 4;
427
428    // Large arrays: NT stores (bypass cache for DRAM-bound writes)
429    // Only if output is 32-byte aligned (required by _mm256_stream_ps).
430    let out_aligned = (output.as_ptr() as usize) % 32 == 0;
431    if data_bytes > NT_STORE_THRESHOLD_BYTES && out_aligned {
432        unsafe { add_avx2_nt(a, b, output) }
433        return;
434    }
435
436    // 8× unrolled (64 elements per iteration) — no software prefetch.
437    // Hardware prefetcher handles sequential dual-stream patterns efficiently.
438    let chunks = n / 64;
439    let remainder_64 = chunks * 64;
440
441    unsafe {
442        let ap = a.as_ptr();
443        let bp = b.as_ptr();
444        let op = output.as_mut_ptr();
445
446        for i in 0..chunks {
447            let base = i * 64;
448            // Interleaved loads from a and b for maximum load port utilization
449            let a0 = _mm256_loadu_ps(ap.add(base));
450            let b0 = _mm256_loadu_ps(bp.add(base));
451            let a1 = _mm256_loadu_ps(ap.add(base + 8));
452            let b1 = _mm256_loadu_ps(bp.add(base + 8));
453            let a2 = _mm256_loadu_ps(ap.add(base + 16));
454            let b2 = _mm256_loadu_ps(bp.add(base + 16));
455            let a3 = _mm256_loadu_ps(ap.add(base + 24));
456            let b3 = _mm256_loadu_ps(bp.add(base + 24));
457            let a4 = _mm256_loadu_ps(ap.add(base + 32));
458            let b4 = _mm256_loadu_ps(bp.add(base + 32));
459            let a5 = _mm256_loadu_ps(ap.add(base + 40));
460            let b5 = _mm256_loadu_ps(bp.add(base + 40));
461            let a6 = _mm256_loadu_ps(ap.add(base + 48));
462            let b6 = _mm256_loadu_ps(bp.add(base + 48));
463            let a7 = _mm256_loadu_ps(ap.add(base + 56));
464            let b7 = _mm256_loadu_ps(bp.add(base + 56));
465            _mm256_storeu_ps(op.add(base), _mm256_add_ps(a0, b0));
466            _mm256_storeu_ps(op.add(base + 8), _mm256_add_ps(a1, b1));
467            _mm256_storeu_ps(op.add(base + 16), _mm256_add_ps(a2, b2));
468            _mm256_storeu_ps(op.add(base + 24), _mm256_add_ps(a3, b3));
469            _mm256_storeu_ps(op.add(base + 32), _mm256_add_ps(a4, b4));
470            _mm256_storeu_ps(op.add(base + 40), _mm256_add_ps(a5, b5));
471            _mm256_storeu_ps(op.add(base + 48), _mm256_add_ps(a6, b6));
472            _mm256_storeu_ps(op.add(base + 56), _mm256_add_ps(a7, b7));
473        }
474
475        let mut i = remainder_64;
476        while i + 8 <= n {
477            let av = _mm256_loadu_ps(ap.add(i));
478            let bv = _mm256_loadu_ps(bp.add(i));
479            _mm256_storeu_ps(op.add(i), _mm256_add_ps(av, bv));
480            i += 8;
481        }
482
483        while i < n {
484            *op.add(i) = *ap.add(i) + *bp.add(i);
485            i += 1;
486        }
487    }
488}
489
490/// Non-temporal store variant of add for large arrays (>L2 cache).
491#[cfg(target_arch = "x86_64")]
492#[target_feature(enable = "avx2")]
493unsafe fn add_avx2_nt(a: &[f32], b: &[f32], output: &mut [f32]) {
494    use std::arch::x86_64::*;
495
496    let n = a.len();
497    let chunks = n / 32;
498    let remainder_32 = chunks * 32;
499
500    unsafe {
501        for i in 0..chunks {
502            let base = i * 32;
503            _mm_prefetch(a.as_ptr().add(base + PREFETCH_DISTANCE / 4) as *const i8, _MM_HINT_T0);
504            _mm_prefetch(b.as_ptr().add(base + PREFETCH_DISTANCE / 4) as *const i8, _MM_HINT_T0);
505            let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
506            let a1 = _mm256_loadu_ps(a.as_ptr().add(base + 8));
507            let a2 = _mm256_loadu_ps(a.as_ptr().add(base + 16));
508            let a3 = _mm256_loadu_ps(a.as_ptr().add(base + 24));
509            let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
510            let b1 = _mm256_loadu_ps(b.as_ptr().add(base + 8));
511            let b2 = _mm256_loadu_ps(b.as_ptr().add(base + 16));
512            let b3 = _mm256_loadu_ps(b.as_ptr().add(base + 24));
513            _mm256_stream_ps(output.as_mut_ptr().add(base), _mm256_add_ps(a0, b0));
514            _mm256_stream_ps(output.as_mut_ptr().add(base + 8), _mm256_add_ps(a1, b1));
515            _mm256_stream_ps(output.as_mut_ptr().add(base + 16), _mm256_add_ps(a2, b2));
516            _mm256_stream_ps(output.as_mut_ptr().add(base + 24), _mm256_add_ps(a3, b3));
517        }
518
519        _mm_sfence();
520
521        let mut i = remainder_32;
522        while i + 8 <= n {
523            let av = _mm256_loadu_ps(a.as_ptr().add(i));
524            let bv = _mm256_loadu_ps(b.as_ptr().add(i));
525            _mm256_storeu_ps(output.as_mut_ptr().add(i), _mm256_add_ps(av, bv));
526            i += 8;
527        }
528        while i < n {
529            output[i] = a[i] + b[i];
530            i += 1;
531        }
532    }
533}
534
535// ============================================================================
536// Scalar Multiply
537// ============================================================================
538
539/// Element-wise scalar multiply: output_i = input_i * scalar
540///
541/// Uses AVX2 `_mm256_mul_ps` when available.
542///
543/// # Errors
544///
545/// Returns `Err` if input and output lengths don't match.
546pub fn mul_scalar(input: &[f32], scalar: f32, output: &mut [f32]) -> Result<(), TruenoError> {
547    // Contract: elementwise-kernel-v1.yaml, equation = mul_scalar
548    debug_assert!(!input.is_empty(), "Contract mul_scalar: input is empty");
549    debug_assert!(scalar.is_finite(), "Contract mul_scalar: scalar is not finite");
550    let n = input.len();
551    if n != output.len() {
552        return Err(TruenoError::InvalidInput(format!(
553            "mul_scalar size mismatch: input[{}], output[{}]",
554            n,
555            output.len()
556        )));
557    }
558
559    #[cfg(target_arch = "x86_64")]
560    {
561        if is_x86_feature_detected!("avx2") {
562            unsafe {
563                mul_scalar_avx2(input, scalar, output);
564            }
565            return Ok(());
566        }
567    }
568
569    for i in 0..n {
570        output[i] = input[i] * scalar;
571    }
572    Ok(())
573}
574
575#[cfg(target_arch = "x86_64")]
576#[target_feature(enable = "avx2")]
577unsafe fn mul_scalar_avx2(input: &[f32], scalar: f32, output: &mut [f32]) {
578    use std::arch::x86_64::*;
579
580    let n = input.len();
581    let chunks = n / 32;
582    let remainder_32 = chunks * 32;
583
584    unsafe {
585        let s = _mm256_set1_ps(scalar);
586
587        for i in 0..chunks {
588            let base = i * 32;
589            let v0 = _mm256_loadu_ps(input.as_ptr().add(base));
590            let v1 = _mm256_loadu_ps(input.as_ptr().add(base + 8));
591            let v2 = _mm256_loadu_ps(input.as_ptr().add(base + 16));
592            let v3 = _mm256_loadu_ps(input.as_ptr().add(base + 24));
593            _mm256_storeu_ps(output.as_mut_ptr().add(base), _mm256_mul_ps(v0, s));
594            _mm256_storeu_ps(output.as_mut_ptr().add(base + 8), _mm256_mul_ps(v1, s));
595            _mm256_storeu_ps(output.as_mut_ptr().add(base + 16), _mm256_mul_ps(v2, s));
596            _mm256_storeu_ps(output.as_mut_ptr().add(base + 24), _mm256_mul_ps(v3, s));
597        }
598
599        let mut i = remainder_32;
600        while i + 8 <= n {
601            let v = _mm256_loadu_ps(input.as_ptr().add(i));
602            _mm256_storeu_ps(output.as_mut_ptr().add(i), _mm256_mul_ps(v, s));
603            i += 8;
604        }
605
606        while i < n {
607            output[i] = input[i] * scalar;
608            i += 1;
609        }
610    }
611}
612
613// ============================================================================
614// Allocating variants (skip zero-initialization)
615// ============================================================================
616
617/// ReLU with output allocation. Avoids zero-fill overhead of `vec![0.0; n]`.
618///
619/// # Safety guarantee
620///
621/// Output Vec is fully initialized by the SIMD/scalar loop before return.
622#[must_use]
623pub fn relu_alloc(input: &[f32]) -> Vec<f32> {
624    let n = input.len();
625    let mut output = vec![0.0f32; n];
626    let _ = relu(input, &mut output);
627    output
628}
629
630/// Element-wise add with output allocation. Avoids zero-fill overhead.
631///
632/// # Panics
633///
634/// Panics if `a` and `b` have different lengths.
635#[must_use]
636pub fn add_alloc(a: &[f32], b: &[f32]) -> Vec<f32> {
637    assert_eq!(a.len(), b.len(), "add_alloc: length mismatch");
638    let n = a.len();
639    let mut output = vec![0.0f32; n];
640    let _ = add(a, b, &mut output);
641    output
642}
643
644/// Scalar multiply with output allocation. Avoids zero-fill overhead.
645#[must_use]
646pub fn mul_scalar_alloc(input: &[f32], scalar: f32) -> Vec<f32> {
647    let n = input.len();
648    let mut output = vec![0.0f32; n];
649    let _ = mul_scalar(input, scalar, &mut output);
650    output
651}
652
653// ============================================================================
654// Fused Operations (PMAT-021)
655// ============================================================================
656// Fused ops reduce DRAM traffic by combining multiple element-wise operations
657// into a single pass. For bandwidth-bound workloads (>4K elements), this is
658// the only way to beat the DRAM bandwidth ceiling that limits individual ops
659// to ~1.0x vs ndarray. Reference: XLA compiler fusion (arXiv:1802.04730).
660
661/// Fused add + ReLU: output_i = max(0, a_i + b_i)
662///
663/// Single pass over data: 2 reads + 1 write = 12 bytes/element.
664/// Unfused equivalent (add then relu) would be 2+1+1+1 = 20 bytes/element.
665/// 40% bandwidth reduction.
666///
667/// # Errors
668///
669/// Returns `Err` if a, b, and output lengths don't match.
670pub fn fused_add_relu(a: &[f32], b: &[f32], output: &mut [f32]) -> Result<(), TruenoError> {
671    let n = a.len();
672    if n != b.len() || n != output.len() {
673        return Err(TruenoError::InvalidInput(format!(
674            "fused_add_relu size mismatch: a[{}], b[{}], output[{}]",
675            n,
676            b.len(),
677            output.len()
678        )));
679    }
680    // LLVM auto-vectorizes this optimally with -O3 -C target-cpu=native.
681    for i in 0..n {
682        output[i] = (a[i] + b[i]).max(0.0);
683    }
684    Ok(())
685}
686
687/// Fused multiply-add: output_i = a_i * b_i + c_i
688///
689/// Single pass: 3 reads + 1 write = 16 bytes/element.
690/// Unfused equivalent (mul then add) = 24 bytes/element.
691/// 33% bandwidth reduction. Maps directly to FMA SIMD instruction.
692///
693/// # Errors
694///
695/// Returns `Err` if a, b, c, and output lengths don't match.
696pub fn fused_mul_add(
697    a: &[f32],
698    b: &[f32],
699    c: &[f32],
700    output: &mut [f32],
701) -> Result<(), TruenoError> {
702    let n = a.len();
703    if n != b.len() || n != c.len() || n != output.len() {
704        return Err(TruenoError::InvalidInput(format!(
705            "fused_mul_add size mismatch: a[{}], b[{}], c[{}], output[{}]",
706            n,
707            b.len(),
708            c.len(),
709            output.len()
710        )));
711    }
712    for i in 0..n {
713        output[i] = a[i].mul_add(b[i], c[i]);
714    }
715    Ok(())
716}
717
718/// Fused scale + bias + ReLU: output_i = max(0, input_i * scale + bias)
719///
720/// Common in neural network inference (linear layer + activation).
721/// Single pass: 1 read + 1 write = 8 bytes/element.
722/// Unfused (scale, add bias, relu) = 24 bytes/element.
723/// 67% bandwidth reduction.
724///
725/// # Errors
726///
727/// Returns `Err` if input and output lengths don't match.
728pub fn fused_scale_bias_relu(
729    input: &[f32],
730    scale: f32,
731    bias: f32,
732    output: &mut [f32],
733) -> Result<(), TruenoError> {
734    let n = input.len();
735    if n != output.len() {
736        return Err(TruenoError::InvalidInput(format!(
737            "fused_scale_bias_relu size mismatch: input[{}], output[{}]",
738            n,
739            output.len()
740        )));
741    }
742    for i in 0..n {
743        output[i] = input[i].mul_add(scale, bias).max(0.0);
744    }
745    Ok(())
746}
747
748// ============================================================================
749// In-Place Operations
750// ============================================================================
751// In-place ops eliminate the output buffer entirely, reducing memory traffic
752// from 2 reads + 1 write to 1 read + 1 write (33% reduction for unary ops).
753
754/// In-place ReLU: data_i = max(0, data_i)
755///
756/// 1 read + 1 write = 8 bytes/element (vs 12 for out-of-place).
757#[inline]
758pub fn relu_inplace(data: &mut [f32]) {
759    for x in data.iter_mut() {
760        *x = x.max(0.0);
761    }
762}
763
764/// In-place add: a_i += b_i
765///
766/// 2 reads + 1 write = 12 bytes/element (same as out-of-place but no alloc).
767pub fn add_inplace(a: &mut [f32], b: &[f32]) -> Result<(), TruenoError> {
768    if a.len() != b.len() {
769        return Err(TruenoError::InvalidInput(format!(
770            "add_inplace size mismatch: a[{}], b[{}]",
771            a.len(),
772            b.len()
773        )));
774    }
775    for i in 0..a.len() {
776        a[i] += b[i];
777    }
778    Ok(())
779}
780
781/// In-place scale: data_i *= scalar
782///
783/// 1 read + 1 write = 8 bytes/element.
784#[inline]
785pub fn scale_inplace(data: &mut [f32], scalar: f32) {
786    for x in data.iter_mut() {
787        *x *= scalar;
788    }
789}
790
791/// In-place fused add + ReLU: a_i = max(0, a_i + b_i)
792///
793/// 2 reads + 1 write = 12 bytes/element. Unfused in-place (add then relu)
794/// would be 2×(read+write) = 16 bytes. 25% reduction.
795pub fn fused_add_relu_inplace(a: &mut [f32], b: &[f32]) -> Result<(), TruenoError> {
796    if a.len() != b.len() {
797        return Err(TruenoError::InvalidInput(format!(
798            "fused_add_relu_inplace size mismatch: a[{}], b[{}]",
799            a.len(),
800            b.len()
801        )));
802    }
803    for i in 0..a.len() {
804        a[i] = (a[i] + b[i]).max(0.0);
805    }
806    Ok(())
807}
808
809// ============================================================================
810// Tests
811// ============================================================================
812
813#[cfg(test)]
814mod tests {
815    use super::*;
816
817    // ── ReLU tests ────────────────────────────────────────────────────────
818
819    #[test]
820    fn test_relu_basic() {
821        let input = [-1.0, 0.0, 1.0, -0.5, 2.0, -3.0, 0.1, -0.1];
822        let expected = [0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.1, 0.0];
823        let mut output = vec![0.0f32; 8];
824        relu(&input, &mut output).unwrap();
825        assert_eq!(output, expected);
826    }
827
828    #[test]
829    fn test_relu_large() {
830        let n = 11008; // FFN intermediate size
831        let input: Vec<f32> =
832            (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
833        let mut output = vec![0.0f32; n];
834        relu(&input, &mut output).unwrap();
835        for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
836            assert_eq!(out, inp.max(0.0), "ReLU mismatch at {i}");
837        }
838    }
839
840    #[test]
841    fn test_relu_avx2_scalar_parity() {
842        for n in [1, 7, 8, 15, 16, 31, 32, 63, 64, 128, 4096] {
843            let input: Vec<f32> =
844                (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 500.0 - 1.0).collect();
845            let mut output = vec![0.0f32; n];
846            relu(&input, &mut output).unwrap();
847            for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
848                assert_eq!(out, inp.max(0.0), "ReLU parity at [{i}] n={n}");
849            }
850        }
851    }
852
853    #[test]
854    fn test_relu_error_mismatch() {
855        let input = vec![1.0f32; 4];
856        let mut output = vec![0.0f32; 3];
857        assert!(relu(&input, &mut output).is_err());
858    }
859
860    // ── Add tests ─────────────────────────────────────────────────────────
861
862    #[test]
863    fn test_add_basic() {
864        let a = [1.0, 2.0, 3.0, 4.0];
865        let b = [10.0, 20.0, 30.0, 40.0];
866        let mut output = vec![0.0f32; 4];
867        add(&a, &b, &mut output).unwrap();
868        assert_eq!(output, vec![11.0, 22.0, 33.0, 44.0]);
869    }
870
871    #[test]
872    fn test_add_large() {
873        let n = 4096;
874        let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
875        let b: Vec<f32> = (0..n).map(|i| (i * 2) as f32).collect();
876        let mut output = vec![0.0f32; n];
877        add(&a, &b, &mut output).unwrap();
878        for i in 0..n {
879            assert_eq!(output[i], a[i] + b[i], "Add mismatch at {i}");
880        }
881    }
882
883    #[test]
884    fn test_add_avx2_scalar_parity() {
885        for n in [1, 7, 8, 15, 16, 31, 32, 63, 64, 128, 4096] {
886            let a: Vec<f32> = (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 500.0 - 1.0).collect();
887            let b: Vec<f32> = (0..n).map(|i| ((i * 13 + 7) % 1000) as f32 / 500.0 - 1.0).collect();
888            let mut output = vec![0.0f32; n];
889            add(&a, &b, &mut output).unwrap();
890            for i in 0..n {
891                assert_eq!(output[i], a[i] + b[i], "Add parity at [{i}] n={n}");
892            }
893        }
894    }
895
896    #[test]
897    fn test_add_error_mismatch() {
898        let a = vec![1.0f32; 4];
899        let b = vec![1.0f32; 3];
900        let mut output = vec![0.0f32; 4];
901        assert!(add(&a, &b, &mut output).is_err());
902    }
903
904    // ── Mul scalar tests ──────────────────────────────────────────────────
905
906    #[test]
907    fn test_mul_scalar_basic() {
908        let input = [1.0, 2.0, 3.0, 4.0];
909        let mut output = vec![0.0f32; 4];
910        mul_scalar(&input, 2.5, &mut output).unwrap();
911        assert_eq!(output, vec![2.5, 5.0, 7.5, 10.0]);
912    }
913
914    #[test]
915    fn test_mul_scalar_large() {
916        let n = 4096;
917        let input: Vec<f32> = (0..n).map(|i| i as f32).collect();
918        let mut output = vec![0.0f32; n];
919        mul_scalar(&input, std::f32::consts::PI, &mut output).unwrap();
920        for i in 0..n {
921            assert!(
922                (output[i] - input[i] * std::f32::consts::PI).abs() < 1e-5,
923                "Mul scalar mismatch at {i}"
924            );
925        }
926    }
927
928    #[test]
929    fn test_mul_scalar_avx2_scalar_parity() {
930        for n in [1, 7, 8, 15, 16, 31, 32, 63, 64, 128, 4096] {
931            let input: Vec<f32> =
932                (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 500.0 - 1.0).collect();
933            let mut output = vec![0.0f32; n];
934            mul_scalar(&input, std::f32::consts::E, &mut output).unwrap();
935            for i in 0..n {
936                assert!(
937                    (output[i] - input[i] * std::f32::consts::E).abs() < 1e-4,
938                    "Mul scalar parity at [{i}] n={n}",
939                );
940            }
941        }
942    }
943
944    #[test]
945    fn test_mul_scalar_error_mismatch() {
946        let input = vec![1.0f32; 4];
947        let mut output = vec![0.0f32; 3];
948        assert!(mul_scalar(&input, 1.0, &mut output).is_err());
949    }
950
951    // ── Fused ops tests (PMAT-021) ──────────────────────────────────────
952
953    #[test]
954    fn test_fused_add_relu_basic() {
955        let a = vec![-2.0, -1.0, 0.0, 1.0, 2.0, -0.5, 0.5, 3.0];
956        let b = vec![1.0, 0.5, -1.0, -2.0, 0.0, 1.0, -1.0, -4.0];
957        let mut out = vec![0.0f32; 8];
958        fused_add_relu(&a, &b, &mut out).unwrap();
959        let expected: Vec<f32> = a.iter().zip(&b).map(|(a, b)| (a + b).max(0.0)).collect();
960        assert_eq!(out, expected);
961    }
962
963    #[test]
964    fn test_fused_add_relu_large() {
965        let n = 10_000;
966        let a: Vec<f32> = (0..n).map(|i| (i as f32 - 5000.0) / 100.0).collect();
967        let b: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3) - 1500.0).collect();
968        let mut out = vec![0.0f32; n];
969        fused_add_relu(&a, &b, &mut out).unwrap();
970        for i in 0..n {
971            assert_eq!(out[i], (a[i] + b[i]).max(0.0), "mismatch at {i}");
972        }
973    }
974
975    #[test]
976    fn test_fused_mul_add_basic() {
977        let a = vec![1.0, 2.0, 3.0, 4.0];
978        let b = vec![2.0, 3.0, 4.0, 5.0];
979        let c = vec![0.5, 0.5, 0.5, 0.5];
980        let mut out = vec![0.0f32; 4];
981        fused_mul_add(&a, &b, &c, &mut out).unwrap();
982        let expected: Vec<f32> = (0..4).map(|i| a[i].mul_add(b[i], c[i])).collect();
983        assert_eq!(out, expected);
984    }
985
986    #[test]
987    fn test_fused_scale_bias_relu_basic() {
988        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
989        let mut out = vec![0.0f32; 5];
990        fused_scale_bias_relu(&input, 2.0, 1.0, &mut out).unwrap();
991        // 2*x + 1, then relu: [-3,0] [-1,0] [1] [3] [5]
992        assert_eq!(out, vec![0.0, 0.0, 1.0, 3.0, 5.0]);
993    }
994}