Skip to main content

yscv_kernels/ops/
simd.rs

1#[cfg(target_arch = "aarch64")]
2use std::arch::aarch64::{
3    float32x4_t, vaddq_f32, vdivq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vmaxq_f32, vminq_f32,
4    vmulq_f32, vnegq_f32, vst1q_f32, vsubq_f32,
5};
6#[cfg(target_arch = "x86")]
7use std::arch::x86::{
8    __m128, __m256, _mm_add_ps, _mm_castsi128_ps, _mm_cvtepi32_ps, _mm_cvtps_epi32, _mm_loadu_ps,
9    _mm_max_ps, _mm_min_ps, _mm_mul_ps, _mm_set1_epi32, _mm_set1_ps, _mm_setzero_ps, _mm_storeu_ps,
10    _mm_sub_ps, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtepi32_ps, _mm256_cvtps_epi32,
11    _mm256_loadu_ps, _mm256_max_ps, _mm256_min_ps, _mm256_mul_ps, _mm256_set1_epi32,
12    _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
13};
14#[cfg(target_arch = "x86_64")]
15use std::arch::x86_64::{
16    __m128, __m256, _mm_add_ps, _mm_castsi128_ps, _mm_cvtepi32_ps, _mm_cvtps_epi32, _mm_loadu_ps,
17    _mm_max_ps, _mm_min_ps, _mm_mul_ps, _mm_set1_epi32, _mm_set1_ps, _mm_setzero_ps, _mm_storeu_ps,
18    _mm_sub_ps, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtepi32_ps, _mm256_cvtps_epi32,
19    _mm256_loadu_ps, _mm256_max_ps, _mm256_min_ps, _mm256_mul_ps, _mm256_set1_epi32,
20    _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
21};
22
23use super::config::BinaryKind;
24
25#[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
26#[allow(unsafe_code, dead_code)]
27unsafe extern "C" {
28    fn vsAdd(n: i32, a: *const f32, b: *const f32, y: *mut f32);
29    fn vsSub(n: i32, a: *const f32, b: *const f32, y: *mut f32);
30    fn vsMul(n: i32, a: *const f32, b: *const f32, y: *mut f32);
31    fn vsDiv(n: i32, a: *const f32, b: *const f32, y: *mut f32);
32    fn vsExp(n: i32, a: *const f32, y: *mut f32);
33    fn vsSqrt(n: i32, a: *const f32, y: *mut f32);
34    fn vsLn(n: i32, a: *const f32, y: *mut f32);
35}
36
37#[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
38#[allow(unsafe_code, dead_code)]
39unsafe extern "C" {
40    fn armpl_svexp_f32(n: i32, x: *const f32, y: *mut f32);
41    fn armpl_svadd_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
42    fn armpl_svsub_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
43    fn armpl_svmul_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
44    fn armpl_svlog_f32(n: i32, x: *const f32, y: *mut f32);
45    fn armpl_svsqrt_f32(n: i32, x: *const f32, y: *mut f32);
46}
47
48#[cfg(target_os = "macos")]
49#[allow(unsafe_code, dead_code)]
50unsafe extern "C" {
51    fn vvexpf(result: *mut f32, input: *const f32, count: *const i32);
52    fn vDSP_vadd(
53        __A: *const f32,
54        __IA: i32,
55        __B: *const f32,
56        __IB: i32,
57        __C: *mut f32,
58        __IC: i32,
59        __N: u32,
60    );
61    fn vDSP_vsub(
62        __B: *const f32,
63        __IB: i32,
64        __A: *const f32,
65        __IA: i32,
66        __C: *mut f32,
67        __IC: i32,
68        __N: u32,
69    );
70    fn vDSP_vmul(
71        __A: *const f32,
72        __IA: i32,
73        __B: *const f32,
74        __IB: i32,
75        __C: *mut f32,
76        __IC: i32,
77        __N: u32,
78    );
79}
80
81// ===========================================================================
82// ReLU dispatch
83// ===========================================================================
84
85#[allow(unsafe_code)]
86#[inline]
87pub fn relu_slice_dispatch(values: &mut [f32]) {
88    if cfg!(miri) {
89        // SAFETY: scalar path only reads/writes within `values` bounds.
90        unsafe {
91            relu_slice_scalar(values);
92        }
93        return;
94    }
95
96    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
97    {
98        if std::is_x86_feature_detected!("avx") {
99            // SAFETY: guarded by runtime feature detection.
100            unsafe {
101                relu_slice_avx(values);
102            }
103            return;
104        }
105        if std::is_x86_feature_detected!("sse") {
106            // SAFETY: guarded by runtime feature detection.
107            unsafe {
108                relu_slice_sse(values);
109            }
110            return;
111        }
112    }
113
114    #[cfg(target_arch = "aarch64")]
115    {
116        if std::arch::is_aarch64_feature_detected!("neon") {
117            // SAFETY: guarded by runtime feature detection.
118            unsafe {
119                relu_slice_neon(values);
120            }
121            return;
122        }
123    }
124
125    // SAFETY: scalar path only reads/writes within `values` bounds.
126    unsafe {
127        relu_slice_scalar(values);
128    }
129}
130
131/// Two-argument ReLU: `output[i] = max(0, input[i])`.
132///
133/// Avoids the clone+in-place pattern by reading from `input` and writing to
134/// `output` in a single pass, halving memory traffic.
135#[allow(unsafe_code)]
136#[inline]
137pub fn relu_to_slice_dispatch(input: &[f32], output: &mut [f32]) {
138    debug_assert_eq!(input.len(), output.len());
139
140    if cfg!(miri) {
141        // SAFETY: scalar path only reads/writes within bounds.
142        unsafe {
143            relu_to_slice_scalar(input, output);
144        }
145        return;
146    }
147
148    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
149    {
150        if std::is_x86_feature_detected!("avx") {
151            // SAFETY: guarded by runtime feature detection.
152            unsafe {
153                relu_to_slice_avx(input, output);
154            }
155            return;
156        }
157        if std::is_x86_feature_detected!("sse") {
158            // SAFETY: guarded by runtime feature detection.
159            unsafe {
160                relu_to_slice_sse(input, output);
161            }
162            return;
163        }
164    }
165
166    #[cfg(target_arch = "aarch64")]
167    {
168        if std::arch::is_aarch64_feature_detected!("neon") {
169            // SAFETY: guarded by runtime feature detection.
170            unsafe {
171                relu_to_slice_neon(input, output);
172            }
173            return;
174        }
175    }
176
177    // SAFETY: scalar path only reads/writes within bounds.
178    unsafe {
179        relu_to_slice_scalar(input, output);
180    }
181}
182
183#[inline]
184#[allow(dead_code)]
185pub(crate) fn sigmoid_slice(values: &mut [f32]) {
186    for value in values {
187        *value = sigmoid_scalar(*value);
188    }
189}
190
191#[inline]
192pub(crate) fn sigmoid_scalar(value: f32) -> f32 {
193    if value >= 0.0 {
194        let z = (-value).exp();
195        1.0 / (1.0 + z)
196    } else {
197        let z = value.exp();
198        z / (1.0 + z)
199    }
200}
201
202// ===========================================================================
203// Exp / Sigmoid / Tanh SIMD dispatch
204// ===========================================================================
205
206/// Fast exp approximation applied element-wise: `output[i] = exp(input[i])`.
207///
208/// Uses a polynomial approximation (degree-4 minimax on [-88, 88]) that is
209/// accurate to roughly 1e-4 relative error for the typical NN activation range.
210#[allow(unsafe_code, unreachable_code)]
211#[inline]
212pub fn exp_slice_dispatch(input: &[f32], output: &mut [f32]) {
213    debug_assert_eq!(input.len(), output.len());
214
215    if cfg!(miri) {
216        exp_slice_scalar(input, output);
217        return;
218    }
219
220    // macOS aarch64: use Apple Accelerate vvexpf (heavily optimized).
221    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
222    {
223        let count = input.len() as i32;
224        // SAFETY: vvexpf reads `count` floats from `input` and writes to `output`.
225        // Both slices have equal length (debug_assert above).
226        unsafe {
227            vvexpf(output.as_mut_ptr(), input.as_ptr(), &count);
228        }
229        return;
230    }
231
232    // x86/x86_64 with MKL: use Intel VML vsExp (heavily optimized).
233    #[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
234    {
235        let count = input.len() as i32;
236        // SAFETY: vsExp reads `count` floats from `input` and writes to `output`.
237        unsafe { vsExp(count, input.as_ptr(), output.as_mut_ptr()) };
238        return;
239    }
240
241    // aarch64 Linux with ARMPL: use ARM Performance Libraries vectorized exp.
242    #[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
243    {
244        let count = input.len() as i32;
245        // SAFETY: armpl_svexp_f32 reads `count` floats from `input` and writes to `output`.
246        unsafe { armpl_svexp_f32(count, input.as_ptr(), output.as_mut_ptr()) };
247        return;
248    }
249
250    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
251    {
252        if std::is_x86_feature_detected!("avx") {
253            // SAFETY: guarded by runtime feature detection.
254            unsafe {
255                exp_slice_avx(input, output);
256            }
257            return;
258        }
259        if std::is_x86_feature_detected!("sse") {
260            // SAFETY: guarded by runtime feature detection.
261            unsafe {
262                exp_slice_sse(input, output);
263            }
264            return;
265        }
266    }
267
268    #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
269    {
270        if std::arch::is_aarch64_feature_detected!("neon") {
271            // SAFETY: guarded by runtime feature detection.
272            unsafe {
273                exp_slice_neon(input, output);
274            }
275            return;
276        }
277    }
278
279    exp_slice_scalar(input, output);
280}
281
282/// Fused subtract-and-exp: `output[i] = exp(input[i] - offset)`.
283///
284/// Combines the max-subtraction and exp steps of softmax into one pass,
285/// avoiding an extra read/write of the output buffer.
286#[allow(unsafe_code)]
287#[inline]
288pub fn sub_exp_slice_dispatch(input: &[f32], offset: f32, output: &mut [f32]) {
289    debug_assert_eq!(input.len(), output.len());
290
291    if cfg!(miri) {
292        sub_exp_slice_scalar(input, offset, output);
293        return;
294    }
295
296    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
297    {
298        if std::is_x86_feature_detected!("avx") {
299            // SAFETY: guarded by runtime feature detection.
300            unsafe {
301                sub_exp_slice_avx(input, offset, output);
302            }
303            return;
304        }
305        if std::is_x86_feature_detected!("sse") {
306            // SAFETY: guarded by runtime feature detection.
307            unsafe {
308                sub_exp_slice_sse(input, offset, output);
309            }
310            return;
311        }
312    }
313
314    #[cfg(target_arch = "aarch64")]
315    {
316        if std::arch::is_aarch64_feature_detected!("neon") {
317            // SAFETY: guarded by runtime feature detection.
318            unsafe {
319                sub_exp_slice_neon(input, offset, output);
320            }
321            return;
322        }
323    }
324
325    sub_exp_slice_scalar(input, offset, output);
326}
327
328/// Fast sigmoid applied element-wise: `output[i] = 1 / (1 + exp(-input[i]))`.
329#[allow(unsafe_code, clippy::needless_return)]
330#[inline]
331pub fn sigmoid_slice_dispatch(input: &[f32], output: &mut [f32]) {
332    debug_assert_eq!(input.len(), output.len());
333
334    if cfg!(miri) {
335        sigmoid_slice_dispatch_scalar(input, output);
336        return;
337    }
338
339    // NEON / AVX / SSE dispatch for sigmoid.
340    {
341        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
342        {
343            if std::is_x86_feature_detected!("avx") {
344                // SAFETY: guarded by runtime feature detection.
345                unsafe {
346                    sigmoid_slice_avx(input, output);
347                }
348                return;
349            }
350            if std::is_x86_feature_detected!("sse") {
351                // SAFETY: guarded by runtime feature detection.
352                unsafe {
353                    sigmoid_slice_sse(input, output);
354                }
355                return;
356            }
357        }
358
359        #[cfg(target_arch = "aarch64")]
360        {
361            if std::arch::is_aarch64_feature_detected!("neon") {
362                unsafe {
363                    sigmoid_slice_neon(input, output);
364                }
365                return;
366            }
367        }
368
369        sigmoid_slice_dispatch_scalar(input, output);
370    }
371}
372
373#[cfg(target_arch = "aarch64")]
374#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
375#[target_feature(enable = "neon")]
376#[inline]
377/// Fast exp for sigmoid: range reduction + 3-term Horner + IEEE bit trick.
378/// WHY 3 terms: 3rd-order polynomial suffices for sigmoid (1/(1+exp) dampens error); max error ~1e-4.
379unsafe fn fast_exp_sigmoid_neon(x: float32x4_t) -> float32x4_t {
380    use std::arch::aarch64::{
381        vaddq_s32, vcvtnq_s32_f32, vcvtq_f32_s32, vdupq_n_s32, vreinterpretq_f32_s32, vshlq_n_s32,
382        vsubq_f32,
383    };
384    let x = vmaxq_f32(vdupq_n_f32(-88.0), vminq_f32(vdupq_n_f32(88.0), x));
385    let n_f = vmulq_f32(x, vdupq_n_f32(std::f32::consts::LOG2_E));
386    let n_i = vcvtnq_s32_f32(n_f);
387    let r = vsubq_f32(
388        x,
389        vmulq_f32(vcvtq_f32_s32(n_i), vdupq_n_f32(std::f32::consts::LN_2)),
390    );
391    let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(n_i, vdupq_n_s32(127))));
392    let p = vfmaq_f32(vdupq_n_f32(0.5), r, vdupq_n_f32(1.0 / 6.0));
393    let p = vfmaq_f32(vdupq_n_f32(1.0), r, p);
394    vmulq_f32(vfmaq_f32(vdupq_n_f32(1.0), r, p), pow2n)
395}
396
397/// Sigmoid via hand-scheduled NEON assembly.
398///
399/// Processes 4 elements per iteration with interleaved load/compute/store.
400/// The FMA pipeline is kept fully saturated by overlapping independent operations.
401#[cfg(target_arch = "aarch64")]
402#[allow(unsafe_code)]
403unsafe fn sigmoid_slice_neon(input: &[f32], output: &mut [f32]) {
404    let len = input.len();
405    let mut inp = input.as_ptr();
406    let mut out = output.as_mut_ptr();
407    let mut remaining = len;
408
409    // Load all constants ONCE before the loop, keep in NEON registers
410    if remaining >= 4 {
411        unsafe {
412            // Constants on stack for ld1r broadcast
413            let c_neg88: f32 = -88.0;
414            let c_pos88: f32 = 88.0;
415            // Schraudolph 1999 constants: exp(x) ≈ reinterpret(int(x * C + B))
416            // C = 2^23 / ln(2) = 12102203.16, B = 127 * 2^23 = 1065353216
417            // WHY: 2^23/ln(2) maps float mantissa bits to IEEE 754 exponent field; 127*2^23 adds the exponent bias.
418            let c_schr_c: f32 = 12102203.0; // 2^23 / ln(2)
419            let c_schr_b: i32 = 127 << 23; // 1065353216 as integer
420            let c_sixth: f32 = 1.0 / 6.0;
421            let c_half: f32 = 0.5;
422            let c_one: f32 = 1.0;
423            let c_127: i32 = 127;
424
425            // Load constants into NEON registers (stays there for entire loop)
426            std::arch::asm!(
427                "ld1r {{v16.4s}}, [{p_neg88}]",
428                "ld1r {{v17.4s}}, [{p_pos88}]",
429                "ld1r {{v18.4s}}, [{p_schr_c}]",   // Schraudolph C (float)
430                "dup  v19.4s, {p_schr_b:w}",        // Schraudolph B (integer 127<<23)
431                "ld1r {{v20.4s}}, [{p_sixth}]",
432                "ld1r {{v21.4s}}, [{p_half}]",
433                "ld1r {{v22.4s}}, [{p_one}]",
434                "dup  v23.4s, {p_127:w}",
435                p_neg88 = in(reg) &c_neg88,
436                p_pos88 = in(reg) &c_pos88,
437                p_schr_c = in(reg) &c_schr_c,
438                p_schr_b = in(reg) c_schr_b,
439                p_sixth = in(reg) &c_sixth,
440                p_half = in(reg) &c_half,
441                p_one = in(reg) &c_one,
442                p_127 = in(reg) c_127,
443                out("v16") _, out("v17") _, out("v18") _, out("v19") _,
444                out("v20") _, out("v21") _, out("v22") _, out("v23") _,
445            );
446
447            // Schraudolph bit-trick: exp(x) ≈ reinterpret_f32(int(x * 2^23/ln2) + 127<<23)
448            // Proper integer arithmetic: fcvtzs to get int, then add bias as int, then reinterpret
449            // 4× unrolled, 16 elements per iteration
450            while remaining >= 16 {
451                std::arch::asm!(
452                    "ldp q0, q1, [{inp}]",
453                    "ldp q2, q3, [{inp}, #32]",
454                    "add {inp}, {inp}, #64",
455                    "fneg v0.4s, v0.4s",
456                    "fneg v1.4s, v1.4s",
457                    "fneg v2.4s, v2.4s",
458                    "fneg v3.4s, v3.4s",
459                    "fmax v0.4s, v0.4s, v16.4s",
460                    "fmax v1.4s, v1.4s, v16.4s",
461                    "fmax v2.4s, v2.4s, v16.4s",
462                    "fmax v3.4s, v3.4s, v16.4s",
463                    "fmin v0.4s, v0.4s, v17.4s",
464                    "fmin v1.4s, v1.4s, v17.4s",
465                    "fmin v2.4s, v2.4s, v17.4s",
466                    "fmin v3.4s, v3.4s, v17.4s",
467                    // x * (2^23/ln2) → convert to int
468                    "fmul v0.4s, v0.4s, v18.4s",
469                    "fmul v1.4s, v1.4s, v18.4s",
470                    "fmul v2.4s, v2.4s, v18.4s",
471                    "fmul v3.4s, v3.4s, v18.4s",
472                    "fcvtzs v0.4s, v0.4s",
473                    "fcvtzs v1.4s, v1.4s",
474                    "fcvtzs v2.4s, v2.4s",
475                    "fcvtzs v3.4s, v3.4s",
476                    // + 127*2^23 (integer add)
477                    "add v0.4s, v0.4s, v19.4s",
478                    "add v1.4s, v1.4s, v19.4s",
479                    "add v2.4s, v2.4s, v19.4s",
480                    "add v3.4s, v3.4s, v19.4s",
481                    // v0-v3 bits ARE exp(-x) when reinterpreted as float
482                    // sigmoid = 1 / (1 + exp)
483                    "fadd v0.4s, v22.4s, v0.4s",
484                    "fadd v1.4s, v22.4s, v1.4s",
485                    "fadd v2.4s, v22.4s, v2.4s",
486                    "fadd v3.4s, v22.4s, v3.4s",
487                    "fdiv v0.4s, v22.4s, v0.4s",
488                    "fdiv v1.4s, v22.4s, v1.4s",
489                    "fdiv v2.4s, v22.4s, v2.4s",
490                    "fdiv v3.4s, v22.4s, v3.4s",
491                    "stp q0, q1, [{out}]",
492                    "stp q2, q3, [{out}, #32]",
493                    "add {out}, {out}, #64",
494                    inp = inout(reg) inp,
495                    out = inout(reg) out,
496                    out("v0") _, out("v1") _, out("v2") _, out("v3") _,
497                );
498                remaining -= 16;
499            }
500            // 4-element tail — Schraudolph
501            while remaining >= 4 {
502                std::arch::asm!(
503                    "ld1 {{v0.4s}}, [{inp}], #16",
504                    "fneg v0.4s, v0.4s",
505                    "fmax v0.4s, v0.4s, v16.4s",
506                    "fmin v0.4s, v0.4s, v17.4s",
507                    "fmul v0.4s, v0.4s, v18.4s",
508                    "fcvtzs v0.4s, v0.4s",
509                    "add v0.4s, v0.4s, v19.4s",
510                    "fadd v0.4s, v22.4s, v0.4s",
511                    "fdiv v0.4s, v22.4s, v0.4s",
512                    "st1 {{v0.4s}}, [{out}], #16",
513                    inp = inout(reg) inp,
514                    out = inout(reg) out,
515                    out("v0") _,
516                );
517                remaining -= 4;
518            }
519            // 4-element tail — Schraudolph
520            while remaining >= 4 {
521                std::arch::asm!(
522                    "ld1 {{v0.4s}}, [{inp}], #16",
523                    "fneg v0.4s, v0.4s",
524                    "fmax v0.4s, v0.4s, v16.4s",
525                    "fmin v0.4s, v0.4s, v17.4s",
526                    "fmul v0.4s, v0.4s, v18.4s",
527                    "fcvtzs v0.4s, v0.4s",
528                    "add v0.4s, v0.4s, v19.4s",
529                    "fadd v0.4s, v22.4s, v0.4s",
530                    "fdiv v0.4s, v22.4s, v0.4s",
531                    "st1 {{v0.4s}}, [{out}], #16",
532                    inp = inout(reg) inp,
533                    out = inout(reg) out,
534                    out("v0") _,
535                );
536                remaining -= 4;
537            }
538        }
539    }
540
541    // Scalar tail
542    for i in 0..remaining {
543        unsafe {
544            let x = *inp.add(i);
545            *out.add(i) = 1.0 / (1.0 + (-x).exp());
546        }
547    }
548}
549
550// (sigmoid_vdsp and silu_vdsp removed — benchmarked slower than NEON polynomial)
551
552/// Fast tanh applied element-wise: `output[i] = tanh(input[i])`.
553///
554/// Computed as `2 * sigmoid(2x) - 1`.
555#[allow(unsafe_code)]
556#[inline]
557pub fn tanh_slice_dispatch(input: &[f32], output: &mut [f32]) {
558    debug_assert_eq!(input.len(), output.len());
559
560    if cfg!(miri) {
561        tanh_slice_dispatch_scalar(input, output);
562        return;
563    }
564
565    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
566    {
567        if std::is_x86_feature_detected!("avx") {
568            // SAFETY: guarded by runtime feature detection.
569            unsafe {
570                tanh_slice_avx(input, output);
571            }
572            return;
573        }
574        if std::is_x86_feature_detected!("sse") {
575            // SAFETY: guarded by runtime feature detection.
576            unsafe {
577                tanh_slice_sse(input, output);
578            }
579            return;
580        }
581    }
582
583    #[cfg(target_arch = "aarch64")]
584    {
585        if std::arch::is_aarch64_feature_detected!("neon") {
586            // SAFETY: guarded by runtime feature detection.
587            unsafe {
588                tanh_slice_neon(input, output);
589            }
590            return;
591        }
592    }
593
594    tanh_slice_dispatch_scalar(input, output);
595}
596
597/// Fused SiLU (Swish) applied element-wise: `output[i] = input[i] * sigmoid(input[i])`.
598///
599/// Single-pass over the data avoids the 2× bandwidth penalty of separate sigmoid + multiply.
600#[allow(unsafe_code)]
601#[inline]
602pub fn silu_slice_dispatch(input: &[f32], output: &mut [f32]) {
603    debug_assert_eq!(input.len(), output.len());
604
605    if cfg!(miri) {
606        silu_slice_dispatch_scalar(input, output);
607        return;
608    }
609
610    #[cfg(target_arch = "aarch64")]
611    {
612        if std::arch::is_aarch64_feature_detected!("neon") {
613            unsafe {
614                silu_slice_neon(input, output);
615            }
616            return;
617        }
618    }
619
620    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
621    {
622        if std::is_x86_feature_detected!("avx") {
623            unsafe { silu_slice_avx(input, output) };
624            return;
625        }
626        if std::is_x86_feature_detected!("sse") {
627            unsafe { silu_slice_sse(input, output) };
628            return;
629        }
630    }
631
632    silu_slice_dispatch_scalar(input, output);
633}
634
635// ===========================================================================
636// Reduction dispatchers: max_reduce, add_reduce
637// ===========================================================================
638
639/// Find the maximum value in `data`.  Returns `f32::NEG_INFINITY` for empty slices.
640#[allow(unsafe_code, dead_code)]
641#[inline]
642pub fn max_reduce_dispatch(data: &[f32]) -> f32 {
643    if data.is_empty() {
644        return f32::NEG_INFINITY;
645    }
646
647    if cfg!(miri) {
648        return max_reduce_scalar(data);
649    }
650
651    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
652    {
653        if std::is_x86_feature_detected!("avx") {
654            // SAFETY: guarded by runtime feature detection.
655            return unsafe { max_reduce_avx(data) };
656        }
657        if std::is_x86_feature_detected!("sse") {
658            // SAFETY: guarded by runtime feature detection.
659            return unsafe { max_reduce_sse(data) };
660        }
661    }
662
663    #[cfg(target_arch = "aarch64")]
664    {
665        if std::arch::is_aarch64_feature_detected!("neon") {
666            // SAFETY: guarded by runtime feature detection.
667            return unsafe { max_reduce_neon(data) };
668        }
669    }
670
671    max_reduce_scalar(data)
672}
673
674/// Sum all values in `data`.  Returns `0.0` for empty slices.
675#[allow(unsafe_code, dead_code)]
676#[inline]
677pub fn add_reduce_dispatch(data: &[f32]) -> f32 {
678    if data.is_empty() {
679        return 0.0;
680    }
681
682    if cfg!(miri) {
683        return add_reduce_scalar(data);
684    }
685
686    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
687    {
688        if std::is_x86_feature_detected!("avx") {
689            // SAFETY: guarded by runtime feature detection.
690            return unsafe { add_reduce_avx(data) };
691        }
692        if std::is_x86_feature_detected!("sse") {
693            // SAFETY: guarded by runtime feature detection.
694            return unsafe { add_reduce_sse(data) };
695        }
696    }
697
698    #[cfg(target_arch = "aarch64")]
699    {
700        if std::arch::is_aarch64_feature_detected!("neon") {
701            // SAFETY: guarded by runtime feature detection.
702            return unsafe { add_reduce_neon(data) };
703        }
704    }
705
706    add_reduce_scalar(data)
707}
708
709// ===========================================================================
710// Scalar-broadcast multiply in-place
711// ===========================================================================
712
713/// Multiply every element of `data` by `scalar` in-place.
714#[allow(unsafe_code, dead_code)]
715#[inline]
716pub fn mul_scalar_inplace_dispatch(data: &mut [f32], scalar: f32) {
717    if cfg!(miri) || data.is_empty() {
718        for v in data.iter_mut() {
719            *v *= scalar;
720        }
721        return;
722    }
723
724    #[cfg(target_arch = "aarch64")]
725    {
726        if std::arch::is_aarch64_feature_detected!("neon") {
727            // SAFETY: guarded by runtime feature detection.
728            unsafe { mul_scalar_inplace_neon(data, scalar) };
729            return;
730        }
731    }
732
733    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
734    {
735        if std::is_x86_feature_detected!("avx") {
736            // SAFETY: guarded by runtime feature detection.
737            unsafe { mul_scalar_inplace_avx(data, scalar) };
738            return;
739        }
740        if std::is_x86_feature_detected!("sse") {
741            // SAFETY: guarded by runtime feature detection.
742            unsafe { mul_scalar_inplace_sse(data, scalar) };
743            return;
744        }
745    }
746
747    for v in data.iter_mut() {
748        *v *= scalar;
749    }
750}
751
752#[cfg(target_arch = "aarch64")]
753#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
754#[target_feature(enable = "neon")]
755unsafe fn mul_scalar_inplace_neon(data: &mut [f32], scalar: f32) {
756    let len = data.len();
757    let ptr = data.as_mut_ptr();
758    let vs = vdupq_n_f32(scalar);
759    let mut i = 0usize;
760    while i + 4 <= len {
761        let v = vld1q_f32(ptr.add(i));
762        vst1q_f32(ptr.add(i), vmulq_f32(v, vs));
763        i += 4;
764    }
765    while i < len {
766        *ptr.add(i) *= scalar;
767        i += 1;
768    }
769}
770
771#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
772#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
773#[target_feature(enable = "avx")]
774unsafe fn mul_scalar_inplace_avx(data: &mut [f32], scalar: f32) {
775    #[cfg(target_arch = "x86")]
776    use std::arch::x86::*;
777    #[cfg(target_arch = "x86_64")]
778    use std::arch::x86_64::*;
779    let len = data.len();
780    let ptr = data.as_mut_ptr();
781    let vs = _mm256_set1_ps(scalar);
782    let mut i = 0usize;
783    while i + 8 <= len {
784        let v = _mm256_loadu_ps(ptr.add(i));
785        _mm256_storeu_ps(ptr.add(i), _mm256_mul_ps(v, vs));
786        i += 8;
787    }
788    // SSE tail
789    let vs4 = _mm_set1_ps(scalar);
790    while i + 4 <= len {
791        let v = _mm_loadu_ps(ptr.add(i));
792        _mm_storeu_ps(ptr.add(i), _mm_mul_ps(v, vs4));
793        i += 4;
794    }
795    while i < len {
796        *ptr.add(i) *= scalar;
797        i += 1;
798    }
799}
800
801#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
802#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
803#[target_feature(enable = "sse")]
804unsafe fn mul_scalar_inplace_sse(data: &mut [f32], scalar: f32) {
805    #[cfg(target_arch = "x86")]
806    use std::arch::x86::*;
807    #[cfg(target_arch = "x86_64")]
808    use std::arch::x86_64::*;
809    let len = data.len();
810    let ptr = data.as_mut_ptr();
811    let vs = _mm_set1_ps(scalar);
812    let mut i = 0usize;
813    while i + 4 <= len {
814        let v = _mm_loadu_ps(ptr.add(i));
815        _mm_storeu_ps(ptr.add(i), _mm_mul_ps(v, vs));
816        i += 4;
817    }
818    while i < len {
819        *ptr.add(i) *= scalar;
820        i += 1;
821    }
822}
823
824// ===========================================================================
825// FMA dispatch (conv2d inner loop helper)
826// ===========================================================================
827
828/// Fused multiply-accumulate: `acc[i] += a[i] * b[i]`.
829#[allow(unsafe_code, dead_code)]
830#[inline]
831pub fn fma_slice_dispatch(a: &[f32], b: &[f32], acc: &mut [f32]) {
832    debug_assert_eq!(a.len(), b.len());
833    debug_assert_eq!(a.len(), acc.len());
834
835    if cfg!(miri) {
836        // SAFETY: scalar path only reads/writes within equal-sized slice bounds.
837        unsafe {
838            fma_slice_scalar(a, b, acc);
839        }
840        return;
841    }
842
843    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
844    {
845        if std::is_x86_feature_detected!("avx") {
846            // SAFETY: guarded by runtime feature detection.
847            unsafe {
848                fma_slice_avx(a, b, acc);
849            }
850            return;
851        }
852        if std::is_x86_feature_detected!("sse") {
853            // SAFETY: guarded by runtime feature detection.
854            unsafe {
855                fma_slice_sse(a, b, acc);
856            }
857            return;
858        }
859    }
860
861    #[cfg(target_arch = "aarch64")]
862    {
863        if std::arch::is_aarch64_feature_detected!("neon") {
864            // SAFETY: guarded by runtime feature detection.
865            unsafe {
866                fma_slice_neon(a, b, acc);
867            }
868            return;
869        }
870    }
871
872    // SAFETY: scalar path only reads/writes within equal-sized slice bounds.
873    unsafe {
874        fma_slice_scalar(a, b, acc);
875    }
876}
877
878// ===========================================================================
879// Binary same-shape dispatch (existing)
880// ===========================================================================
881
882#[allow(unsafe_code, unreachable_code)]
883#[inline]
884pub fn binary_same_shape_dispatch(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
885    debug_assert_eq!(lhs.len(), rhs.len());
886    debug_assert_eq!(lhs.len(), out.len());
887
888    if cfg!(miri) {
889        // SAFETY: scalar path only reads/writes within equal-sized slice bounds.
890        unsafe {
891            binary_same_shape_scalar(lhs, rhs, out, kind);
892        }
893        return;
894    }
895
896    // macOS: use vDSP for add/sub/mul (heavily optimized, zero loop overhead).
897    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
898    {
899        let n = lhs.len() as u32;
900        // SAFETY: vDSP functions read/write `n` floats from contiguous slices.
901        unsafe {
902            match kind {
903                BinaryKind::Add => {
904                    vDSP_vadd(lhs.as_ptr(), 1, rhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
905                }
906                // NOTE: vDSP_vsub computes A - B with reversed argument order: vsub(B, ..., A, ..., C, ...)
907                BinaryKind::Sub => {
908                    vDSP_vsub(rhs.as_ptr(), 1, lhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
909                }
910                BinaryKind::Mul => {
911                    vDSP_vmul(lhs.as_ptr(), 1, rhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
912                }
913            }
914        }
915        return;
916    }
917
918    // x86/x86_64 with MKL: use Intel VML for add/sub/mul (heavily optimized).
919    #[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
920    {
921        let n = lhs.len() as i32;
922        // SAFETY: VML functions read `n` floats from contiguous slices and write to `out`.
923        unsafe {
924            match kind {
925                BinaryKind::Add => vsAdd(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
926                BinaryKind::Sub => vsSub(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
927                BinaryKind::Mul => vsMul(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
928            }
929        }
930        return;
931    }
932
933    // aarch64 Linux with ARMPL: use ARM Performance Libraries for add/sub/mul.
934    #[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
935    {
936        let n = lhs.len() as i32;
937        // SAFETY: ARMPL functions read `n` floats from contiguous slices and write to `out`.
938        unsafe {
939            match kind {
940                BinaryKind::Add => armpl_svadd_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
941                BinaryKind::Sub => armpl_svsub_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
942                BinaryKind::Mul => armpl_svmul_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
943            }
944        }
945        return;
946    }
947
948    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
949    {
950        if std::is_x86_feature_detected!("avx") {
951            // SAFETY: guarded by runtime feature detection.
952            unsafe {
953                binary_same_shape_avx(lhs, rhs, out, kind);
954            }
955            return;
956        }
957        if std::is_x86_feature_detected!("sse") {
958            // SAFETY: guarded by runtime feature detection.
959            unsafe {
960                binary_same_shape_sse(lhs, rhs, out, kind);
961            }
962            return;
963        }
964    }
965
966    #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
967    {
968        if std::arch::is_aarch64_feature_detected!("neon") {
969            // SAFETY: guarded by runtime feature detection.
970            unsafe {
971                binary_same_shape_neon(lhs, rhs, out, kind);
972            }
973            return;
974        }
975    }
976
977    // SAFETY: scalar path only reads/writes within equal-sized slice bounds.
978    unsafe {
979        binary_same_shape_scalar(lhs, rhs, out, kind);
980    }
981}
982
983// ===========================================================================
984// Scalar fallbacks
985// ===========================================================================
986
987#[allow(unsafe_code)]
988#[allow(unsafe_op_in_unsafe_fn)]
989unsafe fn relu_slice_scalar(values: &mut [f32]) {
990    let len = values.len();
991    let ptr = values.as_mut_ptr();
992    let mut index = 0usize;
993
994    while index + 8 <= len {
995        let v0 = *ptr.add(index);
996        let v1 = *ptr.add(index + 1);
997        let v2 = *ptr.add(index + 2);
998        let v3 = *ptr.add(index + 3);
999        let v4 = *ptr.add(index + 4);
1000        let v5 = *ptr.add(index + 5);
1001        let v6 = *ptr.add(index + 6);
1002        let v7 = *ptr.add(index + 7);
1003        *ptr.add(index) = v0.max(0.0);
1004        *ptr.add(index + 1) = v1.max(0.0);
1005        *ptr.add(index + 2) = v2.max(0.0);
1006        *ptr.add(index + 3) = v3.max(0.0);
1007        *ptr.add(index + 4) = v4.max(0.0);
1008        *ptr.add(index + 5) = v5.max(0.0);
1009        *ptr.add(index + 6) = v6.max(0.0);
1010        *ptr.add(index + 7) = v7.max(0.0);
1011        index += 8;
1012    }
1013
1014    while index < len {
1015        *ptr.add(index) = (*ptr.add(index)).max(0.0);
1016        index += 1;
1017    }
1018}
1019
1020#[allow(unsafe_code)]
1021#[allow(unsafe_op_in_unsafe_fn)]
1022unsafe fn relu_to_slice_scalar(input: &[f32], output: &mut [f32]) {
1023    let len = input.len();
1024    let in_ptr = input.as_ptr();
1025    let out_ptr = output.as_mut_ptr();
1026    let mut index = 0usize;
1027
1028    while index + 8 <= len {
1029        *out_ptr.add(index) = (*in_ptr.add(index)).max(0.0);
1030        *out_ptr.add(index + 1) = (*in_ptr.add(index + 1)).max(0.0);
1031        *out_ptr.add(index + 2) = (*in_ptr.add(index + 2)).max(0.0);
1032        *out_ptr.add(index + 3) = (*in_ptr.add(index + 3)).max(0.0);
1033        *out_ptr.add(index + 4) = (*in_ptr.add(index + 4)).max(0.0);
1034        *out_ptr.add(index + 5) = (*in_ptr.add(index + 5)).max(0.0);
1035        *out_ptr.add(index + 6) = (*in_ptr.add(index + 6)).max(0.0);
1036        *out_ptr.add(index + 7) = (*in_ptr.add(index + 7)).max(0.0);
1037        index += 8;
1038    }
1039
1040    while index < len {
1041        *out_ptr.add(index) = (*in_ptr.add(index)).max(0.0);
1042        index += 1;
1043    }
1044}
1045
1046#[allow(unsafe_code)]
1047#[allow(unsafe_op_in_unsafe_fn)]
1048unsafe fn binary_same_shape_scalar(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
1049    let len = lhs.len();
1050    let left_ptr = lhs.as_ptr();
1051    let right_ptr = rhs.as_ptr();
1052    let out_ptr = out.as_mut_ptr();
1053    let mut index = 0usize;
1054
1055    match kind {
1056        BinaryKind::Add => {
1057            while index + 8 <= len {
1058                *out_ptr.add(index) = *left_ptr.add(index) + *right_ptr.add(index);
1059                *out_ptr.add(index + 1) = *left_ptr.add(index + 1) + *right_ptr.add(index + 1);
1060                *out_ptr.add(index + 2) = *left_ptr.add(index + 2) + *right_ptr.add(index + 2);
1061                *out_ptr.add(index + 3) = *left_ptr.add(index + 3) + *right_ptr.add(index + 3);
1062                *out_ptr.add(index + 4) = *left_ptr.add(index + 4) + *right_ptr.add(index + 4);
1063                *out_ptr.add(index + 5) = *left_ptr.add(index + 5) + *right_ptr.add(index + 5);
1064                *out_ptr.add(index + 6) = *left_ptr.add(index + 6) + *right_ptr.add(index + 6);
1065                *out_ptr.add(index + 7) = *left_ptr.add(index + 7) + *right_ptr.add(index + 7);
1066                index += 8;
1067            }
1068            while index < len {
1069                *out_ptr.add(index) = *left_ptr.add(index) + *right_ptr.add(index);
1070                index += 1;
1071            }
1072        }
1073        BinaryKind::Sub => {
1074            while index + 8 <= len {
1075                *out_ptr.add(index) = *left_ptr.add(index) - *right_ptr.add(index);
1076                *out_ptr.add(index + 1) = *left_ptr.add(index + 1) - *right_ptr.add(index + 1);
1077                *out_ptr.add(index + 2) = *left_ptr.add(index + 2) - *right_ptr.add(index + 2);
1078                *out_ptr.add(index + 3) = *left_ptr.add(index + 3) - *right_ptr.add(index + 3);
1079                *out_ptr.add(index + 4) = *left_ptr.add(index + 4) - *right_ptr.add(index + 4);
1080                *out_ptr.add(index + 5) = *left_ptr.add(index + 5) - *right_ptr.add(index + 5);
1081                *out_ptr.add(index + 6) = *left_ptr.add(index + 6) - *right_ptr.add(index + 6);
1082                *out_ptr.add(index + 7) = *left_ptr.add(index + 7) - *right_ptr.add(index + 7);
1083                index += 8;
1084            }
1085            while index < len {
1086                *out_ptr.add(index) = *left_ptr.add(index) - *right_ptr.add(index);
1087                index += 1;
1088            }
1089        }
1090        BinaryKind::Mul => {
1091            while index + 8 <= len {
1092                *out_ptr.add(index) = *left_ptr.add(index) * *right_ptr.add(index);
1093                *out_ptr.add(index + 1) = *left_ptr.add(index + 1) * *right_ptr.add(index + 1);
1094                *out_ptr.add(index + 2) = *left_ptr.add(index + 2) * *right_ptr.add(index + 2);
1095                *out_ptr.add(index + 3) = *left_ptr.add(index + 3) * *right_ptr.add(index + 3);
1096                *out_ptr.add(index + 4) = *left_ptr.add(index + 4) * *right_ptr.add(index + 4);
1097                *out_ptr.add(index + 5) = *left_ptr.add(index + 5) * *right_ptr.add(index + 5);
1098                *out_ptr.add(index + 6) = *left_ptr.add(index + 6) * *right_ptr.add(index + 6);
1099                *out_ptr.add(index + 7) = *left_ptr.add(index + 7) * *right_ptr.add(index + 7);
1100                index += 8;
1101            }
1102            while index < len {
1103                *out_ptr.add(index) = *left_ptr.add(index) * *right_ptr.add(index);
1104                index += 1;
1105            }
1106        }
1107    }
1108}
1109
1110fn exp_slice_scalar(input: &[f32], output: &mut [f32]) {
1111    for (o, &v) in output.iter_mut().zip(input.iter()) {
1112        *o = v.exp();
1113    }
1114}
1115
1116fn sub_exp_slice_scalar(input: &[f32], offset: f32, output: &mut [f32]) {
1117    for (o, &v) in output.iter_mut().zip(input.iter()) {
1118        *o = (v - offset).exp();
1119    }
1120}
1121
1122fn sigmoid_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1123    for (o, &v) in output.iter_mut().zip(input.iter()) {
1124        *o = sigmoid_scalar(v);
1125    }
1126}
1127
1128fn tanh_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1129    for (o, &v) in output.iter_mut().zip(input.iter()) {
1130        *o = v.tanh();
1131    }
1132}
1133
1134fn silu_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1135    for (o, &v) in output.iter_mut().zip(input.iter()) {
1136        let s = 1.0 / (1.0 + (-v).exp());
1137        *o = v * s;
1138    }
1139}
1140
1141#[allow(dead_code)]
1142fn max_reduce_scalar(data: &[f32]) -> f32 {
1143    let mut acc = f32::NEG_INFINITY;
1144    for &v in data {
1145        acc = acc.max(v);
1146    }
1147    acc
1148}
1149
1150#[allow(dead_code)]
1151fn add_reduce_scalar(data: &[f32]) -> f32 {
1152    let mut acc = 0.0f32;
1153    for &v in data {
1154        acc += v;
1155    }
1156    acc
1157}
1158
1159#[allow(unsafe_code, dead_code)]
1160#[allow(unsafe_op_in_unsafe_fn)]
1161unsafe fn fma_slice_scalar(a: &[f32], b: &[f32], acc: &mut [f32]) {
1162    let len = a.len();
1163    let a_ptr = a.as_ptr();
1164    let b_ptr = b.as_ptr();
1165    let acc_ptr = acc.as_mut_ptr();
1166    let mut index = 0usize;
1167
1168    while index + 4 <= len {
1169        *acc_ptr.add(index) += *a_ptr.add(index) * *b_ptr.add(index);
1170        *acc_ptr.add(index + 1) += *a_ptr.add(index + 1) * *b_ptr.add(index + 1);
1171        *acc_ptr.add(index + 2) += *a_ptr.add(index + 2) * *b_ptr.add(index + 2);
1172        *acc_ptr.add(index + 3) += *a_ptr.add(index + 3) * *b_ptr.add(index + 3);
1173        index += 4;
1174    }
1175    while index < len {
1176        *acc_ptr.add(index) += *a_ptr.add(index) * *b_ptr.add(index);
1177        index += 1;
1178    }
1179}
1180
1181// ===========================================================================
1182// SSE fast-exp helper (4-wide)
1183// ===========================================================================
1184//
1185// Uses the classic range-reduction approach:
1186//   exp(x) = 2^n * exp(r)  where  n = round(x / ln2), r = x - n*ln2
1187// Then exp(r) is approximated with a degree-4 polynomial on [-ln2/2, ln2/2].
1188
1189/// Schraudolph 1999 bit-trick exp for SSE: exp(x) ≈ reinterpret(int(x * 2^23/ln2) + 127*2^23).
1190/// WHY: ~3x faster than polynomial, ~1e-3 accuracy is sufficient for sigmoid/tanh where 1/(1+exp) dampens error.
1191#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1192#[allow(unsafe_code)]
1193#[allow(unsafe_op_in_unsafe_fn)]
1194#[target_feature(enable = "sse")]
1195#[inline]
1196unsafe fn fast_exp_bittrick_sse(x: __m128) -> __m128 {
1197    // SSE2 intrinsics used below are always available on x86_64.
1198    #[cfg(target_arch = "x86")]
1199    use std::arch::x86::{_mm_add_epi32, _mm_cvtps_epi32, _mm_set1_epi32};
1200    #[cfg(target_arch = "x86_64")]
1201    use std::arch::x86_64::{_mm_add_epi32, _mm_cvtps_epi32, _mm_set1_epi32};
1202    // exp(x) ≈ reinterpret(int(x * C + B)) where C = 2^23/ln2, B = 127*2^23
1203    let scale = _mm_set1_ps(12102203.0); // WHY: 2^23/ln(2) maps float to IEEE 754 exponent field
1204    let offset = _mm_set1_epi32(1065353216); // WHY: 127*2^23 is the IEEE 754 exponent bias in integer form
1205    let clamp_lo = _mm_set1_ps(-87.0); // WHY: below this exp() produces denormals (underflow)
1206    let clamp_hi = _mm_set1_ps(88.0); // WHY: above this exp() exceeds f32 max (overflow to inf)
1207    let x_clamped = _mm_max_ps(_mm_min_ps(x, clamp_hi), clamp_lo);
1208    let val = _mm_cvtps_epi32(_mm_mul_ps(x_clamped, scale));
1209    _mm_castsi128_ps(_mm_add_epi32(val, offset))
1210}
1211
1212/// Polynomial exp for SSE: range-reduction + 6-term Taylor. Higher accuracy (~1e-6)
1213/// for standalone exp (softmax, etc.) where precision matters more.
1214/// WHY 6 terms: 6th-order Taylor series for 2^f on [0,1), max error ~1e-7, good accuracy/speed tradeoff.
1215#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1216#[allow(unsafe_code)]
1217#[allow(unsafe_op_in_unsafe_fn)]
1218#[target_feature(enable = "sse")]
1219unsafe fn fast_exp_sse(x: __m128) -> __m128 {
1220    let ln2_inv = _mm_set1_ps(std::f32::consts::LOG2_E);
1221    let ln2_hi = _mm_set1_ps(0.693_359_4); // upper bits of ln(2)
1222    let ln2_lo = _mm_set1_ps(-2.121_944_4e-4); // lower bits of ln(2)
1223
1224    // Polynomial coefficients (Taylor series for exp(r) on [-ln2/2, ln2/2])
1225    let c0 = _mm_set1_ps(1.0);
1226    let c1 = _mm_set1_ps(1.0);
1227    let c2 = _mm_set1_ps(0.5);
1228    let c3 = _mm_set1_ps(1.0 / 6.0);
1229    let c4 = _mm_set1_ps(1.0 / 24.0);
1230    let c5 = _mm_set1_ps(1.0 / 120.0);
1231    let c6 = _mm_set1_ps(1.0 / 720.0);
1232
1233    // Clamp input to prevent overflow/underflow
1234    let x = _mm_max_ps(_mm_set1_ps(-88.0), _mm_min_ps(_mm_set1_ps(88.0), x));
1235
1236    // n = round(x / ln2)
1237    let n_f = _mm_mul_ps(x, ln2_inv);
1238    // Round to nearest integer using convert (rounds to nearest by default)
1239    let n_i = _mm_cvtps_epi32(n_f);
1240    let n_f = _mm_cvtepi32_ps(n_i);
1241
1242    // r = x - n * ln2  (two-step for accuracy)
1243    let r = _mm_sub_ps(
1244        _mm_sub_ps(x, _mm_mul_ps(n_f, ln2_hi)),
1245        _mm_mul_ps(n_f, ln2_lo),
1246    );
1247
1248    // Polynomial: c0 + r*(c1 + r*(c2 + r*(c3 + r*(c4 + r*(c5 + r*c6)))))
1249    let mut poly = _mm_add_ps(c5, _mm_mul_ps(r, c6));
1250    poly = _mm_add_ps(c4, _mm_mul_ps(r, poly));
1251    poly = _mm_add_ps(c3, _mm_mul_ps(r, poly));
1252    poly = _mm_add_ps(c2, _mm_mul_ps(r, poly));
1253    poly = _mm_add_ps(c1, _mm_mul_ps(r, poly));
1254    poly = _mm_add_ps(c0, _mm_mul_ps(r, poly));
1255
1256    // Multiply by 2^n using bit manipulation: reinterpret (n + 127) << 23 as f32.
1257    // _mm_add_epi32 and _mm_slli_epi32 are SSE2, always available on x86_64.
1258    let pow2n = {
1259        #[cfg(target_arch = "x86")]
1260        use std::arch::x86::{_mm_add_epi32, _mm_slli_epi32};
1261        #[cfg(target_arch = "x86_64")]
1262        use std::arch::x86_64::{_mm_add_epi32, _mm_slli_epi32};
1263        let bias = _mm_set1_epi32(127);
1264        _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(n_i, bias), 23))
1265    };
1266
1267    _mm_mul_ps(poly, pow2n)
1268}
1269
1270// ===========================================================================
1271// AVX fast-exp helper (8-wide)
1272// ===========================================================================
1273
1274/// Schraudolph 1999 bit-trick exp for AVX: exp(x) ≈ reinterpret(int(x * 2^23/ln2) + 127*2^23).
1275/// WHY: ~3x faster than polynomial, ~1e-3 accuracy is sufficient for sigmoid/tanh where 1/(1+exp) dampens error.
1276#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1277#[allow(unsafe_code)]
1278#[allow(unsafe_op_in_unsafe_fn)]
1279#[target_feature(enable = "avx")]
1280#[inline]
1281unsafe fn fast_exp_bittrick_avx(x: __m256) -> __m256 {
1282    #[cfg(target_arch = "x86")]
1283    use std::arch::x86::{_mm256_add_epi32, _mm256_cvtps_epi32, _mm256_set1_epi32};
1284    #[cfg(target_arch = "x86_64")]
1285    use std::arch::x86_64::{_mm256_add_epi32, _mm256_cvtps_epi32, _mm256_set1_epi32};
1286    let scale = _mm256_set1_ps(12102203.0); // WHY: 2^23/ln(2) maps float to IEEE 754 exponent field
1287    let offset = _mm256_set1_epi32(1065353216); // WHY: 127*2^23 is the IEEE 754 exponent bias in integer form
1288    let clamp_lo = _mm256_set1_ps(-87.0); // WHY: below this exp() produces denormals
1289    let clamp_hi = _mm256_set1_ps(88.0); // WHY: above this exp() exceeds f32 max
1290    let x_clamped = _mm256_max_ps(_mm256_min_ps(x, clamp_hi), clamp_lo);
1291    let val = _mm256_cvtps_epi32(_mm256_mul_ps(x_clamped, scale));
1292    _mm256_castsi256_ps(_mm256_add_epi32(val, offset))
1293}
1294
1295/// Polynomial exp for AVX: range-reduction + 6-term Taylor. Higher accuracy (~1e-6)
1296/// for standalone exp (softmax, etc.) where precision matters more.
1297/// WHY 6 terms: 6th-order Taylor series for 2^f on [0,1), max error ~1e-7, good accuracy/speed tradeoff.
1298#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1299#[allow(unsafe_code)]
1300#[allow(unsafe_op_in_unsafe_fn)]
1301#[target_feature(enable = "avx")]
1302unsafe fn fast_exp_avx(x: __m256) -> __m256 {
1303    let ln2_inv = _mm256_set1_ps(std::f32::consts::LOG2_E);
1304    let ln2_hi = _mm256_set1_ps(0.693_359_4);
1305    let ln2_lo = _mm256_set1_ps(-2.121_944_4e-4);
1306
1307    let c0 = _mm256_set1_ps(1.0);
1308    let c1 = _mm256_set1_ps(1.0);
1309    let c2 = _mm256_set1_ps(0.5);
1310    let c3 = _mm256_set1_ps(1.0 / 6.0);
1311    let c4 = _mm256_set1_ps(1.0 / 24.0);
1312    let c5 = _mm256_set1_ps(1.0 / 120.0);
1313    let c6 = _mm256_set1_ps(1.0 / 720.0);
1314
1315    let x = _mm256_max_ps(
1316        _mm256_set1_ps(-88.0),
1317        _mm256_min_ps(_mm256_set1_ps(88.0), x),
1318    );
1319
1320    let n_f = _mm256_mul_ps(x, ln2_inv);
1321    let n_i = _mm256_cvtps_epi32(n_f);
1322    let n_f = _mm256_cvtepi32_ps(n_i);
1323
1324    let r = _mm256_sub_ps(
1325        _mm256_sub_ps(x, _mm256_mul_ps(n_f, ln2_hi)),
1326        _mm256_mul_ps(n_f, ln2_lo),
1327    );
1328
1329    let mut poly = _mm256_add_ps(c5, _mm256_mul_ps(r, c6));
1330    poly = _mm256_add_ps(c4, _mm256_mul_ps(r, poly));
1331    poly = _mm256_add_ps(c3, _mm256_mul_ps(r, poly));
1332    poly = _mm256_add_ps(c2, _mm256_mul_ps(r, poly));
1333    poly = _mm256_add_ps(c1, _mm256_mul_ps(r, poly));
1334    poly = _mm256_add_ps(c0, _mm256_mul_ps(r, poly));
1335
1336    let bias = _mm256_set1_epi32(127);
1337    let pow2n = {
1338        #[cfg(target_arch = "x86")]
1339        use std::arch::x86::{_mm256_add_epi32, _mm256_slli_epi32};
1340        #[cfg(target_arch = "x86_64")]
1341        use std::arch::x86_64::{_mm256_add_epi32, _mm256_slli_epi32};
1342        _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_add_epi32(n_i, bias), 23))
1343    };
1344
1345    _mm256_mul_ps(poly, pow2n)
1346}
1347
1348// ===========================================================================
1349// NEON fast-exp helper (4-wide)
1350// ===========================================================================
1351
1352#[cfg(target_arch = "aarch64")]
1353#[allow(unsafe_code)]
1354#[allow(unsafe_op_in_unsafe_fn)]
1355#[target_feature(enable = "neon")]
1356unsafe fn fast_exp_neon(x: float32x4_t) -> float32x4_t {
1357    use std::arch::aarch64::{
1358        vaddq_s32, vcvtnq_s32_f32, vcvtq_f32_s32, vreinterpretq_f32_s32, vshlq_n_s32,
1359    };
1360
1361    let ln2_inv = vdupq_n_f32(std::f32::consts::LOG2_E);
1362    let ln2_hi = vdupq_n_f32(0.693_359_4);
1363    let ln2_lo = vdupq_n_f32(-2.121_944_4e-4);
1364
1365    let c0 = vdupq_n_f32(1.0);
1366    let c1 = vdupq_n_f32(1.0);
1367    let c2 = vdupq_n_f32(0.5);
1368    let c3 = vdupq_n_f32(1.0 / 6.0);
1369    let c4 = vdupq_n_f32(1.0 / 24.0);
1370    let c5 = vdupq_n_f32(1.0 / 120.0);
1371    let c6 = vdupq_n_f32(1.0 / 720.0);
1372
1373    let x = vmaxq_f32(vdupq_n_f32(-88.0), vminq_f32(vdupq_n_f32(88.0), x));
1374
1375    let n_f = vmulq_f32(x, ln2_inv);
1376    let n_i = vcvtnq_s32_f32(n_f);
1377    let n_f = vcvtq_f32_s32(n_i);
1378
1379    let r = vsubq_f32(vsubq_f32(x, vmulq_f32(n_f, ln2_hi)), vmulq_f32(n_f, ln2_lo));
1380
1381    let mut poly = vfmaq_f32(c5, r, c6);
1382    poly = vfmaq_f32(c4, r, poly);
1383    poly = vfmaq_f32(c3, r, poly);
1384    poly = vfmaq_f32(c2, r, poly);
1385    poly = vfmaq_f32(c1, r, poly);
1386    poly = vfmaq_f32(c0, r, poly);
1387
1388    use std::arch::aarch64::vdupq_n_s32;
1389    let bias = vdupq_n_s32(127);
1390    let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(n_i, bias)));
1391
1392    vmulq_f32(poly, pow2n)
1393}
1394
1395// ===========================================================================
1396// Exp slice implementations
1397// ===========================================================================
1398
1399#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1400#[allow(unsafe_code)]
1401#[allow(unsafe_op_in_unsafe_fn)]
1402#[target_feature(enable = "sse")]
1403unsafe fn exp_slice_sse(input: &[f32], output: &mut [f32]) {
1404    let len = input.len();
1405    let in_ptr = input.as_ptr();
1406    let out_ptr = output.as_mut_ptr();
1407    let mut index = 0usize;
1408
1409    while index + 4 <= len {
1410        let v = _mm_loadu_ps(in_ptr.add(index));
1411        let r = fast_exp_sse(v);
1412        _mm_storeu_ps(out_ptr.add(index), r);
1413        index += 4;
1414    }
1415
1416    while index < len {
1417        *out_ptr.add(index) = (*in_ptr.add(index)).exp();
1418        index += 1;
1419    }
1420}
1421
1422#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1423#[allow(unsafe_code)]
1424#[allow(unsafe_op_in_unsafe_fn)]
1425#[target_feature(enable = "avx")]
1426unsafe fn exp_slice_avx(input: &[f32], output: &mut [f32]) {
1427    let len = input.len();
1428    let in_ptr = input.as_ptr();
1429    let out_ptr = output.as_mut_ptr();
1430    let mut index = 0usize;
1431
1432    // 2x unrolled: process 16 floats per iteration to hide FMA latency.
1433    while index + 16 <= len {
1434        // Prefetch next cacheline (64 bytes = 16 floats ahead)
1435        #[cfg(target_arch = "x86")]
1436        {
1437            use std::arch::x86::_mm_prefetch;
1438            _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1439        }
1440        #[cfg(target_arch = "x86_64")]
1441        {
1442            use std::arch::x86_64::_mm_prefetch;
1443            _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1444        }
1445        let v0 = _mm256_loadu_ps(in_ptr.add(index));
1446        let v1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1447        let r0 = fast_exp_avx(v0);
1448        let r1 = fast_exp_avx(v1);
1449        _mm256_storeu_ps(out_ptr.add(index), r0);
1450        _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1451        index += 16;
1452    }
1453
1454    // Handle remaining 8-float chunk
1455    while index + 8 <= len {
1456        let v = _mm256_loadu_ps(in_ptr.add(index));
1457        let r = fast_exp_avx(v);
1458        _mm256_storeu_ps(out_ptr.add(index), r);
1459        index += 8;
1460    }
1461
1462    if index < len {
1463        exp_slice_sse(&input[index..], &mut output[index..]);
1464    }
1465}
1466
1467#[cfg(target_arch = "aarch64")]
1468#[allow(unsafe_code, dead_code)]
1469#[allow(unsafe_op_in_unsafe_fn)]
1470#[target_feature(enable = "neon")]
1471unsafe fn exp_slice_neon(input: &[f32], output: &mut [f32]) {
1472    let len = input.len();
1473    let in_ptr = input.as_ptr();
1474    let out_ptr = output.as_mut_ptr();
1475    let mut index = 0usize;
1476
1477    while index + 4 <= len {
1478        let v = vld1q_f32(in_ptr.add(index));
1479        let r = fast_exp_neon(v);
1480        vst1q_f32(out_ptr.add(index), r);
1481        index += 4;
1482    }
1483
1484    while index < len {
1485        *out_ptr.add(index) = (*in_ptr.add(index)).exp();
1486        index += 1;
1487    }
1488}
1489
1490// ===========================================================================
1491// Fused subtract-and-exp: output[i] = exp(input[i] - offset)
1492// ===========================================================================
1493
1494#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1495#[allow(unsafe_code)]
1496#[allow(unsafe_op_in_unsafe_fn)]
1497#[target_feature(enable = "sse")]
1498unsafe fn sub_exp_slice_sse(input: &[f32], offset: f32, output: &mut [f32]) {
1499    let len = input.len();
1500    let in_ptr = input.as_ptr();
1501    let out_ptr = output.as_mut_ptr();
1502    let off = _mm_set1_ps(offset);
1503    let mut index = 0usize;
1504
1505    while index + 4 <= len {
1506        let v = _mm_loadu_ps(in_ptr.add(index));
1507        let shifted = _mm_sub_ps(v, off);
1508        let r = fast_exp_sse(shifted);
1509        _mm_storeu_ps(out_ptr.add(index), r);
1510        index += 4;
1511    }
1512
1513    while index < len {
1514        *out_ptr.add(index) = (*in_ptr.add(index) - offset).exp();
1515        index += 1;
1516    }
1517}
1518
1519#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1520#[allow(unsafe_code)]
1521#[allow(unsafe_op_in_unsafe_fn)]
1522#[target_feature(enable = "avx")]
1523unsafe fn sub_exp_slice_avx(input: &[f32], offset: f32, output: &mut [f32]) {
1524    let len = input.len();
1525    let in_ptr = input.as_ptr();
1526    let out_ptr = output.as_mut_ptr();
1527    let off = _mm256_set1_ps(offset);
1528    let mut index = 0usize;
1529
1530    // 2x unrolled: process 16 floats per iteration to hide FMA latency.
1531    while index + 16 <= len {
1532        #[cfg(target_arch = "x86")]
1533        {
1534            use std::arch::x86::_mm_prefetch;
1535            _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1536        }
1537        #[cfg(target_arch = "x86_64")]
1538        {
1539            use std::arch::x86_64::_mm_prefetch;
1540            _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1541        }
1542        let v0 = _mm256_loadu_ps(in_ptr.add(index));
1543        let v1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1544        let shifted0 = _mm256_sub_ps(v0, off);
1545        let shifted1 = _mm256_sub_ps(v1, off);
1546        let r0 = fast_exp_avx(shifted0);
1547        let r1 = fast_exp_avx(shifted1);
1548        _mm256_storeu_ps(out_ptr.add(index), r0);
1549        _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1550        index += 16;
1551    }
1552
1553    // Handle remaining 8-float chunk
1554    while index + 8 <= len {
1555        let v = _mm256_loadu_ps(in_ptr.add(index));
1556        let shifted = _mm256_sub_ps(v, off);
1557        let r = fast_exp_avx(shifted);
1558        _mm256_storeu_ps(out_ptr.add(index), r);
1559        index += 8;
1560    }
1561
1562    if index < len {
1563        sub_exp_slice_sse(&input[index..], offset, &mut output[index..]);
1564    }
1565}
1566
1567#[cfg(target_arch = "aarch64")]
1568#[allow(unsafe_code)]
1569#[allow(unsafe_op_in_unsafe_fn)]
1570#[target_feature(enable = "neon")]
1571unsafe fn sub_exp_slice_neon(input: &[f32], offset: f32, output: &mut [f32]) {
1572    let len = input.len();
1573    let in_ptr = input.as_ptr();
1574    let out_ptr = output.as_mut_ptr();
1575    let off = vdupq_n_f32(offset);
1576    let mut index = 0usize;
1577
1578    while index + 4 <= len {
1579        let v = vld1q_f32(in_ptr.add(index));
1580        let shifted = vsubq_f32(v, off);
1581        let r = fast_exp_neon(shifted);
1582        vst1q_f32(out_ptr.add(index), r);
1583        index += 4;
1584    }
1585
1586    while index < len {
1587        *out_ptr.add(index) = (*in_ptr.add(index) - offset).exp();
1588        index += 1;
1589    }
1590}
1591
1592// ===========================================================================
1593// Sigmoid slice implementations: sigmoid(x) = 1 / (1 + exp(-x))
1594// ===========================================================================
1595
1596#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1597#[allow(unsafe_code)]
1598#[allow(unsafe_op_in_unsafe_fn)]
1599#[target_feature(enable = "sse")]
1600unsafe fn sigmoid_slice_sse(input: &[f32], output: &mut [f32]) {
1601    #[cfg(target_arch = "x86")]
1602    use std::arch::x86::_mm_div_ps;
1603    #[cfg(target_arch = "x86_64")]
1604    use std::arch::x86_64::_mm_div_ps;
1605
1606    let len = input.len();
1607    let in_ptr = input.as_ptr();
1608    let out_ptr = output.as_mut_ptr();
1609    let one = _mm_set1_ps(1.0);
1610    let zero = _mm_setzero_ps();
1611    let mut index = 0usize;
1612
1613    // Process 16 elements per iteration (4 SSE registers)
1614    while index + 16 <= len {
1615        let x0 = _mm_loadu_ps(in_ptr.add(index));
1616        let x1 = _mm_loadu_ps(in_ptr.add(index + 4));
1617        let x2 = _mm_loadu_ps(in_ptr.add(index + 8));
1618        let x3 = _mm_loadu_ps(in_ptr.add(index + 12));
1619
1620        // Use Schraudolph bit-trick exp for ~3x speedup over polynomial
1621        let e0 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x0));
1622        let e1 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x1));
1623        let e2 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x2));
1624        let e3 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x3));
1625
1626        let r0 = _mm_div_ps(one, _mm_add_ps(one, e0));
1627        let r1 = _mm_div_ps(one, _mm_add_ps(one, e1));
1628        let r2 = _mm_div_ps(one, _mm_add_ps(one, e2));
1629        let r3 = _mm_div_ps(one, _mm_add_ps(one, e3));
1630
1631        _mm_storeu_ps(out_ptr.add(index), r0);
1632        _mm_storeu_ps(out_ptr.add(index + 4), r1);
1633        _mm_storeu_ps(out_ptr.add(index + 8), r2);
1634        _mm_storeu_ps(out_ptr.add(index + 12), r3);
1635
1636        index += 16;
1637    }
1638
1639    // Remaining 4 at a time
1640    while index + 4 <= len {
1641        let x = _mm_loadu_ps(in_ptr.add(index));
1642        let neg_x = _mm_sub_ps(zero, x);
1643        let exp_neg_x = fast_exp_bittrick_sse(neg_x);
1644        let denom = _mm_add_ps(one, exp_neg_x);
1645        let result = _mm_div_ps(one, denom);
1646        _mm_storeu_ps(out_ptr.add(index), result);
1647        index += 4;
1648    }
1649
1650    while index < len {
1651        *out_ptr.add(index) = sigmoid_scalar(*in_ptr.add(index));
1652        index += 1;
1653    }
1654}
1655
1656#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1657#[allow(unsafe_code)]
1658#[allow(unsafe_op_in_unsafe_fn)]
1659#[target_feature(enable = "avx")]
1660unsafe fn sigmoid_slice_avx(input: &[f32], output: &mut [f32]) {
1661    #[cfg(target_arch = "x86")]
1662    use std::arch::x86::_mm256_div_ps;
1663    #[cfg(target_arch = "x86_64")]
1664    use std::arch::x86_64::_mm256_div_ps;
1665
1666    let len = input.len();
1667    let in_ptr = input.as_ptr();
1668    let out_ptr = output.as_mut_ptr();
1669    let one = _mm256_set1_ps(1.0);
1670    let zero = _mm256_setzero_ps();
1671    let mut index = 0usize;
1672
1673    // Process 32 elements per iteration (4 AVX registers)
1674    while index + 32 <= len {
1675        let x0 = _mm256_loadu_ps(in_ptr.add(index));
1676        let x1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1677        let x2 = _mm256_loadu_ps(in_ptr.add(index + 16));
1678        let x3 = _mm256_loadu_ps(in_ptr.add(index + 24));
1679
1680        // Use Schraudolph bit-trick exp for ~3x speedup over polynomial
1681        let e0 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x0));
1682        let e1 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x1));
1683        let e2 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x2));
1684        let e3 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x3));
1685
1686        let r0 = _mm256_div_ps(one, _mm256_add_ps(one, e0));
1687        let r1 = _mm256_div_ps(one, _mm256_add_ps(one, e1));
1688        let r2 = _mm256_div_ps(one, _mm256_add_ps(one, e2));
1689        let r3 = _mm256_div_ps(one, _mm256_add_ps(one, e3));
1690
1691        _mm256_storeu_ps(out_ptr.add(index), r0);
1692        _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1693        _mm256_storeu_ps(out_ptr.add(index + 16), r2);
1694        _mm256_storeu_ps(out_ptr.add(index + 24), r3);
1695
1696        index += 32;
1697    }
1698
1699    // Remaining 8 at a time
1700    while index + 8 <= len {
1701        let x = _mm256_loadu_ps(in_ptr.add(index));
1702        let neg_x = _mm256_sub_ps(zero, x);
1703        let exp_neg_x = fast_exp_bittrick_avx(neg_x);
1704        let denom = _mm256_add_ps(one, exp_neg_x);
1705        let result = _mm256_div_ps(one, denom);
1706        _mm256_storeu_ps(out_ptr.add(index), result);
1707        index += 8;
1708    }
1709
1710    if index < len {
1711        sigmoid_slice_sse(&input[index..], &mut output[index..]);
1712    }
1713}
1714
1715// (sigmoid_slice_neon defined above at line ~291)
1716
1717// ===========================================================================
1718// Tanh slice implementations: tanh(x) = 2 * sigmoid(2x) - 1
1719// ===========================================================================
1720
1721#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1722#[allow(unsafe_code)]
1723#[allow(unsafe_op_in_unsafe_fn)]
1724#[target_feature(enable = "sse")]
1725unsafe fn tanh_slice_sse(input: &[f32], output: &mut [f32]) {
1726    let len = input.len();
1727    let in_ptr = input.as_ptr();
1728    let out_ptr = output.as_mut_ptr();
1729    let two = _mm_set1_ps(2.0);
1730    let one = _mm_set1_ps(1.0);
1731    let zero = _mm_setzero_ps();
1732    let mut index = 0usize;
1733
1734    while index + 4 <= len {
1735        let x = _mm_loadu_ps(in_ptr.add(index));
1736        let two_x = _mm_mul_ps(two, x);
1737        // sigmoid(2x) = 1 / (1 + exp(-2x))
1738        let neg_two_x = _mm_sub_ps(zero, two_x);
1739        // Use Schraudolph bit-trick exp for ~3x speedup
1740        let exp_neg = fast_exp_bittrick_sse(neg_two_x);
1741        let denom = _mm_add_ps(one, exp_neg);
1742        let sig = {
1743            #[cfg(target_arch = "x86")]
1744            use std::arch::x86::_mm_rcp_ps;
1745            #[cfg(target_arch = "x86_64")]
1746            use std::arch::x86_64::_mm_rcp_ps;
1747            let rcp = _mm_rcp_ps(denom);
1748            _mm_mul_ps(rcp, _mm_sub_ps(two, _mm_mul_ps(denom, rcp)))
1749        };
1750        // tanh = 2 * sig - 1
1751        let result = _mm_sub_ps(_mm_mul_ps(two, sig), one);
1752        _mm_storeu_ps(out_ptr.add(index), result);
1753        index += 4;
1754    }
1755
1756    while index < len {
1757        *out_ptr.add(index) = (*in_ptr.add(index)).tanh();
1758        index += 1;
1759    }
1760}
1761
1762#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1763#[allow(unsafe_code)]
1764#[allow(unsafe_op_in_unsafe_fn)]
1765#[target_feature(enable = "avx")]
1766unsafe fn tanh_slice_avx(input: &[f32], output: &mut [f32]) {
1767    let len = input.len();
1768    let in_ptr = input.as_ptr();
1769    let out_ptr = output.as_mut_ptr();
1770    let two = _mm256_set1_ps(2.0);
1771    let one = _mm256_set1_ps(1.0);
1772    let zero = _mm256_setzero_ps();
1773    let mut index = 0usize;
1774
1775    while index + 8 <= len {
1776        let x = _mm256_loadu_ps(in_ptr.add(index));
1777        let two_x = _mm256_mul_ps(two, x);
1778        let neg_two_x = _mm256_sub_ps(zero, two_x);
1779        // Use Schraudolph bit-trick exp for ~3x speedup
1780        let exp_neg = fast_exp_bittrick_avx(neg_two_x);
1781        let denom = _mm256_add_ps(one, exp_neg);
1782        let sig = {
1783            #[cfg(target_arch = "x86")]
1784            use std::arch::x86::_mm256_rcp_ps;
1785            #[cfg(target_arch = "x86_64")]
1786            use std::arch::x86_64::_mm256_rcp_ps;
1787            let rcp = _mm256_rcp_ps(denom);
1788            _mm256_mul_ps(rcp, _mm256_sub_ps(two, _mm256_mul_ps(denom, rcp)))
1789        };
1790        let result = _mm256_sub_ps(_mm256_mul_ps(two, sig), one);
1791        _mm256_storeu_ps(out_ptr.add(index), result);
1792        index += 8;
1793    }
1794
1795    if index < len {
1796        tanh_slice_sse(&input[index..], &mut output[index..]);
1797    }
1798}
1799
1800#[cfg(target_arch = "aarch64")]
1801#[allow(unsafe_code, dead_code)]
1802#[allow(unsafe_op_in_unsafe_fn)]
1803#[target_feature(enable = "neon")]
1804unsafe fn tanh_slice_neon(input: &[f32], output: &mut [f32]) {
1805    let len = input.len();
1806    let in_ptr = input.as_ptr();
1807    let out_ptr = output.as_mut_ptr();
1808    let two = vdupq_n_f32(2.0);
1809    let one = vdupq_n_f32(1.0);
1810    let mut index = 0usize;
1811
1812    // 8x unrolled: 32 elements per iteration, using fast 3-term exp polynomial
1813    while index + 32 <= len {
1814        let x0 = vld1q_f32(in_ptr.add(index));
1815        let x1 = vld1q_f32(in_ptr.add(index + 4));
1816        let x2 = vld1q_f32(in_ptr.add(index + 8));
1817        let x3 = vld1q_f32(in_ptr.add(index + 12));
1818        let x4 = vld1q_f32(in_ptr.add(index + 16));
1819        let x5 = vld1q_f32(in_ptr.add(index + 20));
1820        let x6 = vld1q_f32(in_ptr.add(index + 24));
1821        let x7 = vld1q_f32(in_ptr.add(index + 28));
1822
1823        // exp(-2x) using fast 3-term polynomial (sufficient for tanh)
1824        let e0 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x0)));
1825        let e1 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x1)));
1826        let e2 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x2)));
1827        let e3 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x3)));
1828        let e4 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x4)));
1829        let e5 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x5)));
1830        let e6 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x6)));
1831        let e7 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x7)));
1832
1833        // tanh(x) = 2 * sigmoid(2x) - 1 = 2/(1+exp(-2x)) - 1
1834        vst1q_f32(
1835            out_ptr.add(index),
1836            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e0)), one),
1837        );
1838        vst1q_f32(
1839            out_ptr.add(index + 4),
1840            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e1)), one),
1841        );
1842        vst1q_f32(
1843            out_ptr.add(index + 8),
1844            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e2)), one),
1845        );
1846        vst1q_f32(
1847            out_ptr.add(index + 12),
1848            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e3)), one),
1849        );
1850        vst1q_f32(
1851            out_ptr.add(index + 16),
1852            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e4)), one),
1853        );
1854        vst1q_f32(
1855            out_ptr.add(index + 20),
1856            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e5)), one),
1857        );
1858        vst1q_f32(
1859            out_ptr.add(index + 24),
1860            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e6)), one),
1861        );
1862        vst1q_f32(
1863            out_ptr.add(index + 28),
1864            vsubq_f32(vdivq_f32(two, vaddq_f32(one, e7)), one),
1865        );
1866        index += 32;
1867    }
1868
1869    while index + 4 <= len {
1870        let x = vld1q_f32(in_ptr.add(index));
1871        let two_x = vmulq_f32(two, x);
1872        let neg_two_x = vnegq_f32(two_x);
1873        let exp_neg = fast_exp_sigmoid_neon(neg_two_x);
1874        let denom = vaddq_f32(one, exp_neg);
1875        let result = vsubq_f32(vdivq_f32(two, denom), one);
1876        vst1q_f32(out_ptr.add(index), result);
1877        index += 4;
1878    }
1879
1880    while index < len {
1881        *out_ptr.add(index) = (*in_ptr.add(index)).tanh();
1882        index += 1;
1883    }
1884}
1885
1886#[cfg(target_arch = "aarch64")]
1887#[allow(unsafe_code, dead_code)]
1888#[allow(unsafe_op_in_unsafe_fn)]
1889#[target_feature(enable = "neon")]
1890/// Fused SiLU: output[i] = x * sigmoid(x) in a single pass.
1891/// 8x unrolled with fast 3-term exp polynomial.
1892unsafe fn silu_slice_neon(input: &[f32], output: &mut [f32]) {
1893    let len = input.len();
1894    let in_ptr = input.as_ptr();
1895    let out_ptr = output.as_mut_ptr();
1896    let one = vdupq_n_f32(1.0);
1897    let mut index = 0usize;
1898
1899    // 8x unrolled: 32 elements per iteration
1900    while index + 32 <= len {
1901        let x0 = vld1q_f32(in_ptr.add(index));
1902        let x1 = vld1q_f32(in_ptr.add(index + 4));
1903        let x2 = vld1q_f32(in_ptr.add(index + 8));
1904        let x3 = vld1q_f32(in_ptr.add(index + 12));
1905        let x4 = vld1q_f32(in_ptr.add(index + 16));
1906        let x5 = vld1q_f32(in_ptr.add(index + 20));
1907        let x6 = vld1q_f32(in_ptr.add(index + 24));
1908        let x7 = vld1q_f32(in_ptr.add(index + 28));
1909
1910        // sigmoid(x) = 1 / (1 + exp(-x))
1911        let e0 = fast_exp_sigmoid_neon(vnegq_f32(x0));
1912        let e1 = fast_exp_sigmoid_neon(vnegq_f32(x1));
1913        let e2 = fast_exp_sigmoid_neon(vnegq_f32(x2));
1914        let e3 = fast_exp_sigmoid_neon(vnegq_f32(x3));
1915        let e4 = fast_exp_sigmoid_neon(vnegq_f32(x4));
1916        let e5 = fast_exp_sigmoid_neon(vnegq_f32(x5));
1917        let e6 = fast_exp_sigmoid_neon(vnegq_f32(x6));
1918        let e7 = fast_exp_sigmoid_neon(vnegq_f32(x7));
1919
1920        // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
1921        vst1q_f32(
1922            out_ptr.add(index),
1923            vmulq_f32(x0, vdivq_f32(one, vaddq_f32(one, e0))),
1924        );
1925        vst1q_f32(
1926            out_ptr.add(index + 4),
1927            vmulq_f32(x1, vdivq_f32(one, vaddq_f32(one, e1))),
1928        );
1929        vst1q_f32(
1930            out_ptr.add(index + 8),
1931            vmulq_f32(x2, vdivq_f32(one, vaddq_f32(one, e2))),
1932        );
1933        vst1q_f32(
1934            out_ptr.add(index + 12),
1935            vmulq_f32(x3, vdivq_f32(one, vaddq_f32(one, e3))),
1936        );
1937        vst1q_f32(
1938            out_ptr.add(index + 16),
1939            vmulq_f32(x4, vdivq_f32(one, vaddq_f32(one, e4))),
1940        );
1941        vst1q_f32(
1942            out_ptr.add(index + 20),
1943            vmulq_f32(x5, vdivq_f32(one, vaddq_f32(one, e5))),
1944        );
1945        vst1q_f32(
1946            out_ptr.add(index + 24),
1947            vmulq_f32(x6, vdivq_f32(one, vaddq_f32(one, e6))),
1948        );
1949        vst1q_f32(
1950            out_ptr.add(index + 28),
1951            vmulq_f32(x7, vdivq_f32(one, vaddq_f32(one, e7))),
1952        );
1953        index += 32;
1954    }
1955
1956    while index + 4 <= len {
1957        let x = vld1q_f32(in_ptr.add(index));
1958        let e = fast_exp_sigmoid_neon(vnegq_f32(x));
1959        let sig = vdivq_f32(one, vaddq_f32(one, e));
1960        vst1q_f32(out_ptr.add(index), vmulq_f32(x, sig));
1961        index += 4;
1962    }
1963
1964    while index < len {
1965        let x = *in_ptr.add(index);
1966        let s = 1.0 / (1.0 + (-x).exp());
1967        *out_ptr.add(index) = x * s;
1968        index += 1;
1969    }
1970}
1971
1972/// Fused SiLU (x * sigmoid(x)) using SSE.
1973#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1974#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1975#[target_feature(enable = "sse")]
1976unsafe fn silu_slice_sse(input: &[f32], output: &mut [f32]) {
1977    #[cfg(target_arch = "x86")]
1978    use std::arch::x86::_mm_div_ps;
1979    #[cfg(target_arch = "x86_64")]
1980    use std::arch::x86_64::_mm_div_ps;
1981
1982    let len = input.len();
1983    let in_ptr = input.as_ptr();
1984    let out_ptr = output.as_mut_ptr();
1985    let one = _mm_set1_ps(1.0);
1986    let zero = _mm_setzero_ps();
1987    let mut index = 0usize;
1988
1989    while index + 16 <= len {
1990        let x0 = _mm_loadu_ps(in_ptr.add(index));
1991        let x1 = _mm_loadu_ps(in_ptr.add(index + 4));
1992        let x2 = _mm_loadu_ps(in_ptr.add(index + 8));
1993        let x3 = _mm_loadu_ps(in_ptr.add(index + 12));
1994
1995        // Use Schraudolph bit-trick exp for ~3x speedup
1996        let e0 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x0));
1997        let e1 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x1));
1998        let e2 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x2));
1999        let e3 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x3));
2000
2001        // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
2002        _mm_storeu_ps(
2003            out_ptr.add(index),
2004            _mm_mul_ps(x0, _mm_div_ps(one, _mm_add_ps(one, e0))),
2005        );
2006        _mm_storeu_ps(
2007            out_ptr.add(index + 4),
2008            _mm_mul_ps(x1, _mm_div_ps(one, _mm_add_ps(one, e1))),
2009        );
2010        _mm_storeu_ps(
2011            out_ptr.add(index + 8),
2012            _mm_mul_ps(x2, _mm_div_ps(one, _mm_add_ps(one, e2))),
2013        );
2014        _mm_storeu_ps(
2015            out_ptr.add(index + 12),
2016            _mm_mul_ps(x3, _mm_div_ps(one, _mm_add_ps(one, e3))),
2017        );
2018
2019        index += 16;
2020    }
2021
2022    while index + 4 <= len {
2023        let x = _mm_loadu_ps(in_ptr.add(index));
2024        let e = fast_exp_bittrick_sse(_mm_sub_ps(zero, x));
2025        let sig = _mm_div_ps(one, _mm_add_ps(one, e));
2026        _mm_storeu_ps(out_ptr.add(index), _mm_mul_ps(x, sig));
2027        index += 4;
2028    }
2029
2030    while index < len {
2031        let v = *in_ptr.add(index);
2032        let s = 1.0 / (1.0 + (-v).exp());
2033        *out_ptr.add(index) = v * s;
2034        index += 1;
2035    }
2036}
2037
2038/// Fused SiLU (x * sigmoid(x)) using AVX.
2039#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2040#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
2041#[target_feature(enable = "avx")]
2042unsafe fn silu_slice_avx(input: &[f32], output: &mut [f32]) {
2043    #[cfg(target_arch = "x86")]
2044    use std::arch::x86::_mm256_div_ps;
2045    #[cfg(target_arch = "x86_64")]
2046    use std::arch::x86_64::_mm256_div_ps;
2047
2048    let len = input.len();
2049    let in_ptr = input.as_ptr();
2050    let out_ptr = output.as_mut_ptr();
2051    let one = _mm256_set1_ps(1.0);
2052    let zero = _mm256_setzero_ps();
2053    let mut index = 0usize;
2054
2055    while index + 32 <= len {
2056        let x0 = _mm256_loadu_ps(in_ptr.add(index));
2057        let x1 = _mm256_loadu_ps(in_ptr.add(index + 8));
2058        let x2 = _mm256_loadu_ps(in_ptr.add(index + 16));
2059        let x3 = _mm256_loadu_ps(in_ptr.add(index + 24));
2060
2061        // Use Schraudolph bit-trick exp for ~3x speedup
2062        let e0 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x0));
2063        let e1 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x1));
2064        let e2 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x2));
2065        let e3 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x3));
2066
2067        // silu(x) = x / (1 + exp(-x))
2068        _mm256_storeu_ps(
2069            out_ptr.add(index),
2070            _mm256_mul_ps(x0, _mm256_div_ps(one, _mm256_add_ps(one, e0))),
2071        );
2072        _mm256_storeu_ps(
2073            out_ptr.add(index + 8),
2074            _mm256_mul_ps(x1, _mm256_div_ps(one, _mm256_add_ps(one, e1))),
2075        );
2076        _mm256_storeu_ps(
2077            out_ptr.add(index + 16),
2078            _mm256_mul_ps(x2, _mm256_div_ps(one, _mm256_add_ps(one, e2))),
2079        );
2080        _mm256_storeu_ps(
2081            out_ptr.add(index + 24),
2082            _mm256_mul_ps(x3, _mm256_div_ps(one, _mm256_add_ps(one, e3))),
2083        );
2084
2085        index += 32;
2086    }
2087
2088    while index + 8 <= len {
2089        let x = _mm256_loadu_ps(in_ptr.add(index));
2090        let e = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x));
2091        let sig = _mm256_div_ps(one, _mm256_add_ps(one, e));
2092        _mm256_storeu_ps(out_ptr.add(index), _mm256_mul_ps(x, sig));
2093        index += 8;
2094    }
2095
2096    if index < len {
2097        silu_slice_sse(&input[index..], &mut output[index..]);
2098    }
2099}
2100
2101// ===========================================================================
2102// Max-reduce implementations
2103// ===========================================================================
2104
2105#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2106#[allow(unsafe_code)]
2107#[allow(unsafe_op_in_unsafe_fn)]
2108#[target_feature(enable = "sse")]
2109unsafe fn max_reduce_sse(data: &[f32]) -> f32 {
2110    let len = data.len();
2111    let ptr = data.as_ptr();
2112    let mut index = 0usize;
2113    let mut acc = _mm_set1_ps(f32::NEG_INFINITY);
2114
2115    while index + 4 <= len {
2116        let v = _mm_loadu_ps(ptr.add(index));
2117        acc = _mm_max_ps(acc, v);
2118        index += 4;
2119    }
2120
2121    // Horizontal max of 4-lane accumulator
2122    let mut buf = [0.0f32; 4];
2123    _mm_storeu_ps(buf.as_mut_ptr(), acc);
2124    let mut result = buf[0].max(buf[1]).max(buf[2]).max(buf[3]);
2125
2126    while index < len {
2127        result = result.max(*ptr.add(index));
2128        index += 1;
2129    }
2130    result
2131}
2132
2133#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2134#[allow(unsafe_code)]
2135#[allow(unsafe_op_in_unsafe_fn)]
2136#[target_feature(enable = "avx")]
2137unsafe fn max_reduce_avx(data: &[f32]) -> f32 {
2138    let len = data.len();
2139    let ptr = data.as_ptr();
2140    let mut index = 0usize;
2141    let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
2142
2143    while index + 8 <= len {
2144        let v = _mm256_loadu_ps(ptr.add(index));
2145        acc = _mm256_max_ps(acc, v);
2146        index += 8;
2147    }
2148
2149    // Horizontal max of 8-lane accumulator
2150    let mut buf = [0.0f32; 8];
2151    _mm256_storeu_ps(buf.as_mut_ptr(), acc);
2152    let mut result = buf[0];
2153    for i in 1..8 {
2154        result = result.max(buf[i]);
2155    }
2156
2157    while index < len {
2158        result = result.max(*ptr.add(index));
2159        index += 1;
2160    }
2161    result
2162}
2163
2164#[cfg(target_arch = "aarch64")]
2165#[allow(unsafe_code, dead_code)]
2166#[allow(unsafe_op_in_unsafe_fn)]
2167#[target_feature(enable = "neon")]
2168unsafe fn max_reduce_neon(data: &[f32]) -> f32 {
2169    use std::arch::aarch64::vmaxvq_f32;
2170
2171    let len = data.len();
2172    let ptr = data.as_ptr();
2173    let mut index = 0usize;
2174    let mut acc = vdupq_n_f32(f32::NEG_INFINITY);
2175
2176    while index + 4 <= len {
2177        let v = vld1q_f32(ptr.add(index));
2178        acc = vmaxq_f32(acc, v);
2179        index += 4;
2180    }
2181
2182    let mut result = vmaxvq_f32(acc);
2183    while index < len {
2184        result = result.max(*ptr.add(index));
2185        index += 1;
2186    }
2187    result
2188}
2189
2190// ===========================================================================
2191// Add-reduce (sum) implementations
2192// ===========================================================================
2193
2194#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2195#[allow(unsafe_code)]
2196#[allow(unsafe_op_in_unsafe_fn)]
2197#[target_feature(enable = "sse")]
2198unsafe fn add_reduce_sse(data: &[f32]) -> f32 {
2199    let len = data.len();
2200    let ptr = data.as_ptr();
2201    let mut index = 0usize;
2202    let mut acc = _mm_setzero_ps();
2203
2204    while index + 4 <= len {
2205        let v = _mm_loadu_ps(ptr.add(index));
2206        acc = _mm_add_ps(acc, v);
2207        index += 4;
2208    }
2209
2210    // Horizontal sum of 4-lane accumulator
2211    let mut buf = [0.0f32; 4];
2212    _mm_storeu_ps(buf.as_mut_ptr(), acc);
2213    let mut result = buf[0] + buf[1] + buf[2] + buf[3];
2214
2215    while index < len {
2216        result += *ptr.add(index);
2217        index += 1;
2218    }
2219    result
2220}
2221
2222#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2223#[allow(unsafe_code)]
2224#[allow(unsafe_op_in_unsafe_fn)]
2225#[target_feature(enable = "avx")]
2226unsafe fn add_reduce_avx(data: &[f32]) -> f32 {
2227    let len = data.len();
2228    let ptr = data.as_ptr();
2229    let mut index = 0usize;
2230    let mut acc = _mm256_setzero_ps();
2231
2232    while index + 8 <= len {
2233        let v = _mm256_loadu_ps(ptr.add(index));
2234        acc = _mm256_add_ps(acc, v);
2235        index += 8;
2236    }
2237
2238    let mut buf = [0.0f32; 8];
2239    _mm256_storeu_ps(buf.as_mut_ptr(), acc);
2240    let mut result = buf[0] + buf[1] + buf[2] + buf[3] + buf[4] + buf[5] + buf[6] + buf[7];
2241
2242    while index < len {
2243        result += *ptr.add(index);
2244        index += 1;
2245    }
2246    result
2247}
2248
2249#[cfg(target_arch = "aarch64")]
2250#[allow(unsafe_code, dead_code)]
2251#[allow(unsafe_op_in_unsafe_fn)]
2252#[target_feature(enable = "neon")]
2253unsafe fn add_reduce_neon(data: &[f32]) -> f32 {
2254    use std::arch::aarch64::vaddvq_f32;
2255
2256    let len = data.len();
2257    let ptr = data.as_ptr();
2258    let mut index = 0usize;
2259    let mut acc = vdupq_n_f32(0.0);
2260
2261    while index + 4 <= len {
2262        let v = vld1q_f32(ptr.add(index));
2263        acc = vaddq_f32(acc, v);
2264        index += 4;
2265    }
2266
2267    let mut result = vaddvq_f32(acc);
2268    while index < len {
2269        result += *ptr.add(index);
2270        index += 1;
2271    }
2272    result
2273}
2274
2275// ===========================================================================
2276// FMA slice implementations: acc[i] += a[i] * b[i]
2277// ===========================================================================
2278
2279#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2280#[allow(unsafe_code)]
2281#[allow(unsafe_op_in_unsafe_fn)]
2282#[target_feature(enable = "sse")]
2283unsafe fn fma_slice_sse(a: &[f32], b: &[f32], acc: &mut [f32]) {
2284    let len = a.len();
2285    let a_ptr = a.as_ptr();
2286    let b_ptr = b.as_ptr();
2287    let acc_ptr = acc.as_mut_ptr();
2288    let mut index = 0usize;
2289
2290    while index + 4 <= len {
2291        let av = _mm_loadu_ps(a_ptr.add(index));
2292        let bv = _mm_loadu_ps(b_ptr.add(index));
2293        let cv = _mm_loadu_ps(acc_ptr.add(index));
2294        let result = _mm_add_ps(cv, _mm_mul_ps(av, bv));
2295        _mm_storeu_ps(acc_ptr.add(index), result);
2296        index += 4;
2297    }
2298
2299    if index < len {
2300        fma_slice_scalar(&a[index..], &b[index..], &mut acc[index..]);
2301    }
2302}
2303
2304#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2305#[allow(unsafe_code)]
2306#[allow(unsafe_op_in_unsafe_fn)]
2307#[target_feature(enable = "avx")]
2308unsafe fn fma_slice_avx(a: &[f32], b: &[f32], acc: &mut [f32]) {
2309    let len = a.len();
2310    let a_ptr = a.as_ptr();
2311    let b_ptr = b.as_ptr();
2312    let acc_ptr = acc.as_mut_ptr();
2313    let mut index = 0usize;
2314
2315    while index + 8 <= len {
2316        let av = _mm256_loadu_ps(a_ptr.add(index));
2317        let bv = _mm256_loadu_ps(b_ptr.add(index));
2318        let cv = _mm256_loadu_ps(acc_ptr.add(index));
2319        let result = _mm256_add_ps(cv, _mm256_mul_ps(av, bv));
2320        _mm256_storeu_ps(acc_ptr.add(index), result);
2321        index += 8;
2322    }
2323
2324    if index < len {
2325        fma_slice_sse(&a[index..], &b[index..], &mut acc[index..]);
2326    }
2327}
2328
2329#[cfg(target_arch = "aarch64")]
2330#[allow(unsafe_code, dead_code)]
2331#[allow(unsafe_op_in_unsafe_fn)]
2332#[target_feature(enable = "neon")]
2333unsafe fn fma_slice_neon(a: &[f32], b: &[f32], acc: &mut [f32]) {
2334    let len = a.len();
2335    let a_ptr = a.as_ptr();
2336    let b_ptr = b.as_ptr();
2337    let acc_ptr = acc.as_mut_ptr();
2338    let mut index = 0usize;
2339
2340    while index + 4 <= len {
2341        let av = vld1q_f32(a_ptr.add(index));
2342        let bv = vld1q_f32(b_ptr.add(index));
2343        let cv = vld1q_f32(acc_ptr.add(index));
2344        let result = vfmaq_f32(cv, av, bv);
2345        vst1q_f32(acc_ptr.add(index), result);
2346        index += 4;
2347    }
2348
2349    if index < len {
2350        fma_slice_scalar(&a[index..], &b[index..], &mut acc[index..]);
2351    }
2352}
2353
2354// ===========================================================================
2355// ReLU SIMD implementations (existing)
2356// ===========================================================================
2357
2358#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2359#[allow(unsafe_code)]
2360#[allow(unsafe_op_in_unsafe_fn)]
2361#[target_feature(enable = "sse")]
2362unsafe fn relu_slice_sse(values: &mut [f32]) {
2363    let len = values.len();
2364    let ptr = values.as_mut_ptr();
2365    let zero = _mm_setzero_ps();
2366    let mut index = 0usize;
2367
2368    while index + 4 <= len {
2369        let input = _mm_loadu_ps(ptr.add(index));
2370        let out = _mm_max_ps(input, zero);
2371        _mm_storeu_ps(ptr.add(index), out);
2372        index += 4;
2373    }
2374
2375    if index < len {
2376        relu_slice_scalar(&mut values[index..]);
2377    }
2378}
2379
2380#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2381#[allow(unsafe_code)]
2382#[allow(unsafe_op_in_unsafe_fn)]
2383#[target_feature(enable = "avx")]
2384unsafe fn relu_slice_avx(values: &mut [f32]) {
2385    let len = values.len();
2386    let ptr = values.as_mut_ptr();
2387    let zero = _mm256_setzero_ps();
2388    let mut index = 0usize;
2389
2390    // 4× unrolled: 32 elements per iteration
2391    while index + 32 <= len {
2392        let v0 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index)), zero);
2393        let v1 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 8)), zero);
2394        let v2 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 16)), zero);
2395        let v3 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 24)), zero);
2396        _mm256_storeu_ps(ptr.add(index), v0);
2397        _mm256_storeu_ps(ptr.add(index + 8), v1);
2398        _mm256_storeu_ps(ptr.add(index + 16), v2);
2399        _mm256_storeu_ps(ptr.add(index + 24), v3);
2400        index += 32;
2401    }
2402
2403    while index + 8 <= len {
2404        _mm256_storeu_ps(
2405            ptr.add(index),
2406            _mm256_max_ps(_mm256_loadu_ps(ptr.add(index)), zero),
2407        );
2408        index += 8;
2409    }
2410
2411    if index < len {
2412        relu_slice_sse(&mut values[index..]);
2413    }
2414}
2415
2416#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2417#[allow(unsafe_code)]
2418#[allow(unsafe_op_in_unsafe_fn)]
2419#[target_feature(enable = "sse")]
2420unsafe fn binary_same_shape_sse(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2421    let len = lhs.len();
2422    let left_ptr = lhs.as_ptr();
2423    let right_ptr = rhs.as_ptr();
2424    let out_ptr = out.as_mut_ptr();
2425    let mut index = 0usize;
2426
2427    while index + 4 <= len {
2428        let left = _mm_loadu_ps(left_ptr.add(index));
2429        let right = _mm_loadu_ps(right_ptr.add(index));
2430        let result = match kind {
2431            BinaryKind::Add => _mm_add_ps(left, right),
2432            BinaryKind::Sub => _mm_sub_ps(left, right),
2433            BinaryKind::Mul => _mm_mul_ps(left, right),
2434        };
2435        _mm_storeu_ps(out_ptr.add(index), result);
2436        index += 4;
2437    }
2438
2439    if index < len {
2440        binary_same_shape_scalar(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2441    }
2442}
2443
2444#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2445#[allow(unsafe_code)]
2446#[allow(unsafe_op_in_unsafe_fn)]
2447#[target_feature(enable = "avx")]
2448unsafe fn binary_same_shape_avx(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2449    let len = lhs.len();
2450    let left_ptr = lhs.as_ptr();
2451    let right_ptr = rhs.as_ptr();
2452    let out_ptr = out.as_mut_ptr();
2453    let mut index = 0usize;
2454
2455    // 4x unrolled: process 32 floats per iteration with software prefetch.
2456    // Matches vDSP throughput by keeping the OoO pipeline fully saturated.
2457    match kind {
2458        BinaryKind::Add => {
2459            while index + 32 <= len {
2460                #[cfg(target_arch = "x86")]
2461                {
2462                    use std::arch::x86::_mm_prefetch;
2463                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2464                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2465                }
2466                #[cfg(target_arch = "x86_64")]
2467                {
2468                    use std::arch::x86_64::_mm_prefetch;
2469                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2470                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2471                }
2472                let a0 = _mm256_loadu_ps(left_ptr.add(index));
2473                let b0 = _mm256_loadu_ps(right_ptr.add(index));
2474                let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2475                let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2476                _mm256_storeu_ps(out_ptr.add(index), _mm256_add_ps(a0, b0));
2477                _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_add_ps(a1, b1));
2478                let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2479                let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2480                let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2481                let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2482                _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_add_ps(a2, b2));
2483                _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_add_ps(a3, b3));
2484                index += 32;
2485            }
2486        }
2487        BinaryKind::Sub => {
2488            while index + 32 <= len {
2489                #[cfg(target_arch = "x86")]
2490                {
2491                    use std::arch::x86::_mm_prefetch;
2492                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2493                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2494                }
2495                #[cfg(target_arch = "x86_64")]
2496                {
2497                    use std::arch::x86_64::_mm_prefetch;
2498                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2499                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2500                }
2501                let a0 = _mm256_loadu_ps(left_ptr.add(index));
2502                let b0 = _mm256_loadu_ps(right_ptr.add(index));
2503                let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2504                let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2505                _mm256_storeu_ps(out_ptr.add(index), _mm256_sub_ps(a0, b0));
2506                _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_sub_ps(a1, b1));
2507                let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2508                let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2509                let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2510                let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2511                _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_sub_ps(a2, b2));
2512                _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_sub_ps(a3, b3));
2513                index += 32;
2514            }
2515        }
2516        BinaryKind::Mul => {
2517            while index + 32 <= len {
2518                #[cfg(target_arch = "x86")]
2519                {
2520                    use std::arch::x86::_mm_prefetch;
2521                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2522                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2523                }
2524                #[cfg(target_arch = "x86_64")]
2525                {
2526                    use std::arch::x86_64::_mm_prefetch;
2527                    _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2528                    _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2529                }
2530                let a0 = _mm256_loadu_ps(left_ptr.add(index));
2531                let b0 = _mm256_loadu_ps(right_ptr.add(index));
2532                let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2533                let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2534                _mm256_storeu_ps(out_ptr.add(index), _mm256_mul_ps(a0, b0));
2535                _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_mul_ps(a1, b1));
2536                let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2537                let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2538                let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2539                let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2540                _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_mul_ps(a2, b2));
2541                _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_mul_ps(a3, b3));
2542                index += 32;
2543            }
2544        }
2545    }
2546
2547    // Handle remaining elements 8 at a time
2548    while index + 8 <= len {
2549        let left = _mm256_loadu_ps(left_ptr.add(index));
2550        let right = _mm256_loadu_ps(right_ptr.add(index));
2551        let result = match kind {
2552            BinaryKind::Add => _mm256_add_ps(left, right),
2553            BinaryKind::Sub => _mm256_sub_ps(left, right),
2554            BinaryKind::Mul => _mm256_mul_ps(left, right),
2555        };
2556        _mm256_storeu_ps(out_ptr.add(index), result);
2557        index += 8;
2558    }
2559
2560    if index < len {
2561        binary_same_shape_sse(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2562    }
2563}
2564
2565#[cfg(target_arch = "aarch64")]
2566#[allow(unsafe_code)]
2567#[allow(unsafe_op_in_unsafe_fn)]
2568#[target_feature(enable = "neon")]
2569unsafe fn relu_slice_neon(values: &mut [f32]) {
2570    let len = values.len();
2571    let ptr = values.as_mut_ptr();
2572    let zero = vdupq_n_f32(0.0);
2573    let mut index = 0usize;
2574
2575    // 8× unrolled: 32 elements per iteration
2576    while index + 32 <= len {
2577        let v0 = vmaxq_f32(vld1q_f32(ptr.add(index)), zero);
2578        let v1 = vmaxq_f32(vld1q_f32(ptr.add(index + 4)), zero);
2579        let v2 = vmaxq_f32(vld1q_f32(ptr.add(index + 8)), zero);
2580        let v3 = vmaxq_f32(vld1q_f32(ptr.add(index + 12)), zero);
2581        let v4 = vmaxq_f32(vld1q_f32(ptr.add(index + 16)), zero);
2582        let v5 = vmaxq_f32(vld1q_f32(ptr.add(index + 20)), zero);
2583        let v6 = vmaxq_f32(vld1q_f32(ptr.add(index + 24)), zero);
2584        let v7 = vmaxq_f32(vld1q_f32(ptr.add(index + 28)), zero);
2585        vst1q_f32(ptr.add(index), v0);
2586        vst1q_f32(ptr.add(index + 4), v1);
2587        vst1q_f32(ptr.add(index + 8), v2);
2588        vst1q_f32(ptr.add(index + 12), v3);
2589        vst1q_f32(ptr.add(index + 16), v4);
2590        vst1q_f32(ptr.add(index + 20), v5);
2591        vst1q_f32(ptr.add(index + 24), v6);
2592        vst1q_f32(ptr.add(index + 28), v7);
2593        index += 32;
2594    }
2595
2596    while index + 4 <= len {
2597        vst1q_f32(ptr.add(index), vmaxq_f32(vld1q_f32(ptr.add(index)), zero));
2598        index += 4;
2599    }
2600
2601    if index < len {
2602        relu_slice_scalar(&mut values[index..]);
2603    }
2604}
2605
2606// ===========================================================================
2607// Two-argument ReLU SIMD implementations (input -> output)
2608// ===========================================================================
2609
2610#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2611#[allow(unsafe_code)]
2612#[allow(unsafe_op_in_unsafe_fn)]
2613#[target_feature(enable = "sse")]
2614unsafe fn relu_to_slice_sse(input: &[f32], output: &mut [f32]) {
2615    let len = input.len();
2616    let in_ptr = input.as_ptr();
2617    let out_ptr = output.as_mut_ptr();
2618    let zero = _mm_setzero_ps();
2619    let mut index = 0usize;
2620
2621    while index + 4 <= len {
2622        let v = _mm_loadu_ps(in_ptr.add(index));
2623        let r = _mm_max_ps(v, zero);
2624        _mm_storeu_ps(out_ptr.add(index), r);
2625        index += 4;
2626    }
2627
2628    if index < len {
2629        relu_to_slice_scalar(&input[index..], &mut output[index..]);
2630    }
2631}
2632
2633#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2634#[allow(unsafe_code)]
2635#[allow(unsafe_op_in_unsafe_fn)]
2636#[target_feature(enable = "avx")]
2637unsafe fn relu_to_slice_avx(input: &[f32], output: &mut [f32]) {
2638    let len = input.len();
2639    let in_ptr = input.as_ptr();
2640    let out_ptr = output.as_mut_ptr();
2641    let zero = _mm256_setzero_ps();
2642    let mut index = 0usize;
2643
2644    // 4× unrolled: 32 elements per iteration (matches NEON unrolling)
2645    while index + 32 <= len {
2646        let a0 = _mm256_loadu_ps(in_ptr.add(index));
2647        let a1 = _mm256_loadu_ps(in_ptr.add(index + 8));
2648        let a2 = _mm256_loadu_ps(in_ptr.add(index + 16));
2649        let a3 = _mm256_loadu_ps(in_ptr.add(index + 24));
2650        _mm256_storeu_ps(out_ptr.add(index), _mm256_max_ps(a0, zero));
2651        _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_max_ps(a1, zero));
2652        _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_max_ps(a2, zero));
2653        _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_max_ps(a3, zero));
2654        index += 32;
2655    }
2656
2657    while index + 8 <= len {
2658        _mm256_storeu_ps(
2659            out_ptr.add(index),
2660            _mm256_max_ps(_mm256_loadu_ps(in_ptr.add(index)), zero),
2661        );
2662        index += 8;
2663    }
2664
2665    if index < len {
2666        relu_to_slice_sse(&input[index..], &mut output[index..]);
2667    }
2668}
2669
2670#[cfg(target_arch = "aarch64")]
2671#[allow(unsafe_code)]
2672#[allow(unsafe_op_in_unsafe_fn)]
2673#[target_feature(enable = "neon")]
2674unsafe fn relu_to_slice_neon(input: &[f32], output: &mut [f32]) {
2675    let len = input.len();
2676    let in_ptr = input.as_ptr();
2677    let out_ptr = output.as_mut_ptr();
2678    let zero = vdupq_n_f32(0.0);
2679    let mut index = 0usize;
2680
2681    // 8× unrolled with interleaved load/compute/store for better OoO pipelining
2682    while index + 32 <= len {
2683        let a0 = vld1q_f32(in_ptr.add(index));
2684        let a1 = vld1q_f32(in_ptr.add(index + 4));
2685        let a2 = vld1q_f32(in_ptr.add(index + 8));
2686        let a3 = vld1q_f32(in_ptr.add(index + 12));
2687        vst1q_f32(out_ptr.add(index), vmaxq_f32(a0, zero));
2688        vst1q_f32(out_ptr.add(index + 4), vmaxq_f32(a1, zero));
2689        let a4 = vld1q_f32(in_ptr.add(index + 16));
2690        let a5 = vld1q_f32(in_ptr.add(index + 20));
2691        vst1q_f32(out_ptr.add(index + 8), vmaxq_f32(a2, zero));
2692        vst1q_f32(out_ptr.add(index + 12), vmaxq_f32(a3, zero));
2693        let a6 = vld1q_f32(in_ptr.add(index + 24));
2694        let a7 = vld1q_f32(in_ptr.add(index + 28));
2695        vst1q_f32(out_ptr.add(index + 16), vmaxq_f32(a4, zero));
2696        vst1q_f32(out_ptr.add(index + 20), vmaxq_f32(a5, zero));
2697        vst1q_f32(out_ptr.add(index + 24), vmaxq_f32(a6, zero));
2698        vst1q_f32(out_ptr.add(index + 28), vmaxq_f32(a7, zero));
2699        index += 32;
2700    }
2701
2702    while index + 4 <= len {
2703        vst1q_f32(
2704            out_ptr.add(index),
2705            vmaxq_f32(vld1q_f32(in_ptr.add(index)), zero),
2706        );
2707        index += 4;
2708    }
2709
2710    if index < len {
2711        relu_to_slice_scalar(&input[index..], &mut output[index..]);
2712    }
2713}
2714
2715#[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
2716#[allow(unsafe_code)]
2717#[allow(unsafe_op_in_unsafe_fn)]
2718#[target_feature(enable = "neon")]
2719unsafe fn binary_same_shape_neon(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2720    let len = lhs.len();
2721    let left_ptr = lhs.as_ptr();
2722    let right_ptr = rhs.as_ptr();
2723    let out_ptr = out.as_mut_ptr();
2724    let mut index = 0usize;
2725
2726    while index + 4 <= len {
2727        let left = vld1q_f32(left_ptr.add(index));
2728        let right = vld1q_f32(right_ptr.add(index));
2729        let result = match kind {
2730            BinaryKind::Add => vaddq_f32(left, right),
2731            BinaryKind::Sub => vsubq_f32(left, right),
2732            BinaryKind::Mul => vmulq_f32(left, right),
2733        };
2734        vst1q_f32(out_ptr.add(index), result);
2735        index += 4;
2736    }
2737
2738    if index < len {
2739        binary_same_shape_scalar(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2740    }
2741}
2742
2743// ---------------------------------------------------------------------------
2744// SIMD-accelerated matmul inner loop
2745// ---------------------------------------------------------------------------
2746//
2747// Computes one output row of C = A * B by iterating over the shared dimension k.
2748// For each k, broadcasts a[row*K + k] and multiplies by the contiguous B row
2749// b[k*N .. k*N + N], accumulating into the output row out[0..N].
2750//
2751// The "broadcast A, contiguous B row" access pattern is SIMD-friendly because
2752// all loads from B are contiguous.
2753
2754/// Dispatch to the best available SIMD path for a single matmul output row.
2755///
2756/// # Safety
2757/// - `left_row` must point to at least `k` valid f32 elements.
2758/// - `right` must point to at least `k * n` valid f32 elements (row-major B).
2759/// - `out_row` must point to at least `n` valid f32 elements.
2760/// - The caller must ensure no aliasing between `out_row` and the input pointers.
2761#[inline]
2762#[allow(unsafe_code)]
2763#[allow(unsafe_op_in_unsafe_fn)]
2764pub unsafe fn matmul_row_dispatch(
2765    left_row: *const f32,
2766    right: *const f32,
2767    out_row: *mut f32,
2768    k: usize,
2769    n: usize,
2770) {
2771    if cfg!(miri) {
2772        matmul_row_scalar(left_row, right, out_row, k, n);
2773        return;
2774    }
2775
2776    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2777    {
2778        if std::is_x86_feature_detected!("avx") {
2779            matmul_row_avx(left_row, right, out_row, k, n);
2780            return;
2781        }
2782        if std::is_x86_feature_detected!("sse") {
2783            matmul_row_sse(left_row, right, out_row, k, n);
2784            return;
2785        }
2786    }
2787
2788    #[cfg(target_arch = "aarch64")]
2789    {
2790        if std::arch::is_aarch64_feature_detected!("neon") {
2791            matmul_row_neon(left_row, right, out_row, k, n);
2792            return;
2793        }
2794    }
2795
2796    matmul_row_scalar(left_row, right, out_row, k, n);
2797}
2798
2799/// Scalar fallback: broadcast-multiply-accumulate, unrolled by 4.
2800#[allow(unsafe_code)]
2801#[allow(unsafe_op_in_unsafe_fn)]
2802unsafe fn matmul_row_scalar(
2803    left_row: *const f32,
2804    right: *const f32,
2805    out_row: *mut f32,
2806    k: usize,
2807    n: usize,
2808) {
2809    for p in 0..k {
2810        let a_val = *left_row.add(p);
2811        let b_row = right.add(p * n);
2812
2813        let mut col = 0usize;
2814        while col + 4 <= n {
2815            *out_row.add(col) += a_val * *b_row.add(col);
2816            *out_row.add(col + 1) += a_val * *b_row.add(col + 1);
2817            *out_row.add(col + 2) += a_val * *b_row.add(col + 2);
2818            *out_row.add(col + 3) += a_val * *b_row.add(col + 3);
2819            col += 4;
2820        }
2821        while col < n {
2822            *out_row.add(col) += a_val * *b_row.add(col);
2823            col += 1;
2824        }
2825    }
2826}
2827
2828#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2829#[allow(unsafe_code)]
2830#[allow(unsafe_op_in_unsafe_fn)]
2831#[target_feature(enable = "sse")]
2832unsafe fn matmul_row_sse(
2833    left_row: *const f32,
2834    right: *const f32,
2835    out_row: *mut f32,
2836    k: usize,
2837    n: usize,
2838) {
2839    for p in 0..k {
2840        let a_val = _mm_set1_ps(*left_row.add(p));
2841        let b_row = right.add(p * n);
2842
2843        let mut col = 0usize;
2844        while col + 4 <= n {
2845            let b_vec = _mm_loadu_ps(b_row.add(col));
2846            let out_vec = _mm_loadu_ps(out_row.add(col));
2847            let result = _mm_add_ps(out_vec, _mm_mul_ps(a_val, b_vec));
2848            _mm_storeu_ps(out_row.add(col), result);
2849            col += 4;
2850        }
2851        while col < n {
2852            *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2853            col += 1;
2854        }
2855    }
2856}
2857
2858#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2859#[allow(unsafe_code)]
2860#[allow(unsafe_op_in_unsafe_fn)]
2861#[target_feature(enable = "avx")]
2862unsafe fn matmul_row_avx(
2863    left_row: *const f32,
2864    right: *const f32,
2865    out_row: *mut f32,
2866    k: usize,
2867    n: usize,
2868) {
2869    for p in 0..k {
2870        let a_val_avx = _mm256_set1_ps(*left_row.add(p));
2871        let a_val_sse = _mm_set1_ps(*left_row.add(p));
2872        let b_row = right.add(p * n);
2873
2874        let mut col = 0usize;
2875        while col + 8 <= n {
2876            let b_vec = _mm256_loadu_ps(b_row.add(col));
2877            let out_vec = _mm256_loadu_ps(out_row.add(col));
2878            let result = _mm256_add_ps(out_vec, _mm256_mul_ps(a_val_avx, b_vec));
2879            _mm256_storeu_ps(out_row.add(col), result);
2880            col += 8;
2881        }
2882        // Handle 4-element remainder with SSE.
2883        while col + 4 <= n {
2884            let b_vec = _mm_loadu_ps(b_row.add(col));
2885            let out_vec = _mm_loadu_ps(out_row.add(col));
2886            let result = _mm_add_ps(out_vec, _mm_mul_ps(a_val_sse, b_vec));
2887            _mm_storeu_ps(out_row.add(col), result);
2888            col += 4;
2889        }
2890        while col < n {
2891            *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2892            col += 1;
2893        }
2894    }
2895}
2896
2897#[cfg(target_arch = "aarch64")]
2898#[allow(unsafe_code)]
2899#[allow(unsafe_op_in_unsafe_fn)]
2900#[target_feature(enable = "neon")]
2901unsafe fn matmul_row_neon(
2902    left_row: *const f32,
2903    right: *const f32,
2904    out_row: *mut f32,
2905    k: usize,
2906    n: usize,
2907) {
2908    for p in 0..k {
2909        let a_val: float32x4_t = vdupq_n_f32(*left_row.add(p));
2910        let b_row = right.add(p * n);
2911
2912        let mut col = 0usize;
2913        while col + 4 <= n {
2914            let b_vec = vld1q_f32(b_row.add(col));
2915            let out_vec = vld1q_f32(out_row.add(col));
2916            let result = vfmaq_f32(out_vec, a_val, b_vec);
2917            vst1q_f32(out_row.add(col), result);
2918            col += 4;
2919        }
2920        while col < n {
2921            *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2922            col += 1;
2923        }
2924    }
2925}
2926
2927// ===========================================================================
2928// Fused softmax: max + sub-exp + sum + divide in one function
2929// ===========================================================================
2930
2931/// Fused softmax row: `out[i] = exp(input[i] - max) / sum(exp(input - max))`.
2932///
2933/// Performs all four steps (max, subtract+exp, sum, divide) inside a single
2934/// function so that data stays in L1 cache and dispatcher overhead is eliminated.
2935#[allow(unsafe_code)]
2936#[inline]
2937pub fn softmax_row_fused_dispatch(input: &[f32], output: &mut [f32]) {
2938    debug_assert_eq!(input.len(), output.len());
2939
2940    if cfg!(miri) || input.is_empty() {
2941        softmax_row_fused_scalar(input, output);
2942        return;
2943    }
2944
2945    #[cfg(target_arch = "aarch64")]
2946    {
2947        if std::arch::is_aarch64_feature_detected!("neon") {
2948            // SAFETY: guarded by runtime feature detection.
2949            unsafe {
2950                softmax_row_fused_neon(input, output);
2951            }
2952            return;
2953        }
2954    }
2955
2956    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2957    {
2958        if std::is_x86_feature_detected!("avx") {
2959            // SAFETY: guarded by runtime feature detection.
2960            unsafe {
2961                softmax_row_fused_avx(input, output);
2962            }
2963            return;
2964        }
2965        if std::is_x86_feature_detected!("sse") {
2966            // SAFETY: guarded by runtime feature detection.
2967            unsafe {
2968                softmax_row_fused_sse(input, output);
2969            }
2970            return;
2971        }
2972    }
2973
2974    softmax_row_fused_scalar(input, output);
2975}
2976
2977fn softmax_row_fused_scalar(input: &[f32], output: &mut [f32]) {
2978    if input.is_empty() {
2979        return;
2980    }
2981
2982    // 1. max
2983    let mut max_val = f32::NEG_INFINITY;
2984    for &v in input {
2985        max_val = max_val.max(v);
2986    }
2987
2988    // 2. sub+exp + 3. accumulate sum
2989    let mut sum_exp = 0.0f32;
2990    for (o, &v) in output.iter_mut().zip(input.iter()) {
2991        let e = (v - max_val).exp();
2992        *o = e;
2993        sum_exp += e;
2994    }
2995
2996    // 4. divide
2997    let inv = 1.0 / sum_exp;
2998    for o in output.iter_mut() {
2999        *o *= inv;
3000    }
3001}
3002
3003#[cfg(target_arch = "aarch64")]
3004#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3005#[target_feature(enable = "neon")]
3006unsafe fn softmax_row_fused_neon(input: &[f32], output: &mut [f32]) {
3007    use std::arch::aarch64::{vaddvq_f32, vmaxvq_f32};
3008
3009    let len = input.len();
3010    let in_ptr = input.as_ptr();
3011    let out_ptr = output.as_mut_ptr();
3012
3013    // 1. Find max (NEON reduce)
3014    let mut acc_max = vdupq_n_f32(f32::NEG_INFINITY);
3015    let mut i = 0usize;
3016    while i + 16 <= len {
3017        let v0 = vld1q_f32(in_ptr.add(i));
3018        let v1 = vld1q_f32(in_ptr.add(i + 4));
3019        let v2 = vld1q_f32(in_ptr.add(i + 8));
3020        let v3 = vld1q_f32(in_ptr.add(i + 12));
3021        acc_max = vmaxq_f32(acc_max, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3)));
3022        i += 16;
3023    }
3024    while i + 4 <= len {
3025        let v = vld1q_f32(in_ptr.add(i));
3026        acc_max = vmaxq_f32(acc_max, v);
3027        i += 4;
3028    }
3029    let mut max_val = vmaxvq_f32(acc_max);
3030    while i < len {
3031        max_val = max_val.max(*in_ptr.add(i));
3032        i += 1;
3033    }
3034
3035    // 2. sub+exp (NEON fast_exp, writes output) + 3. accumulate sum
3036    let off = vdupq_n_f32(max_val);
3037    let mut acc_sum = vdupq_n_f32(0.0);
3038    i = 0;
3039    while i + 16 <= len {
3040        let v0 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3041        let v1 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), off));
3042        let v2 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), off));
3043        let v3 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), off));
3044        vst1q_f32(out_ptr.add(i), v0);
3045        vst1q_f32(out_ptr.add(i + 4), v1);
3046        vst1q_f32(out_ptr.add(i + 8), v2);
3047        vst1q_f32(out_ptr.add(i + 12), v3);
3048        acc_sum = vaddq_f32(acc_sum, vaddq_f32(vaddq_f32(v0, v1), vaddq_f32(v2, v3)));
3049        i += 16;
3050    }
3051    while i + 4 <= len {
3052        let v = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3053        vst1q_f32(out_ptr.add(i), v);
3054        acc_sum = vaddq_f32(acc_sum, v);
3055        i += 4;
3056    }
3057    let mut sum_exp = vaddvq_f32(acc_sum);
3058    while i < len {
3059        let e = (*in_ptr.add(i) - max_val).exp();
3060        *out_ptr.add(i) = e;
3061        sum_exp += e;
3062        i += 1;
3063    }
3064
3065    // 4. divide (NEON multiply by 1/sum)
3066    let inv = vdupq_n_f32(1.0 / sum_exp);
3067    i = 0;
3068    while i + 16 <= len {
3069        vst1q_f32(out_ptr.add(i), vmulq_f32(vld1q_f32(out_ptr.add(i)), inv));
3070        vst1q_f32(
3071            out_ptr.add(i + 4),
3072            vmulq_f32(vld1q_f32(out_ptr.add(i + 4)), inv),
3073        );
3074        vst1q_f32(
3075            out_ptr.add(i + 8),
3076            vmulq_f32(vld1q_f32(out_ptr.add(i + 8)), inv),
3077        );
3078        vst1q_f32(
3079            out_ptr.add(i + 12),
3080            vmulq_f32(vld1q_f32(out_ptr.add(i + 12)), inv),
3081        );
3082        i += 16;
3083    }
3084    while i + 4 <= len {
3085        vst1q_f32(out_ptr.add(i), vmulq_f32(vld1q_f32(out_ptr.add(i)), inv));
3086        i += 4;
3087    }
3088    let inv_s = 1.0 / sum_exp;
3089    while i < len {
3090        *out_ptr.add(i) *= inv_s;
3091        i += 1;
3092    }
3093}
3094
3095/// SSE fused softmax fallback.
3096#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3097#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3098#[target_feature(enable = "sse")]
3099unsafe fn softmax_row_fused_sse(input: &[f32], output: &mut [f32]) {
3100    let len = input.len();
3101    let in_ptr = input.as_ptr();
3102    let out_ptr = output.as_mut_ptr();
3103
3104    // 1. max
3105    let mut acc_max = _mm_set1_ps(f32::NEG_INFINITY);
3106    let mut i = 0usize;
3107    while i + 4 <= len {
3108        acc_max = _mm_max_ps(acc_max, _mm_loadu_ps(in_ptr.add(i)));
3109        i += 4;
3110    }
3111    let mut buf = [0.0f32; 4];
3112    _mm_storeu_ps(buf.as_mut_ptr(), acc_max);
3113    let mut max_val = buf[0].max(buf[1]).max(buf[2].max(buf[3]));
3114    while i < len {
3115        max_val = max_val.max(*in_ptr.add(i));
3116        i += 1;
3117    }
3118
3119    // 2. sub+exp + 3. sum
3120    let off = _mm_set1_ps(max_val);
3121    let mut acc_sum = _mm_setzero_ps();
3122    i = 0;
3123    while i + 4 <= len {
3124        let v = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off));
3125        _mm_storeu_ps(out_ptr.add(i), v);
3126        acc_sum = _mm_add_ps(acc_sum, v);
3127        i += 4;
3128    }
3129    _mm_storeu_ps(buf.as_mut_ptr(), acc_sum);
3130    let mut sum_exp = buf[0] + buf[1] + buf[2] + buf[3];
3131    while i < len {
3132        let e = (*in_ptr.add(i) - max_val).exp();
3133        *out_ptr.add(i) = e;
3134        sum_exp += e;
3135        i += 1;
3136    }
3137
3138    // 4. divide
3139    let inv = _mm_set1_ps(1.0 / sum_exp);
3140    i = 0;
3141    while i + 4 <= len {
3142        _mm_storeu_ps(
3143            out_ptr.add(i),
3144            _mm_mul_ps(_mm_loadu_ps(out_ptr.add(i)), inv),
3145        );
3146        i += 4;
3147    }
3148    let inv_s = 1.0 / sum_exp;
3149    while i < len {
3150        *out_ptr.add(i) *= inv_s;
3151        i += 1;
3152    }
3153}
3154
3155/// AVX fused softmax fallback — delegates tail to SSE.
3156#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3157#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3158#[target_feature(enable = "avx")]
3159unsafe fn softmax_row_fused_avx(input: &[f32], output: &mut [f32]) {
3160    let len = input.len();
3161    let in_ptr = input.as_ptr();
3162    let out_ptr = output.as_mut_ptr();
3163
3164    // 1. max
3165    let mut acc_max = _mm256_set1_ps(f32::NEG_INFINITY);
3166    let mut i = 0usize;
3167    while i + 8 <= len {
3168        acc_max = _mm256_max_ps(acc_max, _mm256_loadu_ps(in_ptr.add(i)));
3169        i += 8;
3170    }
3171    let mut buf8 = [0.0f32; 8];
3172    _mm256_storeu_ps(buf8.as_mut_ptr(), acc_max);
3173    let mut max_val = buf8[0];
3174    for &v in &buf8[1..] {
3175        max_val = max_val.max(v);
3176    }
3177    while i < len {
3178        max_val = max_val.max(*in_ptr.add(i));
3179        i += 1;
3180    }
3181
3182    // 2. sub+exp + 3. sum
3183    let off = _mm256_set1_ps(max_val);
3184    let mut acc_sum = _mm256_setzero_ps();
3185    i = 0;
3186    while i + 8 <= len {
3187        let v = fast_exp_avx(_mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), off));
3188        _mm256_storeu_ps(out_ptr.add(i), v);
3189        acc_sum = _mm256_add_ps(acc_sum, v);
3190        i += 8;
3191    }
3192    _mm256_storeu_ps(buf8.as_mut_ptr(), acc_sum);
3193    let mut sum_exp: f32 = buf8.iter().sum();
3194    // SSE tail for remaining < 8 elements
3195    let off4 = _mm_set1_ps(max_val);
3196    while i + 4 <= len {
3197        let v = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off4));
3198        _mm_storeu_ps(out_ptr.add(i), v);
3199        let mut b4 = [0.0f32; 4];
3200        _mm_storeu_ps(b4.as_mut_ptr(), v);
3201        sum_exp += b4[0] + b4[1] + b4[2] + b4[3];
3202        i += 4;
3203    }
3204    while i < len {
3205        let e = (*in_ptr.add(i) - max_val).exp();
3206        *out_ptr.add(i) = e;
3207        sum_exp += e;
3208        i += 1;
3209    }
3210
3211    // 4. divide
3212    let inv8 = _mm256_set1_ps(1.0 / sum_exp);
3213    i = 0;
3214    while i + 8 <= len {
3215        _mm256_storeu_ps(
3216            out_ptr.add(i),
3217            _mm256_mul_ps(_mm256_loadu_ps(out_ptr.add(i)), inv8),
3218        );
3219        i += 8;
3220    }
3221    let inv4 = _mm_set1_ps(1.0 / sum_exp);
3222    while i + 4 <= len {
3223        _mm_storeu_ps(
3224            out_ptr.add(i),
3225            _mm_mul_ps(_mm_loadu_ps(out_ptr.add(i)), inv4),
3226        );
3227        i += 4;
3228    }
3229    let inv_s = 1.0 / sum_exp;
3230    while i < len {
3231        *out_ptr.add(i) *= inv_s;
3232        i += 1;
3233    }
3234}
3235
3236// ===========================================================================
3237// Fused log-softmax: out[i] = x[i] - max - log(sum(exp(x - max)))
3238// ===========================================================================
3239
3240#[allow(unsafe_code)]
3241#[inline]
3242pub fn log_softmax_row_fused_dispatch(input: &[f32], output: &mut [f32]) {
3243    debug_assert_eq!(input.len(), output.len());
3244
3245    if cfg!(miri) || input.is_empty() {
3246        log_softmax_row_fused_scalar(input, output);
3247        return;
3248    }
3249
3250    #[cfg(target_arch = "aarch64")]
3251    {
3252        if std::arch::is_aarch64_feature_detected!("neon") {
3253            // SAFETY: guarded by runtime feature detection.
3254            unsafe {
3255                log_softmax_row_fused_neon(input, output);
3256            }
3257            return;
3258        }
3259    }
3260
3261    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3262    {
3263        if std::is_x86_feature_detected!("avx") {
3264            // SAFETY: guarded by runtime feature detection.
3265            unsafe {
3266                log_softmax_row_fused_avx(input, output);
3267            }
3268            return;
3269        }
3270        if std::is_x86_feature_detected!("sse") {
3271            // SAFETY: guarded by runtime feature detection.
3272            unsafe {
3273                log_softmax_row_fused_sse(input, output);
3274            }
3275            return;
3276        }
3277    }
3278
3279    log_softmax_row_fused_scalar(input, output);
3280}
3281
3282fn log_softmax_row_fused_scalar(input: &[f32], output: &mut [f32]) {
3283    if input.is_empty() {
3284        return;
3285    }
3286
3287    // 1. max
3288    let mut max_val = f32::NEG_INFINITY;
3289    for &v in input {
3290        max_val = max_val.max(v);
3291    }
3292
3293    // 2. sum(exp(x - max))
3294    let mut sum_exp = 0.0f32;
3295    for &v in input {
3296        sum_exp += (v - max_val).exp();
3297    }
3298
3299    // 3. output[i] = x[i] - max - log(sum_exp)
3300    let log_denom = max_val + sum_exp.ln();
3301    for (o, &v) in output.iter_mut().zip(input.iter()) {
3302        *o = v - log_denom;
3303    }
3304}
3305
3306#[cfg(target_arch = "aarch64")]
3307#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3308#[target_feature(enable = "neon")]
3309unsafe fn log_softmax_row_fused_neon(input: &[f32], output: &mut [f32]) {
3310    use std::arch::aarch64::{vaddvq_f32, vmaxvq_f32};
3311
3312    let len = input.len();
3313    let in_ptr = input.as_ptr();
3314    let out_ptr = output.as_mut_ptr();
3315
3316    // 1. Find max (NEON reduce)
3317    let mut acc_max = vdupq_n_f32(f32::NEG_INFINITY);
3318    let mut i = 0usize;
3319    while i + 16 <= len {
3320        let v0 = vld1q_f32(in_ptr.add(i));
3321        let v1 = vld1q_f32(in_ptr.add(i + 4));
3322        let v2 = vld1q_f32(in_ptr.add(i + 8));
3323        let v3 = vld1q_f32(in_ptr.add(i + 12));
3324        acc_max = vmaxq_f32(acc_max, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3)));
3325        i += 16;
3326    }
3327    while i + 4 <= len {
3328        acc_max = vmaxq_f32(acc_max, vld1q_f32(in_ptr.add(i)));
3329        i += 4;
3330    }
3331    let mut max_val = vmaxvq_f32(acc_max);
3332    while i < len {
3333        max_val = max_val.max(*in_ptr.add(i));
3334        i += 1;
3335    }
3336
3337    // 2. sum(exp(x - max))
3338    let off = vdupq_n_f32(max_val);
3339    let mut acc_sum = vdupq_n_f32(0.0);
3340    i = 0;
3341    while i + 16 <= len {
3342        let e0 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3343        let e1 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), off));
3344        let e2 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), off));
3345        let e3 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), off));
3346        acc_sum = vaddq_f32(acc_sum, vaddq_f32(vaddq_f32(e0, e1), vaddq_f32(e2, e3)));
3347        i += 16;
3348    }
3349    while i + 4 <= len {
3350        let e = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3351        acc_sum = vaddq_f32(acc_sum, e);
3352        i += 4;
3353    }
3354    let mut sum_exp = vaddvq_f32(acc_sum);
3355    while i < len {
3356        sum_exp += (*in_ptr.add(i) - max_val).exp();
3357        i += 1;
3358    }
3359
3360    // 3. output[i] = x[i] - (max + log(sum_exp))
3361    let log_denom = vdupq_n_f32(max_val + sum_exp.ln());
3362    i = 0;
3363    while i + 16 <= len {
3364        vst1q_f32(
3365            out_ptr.add(i),
3366            vsubq_f32(vld1q_f32(in_ptr.add(i)), log_denom),
3367        );
3368        vst1q_f32(
3369            out_ptr.add(i + 4),
3370            vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), log_denom),
3371        );
3372        vst1q_f32(
3373            out_ptr.add(i + 8),
3374            vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), log_denom),
3375        );
3376        vst1q_f32(
3377            out_ptr.add(i + 12),
3378            vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), log_denom),
3379        );
3380        i += 16;
3381    }
3382    while i + 4 <= len {
3383        vst1q_f32(
3384            out_ptr.add(i),
3385            vsubq_f32(vld1q_f32(in_ptr.add(i)), log_denom),
3386        );
3387        i += 4;
3388    }
3389    let log_denom_s = max_val + sum_exp.ln();
3390    while i < len {
3391        *out_ptr.add(i) = *in_ptr.add(i) - log_denom_s;
3392        i += 1;
3393    }
3394}
3395
3396/// SSE fused log-softmax.
3397#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3398#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3399#[target_feature(enable = "sse")]
3400unsafe fn log_softmax_row_fused_sse(input: &[f32], output: &mut [f32]) {
3401    let len = input.len();
3402    let in_ptr = input.as_ptr();
3403    let out_ptr = output.as_mut_ptr();
3404
3405    // 1. max
3406    let mut acc_max = _mm_set1_ps(f32::NEG_INFINITY);
3407    let mut i = 0usize;
3408    while i + 4 <= len {
3409        acc_max = _mm_max_ps(acc_max, _mm_loadu_ps(in_ptr.add(i)));
3410        i += 4;
3411    }
3412    let mut buf = [0.0f32; 4];
3413    _mm_storeu_ps(buf.as_mut_ptr(), acc_max);
3414    let mut max_val = buf[0].max(buf[1]).max(buf[2].max(buf[3]));
3415    while i < len {
3416        max_val = max_val.max(*in_ptr.add(i));
3417        i += 1;
3418    }
3419
3420    // 2. sum(exp(x - max))
3421    let off = _mm_set1_ps(max_val);
3422    let mut acc_sum = _mm_setzero_ps();
3423    i = 0;
3424    while i + 4 <= len {
3425        let e = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off));
3426        acc_sum = _mm_add_ps(acc_sum, e);
3427        i += 4;
3428    }
3429    _mm_storeu_ps(buf.as_mut_ptr(), acc_sum);
3430    let mut sum_exp = buf[0] + buf[1] + buf[2] + buf[3];
3431    while i < len {
3432        sum_exp += (*in_ptr.add(i) - max_val).exp();
3433        i += 1;
3434    }
3435
3436    // 3. output[i] = x[i] - (max + log(sum_exp))
3437    let log_denom = _mm_set1_ps(max_val + sum_exp.ln());
3438    i = 0;
3439    while i + 4 <= len {
3440        _mm_storeu_ps(
3441            out_ptr.add(i),
3442            _mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), log_denom),
3443        );
3444        i += 4;
3445    }
3446    let log_denom_s = max_val + sum_exp.ln();
3447    while i < len {
3448        *out_ptr.add(i) = *in_ptr.add(i) - log_denom_s;
3449        i += 1;
3450    }
3451}
3452
3453/// AVX fused log-softmax.
3454#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3455#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3456#[target_feature(enable = "avx")]
3457unsafe fn log_softmax_row_fused_avx(input: &[f32], output: &mut [f32]) {
3458    let len = input.len();
3459    let in_ptr = input.as_ptr();
3460    let out_ptr = output.as_mut_ptr();
3461
3462    // 1. max
3463    let mut acc_max = _mm256_set1_ps(f32::NEG_INFINITY);
3464    let mut i = 0usize;
3465    while i + 8 <= len {
3466        acc_max = _mm256_max_ps(acc_max, _mm256_loadu_ps(in_ptr.add(i)));
3467        i += 8;
3468    }
3469    let mut buf8 = [0.0f32; 8];
3470    _mm256_storeu_ps(buf8.as_mut_ptr(), acc_max);
3471    let mut max_val = buf8[0];
3472    for &v in &buf8[1..] {
3473        max_val = max_val.max(v);
3474    }
3475    while i < len {
3476        max_val = max_val.max(*in_ptr.add(i));
3477        i += 1;
3478    }
3479
3480    // 2. sum(exp(x - max))
3481    let off = _mm256_set1_ps(max_val);
3482    let mut acc_sum = _mm256_setzero_ps();
3483    i = 0;
3484    while i + 8 <= len {
3485        let e = fast_exp_avx(_mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), off));
3486        acc_sum = _mm256_add_ps(acc_sum, e);
3487        i += 8;
3488    }
3489    _mm256_storeu_ps(buf8.as_mut_ptr(), acc_sum);
3490    let mut sum_exp: f32 = buf8.iter().sum();
3491    // SSE tail for remaining < 8 elements
3492    let off4 = _mm_set1_ps(max_val);
3493    while i + 4 <= len {
3494        let e = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off4));
3495        let mut b4 = [0.0f32; 4];
3496        _mm_storeu_ps(b4.as_mut_ptr(), e);
3497        sum_exp += b4[0] + b4[1] + b4[2] + b4[3];
3498        i += 4;
3499    }
3500    while i < len {
3501        sum_exp += (*in_ptr.add(i) - max_val).exp();
3502        i += 1;
3503    }
3504
3505    // 3. output[i] = x[i] - (max + log(sum_exp))
3506    let log_denom_val = max_val + sum_exp.ln();
3507    let log_denom8 = _mm256_set1_ps(log_denom_val);
3508    i = 0;
3509    while i + 8 <= len {
3510        _mm256_storeu_ps(
3511            out_ptr.add(i),
3512            _mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), log_denom8),
3513        );
3514        i += 8;
3515    }
3516    let log_denom4 = _mm_set1_ps(log_denom_val);
3517    while i + 4 <= len {
3518        _mm_storeu_ps(
3519            out_ptr.add(i),
3520            _mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), log_denom4),
3521        );
3522        i += 4;
3523    }
3524    while i < len {
3525        *out_ptr.add(i) = *in_ptr.add(i) - log_denom_val;
3526        i += 1;
3527    }
3528}
3529
3530// ===========================================================================
3531// Tests
3532// ===========================================================================
3533
3534#[cfg(test)]
3535mod tests {
3536    use super::*;
3537
3538    fn assert_close(a: &[f32], b: &[f32], tol: f32) {
3539        assert_eq!(a.len(), b.len(), "length mismatch");
3540        for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
3541            let d = (x - y).abs();
3542            assert!(d <= tol, "index {i}: {x} vs {y}, diff={d}, tolerance={tol}");
3543        }
3544    }
3545
3546    #[test]
3547    fn exp_matches_scalar() {
3548        let input: Vec<f32> = (-20..=20).map(|i| i as f32 * 0.5).collect();
3549        let mut simd_out = vec![0.0f32; input.len()];
3550        let mut scalar_out = vec![0.0f32; input.len()];
3551
3552        exp_slice_dispatch(&input, &mut simd_out);
3553        exp_slice_scalar(&input, &mut scalar_out);
3554
3555        // Degree-6 Taylor polynomial is accurate to roughly 1e-6 relative error
3556        for (i, (&s, &r)) in simd_out.iter().zip(scalar_out.iter()).enumerate() {
3557            let rel = if r.abs() > 1e-10 {
3558                (s - r).abs() / r.abs()
3559            } else {
3560                (s - r).abs()
3561            };
3562            assert!(
3563                rel < 1e-5,
3564                "exp mismatch at index {i}: simd={s}, scalar={r}, rel_err={rel}"
3565            );
3566        }
3567    }
3568
3569    #[test]
3570    fn sigmoid_dispatch_matches_scalar() {
3571        let input: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.3).collect();
3572        let mut simd_out = vec![0.0f32; input.len()];
3573        let mut scalar_out = vec![0.0f32; input.len()];
3574
3575        sigmoid_slice_dispatch(&input, &mut simd_out);
3576        sigmoid_slice_dispatch_scalar(&input, &mut scalar_out);
3577
3578        // Sigmoid uses Schraudolph bit-trick exp (~4% max error on exp,
3579        // but sigmoid squashes error near 0/1, practical max ~0.03).
3580        assert_close(&simd_out, &scalar_out, 0.035);
3581    }
3582
3583    #[test]
3584    fn tanh_dispatch_matches_scalar() {
3585        let input: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.3).collect();
3586        let mut simd_out = vec![0.0f32; input.len()];
3587        let mut scalar_out = vec![0.0f32; input.len()];
3588
3589        tanh_slice_dispatch(&input, &mut simd_out);
3590        tanh_slice_dispatch_scalar(&input, &mut scalar_out);
3591
3592        // Uses fast 3-term exp polynomial for sigmoid path (~2e-3 max error vs scalar tanh).
3593        assert_close(&simd_out, &scalar_out, 2e-3);
3594    }
3595
3596    #[test]
3597    fn max_reduce_matches_scalar() {
3598        let data: Vec<f32> = (0..37).map(|i| (i as f32 * 0.7 - 12.0).sin()).collect();
3599        let simd_result = max_reduce_dispatch(&data);
3600        let scalar_result = max_reduce_scalar(&data);
3601        assert!((simd_result - scalar_result).abs() < 1e-6);
3602    }
3603
3604    #[test]
3605    fn max_reduce_empty() {
3606        assert_eq!(max_reduce_dispatch(&[]), f32::NEG_INFINITY);
3607    }
3608
3609    #[test]
3610    fn add_reduce_matches_scalar() {
3611        let data: Vec<f32> = (0..37).map(|i| i as f32 * 0.1).collect();
3612        let simd_result = add_reduce_dispatch(&data);
3613        let scalar_result = add_reduce_scalar(&data);
3614        assert!(
3615            (simd_result - scalar_result).abs() < 1e-3,
3616            "simd={simd_result}, scalar={scalar_result}"
3617        );
3618    }
3619
3620    #[test]
3621    fn add_reduce_empty() {
3622        assert_eq!(add_reduce_dispatch(&[]), 0.0);
3623    }
3624
3625    #[test]
3626    #[allow(unsafe_code)]
3627    fn fma_matches_scalar() {
3628        let a: Vec<f32> = (0..33).map(|i| i as f32 * 0.3).collect();
3629        let b: Vec<f32> = (0..33).map(|i| (i as f32 * 0.7).sin()).collect();
3630        let mut simd_acc = vec![1.0f32; 33];
3631        let mut scalar_acc = vec![1.0f32; 33];
3632
3633        fma_slice_dispatch(&a, &b, &mut simd_acc);
3634        unsafe { fma_slice_scalar(&a, &b, &mut scalar_acc) };
3635
3636        assert_close(&simd_acc, &scalar_acc, 1e-5);
3637    }
3638
3639    #[test]
3640    fn sigmoid_dispatch_boundary_values() {
3641        // Verify sigmoid at key points
3642        let input = vec![-100.0, -10.0, 0.0, 10.0, 100.0];
3643        let mut output = vec![0.0f32; 5];
3644        sigmoid_slice_dispatch(&input, &mut output);
3645
3646        // sigmoid(-100) ~ 0, sigmoid(0) = 0.5, sigmoid(100) ~ 1
3647        assert!(
3648            output[0] < 0.01,
3649            "sigmoid(-100) should be near 0: {}",
3650            output[0]
3651        );
3652        assert!(
3653            (output[2] - 0.5).abs() < 0.01,
3654            "sigmoid(0) should be near 0.5: {}",
3655            output[2]
3656        );
3657        assert!(
3658            output[4] > 0.99,
3659            "sigmoid(100) should be near 1: {}",
3660            output[4]
3661        );
3662    }
3663
3664    #[test]
3665    fn tanh_dispatch_boundary_values() {
3666        let input = vec![-100.0, -1.0, 0.0, 1.0, 100.0];
3667        let mut output = vec![0.0f32; 5];
3668        tanh_slice_dispatch(&input, &mut output);
3669
3670        assert!(
3671            output[0] < -0.99,
3672            "tanh(-100) should be near -1: {}",
3673            output[0]
3674        );
3675        assert!(
3676            (output[2]).abs() < 0.01,
3677            "tanh(0) should be near 0: {}",
3678            output[2]
3679        );
3680        assert!(
3681            output[4] > 0.99,
3682            "tanh(100) should be near 1: {}",
3683            output[4]
3684        );
3685    }
3686}