use burn_backend::Scalar;
use burn_backend::ops::ActivationOps;
use burn_backend::tensor::FloatTensor;
use num_traits::ToPrimitive;
use crate::Flex;
use crate::ops::binary::binary_op;
use crate::ops::unary::unary_op;
impl ActivationOps<Flex> for Flex {
fn relu(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
unary_op(tensor, |x: f32| x.max(0.0), |x: f64| x.max(0.0))
}
fn relu_backward(output: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
binary_op(
output,
grad,
|out: f32, g| if out > 0.0 { g } else { 0.0 },
|out: f64, g| if out > 0.0 { g } else { 0.0 },
None,
)
}
fn leaky_relu(tensor: FloatTensor<Flex>, negative_slope: Scalar) -> FloatTensor<Flex> {
let ns32 = negative_slope.to_f32().unwrap();
let ns64 = negative_slope.to_f64().unwrap();
unary_op(
tensor,
move |x: f32| if x >= 0.0 { x } else { ns32 * x },
move |x: f64| if x >= 0.0 { x } else { ns64 * x },
)
}
fn prelu(tensor: FloatTensor<Flex>, alpha: FloatTensor<Flex>) -> FloatTensor<Flex> {
binary_op(
tensor,
alpha,
|x: f32, a| if x >= 0.0 { x } else { a * x },
|x: f64, a| if x >= 0.0 { x } else { a * x },
None,
)
}
fn gelu(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
use crate::ops::unary::{erf_f32, erf_f64};
let sqrt2_f32: f32 = core::f32::consts::SQRT_2;
let sqrt2_f64: f64 = core::f64::consts::SQRT_2;
unary_op(
tensor,
move |x: f32| 0.5 * x * (1.0 + erf_f32(x / sqrt2_f32)),
move |x: f64| 0.5 * x * (1.0 + erf_f64(x / sqrt2_f64)),
)
}
fn gelu_backward(x: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
use crate::ops::unary::{erf_f32, erf_f64};
let sqrt2_f32: f32 = core::f32::consts::SQRT_2;
let sqrt2_f64: f64 = core::f64::consts::SQRT_2;
let inv_sqrt_2pi_f32: f32 = 1.0 / (2.0 * core::f32::consts::PI).sqrt();
let inv_sqrt_2pi_f64: f64 = 1.0 / (2.0 * core::f64::consts::PI).sqrt();
binary_op(
x,
grad,
move |x: f32, g| {
let cdf = 0.5 * (1.0 + erf_f32(x / sqrt2_f32));
let pdf = inv_sqrt_2pi_f32 * (-0.5 * x * x).exp();
g * (cdf + x * pdf)
},
move |x: f64, g| {
let cdf = 0.5 * (1.0 + erf_f64(x / sqrt2_f64));
let pdf = inv_sqrt_2pi_f64 * (-0.5 * x * x).exp();
g * (cdf + x * pdf)
},
None,
)
}
fn sigmoid(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
unary_op(tensor, sigmoid_f32, sigmoid_f64)
}
fn sigmoid_backward(output: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
binary_op(
output,
grad,
|s: f32, g| g * s * (1.0 - s),
|s: f64, g| g * s * (1.0 - s),
None,
)
}
fn hard_sigmoid(tensor: FloatTensor<Flex>, alpha: Scalar, beta: Scalar) -> FloatTensor<Flex> {
let alpha32 = alpha.to_f32().unwrap();
let beta32 = beta.to_f32().unwrap();
let alpha64 = alpha.to_f64().unwrap();
let beta64 = beta.to_f64().unwrap();
unary_op(
tensor,
move |x: f32| (alpha32 * x + beta32).clamp(0.0, 1.0),
move |x: f64| (alpha64 * x + beta64).clamp(0.0, 1.0),
)
}
fn log_sigmoid(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
unary_op(
tensor,
|x: f32| {
if x >= 0.0 {
-((-x).exp().ln_1p())
} else {
x - x.exp().ln_1p()
}
},
|x: f64| {
if x >= 0.0 {
-((-x).exp().ln_1p())
} else {
x - x.exp().ln_1p()
}
},
)
}
fn log_sigmoid_backward(x: FloatTensor<Flex>, grad: FloatTensor<Flex>) -> FloatTensor<Flex> {
binary_op(
x,
grad,
|x: f32, g| g * sigmoid_f32(-x),
|x: f64, g| g * sigmoid_f64(-x),
None,
)
}
}
#[inline]
fn sigmoid_f32(x: f32) -> f32 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let e = x.exp();
e / (1.0 + e)
}
}
#[inline]
fn sigmoid_f64(x: f64) -> f64 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let e = x.exp();
e / (1.0 + e)
}
}
#[cfg(test)]
mod tests {
use burn_backend::Tolerance;
use burn_tensor::{Tensor, TensorData, activation};
use crate::Flex;
#[test]
fn test_relu() {
let t: Tensor<Flex, 1> =
Tensor::from_data([-2.0f32, -1.0, 0.0, 1.0, 2.0], &Default::default());
activation::relu(t).into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 0.0, 0.0, 1.0, 2.0]),
Tolerance::absolute(1e-6),
);
}
#[test]
fn test_sigmoid() {
let t: Tensor<Flex, 1> = Tensor::from_data([-10.0f32, 0.0, 10.0], &Default::default());
activation::sigmoid(t).into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 0.5, 1.0]),
Tolerance::absolute(1e-3),
);
}
#[test]
fn test_gelu() {
let t: Tensor<Flex, 1> = Tensor::from_data([-3.0f32, 0.0, 3.0], &Default::default());
activation::gelu(t).into_data().assert_approx_eq::<f32>(
&TensorData::from([0.0, 0.0, 3.0]),
Tolerance::absolute(0.01),
);
}
#[test]
fn test_leaky_relu() {
let t: Tensor<Flex, 1> =
Tensor::from_data([-2.0f32, -1.0, 0.0, 1.0, 2.0], &Default::default());
activation::leaky_relu(t, 0.01)
.into_data()
.assert_approx_eq::<f32>(
&TensorData::from([-0.02, -0.01, 0.0, 1.0, 2.0]),
Tolerance::absolute(1e-6),
);
}
#[test]
fn test_log_sigmoid() {
let t: Tensor<Flex, 1> = Tensor::from_data([-10.0f32, 0.0, 10.0], &Default::default());
activation::log_sigmoid(t)
.into_data()
.assert_approx_eq::<f32>(
&TensorData::from([-10.0, -0.6931472, 0.0]),
Tolerance::absolute(1e-3),
);
}
}