use alloc::vec;
use alloc::vec::Vec;
use burn_backend::Scalar;
use burn_backend::ops::ActivationOps;
use burn_backend::tensor::FloatTensor;
use burn_backend::{DType, TensorMetadata};
use burn_std::{Bytes, bf16, f16};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use num_traits::ToPrimitive;
use crate::ops::binary::binary_op;
use crate::ops::unary::unary_op;
use crate::{Flex, FlexTensor, Layout};
impl ActivationOps<Flex> for Flex {
fn relu(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
unary_op(tensor, |x: f32| x.max(0.0), |x: f64| x.max(0.0))
}
fn relu_backward(output: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
binary_op(
output,
grad,
|out: f32, g| if out > 0.0 { g } else { 0.0 },
|out: f64, g| if out > 0.0 { g } else { 0.0 },
None,
)
}
fn leaky_relu(tensor: FloatTensor<Flex>, negative_slope: Scalar) -> FloatTensor<Flex> {
let ns32 = negative_slope.to_f32().unwrap();
let ns64 = negative_slope.to_f64().unwrap();
unary_op(
tensor,
move |x: f32| if x >= 0.0 { x } else { ns32 * x },
move |x: f64| if x >= 0.0 { x } else { ns64 * x },
)
}
fn prelu(tensor: FloatTensor<Flex>, alpha: FloatTensor<Flex>) -> FloatTensor<Flex> {
binary_op(
tensor,
alpha,
|x: f32, a| if x >= 0.0 { x } else { a * x },
|x: f64, a| if x >= 0.0 { x } else { a * x },
None,
)
}
fn gelu(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
use crate::ops::unary::{erf_f32, erf_f64};
let sqrt2_f32: f32 = core::f32::consts::SQRT_2;
let sqrt2_f64: f64 = core::f64::consts::SQRT_2;
unary_op(
tensor,
move |x: f32| 0.5 * x * (1.0 + erf_f32(x / sqrt2_f32)),
move |x: f64| 0.5 * x * (1.0 + erf_f64(x / sqrt2_f64)),
)
}
fn gelu_backward(x: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
use crate::ops::unary::{erf_f32, erf_f64};
let sqrt2_f32: f32 = core::f32::consts::SQRT_2;
let sqrt2_f64: f64 = core::f64::consts::SQRT_2;
let inv_sqrt_2pi_f32: f32 = 1.0 / (2.0 * core::f32::consts::PI).sqrt();
let inv_sqrt_2pi_f64: f64 = 1.0 / (2.0 * core::f64::consts::PI).sqrt();
binary_op(
x,
grad,
move |x: f32, g| {
let cdf = 0.5 * (1.0 + erf_f32(x / sqrt2_f32));
let pdf = inv_sqrt_2pi_f32 * (-0.5 * x * x).exp();
g * (cdf + x * pdf)
},
move |x: f64, g| {
let cdf = 0.5 * (1.0 + erf_f64(x / sqrt2_f64));
let pdf = inv_sqrt_2pi_f64 * (-0.5 * x * x).exp();
g * (cdf + x * pdf)
},
None,
)
}
fn sigmoid(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
unary_op(tensor, sigmoid_f32, sigmoid_f64)
}
fn sigmoid_backward(output: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
binary_op(
output,
grad,
|s: f32, g| g * s * (1.0 - s),
|s: f64, g| g * s * (1.0 - s),
None,
)
}
fn hard_sigmoid(tensor: FloatTensor<Flex>, alpha: Scalar, beta: Scalar) -> FloatTensor<Flex> {
let alpha32 = alpha.to_f32().unwrap();
let beta32 = beta.to_f32().unwrap();
let alpha64 = alpha.to_f64().unwrap();
let beta64 = beta.to_f64().unwrap();
unary_op(
tensor,
move |x: f32| (alpha32 * x + beta32).clamp(0.0, 1.0),
move |x: f64| (alpha64 * x + beta64).clamp(0.0, 1.0),
)
}
fn log_sigmoid(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
unary_op(
tensor,
|x: f32| {
if x >= 0.0 {
-((-x).exp().ln_1p())
} else {
x - x.exp().ln_1p()
}
},
|x: f64| {
if x >= 0.0 {
-((-x).exp().ln_1p())
} else {
x - x.exp().ln_1p()
}
},
)
}
fn log_sigmoid_backward(x: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
binary_op(
x,
grad,
|x: f32, g| g * sigmoid_f32(-x),
|x: f64, g| g * sigmoid_f64(-x),
None,
)
}
}
#[inline]
fn sigmoid_f32(x: f32) -> f32 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let e = x.exp();
e / (1.0 + e)
}
}
#[inline]
fn sigmoid_f64(x: f64) -> f64 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let e = x.exp();
e / (1.0 + e)
}
}
pub fn softmax(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
let rank = tensor.shape().num_dims();
assert!(
dim < rank,
"softmax dim {} out of range for rank {}",
dim,
rank
);
assert!(
dim == rank - 1,
"burn_flex::softmax currently only supports softmax along the last axis \
(got dim={} for rank {}). Permute the tensor or fall back to \
burn_tensor::activation::softmax for other axes.",
dim,
rank
);
let tensor = tensor.to_contiguous();
match tensor.dtype() {
DType::F32 => softmax_last_f32(tensor),
DType::F64 => softmax_last_f64(tensor),
DType::F16 => softmax_last_f16(tensor),
DType::BF16 => softmax_last_bf16(tensor),
dtype => panic!("softmax: unsupported dtype {:?}", dtype),
}
}
fn softmax_last_f32(tensor: FlexTensor) -> FlexTensor {
let shape = tensor.layout().shape().clone();
let last = *shape.last().expect("softmax: empty shape");
if last == 0 {
return tensor;
}
let input: &[f32] = tensor.storage();
let n = input.len();
let mut output: Vec<f32> = vec![0.0; n];
let out_slice = output.as_mut_slice();
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
const ROWS_PER_TASK: usize = 64;
let chunk_elems = ROWS_PER_TASK * last;
out_slice
.par_chunks_mut(chunk_elems)
.zip(input.par_chunks(chunk_elems))
.for_each(|(o, i)| softmax_rows_f32(i, o, last));
}
#[cfg(not(feature = "rayon"))]
{
softmax_rows_f32(input, out_slice, last);
}
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(shape),
DType::F32,
)
}
#[inline]
fn softmax_rows_f32(input: &[f32], output: &mut [f32], row_len: usize) {
assert_eq!(input.len(), output.len());
assert_eq!(input.len() % row_len, 0);
#[cfg(feature = "simd")]
softmax_rows_f32_simd(input, output, row_len);
#[cfg(not(feature = "simd"))]
{
for (in_row, out_row) in input.chunks(row_len).zip(output.chunks_mut(row_len)) {
softmax_row_f32_scalar(in_row, out_row);
}
}
}
#[cfg(feature = "simd")]
#[macerator::with_simd]
fn softmax_rows_f32_simd<S: macerator::Simd>(input: &[f32], output: &mut [f32], row_len: usize) {
debug_assert_eq!(input.len(), output.len());
debug_assert_eq!(input.len() % row_len, 0);
for (in_row, out_row) in input.chunks(row_len).zip(output.chunks_mut(row_len)) {
softmax_row_f32_simd::<S>(in_row, out_row);
}
}
#[cfg(not(feature = "simd"))]
#[inline]
fn softmax_row_f32_scalar(input: &[f32], output: &mut [f32]) {
let mut max_val = f32::NEG_INFINITY;
for &x in input {
if x > max_val {
max_val = x;
}
}
let mut sum = 0.0f32;
for (i, &x) in input.iter().enumerate() {
let e = (x - max_val).exp();
output[i] = e;
sum += e;
}
let inv = 1.0f32 / sum;
for x in output.iter_mut() {
*x *= inv;
}
}
#[cfg(feature = "simd")]
#[inline(always)]
fn softmax_row_f32_simd<S: macerator::Simd>(input: &[f32], output: &mut [f32]) {
use macerator::{Scalar, vload_unaligned, vstore_unaligned};
let lanes = <f32 as Scalar>::lanes::<S>();
let len = input.len();
let simd_len = len / lanes * lanes;
let (mut max_val, tail_start) = if simd_len >= lanes {
let mut max_vec = unsafe { vload_unaligned::<S, _>(input.as_ptr()) };
let mut j = lanes;
while j < simd_len {
let v = unsafe { vload_unaligned::<S, _>(input.as_ptr().add(j)) };
max_vec = max_vec.max(v);
j += lanes;
}
(max_vec.reduce_max(), simd_len)
} else {
(f32::NEG_INFINITY, 0)
};
for &x in &input[tail_start..] {
if x > max_val {
max_val = x;
}
}
let mut sum = 0.0f32;
for idx in 0..len {
let e = (input[idx] - max_val).exp();
output[idx] = e;
sum += e;
}
let inv = 1.0f32 / sum;
let inv_vec = inv.splat::<S>();
let mut i = 0;
while i < simd_len {
unsafe {
let v = vload_unaligned::<S, _>(output.as_ptr().add(i));
vstore_unaligned::<S, _>(output.as_mut_ptr().add(i), v * inv_vec);
}
i += lanes;
}
for x in &mut output[i..] {
*x *= inv;
}
}
macro_rules! softmax_last_dtype {
($fn_name:ident, $T:ty, $zero:expr, $dtype:expr, $row_fn:ident) => {
fn $fn_name(tensor: FlexTensor) -> FlexTensor {
let shape = tensor.layout().shape().clone();
let last = *shape.last().expect("softmax: empty shape");
if last == 0 {
return tensor;
}
let input: &[$T] = tensor.storage();
let mut output: Vec<$T> = vec![$zero; input.len()];
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
output
.par_chunks_mut(last)
.zip(input.par_chunks(last))
.for_each(|(o, i)| $row_fn(i, o));
}
#[cfg(not(feature = "rayon"))]
{
for (i, o) in input.chunks(last).zip(output.chunks_mut(last)) {
$row_fn(i, o);
}
}
FlexTensor::new(Bytes::from_elems(output), Layout::contiguous(shape), $dtype)
}
};
}
macro_rules! softmax_row_half {
($fn_name:ident, $T:ty) => {
#[inline]
fn $fn_name(input: &[$T], output: &mut [$T]) {
let mut max_val = f32::NEG_INFINITY;
for &x in input {
let xf = x.to_f32();
if xf > max_val {
max_val = xf;
}
}
let mut sum = 0.0f32;
for (i, &x) in input.iter().enumerate() {
let e = (x.to_f32() - max_val).exp();
output[i] = <$T>::from_f32(e);
sum += e;
}
let inv = 1.0f32 / sum;
for x in output.iter_mut() {
*x = <$T>::from_f32(x.to_f32() * inv);
}
}
};
}
#[inline]
fn softmax_row_f64(input: &[f64], output: &mut [f64]) {
let mut max_val = f64::NEG_INFINITY;
for &x in input {
if x > max_val {
max_val = x;
}
}
let mut sum = 0.0f64;
for (i, &x) in input.iter().enumerate() {
let e = (x - max_val).exp();
output[i] = e;
sum += e;
}
let inv = 1.0f64 / sum;
for x in output.iter_mut() {
*x *= inv;
}
}
softmax_row_half!(softmax_row_f16, f16);
softmax_row_half!(softmax_row_bf16, bf16);
softmax_last_dtype!(softmax_last_f64, f64, 0.0f64, DType::F64, softmax_row_f64);
softmax_last_dtype!(
softmax_last_f16,
f16,
f16::from_f32(0.0),
DType::F16,
softmax_row_f16
);
softmax_last_dtype!(
softmax_last_bf16,
bf16,
bf16::from_f32(0.0),
DType::BF16,
softmax_row_bf16
);
pub fn layer_norm(
input: FloatTensor<Flex>,
gamma: FloatTensor<Flex>,
beta: Option<FloatTensor<Flex>>,
epsilon: f64,
) -> FloatTensor<Flex> {
let rank = input.shape().num_dims();
assert!(rank >= 1, "layer_norm: input must have at least one dim");
let input = input.to_contiguous();
let gamma = gamma.to_contiguous();
let beta = beta.map(|b| b.to_contiguous());
let d_model = *input
.layout()
.shape()
.last()
.expect("layer_norm: empty shape");
let gamma_shape = gamma.layout().shape();
assert!(
gamma_shape.len() == 1 && gamma_shape[0] == d_model,
"layer_norm: gamma must be a 1-D tensor of length equal to last dim of input \
(got shape {:?}, expected [{}])",
gamma_shape,
d_model,
);
if let Some(ref b) = beta {
let beta_shape = b.layout().shape();
assert!(
beta_shape.len() == 1 && beta_shape[0] == d_model,
"layer_norm: beta must be a 1-D tensor of length equal to last dim of input \
(got shape {:?}, expected [{}])",
beta_shape,
d_model,
);
}
match input.dtype() {
DType::F32 => layer_norm_f32(input, gamma, beta, epsilon as f32),
dtype => panic!(
"burn_flex::layer_norm: unsupported dtype {:?} (only f32 fast path is implemented; \
cast to f32 or fall back to burn::nn::LayerNorm)",
dtype
),
}
}
fn layer_norm_f32(
input: FlexTensor,
gamma: FlexTensor,
beta: Option<FlexTensor>,
epsilon: f32,
) -> FlexTensor {
let shape = input.layout().shape().clone();
let d_model = *shape.last().expect("layer_norm: empty shape");
if d_model == 0 {
return input;
}
let input_data: &[f32] = input.storage();
let gamma_data: &[f32] = gamma.storage();
let beta_data: Option<&[f32]> = beta.as_ref().map(|b| b.storage());
let n = input_data.len();
let mut output: Vec<f32> = vec![0.0; n];
let out_slice = output.as_mut_slice();
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
const ROWS_PER_TASK: usize = 64;
let chunk_elems = ROWS_PER_TASK * d_model;
match beta_data {
Some(beta_slice) => {
out_slice
.par_chunks_mut(chunk_elems)
.zip(input_data.par_chunks(chunk_elems))
.for_each(|(o, i)| {
layer_norm_rows_f32_with_beta(
i, o, gamma_data, beta_slice, d_model, epsilon,
);
});
}
None => {
out_slice
.par_chunks_mut(chunk_elems)
.zip(input_data.par_chunks(chunk_elems))
.for_each(|(o, i)| {
layer_norm_rows_f32_no_beta(i, o, gamma_data, d_model, epsilon);
});
}
}
}
#[cfg(not(feature = "rayon"))]
{
match beta_data {
Some(beta_slice) => layer_norm_rows_f32_with_beta(
input_data, out_slice, gamma_data, beta_slice, d_model, epsilon,
),
None => {
layer_norm_rows_f32_no_beta(input_data, out_slice, gamma_data, d_model, epsilon)
}
}
}
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(shape),
DType::F32,
)
}
#[inline]
fn layer_norm_rows_f32_with_beta(
input: &[f32],
output: &mut [f32],
gamma: &[f32],
beta: &[f32],
d_model: usize,
epsilon: f32,
) {
assert_eq!(input.len(), output.len());
assert_eq!(input.len() % d_model, 0);
assert_eq!(gamma.len(), d_model);
assert_eq!(beta.len(), d_model);
#[cfg(feature = "simd")]
layer_norm_rows_f32_with_beta_simd(input, output, gamma, beta, d_model, epsilon);
#[cfg(not(feature = "simd"))]
{
for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
layer_norm_row_f32_scalar(in_row, out_row, gamma, Some(beta), epsilon);
}
}
}
#[inline]
fn layer_norm_rows_f32_no_beta(
input: &[f32],
output: &mut [f32],
gamma: &[f32],
d_model: usize,
epsilon: f32,
) {
assert_eq!(input.len(), output.len());
assert_eq!(input.len() % d_model, 0);
assert_eq!(gamma.len(), d_model);
#[cfg(feature = "simd")]
layer_norm_rows_f32_no_beta_simd(input, output, gamma, d_model, epsilon);
#[cfg(not(feature = "simd"))]
{
for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
layer_norm_row_f32_scalar(in_row, out_row, gamma, None, epsilon);
}
}
}
#[cfg(not(feature = "simd"))]
#[inline]
fn layer_norm_row_f32_scalar(
input: &[f32],
output: &mut [f32],
gamma: &[f32],
beta: Option<&[f32]>,
epsilon: f32,
) {
let len = input.len();
let mut mean = 0.0f32;
let mut m2 = 0.0f32;
for (k, &x) in input.iter().enumerate() {
let n_k = (k + 1) as f32;
let delta = x - mean;
mean += delta / n_k;
let delta2 = x - mean;
m2 += delta * delta2;
}
let var = m2 / len as f32;
let inv_std = 1.0f32 / (var + epsilon).sqrt();
for (i, &x) in input.iter().enumerate() {
let scale = inv_std * gamma[i];
let normed = (x - mean) * scale;
output[i] = match beta {
Some(b) => normed + b[i],
None => normed,
};
}
}
#[cfg(feature = "simd")]
#[macerator::with_simd]
fn layer_norm_rows_f32_with_beta_simd<S: macerator::Simd>(
input: &[f32],
output: &mut [f32],
gamma: &[f32],
beta: &[f32],
d_model: usize,
epsilon: f32,
) {
debug_assert_eq!(input.len(), output.len());
debug_assert_eq!(input.len() % d_model, 0);
debug_assert_eq!(gamma.len(), d_model);
debug_assert_eq!(beta.len(), d_model);
for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
layer_norm_row_f32_simd::<S>(in_row, out_row, gamma, Some(beta), epsilon);
}
}
#[cfg(feature = "simd")]
#[macerator::with_simd]
fn layer_norm_rows_f32_no_beta_simd<S: macerator::Simd>(
input: &[f32],
output: &mut [f32],
gamma: &[f32],
d_model: usize,
epsilon: f32,
) {
debug_assert_eq!(input.len(), output.len());
debug_assert_eq!(input.len() % d_model, 0);
debug_assert_eq!(gamma.len(), d_model);
for (in_row, out_row) in input.chunks(d_model).zip(output.chunks_mut(d_model)) {
layer_norm_row_f32_simd::<S>(in_row, out_row, gamma, None, epsilon);
}
}
#[cfg(feature = "simd")]
#[inline(always)]
fn layer_norm_row_f32_simd<S: macerator::Simd>(
input: &[f32],
output: &mut [f32],
gamma: &[f32],
beta: Option<&[f32]>,
epsilon: f32,
) {
use macerator::{Scalar, vload_unaligned, vstore_unaligned};
let lanes = <f32 as Scalar>::lanes::<S>();
let len = input.len();
let simd_len = len / lanes * lanes;
let (sum, sumsq) = if simd_len >= lanes {
let mut acc_sum = 0.0f32.splat::<S>();
let mut acc_sumsq = 0.0f32.splat::<S>();
let mut i = 0;
while i < simd_len {
unsafe {
let v = vload_unaligned::<S, _>(input.as_ptr().add(i));
acc_sum += v;
acc_sumsq = v.mul_add(v, acc_sumsq);
}
i += lanes;
}
let mut s = acc_sum.reduce_add();
let mut sq = acc_sumsq.reduce_add();
for &x in &input[simd_len..] {
s += x;
sq += x * x;
}
(s, sq)
} else {
let mut s = 0.0f32;
let mut sq = 0.0f32;
for &x in input {
s += x;
sq += x * x;
}
(s, sq)
};
let n = len as f32;
let mean = sum / n;
let var = (sumsq / n) - mean * mean;
let inv_std = 1.0f32 / (var + epsilon).sqrt();
let mean_vec = mean.splat::<S>();
let inv_std_vec = inv_std.splat::<S>();
let mut i = 0;
while i < simd_len {
unsafe {
let x = vload_unaligned::<S, _>(input.as_ptr().add(i));
let g = vload_unaligned::<S, _>(gamma.as_ptr().add(i));
let scale = inv_std_vec * g;
let centered = x - mean_vec;
let normed = centered * scale;
let out = if let Some(b) = beta {
let b_vec = vload_unaligned::<S, _>(b.as_ptr().add(i));
normed + b_vec
} else {
normed
};
vstore_unaligned::<S, _>(output.as_mut_ptr().add(i), out);
}
i += lanes;
}
while i < len {
let centered = input[i] - mean;
let normed = centered * inv_std * gamma[i];
output[i] = match beta {
Some(b) => normed + b[i],
None => normed,
};
i += 1;
}
}
#[cfg(test)]
mod tests {
use burn_backend::Tolerance;
use burn_tensor::{Tensor, TensorData, activation};
use crate::Flex;
#[test]
fn test_relu() {
let t: Tensor<Flex, 1> =
Tensor::from_data([-2.0f32, -1.0, 0.0, 1.0, 2.0], &Default::default());
activation::relu(t).into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 0.0, 0.0, 1.0, 2.0]),
Tolerance::absolute(1e-6),
);
}
#[test]
fn test_sigmoid() {
let t: Tensor<Flex, 1> = Tensor::from_data([-10.0f32, 0.0, 10.0], &Default::default());
activation::sigmoid(t).into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 0.5, 1.0]),
Tolerance::absolute(1e-3),
);
}
#[test]
fn test_gelu() {
let t: Tensor<Flex, 1> = Tensor::from_data([-3.0f32, 0.0, 3.0], &Default::default());
activation::gelu(t).into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 0.0, 3.0]),
Tolerance::absolute(0.01),
);
}
#[test]
fn test_leaky_relu() {
let t: Tensor<Flex, 1> =
Tensor::from_data([-2.0f32, -1.0, 0.0, 1.0, 2.0], &Default::default());
activation::leaky_relu(t, 0.01)
.into_data()
.assert_approx_eq::<f32>(
&TensorData::from([-0.02, -0.01, 0.0, 1.0, 2.0]),
Tolerance::absolute(1e-6),
);
}
#[test]
fn test_softmax_1d() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 1> = Tensor::from_data([1.0f32, 2.0, 3.0], &Default::default());
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let result = crate::ops::activation::softmax(primitive, 0);
let result: Tensor<Flex, 1> = Tensor::from_primitive(TensorPrimitive::Float(result));
result.into_data().assert_approx_eq::<f32>(
&TensorData::from([0.09003, 0.24473, 0.66524]),
Tolerance::absolute(1e-4),
);
}
#[test]
fn test_softmax_2d_last_axis() {
use burn_tensor::TensorPrimitive;
let data = [[-1.0f32, 0.0, 1.0, 2.0], [0.5, 0.5, 0.5, 0.5]];
let t: Tensor<Flex, 2> = Tensor::from_data(data, &Default::default());
let reference = activation::softmax(t.clone(), 1);
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 1);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f32>(&reference.into_data(), Tolerance::absolute(1e-5));
}
#[test]
fn test_layer_norm_2d_with_beta() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> = Tensor::from_data(
[[1.0f32, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
&Default::default(),
);
let gamma: Tensor<Flex, 1> =
Tensor::from_data([1.0f32, 1.0, 1.0, 1.0], &Default::default());
let beta: Tensor<Flex, 1> = Tensor::from_data([0.0f32; 4], &Default::default());
let t_prim = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_prim = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let b_prim = match beta.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let out = crate::ops::activation::layer_norm(t_prim, g_prim, Some(b_prim), 1e-5);
let out: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(out));
let expected = [
[-1.3416408, -0.4472136, 0.4472136, 1.3416408],
[-1.3416408, -0.4472136, 0.4472136, 1.3416408],
];
out.into_data()
.assert_approx_eq::<f32>(&TensorData::from(expected), Tolerance::absolute(1e-4));
}
#[test]
fn test_layer_norm_with_affine() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> = Tensor::from_data([[1.0f32, 2.0, 3.0, 4.0]], &Default::default());
let gamma: Tensor<Flex, 1> =
Tensor::from_data([2.0f32, 0.5, 1.0, 3.0], &Default::default());
let beta: Tensor<Flex, 1> =
Tensor::from_data([1.0f32, -1.0, 0.0, 2.0], &Default::default());
let t_prim = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_prim = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let b_prim = match beta.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let out = crate::ops::activation::layer_norm(t_prim, g_prim, Some(b_prim), 1e-5);
let out: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(out));
out.into_data().assert_approx_eq::<f32>(
&TensorData::from([[-1.6833, -1.2236, 0.4472, 6.0249]]),
Tolerance::absolute(1e-3),
);
}
#[test]
fn test_layer_norm_no_beta() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> = Tensor::from_data([[1.0f32, 2.0, 3.0, 4.0]], &Default::default());
let gamma: Tensor<Flex, 1> =
Tensor::from_data([1.0f32, 1.0, 1.0, 1.0], &Default::default());
let t_prim = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_prim = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let out = crate::ops::activation::layer_norm(t_prim, g_prim, None, 1e-5);
let out: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(out));
out.into_data().assert_approx_eq::<f32>(
&TensorData::from([[-1.3416408, -0.4472136, 0.4472136, 1.3416408]]),
Tolerance::absolute(1e-4),
);
}
#[test]
fn test_softmax_3d_attention_shape() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 3> = Tensor::from_data(
[
[[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[0.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
],
&Default::default(),
);
let reference = activation::softmax(t.clone(), 2);
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 2);
let fused: Tensor<Flex, 3> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f32>(&reference.into_data(), Tolerance::absolute(1e-5));
}
#[test]
fn test_softmax_simd_body_row() {
use burn_tensor::{Tensor, TensorData, TensorPrimitive};
let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
let t: Tensor<Flex, 2> =
Tensor::from_data(TensorData::new(data, [1, 32]), &Default::default());
let reference = activation::softmax(t.clone(), 1);
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 1);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f32>(&reference.into_data(), Tolerance::absolute(1e-5));
}
#[test]
fn test_softmax_multi_chunk_rayon() {
use burn_tensor::{Tensor, TensorData, TensorPrimitive};
let data: Vec<f32> = (0..100 * 16).map(|i| ((i % 17) as f32) * 0.05).collect();
let t: Tensor<Flex, 2> =
Tensor::from_data(TensorData::new(data, [100, 16]), &Default::default());
let reference = activation::softmax(t.clone(), 1);
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 1);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f32>(&reference.into_data(), Tolerance::absolute(1e-5));
}
#[test]
fn test_softmax_f64() {
use burn_tensor::{Tensor, TensorData, TensorPrimitive};
let data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let t: Tensor<Flex, 2> = Tensor::from_data(
TensorData::new(data.to_vec(), [2, 4]),
(&Default::default(), burn_backend::DType::F64),
);
let reference = activation::softmax(t.clone(), 1);
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 1);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f64>(&reference.into_data(), Tolerance::absolute(1e-10));
}
#[test]
fn test_softmax_f16() {
use burn_std::f16;
use burn_tensor::{Tensor, TensorData, TensorPrimitive};
let data: Vec<f16> = [1.0f32, 2.0, 3.0, 4.0, 0.5, 0.5, 0.5, 0.5]
.iter()
.map(|&x| f16::from_f32(x))
.collect();
let t: Tensor<Flex, 2> = Tensor::from_data(
TensorData::new(data, [2, 4]),
(&Default::default(), burn_backend::DType::F16),
);
let reference = activation::softmax(t.clone(), 1);
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 1);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f16>(&reference.into_data(), Tolerance::absolute(1e-2));
}
#[test]
fn test_softmax_bf16() {
use burn_std::bf16;
use burn_tensor::{Tensor, TensorData, TensorPrimitive};
let data: Vec<bf16> = [1.0f32, 2.0, 3.0, 4.0, 0.5, 0.5, 0.5, 0.5]
.iter()
.map(|&x| bf16::from_f32(x))
.collect();
let t: Tensor<Flex, 2> = Tensor::from_data(
TensorData::new(data, [2, 4]),
(&Default::default(), burn_backend::DType::BF16),
);
let reference = activation::softmax(t.clone(), 1);
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 1);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<bf16>(&reference.into_data(), Tolerance::absolute(5e-2));
}
#[test]
fn test_layer_norm_multi_chunk_rayon() {
use burn_tensor::TensorPrimitive;
let data: Vec<f32> = (0..128 * 16).map(|i| ((i % 19) as f32) * 0.03).collect();
let t: Tensor<Flex, 2> =
Tensor::from_data(TensorData::new(data, [128, 16]), &Default::default());
let gamma: Tensor<Flex, 1> = Tensor::from_data([1.0f32; 16], &Default::default());
let beta: Tensor<Flex, 1> = Tensor::from_data([0.0f32; 16], &Default::default());
let rows_in = t.clone().into_data().to_vec::<f32>().unwrap();
let mut expected = vec![0.0f32; rows_in.len()];
let eps = 1e-5f32;
for (in_row, out_row) in rows_in.chunks(16).zip(expected.chunks_mut(16)) {
let mean: f32 = in_row.iter().sum::<f32>() / 16.0;
let var: f32 = in_row.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / 16.0;
let inv_std = 1.0 / (var + eps).sqrt();
for (i, &x) in in_row.iter().enumerate() {
out_row[i] = (x - mean) * inv_std;
}
}
let t_prim = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_prim = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let b_prim = match beta.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::layer_norm(t_prim, g_prim, Some(b_prim), 1e-5);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused.into_data().assert_approx_eq::<f32>(
&TensorData::new(expected, [128, 16]),
Tolerance::absolute(1e-4),
);
}
#[test]
fn test_softmax_non_contiguous_input() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> = Tensor::from_data(
[
[1.0f32, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
],
&Default::default(),
);
let t_transposed = t.transpose();
let reference = activation::softmax(t_transposed.clone(), 1);
let primitive = match t_transposed.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 1);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f32>(&reference.into_data(), Tolerance::absolute(1e-5));
}
#[test]
fn test_softmax_empty_last_dim_returns_input() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> = Tensor::from_data(
TensorData::new(Vec::<f32>::new(), [2, 0]),
&Default::default(),
);
let shape_before = t.shape();
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let result = crate::ops::activation::softmax(primitive, 1);
let result: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(result));
assert_eq!(result.shape(), shape_before);
}
#[test]
fn test_layer_norm_empty_last_dim_returns_input() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> = Tensor::from_data(
TensorData::new(Vec::<f32>::new(), [3, 0]),
&Default::default(),
);
let gamma: Tensor<Flex, 1> =
Tensor::from_data(TensorData::new(Vec::<f32>::new(), [0]), &Default::default());
let beta: Tensor<Flex, 1> =
Tensor::from_data(TensorData::new(Vec::<f32>::new(), [0]), &Default::default());
let shape_before = t.shape();
let t_p = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_p = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let b_p = match beta.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let result = crate::ops::activation::layer_norm(t_p, g_p, Some(b_p), 1e-5);
let result: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(result));
assert_eq!(result.shape(), shape_before);
}
#[test]
#[should_panic(expected = "gamma must be a 1-D tensor")]
fn test_layer_norm_gamma_length_mismatch_panics() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> = Tensor::from_data([[1.0f32, 2.0, 3.0, 4.0]], &Default::default());
let gamma: Tensor<Flex, 1> = Tensor::from_data([1.0f32, 1.0, 1.0], &Default::default());
let t_p = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_p = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let _ = crate::ops::activation::layer_norm(t_p, g_p, None, 1e-5);
}
#[test]
#[should_panic(expected = "beta must be a 1-D tensor")]
fn test_layer_norm_beta_length_mismatch_panics() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> = Tensor::from_data([[1.0f32, 2.0, 3.0, 4.0]], &Default::default());
let gamma: Tensor<Flex, 1> =
Tensor::from_data([1.0f32, 1.0, 1.0, 1.0], &Default::default());
let beta: Tensor<Flex, 1> = Tensor::from_data([0.0f32, 0.0, 0.0], &Default::default());
let t_p = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_p = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let b_p = match beta.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let _ = crate::ops::activation::layer_norm(t_p, g_p, Some(b_p), 1e-5);
}
#[test]
#[should_panic(expected = "gamma must be a 1-D tensor")]
fn test_layer_norm_gamma_rank_mismatch_panics() {
use burn_backend::DType;
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> =
Tensor::from_data([[1.0f32, 2.0, 3.0, 4.0]], (&Default::default(), DType::F32));
let gamma: Tensor<Flex, 2> = Tensor::from_data(
[[1.0f32, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]],
(&Default::default(), DType::F32),
);
let t_p = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_p = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let _ = crate::ops::activation::layer_norm(t_p, g_p, None, 1e-5);
}
#[test]
#[should_panic(expected = "only supports softmax along the last axis")]
fn test_softmax_non_last_axis_panics() {
use burn_tensor::TensorPrimitive;
let t: Tensor<Flex, 2> =
Tensor::from_data([[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]], &Default::default());
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let _ = crate::ops::activation::softmax(primitive, 0);
}
#[test]
fn test_softmax_simd_body_plus_scalar_tail() {
use burn_backend::DType;
use burn_tensor::TensorPrimitive;
let data: Vec<f32> = (0..34).map(|i| (i as f32 * 0.137) - 2.3).collect();
let t: Tensor<Flex, 2> = Tensor::from_data(
TensorData::new(data, [2, 17]),
(&Default::default(), DType::F32),
);
let reference = activation::softmax(t.clone(), 1);
let primitive = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::softmax(primitive, 1);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f32>(&reference.into_data(), Tolerance::absolute(1e-5));
}
#[test]
fn test_layer_norm_simd_body_plus_scalar_tail() {
use burn_backend::DType;
use burn_tensor::TensorPrimitive;
let data: Vec<f32> = (0..34).map(|i| (i as f32 * 0.137) - 2.3).collect();
let t: Tensor<Flex, 2> = Tensor::from_data(
TensorData::new(data, [2, 17]),
(&Default::default(), DType::F32),
);
let gamma_data: Vec<f32> = (0..17).map(|i| 1.0 + i as f32 * 0.05).collect();
let beta_data: Vec<f32> = (0..17).map(|i| i as f32 * 0.01).collect();
let gamma: Tensor<Flex, 1> = Tensor::from_data(
TensorData::new(gamma_data, [17]),
(&Default::default(), DType::F32),
);
let beta: Tensor<Flex, 1> = Tensor::from_data(
TensorData::new(beta_data, [17]),
(&Default::default(), DType::F32),
);
let mean = t.clone().mean_dim(1);
let centered = t.clone() - mean;
let var = centered.clone().powi_scalar(2).mean_dim(1);
let eps = 1e-5f32;
let normed = centered / (var + eps).sqrt();
let reference = normed * gamma.clone().unsqueeze::<2>() + beta.clone().unsqueeze::<2>();
let t_prim = match t.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let g_prim = match gamma.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let b_prim = match beta.into_primitive() {
TensorPrimitive::Float(x) => x,
_ => unreachable!(),
};
let fused = crate::ops::activation::layer_norm(t_prim, g_prim, Some(b_prim), 1e-5);
let fused: Tensor<Flex, 2> = Tensor::from_primitive(TensorPrimitive::Float(fused));
fused
.into_data()
.assert_approx_eq::<f32>(&reference.into_data(), Tolerance::absolute(1e-5));
}
#[test]
fn test_log_sigmoid() {
let t: Tensor<Flex, 1> = Tensor::from_data([-10.0f32, 0.0, 10.0], &Default::default());
activation::log_sigmoid(t)
.into_data()
.assert_approx_eq::<f32>(
&TensorData::from([-10.0, -0.6931472, 0.0]),
Tolerance::absolute(1e-3),
);
}
}