Skip to main content

rlx_cpu/
kernels.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SIMD kernels for fused operations.
17//!
18//! These are the production kernels extracted from burnembed's ndarray_fused.rs.
19//! Each kernel processes data in-place or into a pre-allocated output buffer
20//! (from the arena). No allocation.
21
22use crate::pool;
23
24// ── NEON vectorized exp ─────────────────────────────────────────────────
25
26/// NEON vectorized exp(x) for 4 floats. Range reduction + 6th-order Taylor.
27/// Max relative error: ~2e-7 across [-87, 88].
28#[cfg(target_arch = "aarch64")]
29#[inline(always)]
30#[allow(unsafe_op_in_unsafe_fn)]
31pub unsafe fn neon_exp4(x: std::arch::aarch64::float32x4_t) -> std::arch::aarch64::float32x4_t {
32    use std::arch::aarch64::*;
33    let x = vmaxq_f32(x, vdupq_n_f32(-87.3));
34    let x = vminq_f32(x, vdupq_n_f32(88.7));
35    let inv_ln2 = vdupq_n_f32(std::f32::consts::LOG2_E);
36    let ln2_hi = vdupq_n_f32(0.693_145_75);
37    let ln2_lo = vdupq_n_f32(1.428_606_8e-6);
38    let n = vrndnq_f32(vmulq_f32(x, inv_ln2));
39    let r = vfmsq_f32(vfmsq_f32(x, n, ln2_hi), n, ln2_lo);
40    let c1 = vdupq_n_f32(1.0);
41    let mut p = vdupq_n_f32(0.001_388_888_9);
42    p = vfmaq_f32(vdupq_n_f32(0.008_333_334), p, r);
43    p = vfmaq_f32(vdupq_n_f32(0.041_666_668), p, r);
44    p = vfmaq_f32(vdupq_n_f32(0.166_666_67), p, r);
45    p = vfmaq_f32(vdupq_n_f32(0.5), p, r);
46    p = vfmaq_f32(c1, p, r);
47    p = vfmaq_f32(c1, p, r);
48    let ni = vcvtq_s32_f32(n);
49    vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(p), vshlq_n_s32(ni, 23)))
50}
51
52/// AVX2+FMA vectorised exp(x) for 8 floats. Same range reduction +
53/// 6th-order Taylor polynomial as `neon_exp4`. Max relative error
54/// stays in the ~2e-7 range. Requires `+avx2,+fma` codegen.
55#[cfg(all(
56    target_arch = "x86_64",
57    target_feature = "avx2",
58    target_feature = "fma"
59))]
60#[inline(always)]
61#[allow(unsafe_op_in_unsafe_fn)]
62pub unsafe fn avx2_exp8(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
63    use std::arch::x86_64::*;
64    let x = _mm256_max_ps(x, _mm256_set1_ps(-87.3));
65    let x = _mm256_min_ps(x, _mm256_set1_ps(88.7));
66    let inv_ln2 = _mm256_set1_ps(1.442695040888963);
67    let ln2_hi = _mm256_set1_ps(0.693145751953125);
68    let ln2_lo = _mm256_set1_ps(1.428606765330187e-6);
69    // n = round(x / ln2)  (round-to-nearest-even)
70    let n = _mm256_round_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(_mm256_mul_ps(
71        x, inv_ln2,
72    ));
73    // r = x − n·ln2_hi − n·ln2_lo
74    let r = _mm256_fnmadd_ps(n, ln2_lo, _mm256_fnmadd_ps(n, ln2_hi, x));
75    let c1 = _mm256_set1_ps(1.0);
76    let mut p = _mm256_set1_ps(0.001388888888888889);
77    p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.008333333333333333));
78    p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.041666666666666664));
79    p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.16666666666666666));
80    p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.5));
81    p = _mm256_fmadd_ps(p, r, c1);
82    p = _mm256_fmadd_ps(p, r, c1);
83    // 2^n via integer-bias trick on the f32 exponent field.
84    let ni = _mm256_cvtps_epi32(n);
85    let shifted = _mm256_slli_epi32::<23>(ni);
86    _mm256_castsi256_ps(_mm256_add_epi32(_mm256_castps_si256(p), shifted))
87}
88
89// ── Fused bias + GELU ───────────────────────────────────────────────────
90
91/// Fused bias addition + GELU activation on a [m, n] buffer.
92/// Uses Abramowitz & Stegun erf approximation with NEON exp.
93#[cfg(target_arch = "aarch64")]
94pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
95    use std::arch::aarch64::*;
96    let chunks = n / 4;
97    unsafe {
98        let half = vdupq_n_f32(0.5);
99        let one = vdupq_n_f32(1.0);
100        let inv_sqrt2 = vdupq_n_f32(std::f32::consts::FRAC_1_SQRT_2);
101        let p = vdupq_n_f32(0.3275911);
102        let a1 = vdupq_n_f32(0.254_829_6);
103        let a2 = vdupq_n_f32(-0.284_496_72);
104        let a3 = vdupq_n_f32(1.421_413_8);
105        let a4 = vdupq_n_f32(-1.453_152_1);
106        let a5 = vdupq_n_f32(1.061_405_4);
107        let neg_one = vdupq_n_f32(-1.0);
108        let zero = vdupq_n_f32(0.0);
109
110        for row in 0..m {
111            let base = row * n;
112            for c in 0..chunks {
113                let off = base + c * 4;
114                let ptr = data.as_mut_ptr().add(off);
115                let x = vaddq_f32(vld1q_f32(ptr), vld1q_f32(bias.as_ptr().add(c * 4)));
116                let erf_arg = vmulq_f32(x, inv_sqrt2);
117                let xa = vabsq_f32(erf_arg);
118                let sign = vbslq_f32(vcgeq_f32(erf_arg, zero), one, neg_one);
119                let denom = vfmaq_f32(one, p, xa);
120                let t = vdivq_f32(one, denom);
121                let mut y = a5;
122                y = vfmaq_f32(a4, y, t);
123                y = vfmaq_f32(a3, y, t);
124                y = vfmaq_f32(a2, y, t);
125                y = vfmaq_f32(a1, y, t);
126                y = vmulq_f32(y, t);
127                let exp_val = neon_exp4(vnegq_f32(vmulq_f32(xa, xa)));
128                let erf_val = vmulq_f32(sign, vfmsq_f32(one, y, exp_val));
129                vst1q_f32(ptr, vmulq_f32(x, vmulq_f32(half, vaddq_f32(one, erf_val))));
130            }
131            for i in (chunks * 4)..n {
132                let x = data[base + i] + bias[i];
133                data[base + i] = scalar_gelu(x);
134            }
135        }
136    }
137}
138
139#[cfg(all(
140    target_arch = "x86_64",
141    target_feature = "avx2",
142    target_feature = "fma"
143))]
144pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
145    use std::arch::x86_64::*;
146    let chunks = n / 8;
147    unsafe {
148        let half = _mm256_set1_ps(0.5);
149        let one = _mm256_set1_ps(1.0);
150        let inv_sqrt2 = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
151        let p = _mm256_set1_ps(0.3275911);
152        let a1 = _mm256_set1_ps(0.254829592);
153        let a2 = _mm256_set1_ps(-0.284496736);
154        let a3 = _mm256_set1_ps(1.421413741);
155        let a4 = _mm256_set1_ps(-1.453152027);
156        let a5 = _mm256_set1_ps(1.061405429);
157        let neg_one = _mm256_set1_ps(-1.0);
158        let zero = _mm256_set1_ps(0.0);
159        // Sign bit mask for fabs via AND with 0x7fffffff.
160        let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff_ffff));
161
162        for row in 0..m {
163            let base = row * n;
164            for c in 0..chunks {
165                let off = base + c * 8;
166                let ptr = data.as_mut_ptr().add(off);
167                let x = _mm256_add_ps(
168                    _mm256_loadu_ps(ptr),
169                    _mm256_loadu_ps(bias.as_ptr().add(c * 8)),
170                );
171                let erf_arg = _mm256_mul_ps(x, inv_sqrt2);
172                let xa = _mm256_and_ps(erf_arg, abs_mask);
173                // sign = (erf_arg >= 0) ? 1 : -1
174                let ge0 = _mm256_cmp_ps::<_CMP_GE_OQ>(erf_arg, zero);
175                let sign = _mm256_blendv_ps(neg_one, one, ge0);
176                let denom = _mm256_fmadd_ps(p, xa, one);
177                let t = _mm256_div_ps(one, denom);
178                let mut y = a5;
179                y = _mm256_fmadd_ps(y, t, a4);
180                y = _mm256_fmadd_ps(y, t, a3);
181                y = _mm256_fmadd_ps(y, t, a2);
182                y = _mm256_fmadd_ps(y, t, a1);
183                y = _mm256_mul_ps(y, t);
184                let exp_val = avx2_exp8(_mm256_sub_ps(zero, _mm256_mul_ps(xa, xa)));
185                // erf = sign * (1 - y*exp(-xa^2))
186                let erf_val = _mm256_mul_ps(sign, _mm256_fnmadd_ps(y, exp_val, one));
187                _mm256_storeu_ps(
188                    ptr,
189                    _mm256_mul_ps(x, _mm256_mul_ps(half, _mm256_add_ps(one, erf_val))),
190                );
191            }
192            for i in (chunks * 8)..n {
193                let x = data[base + i] + bias[i];
194                data[base + i] = scalar_gelu(x);
195            }
196        }
197    }
198}
199
200#[cfg(not(any(
201    target_arch = "aarch64",
202    all(
203        target_arch = "x86_64",
204        target_feature = "avx2",
205        target_feature = "fma"
206    )
207)))]
208pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
209    for row in 0..m {
210        let base = row * n;
211        for i in 0..n {
212            let x = data[base + i] + bias[i];
213            data[base + i] = scalar_gelu(x);
214        }
215    }
216}
217
218/// Parallel bias + GELU across thread pool.
219pub fn par_bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
220    let cfg = crate::config::RuntimeConfig::global();
221    if m * n < cfg.par_threshold || m < cfg.min_rows_per_thread {
222        bias_gelu(data, bias, m, n);
223        return;
224    }
225    let data_ptr = data.as_mut_ptr() as usize;
226    let bias_ptr = bias.as_ptr() as usize;
227    pool::par_for(m, cfg.min_rows_per_thread, &|off, cnt| unsafe {
228        let d = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(off * n), cnt * n);
229        let b = std::slice::from_raw_parts(bias_ptr as *const f32, n);
230        bias_gelu(d, b, cnt, n);
231    });
232}
233
234// ── Fused SiLU ──────────────────────────────────────────────────────────
235
236/// SiLU (Swish) in-place: x / (1 + exp(-x))
237#[cfg(target_arch = "aarch64")]
238pub fn silu_inplace(data: &mut [f32]) {
239    use std::arch::aarch64::*;
240    let chunks = data.len() / 4;
241    unsafe {
242        let one = vdupq_n_f32(1.0);
243        for c in 0..chunks {
244            let ptr = data.as_mut_ptr().add(c * 4);
245            let x = vld1q_f32(ptr);
246            let exp_neg = neon_exp4(vnegq_f32(x));
247            let sigmoid = vdivq_f32(one, vaddq_f32(one, exp_neg));
248            vst1q_f32(ptr, vmulq_f32(x, sigmoid));
249        }
250    }
251    for i in (chunks * 4)..data.len() {
252        let x = data[i];
253        data[i] = x / (1.0 + (-x).exp());
254    }
255}
256
257#[cfg(all(
258    target_arch = "x86_64",
259    target_feature = "avx2",
260    target_feature = "fma"
261))]
262pub fn silu_inplace(data: &mut [f32]) {
263    use std::arch::x86_64::*;
264    let chunks = data.len() / 8;
265    unsafe {
266        let one = _mm256_set1_ps(1.0);
267        let zero = _mm256_set1_ps(0.0);
268        for c in 0..chunks {
269            let off = c * 8;
270            let ptr = data.as_mut_ptr().add(off);
271            let x = _mm256_loadu_ps(ptr);
272            // silu(x) = x / (1 + exp(-x))
273            let neg_x = _mm256_sub_ps(zero, x);
274            let denom = _mm256_add_ps(one, avx2_exp8(neg_x));
275            _mm256_storeu_ps(ptr, _mm256_div_ps(x, denom));
276        }
277        for i in (chunks * 8)..data.len() {
278            let x = data[i];
279            data[i] = x / (1.0 + (-x).exp());
280        }
281    }
282}
283
284#[cfg(not(any(
285    target_arch = "aarch64",
286    all(
287        target_arch = "x86_64",
288        target_feature = "avx2",
289        target_feature = "fma"
290    )
291)))]
292pub fn silu_inplace(data: &mut [f32]) {
293    for v in data.iter_mut() {
294        let x = *v;
295        *v = x / (1.0 + (-x).exp());
296    }
297}
298
299// ── LayerNorm (2-pass) ──────────────────────────────────────────────────
300
301/// Single-row LayerNorm: out = (x - mean) * inv_std * gamma + beta.
302/// 2-pass: compute mean+variance (E\[x²\]-E\[x\]²), then normalize.
303#[cfg(target_arch = "aarch64")]
304pub fn layer_norm_row(
305    input: &[f32],
306    gamma: &[f32],
307    beta: &[f32],
308    output: &mut [f32],
309    h: usize,
310    eps: f32,
311) {
312    use std::arch::aarch64::*;
313    let inv_hf = 1.0 / h as f32;
314    let chunks = h / 4;
315    unsafe {
316        let mut vsum = vdupq_n_f32(0.0);
317        let mut vsumsq = vdupq_n_f32(0.0);
318        for c in 0..chunks {
319            let x = vld1q_f32(input.as_ptr().add(c * 4));
320            vsum = vaddq_f32(vsum, x);
321            vsumsq = vfmaq_f32(vsumsq, x, x);
322        }
323        let mut sum = vaddvq_f32(vsum);
324        let mut sumsq = vaddvq_f32(vsumsq);
325        for i in (chunks * 4)..h {
326            sum += input[i];
327            sumsq += input[i] * input[i];
328        }
329        let mean = sum * inv_hf;
330        let var = (sumsq * inv_hf - mean * mean).max(0.0);
331        let inv = 1.0 / (var + eps).sqrt();
332        let vmean = vdupq_n_f32(mean);
333        let vinv = vdupq_n_f32(inv);
334        for c in 0..chunks {
335            let off = c * 4;
336            let x = vld1q_f32(input.as_ptr().add(off));
337            let norm = vmulq_f32(vsubq_f32(x, vmean), vinv);
338            vst1q_f32(
339                output.as_mut_ptr().add(off),
340                vfmaq_f32(
341                    vld1q_f32(beta.as_ptr().add(off)),
342                    norm,
343                    vld1q_f32(gamma.as_ptr().add(off)),
344                ),
345            );
346        }
347        for i in (chunks * 4)..h {
348            output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
349        }
350    }
351}
352
353#[cfg(all(
354    target_arch = "x86_64",
355    target_feature = "avx2",
356    target_feature = "fma"
357))]
358pub fn layer_norm_row(
359    input: &[f32],
360    gamma: &[f32],
361    beta: &[f32],
362    output: &mut [f32],
363    h: usize,
364    eps: f32,
365) {
366    use std::arch::x86_64::*;
367    let inv_hf = 1.0 / h as f32;
368    let chunks = h / 8;
369    unsafe {
370        let mut vsum = _mm256_setzero_ps();
371        let mut vsumsq = _mm256_setzero_ps();
372        for c in 0..chunks {
373            let x = _mm256_loadu_ps(input.as_ptr().add(c * 8));
374            vsum = _mm256_add_ps(vsum, x);
375            vsumsq = _mm256_fmadd_ps(x, x, vsumsq);
376        }
377        // Horizontal reduce: 8 lanes → 1.
378        let hsum = {
379            let lo = _mm256_castps256_ps128(vsum);
380            let hi = _mm256_extractf128_ps::<1>(vsum);
381            let s4 = _mm_add_ps(lo, hi);
382            let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
383            let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
384            _mm_cvtss_f32(s1)
385        };
386        let hsumsq = {
387            let lo = _mm256_castps256_ps128(vsumsq);
388            let hi = _mm256_extractf128_ps::<1>(vsumsq);
389            let s4 = _mm_add_ps(lo, hi);
390            let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
391            let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
392            _mm_cvtss_f32(s1)
393        };
394        let mut sum = hsum;
395        let mut sumsq = hsumsq;
396        for i in (chunks * 8)..h {
397            sum += input[i];
398            sumsq += input[i] * input[i];
399        }
400        let mean = sum * inv_hf;
401        let var = (sumsq * inv_hf - mean * mean).max(0.0);
402        let inv = 1.0 / (var + eps).sqrt();
403        let vmean = _mm256_set1_ps(mean);
404        let vinv = _mm256_set1_ps(inv);
405        for c in 0..chunks {
406            let off = c * 8;
407            let x = _mm256_loadu_ps(input.as_ptr().add(off));
408            let norm = _mm256_mul_ps(_mm256_sub_ps(x, vmean), vinv);
409            let g = _mm256_loadu_ps(gamma.as_ptr().add(off));
410            let b = _mm256_loadu_ps(beta.as_ptr().add(off));
411            _mm256_storeu_ps(output.as_mut_ptr().add(off), _mm256_fmadd_ps(norm, g, b));
412        }
413        for i in (chunks * 8)..h {
414            output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
415        }
416    }
417}
418
419#[cfg(not(any(
420    target_arch = "aarch64",
421    all(
422        target_arch = "x86_64",
423        target_feature = "avx2",
424        target_feature = "fma"
425    )
426)))]
427pub fn layer_norm_row(
428    input: &[f32],
429    gamma: &[f32],
430    beta: &[f32],
431    output: &mut [f32],
432    h: usize,
433    eps: f32,
434) {
435    let inv_hf = 1.0 / h as f32;
436    let mut sum = 0f32;
437    let mut sumsq = 0f32;
438    for i in 0..h {
439        sum += input[i];
440        sumsq += input[i] * input[i];
441    }
442    let mean = sum * inv_hf;
443    let var = (sumsq * inv_hf - mean * mean).max(0.0);
444    let inv = 1.0 / (var + eps).sqrt();
445    for i in 0..h {
446        output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
447    }
448}
449
450/// Inference BatchNorm with frozen running statistics (PyTorch `BatchNorm*d` eval).
451///
452/// `x` is row-major with feature dimension `channels` on the last axis
453/// (`[B, C]`, `[B, P, C]`, …). `gamma`, `beta`, `mean`, `var` are length `C`.
454pub fn batch_norm_inference(
455    x: &[f32],
456    gamma: &[f32],
457    beta: &[f32],
458    mean: &[f32],
459    var: &[f32],
460    out: &mut [f32],
461    channels: usize,
462    eps: f32,
463) {
464    let n = x.len() / channels.max(1);
465    for i in 0..n {
466        for c in 0..channels {
467            let idx = i * channels + c;
468            let inv = 1.0 / (var[c] + eps).sqrt();
469            let xhat = (x[idx] - mean[c]) * inv;
470            out[idx] = gamma[c] * xhat + beta[c];
471        }
472    }
473}
474
475/// `d_x` for [`batch_norm_inference`] (mean/var treated as constants).
476pub fn batch_norm_inference_backward_input(
477    x: &[f32],
478    gamma: &[f32],
479    _mean: &[f32],
480    var: &[f32],
481    dy: &[f32],
482    dx: &mut [f32],
483    channels: usize,
484    eps: f32,
485) {
486    let n = x.len() / channels.max(1);
487    for i in 0..n {
488        for c in 0..channels {
489            let idx = i * channels + c;
490            let inv = 1.0 / (var[c] + eps).sqrt();
491            dx[idx] = dy[idx] * gamma[c] * inv;
492        }
493    }
494}
495
496/// `d_gamma` for [`batch_norm_inference`].
497pub fn batch_norm_inference_backward_gamma(
498    x: &[f32],
499    mean: &[f32],
500    var: &[f32],
501    dy: &[f32],
502    dgamma: &mut [f32],
503    channels: usize,
504    eps: f32,
505) {
506    dgamma.fill(0.0);
507    let n = x.len() / channels.max(1);
508    for i in 0..n {
509        for c in 0..channels {
510            let idx = i * channels + c;
511            let inv = 1.0 / (var[c] + eps).sqrt();
512            let xhat = (x[idx] - mean[c]) * inv;
513            dgamma[c] += dy[idx] * xhat;
514        }
515    }
516}
517
518/// `d_beta` for [`batch_norm_inference`].
519pub fn batch_norm_inference_backward_beta(dy: &[f32], dbeta: &mut [f32], channels: usize) {
520    dbeta.fill(0.0);
521    let n = dy.len() / channels.max(1);
522    for i in 0..n {
523        for c in 0..channels {
524            dbeta[c] += dy[i * channels + c];
525        }
526    }
527}
528
529/// Fused residual + bias + LayerNorm on [n, h] buffers.
530/// Computes: output\[row\] = LN(a\[row\] + b\[row\] + bias, gamma, beta)
531pub fn residual_bias_layer_norm(
532    a: &[f32],
533    b: &[f32],
534    bias: &[f32],
535    gamma: &[f32],
536    beta: &[f32],
537    output: &mut [f32],
538    n: usize,
539    h: usize,
540    eps: f32,
541) {
542    // Temporary per-row buffer for a+b+bias (stack allocated for small h)
543    let mut tmp = vec![0f32; h];
544    for row in 0..n {
545        let base = row * h;
546        for i in 0..h {
547            tmp[i] = a[base + i] + b[base + i] + bias[i];
548        }
549        layer_norm_row(&tmp, gamma, beta, &mut output[base..base + h], h, eps);
550    }
551}
552
553/// Fused residual + bias + RMSNorm on [n, h] buffers.
554/// Computes: output[row] = RmsNorm(a[row] + b[row] + bias, gamma, beta)
555pub fn residual_bias_rms_norm(
556    a: &[f32],
557    b: &[f32],
558    bias: &[f32],
559    gamma: &[f32],
560    beta: &[f32],
561    output: &mut [f32],
562    n: usize,
563    h: usize,
564    eps: f32,
565) {
566    let inv_h = 1.0 / h as f32;
567    for row in 0..n {
568        let base = row * h;
569        let mut sumsq = 0f32;
570        for i in 0..h {
571            let v = a[base + i] + b[base + i] + bias[i];
572            sumsq += v * v;
573        }
574        let inv_rms = (sumsq * inv_h + eps).sqrt().recip();
575        for i in 0..h {
576            let v = a[base + i] + b[base + i] + bias[i];
577            output[base + i] = v * inv_rms * gamma[i] + beta[i];
578        }
579    }
580}
581
582/// Parallel residual + bias + LayerNorm.
583pub fn par_residual_bias_ln(
584    a: &[f32],
585    b: &[f32],
586    bias: &[f32],
587    gamma: &[f32],
588    beta: &[f32],
589    output: &mut [f32],
590    n: usize,
591    h: usize,
592    eps: f32,
593) {
594    let cfg = crate::config::RuntimeConfig::global();
595    if n * h < cfg.par_threshold || n < cfg.min_rows_per_thread {
596        residual_bias_layer_norm(a, b, bias, gamma, beta, output, n, h, eps);
597        return;
598    }
599    let a_ptr = a.as_ptr() as usize;
600    let b_ptr = b.as_ptr() as usize;
601    let o_ptr = output.as_mut_ptr() as usize;
602    let bias_ptr = bias.as_ptr() as usize;
603    let gamma_ptr = gamma.as_ptr() as usize;
604    let beta_ptr = beta.as_ptr() as usize;
605    pool::par_for(n, cfg.min_rows_per_thread, &|off, cnt| unsafe {
606        let a_s = std::slice::from_raw_parts((a_ptr as *const f32).add(off * h), cnt * h);
607        let b_s = std::slice::from_raw_parts((b_ptr as *const f32).add(off * h), cnt * h);
608        let o_s = std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
609        let bi = std::slice::from_raw_parts(bias_ptr as *const f32, h);
610        let g = std::slice::from_raw_parts(gamma_ptr as *const f32, h);
611        let be = std::slice::from_raw_parts(beta_ptr as *const f32, h);
612        residual_bias_layer_norm(a_s, b_s, bi, g, be, o_s, cnt, h, eps);
613    });
614}
615
616// ── NEON softmax ────────────────────────────────────────────────────────
617
618/// NEON-vectorized softmax: 3-pass (max, exp+sum, normalize).
619#[cfg(target_arch = "aarch64")]
620pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
621    use std::arch::aarch64::*;
622    let chunks = cols / 4;
623    unsafe {
624        for row in 0..rows {
625            let base = row * cols;
626            let ptr = data.as_mut_ptr().add(base);
627
628            // Pass 1: find row max
629            let mut vmax = vdupq_n_f32(f32::NEG_INFINITY);
630            for c in 0..chunks {
631                vmax = vmaxq_f32(vmax, vld1q_f32(ptr.add(c * 4)));
632            }
633            let mut max_val = vmaxvq_f32(vmax);
634            for i in (chunks * 4)..cols {
635                max_val = max_val.max(*ptr.add(i));
636            }
637
638            // Pass 2: exp(x - max) and accumulate sum
639            let vmx = vdupq_n_f32(max_val);
640            let mut vsum = vdupq_n_f32(0.0);
641            for c in 0..chunks {
642                let off = c * 4;
643                let e = neon_exp4(vsubq_f32(vld1q_f32(ptr.add(off)), vmx));
644                vst1q_f32(ptr.add(off), e);
645                vsum = vaddq_f32(vsum, e);
646            }
647            let mut sum = vaddvq_f32(vsum);
648            for i in (chunks * 4)..cols {
649                let e = (*ptr.add(i) - max_val).exp();
650                *ptr.add(i) = e;
651                sum += e;
652            }
653
654            // Pass 3: normalize
655            let vinv = vdupq_n_f32(1.0 / sum);
656            for c in 0..chunks {
657                let off = c * 4;
658                vst1q_f32(ptr.add(off), vmulq_f32(vld1q_f32(ptr.add(off)), vinv));
659            }
660            let inv = 1.0 / sum;
661            for i in (chunks * 4)..cols {
662                *ptr.add(i) *= inv;
663            }
664        }
665    }
666}
667
668#[cfg(all(
669    target_arch = "x86_64",
670    target_feature = "avx2",
671    target_feature = "fma"
672))]
673pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
674    use std::arch::x86_64::*;
675    let chunks = cols / 8;
676    unsafe {
677        for r in 0..rows {
678            let row = data.as_mut_ptr().add(r * cols);
679            // 1) Vector max for stability.
680            let mut vmax = _mm256_set1_ps(f32::NEG_INFINITY);
681            for c in 0..chunks {
682                vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(row.add(c * 8)));
683            }
684            let mut max_v = {
685                let lo = _mm256_castps256_ps128(vmax);
686                let hi = _mm256_extractf128_ps::<1>(vmax);
687                let s4 = _mm_max_ps(lo, hi);
688                let s2 = _mm_max_ps(s4, _mm_movehl_ps(s4, s4));
689                let s1 = _mm_max_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
690                _mm_cvtss_f32(s1)
691            };
692            for i in (chunks * 8)..cols {
693                let v = *row.add(i);
694                if v > max_v {
695                    max_v = v;
696                }
697            }
698            // 2) exp(x − max) and sum.
699            let vmax = _mm256_set1_ps(max_v);
700            let mut vsum = _mm256_setzero_ps();
701            for c in 0..chunks {
702                let off = c * 8;
703                let e = avx2_exp8(_mm256_sub_ps(_mm256_loadu_ps(row.add(off)), vmax));
704                _mm256_storeu_ps(row.add(off), e);
705                vsum = _mm256_add_ps(vsum, e);
706            }
707            let mut sum_v = {
708                let lo = _mm256_castps256_ps128(vsum);
709                let hi = _mm256_extractf128_ps::<1>(vsum);
710                let s4 = _mm_add_ps(lo, hi);
711                let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
712                let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
713                _mm_cvtss_f32(s1)
714            };
715            for i in (chunks * 8)..cols {
716                let v = (*row.add(i) - max_v).exp();
717                *row.add(i) = v;
718                sum_v += v;
719            }
720            // 3) Normalize.
721            let vinv = _mm256_set1_ps(1.0 / sum_v);
722            for c in 0..chunks {
723                let off = c * 8;
724                _mm256_storeu_ps(
725                    row.add(off),
726                    _mm256_mul_ps(_mm256_loadu_ps(row.add(off)), vinv),
727                );
728            }
729            let inv_sum = 1.0 / sum_v;
730            for i in (chunks * 8)..cols {
731                *row.add(i) *= inv_sum;
732            }
733        }
734    }
735}
736
737#[cfg(not(any(
738    target_arch = "aarch64",
739    all(
740        target_arch = "x86_64",
741        target_feature = "avx2",
742        target_feature = "fma"
743    )
744)))]
745pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
746    crate::naive::softmax(data, rows, cols);
747}
748
749// ── GELU in-place (no bias) ────────────────────────────────────────────
750
751/// NEON GELU activation in-place (without bias addition).
752#[cfg(target_arch = "aarch64")]
753pub fn gelu_inplace(data: &mut [f32]) {
754    use std::arch::aarch64::*;
755    let len = data.len();
756    let chunks = len / 4;
757    unsafe {
758        let half = vdupq_n_f32(0.5);
759        let one = vdupq_n_f32(1.0);
760        let inv_sqrt2 = vdupq_n_f32(std::f32::consts::FRAC_1_SQRT_2);
761        let p = vdupq_n_f32(0.3275911);
762        let a1 = vdupq_n_f32(0.254_829_6);
763        let a2 = vdupq_n_f32(-0.284_496_72);
764        let a3 = vdupq_n_f32(1.421_413_8);
765        let a4 = vdupq_n_f32(-1.453_152_1);
766        let a5 = vdupq_n_f32(1.061_405_4);
767        let neg_one = vdupq_n_f32(-1.0);
768        let zero = vdupq_n_f32(0.0);
769
770        for c in 0..chunks {
771            let ptr = data.as_mut_ptr().add(c * 4);
772            let x = vld1q_f32(ptr);
773            let erf_arg = vmulq_f32(x, inv_sqrt2);
774            let xa = vabsq_f32(erf_arg);
775            let sign = vbslq_f32(vcgeq_f32(erf_arg, zero), one, neg_one);
776            let denom = vfmaq_f32(one, p, xa);
777            let t = vdivq_f32(one, denom);
778            let mut y = a5;
779            y = vfmaq_f32(a4, y, t);
780            y = vfmaq_f32(a3, y, t);
781            y = vfmaq_f32(a2, y, t);
782            y = vfmaq_f32(a1, y, t);
783            y = vmulq_f32(y, t);
784            let exp_val = neon_exp4(vnegq_f32(vmulq_f32(xa, xa)));
785            let erf_val = vmulq_f32(sign, vfmsq_f32(one, y, exp_val));
786            vst1q_f32(ptr, vmulq_f32(x, vmulq_f32(half, vaddq_f32(one, erf_val))));
787        }
788        for i in (chunks * 4)..len {
789            data[i] = scalar_gelu(data[i]);
790        }
791    }
792}
793
794#[cfg(all(
795    target_arch = "x86_64",
796    target_feature = "avx2",
797    target_feature = "fma"
798))]
799pub fn gelu_inplace(data: &mut [f32]) {
800    use std::arch::x86_64::*;
801    let chunks = data.len() / 8;
802    unsafe {
803        let half = _mm256_set1_ps(0.5);
804        let one = _mm256_set1_ps(1.0);
805        let inv_sqrt2 = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
806        let p = _mm256_set1_ps(0.3275911);
807        let a1 = _mm256_set1_ps(0.254829592);
808        let a2 = _mm256_set1_ps(-0.284496736);
809        let a3 = _mm256_set1_ps(1.421413741);
810        let a4 = _mm256_set1_ps(-1.453152027);
811        let a5 = _mm256_set1_ps(1.061405429);
812        let neg_one = _mm256_set1_ps(-1.0);
813        let zero = _mm256_set1_ps(0.0);
814        let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff_ffff));
815        for c in 0..chunks {
816            let off = c * 8;
817            let ptr = data.as_mut_ptr().add(off);
818            let x = _mm256_loadu_ps(ptr);
819            let erf_arg = _mm256_mul_ps(x, inv_sqrt2);
820            let xa = _mm256_and_ps(erf_arg, abs_mask);
821            let ge0 = _mm256_cmp_ps::<_CMP_GE_OQ>(erf_arg, zero);
822            let sign = _mm256_blendv_ps(neg_one, one, ge0);
823            let denom = _mm256_fmadd_ps(p, xa, one);
824            let t = _mm256_div_ps(one, denom);
825            let mut y = a5;
826            y = _mm256_fmadd_ps(y, t, a4);
827            y = _mm256_fmadd_ps(y, t, a3);
828            y = _mm256_fmadd_ps(y, t, a2);
829            y = _mm256_fmadd_ps(y, t, a1);
830            y = _mm256_mul_ps(y, t);
831            let exp_val = avx2_exp8(_mm256_sub_ps(zero, _mm256_mul_ps(xa, xa)));
832            let erf_val = _mm256_mul_ps(sign, _mm256_fnmadd_ps(y, exp_val, one));
833            _mm256_storeu_ps(
834                ptr,
835                _mm256_mul_ps(x, _mm256_mul_ps(half, _mm256_add_ps(one, erf_val))),
836            );
837        }
838        for i in (chunks * 8)..data.len() {
839            data[i] = scalar_gelu(data[i]);
840        }
841    }
842}
843
844#[cfg(not(any(
845    target_arch = "aarch64",
846    all(
847        target_arch = "x86_64",
848        target_feature = "avx2",
849        target_feature = "fma"
850    )
851)))]
852pub fn gelu_inplace(data: &mut [f32]) {
853    for v in data.iter_mut() {
854        *v = scalar_gelu(*v);
855    }
856}
857
858/// Parallel GELU in-place (splits work across thread pool).
859///
860/// Activation kernels are O(n) with very low per-element cost
861/// (~10 NEON cycles on aarch64). Pool dispatch overhead — even
862/// with the parked design — is in the multi-µs range under
863/// container scheduling, which dwarfs the actual compute for any
864/// reasonable activation size. Threshold here is 1 Mi elements:
865/// only crossed by very large activation tensors (e.g. an
866/// H=4096, FFN=14336, S=1024 LLM up-projection at ~14M
867/// elements). Single-thread NEON is the clear win below that.
868const ACTIVATION_PAR_MIN: usize = 1 << 20;
869
870/// Tanh-approximation GELU (matches PyTorch/candle `Tensor::gelu`):
871///   y = 0.5 x (1 + tanh(√(2/π) · (x + 0.044715 x³)))
872///
873/// Scalar-only for now; the erf-based `gelu_inplace` above is SIMD.
874/// Routed from `Activation::GeluApprox` so models that need
875/// numerical parity with PyTorch's default GELU (e.g. DINOv2,
876/// many ViTs) get the right formula. Use `Activation::Gelu` for the
877/// erf form (also PyTorch-default in some newer builds).
878#[inline]
879pub fn scalar_gelu_approx(x: f32) -> f32 {
880    const C: f32 = 0.797_884_6; // √(2/π)
881    const A: f32 = 0.044_715;
882    0.5 * x * (1.0 + (C * (x + A * x * x * x)).tanh())
883}
884
885pub fn gelu_approx_inplace(data: &mut [f32]) {
886    for v in data.iter_mut() {
887        *v = scalar_gelu_approx(*v);
888    }
889}
890
891pub fn par_gelu_approx_inplace(data: &mut [f32]) {
892    let len = data.len();
893    if len < ACTIVATION_PAR_MIN {
894        gelu_approx_inplace(data);
895        return;
896    }
897    let cfg = crate::config::RuntimeConfig::global();
898    let chunk = 512;
899    let rows = len / chunk;
900    if rows < 2 {
901        gelu_approx_inplace(data);
902        return;
903    }
904    let data_ptr = data.as_mut_ptr() as usize;
905    pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
906        let start = off * chunk;
907        let end = if off + cnt >= rows {
908            len
909        } else {
910            (off + cnt) * chunk
911        };
912        let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
913        gelu_approx_inplace(s);
914    });
915    let done = rows * chunk;
916    if done < len {
917        gelu_approx_inplace(&mut data[done..]);
918    }
919}
920
921pub fn par_gelu_inplace(data: &mut [f32]) {
922    let len = data.len();
923    if len < ACTIVATION_PAR_MIN {
924        gelu_inplace(data);
925        return;
926    }
927    let cfg = crate::config::RuntimeConfig::global();
928    let chunk = 512;
929    let rows = len / chunk;
930    if rows < 2 {
931        gelu_inplace(data);
932        return;
933    }
934    let data_ptr = data.as_mut_ptr() as usize;
935    pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
936        let start = off * chunk;
937        let end = if off + cnt >= rows {
938            len
939        } else {
940            (off + cnt) * chunk
941        };
942        let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
943        gelu_inplace(s);
944    });
945    let done = rows * chunk;
946    if done < len {
947        gelu_inplace(&mut data[done..]);
948    }
949}
950
951/// Parallel SiLU in-place. Same threshold reasoning as `par_gelu_inplace`.
952pub fn par_silu_inplace(data: &mut [f32]) {
953    let len = data.len();
954    if len < ACTIVATION_PAR_MIN {
955        silu_inplace(data);
956        return;
957    }
958    let cfg = crate::config::RuntimeConfig::global();
959    let chunk = 512;
960    let rows = len / chunk;
961    if rows < 2 {
962        silu_inplace(data);
963        return;
964    }
965    let data_ptr = data.as_mut_ptr() as usize;
966    pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
967        let start = off * chunk;
968        let end = if off + cnt >= rows {
969            len
970        } else {
971            (off + cnt) * chunk
972        };
973        let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
974        silu_inplace(s);
975    });
976    let done = rows * chunk;
977    if done < len {
978        silu_inplace(&mut data[done..]);
979    }
980}
981
982// ── Small-m NEON matmul ─────────────────────────────────────────────────
983
984/// NEON matmul for tiny m (1-8 rows). Avoids BLAS call overhead.
985/// C = A @ B where A=\[m,k\], B=\[k,n\], C=\[m,n\], all row-major.
986/// For m≤8 with small k×n (under ~16K elements), this beats cblas_sgemm
987/// by avoiding AMX setup cost and function call overhead.
988#[cfg(target_arch = "aarch64")]
989pub fn neon_sgemm_small(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
990    use std::arch::aarch64::*;
991    let n4 = n / 4;
992    unsafe {
993        for j4 in 0..n4 {
994            let j = j4 * 4;
995            // m accumulators (one per output row, 4-wide)
996            let mut acc = [vdupq_n_f32(0.0); 8];
997            for kk in 0..k {
998                let bv = vld1q_f32(b.as_ptr().add(kk * n + j));
999                for i in 0..m {
1000                    let av = vdupq_n_f32(*a.as_ptr().add(i * k + kk));
1001                    acc[i] = vfmaq_f32(acc[i], av, bv);
1002                }
1003            }
1004            for i in 0..m {
1005                vst1q_f32(c.as_mut_ptr().add(i * n + j), acc[i]);
1006            }
1007        }
1008        // Remainder columns
1009        for j in (n4 * 4)..n {
1010            for i in 0..m {
1011                let mut sum = 0f32;
1012                for kk in 0..k {
1013                    sum += a[i * k + kk] * b[kk * n + j];
1014                }
1015                c[i * n + j] = sum;
1016            }
1017        }
1018    }
1019}
1020
1021#[cfg(not(target_arch = "aarch64"))]
1022pub fn neon_sgemm_small(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
1023    crate::naive::matmul(a, b, c, m, k, n);
1024}
1025
1026/// NEON sgemm_bias for tiny m: C = A @ B + bias.
1027#[cfg(target_arch = "aarch64")]
1028pub fn neon_sgemm_bias_small(
1029    a: &[f32],
1030    b: &[f32],
1031    bias: &[f32],
1032    c: &mut [f32],
1033    m: usize,
1034    k: usize,
1035    n: usize,
1036) {
1037    neon_sgemm_small(a, b, c, m, k, n);
1038    crate::blas::bias_add(c, bias, m, n);
1039}
1040
1041#[cfg(not(target_arch = "aarch64"))]
1042pub fn neon_sgemm_bias_small(
1043    a: &[f32],
1044    b: &[f32],
1045    bias: &[f32],
1046    c: &mut [f32],
1047    m: usize,
1048    k: usize,
1049    n: usize,
1050) {
1051    crate::naive::matmul(a, b, c, m, k, n);
1052    crate::naive::bias_add(c, bias, m, n);
1053}
1054
1055// ── Scalar fallbacks ────────────────────────────────────────────────────
1056
1057fn scalar_gelu(x: f32) -> f32 {
1058    x * 0.5 * (1.0 + scalar_erf(x * std::f32::consts::FRAC_1_SQRT_2))
1059}
1060
1061fn scalar_erf(x: f32) -> f32 {
1062    let sign = if x >= 0.0 { 1.0f32 } else { -1.0 };
1063    let xa = x.abs();
1064    let t = 1.0 / (1.0 + 0.3275911 * xa);
1065    let y = t
1066        * (0.254_829_6
1067            + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
1068    sign * (1.0 - y * (-xa * xa).exp())
1069}
1070
1071/// NCHW LayerNorm2d (candle / SAM semantics): normalize across channels at
1072/// each spatial position. `gamma`/`beta` are per-channel `[C]`.
1073pub fn layer_norm2d_nchw(
1074    input: &[f32],
1075    gamma: &[f32],
1076    beta: &[f32],
1077    output: &mut [f32],
1078    batch: usize,
1079    channels: usize,
1080    h: usize,
1081    w: usize,
1082    eps: f32,
1083) {
1084    let spatial = h * w;
1085    for b in 0..batch {
1086        for i in 0..spatial {
1087            let mut mean = 0.0f32;
1088            for c in 0..channels {
1089                mean += input[((b * channels + c) * spatial) + i];
1090            }
1091            mean /= channels as f32;
1092            let mut var = 0.0f32;
1093            for c in 0..channels {
1094                let d = input[((b * channels + c) * spatial) + i] - mean;
1095                var += d * d;
1096            }
1097            var /= channels as f32;
1098            let inv = 1.0 / (var + eps).sqrt();
1099            for c in 0..channels {
1100                let v = (input[((b * channels + c) * spatial) + i] - mean) * inv;
1101                output[((b * channels + c) * spatial) + i] = v * gamma[c] + beta[c];
1102            }
1103        }
1104    }
1105}
1106
1107/// NCHW transposed convolution (PyTorch `ConvTranspose2d`, no bias).
1108/// Weight layout `[C_in, C_out/groups, kH, kW]`.
1109pub fn conv_transpose2d_nchw(
1110    input: &[f32],
1111    weight: &[f32],
1112    output: &mut [f32],
1113    n: usize,
1114    c_in: usize,
1115    h: usize,
1116    w: usize,
1117    c_out: usize,
1118    h_out: usize,
1119    w_out: usize,
1120    kh: usize,
1121    kw: usize,
1122    sh: usize,
1123    sw: usize,
1124    ph: usize,
1125    pw: usize,
1126    dh: usize,
1127    dw: usize,
1128    groups: usize,
1129) {
1130    output.fill(0.0);
1131    let c_in_per_g = c_in / groups;
1132    let c_out_per_g = c_out / groups;
1133    for ni in 0..n {
1134        for ic in 0..c_in {
1135            let g = ic / c_in_per_g;
1136            let _ic_off = ic % c_in_per_g;
1137            for iy in 0..h {
1138                for ix in 0..w {
1139                    let v = input[((ni * c_in + ic) * h + iy) * w + ix];
1140                    if v == 0.0 {
1141                        continue;
1142                    }
1143                    for ky in 0..kh {
1144                        let oy = iy * sh + ky * dh;
1145                        if oy < ph || oy >= h_out + ph {
1146                            continue;
1147                        }
1148                        let oy = oy - ph;
1149                        if oy >= h_out {
1150                            continue;
1151                        }
1152                        for kx in 0..kw {
1153                            let ox = ix * sw + kx * dw;
1154                            if ox < pw || ox >= w_out + pw {
1155                                continue;
1156                            }
1157                            let ox = ox - pw;
1158                            if ox >= w_out {
1159                                continue;
1160                            }
1161                            for oc_off in 0..c_out_per_g {
1162                                let oc = g * c_out_per_g + oc_off;
1163                                let w_idx = ((ic * c_out_per_g + oc_off) * kh + ky) * kw + kx;
1164                                let wt = weight[w_idx];
1165                                output[((ni * c_out + oc) * h_out + oy) * w_out + ox] += v * wt;
1166                            }
1167                        }
1168                    }
1169                }
1170            }
1171        }
1172    }
1173}
1174
1175/// NCHW group normalization: normalizes each `(C/G)×H×W` group.
1176pub fn group_norm_nchw(
1177    input: &[f32],
1178    gamma: &[f32],
1179    beta: &[f32],
1180    output: &mut [f32],
1181    batch: usize,
1182    channels: usize,
1183    h: usize,
1184    w: usize,
1185    num_groups: usize,
1186    eps: f32,
1187) {
1188    let cpg = channels / num_groups;
1189    let spatial = h * w;
1190    let n = (cpg * spatial) as f32;
1191    for b in 0..batch {
1192        for g in 0..num_groups {
1193            let c0 = g * cpg;
1194            let mut mean = 0.0f32;
1195            for c in 0..cpg {
1196                let plane = &input
1197                    [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1198                mean += plane.iter().sum::<f32>();
1199            }
1200            mean /= n;
1201            let mut var = 0.0f32;
1202            for c in 0..cpg {
1203                let plane = &input
1204                    [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1205                for &v in plane {
1206                    let d = v - mean;
1207                    var += d * d;
1208                }
1209            }
1210            var /= n;
1211            let inv = 1.0 / (var + eps).sqrt();
1212            for c in 0..cpg {
1213                let gi = c0 + c;
1214                let gamm = gamma[gi];
1215                let bet = beta[gi];
1216                let src =
1217                    &input[((b * channels + gi) * spatial)..((b * channels + gi + 1) * spatial)];
1218                let dst = &mut output
1219                    [((b * channels + gi) * spatial)..((b * channels + gi + 1) * spatial)];
1220                for (d, &s) in dst.iter_mut().zip(src) {
1221                    *d = (s - mean) * inv * gamm + bet;
1222                }
1223            }
1224        }
1225    }
1226}
1227
1228/// Nearest-neighbor 2× upsample on planar NCHW.
1229pub fn resize_nearest_2x_nchw(
1230    input: &[f32],
1231    output: &mut [f32],
1232    channels: usize,
1233    h: usize,
1234    w: usize,
1235) {
1236    let h2 = h * 2;
1237    let w2 = w * 2;
1238    for c in 0..channels {
1239        let plane = &input[c * h * w..(c + 1) * h * w];
1240        let dst = &mut output[c * h2 * w2..(c + 1) * h2 * w2];
1241        for y in 0..h {
1242            for x in 0..w {
1243                let v = plane[y * w + x];
1244                for dy in 0..2 {
1245                    for dx in 0..2 {
1246                        dst[(y * 2 + dy) * w2 + (x * 2 + dx)] = v;
1247                    }
1248                }
1249            }
1250        }
1251    }
1252}
1253
1254#[cfg(test)]
1255mod tests {
1256    use super::*;
1257
1258    #[test]
1259    fn gelu_correctness() {
1260        let x = 1.5f32;
1261        let g = scalar_gelu(x);
1262        // Reference: gelu(1.5) ≈ 1.3990
1263        assert!((g - 1.3990).abs() < 0.01, "gelu(1.5) = {g}");
1264    }
1265
1266    #[test]
1267    fn bias_gelu_works() {
1268        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1269        let bias = vec![0.1, 0.2, 0.3, 0.4];
1270        bias_gelu(&mut data, &bias, 2, 4);
1271        // After bias+gelu, values should be > 0 (all inputs positive)
1272        for &v in &data {
1273            assert!(v > 0.0, "bias_gelu produced {v}");
1274        }
1275    }
1276
1277    #[test]
1278    fn batch_norm_inference_roundtrip() {
1279        let c = 4usize;
1280        let x: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1281        let gamma = vec![1.0; c];
1282        let beta = vec![0.0; c];
1283        let mean = vec![2.5, 2.5, 2.5, 2.5];
1284        let var = vec![1.0; c];
1285        let mut y = vec![0.0; 8];
1286        batch_norm_inference(&x, &gamma, &beta, &mean, &var, &mut y, c, 1e-5);
1287        let mut dx = vec![0.0; 8];
1288        let dy = vec![1.0; 8];
1289        let mut dgamma = vec![0.0; c];
1290        let mut dbeta = vec![0.0; c];
1291        batch_norm_inference_backward_input(&x, &gamma, &mean, &var, &dy, &mut dx, c, 1e-5);
1292        batch_norm_inference_backward_gamma(&x, &mean, &var, &dy, &mut dgamma, c, 1e-5);
1293        batch_norm_inference_backward_beta(&dy, &mut dbeta, c);
1294        assert!(y.iter().all(|v| v.is_finite()));
1295        assert!(dx.iter().all(|v| v.is_finite()));
1296        assert!(dgamma.iter().any(|&v| v.abs() > 1e-6));
1297        assert_eq!(dbeta, vec![2.0, 2.0, 2.0, 2.0]);
1298    }
1299
1300    #[test]
1301    fn layer_norm_unit_test() {
1302        let input = vec![1.0, 2.0, 3.0, 4.0];
1303        let gamma = vec![1.0; 4];
1304        let beta = vec![0.0; 4];
1305        let mut output = vec![0.0; 4];
1306        layer_norm_row(&input, &gamma, &beta, &mut output, 4, 1e-5);
1307        // Mean=2.5, std≈1.118. output ≈ [-1.342, -0.447, 0.447, 1.342]
1308        assert!((output[0] - -1.342).abs() < 0.01);
1309        assert!((output[3] - 1.342).abs() < 0.01);
1310        // Sum should be ~0 (normalized)
1311        let sum: f32 = output.iter().sum();
1312        assert!(sum.abs() < 0.01, "LN sum should be ~0, got {sum}");
1313    }
1314
1315    #[test]
1316    fn par_bias_gelu_matches_sequential() {
1317        let n = 100;
1318        let m = 64;
1319        let mut data_par = vec![0.5f32; n * m];
1320        let mut data_seq = data_par.clone();
1321        let bias = vec![0.1f32; m];
1322
1323        bias_gelu(&mut data_seq, &bias, n, m);
1324        par_bias_gelu(&mut data_par, &bias, n, m);
1325
1326        let max_diff: f32 = data_par
1327            .iter()
1328            .zip(data_seq.iter())
1329            .map(|(a, b)| (a - b).abs())
1330            .fold(0f32, f32::max);
1331        assert!(max_diff < 1e-6, "par vs seq diff: {max_diff}");
1332    }
1333}