Skip to main content

rlx_cpu/
kernels.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SIMD kernels for fused operations.
17//!
18//! These are the production kernels extracted from burnembed's ndarray_fused.rs.
19//! Each kernel processes data in-place or into a pre-allocated output buffer
20//! (from the arena). No allocation.
21
22use crate::pool;
23
24// ── NEON vectorized exp ─────────────────────────────────────────────────
25
26/// NEON vectorized exp(x) for 4 floats. Range reduction + 6th-order Taylor.
27/// Max relative error: ~2e-7 across [-87, 88].
28#[cfg(target_arch = "aarch64")]
29#[inline(always)]
30#[allow(unsafe_op_in_unsafe_fn)]
31pub unsafe fn neon_exp4(x: std::arch::aarch64::float32x4_t) -> std::arch::aarch64::float32x4_t {
32    use std::arch::aarch64::*;
33    let x = vmaxq_f32(x, vdupq_n_f32(-87.3));
34    let x = vminq_f32(x, vdupq_n_f32(88.7));
35    let inv_ln2 = vdupq_n_f32(std::f32::consts::LOG2_E);
36    let ln2_hi = vdupq_n_f32(0.693_145_75);
37    let ln2_lo = vdupq_n_f32(1.428_606_8e-6);
38    let n = vrndnq_f32(vmulq_f32(x, inv_ln2));
39    let r = vfmsq_f32(vfmsq_f32(x, n, ln2_hi), n, ln2_lo);
40    let c1 = vdupq_n_f32(1.0);
41    let mut p = vdupq_n_f32(0.001_388_888_9);
42    p = vfmaq_f32(vdupq_n_f32(0.008_333_334), p, r);
43    p = vfmaq_f32(vdupq_n_f32(0.041_666_668), p, r);
44    p = vfmaq_f32(vdupq_n_f32(0.166_666_67), p, r);
45    p = vfmaq_f32(vdupq_n_f32(0.5), p, r);
46    p = vfmaq_f32(c1, p, r);
47    p = vfmaq_f32(c1, p, r);
48    let ni = vcvtq_s32_f32(n);
49    vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(p), vshlq_n_s32(ni, 23)))
50}
51
52/// AVX2+FMA vectorised exp(x) for 8 floats. Same range reduction +
53/// 6th-order Taylor polynomial as `neon_exp4`. Max relative error
54/// stays in the ~2e-7 range. Requires `+avx2,+fma` codegen.
55#[cfg(all(
56    target_arch = "x86_64",
57    target_feature = "avx2",
58    target_feature = "fma"
59))]
60#[inline(always)]
61#[allow(unsafe_op_in_unsafe_fn)]
62pub unsafe fn avx2_exp8(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
63    use std::arch::x86_64::*;
64    let x = _mm256_max_ps(x, _mm256_set1_ps(-87.3));
65    let x = _mm256_min_ps(x, _mm256_set1_ps(88.7));
66    let inv_ln2 = _mm256_set1_ps(1.442695040888963);
67    let ln2_hi = _mm256_set1_ps(0.693145751953125);
68    let ln2_lo = _mm256_set1_ps(1.428606765330187e-6);
69    // n = round(x / ln2)  (round-to-nearest-even)
70    let n = _mm256_round_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(_mm256_mul_ps(
71        x, inv_ln2,
72    ));
73    // r = x − n·ln2_hi − n·ln2_lo
74    let r = _mm256_fnmadd_ps(n, ln2_lo, _mm256_fnmadd_ps(n, ln2_hi, x));
75    let c1 = _mm256_set1_ps(1.0);
76    let mut p = _mm256_set1_ps(0.001388888888888889);
77    p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.008333333333333333));
78    p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.041666666666666664));
79    p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.16666666666666666));
80    p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.5));
81    p = _mm256_fmadd_ps(p, r, c1);
82    p = _mm256_fmadd_ps(p, r, c1);
83    // 2^n via integer-bias trick on the f32 exponent field.
84    let ni = _mm256_cvtps_epi32(n);
85    let shifted = _mm256_slli_epi32::<23>(ni);
86    _mm256_castsi256_ps(_mm256_add_epi32(_mm256_castps_si256(p), shifted))
87}
88
89// ── Fused bias + GELU ───────────────────────────────────────────────────
90
91/// Fused bias addition + GELU activation on a [m, n] buffer.
92/// Uses Abramowitz & Stegun erf approximation with NEON exp.
93#[cfg(target_arch = "aarch64")]
94pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
95    use std::arch::aarch64::*;
96    let chunks = n / 4;
97    unsafe {
98        let half = vdupq_n_f32(0.5);
99        let one = vdupq_n_f32(1.0);
100        let inv_sqrt2 = vdupq_n_f32(std::f32::consts::FRAC_1_SQRT_2);
101        let p = vdupq_n_f32(0.3275911);
102        let a1 = vdupq_n_f32(0.254_829_6);
103        let a2 = vdupq_n_f32(-0.284_496_72);
104        let a3 = vdupq_n_f32(1.421_413_8);
105        let a4 = vdupq_n_f32(-1.453_152_1);
106        let a5 = vdupq_n_f32(1.061_405_4);
107        let neg_one = vdupq_n_f32(-1.0);
108        let zero = vdupq_n_f32(0.0);
109
110        for row in 0..m {
111            let base = row * n;
112            for c in 0..chunks {
113                let off = base + c * 4;
114                let ptr = data.as_mut_ptr().add(off);
115                let x = vaddq_f32(vld1q_f32(ptr), vld1q_f32(bias.as_ptr().add(c * 4)));
116                let erf_arg = vmulq_f32(x, inv_sqrt2);
117                let xa = vabsq_f32(erf_arg);
118                let sign = vbslq_f32(vcgeq_f32(erf_arg, zero), one, neg_one);
119                let denom = vfmaq_f32(one, p, xa);
120                let t = vdivq_f32(one, denom);
121                let mut y = a5;
122                y = vfmaq_f32(a4, y, t);
123                y = vfmaq_f32(a3, y, t);
124                y = vfmaq_f32(a2, y, t);
125                y = vfmaq_f32(a1, y, t);
126                y = vmulq_f32(y, t);
127                let exp_val = neon_exp4(vnegq_f32(vmulq_f32(xa, xa)));
128                let erf_val = vmulq_f32(sign, vfmsq_f32(one, y, exp_val));
129                vst1q_f32(ptr, vmulq_f32(x, vmulq_f32(half, vaddq_f32(one, erf_val))));
130            }
131            for i in (chunks * 4)..n {
132                let x = data[base + i] + bias[i];
133                data[base + i] = scalar_gelu(x);
134            }
135        }
136    }
137}
138
139#[cfg(all(
140    target_arch = "x86_64",
141    target_feature = "avx2",
142    target_feature = "fma"
143))]
144pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
145    use std::arch::x86_64::*;
146    let chunks = n / 8;
147    unsafe {
148        let half = _mm256_set1_ps(0.5);
149        let one = _mm256_set1_ps(1.0);
150        let inv_sqrt2 = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
151        let p = _mm256_set1_ps(0.3275911);
152        let a1 = _mm256_set1_ps(0.254829592);
153        let a2 = _mm256_set1_ps(-0.284496736);
154        let a3 = _mm256_set1_ps(1.421413741);
155        let a4 = _mm256_set1_ps(-1.453152027);
156        let a5 = _mm256_set1_ps(1.061405429);
157        let neg_one = _mm256_set1_ps(-1.0);
158        let zero = _mm256_set1_ps(0.0);
159        // Sign bit mask for fabs via AND with 0x7fffffff.
160        let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff_ffff));
161
162        for row in 0..m {
163            let base = row * n;
164            for c in 0..chunks {
165                let off = base + c * 8;
166                let ptr = data.as_mut_ptr().add(off);
167                let x = _mm256_add_ps(
168                    _mm256_loadu_ps(ptr),
169                    _mm256_loadu_ps(bias.as_ptr().add(c * 8)),
170                );
171                let erf_arg = _mm256_mul_ps(x, inv_sqrt2);
172                let xa = _mm256_and_ps(erf_arg, abs_mask);
173                // sign = (erf_arg >= 0) ? 1 : -1
174                let ge0 = _mm256_cmp_ps::<_CMP_GE_OQ>(erf_arg, zero);
175                let sign = _mm256_blendv_ps(neg_one, one, ge0);
176                let denom = _mm256_fmadd_ps(p, xa, one);
177                let t = _mm256_div_ps(one, denom);
178                let mut y = a5;
179                y = _mm256_fmadd_ps(y, t, a4);
180                y = _mm256_fmadd_ps(y, t, a3);
181                y = _mm256_fmadd_ps(y, t, a2);
182                y = _mm256_fmadd_ps(y, t, a1);
183                y = _mm256_mul_ps(y, t);
184                let exp_val = avx2_exp8(_mm256_sub_ps(zero, _mm256_mul_ps(xa, xa)));
185                // erf = sign * (1 - y*exp(-xa^2))
186                let erf_val = _mm256_mul_ps(sign, _mm256_fnmadd_ps(y, exp_val, one));
187                _mm256_storeu_ps(
188                    ptr,
189                    _mm256_mul_ps(x, _mm256_mul_ps(half, _mm256_add_ps(one, erf_val))),
190                );
191            }
192            for i in (chunks * 8)..n {
193                let x = data[base + i] + bias[i];
194                data[base + i] = scalar_gelu(x);
195            }
196        }
197    }
198}
199
200#[cfg(not(any(
201    target_arch = "aarch64",
202    all(
203        target_arch = "x86_64",
204        target_feature = "avx2",
205        target_feature = "fma"
206    )
207)))]
208pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
209    for row in 0..m {
210        let base = row * n;
211        for i in 0..n {
212            let x = data[base + i] + bias[i];
213            data[base + i] = scalar_gelu(x);
214        }
215    }
216}
217
218/// Parallel bias + GELU across thread pool.
219pub fn par_bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
220    let cfg = crate::config::RuntimeConfig::global();
221    if m * n < cfg.par_threshold || m < cfg.min_rows_per_thread {
222        bias_gelu(data, bias, m, n);
223        return;
224    }
225    let data_ptr = data.as_mut_ptr() as usize;
226    let bias_ptr = bias.as_ptr() as usize;
227    pool::par_for(m, cfg.min_rows_per_thread, &|off, cnt| unsafe {
228        let d = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(off * n), cnt * n);
229        let b = std::slice::from_raw_parts(bias_ptr as *const f32, n);
230        bias_gelu(d, b, cnt, n);
231    });
232}
233
234// ── Fused SiLU ──────────────────────────────────────────────────────────
235
236/// SiLU (Swish) in-place: x / (1 + exp(-x))
237#[cfg(target_arch = "aarch64")]
238pub fn silu_inplace(data: &mut [f32]) {
239    use std::arch::aarch64::*;
240    let chunks = data.len() / 4;
241    unsafe {
242        let one = vdupq_n_f32(1.0);
243        for c in 0..chunks {
244            let ptr = data.as_mut_ptr().add(c * 4);
245            let x = vld1q_f32(ptr);
246            let exp_neg = neon_exp4(vnegq_f32(x));
247            let sigmoid = vdivq_f32(one, vaddq_f32(one, exp_neg));
248            vst1q_f32(ptr, vmulq_f32(x, sigmoid));
249        }
250    }
251    for i in (chunks * 4)..data.len() {
252        let x = data[i];
253        data[i] = x / (1.0 + (-x).exp());
254    }
255}
256
257#[cfg(all(
258    target_arch = "x86_64",
259    target_feature = "avx2",
260    target_feature = "fma"
261))]
262pub fn silu_inplace(data: &mut [f32]) {
263    use std::arch::x86_64::*;
264    let chunks = data.len() / 8;
265    unsafe {
266        let one = _mm256_set1_ps(1.0);
267        let zero = _mm256_set1_ps(0.0);
268        for c in 0..chunks {
269            let off = c * 8;
270            let ptr = data.as_mut_ptr().add(off);
271            let x = _mm256_loadu_ps(ptr);
272            // silu(x) = x / (1 + exp(-x))
273            let neg_x = _mm256_sub_ps(zero, x);
274            let denom = _mm256_add_ps(one, avx2_exp8(neg_x));
275            _mm256_storeu_ps(ptr, _mm256_div_ps(x, denom));
276        }
277        for i in (chunks * 8)..data.len() {
278            let x = data[i];
279            data[i] = x / (1.0 + (-x).exp());
280        }
281    }
282}
283
284#[cfg(not(any(
285    target_arch = "aarch64",
286    all(
287        target_arch = "x86_64",
288        target_feature = "avx2",
289        target_feature = "fma"
290    )
291)))]
292pub fn silu_inplace(data: &mut [f32]) {
293    for v in data.iter_mut() {
294        let x = *v;
295        *v = x / (1.0 + (-x).exp());
296    }
297}
298
299// ── LayerNorm (2-pass) ──────────────────────────────────────────────────
300
301/// Single-row LayerNorm: out = (x - mean) * inv_std * gamma + beta.
302/// 2-pass: compute mean+variance (E\[x²\]-E\[x\]²), then normalize.
303#[cfg(target_arch = "aarch64")]
304pub fn layer_norm_row(
305    input: &[f32],
306    gamma: &[f32],
307    beta: &[f32],
308    output: &mut [f32],
309    h: usize,
310    eps: f32,
311) {
312    use std::arch::aarch64::*;
313    let inv_hf = 1.0 / h as f32;
314    let chunks = h / 4;
315    unsafe {
316        let mut vsum = vdupq_n_f32(0.0);
317        let mut vsumsq = vdupq_n_f32(0.0);
318        for c in 0..chunks {
319            let x = vld1q_f32(input.as_ptr().add(c * 4));
320            vsum = vaddq_f32(vsum, x);
321            vsumsq = vfmaq_f32(vsumsq, x, x);
322        }
323        let mut sum = vaddvq_f32(vsum);
324        let mut sumsq = vaddvq_f32(vsumsq);
325        for i in (chunks * 4)..h {
326            sum += input[i];
327            sumsq += input[i] * input[i];
328        }
329        let mean = sum * inv_hf;
330        let var = sumsq * inv_hf - mean * mean;
331        let inv = 1.0 / (var + eps).sqrt();
332        let vmean = vdupq_n_f32(mean);
333        let vinv = vdupq_n_f32(inv);
334        for c in 0..chunks {
335            let off = c * 4;
336            let x = vld1q_f32(input.as_ptr().add(off));
337            let norm = vmulq_f32(vsubq_f32(x, vmean), vinv);
338            vst1q_f32(
339                output.as_mut_ptr().add(off),
340                vfmaq_f32(
341                    vld1q_f32(beta.as_ptr().add(off)),
342                    norm,
343                    vld1q_f32(gamma.as_ptr().add(off)),
344                ),
345            );
346        }
347        for i in (chunks * 4)..h {
348            output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
349        }
350    }
351}
352
353#[cfg(all(
354    target_arch = "x86_64",
355    target_feature = "avx2",
356    target_feature = "fma"
357))]
358pub fn layer_norm_row(
359    input: &[f32],
360    gamma: &[f32],
361    beta: &[f32],
362    output: &mut [f32],
363    h: usize,
364    eps: f32,
365) {
366    use std::arch::x86_64::*;
367    let inv_hf = 1.0 / h as f32;
368    let chunks = h / 8;
369    unsafe {
370        let mut vsum = _mm256_setzero_ps();
371        let mut vsumsq = _mm256_setzero_ps();
372        for c in 0..chunks {
373            let x = _mm256_loadu_ps(input.as_ptr().add(c * 8));
374            vsum = _mm256_add_ps(vsum, x);
375            vsumsq = _mm256_fmadd_ps(x, x, vsumsq);
376        }
377        // Horizontal reduce: 8 lanes → 1.
378        let hsum = {
379            let lo = _mm256_castps256_ps128(vsum);
380            let hi = _mm256_extractf128_ps::<1>(vsum);
381            let s4 = _mm_add_ps(lo, hi);
382            let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
383            let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
384            _mm_cvtss_f32(s1)
385        };
386        let hsumsq = {
387            let lo = _mm256_castps256_ps128(vsumsq);
388            let hi = _mm256_extractf128_ps::<1>(vsumsq);
389            let s4 = _mm_add_ps(lo, hi);
390            let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
391            let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
392            _mm_cvtss_f32(s1)
393        };
394        let mut sum = hsum;
395        let mut sumsq = hsumsq;
396        for i in (chunks * 8)..h {
397            sum += input[i];
398            sumsq += input[i] * input[i];
399        }
400        let mean = sum * inv_hf;
401        let var = sumsq * inv_hf - mean * mean;
402        let inv = 1.0 / (var + eps).sqrt();
403        let vmean = _mm256_set1_ps(mean);
404        let vinv = _mm256_set1_ps(inv);
405        for c in 0..chunks {
406            let off = c * 8;
407            let x = _mm256_loadu_ps(input.as_ptr().add(off));
408            let norm = _mm256_mul_ps(_mm256_sub_ps(x, vmean), vinv);
409            let g = _mm256_loadu_ps(gamma.as_ptr().add(off));
410            let b = _mm256_loadu_ps(beta.as_ptr().add(off));
411            _mm256_storeu_ps(output.as_mut_ptr().add(off), _mm256_fmadd_ps(norm, g, b));
412        }
413        for i in (chunks * 8)..h {
414            output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
415        }
416    }
417}
418
419#[cfg(not(any(
420    target_arch = "aarch64",
421    all(
422        target_arch = "x86_64",
423        target_feature = "avx2",
424        target_feature = "fma"
425    )
426)))]
427pub fn layer_norm_row(
428    input: &[f32],
429    gamma: &[f32],
430    beta: &[f32],
431    output: &mut [f32],
432    h: usize,
433    eps: f32,
434) {
435    let inv_hf = 1.0 / h as f32;
436    let mut sum = 0f32;
437    let mut sumsq = 0f32;
438    for i in 0..h {
439        sum += input[i];
440        sumsq += input[i] * input[i];
441    }
442    let mean = sum * inv_hf;
443    let var = sumsq * inv_hf - mean * mean;
444    let inv = 1.0 / (var + eps).sqrt();
445    for i in 0..h {
446        output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
447    }
448}
449
450/// Fused residual + bias + LayerNorm on [n, h] buffers.
451/// Computes: output\[row\] = LN(a\[row\] + b\[row\] + bias, gamma, beta)
452pub fn residual_bias_layer_norm(
453    a: &[f32],
454    b: &[f32],
455    bias: &[f32],
456    gamma: &[f32],
457    beta: &[f32],
458    output: &mut [f32],
459    n: usize,
460    h: usize,
461    eps: f32,
462) {
463    // Temporary per-row buffer for a+b+bias (stack allocated for small h)
464    let mut tmp = vec![0f32; h];
465    for row in 0..n {
466        let base = row * h;
467        for i in 0..h {
468            tmp[i] = a[base + i] + b[base + i] + bias[i];
469        }
470        layer_norm_row(&tmp, gamma, beta, &mut output[base..base + h], h, eps);
471    }
472}
473
474/// Fused residual + bias + RMSNorm on [n, h] buffers.
475/// Computes: output[row] = RmsNorm(a[row] + b[row] + bias, gamma, beta)
476pub fn residual_bias_rms_norm(
477    a: &[f32],
478    b: &[f32],
479    bias: &[f32],
480    gamma: &[f32],
481    beta: &[f32],
482    output: &mut [f32],
483    n: usize,
484    h: usize,
485    eps: f32,
486) {
487    let inv_h = 1.0 / h as f32;
488    for row in 0..n {
489        let base = row * h;
490        let mut sumsq = 0f32;
491        for i in 0..h {
492            let v = a[base + i] + b[base + i] + bias[i];
493            sumsq += v * v;
494        }
495        let inv_rms = (sumsq * inv_h + eps).sqrt().recip();
496        for i in 0..h {
497            let v = a[base + i] + b[base + i] + bias[i];
498            output[base + i] = v * inv_rms * gamma[i] + beta[i];
499        }
500    }
501}
502
503/// Parallel residual + bias + LayerNorm.
504pub fn par_residual_bias_ln(
505    a: &[f32],
506    b: &[f32],
507    bias: &[f32],
508    gamma: &[f32],
509    beta: &[f32],
510    output: &mut [f32],
511    n: usize,
512    h: usize,
513    eps: f32,
514) {
515    let cfg = crate::config::RuntimeConfig::global();
516    if n * h < cfg.par_threshold || n < cfg.min_rows_per_thread {
517        residual_bias_layer_norm(a, b, bias, gamma, beta, output, n, h, eps);
518        return;
519    }
520    let a_ptr = a.as_ptr() as usize;
521    let b_ptr = b.as_ptr() as usize;
522    let o_ptr = output.as_mut_ptr() as usize;
523    let bias_ptr = bias.as_ptr() as usize;
524    let gamma_ptr = gamma.as_ptr() as usize;
525    let beta_ptr = beta.as_ptr() as usize;
526    pool::par_for(n, cfg.min_rows_per_thread, &|off, cnt| unsafe {
527        let a_s = std::slice::from_raw_parts((a_ptr as *const f32).add(off * h), cnt * h);
528        let b_s = std::slice::from_raw_parts((b_ptr as *const f32).add(off * h), cnt * h);
529        let o_s = std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
530        let bi = std::slice::from_raw_parts(bias_ptr as *const f32, h);
531        let g = std::slice::from_raw_parts(gamma_ptr as *const f32, h);
532        let be = std::slice::from_raw_parts(beta_ptr as *const f32, h);
533        residual_bias_layer_norm(a_s, b_s, bi, g, be, o_s, cnt, h, eps);
534    });
535}
536
537// ── NEON softmax ────────────────────────────────────────────────────────
538
539/// NEON-vectorized softmax: 3-pass (max, exp+sum, normalize).
540#[cfg(target_arch = "aarch64")]
541pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
542    use std::arch::aarch64::*;
543    let chunks = cols / 4;
544    unsafe {
545        for row in 0..rows {
546            let base = row * cols;
547            let ptr = data.as_mut_ptr().add(base);
548
549            // Pass 1: find row max
550            let mut vmax = vdupq_n_f32(f32::NEG_INFINITY);
551            for c in 0..chunks {
552                vmax = vmaxq_f32(vmax, vld1q_f32(ptr.add(c * 4)));
553            }
554            let mut max_val = vmaxvq_f32(vmax);
555            for i in (chunks * 4)..cols {
556                max_val = max_val.max(*ptr.add(i));
557            }
558
559            // Pass 2: exp(x - max) and accumulate sum
560            let vmx = vdupq_n_f32(max_val);
561            let mut vsum = vdupq_n_f32(0.0);
562            for c in 0..chunks {
563                let off = c * 4;
564                let e = neon_exp4(vsubq_f32(vld1q_f32(ptr.add(off)), vmx));
565                vst1q_f32(ptr.add(off), e);
566                vsum = vaddq_f32(vsum, e);
567            }
568            let mut sum = vaddvq_f32(vsum);
569            for i in (chunks * 4)..cols {
570                let e = (*ptr.add(i) - max_val).exp();
571                *ptr.add(i) = e;
572                sum += e;
573            }
574
575            // Pass 3: normalize
576            let vinv = vdupq_n_f32(1.0 / sum);
577            for c in 0..chunks {
578                let off = c * 4;
579                vst1q_f32(ptr.add(off), vmulq_f32(vld1q_f32(ptr.add(off)), vinv));
580            }
581            let inv = 1.0 / sum;
582            for i in (chunks * 4)..cols {
583                *ptr.add(i) *= inv;
584            }
585        }
586    }
587}
588
589#[cfg(all(
590    target_arch = "x86_64",
591    target_feature = "avx2",
592    target_feature = "fma"
593))]
594pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
595    use std::arch::x86_64::*;
596    let chunks = cols / 8;
597    unsafe {
598        for r in 0..rows {
599            let row = data.as_mut_ptr().add(r * cols);
600            // 1) Vector max for stability.
601            let mut vmax = _mm256_set1_ps(f32::NEG_INFINITY);
602            for c in 0..chunks {
603                vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(row.add(c * 8)));
604            }
605            let mut max_v = {
606                let lo = _mm256_castps256_ps128(vmax);
607                let hi = _mm256_extractf128_ps::<1>(vmax);
608                let s4 = _mm_max_ps(lo, hi);
609                let s2 = _mm_max_ps(s4, _mm_movehl_ps(s4, s4));
610                let s1 = _mm_max_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
611                _mm_cvtss_f32(s1)
612            };
613            for i in (chunks * 8)..cols {
614                let v = *row.add(i);
615                if v > max_v {
616                    max_v = v;
617                }
618            }
619            // 2) exp(x − max) and sum.
620            let vmax = _mm256_set1_ps(max_v);
621            let mut vsum = _mm256_setzero_ps();
622            for c in 0..chunks {
623                let off = c * 8;
624                let e = avx2_exp8(_mm256_sub_ps(_mm256_loadu_ps(row.add(off)), vmax));
625                _mm256_storeu_ps(row.add(off), e);
626                vsum = _mm256_add_ps(vsum, e);
627            }
628            let mut sum_v = {
629                let lo = _mm256_castps256_ps128(vsum);
630                let hi = _mm256_extractf128_ps::<1>(vsum);
631                let s4 = _mm_add_ps(lo, hi);
632                let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
633                let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
634                _mm_cvtss_f32(s1)
635            };
636            for i in (chunks * 8)..cols {
637                let v = (*row.add(i) - max_v).exp();
638                *row.add(i) = v;
639                sum_v += v;
640            }
641            // 3) Normalize.
642            let vinv = _mm256_set1_ps(1.0 / sum_v);
643            for c in 0..chunks {
644                let off = c * 8;
645                _mm256_storeu_ps(
646                    row.add(off),
647                    _mm256_mul_ps(_mm256_loadu_ps(row.add(off)), vinv),
648                );
649            }
650            let inv_sum = 1.0 / sum_v;
651            for i in (chunks * 8)..cols {
652                *row.add(i) *= inv_sum;
653            }
654        }
655    }
656}
657
658#[cfg(not(any(
659    target_arch = "aarch64",
660    all(
661        target_arch = "x86_64",
662        target_feature = "avx2",
663        target_feature = "fma"
664    )
665)))]
666pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
667    crate::naive::softmax(data, rows, cols);
668}
669
670// ── GELU in-place (no bias) ────────────────────────────────────────────
671
672/// NEON GELU activation in-place (without bias addition).
673#[cfg(target_arch = "aarch64")]
674pub fn gelu_inplace(data: &mut [f32]) {
675    use std::arch::aarch64::*;
676    let len = data.len();
677    let chunks = len / 4;
678    unsafe {
679        let half = vdupq_n_f32(0.5);
680        let one = vdupq_n_f32(1.0);
681        let inv_sqrt2 = vdupq_n_f32(std::f32::consts::FRAC_1_SQRT_2);
682        let p = vdupq_n_f32(0.3275911);
683        let a1 = vdupq_n_f32(0.254_829_6);
684        let a2 = vdupq_n_f32(-0.284_496_72);
685        let a3 = vdupq_n_f32(1.421_413_8);
686        let a4 = vdupq_n_f32(-1.453_152_1);
687        let a5 = vdupq_n_f32(1.061_405_4);
688        let neg_one = vdupq_n_f32(-1.0);
689        let zero = vdupq_n_f32(0.0);
690
691        for c in 0..chunks {
692            let ptr = data.as_mut_ptr().add(c * 4);
693            let x = vld1q_f32(ptr);
694            let erf_arg = vmulq_f32(x, inv_sqrt2);
695            let xa = vabsq_f32(erf_arg);
696            let sign = vbslq_f32(vcgeq_f32(erf_arg, zero), one, neg_one);
697            let denom = vfmaq_f32(one, p, xa);
698            let t = vdivq_f32(one, denom);
699            let mut y = a5;
700            y = vfmaq_f32(a4, y, t);
701            y = vfmaq_f32(a3, y, t);
702            y = vfmaq_f32(a2, y, t);
703            y = vfmaq_f32(a1, y, t);
704            y = vmulq_f32(y, t);
705            let exp_val = neon_exp4(vnegq_f32(vmulq_f32(xa, xa)));
706            let erf_val = vmulq_f32(sign, vfmsq_f32(one, y, exp_val));
707            vst1q_f32(ptr, vmulq_f32(x, vmulq_f32(half, vaddq_f32(one, erf_val))));
708        }
709        for i in (chunks * 4)..len {
710            data[i] = scalar_gelu(data[i]);
711        }
712    }
713}
714
715#[cfg(all(
716    target_arch = "x86_64",
717    target_feature = "avx2",
718    target_feature = "fma"
719))]
720pub fn gelu_inplace(data: &mut [f32]) {
721    use std::arch::x86_64::*;
722    let chunks = data.len() / 8;
723    unsafe {
724        let half = _mm256_set1_ps(0.5);
725        let one = _mm256_set1_ps(1.0);
726        let inv_sqrt2 = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
727        let p = _mm256_set1_ps(0.3275911);
728        let a1 = _mm256_set1_ps(0.254829592);
729        let a2 = _mm256_set1_ps(-0.284496736);
730        let a3 = _mm256_set1_ps(1.421413741);
731        let a4 = _mm256_set1_ps(-1.453152027);
732        let a5 = _mm256_set1_ps(1.061405429);
733        let neg_one = _mm256_set1_ps(-1.0);
734        let zero = _mm256_set1_ps(0.0);
735        let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff_ffff));
736        for c in 0..chunks {
737            let off = c * 8;
738            let ptr = data.as_mut_ptr().add(off);
739            let x = _mm256_loadu_ps(ptr);
740            let erf_arg = _mm256_mul_ps(x, inv_sqrt2);
741            let xa = _mm256_and_ps(erf_arg, abs_mask);
742            let ge0 = _mm256_cmp_ps::<_CMP_GE_OQ>(erf_arg, zero);
743            let sign = _mm256_blendv_ps(neg_one, one, ge0);
744            let denom = _mm256_fmadd_ps(p, xa, one);
745            let t = _mm256_div_ps(one, denom);
746            let mut y = a5;
747            y = _mm256_fmadd_ps(y, t, a4);
748            y = _mm256_fmadd_ps(y, t, a3);
749            y = _mm256_fmadd_ps(y, t, a2);
750            y = _mm256_fmadd_ps(y, t, a1);
751            y = _mm256_mul_ps(y, t);
752            let exp_val = avx2_exp8(_mm256_sub_ps(zero, _mm256_mul_ps(xa, xa)));
753            let erf_val = _mm256_mul_ps(sign, _mm256_fnmadd_ps(y, exp_val, one));
754            _mm256_storeu_ps(
755                ptr,
756                _mm256_mul_ps(x, _mm256_mul_ps(half, _mm256_add_ps(one, erf_val))),
757            );
758        }
759        for i in (chunks * 8)..data.len() {
760            data[i] = scalar_gelu(data[i]);
761        }
762    }
763}
764
765#[cfg(not(any(
766    target_arch = "aarch64",
767    all(
768        target_arch = "x86_64",
769        target_feature = "avx2",
770        target_feature = "fma"
771    )
772)))]
773pub fn gelu_inplace(data: &mut [f32]) {
774    for v in data.iter_mut() {
775        *v = scalar_gelu(*v);
776    }
777}
778
779/// Parallel GELU in-place (splits work across thread pool).
780///
781/// Activation kernels are O(n) with very low per-element cost
782/// (~10 NEON cycles on aarch64). Pool dispatch overhead — even
783/// with the parked design — is in the multi-µs range under
784/// container scheduling, which dwarfs the actual compute for any
785/// reasonable activation size. Threshold here is 1 Mi elements:
786/// only crossed by very large activation tensors (e.g. an
787/// H=4096, FFN=14336, S=1024 LLM up-projection at ~14M
788/// elements). Single-thread NEON is the clear win below that.
789const ACTIVATION_PAR_MIN: usize = 1 << 20;
790
791/// Tanh-approximation GELU (matches PyTorch/candle `Tensor::gelu`):
792///   y = 0.5 x (1 + tanh(√(2/π) · (x + 0.044715 x³)))
793///
794/// Scalar-only for now; the erf-based `gelu_inplace` above is SIMD.
795/// Routed from `Activation::GeluApprox` so models that need
796/// numerical parity with PyTorch's default GELU (e.g. DINOv2,
797/// many ViTs) get the right formula. Use `Activation::Gelu` for the
798/// erf form (also PyTorch-default in some newer builds).
799#[inline]
800pub fn scalar_gelu_approx(x: f32) -> f32 {
801    const C: f32 = 0.797_884_6; // √(2/π)
802    const A: f32 = 0.044_715;
803    0.5 * x * (1.0 + (C * (x + A * x * x * x)).tanh())
804}
805
806pub fn gelu_approx_inplace(data: &mut [f32]) {
807    for v in data.iter_mut() {
808        *v = scalar_gelu_approx(*v);
809    }
810}
811
812pub fn par_gelu_approx_inplace(data: &mut [f32]) {
813    let len = data.len();
814    if len < ACTIVATION_PAR_MIN {
815        gelu_approx_inplace(data);
816        return;
817    }
818    let cfg = crate::config::RuntimeConfig::global();
819    let chunk = 512;
820    let rows = len / chunk;
821    if rows < 2 {
822        gelu_approx_inplace(data);
823        return;
824    }
825    let data_ptr = data.as_mut_ptr() as usize;
826    pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
827        let start = off * chunk;
828        let end = if off + cnt >= rows {
829            len
830        } else {
831            (off + cnt) * chunk
832        };
833        let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
834        gelu_approx_inplace(s);
835    });
836    let done = rows * chunk;
837    if done < len {
838        gelu_approx_inplace(&mut data[done..]);
839    }
840}
841
842pub fn par_gelu_inplace(data: &mut [f32]) {
843    let len = data.len();
844    if len < ACTIVATION_PAR_MIN {
845        gelu_inplace(data);
846        return;
847    }
848    let cfg = crate::config::RuntimeConfig::global();
849    let chunk = 512;
850    let rows = len / chunk;
851    if rows < 2 {
852        gelu_inplace(data);
853        return;
854    }
855    let data_ptr = data.as_mut_ptr() as usize;
856    pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
857        let start = off * chunk;
858        let end = if off + cnt >= rows {
859            len
860        } else {
861            (off + cnt) * chunk
862        };
863        let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
864        gelu_inplace(s);
865    });
866    let done = rows * chunk;
867    if done < len {
868        gelu_inplace(&mut data[done..]);
869    }
870}
871
872/// Parallel SiLU in-place. Same threshold reasoning as `par_gelu_inplace`.
873pub fn par_silu_inplace(data: &mut [f32]) {
874    let len = data.len();
875    if len < ACTIVATION_PAR_MIN {
876        silu_inplace(data);
877        return;
878    }
879    let cfg = crate::config::RuntimeConfig::global();
880    let chunk = 512;
881    let rows = len / chunk;
882    if rows < 2 {
883        silu_inplace(data);
884        return;
885    }
886    let data_ptr = data.as_mut_ptr() as usize;
887    pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
888        let start = off * chunk;
889        let end = if off + cnt >= rows {
890            len
891        } else {
892            (off + cnt) * chunk
893        };
894        let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
895        silu_inplace(s);
896    });
897    let done = rows * chunk;
898    if done < len {
899        silu_inplace(&mut data[done..]);
900    }
901}
902
903// ── Small-m NEON matmul ─────────────────────────────────────────────────
904
905/// NEON matmul for tiny m (1-8 rows). Avoids BLAS call overhead.
906/// C = A @ B where A=\[m,k\], B=\[k,n\], C=\[m,n\], all row-major.
907/// For m≤8 with small k×n (under ~16K elements), this beats cblas_sgemm
908/// by avoiding AMX setup cost and function call overhead.
909#[cfg(target_arch = "aarch64")]
910pub fn neon_sgemm_small(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
911    use std::arch::aarch64::*;
912    let n4 = n / 4;
913    unsafe {
914        for j4 in 0..n4 {
915            let j = j4 * 4;
916            // m accumulators (one per output row, 4-wide)
917            let mut acc = [vdupq_n_f32(0.0); 8];
918            for kk in 0..k {
919                let bv = vld1q_f32(b.as_ptr().add(kk * n + j));
920                for i in 0..m {
921                    let av = vdupq_n_f32(*a.as_ptr().add(i * k + kk));
922                    acc[i] = vfmaq_f32(acc[i], av, bv);
923                }
924            }
925            for i in 0..m {
926                vst1q_f32(c.as_mut_ptr().add(i * n + j), acc[i]);
927            }
928        }
929        // Remainder columns
930        for j in (n4 * 4)..n {
931            for i in 0..m {
932                let mut sum = 0f32;
933                for kk in 0..k {
934                    sum += a[i * k + kk] * b[kk * n + j];
935                }
936                c[i * n + j] = sum;
937            }
938        }
939    }
940}
941
942#[cfg(not(target_arch = "aarch64"))]
943pub fn neon_sgemm_small(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
944    crate::naive::matmul(a, b, c, m, k, n);
945}
946
947/// NEON sgemm_bias for tiny m: C = A @ B + bias.
948#[cfg(target_arch = "aarch64")]
949pub fn neon_sgemm_bias_small(
950    a: &[f32],
951    b: &[f32],
952    bias: &[f32],
953    c: &mut [f32],
954    m: usize,
955    k: usize,
956    n: usize,
957) {
958    neon_sgemm_small(a, b, c, m, k, n);
959    crate::blas::bias_add(c, bias, m, n);
960}
961
962#[cfg(not(target_arch = "aarch64"))]
963pub fn neon_sgemm_bias_small(
964    a: &[f32],
965    b: &[f32],
966    bias: &[f32],
967    c: &mut [f32],
968    m: usize,
969    k: usize,
970    n: usize,
971) {
972    crate::naive::matmul(a, b, c, m, k, n);
973    crate::naive::bias_add(c, bias, m, n);
974}
975
976// ── Scalar fallbacks ────────────────────────────────────────────────────
977
978fn scalar_gelu(x: f32) -> f32 {
979    x * 0.5 * (1.0 + scalar_erf(x * std::f32::consts::FRAC_1_SQRT_2))
980}
981
982fn scalar_erf(x: f32) -> f32 {
983    let sign = if x >= 0.0 { 1.0f32 } else { -1.0 };
984    let xa = x.abs();
985    let t = 1.0 / (1.0 + 0.3275911 * xa);
986    let y = t
987        * (0.254_829_6
988            + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
989    sign * (1.0 - y * (-xa * xa).exp())
990}
991
992/// NCHW LayerNorm2d (candle / SAM semantics): normalize across channels at
993/// each spatial position. `gamma`/`beta` are per-channel `[C]`.
994pub fn layer_norm2d_nchw(
995    input: &[f32],
996    gamma: &[f32],
997    beta: &[f32],
998    output: &mut [f32],
999    batch: usize,
1000    channels: usize,
1001    h: usize,
1002    w: usize,
1003    eps: f32,
1004) {
1005    let spatial = h * w;
1006    for b in 0..batch {
1007        for i in 0..spatial {
1008            let mut mean = 0.0f32;
1009            for c in 0..channels {
1010                mean += input[((b * channels + c) * spatial) + i];
1011            }
1012            mean /= channels as f32;
1013            let mut var = 0.0f32;
1014            for c in 0..channels {
1015                let d = input[((b * channels + c) * spatial) + i] - mean;
1016                var += d * d;
1017            }
1018            var /= channels as f32;
1019            let inv = 1.0 / (var + eps).sqrt();
1020            for c in 0..channels {
1021                let v = (input[((b * channels + c) * spatial) + i] - mean) * inv;
1022                output[((b * channels + c) * spatial) + i] = v * gamma[c] + beta[c];
1023            }
1024        }
1025    }
1026}
1027
1028/// NCHW transposed convolution (PyTorch `ConvTranspose2d`, no bias).
1029/// Weight layout `[C_in, C_out/groups, kH, kW]`.
1030pub fn conv_transpose2d_nchw(
1031    input: &[f32],
1032    weight: &[f32],
1033    output: &mut [f32],
1034    n: usize,
1035    c_in: usize,
1036    h: usize,
1037    w: usize,
1038    c_out: usize,
1039    h_out: usize,
1040    w_out: usize,
1041    kh: usize,
1042    kw: usize,
1043    sh: usize,
1044    sw: usize,
1045    ph: usize,
1046    pw: usize,
1047    dh: usize,
1048    dw: usize,
1049    groups: usize,
1050) {
1051    output.fill(0.0);
1052    let c_in_per_g = c_in / groups;
1053    let c_out_per_g = c_out / groups;
1054    for ni in 0..n {
1055        for ic in 0..c_in {
1056            let g = ic / c_in_per_g;
1057            let _ic_off = ic % c_in_per_g;
1058            for iy in 0..h {
1059                for ix in 0..w {
1060                    let v = input[((ni * c_in + ic) * h + iy) * w + ix];
1061                    if v == 0.0 {
1062                        continue;
1063                    }
1064                    for ky in 0..kh {
1065                        let oy = iy * sh + ky * dh;
1066                        if oy < ph || oy >= h_out + ph {
1067                            continue;
1068                        }
1069                        let oy = oy - ph;
1070                        if oy >= h_out {
1071                            continue;
1072                        }
1073                        for kx in 0..kw {
1074                            let ox = ix * sw + kx * dw;
1075                            if ox < pw || ox >= w_out + pw {
1076                                continue;
1077                            }
1078                            let ox = ox - pw;
1079                            if ox >= w_out {
1080                                continue;
1081                            }
1082                            for oc_off in 0..c_out_per_g {
1083                                let oc = g * c_out_per_g + oc_off;
1084                                let w_idx = ((ic * c_out_per_g + oc_off) * kh + ky) * kw + kx;
1085                                let wt = weight[w_idx];
1086                                output[((ni * c_out + oc) * h_out + oy) * w_out + ox] += v * wt;
1087                            }
1088                        }
1089                    }
1090                }
1091            }
1092        }
1093    }
1094}
1095
1096/// NCHW group normalization: normalizes each `(C/G)×H×W` group.
1097pub fn group_norm_nchw(
1098    input: &[f32],
1099    gamma: &[f32],
1100    beta: &[f32],
1101    output: &mut [f32],
1102    batch: usize,
1103    channels: usize,
1104    h: usize,
1105    w: usize,
1106    num_groups: usize,
1107    eps: f32,
1108) {
1109    let cpg = channels / num_groups;
1110    let spatial = h * w;
1111    let n = (cpg * spatial) as f32;
1112    for b in 0..batch {
1113        for g in 0..num_groups {
1114            let c0 = g * cpg;
1115            let mut mean = 0.0f32;
1116            for c in 0..cpg {
1117                let plane = &input
1118                    [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1119                mean += plane.iter().sum::<f32>();
1120            }
1121            mean /= n;
1122            let mut var = 0.0f32;
1123            for c in 0..cpg {
1124                let plane = &input
1125                    [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1126                for &v in plane {
1127                    let d = v - mean;
1128                    var += d * d;
1129                }
1130            }
1131            var /= n;
1132            let inv = 1.0 / (var + eps).sqrt();
1133            for c in 0..cpg {
1134                let gi = c0 + c;
1135                let gamm = gamma[gi];
1136                let bet = beta[gi];
1137                let src =
1138                    &input[((b * channels + gi) * spatial)..((b * channels + gi + 1) * spatial)];
1139                let dst = &mut output
1140                    [((b * channels + gi) * spatial)..((b * channels + gi + 1) * spatial)];
1141                for (d, &s) in dst.iter_mut().zip(src) {
1142                    *d = (s - mean) * inv * gamm + bet;
1143                }
1144            }
1145        }
1146    }
1147}
1148
1149/// Nearest-neighbor 2× upsample on planar NCHW.
1150pub fn resize_nearest_2x_nchw(
1151    input: &[f32],
1152    output: &mut [f32],
1153    channels: usize,
1154    h: usize,
1155    w: usize,
1156) {
1157    let h2 = h * 2;
1158    let w2 = w * 2;
1159    for c in 0..channels {
1160        let plane = &input[c * h * w..(c + 1) * h * w];
1161        let dst = &mut output[c * h2 * w2..(c + 1) * h2 * w2];
1162        for y in 0..h {
1163            for x in 0..w {
1164                let v = plane[y * w + x];
1165                for dy in 0..2 {
1166                    for dx in 0..2 {
1167                        dst[(y * 2 + dy) * w2 + (x * 2 + dx)] = v;
1168                    }
1169                }
1170            }
1171        }
1172    }
1173}
1174
1175#[cfg(test)]
1176mod tests {
1177    use super::*;
1178
1179    #[test]
1180    fn gelu_correctness() {
1181        let x = 1.5f32;
1182        let g = scalar_gelu(x);
1183        // Reference: gelu(1.5) ≈ 1.3990
1184        assert!((g - 1.3990).abs() < 0.01, "gelu(1.5) = {g}");
1185    }
1186
1187    #[test]
1188    fn bias_gelu_works() {
1189        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1190        let bias = vec![0.1, 0.2, 0.3, 0.4];
1191        bias_gelu(&mut data, &bias, 2, 4);
1192        // After bias+gelu, values should be > 0 (all inputs positive)
1193        for &v in &data {
1194            assert!(v > 0.0, "bias_gelu produced {v}");
1195        }
1196    }
1197
1198    #[test]
1199    fn layer_norm_unit_test() {
1200        let input = vec![1.0, 2.0, 3.0, 4.0];
1201        let gamma = vec![1.0; 4];
1202        let beta = vec![0.0; 4];
1203        let mut output = vec![0.0; 4];
1204        layer_norm_row(&input, &gamma, &beta, &mut output, 4, 1e-5);
1205        // Mean=2.5, std≈1.118. output ≈ [-1.342, -0.447, 0.447, 1.342]
1206        assert!((output[0] - -1.342).abs() < 0.01);
1207        assert!((output[3] - 1.342).abs() < 0.01);
1208        // Sum should be ~0 (normalized)
1209        let sum: f32 = output.iter().sum();
1210        assert!(sum.abs() < 0.01, "LN sum should be ~0, got {sum}");
1211    }
1212
1213    #[test]
1214    fn par_bias_gelu_matches_sequential() {
1215        let n = 100;
1216        let m = 64;
1217        let mut data_par = vec![0.5f32; n * m];
1218        let mut data_seq = data_par.clone();
1219        let bias = vec![0.1f32; m];
1220
1221        bias_gelu(&mut data_seq, &bias, n, m);
1222        par_bias_gelu(&mut data_par, &bias, n, m);
1223
1224        let max_diff: f32 = data_par
1225            .iter()
1226            .zip(data_seq.iter())
1227            .map(|(a, b)| (a - b).abs())
1228            .fold(0f32, f32::max);
1229        assert!(max_diff < 1e-6, "par vs seq diff: {max_diff}");
1230    }
1231}