use torsh_core::{dtype::FloatElement, Result as TorshResult};
use torsh_tensor::Tensor;
pub fn relu<T: FloatElement>(input: &Tensor<T>, inplace: bool) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd,
{
if inplace {
use crate::activations::apply_elementwise_inplace;
let zero = <T as torsh_core::dtype::TensorElement>::zero();
apply_elementwise_inplace(input, inplace, |x| if x < zero { zero } else { x })
} else {
input.relu()
}
}
pub fn leaky_relu<T: FloatElement>(
input: &Tensor<T>,
negative_slope: f64,
inplace: bool,
) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
if inplace {
let data = input.data()?;
let slope = <T as From<f32>>::from(negative_slope as f32);
let zero = <T as torsh_core::dtype::TensorElement>::zero();
let result_data: Vec<T> = data
.iter()
.map(|&x| if x < zero { x * slope } else { x })
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
} else {
input.leaky_relu(<T as From<f32>>::from(negative_slope as f32))
}
}
pub fn elu<T: FloatElement>(input: &Tensor<T>, alpha: f64, _inplace: bool) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let data = input.data()?;
let alpha_val = <T as From<f32>>::from(alpha as f32);
let zero = <T as torsh_core::dtype::TensorElement>::zero();
let one = <T as torsh_core::dtype::TensorElement>::one();
let result_data: Vec<T> = data
.iter()
.map(|&x| {
if x > zero {
x
} else {
alpha_val * (x.exp() - one)
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn selu<T: FloatElement>(input: &Tensor<T>, _inplace: bool) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32> + Default,
{
let alpha = 1.673_263_242_354_377_2;
let scale = 1.050_700_987_355_480_5;
let elu_result = elu(input, alpha, false)?;
elu_result
.mul_scalar(num_traits::cast(scale).expect("f64 scale constant should be castable to f32"))
}
pub fn relu6<T: FloatElement>(input: &Tensor<T>, inplace: bool) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
use crate::activations::apply_elementwise_inplace;
let zero = <T as torsh_core::dtype::TensorElement>::zero();
let six = <T as From<f32>>::from(6.0);
apply_elementwise_inplace(input, inplace, |x| {
if x < zero {
zero
} else if x > six {
six
} else {
x
}
})
}
pub fn prelu<T: FloatElement>(input: &Tensor<T>, weight: &Tensor<T>) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd,
{
let zero = Tensor::zeros_like(input)?;
let positive_mask = input.gt(&zero)?;
let negative_part = input.mul(weight)?;
input.where_tensor(&positive_mask, &negative_part)
}
pub fn celu<T: FloatElement>(input: &Tensor<T>, alpha: f64) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let alpha_val = <T as From<f32>>::from(alpha as f32);
let zero = <T as torsh_core::dtype::TensorElement>::zero();
let one = <T as torsh_core::dtype::TensorElement>::one();
let data = input.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
if x > zero {
x
} else {
alpha_val * ((x / alpha_val).exp() - one)
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn rrelu<T: FloatElement>(
input: &Tensor<T>,
lower: f64,
upper: f64,
training: bool,
inplace: bool,
) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let slope = if training {
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
rng.gen_range(lower..upper)
} else {
(lower + upper) / 2.0
};
leaky_relu(input, slope, inplace)
}
pub fn hardshrink<T: FloatElement>(input: &Tensor<T>, lambd: f64) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let lambda_val = <T as From<f32>>::from(lambd as f32);
let zero = <T as torsh_core::dtype::TensorElement>::zero();
let data = input.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| if x.abs() > lambda_val { x } else { zero })
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn softshrink<T: FloatElement>(input: &Tensor<T>, lambd: f64) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let lambda_val = <T as From<f32>>::from(lambd as f32);
let zero = <T as torsh_core::dtype::TensorElement>::zero();
let data = input.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
if x > lambda_val {
x - lambda_val
} else if x < -lambda_val {
x + lambda_val
} else {
zero
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn threshold<T: FloatElement>(
input: &Tensor<T>,
threshold: f64,
value: f64,
inplace: bool,
) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let threshold_val = <T as From<f32>>::from(threshold as f32);
let replace_val = <T as From<f32>>::from(value as f32);
let data = input.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| if x > threshold_val { x } else { replace_val })
.collect();
let result = Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())?;
if inplace {
Ok(result)
} else {
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_relu_basic() -> TorshResult<()> {
let input = from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5], DeviceType::Cpu)?;
let output = relu(&input, false)?;
let expected = from_vec(vec![0.0, 0.0, 0.0, 1.0, 2.0], &[5], DeviceType::Cpu)?;
let output_data = output.data()?;
let expected_data = expected.data()?;
for (i, (&out, &exp)) in output_data.iter().zip(expected_data.iter()).enumerate() {
assert!(
((out - exp) as f32).abs() < 1e-6,
"Mismatch at index {}: {} vs {}",
i,
out,
exp
);
}
Ok(())
}
#[test]
fn test_leaky_relu_basic() -> TorshResult<()> {
let input = from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5], DeviceType::Cpu)?;
let output = leaky_relu(&input, 0.1, false)?;
let output_data = output.data()?;
let expected = vec![-0.2, -0.1, 0.0, 1.0, 2.0];
for (i, (&out, &exp)) in output_data.iter().zip(expected.iter()).enumerate() {
assert!(
((out - exp) as f32).abs() < 1e-6,
"Mismatch at index {}: {} vs {}",
i,
out,
exp
);
}
Ok(())
}
#[test]
fn test_relu6_clipping() -> TorshResult<()> {
let input = from_vec(vec![-1.0, 3.0, 8.0], &[3], DeviceType::Cpu)?;
let output = relu6(&input, false)?;
let output_data = output.data()?;
let expected = vec![0.0, 3.0, 6.0];
for (i, (&out, &exp)) in output_data.iter().zip(expected.iter()).enumerate() {
assert!(
((out - exp) as f32).abs() < 1e-6,
"Mismatch at index {}: {} vs {}",
i,
out,
exp
);
}
Ok(())
}
#[test]
fn test_elu_properties() -> TorshResult<()> {
let input = from_vec(vec![-1.0, 0.0, 1.0], &[3], DeviceType::Cpu)?;
let output = elu(&input, 1.0, false)?;
let output_data = output.data()?;
assert!((output_data[1] - 0.0_f32).abs() < 1e-6); assert!((output_data[2] - 1.0_f32).abs() < 1e-6);
assert!((output_data[0] + 0.632_f32).abs() < 0.01);
Ok(())
}
}