1use alloc::vec;
7use alloc::vec::Vec;
8use burn_backend::Scalar;
9use burn_backend::ops::{ActivationOps, FloatTensorOps};
10use burn_backend::tensor::FloatTensor;
11use burn_backend::{DType, TensorMetadata};
12use burn_std::{Bytes, bf16, f16};
13#[cfg(not(feature = "std"))]
14#[allow(unused_imports)]
15use num_traits::Float;
16use num_traits::ToPrimitive;
17
18use crate::ops::binary::binary_op;
19use crate::ops::unary::unary_op;
20use crate::{Flex, FlexTensor, Layout};
21
22impl ActivationOps<Flex> for Flex {
23 fn relu(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
24 unary_op(tensor, |x: f32| x.max(0.0), |x: f64| x.max(0.0))
25 }
26
27 fn relu_backward(output: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
28 binary_op(
30 output,
31 grad,
32 |out: f32, g| if out > 0.0 { g } else { 0.0 },
33 |out: f64, g| if out > 0.0 { g } else { 0.0 },
34 None,
35 )
36 }
37
38 fn leaky_relu(tensor: FloatTensor<Flex>, negative_slope: Scalar) -> FloatTensor<Flex> {
39 let ns32 = negative_slope.to_f32().unwrap();
40 let ns64 = negative_slope.to_f64().unwrap();
41 unary_op(
42 tensor,
43 move |x: f32| if x >= 0.0 { x } else { ns32 * x },
44 move |x: f64| if x >= 0.0 { x } else { ns64 * x },
45 )
46 }
47
48 fn prelu(tensor: FloatTensor<Flex>, alpha: FloatTensor<Flex>) -> FloatTensor<Flex> {
49 binary_op(
51 tensor,
52 alpha,
53 |x: f32, a| if x >= 0.0 { x } else { a * x },
54 |x: f64, a| if x >= 0.0 { x } else { a * x },
55 None,
56 )
57 }
58
59 fn gelu(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
60 use crate::ops::unary::{erf_f32, erf_f64};
62 let sqrt2_f32: f32 = core::f32::consts::SQRT_2;
63 let sqrt2_f64: f64 = core::f64::consts::SQRT_2;
64 unary_op(
65 tensor,
66 move |x: f32| 0.5 * x * (1.0 + erf_f32(x / sqrt2_f32)),
67 move |x: f64| 0.5 * x * (1.0 + erf_f64(x / sqrt2_f64)),
68 )
69 }
70
71 fn gelu_backward(x: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
72 use crate::ops::unary::{erf_f32, erf_f64};
74 let sqrt2_f32: f32 = core::f32::consts::SQRT_2;
75 let sqrt2_f64: f64 = core::f64::consts::SQRT_2;
76 let inv_sqrt_2pi_f32: f32 = 1.0 / (2.0 * core::f32::consts::PI).sqrt();
77 let inv_sqrt_2pi_f64: f64 = 1.0 / (2.0 * core::f64::consts::PI).sqrt();
78 binary_op(
79 x,
80 grad,
81 move |x: f32, g| {
82 let cdf = 0.5 * (1.0 + erf_f32(x / sqrt2_f32));
83 let pdf = inv_sqrt_2pi_f32 * (-0.5 * x * x).exp();
84 g * (cdf + x * pdf)
85 },
86 move |x: f64, g| {
87 let cdf = 0.5 * (1.0 + erf_f64(x / sqrt2_f64));
88 let pdf = inv_sqrt_2pi_f64 * (-0.5 * x * x).exp();
89 g * (cdf + x * pdf)
90 },
91 None,
92 )
93 }
94
95 fn sigmoid(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
96 unary_op(tensor, sigmoid_f32, sigmoid_f64)
97 }
98
99 fn sigmoid_backward(output: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
100 binary_op(
102 output,
103 grad,
104 |s: f32, g| g * s * (1.0 - s),
105 |s: f64, g| g * s * (1.0 - s),
106 None,
107 )
108 }
109
110 fn hard_sigmoid(tensor: FloatTensor<Flex>, alpha: Scalar, beta: Scalar) -> FloatTensor<Flex> {
111 let alpha32 = alpha.to_f32().unwrap();
112 let beta32 = beta.to_f32().unwrap();
113 let alpha64 = alpha.to_f64().unwrap();
114 let beta64 = beta.to_f64().unwrap();
115 unary_op(
116 tensor,
117 move |x: f32| (alpha32 * x + beta32).clamp(0.0, 1.0),
118 move |x: f64| (alpha64 * x + beta64).clamp(0.0, 1.0),
119 )
120 }
121
122 fn log_sigmoid(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
123 unary_op(
127 tensor,
128 |x: f32| {
129 if x >= 0.0 {
130 -((-x).exp().ln_1p())
131 } else {
132 x - x.exp().ln_1p()
133 }
134 },
135 |x: f64| {
136 if x >= 0.0 {
137 -((-x).exp().ln_1p())
138 } else {
139 x - x.exp().ln_1p()
140 }
141 },
142 )
143 }
144
145 fn log_sigmoid_backward(x: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
146 binary_op(
149 x,
150 grad,
151 |x: f32, g| g * sigmoid_f32(-x),
152 |x: f64, g| g * sigmoid_f64(-x),
153 None,
154 )
155 }
156
157 fn softmax(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
158 softmax(tensor, dim)
159 }
160}
161
162#[inline]
163fn sigmoid_f32(x: f32) -> f32 {
164 if x >= 0.0 {
165 1.0 / (1.0 + (-x).exp())
166 } else {
167 let e = x.exp();
168 e / (1.0 + e)
169 }
170}
171
172#[inline]
173fn sigmoid_f64(x: f64) -> f64 {
174 if x >= 0.0 {
175 1.0 / (1.0 + (-x).exp())
176 } else {
177 let e = x.exp();
178 e / (1.0 + e)
179 }
180}
181
182pub fn softmax(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
205 let rank = tensor.shape().num_dims();
206 assert!(
207 dim < rank,
208 "softmax dim {} out of range for rank {}",
209 dim,
210 rank
211 );
212
213 if dim != rank - 1 {
214 let swapped = Flex::float_swap_dims(tensor, dim, rank - 1);
215 let normed = softmax_last(swapped);
216 return Flex::float_swap_dims(normed, dim, rank - 1);
217 }
218
219 softmax_last(tensor)
220}
221
222fn softmax_last(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
223 let tensor = tensor.to_contiguous();
224 match tensor.dtype() {
225 DType::F32 => softmax_last_f32(tensor),
226 DType::F64 => softmax_last_f64(tensor),
227 DType::F16 => softmax_last_f16(tensor),
228 DType::BF16 => softmax_last_bf16(tensor),
229 dtype => panic!("softmax: unsupported dtype {:?}", dtype),
230 }
231}
232
233fn softmax_last_f32(tensor: FlexTensor) -> FlexTensor {
234 let shape = tensor.layout().shape().clone();
235 let last = *shape.last().expect("softmax: empty shape");
236 if last == 0 {
237 return tensor;
238 }
239 let input: &[f32] = tensor.storage();
240 let n = input.len();
241
242 let mut output: Vec<f32> = vec![0.0; n];
254 let out_slice = output.as_mut_slice();
255
256 #[cfg(feature = "rayon")]
259 {
260 use rayon::prelude::*;
261 const ROWS_PER_TASK: usize = 64;
262 let chunk_elems = ROWS_PER_TASK * last;
263 out_slice
264 .par_chunks_mut(chunk_elems)
265 .zip(input.par_chunks(chunk_elems))
266 .for_each(|(o, i)| softmax_rows_f32(i, o, last));
267 }
268 #[cfg(not(feature = "rayon"))]
269 {
270 softmax_rows_f32(input, out_slice, last);
271 }
272
273 FlexTensor::new(
274 Bytes::from_elems(output),
275 Layout::contiguous(shape),
276 DType::F32,
277 )
278}
279
280#[inline]
285fn softmax_rows_f32(input: &[f32], output: &mut [f32], row_len: usize) {
286 assert_eq!(input.len(), output.len());
293 assert_eq!(input.len() % row_len, 0);
294 #[cfg(feature = "simd")]
295 softmax_rows_f32_simd(input, output, row_len);
296 #[cfg(not(feature = "simd"))]
297 {
298 for (in_row, out_row) in input.chunks(row_len).zip(output.chunks_mut(row_len)) {
299 softmax_row_f32_scalar(in_row, out_row);
300 }
301 }
302}
303
304#[cfg(feature = "simd")]
305#[macerator::with_simd]
306fn softmax_rows_f32_simd<S: macerator::Simd>(input: &[f32], output: &mut [f32], row_len: usize) {
307 debug_assert_eq!(input.len(), output.len());
308 debug_assert_eq!(input.len() % row_len, 0);
309 for (in_row, out_row) in input.chunks(row_len).zip(output.chunks_mut(row_len)) {
310 softmax_row_f32_simd::<S>(in_row, out_row);
311 }
312}
313
314#[cfg(not(feature = "simd"))]
318#[inline]
319fn softmax_row_f32_scalar(input: &[f32], output: &mut [f32]) {
320 let mut max_val = f32::NEG_INFINITY;
321 for &x in input {
322 if x > max_val {
323 max_val = x;
324 }
325 }
326 let mut sum = 0.0f32;
327 for (i, &x) in input.iter().enumerate() {
328 let e = (x - max_val).exp();
329 output[i] = e;
330 sum += e;
331 }
332 let inv = 1.0f32 / sum;
333 for x in output.iter_mut() {
334 *x *= inv;
335 }
336}
337
338#[cfg(feature = "simd")]
342#[inline(always)]
343fn softmax_row_f32_simd<S: macerator::Simd>(input: &[f32], output: &mut [f32]) {
344 use macerator::{Scalar, vload_unaligned, vstore_unaligned};
345 let lanes = <f32 as Scalar>::lanes::<S>();
346 let len = input.len();
347 let simd_len = len / lanes * lanes;
348
349 let (mut max_val, tail_start) = if simd_len >= lanes {
352 let mut max_vec = unsafe { vload_unaligned::<S, _>(input.as_ptr()) };
353 let mut j = lanes;
354 while j < simd_len {
355 let v = unsafe { vload_unaligned::<S, _>(input.as_ptr().add(j)) };
356 max_vec = max_vec.max(v);
357 j += lanes;
358 }
359 (max_vec.reduce_max(), simd_len)
360 } else {
361 (f32::NEG_INFINITY, 0)
362 };
363 for &x in &input[tail_start..] {
364 if x > max_val {
365 max_val = x;
366 }
367 }
368
369 let mut sum = 0.0f32;
374 for idx in 0..len {
375 let e = (input[idx] - max_val).exp();
376 output[idx] = e;
377 sum += e;
378 }
379
380 let inv = 1.0f32 / sum;
383 let inv_vec = inv.splat::<S>();
384 let mut i = 0;
385 while i < simd_len {
386 unsafe {
387 let v = vload_unaligned::<S, _>(output.as_ptr().add(i));
388 vstore_unaligned::<S, _>(output.as_mut_ptr().add(i), v * inv_vec);
389 }
390 i += lanes;
391 }
392 for x in &mut output[i..] {
393 *x *= inv;
394 }
395}
396
397macro_rules! softmax_last_dtype {
403 ($fn_name:ident, $T:ty, $zero:expr, $dtype:expr, $row_fn:ident) => {
404 fn $fn_name(tensor: FlexTensor) -> FlexTensor {
405 let shape = tensor.layout().shape().clone();
406 let last = *shape.last().expect("softmax: empty shape");
407 if last == 0 {
408 return tensor;
409 }
410 let input: &[$T] = tensor.storage();
411 let mut output: Vec<$T> = vec![$zero; input.len()];
412
413 #[cfg(feature = "rayon")]
414 {
415 use rayon::prelude::*;
416 output
417 .par_chunks_mut(last)
418 .zip(input.par_chunks(last))
419 .for_each(|(o, i)| $row_fn(i, o));
420 }
421 #[cfg(not(feature = "rayon"))]
422 {
423 for (i, o) in input.chunks(last).zip(output.chunks_mut(last)) {
424 $row_fn(i, o);
425 }
426 }
427
428 FlexTensor::new(Bytes::from_elems(output), Layout::contiguous(shape), $dtype)
429 }
430 };
431}
432
433macro_rules! softmax_row_half {
439 ($fn_name:ident, $T:ty) => {
440 #[inline]
441 fn $fn_name(input: &[$T], output: &mut [$T]) {
442 let mut max_val = f32::NEG_INFINITY;
443 for &x in input {
444 let xf = x.to_f32();
445 if xf > max_val {
446 max_val = xf;
447 }
448 }
449 let mut sum = 0.0f32;
450 for (i, &x) in input.iter().enumerate() {
451 let e = (x.to_f32() - max_val).exp();
452 output[i] = <$T>::from_f32(e);
453 sum += e;
454 }
455 let inv = 1.0f32 / sum;
456 for x in output.iter_mut() {
457 *x = <$T>::from_f32(x.to_f32() * inv);
458 }
459 }
460 };
461}
462
463#[inline]
464fn softmax_row_f64(input: &[f64], output: &mut [f64]) {
465 let mut max_val = f64::NEG_INFINITY;
466 for &x in input {
467 if x > max_val {
468 max_val = x;
469 }
470 }
471 let mut sum = 0.0f64;
472 for (i, &x) in input.iter().enumerate() {
473 let e = (x - max_val).exp();
474 output[i] = e;
475 sum += e;
476 }
477 let inv = 1.0f64 / sum;
478 for x in output.iter_mut() {
479 *x *= inv;
480 }
481}
482
483softmax_row_half!(softmax_row_f16, f16);
484softmax_row_half!(softmax_row_bf16, bf16);
485
486softmax_last_dtype!(softmax_last_f64, f64, 0.0f64, DType::F64, softmax_row_f64);
487softmax_last_dtype!(
488 softmax_last_f16,
489 f16,
490 f16::from_f32(0.0),
491 DType::F16,
492 softmax_row_f16
493);
494softmax_last_dtype!(
495 softmax_last_bf16,
496 bf16,
497 bf16::from_f32(0.0),
498 DType::BF16,
499 softmax_row_bf16
500);
501
502pub fn layer_norm(
535 input: FloatTensor<Flex>,
536 gamma: FloatTensor<Flex>,
537 beta: Option<FloatTensor<Flex>>,
538 epsilon: f64,
539) -> FloatTensor<Flex> {
540 let rank = input.shape().num_dims();
541 assert!(rank >= 1, "layer_norm: input must have at least one dim");
542 assert_eq!(
547 gamma.dtype(),
548 input.dtype(),
549 "layer_norm: gamma dtype {:?} does not match input dtype {:?}",
550 gamma.dtype(),
551 input.dtype(),
552 );
553 if let Some(ref b) = beta {
554 assert_eq!(
555 b.dtype(),
556 input.dtype(),
557 "layer_norm: beta dtype {:?} does not match input dtype {:?}",
558 b.dtype(),
559 input.dtype(),
560 );
561 }
562 let input = input.to_contiguous();
563 let gamma = gamma.to_contiguous();
564 let beta = beta.map(|b| b.to_contiguous());
565
566 let d_model = *input
567 .layout()
568 .shape()
569 .last()
570 .expect("layer_norm: empty shape");
571 let gamma_shape = gamma.layout().shape();
576 assert!(
577 gamma_shape.len() == 1 && gamma_shape[0] == d_model,
578 "layer_norm: gamma must be a 1-D tensor of length equal to last dim of input \
579 (got shape {:?}, expected [{}])",
580 gamma_shape,
581 d_model,
582 );
583 if let Some(ref b) = beta {
584 let beta_shape = b.layout().shape();
585 assert!(
586 beta_shape.len() == 1 && beta_shape[0] == d_model,
587 "layer_norm: beta must be a 1-D tensor of length equal to last dim of input \
588 (got shape {:?}, expected [{}])",
589 beta_shape,
590 d_model,
591 );
592 }
593
594 match input.dtype() {
595 DType::F32 => layer_norm_f32(input, gamma, beta, epsilon as f32),
596 DType::F64 => layer_norm_f64(input, gamma, beta, epsilon),
597 DType::F16 => {
598 layer_norm_via_f32::<f16>(input, gamma, beta, epsilon, f16::to_f32, f16::from_f32)
599 }
600 DType::BF16 => {
601 layer_norm_via_f32::<bf16>(input, gamma, beta, epsilon, bf16::to_f32, bf16::from_f32)
602 }
603 dtype => panic!("burn_flex::layer_norm: unsupported dtype {:?}", dtype),
604 }
605}
606
607fn layer_norm_via_f32<E: burn_backend::Element + bytemuck::Pod + Copy>(
608 input: FlexTensor,
609 gamma: FlexTensor,
610 beta: Option<FlexTensor>,
611 epsilon: f64,
612 to_f32: fn(E) -> f32,
613 from_f32: fn(f32) -> E,
614) -> FlexTensor {
615 let input_f32 = crate::ops::module::cast_to_f32::<E>(input, to_f32);
616 let gamma_f32 = crate::ops::module::cast_to_f32::<E>(gamma, to_f32);
617 let beta_f32 = beta.map(|b| crate::ops::module::cast_to_f32::<E>(b, to_f32));
618 let out = layer_norm_f32(input_f32, gamma_f32, beta_f32, epsilon as f32);
619 crate::ops::module::cast_from_f32::<E>(out, from_f32)
620}
621
622fn layer_norm_f64(
627 input: FlexTensor,
628 gamma: FlexTensor,
629 beta: Option<FlexTensor>,
630 epsilon: f64,
631) -> FlexTensor {
632 let shape = input.layout().shape().clone();
633 let d_model = *shape.last().expect("layer_norm: empty shape");
634 if d_model == 0 {
635 return input;
636 }
637 let input_data: &[f64] = input.storage();
638 let gamma_data: &[f64] = gamma.storage();
639 let beta_data: Option<&[f64]> = beta.as_ref().map(|b| b.storage());
640 let mut output: Vec<f64> = vec![0.0; input_data.len()];
641
642 #[cfg(feature = "rayon")]
643 {
644 use rayon::prelude::*;
645 const ROWS_PER_TASK: usize = 64;
646 let chunk_elems = ROWS_PER_TASK * d_model;
647 match beta_data {
648 Some(beta_slice) => {
649 output
650 .par_chunks_mut(chunk_elems)
651 .zip(input_data.par_chunks(chunk_elems))
652 .for_each(|(o, i)| {
653 layer_norm_rows_f64_with_beta(
654 i, o, gamma_data, beta_slice, d_model, epsilon,
655 );
656 });
657 }
658 None => {
659 output
660 .par_chunks_mut(chunk_elems)
661 .zip(input_data.par_chunks(chunk_elems))
662 .for_each(|(o, i)| {
663 layer_norm_rows_f64_no_beta(i, o, gamma_data, d_model, epsilon);
664 });
665 }
666 }
667 }
668 #[cfg(not(feature = "rayon"))]
669 {
670 match beta_data {
671 Some(beta_slice) => layer_norm_rows_f64_with_beta(
672 input_data,
673 output.as_mut_slice(),
674 gamma_data,
675 beta_slice,
676 d_model,
677 epsilon,
678 ),
679 None => layer_norm_rows_f64_no_beta(
680 input_data,
681 output.as_mut_slice(),
682 gamma_data,
683 d_model,
684 epsilon,
685 ),
686 }
687 }
688
689 FlexTensor::new(
690 Bytes::from_elems(output),
691 Layout::contiguous(shape),
692 DType::F64,
693 )
694}
695
696#[inline]
697fn layer_norm_rows_f64_with_beta(
698 input: &[f64],
699 output: &mut [f64],
700 gamma: &[f64],
701 beta: &[f64],
702 d_model: usize,
703 epsilon: f64,
704) {
705 for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
706 let (mean, inv_std) = welford_f64(in_row, epsilon);
707 for (i, &x) in in_row.iter().enumerate() {
708 out_row[i] = (x - mean) * (inv_std * gamma[i]) + beta[i];
709 }
710 }
711}
712
713#[inline]
714fn layer_norm_rows_f64_no_beta(
715 input: &[f64],
716 output: &mut [f64],
717 gamma: &[f64],
718 d_model: usize,
719 epsilon: f64,
720) {
721 for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
722 let (mean, inv_std) = welford_f64(in_row, epsilon);
723 for (i, &x) in in_row.iter().enumerate() {
724 out_row[i] = (x - mean) * (inv_std * gamma[i]);
725 }
726 }
727}
728
729#[inline]
730fn welford_f64(row: &[f64], epsilon: f64) -> (f64, f64) {
731 let mut mean = 0.0f64;
732 let mut m2 = 0.0f64;
733 for (k, &x) in row.iter().enumerate() {
734 let n_k = (k + 1) as f64;
735 let delta = x - mean;
736 mean += delta / n_k;
737 m2 += delta * (x - mean);
738 }
739 let var = m2 / row.len() as f64;
740 (mean, 1.0f64 / (var + epsilon).sqrt())
741}
742
743fn layer_norm_f32(
744 input: FlexTensor,
745 gamma: FlexTensor,
746 beta: Option<FlexTensor>,
747 epsilon: f32,
748) -> FlexTensor {
749 let shape = input.layout().shape().clone();
750 let d_model = *shape.last().expect("layer_norm: empty shape");
751 if d_model == 0 {
752 return input;
753 }
754
755 let input_data: &[f32] = input.storage();
756 let gamma_data: &[f32] = gamma.storage();
757 let beta_data: Option<&[f32]> = beta.as_ref().map(|b| b.storage());
758
759 let n = input_data.len();
760 let mut output: Vec<f32> = vec![0.0; n];
765 let out_slice = output.as_mut_slice();
766
767 #[cfg(feature = "rayon")]
772 {
773 use rayon::prelude::*;
774 const ROWS_PER_TASK: usize = 64;
775 let chunk_elems = ROWS_PER_TASK * d_model;
776 match beta_data {
777 Some(beta_slice) => {
778 out_slice
779 .par_chunks_mut(chunk_elems)
780 .zip(input_data.par_chunks(chunk_elems))
781 .for_each(|(o, i)| {
782 layer_norm_rows_f32_with_beta(
783 i, o, gamma_data, beta_slice, d_model, epsilon,
784 );
785 });
786 }
787 None => {
788 out_slice
789 .par_chunks_mut(chunk_elems)
790 .zip(input_data.par_chunks(chunk_elems))
791 .for_each(|(o, i)| {
792 layer_norm_rows_f32_no_beta(i, o, gamma_data, d_model, epsilon);
793 });
794 }
795 }
796 }
797 #[cfg(not(feature = "rayon"))]
798 {
799 match beta_data {
800 Some(beta_slice) => layer_norm_rows_f32_with_beta(
801 input_data, out_slice, gamma_data, beta_slice, d_model, epsilon,
802 ),
803 None => {
804 layer_norm_rows_f32_no_beta(input_data, out_slice, gamma_data, d_model, epsilon)
805 }
806 }
807 }
808
809 FlexTensor::new(
810 Bytes::from_elems(output),
811 Layout::contiguous(shape),
812 DType::F32,
813 )
814}
815
816#[inline]
819fn layer_norm_rows_f32_with_beta(
820 input: &[f32],
821 output: &mut [f32],
822 gamma: &[f32],
823 beta: &[f32],
824 d_model: usize,
825 epsilon: f32,
826) {
827 assert_eq!(input.len(), output.len());
829 assert_eq!(input.len() % d_model, 0);
830 assert_eq!(gamma.len(), d_model);
831 assert_eq!(beta.len(), d_model);
832 #[cfg(feature = "simd")]
833 layer_norm_rows_f32_with_beta_simd(input, output, gamma, beta, d_model, epsilon);
834 #[cfg(not(feature = "simd"))]
835 {
836 for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
837 layer_norm_row_f32_scalar(in_row, out_row, gamma, Some(beta), epsilon);
838 }
839 }
840}
841
842#[inline]
844fn layer_norm_rows_f32_no_beta(
845 input: &[f32],
846 output: &mut [f32],
847 gamma: &[f32],
848 d_model: usize,
849 epsilon: f32,
850) {
851 assert_eq!(input.len(), output.len());
853 assert_eq!(input.len() % d_model, 0);
854 assert_eq!(gamma.len(), d_model);
855 #[cfg(feature = "simd")]
856 layer_norm_rows_f32_no_beta_simd(input, output, gamma, d_model, epsilon);
857 #[cfg(not(feature = "simd"))]
858 {
859 for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
860 layer_norm_row_f32_scalar(in_row, out_row, gamma, None, epsilon);
861 }
862 }
863}
864
865#[cfg(not(feature = "simd"))]
869#[inline]
870fn layer_norm_row_f32_scalar(
871 input: &[f32],
872 output: &mut [f32],
873 gamma: &[f32],
874 beta: Option<&[f32]>,
875 epsilon: f32,
876) {
877 let len = input.len();
886 let mut mean = 0.0f32;
887 let mut m2 = 0.0f32;
888 for (k, &x) in input.iter().enumerate() {
889 let n_k = (k + 1) as f32;
890 let delta = x - mean;
891 mean += delta / n_k;
892 let delta2 = x - mean;
893 m2 += delta * delta2;
894 }
895 let var = m2 / len as f32;
896 let inv_std = 1.0f32 / (var + epsilon).sqrt();
897 for (i, &x) in input.iter().enumerate() {
898 let scale = inv_std * gamma[i];
899 let normed = (x - mean) * scale;
900 output[i] = match beta {
901 Some(b) => normed + b[i],
902 None => normed,
903 };
904 }
905}
906
907#[cfg(feature = "simd")]
910#[macerator::with_simd]
911fn layer_norm_rows_f32_with_beta_simd<S: macerator::Simd>(
912 input: &[f32],
913 output: &mut [f32],
914 gamma: &[f32],
915 beta: &[f32],
916 d_model: usize,
917 epsilon: f32,
918) {
919 debug_assert_eq!(input.len(), output.len());
920 debug_assert_eq!(input.len() % d_model, 0);
921 debug_assert_eq!(gamma.len(), d_model);
922 debug_assert_eq!(beta.len(), d_model);
923 for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
924 layer_norm_row_f32_simd::<S>(in_row, out_row, gamma, Some(beta), epsilon);
925 }
926}
927
928#[cfg(feature = "simd")]
930#[macerator::with_simd]
931fn layer_norm_rows_f32_no_beta_simd<S: macerator::Simd>(
932 input: &[f32],
933 output: &mut [f32],
934 gamma: &[f32],
935 d_model: usize,
936 epsilon: f32,
937) {
938 debug_assert_eq!(input.len(), output.len());
939 debug_assert_eq!(input.len() % d_model, 0);
940 debug_assert_eq!(gamma.len(), d_model);
941 for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
942 layer_norm_row_f32_simd::<S>(in_row, out_row, gamma, None, epsilon);
943 }
944}
945
946#[cfg(feature = "simd")]
948#[inline(always)]
949fn layer_norm_row_f32_simd<S: macerator::Simd>(
950 input: &[f32],
951 output: &mut [f32],
952 gamma: &[f32],
953 beta: Option<&[f32]>,
954 epsilon: f32,
955) {
956 use macerator::{Scalar, vload_unaligned, vstore_unaligned};
957 let lanes = <f32 as Scalar>::lanes::<S>();
958 let len = input.len();
959 let simd_len = len / lanes * lanes;
960
961 let (sum, sumsq) = if simd_len >= lanes {
965 let mut acc_sum = 0.0f32.splat::<S>();
966 let mut acc_sumsq = 0.0f32.splat::<S>();
967 let mut i = 0;
968 while i < simd_len {
969 unsafe {
970 let v = vload_unaligned::<S, _>(input.as_ptr().add(i));
971 acc_sum += v;
972 acc_sumsq = v.mul_add(v, acc_sumsq);
975 }
976 i += lanes;
977 }
978 let mut s = acc_sum.reduce_add();
979 let mut sq = acc_sumsq.reduce_add();
980 for &x in &input[simd_len..] {
981 s += x;
982 sq += x * x;
983 }
984 (s, sq)
985 } else {
986 let mut s = 0.0f32;
987 let mut sq = 0.0f32;
988 for &x in input {
989 s += x;
990 sq += x * x;
991 }
992 (s, sq)
993 };
994
995 let n = len as f32;
996 let mean = sum / n;
997 let var = (sumsq / n) - mean * mean;
1001 let inv_std = 1.0f32 / (var + epsilon).sqrt();
1002
1003 let mean_vec = mean.splat::<S>();
1009 let inv_std_vec = inv_std.splat::<S>();
1010 let mut i = 0;
1011 while i < simd_len {
1012 unsafe {
1013 let x = vload_unaligned::<S, _>(input.as_ptr().add(i));
1014 let g = vload_unaligned::<S, _>(gamma.as_ptr().add(i));
1015 let scale = inv_std_vec * g;
1017 let centered = x - mean_vec;
1019 let normed = centered * scale;
1021 let out = if let Some(b) = beta {
1022 let b_vec = vload_unaligned::<S, _>(b.as_ptr().add(i));
1023 normed + b_vec
1024 } else {
1025 normed
1026 };
1027 vstore_unaligned::<S, _>(output.as_mut_ptr().add(i), out);
1028 }
1029 i += lanes;
1030 }
1031 while i < len {
1033 let centered = input[i] - mean;
1034 let normed = centered * inv_std * gamma[i];
1035 output[i] = match beta {
1036 Some(b) => normed + b[i],
1037 None => normed,
1038 };
1039 i += 1;
1040 }
1041}
1042
1043#[cfg(test)]
1051mod tests {
1052 use alloc::vec;
1053 use burn_backend::{DType, TensorData, TensorMetadata, Tolerance};
1054 use burn_std::{bf16, f16};
1055 use num_traits::Float;
1056
1057 use crate::FlexTensor;
1058
1059 fn softmax_row<T: Float>(row_in: &[T], row_out: &mut [T]) {
1069 let max = row_in
1070 .iter()
1071 .copied()
1072 .fold(T::neg_infinity(), |a, b| if a > b { a } else { b });
1073 let mut sum = T::zero();
1074 for (i, &x) in row_in.iter().enumerate() {
1075 let e = (x - max).exp();
1076 row_out[i] = e;
1077 sum = sum + e;
1078 }
1079 for v in row_out.iter_mut() {
1080 *v = *v / sum;
1081 }
1082 }
1083
1084 fn softmax_last_ref<T: Float>(data: &[T], row_len: usize) -> Vec<T> {
1085 let mut out = vec![T::zero(); data.len()];
1086 for (i, o) in data.chunks(row_len).zip(out.chunks_mut(row_len)) {
1087 softmax_row(i, o);
1088 }
1089 out
1090 }
1091
1092 fn layer_norm_row<T: Float>(
1093 row_in: &[T],
1094 gamma: &[T],
1095 beta: Option<&[T]>,
1096 eps: T,
1097 row_out: &mut [T],
1098 ) {
1099 let n = T::from(row_in.len()).unwrap();
1100 let mean = row_in.iter().copied().fold(T::zero(), |a, b| a + b) / n;
1101 let var = row_in
1102 .iter()
1103 .map(|&x| (x - mean) * (x - mean))
1104 .fold(T::zero(), |a, b| a + b)
1105 / n;
1106 let inv_std = T::one() / (var + eps).sqrt();
1107 for (i, &x) in row_in.iter().enumerate() {
1108 let normed = (x - mean) * inv_std;
1109 let scaled = normed * gamma[i];
1110 row_out[i] = match beta {
1111 Some(b) => scaled + b[i],
1112 None => scaled,
1113 };
1114 }
1115 }
1116
1117 fn layer_norm_last_ref<T: Float>(
1118 data: &[T],
1119 gamma: &[T],
1120 beta: Option<&[T]>,
1121 eps: T,
1122 row_len: usize,
1123 ) -> Vec<T> {
1124 let mut out = vec![T::zero(); data.len()];
1125 for (i, o) in data.chunks(row_len).zip(out.chunks_mut(row_len)) {
1126 layer_norm_row(i, gamma, beta, eps, o);
1127 }
1128 out
1129 }
1130
1131 fn flex_f32(data: Vec<f32>, shape: &[usize]) -> FlexTensor {
1136 FlexTensor::from_data(TensorData::new(data, shape.to_vec()))
1137 }
1138
1139 fn flex_f64(data: Vec<f64>, shape: &[usize]) -> FlexTensor {
1140 FlexTensor::from_data(TensorData::new(data, shape.to_vec()))
1141 }
1142
1143 fn flex_half<T: burn_backend::Element>(data: Vec<T>, shape: &[usize]) -> FlexTensor {
1144 FlexTensor::from_data(TensorData::new(data, shape.to_vec()))
1145 }
1146
1147 #[test]
1152 fn test_layer_norm_2d_with_beta() {
1153 let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]);
1154 let gamma = flex_f32(vec![1.0; 4], &[4]);
1155 let beta = flex_f32(vec![0.0; 4], &[4]);
1156 let out = crate::ops::activation::layer_norm(t, gamma, Some(beta), 1e-5);
1157
1158 let expected: Vec<f32> = vec![
1159 -1.3416408, -0.4472136, 0.4472136, 1.3416408, -1.3416408, -0.4472136, 0.4472136,
1160 1.3416408,
1161 ];
1162 out.into_data().assert_approx_eq::<f32>(
1163 &TensorData::new(expected, vec![2, 4]),
1164 Tolerance::absolute(1e-4),
1165 );
1166 }
1167
1168 #[test]
1169 fn test_layer_norm_with_affine() {
1170 let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1171 let gamma = flex_f32(vec![2.0, 0.5, 1.0, 3.0], &[4]);
1172 let beta = flex_f32(vec![1.0, -1.0, 0.0, 2.0], &[4]);
1173 let out = crate::ops::activation::layer_norm(t, gamma, Some(beta), 1e-5);
1174
1175 out.into_data().assert_approx_eq::<f32>(
1178 &TensorData::new(vec![-1.6833, -1.2236, 0.4472, 6.0249], vec![1, 4]),
1179 Tolerance::absolute(1e-3),
1180 );
1181 }
1182
1183 #[test]
1184 fn test_layer_norm_no_beta() {
1185 let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1186 let gamma = flex_f32(vec![1.0; 4], &[4]);
1187 let out = crate::ops::activation::layer_norm(t, gamma, None, 1e-5);
1188
1189 out.into_data().assert_approx_eq::<f32>(
1190 &TensorData::new(
1191 vec![-1.3416408, -0.4472136, 0.4472136, 1.3416408],
1192 vec![1, 4],
1193 ),
1194 Tolerance::absolute(1e-4),
1195 );
1196 }
1197
1198 #[test]
1203 fn test_softmax_simd_body_row() {
1204 let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
1207 let expected = softmax_last_ref(&data, 32);
1208 let fused = crate::ops::activation::softmax(flex_f32(data, &[1, 32]), 1);
1209 fused.into_data().assert_approx_eq::<f32>(
1210 &TensorData::new(expected, vec![1, 32]),
1211 Tolerance::absolute(1e-5),
1212 );
1213 }
1214
1215 #[test]
1216 fn test_softmax_multi_chunk_rayon() {
1217 let data: Vec<f32> = (0..100 * 16).map(|i| ((i % 17) as f32) * 0.05).collect();
1219 let expected = softmax_last_ref(&data, 16);
1220 let fused = crate::ops::activation::softmax(flex_f32(data, &[100, 16]), 1);
1221 fused.into_data().assert_approx_eq::<f32>(
1222 &TensorData::new(expected, vec![100, 16]),
1223 Tolerance::absolute(1e-5),
1224 );
1225 }
1226
1227 #[test]
1228 fn test_softmax_f64() {
1229 let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1231 let expected = softmax_last_ref(&data, 4);
1232 let fused = crate::ops::activation::softmax(flex_f64(data, &[2, 4]), 1);
1233 fused.into_data().assert_approx_eq::<f64>(
1234 &TensorData::new(expected, vec![2, 4]),
1235 Tolerance::absolute(1e-10),
1236 );
1237 }
1238
1239 #[test]
1240 fn test_softmax_f16() {
1241 let source: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1243 let data: Vec<f16> = source.iter().map(|&x| f16::from_f32(x)).collect();
1244 let expected = softmax_last_ref(&data, 4);
1245 let fused = crate::ops::activation::softmax(flex_half(data, &[2, 4]), 1);
1246 fused.into_data().assert_approx_eq::<f16>(
1247 &TensorData::new(expected, vec![2, 4]),
1248 Tolerance::absolute(1e-2),
1249 );
1250 }
1251
1252 #[test]
1253 fn test_softmax_bf16() {
1254 let source: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 0.5, 0.5, 0.5, 0.5];
1256 let data: Vec<bf16> = source.iter().map(|&x| bf16::from_f32(x)).collect();
1257 let expected = softmax_last_ref(&data, 4);
1258 let fused = crate::ops::activation::softmax(flex_half(data, &[2, 4]), 1);
1259 fused.into_data().assert_approx_eq::<bf16>(
1260 &TensorData::new(expected, vec![2, 4]),
1261 Tolerance::absolute(5e-2),
1262 );
1263 }
1264
1265 #[test]
1266 fn test_layer_norm_multi_chunk_rayon() {
1267 let data: Vec<f32> = (0..128 * 16).map(|i| ((i % 19) as f32) * 0.03).collect();
1269 let gamma_data: Vec<f32> = vec![1.0; 16];
1270 let beta_data: Vec<f32> = vec![0.0; 16];
1271 let expected = layer_norm_last_ref(&data, &gamma_data, Some(&beta_data), 1e-5f32, 16);
1272 let fused = crate::ops::activation::layer_norm(
1273 flex_f32(data, &[128, 16]),
1274 flex_f32(gamma_data, &[16]),
1275 Some(flex_f32(beta_data, &[16])),
1276 1e-5,
1277 );
1278 fused.into_data().assert_approx_eq::<f32>(
1279 &TensorData::new(expected, vec![128, 16]),
1280 Tolerance::absolute(1e-4),
1281 );
1282 }
1283
1284 #[test]
1285 fn test_softmax_empty_last_dim_returns_input() {
1286 let t = flex_f32(Vec::<f32>::new(), &[2, 0]);
1289 let result = crate::ops::activation::softmax(t, 1);
1290 assert_eq!(result.shape().as_slice(), &[2, 0]);
1291 }
1292
1293 #[test]
1294 fn test_layer_norm_empty_last_dim_returns_input() {
1295 let t = flex_f32(Vec::<f32>::new(), &[3, 0]);
1296 let gamma = flex_f32(Vec::<f32>::new(), &[0]);
1297 let beta = flex_f32(Vec::<f32>::new(), &[0]);
1298 let result = crate::ops::activation::layer_norm(t, gamma, Some(beta), 1e-5);
1299 assert_eq!(result.shape().as_slice(), &[3, 0]);
1300 }
1301
1302 #[test]
1303 #[should_panic(expected = "gamma must be a 1-D tensor")]
1304 fn test_layer_norm_gamma_length_mismatch_panics() {
1305 let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1306 let gamma = flex_f32(vec![1.0, 1.0, 1.0], &[3]);
1307 let _ = crate::ops::activation::layer_norm(t, gamma, None, 1e-5);
1308 }
1309
1310 #[test]
1311 #[should_panic(expected = "beta must be a 1-D tensor")]
1312 fn test_layer_norm_beta_length_mismatch_panics() {
1313 let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1314 let gamma = flex_f32(vec![1.0, 1.0, 1.0, 1.0], &[4]);
1315 let beta = flex_f32(vec![0.0, 0.0, 0.0], &[3]);
1316 let _ = crate::ops::activation::layer_norm(t, gamma, Some(beta), 1e-5);
1317 }
1318
1319 #[test]
1320 #[should_panic(expected = "gamma must be a 1-D tensor")]
1321 fn test_layer_norm_gamma_rank_mismatch_panics() {
1322 let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1325 let gamma = flex_f32(vec![1.0; 8], &[2, 4]);
1326 let _ = crate::ops::activation::layer_norm(t, gamma, None, 1e-5);
1327 }
1328
1329 #[test]
1335 fn test_softmax_simd_body_plus_scalar_tail() {
1336 let data: Vec<f32> = (0..34).map(|i| (i as f32 * 0.137) - 2.3).collect();
1337 let expected = softmax_last_ref(&data, 17);
1338 let fused = crate::ops::activation::softmax(flex_f32(data, &[2, 17]), 1);
1339 fused.into_data().assert_approx_eq::<f32>(
1340 &TensorData::new(expected, vec![2, 17]),
1341 Tolerance::absolute(1e-5),
1342 );
1343 }
1344
1345 #[test]
1346 fn test_layer_norm_simd_body_plus_scalar_tail() {
1347 let data: Vec<f32> = (0..34).map(|i| (i as f32 * 0.137) - 2.3).collect();
1348 let gamma_data: Vec<f32> = (0..17).map(|i| 1.0 + i as f32 * 0.05).collect();
1349 let beta_data: Vec<f32> = (0..17).map(|i| i as f32 * 0.01).collect();
1350 let expected = layer_norm_last_ref(&data, &gamma_data, Some(&beta_data), 1e-5f32, 17);
1351 let fused = crate::ops::activation::layer_norm(
1352 flex_f32(data, &[2, 17]),
1353 flex_f32(gamma_data, &[17]),
1354 Some(flex_f32(beta_data, &[17])),
1355 1e-5,
1356 );
1357 fused.into_data().assert_approx_eq::<f32>(
1358 &TensorData::new(expected, vec![2, 17]),
1359 Tolerance::absolute(1e-5),
1360 );
1361 }
1362
1363 #[test]
1364 fn test_layer_norm_f64_with_beta_multi_chunk() {
1365 let d_model = 16;
1367 let n_rows = 80;
1368 let data: Vec<f64> = (0..n_rows * d_model)
1369 .map(|i| ((i % 13) as f64) * 0.07 - 0.3)
1370 .collect();
1371 let gamma_data: Vec<f64> = vec![0.9; d_model];
1372 let beta_data: Vec<f64> = vec![0.05; d_model];
1373 let eps = 1e-5f64;
1374 let expected = layer_norm_last_ref(&data, &gamma_data, Some(&beta_data), eps, d_model);
1375 let fused = crate::ops::activation::layer_norm(
1376 flex_f64(data, &[n_rows, d_model]),
1377 flex_f64(gamma_data, &[d_model]),
1378 Some(flex_f64(beta_data, &[d_model])),
1379 eps,
1380 );
1381 fused.into_data().assert_approx_eq::<f64>(
1382 &TensorData::new(expected, vec![n_rows, d_model]),
1383 Tolerance::absolute(1e-10),
1384 );
1385 }
1386
1387 #[test]
1388 fn test_layer_norm_f64_no_beta() {
1389 let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, -1.0, 0.5, 1.5, -0.5];
1390 let gamma_data: Vec<f64> = vec![1.0; 4];
1391 let eps = 1e-5f64;
1392 let expected = layer_norm_last_ref(&data, &gamma_data, None, eps, 4);
1393 let fused = crate::ops::activation::layer_norm(
1394 flex_f64(data, &[2, 4]),
1395 flex_f64(gamma_data, &[4]),
1396 None,
1397 eps,
1398 );
1399 fused.into_data().assert_approx_eq::<f64>(
1400 &TensorData::new(expected, vec![2, 4]),
1401 Tolerance::absolute(1e-10),
1402 );
1403 }
1404
1405 fn check_layer_norm_half_precision<E>(from_f32: fn(f32) -> E, dtype: DType)
1409 where
1410 E: burn_backend::Element + Float,
1411 {
1412 let rows_f32: [f32; 12] = [
1413 1.0, 2.0, 3.0, 4.0, -1.0, 0.0, 1.0, 2.0, 0.5, -0.5, 1.5, -1.5,
1414 ];
1415 let gamma_f32: [f32; 4] = [1.0, 0.5, 1.5, 1.0];
1416 let beta_f32: [f32; 4] = [0.1, -0.1, 0.0, 0.2];
1417 let eps = 1e-5f32;
1418
1419 let expected_f32 = layer_norm_last_ref(&rows_f32, &gamma_f32, Some(&beta_f32), eps, 4);
1420
1421 let data: Vec<E> = rows_f32.iter().map(|&x| from_f32(x)).collect();
1422 let gamma_data: Vec<E> = gamma_f32.iter().map(|&x| from_f32(x)).collect();
1423 let beta_data: Vec<E> = beta_f32.iter().map(|&x| from_f32(x)).collect();
1424 assert_eq!(E::dtype(), dtype);
1425
1426 let fused = crate::ops::activation::layer_norm(
1427 flex_half(data, &[3, 4]),
1428 flex_half(gamma_data, &[4]),
1429 Some(flex_half(beta_data, &[4])),
1430 eps as f64,
1431 );
1432 fused.into_data().assert_approx_eq::<f32>(
1433 &TensorData::new(expected_f32, vec![3, 4]),
1434 Tolerance::absolute(3e-2),
1435 );
1436 }
1437
1438 #[test]
1439 fn test_layer_norm_f16_via_f32_cast() {
1440 check_layer_norm_half_precision::<f16>(f16::from_f32, DType::F16);
1441 }
1442
1443 #[test]
1444 fn test_layer_norm_bf16_via_f32_cast() {
1445 check_layer_norm_half_precision::<bf16>(bf16::from_f32, DType::BF16);
1446 }
1447
1448 #[test]
1449 #[should_panic(expected = "softmax dim")]
1450 fn test_softmax_dim_out_of_range_panics() {
1451 let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
1452 let _ = crate::ops::activation::softmax(t, 2);
1453 }
1454
1455 #[test]
1456 #[should_panic(expected = "gamma dtype")]
1457 fn test_layer_norm_gamma_dtype_mismatch_panics() {
1458 let t = flex_f32(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]);
1461 let gamma = flex_f64(vec![1.0; 4], &[4]);
1462 let _ = crate::ops::activation::layer_norm(t, gamma, None, 1e-5);
1463 }
1464}