Skip to main content

burn_flex/ops/
activation.rs

1//! Activation function operations for the Flex backend.
2//!
3//! Each activation is implemented as a single-pass unary operation,
4//! replacing the default multi-op compositions from Burn's trait defaults.
5
6use alloc::vec;
7use alloc::vec::Vec;
8use burn_backend::Scalar;
9use burn_backend::ops::{ActivationOps, FloatTensorOps};
10use burn_backend::tensor::FloatTensor;
11use burn_backend::{DType, TensorMetadata};
12use burn_std::{Bytes, bf16, f16};
13#[cfg(not(feature = "std"))]
14#[allow(unused_imports)]
15use num_traits::Float;
16use num_traits::ToPrimitive;
17
18use crate::ops::binary::binary_op;
19use crate::ops::unary::unary_op;
20use crate::{Flex, FlexTensor, Layout};
21
22impl ActivationOps<Flex> for Flex {
23    fn relu(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
24        unary_op(tensor, |x: f32| x.max(0.0), |x: f64| x.max(0.0))
25    }
26
27    fn relu_backward(output: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
28        // grad * (output > 0): zero the gradient where output was zero
29        binary_op(
30            output,
31            grad,
32            |out: f32, g| if out > 0.0 { g } else { 0.0 },
33            |out: f64, g| if out > 0.0 { g } else { 0.0 },
34            None,
35        )
36    }
37
38    fn leaky_relu(tensor: FloatTensor<Flex>, negative_slope: Scalar) -> FloatTensor<Flex> {
39        let ns32 = negative_slope.to_f32().unwrap();
40        let ns64 = negative_slope.to_f64().unwrap();
41        unary_op(
42            tensor,
43            move |x: f32| if x >= 0.0 { x } else { ns32 * x },
44            move |x: f64| if x >= 0.0 { x } else { ns64 * x },
45        )
46    }
47
48    fn prelu(tensor: FloatTensor<Flex>, alpha: FloatTensor<Flex>) -> FloatTensor<Flex> {
49        // x if x >= 0, alpha * x otherwise
50        binary_op(
51            tensor,
52            alpha,
53            |x: f32, a| if x >= 0.0 { x } else { a * x },
54            |x: f64, a| if x >= 0.0 { x } else { a * x },
55            None,
56        )
57    }
58
59    fn gelu(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
60        // 0.5 * x * (1 + erf(x / sqrt(2)))
61        use crate::ops::unary::{erf_f32, erf_f64};
62        let sqrt2_f32: f32 = core::f32::consts::SQRT_2;
63        let sqrt2_f64: f64 = core::f64::consts::SQRT_2;
64        unary_op(
65            tensor,
66            move |x: f32| 0.5 * x * (1.0 + erf_f32(x / sqrt2_f32)),
67            move |x: f64| 0.5 * x * (1.0 + erf_f64(x / sqrt2_f64)),
68        )
69    }
70
71    fn gelu_backward(x: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
72        // d/dx[gelu(x)] = 0.5 * (1 + erf(x/sqrt(2))) + x * (1/sqrt(2*pi)) * exp(-x^2/2)
73        use crate::ops::unary::{erf_f32, erf_f64};
74        let sqrt2_f32: f32 = core::f32::consts::SQRT_2;
75        let sqrt2_f64: f64 = core::f64::consts::SQRT_2;
76        let inv_sqrt_2pi_f32: f32 = 1.0 / (2.0 * core::f32::consts::PI).sqrt();
77        let inv_sqrt_2pi_f64: f64 = 1.0 / (2.0 * core::f64::consts::PI).sqrt();
78        binary_op(
79            x,
80            grad,
81            move |x: f32, g| {
82                let cdf = 0.5 * (1.0 + erf_f32(x / sqrt2_f32));
83                let pdf = inv_sqrt_2pi_f32 * (-0.5 * x * x).exp();
84                g * (cdf + x * pdf)
85            },
86            move |x: f64, g| {
87                let cdf = 0.5 * (1.0 + erf_f64(x / sqrt2_f64));
88                let pdf = inv_sqrt_2pi_f64 * (-0.5 * x * x).exp();
89                g * (cdf + x * pdf)
90            },
91            None,
92        )
93    }
94
95    fn sigmoid(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
96        unary_op(tensor, sigmoid_f32, sigmoid_f64)
97    }
98
99    fn sigmoid_backward(output: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
100        // grad * output * (1 - output)
101        binary_op(
102            output,
103            grad,
104            |s: f32, g| g * s * (1.0 - s),
105            |s: f64, g| g * s * (1.0 - s),
106            None,
107        )
108    }
109
110    fn hard_sigmoid(tensor: FloatTensor<Flex>, alpha: Scalar, beta: Scalar) -> FloatTensor<Flex> {
111        let alpha32 = alpha.to_f32().unwrap();
112        let beta32 = beta.to_f32().unwrap();
113        let alpha64 = alpha.to_f64().unwrap();
114        let beta64 = beta.to_f64().unwrap();
115        unary_op(
116            tensor,
117            move |x: f32| (alpha32 * x + beta32).clamp(0.0, 1.0),
118            move |x: f64| (alpha64 * x + beta64).clamp(0.0, 1.0),
119        )
120    }
121
122    fn log_sigmoid(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
123        // Numerically stable: -softplus(-x) = -log(1 + exp(-x))
124        // For x >= 0: -log(1 + exp(-x))  (standard form, exp(-x) is small)
125        // For x < 0: x - log(1 + exp(x))  (avoids exp of large positive)
126        unary_op(
127            tensor,
128            |x: f32| {
129                if x >= 0.0 {
130                    -((-x).exp().ln_1p())
131                } else {
132                    x - x.exp().ln_1p()
133                }
134            },
135            |x: f64| {
136                if x >= 0.0 {
137                    -((-x).exp().ln_1p())
138                } else {
139                    x - x.exp().ln_1p()
140                }
141            },
142        )
143    }
144
145    fn log_sigmoid_backward(x: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
146        // d/dx[log_sigmoid(x)] = sigmoid(-x) * (-1) * (-1) = 1 - sigmoid(x) = sigmoid(-x)
147        // So: grad * sigmoid(-x)
148        binary_op(
149            x,
150            grad,
151            |x: f32, g| g * sigmoid_f32(-x),
152            |x: f64, g| g * sigmoid_f64(-x),
153            None,
154        )
155    }
156
157    fn softmax(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
158        softmax(tensor, dim)
159    }
160}
161
162#[inline]
163fn sigmoid_f32(x: f32) -> f32 {
164    if x >= 0.0 {
165        1.0 / (1.0 + (-x).exp())
166    } else {
167        let e = x.exp();
168        e / (1.0 + e)
169    }
170}
171
172#[inline]
173fn sigmoid_f64(x: f64) -> f64 {
174    if x >= 0.0 {
175        1.0 / (1.0 + (-x).exp())
176    } else {
177        let e = x.exp();
178        e / (1.0 + e)
179    }
180}
181
182// ============================================================================
183// Fused softmax
184// ============================================================================
185//
186// `ActivationOps` does not currently expose a `softmax` hook, so
187// `burn_tensor::activation::softmax` falls back to a 5-op decomposition
188// (`max_dim`/`sub`/`exp`/`sum_dim`/`div`). This module provides a fused
189// alternative users can opt into directly.
190
191/// Fused softmax along `dim`.
192///
193/// Three-pass row-wise algorithm (max, exp+sum, normalize) keeping each row
194/// cache-hot. Rows are processed in parallel via rayon. For axes other than
195/// the last, the tensor is permuted to put `dim` last, the fused kernel runs,
196/// and the result is permuted back (both permutes are metadata-only; the
197/// fused kernel's internal `to_contiguous` materializes the permuted layout
198/// once).
199///
200/// # Panics
201///
202/// * If `dim` is out of range for `input`.
203/// * If `input`'s dtype is not one of `f32`/`f64`/`f16`/`bf16`.
204pub fn softmax(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
205    let rank = tensor.shape().num_dims();
206    assert!(
207        dim < rank,
208        "softmax dim {} out of range for rank {}",
209        dim,
210        rank
211    );
212
213    if dim != rank - 1 {
214        let swapped = Flex::float_swap_dims(tensor, dim, rank - 1);
215        let normed = softmax_last(swapped);
216        return Flex::float_swap_dims(normed, dim, rank - 1);
217    }
218
219    softmax_last(tensor)
220}
221
222fn softmax_last(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
223    let tensor = tensor.to_contiguous();
224    match tensor.dtype() {
225        DType::F32 => softmax_last_f32(tensor),
226        DType::F64 => softmax_last_f64(tensor),
227        DType::F16 => softmax_last_f16(tensor),
228        DType::BF16 => softmax_last_bf16(tensor),
229        dtype => panic!("softmax: unsupported dtype {:?}", dtype),
230    }
231}
232
233fn softmax_last_f32(tensor: FlexTensor) -> FlexTensor {
234    let shape = tensor.layout().shape().clone();
235    let last = *shape.last().expect("softmax: empty shape");
236    if last == 0 {
237        return tensor;
238    }
239    let input: &[f32] = tensor.storage();
240    let n = input.len();
241
242    // Zero-initialize the output. The previous implementation used
243    // `Vec::with_capacity` + `spare_capacity_mut` + a raw-pointer cast to
244    // `&mut [f32]` to skip the memset, but forming a `&mut [f32]` over
245    // uninitialized memory violates Rust's validity invariant (references
246    // must point to initialized values of the correct type) even if every
247    // element is written before it is read. The sound zero-memset
248    // alternative would require threading `&mut [MaybeUninit<f32>]` through
249    // the row kernel, which does not compose with macerator's `#[with_simd]`
250    // signature. The memset is a streaming write on a bandwidth-bound
251    // kernel, so the overhead is small (~10% at the largest bench shape)
252    // and the fused path remains well ahead of decomposed and candle.
253    let mut output: Vec<f32> = vec![0.0; n];
254    let out_slice = output.as_mut_slice();
255
256    // Row-parallel via rayon: one macerator dispatch per chunk of rows,
257    // amortized over all rows in the chunk.
258    #[cfg(feature = "rayon")]
259    {
260        use rayon::prelude::*;
261        const ROWS_PER_TASK: usize = 64;
262        let chunk_elems = ROWS_PER_TASK * last;
263        out_slice
264            .par_chunks_mut(chunk_elems)
265            .zip(input.par_chunks(chunk_elems))
266            .for_each(|(o, i)| softmax_rows_f32(i, o, last));
267    }
268    #[cfg(not(feature = "rayon"))]
269    {
270        softmax_rows_f32(input, out_slice, last);
271    }
272
273    FlexTensor::new(
274        Bytes::from_elems(output),
275        Layout::contiguous(shape),
276        DType::F32,
277    )
278}
279
280/// Row sweep for f32 softmax. With the `simd` feature, delegates to the
281/// `#[macerator::with_simd]` SIMD kernel (one dispatch per chunk of rows,
282/// amortized over all rows in the chunk). Without `simd`, uses a scalar
283/// row kernel.
284#[inline]
285fn softmax_rows_f32(input: &[f32], output: &mut [f32], row_len: usize) {
286    // Release-mode invariant checks. These run once per chunk of rows
287    // (dozens of times per call, not per-element), so the overhead is
288    // unmeasurable against the kernel work. A debug-only check would
289    // silently pass a short final chunk to the row kernel on release
290    // builds if a future refactor broke the row alignment at the call
291    // site, yielding wrong softmax output with no panic.
292    assert_eq!(input.len(), output.len());
293    assert_eq!(input.len() % row_len, 0);
294    #[cfg(feature = "simd")]
295    softmax_rows_f32_simd(input, output, row_len);
296    #[cfg(not(feature = "simd"))]
297    {
298        for (in_row, out_row) in input.chunks(row_len).zip(output.chunks_mut(row_len)) {
299            softmax_row_f32_scalar(in_row, out_row);
300        }
301    }
302}
303
304#[cfg(feature = "simd")]
305#[macerator::with_simd]
306fn softmax_rows_f32_simd<S: macerator::Simd>(input: &[f32], output: &mut [f32], row_len: usize) {
307    debug_assert_eq!(input.len(), output.len());
308    debug_assert_eq!(input.len() % row_len, 0);
309    for (in_row, out_row) in input.chunks(row_len).zip(output.chunks_mut(row_len)) {
310        softmax_row_f32_simd::<S>(in_row, out_row);
311    }
312}
313
314/// Scalar fallback row kernel for f32 softmax when the `simd` feature is
315/// disabled. Uses the same 3-pass algorithm as the SIMD path; LLVM
316/// autovectorizes the max-reduce and normalize loops on most targets.
317#[cfg(not(feature = "simd"))]
318#[inline]
319fn softmax_row_f32_scalar(input: &[f32], output: &mut [f32]) {
320    let mut max_val = f32::NEG_INFINITY;
321    for &x in input {
322        if x > max_val {
323            max_val = x;
324        }
325    }
326    let mut sum = 0.0f32;
327    for (i, &x) in input.iter().enumerate() {
328        let e = (x - max_val).exp();
329        output[i] = e;
330        sum += e;
331    }
332    let inv = 1.0f32 / sum;
333    for x in output.iter_mut() {
334        *x *= inv;
335    }
336}
337
338/// Inner row kernel for a single softmax row. `#[inline(always)]` so it
339/// inlines into `softmax_rows_f32_simd`'s loop body for each monomorphized S,
340/// avoiding a per-row call boundary.
341#[cfg(feature = "simd")]
342#[inline(always)]
343fn softmax_row_f32_simd<S: macerator::Simd>(input: &[f32], output: &mut [f32]) {
344    use macerator::{Scalar, vload_unaligned, vstore_unaligned};
345    let lanes = <f32 as Scalar>::lanes::<S>();
346    let len = input.len();
347    let simd_len = len / lanes * lanes;
348
349    // Pass 1: row max for numerical stability.
350    // SIMD max-reduction across the row, scalar tail.
351    let (mut max_val, tail_start) = if simd_len >= lanes {
352        let mut max_vec = unsafe { vload_unaligned::<S, _>(input.as_ptr()) };
353        let mut j = lanes;
354        while j < simd_len {
355            let v = unsafe { vload_unaligned::<S, _>(input.as_ptr().add(j)) };
356            max_vec = max_vec.max(v);
357            j += lanes;
358        }
359        (max_vec.reduce_max(), simd_len)
360    } else {
361        (f32::NEG_INFINITY, 0)
362    };
363    for &x in &input[tail_start..] {
364        if x > max_val {
365            max_val = x;
366        }
367    }
368
369    // Pass 2: compute exp(x - max), store in output, accumulate sum.
370    // Scalar exp (no SIMD exp in macerator). This pass is the one that
371    // actually does memory reads + writes on the whole row, so scalar
372    // here still lands us at memory bandwidth.
373    let mut sum = 0.0f32;
374    for idx in 0..len {
375        let e = (input[idx] - max_val).exp();
376        output[idx] = e;
377        sum += e;
378    }
379
380    // Pass 3: normalize.
381    // SIMD splat + multiply, scalar tail.
382    let inv = 1.0f32 / sum;
383    let inv_vec = inv.splat::<S>();
384    let mut i = 0;
385    while i < simd_len {
386        unsafe {
387            let v = vload_unaligned::<S, _>(output.as_ptr().add(i));
388            vstore_unaligned::<S, _>(output.as_mut_ptr().add(i), v * inv_vec);
389        }
390        i += lanes;
391    }
392    for x in &mut output[i..] {
393        *x *= inv;
394    }
395}
396
397// f64, f16, bf16 softmax share the same row-parallel dispatcher shell and
398// differ only in their row kernel (native f64 vs via-f32 for half
399// precision). Generated via macros to keep the three variants in lockstep.
400// Only f32 has a dedicated SIMD fast path above.
401
402macro_rules! softmax_last_dtype {
403    ($fn_name:ident, $T:ty, $zero:expr, $dtype:expr, $row_fn:ident) => {
404        fn $fn_name(tensor: FlexTensor) -> FlexTensor {
405            let shape = tensor.layout().shape().clone();
406            let last = *shape.last().expect("softmax: empty shape");
407            if last == 0 {
408                return tensor;
409            }
410            let input: &[$T] = tensor.storage();
411            let mut output: Vec<$T> = vec![$zero; input.len()];
412
413            #[cfg(feature = "rayon")]
414            {
415                use rayon::prelude::*;
416                output
417                    .par_chunks_mut(last)
418                    .zip(input.par_chunks(last))
419                    .for_each(|(o, i)| $row_fn(i, o));
420            }
421            #[cfg(not(feature = "rayon"))]
422            {
423                for (i, o) in input.chunks(last).zip(output.chunks_mut(last)) {
424                    $row_fn(i, o);
425                }
426            }
427
428            FlexTensor::new(Bytes::from_elems(output), Layout::contiguous(shape), $dtype)
429        }
430    };
431}
432
433/// Half-precision softmax row kernel. Accumulates in f32 for numerical
434/// stability and converts back to the target type at each write. This
435/// double-rounds across passes 2 and 3; acceptable for half precision. An
436/// f32 scratch buffer would remove the double rounding at the cost of a
437/// per-row allocation.
438macro_rules! softmax_row_half {
439    ($fn_name:ident, $T:ty) => {
440        #[inline]
441        fn $fn_name(input: &[$T], output: &mut [$T]) {
442            let mut max_val = f32::NEG_INFINITY;
443            for &x in input {
444                let xf = x.to_f32();
445                if xf > max_val {
446                    max_val = xf;
447                }
448            }
449            let mut sum = 0.0f32;
450            for (i, &x) in input.iter().enumerate() {
451                let e = (x.to_f32() - max_val).exp();
452                output[i] = <$T>::from_f32(e);
453                sum += e;
454            }
455            let inv = 1.0f32 / sum;
456            for x in output.iter_mut() {
457                *x = <$T>::from_f32(x.to_f32() * inv);
458            }
459        }
460    };
461}
462
463#[inline]
464fn softmax_row_f64(input: &[f64], output: &mut [f64]) {
465    let mut max_val = f64::NEG_INFINITY;
466    for &x in input {
467        if x > max_val {
468            max_val = x;
469        }
470    }
471    let mut sum = 0.0f64;
472    for (i, &x) in input.iter().enumerate() {
473        let e = (x - max_val).exp();
474        output[i] = e;
475        sum += e;
476    }
477    let inv = 1.0f64 / sum;
478    for x in output.iter_mut() {
479        *x *= inv;
480    }
481}
482
483softmax_row_half!(softmax_row_f16, f16);
484softmax_row_half!(softmax_row_bf16, bf16);
485
486softmax_last_dtype!(softmax_last_f64, f64, 0.0f64, DType::F64, softmax_row_f64);
487softmax_last_dtype!(
488    softmax_last_f16,
489    f16,
490    f16::from_f32(0.0),
491    DType::F16,
492    softmax_row_f16
493);
494softmax_last_dtype!(
495    softmax_last_bf16,
496    bf16,
497    bf16::from_f32(0.0),
498    DType::BF16,
499    softmax_row_bf16
500);
501
502// ============================================================================
503// Fused layer_norm
504// ============================================================================
505//
506// `burn::nn::LayerNorm::forward` decomposes into ~6 primitive tensor ops
507// with intermediate allocations, and there is no backend trait hook for
508// layer_norm. This module provides a fused alternative users can opt into
509// directly. Two-pass row kernel (sum+sumsq sweep, then normalize+affine
510// sweep), both vectorized via macerator.
511
512/// Fused layer normalization along the last axis.
513///
514/// Applies `y = ((x - mean) / sqrt(var + eps)) * gamma + beta`, where
515/// `mean` and `var` are computed per row along the last axis of `input`.
516/// `gamma` and `beta` are 1-D tensors of length `input.shape()[-1]`;
517/// `beta` is optional (set to `None` for a bias-free layer norm).
518///
519/// Two-pass row kernel (mean/variance via a single sum+sum-of-squares
520/// sweep, then one normalize+affine sweep). Both passes are SIMD via
521/// macerator; each row stays cache-hot across both passes.
522///
523/// Supports `f32` (SIMD-vectorized), `f64` (scalar + LLVM autovec), and
524/// `f16`/`bf16` (via an f32 cast-fuse-cast shell; the f32 row kernel
525/// already accumulates in f32, so this matches the precision a
526/// half-precision-native kernel would produce).
527///
528/// # Panics
529///
530/// * If `input`'s dtype is not one of `f32`/`f64`/`f16`/`bf16`.
531/// * If `input` has rank 0.
532/// * If `gamma` (or `beta`, when present) is not a 1-D tensor of length
533///   equal to the last dim of `input`.
534pub fn layer_norm(
535    input: FloatTensor<Flex>,
536    gamma: FloatTensor<Flex>,
537    beta: Option<FloatTensor<Flex>>,
538    epsilon: f64,
539) -> FloatTensor<Flex> {
540    let rank = input.shape().num_dims();
541    assert!(rank >= 1, "layer_norm: input must have at least one dim");
542    // Keep gamma/beta dtypes aligned with the input. The half-precision path
543    // (see `layer_norm_via_f32`) ultimately accesses storage using the input's
544    // element type, and a mismatch would panic there; reject it up front with
545    // a clearer layer_norm-specific error message.
546    assert_eq!(
547        gamma.dtype(),
548        input.dtype(),
549        "layer_norm: gamma dtype {:?} does not match input dtype {:?}",
550        gamma.dtype(),
551        input.dtype(),
552    );
553    if let Some(ref b) = beta {
554        assert_eq!(
555            b.dtype(),
556            input.dtype(),
557            "layer_norm: beta dtype {:?} does not match input dtype {:?}",
558            b.dtype(),
559            input.dtype(),
560        );
561    }
562    let input = input.to_contiguous();
563    let gamma = gamma.to_contiguous();
564    let beta = beta.map(|b| b.to_contiguous());
565
566    let d_model = *input
567        .layout()
568        .shape()
569        .last()
570        .expect("layer_norm: empty shape");
571    // Validate rank + length explicitly rather than just last-dim == d_model.
572    // A gamma shaped like `[2, d_model]` would pass a last-dim check but
573    // has 2*d_model elements, which would index the wrong data in the row
574    // kernel (caught by an inner assert, but with a confusing message).
575    let gamma_shape = gamma.layout().shape();
576    assert!(
577        gamma_shape.len() == 1 && gamma_shape[0] == d_model,
578        "layer_norm: gamma must be a 1-D tensor of length equal to last dim of input \
579         (got shape {:?}, expected [{}])",
580        gamma_shape,
581        d_model,
582    );
583    if let Some(ref b) = beta {
584        let beta_shape = b.layout().shape();
585        assert!(
586            beta_shape.len() == 1 && beta_shape[0] == d_model,
587            "layer_norm: beta must be a 1-D tensor of length equal to last dim of input \
588             (got shape {:?}, expected [{}])",
589            beta_shape,
590            d_model,
591        );
592    }
593
594    match input.dtype() {
595        DType::F32 => layer_norm_f32(input, gamma, beta, epsilon as f32),
596        DType::F64 => layer_norm_f64(input, gamma, beta, epsilon),
597        DType::F16 => {
598            layer_norm_via_f32::<f16>(input, gamma, beta, epsilon, f16::to_f32, f16::from_f32)
599        }
600        DType::BF16 => {
601            layer_norm_via_f32::<bf16>(input, gamma, beta, epsilon, bf16::to_f32, bf16::from_f32)
602        }
603        dtype => panic!("burn_flex::layer_norm: unsupported dtype {:?}", dtype),
604    }
605}
606
607fn layer_norm_via_f32<E: burn_backend::Element + bytemuck::Pod + Copy>(
608    input: FlexTensor,
609    gamma: FlexTensor,
610    beta: Option<FlexTensor>,
611    epsilon: f64,
612    to_f32: fn(E) -> f32,
613    from_f32: fn(f32) -> E,
614) -> FlexTensor {
615    let input_f32 = crate::ops::module::cast_to_f32::<E>(input, to_f32);
616    let gamma_f32 = crate::ops::module::cast_to_f32::<E>(gamma, to_f32);
617    let beta_f32 = beta.map(|b| crate::ops::module::cast_to_f32::<E>(b, to_f32));
618    let out = layer_norm_f32(input_f32, gamma_f32, beta_f32, epsilon as f32);
619    crate::ops::module::cast_from_f32::<E>(out, from_f32)
620}
621
622/// Fused f64 layer_norm. The Welford mean/variance pass is serial (the
623/// mean update on iteration `k` depends on iteration `k-1`); the
624/// normalize+affine pass autovectorizes on targets with f64 SIMD. A
625/// macerator f64 path can be added if profiling shows it matters.
626fn layer_norm_f64(
627    input: FlexTensor,
628    gamma: FlexTensor,
629    beta: Option<FlexTensor>,
630    epsilon: f64,
631) -> FlexTensor {
632    let shape = input.layout().shape().clone();
633    let d_model = *shape.last().expect("layer_norm: empty shape");
634    if d_model == 0 {
635        return input;
636    }
637    let input_data: &[f64] = input.storage();
638    let gamma_data: &[f64] = gamma.storage();
639    let beta_data: Option<&[f64]> = beta.as_ref().map(|b| b.storage());
640    let mut output: Vec<f64> = vec![0.0; input_data.len()];
641
642    #[cfg(feature = "rayon")]
643    {
644        use rayon::prelude::*;
645        const ROWS_PER_TASK: usize = 64;
646        let chunk_elems = ROWS_PER_TASK * d_model;
647        match beta_data {
648            Some(beta_slice) => {
649                output
650                    .par_chunks_mut(chunk_elems)
651                    .zip(input_data.par_chunks(chunk_elems))
652                    .for_each(|(o, i)| {
653                        layer_norm_rows_f64_with_beta(
654                            i, o, gamma_data, beta_slice, d_model, epsilon,
655                        );
656                    });
657            }
658            None => {
659                output
660                    .par_chunks_mut(chunk_elems)
661                    .zip(input_data.par_chunks(chunk_elems))
662                    .for_each(|(o, i)| {
663                        layer_norm_rows_f64_no_beta(i, o, gamma_data, d_model, epsilon);
664                    });
665            }
666        }
667    }
668    #[cfg(not(feature = "rayon"))]
669    {
670        match beta_data {
671            Some(beta_slice) => layer_norm_rows_f64_with_beta(
672                input_data,
673                output.as_mut_slice(),
674                gamma_data,
675                beta_slice,
676                d_model,
677                epsilon,
678            ),
679            None => layer_norm_rows_f64_no_beta(
680                input_data,
681                output.as_mut_slice(),
682                gamma_data,
683                d_model,
684                epsilon,
685            ),
686        }
687    }
688
689    FlexTensor::new(
690        Bytes::from_elems(output),
691        Layout::contiguous(shape),
692        DType::F64,
693    )
694}
695
696#[inline]
697fn layer_norm_rows_f64_with_beta(
698    input: &[f64],
699    output: &mut [f64],
700    gamma: &[f64],
701    beta: &[f64],
702    d_model: usize,
703    epsilon: f64,
704) {
705    for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
706        let (mean, inv_std) = welford_f64(in_row, epsilon);
707        for (i, &x) in in_row.iter().enumerate() {
708            out_row[i] = (x - mean) * (inv_std * gamma[i]) + beta[i];
709        }
710    }
711}
712
713#[inline]
714fn layer_norm_rows_f64_no_beta(
715    input: &[f64],
716    output: &mut [f64],
717    gamma: &[f64],
718    d_model: usize,
719    epsilon: f64,
720) {
721    for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
722        let (mean, inv_std) = welford_f64(in_row, epsilon);
723        for (i, &x) in in_row.iter().enumerate() {
724            out_row[i] = (x - mean) * (inv_std * gamma[i]);
725        }
726    }
727}
728
729#[inline]
730fn welford_f64(row: &[f64], epsilon: f64) -> (f64, f64) {
731    let mut mean = 0.0f64;
732    let mut m2 = 0.0f64;
733    for (k, &x) in row.iter().enumerate() {
734        let n_k = (k + 1) as f64;
735        let delta = x - mean;
736        mean += delta / n_k;
737        m2 += delta * (x - mean);
738    }
739    let var = m2 / row.len() as f64;
740    (mean, 1.0f64 / (var + epsilon).sqrt())
741}
742
743fn layer_norm_f32(
744    input: FlexTensor,
745    gamma: FlexTensor,
746    beta: Option<FlexTensor>,
747    epsilon: f32,
748) -> FlexTensor {
749    let shape = input.layout().shape().clone();
750    let d_model = *shape.last().expect("layer_norm: empty shape");
751    if d_model == 0 {
752        return input;
753    }
754
755    let input_data: &[f32] = input.storage();
756    let gamma_data: &[f32] = gamma.storage();
757    let beta_data: Option<&[f32]> = beta.as_ref().map(|b| b.storage());
758
759    let n = input_data.len();
760    // See softmax_last_f32 for the rationale on zero-init instead of
761    // `spare_capacity_mut` + `&mut [f32]` cast: the latter creates a
762    // reference to uninitialized f32 values, which is UB under Rust's
763    // aliasing model even with no intervening read.
764    let mut output: Vec<f32> = vec![0.0; n];
765    let out_slice = output.as_mut_slice();
766
767    // `#[macerator::with_simd]` can't auto-lifetime through
768    // `Option<&[T]>`, so we dispatch two separate monomorphized
769    // versions, one with beta and one without. Both call into the
770    // same shared row kernel.
771    #[cfg(feature = "rayon")]
772    {
773        use rayon::prelude::*;
774        const ROWS_PER_TASK: usize = 64;
775        let chunk_elems = ROWS_PER_TASK * d_model;
776        match beta_data {
777            Some(beta_slice) => {
778                out_slice
779                    .par_chunks_mut(chunk_elems)
780                    .zip(input_data.par_chunks(chunk_elems))
781                    .for_each(|(o, i)| {
782                        layer_norm_rows_f32_with_beta(
783                            i, o, gamma_data, beta_slice, d_model, epsilon,
784                        );
785                    });
786            }
787            None => {
788                out_slice
789                    .par_chunks_mut(chunk_elems)
790                    .zip(input_data.par_chunks(chunk_elems))
791                    .for_each(|(o, i)| {
792                        layer_norm_rows_f32_no_beta(i, o, gamma_data, d_model, epsilon);
793                    });
794            }
795        }
796    }
797    #[cfg(not(feature = "rayon"))]
798    {
799        match beta_data {
800            Some(beta_slice) => layer_norm_rows_f32_with_beta(
801                input_data, out_slice, gamma_data, beta_slice, d_model, epsilon,
802            ),
803            None => {
804                layer_norm_rows_f32_no_beta(input_data, out_slice, gamma_data, d_model, epsilon)
805            }
806        }
807    }
808
809    FlexTensor::new(
810        Bytes::from_elems(output),
811        Layout::contiguous(shape),
812        DType::F32,
813    )
814}
815
816/// Row sweep for f32 layer_norm with bias. Delegates to the SIMD kernel
817/// when the `simd` feature is enabled; otherwise uses a scalar row loop.
818#[inline]
819fn layer_norm_rows_f32_with_beta(
820    input: &[f32],
821    output: &mut [f32],
822    gamma: &[f32],
823    beta: &[f32],
824    d_model: usize,
825    epsilon: f32,
826) {
827    // Release-mode invariant checks; see softmax_rows_f32 for rationale.
828    assert_eq!(input.len(), output.len());
829    assert_eq!(input.len() % d_model, 0);
830    assert_eq!(gamma.len(), d_model);
831    assert_eq!(beta.len(), d_model);
832    #[cfg(feature = "simd")]
833    layer_norm_rows_f32_with_beta_simd(input, output, gamma, beta, d_model, epsilon);
834    #[cfg(not(feature = "simd"))]
835    {
836        for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
837            layer_norm_row_f32_scalar(in_row, out_row, gamma, Some(beta), epsilon);
838        }
839    }
840}
841
842/// Row sweep for f32 layer_norm without bias.
843#[inline]
844fn layer_norm_rows_f32_no_beta(
845    input: &[f32],
846    output: &mut [f32],
847    gamma: &[f32],
848    d_model: usize,
849    epsilon: f32,
850) {
851    // Release-mode invariant checks; see softmax_rows_f32 for rationale.
852    assert_eq!(input.len(), output.len());
853    assert_eq!(input.len() % d_model, 0);
854    assert_eq!(gamma.len(), d_model);
855    #[cfg(feature = "simd")]
856    layer_norm_rows_f32_no_beta_simd(input, output, gamma, d_model, epsilon);
857    #[cfg(not(feature = "simd"))]
858    {
859        for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
860            layer_norm_row_f32_scalar(in_row, out_row, gamma, None, epsilon);
861        }
862    }
863}
864
865/// Scalar fallback row kernel for layer_norm when the `simd` feature is
866/// disabled. Two-pass algorithm matching the SIMD version (sum+sumsq,
867/// then normalize+affine).
868#[cfg(not(feature = "simd"))]
869#[inline]
870fn layer_norm_row_f32_scalar(
871    input: &[f32],
872    output: &mut [f32],
873    gamma: &[f32],
874    beta: Option<&[f32]>,
875    epsilon: f32,
876) {
877    // Welford's online algorithm for mean and variance, rather than the
878    // `sumsq / n - mean * mean` identity the SIMD path uses. The identity
879    // is vulnerable to catastrophic cancellation when the two terms are
880    // close in magnitude (large mean relative to variance). Welford's
881    // single-pass formulation avoids that by tracking the running mean
882    // and accumulating squared deviations from it. The scalar path is
883    // the contract used when `simd` is disabled, so we prefer numerical
884    // stability over bit-for-bit match with the SIMD tree reduction.
885    let len = input.len();
886    let mut mean = 0.0f32;
887    let mut m2 = 0.0f32;
888    for (k, &x) in input.iter().enumerate() {
889        let n_k = (k + 1) as f32;
890        let delta = x - mean;
891        mean += delta / n_k;
892        let delta2 = x - mean;
893        m2 += delta * delta2;
894    }
895    let var = m2 / len as f32;
896    let inv_std = 1.0f32 / (var + epsilon).sqrt();
897    for (i, &x) in input.iter().enumerate() {
898        let scale = inv_std * gamma[i];
899        let normed = (x - mean) * scale;
900        output[i] = match beta {
901            Some(b) => normed + b[i],
902            None => normed,
903        };
904    }
905}
906
907/// SIMD-dispatched row sweep for f32 layer_norm with bias (beta). One
908/// macerator dispatch per chunk of rows, amortized over the whole chunk.
909#[cfg(feature = "simd")]
910#[macerator::with_simd]
911fn layer_norm_rows_f32_with_beta_simd<S: macerator::Simd>(
912    input: &[f32],
913    output: &mut [f32],
914    gamma: &[f32],
915    beta: &[f32],
916    d_model: usize,
917    epsilon: f32,
918) {
919    debug_assert_eq!(input.len(), output.len());
920    debug_assert_eq!(input.len() % d_model, 0);
921    debug_assert_eq!(gamma.len(), d_model);
922    debug_assert_eq!(beta.len(), d_model);
923    for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
924        layer_norm_row_f32_simd::<S>(in_row, out_row, gamma, Some(beta), epsilon);
925    }
926}
927
928/// SIMD-dispatched row sweep for f32 layer_norm without bias.
929#[cfg(feature = "simd")]
930#[macerator::with_simd]
931fn layer_norm_rows_f32_no_beta_simd<S: macerator::Simd>(
932    input: &[f32],
933    output: &mut [f32],
934    gamma: &[f32],
935    d_model: usize,
936    epsilon: f32,
937) {
938    debug_assert_eq!(input.len(), output.len());
939    debug_assert_eq!(input.len() % d_model, 0);
940    debug_assert_eq!(gamma.len(), d_model);
941    for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
942        layer_norm_row_f32_simd::<S>(in_row, out_row, gamma, None, epsilon);
943    }
944}
945
946/// Single-row layer_norm kernel. Two vectorized passes.
947#[cfg(feature = "simd")]
948#[inline(always)]
949fn layer_norm_row_f32_simd<S: macerator::Simd>(
950    input: &[f32],
951    output: &mut [f32],
952    gamma: &[f32],
953    beta: Option<&[f32]>,
954    epsilon: f32,
955) {
956    use macerator::{Scalar, vload_unaligned, vstore_unaligned};
957    let lanes = <f32 as Scalar>::lanes::<S>();
958    let len = input.len();
959    let simd_len = len / lanes * lanes;
960
961    // Pass 1: compute sum and sum-of-squares in one sweep, then derive
962    // mean and variance. Two independent SIMD accumulators (sum, sumsq)
963    // expose ILP to the two FMA ports.
964    let (sum, sumsq) = if simd_len >= lanes {
965        let mut acc_sum = 0.0f32.splat::<S>();
966        let mut acc_sumsq = 0.0f32.splat::<S>();
967        let mut i = 0;
968        while i < simd_len {
969            unsafe {
970                let v = vload_unaligned::<S, _>(input.as_ptr().add(i));
971                acc_sum += v;
972                // acc_sumsq += v * v; Vector::mul_add(self, a, b) = self*a + b,
973                // so v.mul_add(v, acc_sumsq) = v*v + acc_sumsq.
974                acc_sumsq = v.mul_add(v, acc_sumsq);
975            }
976            i += lanes;
977        }
978        let mut s = acc_sum.reduce_add();
979        let mut sq = acc_sumsq.reduce_add();
980        for &x in &input[simd_len..] {
981            s += x;
982            sq += x * x;
983        }
984        (s, sq)
985    } else {
986        let mut s = 0.0f32;
987        let mut sq = 0.0f32;
988        for &x in input {
989            s += x;
990            sq += x * x;
991        }
992        (s, sq)
993    };
994
995    let n = len as f32;
996    let mean = sum / n;
997    // Biased variance: E[x^2] - E[x]^2. Matches burn::nn::LayerNorm which
998    // uses var_mean_bias (the biased estimator) rather than Bessel's
999    // correction.
1000    let var = (sumsq / n) - mean * mean;
1001    let inv_std = 1.0f32 / (var + epsilon).sqrt();
1002
1003    // Pass 2: normalize and affine transform.
1004    //   out[i] = (x[i] - mean) * inv_std * gamma[i] + beta[i]
1005    // mean_vec and inv_std_vec are hoisted outside the loop (one splat
1006    // each per row). gamma and beta are read once per element; both
1007    // fit in L1 and are shared across all rows within a rayon chunk.
1008    let mean_vec = mean.splat::<S>();
1009    let inv_std_vec = inv_std.splat::<S>();
1010    let mut i = 0;
1011    while i < simd_len {
1012        unsafe {
1013            let x = vload_unaligned::<S, _>(input.as_ptr().add(i));
1014            let g = vload_unaligned::<S, _>(gamma.as_ptr().add(i));
1015            // scale = inv_std * g
1016            let scale = inv_std_vec * g;
1017            // centered = x - mean
1018            let centered = x - mean_vec;
1019            // out = centered * scale  (+ beta if present)
1020            let normed = centered * scale;
1021            let out = if let Some(b) = beta {
1022                let b_vec = vload_unaligned::<S, _>(b.as_ptr().add(i));
1023                normed + b_vec
1024            } else {
1025                normed
1026            };
1027            vstore_unaligned::<S, _>(output.as_mut_ptr().add(i), out);
1028        }
1029        i += lanes;
1030    }
1031    // Scalar tail
1032    while i < len {
1033        let centered = input[i] - mean;
1034        let normed = centered * inv_std * gamma[i];
1035        output[i] = match beta {
1036            Some(b) => normed + b[i],
1037            None => normed,
1038        };
1039        i += 1;
1040    }
1041}
1042
1043// Tests kept here exercise flex-specific behavior: SIMD boundaries, rayon
1044// chunk boundaries, non-contiguous input handling, the flex-internal
1045// layer_norm op (no public API yet), and dtype-specific fused softmax
1046// paths (f16/bf16/f64). Plain activation/softmax smoke tests have been
1047// migrated to burn-backend-tests so they cover every backend. When adding
1048// new tests, keep them here only if they probe flex internals; otherwise
1049// add them to crates/burn-backend-tests/tests/tensor/float/activation/.
1050#[cfg(test)]
1051mod tests {
1052    use alloc::vec;
1053    use burn_backend::{DType, TensorData, TensorMetadata, Tolerance};
1054    use burn_std::{bf16, f16};
1055    use num_traits::Float;
1056
1057    use crate::FlexTensor;
1058
1059    // ============================================================================
1060    // Reference implementations (per-row, last-axis).
1061    //
1062    // These mirror the contract the fused kernel commits to: stable softmax via
1063    // (x - max), layer_norm via (x - mean) * inv(sqrt(var + eps)) with optional
1064    // affine. Written in plain Rust over f32/f64 slices so the tests avoid any
1065    // tensor-library dependency.
1066    // ============================================================================
1067
1068    fn softmax_row<T: Float>(row_in: &[T], row_out: &mut [T]) {
1069        let max = row_in
1070            .iter()
1071            .copied()
1072            .fold(T::neg_infinity(), |a, b| if a > b { a } else { b });
1073        let mut sum = T::zero();
1074        for (i, &x) in row_in.iter().enumerate() {
1075            let e = (x - max).exp();
1076            row_out[i] = e;
1077            sum = sum + e;
1078        }
1079        for v in row_out.iter_mut() {
1080            *v = *v / sum;
1081        }
1082    }
1083
1084    fn softmax_last_ref<T: Float>(data: &[T], row_len: usize) -> Vec<T> {
1085        let mut out = vec![T::zero(); data.len()];
1086        for (i, o) in data.chunks(row_len).zip(out.chunks_mut(row_len)) {
1087            softmax_row(i, o);
1088        }
1089        out
1090    }
1091
1092    fn layer_norm_row<T: Float>(
1093        row_in: &[T],
1094        gamma: &[T],
1095        beta: Option<&[T]>,
1096        eps: T,
1097        row_out: &mut [T],
1098    ) {
1099        let n = T::from(row_in.len()).unwrap();
1100        let mean = row_in.iter().copied().fold(T::zero(), |a, b| a + b) / n;
1101        let var = row_in
1102            .iter()
1103            .map(|&x| (x - mean) * (x - mean))
1104            .fold(T::zero(), |a, b| a + b)
1105            / n;
1106        let inv_std = T::one() / (var + eps).sqrt();
1107        for (i, &x) in row_in.iter().enumerate() {
1108            let normed = (x - mean) * inv_std;
1109            let scaled = normed * gamma[i];
1110            row_out[i] = match beta {
1111                Some(b) => scaled + b[i],
1112                None => scaled,
1113            };
1114        }
1115    }
1116
1117    fn layer_norm_last_ref<T: Float>(
1118        data: &[T],
1119        gamma: &[T],
1120        beta: Option<&[T]>,
1121        eps: T,
1122        row_len: usize,
1123    ) -> Vec<T> {
1124        let mut out = vec![T::zero(); data.len()];
1125        for (i, o) in data.chunks(row_len).zip(out.chunks_mut(row_len)) {
1126            layer_norm_row(i, gamma, beta, eps, o);
1127        }
1128        out
1129    }
1130
1131    // ============================================================================
1132    // Helpers: FlexTensor constructors for typed inputs.
1133    // ============================================================================
1134
1135    fn flex_f32(data: Vec<f32>, shape: &[usize]) -> FlexTensor {
1136        FlexTensor::from_data(TensorData::new(data, shape.to_vec()))
1137    }
1138
1139    fn flex_f64(data: Vec<f64>, shape: &[usize]) -> FlexTensor {
1140        FlexTensor::from_data(TensorData::new(data, shape.to_vec()))
1141    }
1142
1143    fn flex_half<T: burn_backend::Element>(data: Vec<T>, shape: &[usize]) -> FlexTensor {
1144        FlexTensor::from_data(TensorData::new(data, shape.to_vec()))
1145    }
1146
1147    // ============================================================================
1148    // layer_norm tests
1149    // ============================================================================
1150
1151    #[test]
1152    fn test_layer_norm_2d_with_beta() {
1153        let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]);
1154        let gamma = flex_f32(vec![1.0; 4], &[4]);
1155        let beta = flex_f32(vec![0.0; 4], &[4]);
1156        let out = crate::ops::activation::layer_norm(t, gamma, Some(beta), 1e-5);
1157
1158        let expected: Vec<f32> = vec![
1159            -1.3416408, -0.4472136, 0.4472136, 1.3416408, -1.3416408, -0.4472136, 0.4472136,
1160            1.3416408,
1161        ];
1162        out.into_data().assert_approx_eq::<f32>(
1163            &TensorData::new(expected, vec![2, 4]),
1164            Tolerance::absolute(1e-4),
1165        );
1166    }
1167
1168    #[test]
1169    fn test_layer_norm_with_affine() {
1170        let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1171        let gamma = flex_f32(vec![2.0, 0.5, 1.0, 3.0], &[4]);
1172        let beta = flex_f32(vec![1.0, -1.0, 0.0, 2.0], &[4]);
1173        let out = crate::ops::activation::layer_norm(t, gamma, Some(beta), 1e-5);
1174
1175        // normalized = [-1.3416, -0.4472, 0.4472, 1.3416]
1176        // affine: [-1.6833, -1.2236, 0.4472, 6.0249]
1177        out.into_data().assert_approx_eq::<f32>(
1178            &TensorData::new(vec![-1.6833, -1.2236, 0.4472, 6.0249], vec![1, 4]),
1179            Tolerance::absolute(1e-3),
1180        );
1181    }
1182
1183    #[test]
1184    fn test_layer_norm_no_beta() {
1185        let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1186        let gamma = flex_f32(vec![1.0; 4], &[4]);
1187        let out = crate::ops::activation::layer_norm(t, gamma, None, 1e-5);
1188
1189        out.into_data().assert_approx_eq::<f32>(
1190            &TensorData::new(
1191                vec![-1.3416408, -0.4472136, 0.4472136, 1.3416408],
1192                vec![1, 4],
1193            ),
1194            Tolerance::absolute(1e-4),
1195        );
1196    }
1197
1198    // ============================================================================
1199    // softmax SIMD / rayon boundary tests
1200    // ============================================================================
1201
1202    #[test]
1203    fn test_softmax_simd_body_row() {
1204        // Row length 32 ensures the SIMD body runs on every supported target:
1205        // NEON (lanes=4), AVX2 (lanes=8), AVX-512 (lanes=16), SIMD128 (lanes=4).
1206        let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
1207        let expected = softmax_last_ref(&data, 32);
1208        let fused = crate::ops::activation::softmax(flex_f32(data, &[1, 32]), 1);
1209        fused.into_data().assert_approx_eq::<f32>(
1210            &TensorData::new(expected, vec![1, 32]),
1211            Tolerance::absolute(1e-5),
1212        );
1213    }
1214
1215    #[test]
1216    fn test_softmax_multi_chunk_rayon() {
1217        // 100 rows > ROWS_PER_TASK (64) triggers the rayon par_chunks path.
1218        let data: Vec<f32> = (0..100 * 16).map(|i| ((i % 17) as f32) * 0.05).collect();
1219        let expected = softmax_last_ref(&data, 16);
1220        let fused = crate::ops::activation::softmax(flex_f32(data, &[100, 16]), 1);
1221        fused.into_data().assert_approx_eq::<f32>(
1222            &TensorData::new(expected, vec![100, 16]),
1223            Tolerance::absolute(1e-5),
1224        );
1225    }
1226
1227    #[test]
1228    fn test_softmax_f64() {
1229        // Exercises softmax_last_dtype! + softmax_row_native f64 path.
1230        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1231        let expected = softmax_last_ref(&data, 4);
1232        let fused = crate::ops::activation::softmax(flex_f64(data, &[2, 4]), 1);
1233        fused.into_data().assert_approx_eq::<f64>(
1234            &TensorData::new(expected, vec![2, 4]),
1235            Tolerance::absolute(1e-10),
1236        );
1237    }
1238
1239    #[test]
1240    fn test_softmax_f16() {
1241        // Exercises softmax_last_dtype! + softmax_row_half f16 path.
1242        let source: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1243        let data: Vec<f16> = source.iter().map(|&x| f16::from_f32(x)).collect();
1244        let expected = softmax_last_ref(&data, 4);
1245        let fused = crate::ops::activation::softmax(flex_half(data, &[2, 4]), 1);
1246        fused.into_data().assert_approx_eq::<f16>(
1247            &TensorData::new(expected, vec![2, 4]),
1248            Tolerance::absolute(1e-2),
1249        );
1250    }
1251
1252    #[test]
1253    fn test_softmax_bf16() {
1254        // Exercises softmax_last_dtype! + softmax_row_half bf16 path.
1255        let source: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1256        let data: Vec<bf16> = source.iter().map(|&x| bf16::from_f32(x)).collect();
1257        let expected = softmax_last_ref(&data, 4);
1258        let fused = crate::ops::activation::softmax(flex_half(data, &[2, 4]), 1);
1259        fused.into_data().assert_approx_eq::<bf16>(
1260            &TensorData::new(expected, vec![2, 4]),
1261            Tolerance::absolute(5e-2),
1262        );
1263    }
1264
1265    #[test]
1266    fn test_layer_norm_multi_chunk_rayon() {
1267        // 128 rows > ROWS_PER_TASK (64) triggers the rayon path.
1268        let data: Vec<f32> = (0..128 * 16).map(|i| ((i % 19) as f32) * 0.03).collect();
1269        let gamma_data: Vec<f32> = vec![1.0; 16];
1270        let beta_data: Vec<f32> = vec![0.0; 16];
1271        let expected = layer_norm_last_ref(&data, &gamma_data, Some(&beta_data), 1e-5f32, 16);
1272        let fused = crate::ops::activation::layer_norm(
1273            flex_f32(data, &[128, 16]),
1274            flex_f32(gamma_data, &[16]),
1275            Some(flex_f32(beta_data, &[16])),
1276            1e-5,
1277        );
1278        fused.into_data().assert_approx_eq::<f32>(
1279            &TensorData::new(expected, vec![128, 16]),
1280            Tolerance::absolute(1e-4),
1281        );
1282    }
1283
1284    #[test]
1285    fn test_softmax_empty_last_dim_returns_input() {
1286        // shape [2, 0]: empty last dim should round-trip unchanged instead
1287        // of producing NaN via 0/0.
1288        let t = flex_f32(Vec::<f32>::new(), &[2, 0]);
1289        let result = crate::ops::activation::softmax(t, 1);
1290        assert_eq!(result.shape().as_slice(), &[2, 0]);
1291    }
1292
1293    #[test]
1294    fn test_layer_norm_empty_last_dim_returns_input() {
1295        let t = flex_f32(Vec::<f32>::new(), &[3, 0]);
1296        let gamma = flex_f32(Vec::<f32>::new(), &[0]);
1297        let beta = flex_f32(Vec::<f32>::new(), &[0]);
1298        let result = crate::ops::activation::layer_norm(t, gamma, Some(beta), 1e-5);
1299        assert_eq!(result.shape().as_slice(), &[3, 0]);
1300    }
1301
1302    #[test]
1303    #[should_panic(expected = "gamma must be a 1-D tensor")]
1304    fn test_layer_norm_gamma_length_mismatch_panics() {
1305        let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1306        let gamma = flex_f32(vec![1.0, 1.0, 1.0], &[3]);
1307        let _ = crate::ops::activation::layer_norm(t, gamma, None, 1e-5);
1308    }
1309
1310    #[test]
1311    #[should_panic(expected = "beta must be a 1-D tensor")]
1312    fn test_layer_norm_beta_length_mismatch_panics() {
1313        let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1314        let gamma = flex_f32(vec![1.0, 1.0, 1.0, 1.0], &[4]);
1315        let beta = flex_f32(vec![0.0, 0.0, 0.0], &[3]);
1316        let _ = crate::ops::activation::layer_norm(t, gamma, Some(beta), 1e-5);
1317    }
1318
1319    #[test]
1320    #[should_panic(expected = "gamma must be a 1-D tensor")]
1321    fn test_layer_norm_gamma_rank_mismatch_panics() {
1322        // gamma [2, 4] has matching last-dim but rank 2, so the old last-dim
1323        // check alone would have accepted it and then indexed wrong storage.
1324        let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1325        let gamma = flex_f32(vec![1.0; 8], &[2, 4]);
1326        let _ = crate::ops::activation::layer_norm(t, gamma, None, 1e-5);
1327    }
1328
1329    // Row length 17 leaves exactly one scalar-tail element after every common
1330    // SIMD width (NEON/SSE f32x4: body=16, tail=1; AVX2 f32x8: body=16, tail=1;
1331    // AVX-512 f32x16: body=16, tail=1). Row lengths that divide evenly by the
1332    // SIMD width skip the tail branch entirely, so a bug in the scalar tail
1333    // kernel would sail past CI without a test like this.
1334    #[test]
1335    fn test_softmax_simd_body_plus_scalar_tail() {
1336        let data: Vec<f32> = (0..34).map(|i| (i as f32 * 0.137) - 2.3).collect();
1337        let expected = softmax_last_ref(&data, 17);
1338        let fused = crate::ops::activation::softmax(flex_f32(data, &[2, 17]), 1);
1339        fused.into_data().assert_approx_eq::<f32>(
1340            &TensorData::new(expected, vec![2, 17]),
1341            Tolerance::absolute(1e-5),
1342        );
1343    }
1344
1345    #[test]
1346    fn test_layer_norm_simd_body_plus_scalar_tail() {
1347        let data: Vec<f32> = (0..34).map(|i| (i as f32 * 0.137) - 2.3).collect();
1348        let gamma_data: Vec<f32> = (0..17).map(|i| 1.0 + i as f32 * 0.05).collect();
1349        let beta_data: Vec<f32> = (0..17).map(|i| i as f32 * 0.01).collect();
1350        let expected = layer_norm_last_ref(&data, &gamma_data, Some(&beta_data), 1e-5f32, 17);
1351        let fused = crate::ops::activation::layer_norm(
1352            flex_f32(data, &[2, 17]),
1353            flex_f32(gamma_data, &[17]),
1354            Some(flex_f32(beta_data, &[17])),
1355            1e-5,
1356        );
1357        fused.into_data().assert_approx_eq::<f32>(
1358            &TensorData::new(expected, vec![2, 17]),
1359            Tolerance::absolute(1e-5),
1360        );
1361    }
1362
1363    #[test]
1364    fn test_layer_norm_f64_with_beta_multi_chunk() {
1365        // 80 rows > ROWS_PER_TASK (64) exercises the rayon multi-chunk f64 path.
1366        let d_model = 16;
1367        let n_rows = 80;
1368        let data: Vec<f64> = (0..n_rows * d_model)
1369            .map(|i| ((i % 13) as f64) * 0.07 - 0.3)
1370            .collect();
1371        let gamma_data: Vec<f64> = vec![0.9; d_model];
1372        let beta_data: Vec<f64> = vec![0.05; d_model];
1373        let eps = 1e-5f64;
1374        let expected = layer_norm_last_ref(&data, &gamma_data, Some(&beta_data), eps, d_model);
1375        let fused = crate::ops::activation::layer_norm(
1376            flex_f64(data, &[n_rows, d_model]),
1377            flex_f64(gamma_data, &[d_model]),
1378            Some(flex_f64(beta_data, &[d_model])),
1379            eps,
1380        );
1381        fused.into_data().assert_approx_eq::<f64>(
1382            &TensorData::new(expected, vec![n_rows, d_model]),
1383            Tolerance::absolute(1e-10),
1384        );
1385    }
1386
1387    #[test]
1388    fn test_layer_norm_f64_no_beta() {
1389        let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, -1.0, 0.5, 1.5, -0.5];
1390        let gamma_data: Vec<f64> = vec![1.0; 4];
1391        let eps = 1e-5f64;
1392        let expected = layer_norm_last_ref(&data, &gamma_data, None, eps, 4);
1393        let fused = crate::ops::activation::layer_norm(
1394            flex_f64(data, &[2, 4]),
1395            flex_f64(gamma_data, &[4]),
1396            None,
1397            eps,
1398        );
1399        fused.into_data().assert_approx_eq::<f64>(
1400            &TensorData::new(expected, vec![2, 4]),
1401            Tolerance::absolute(1e-10),
1402        );
1403    }
1404
1405    // Shared body for f16/bf16 layer_norm tests. The fused half-precision
1406    // kernel casts to f32 internally, so the reference is computed in f32
1407    // and compared back against the half output with an f32 tolerance.
1408    fn check_layer_norm_half_precision<E>(from_f32: fn(f32) -> E, dtype: DType)
1409    where
1410        E: burn_backend::Element + Float,
1411    {
1412        let rows_f32: [f32; 12] = [
1413            1.0, 2.0, 3.0, 4.0, -1.0, 0.0, 1.0, 2.0, 0.5, -0.5, 1.5, -1.5,
1414        ];
1415        let gamma_f32: [f32; 4] = [1.0, 0.5, 1.5, 1.0];
1416        let beta_f32: [f32; 4] = [0.1, -0.1, 0.0, 0.2];
1417        let eps = 1e-5f32;
1418
1419        let expected_f32 = layer_norm_last_ref(&rows_f32, &gamma_f32, Some(&beta_f32), eps, 4);
1420
1421        let data: Vec<E> = rows_f32.iter().map(|&x| from_f32(x)).collect();
1422        let gamma_data: Vec<E> = gamma_f32.iter().map(|&x| from_f32(x)).collect();
1423        let beta_data: Vec<E> = beta_f32.iter().map(|&x| from_f32(x)).collect();
1424        assert_eq!(E::dtype(), dtype);
1425
1426        let fused = crate::ops::activation::layer_norm(
1427            flex_half(data, &[3, 4]),
1428            flex_half(gamma_data, &[4]),
1429            Some(flex_half(beta_data, &[4])),
1430            eps as f64,
1431        );
1432        fused.into_data().assert_approx_eq::<f32>(
1433            &TensorData::new(expected_f32, vec![3, 4]),
1434            Tolerance::absolute(3e-2),
1435        );
1436    }
1437
1438    #[test]
1439    fn test_layer_norm_f16_via_f32_cast() {
1440        check_layer_norm_half_precision::<f16>(f16::from_f32, DType::F16);
1441    }
1442
1443    #[test]
1444    fn test_layer_norm_bf16_via_f32_cast() {
1445        check_layer_norm_half_precision::<bf16>(bf16::from_f32, DType::BF16);
1446    }
1447
1448    #[test]
1449    #[should_panic(expected = "softmax dim")]
1450    fn test_softmax_dim_out_of_range_panics() {
1451        let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
1452        let _ = crate::ops::activation::softmax(t, 2);
1453    }
1454
1455    #[test]
1456    #[should_panic(expected = "gamma dtype")]
1457    fn test_layer_norm_gamma_dtype_mismatch_panics() {
1458        // Input f32, gamma f64: layer_norm rejects the mismatch up front
1459        // rather than panicking later inside the storage-typed access.
1460        let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1461        let gamma = flex_f64(vec![1.0; 4], &[4]);
1462        let _ = crate::ops::activation::layer_norm(t, gamma, None, 1e-5);
1463    }
1464}