use torsh_core::{dtype::FloatElement, Result as TorshResult};
use torsh_tensor::Tensor;
#[inline]
fn apply_inplace_elementwise<T, F>(input: &mut Tensor<T>, operation: F) -> TorshResult<()>
where
T: FloatElement + Copy,
F: Fn(T) -> T,
{
let data = input.data()?;
let result_data: Vec<T> = data.iter().map(|&x| operation(x)).collect();
*input = Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())?;
Ok(())
}
pub fn relu_<T: FloatElement>(input: &mut Tensor<T>) -> TorshResult<()>
where
T: Copy + PartialOrd + torsh_core::dtype::TensorElement + Default,
{
let zero = <T as torsh_core::dtype::TensorElement>::zero();
apply_inplace_elementwise(input, |x| if x < zero { zero } else { x })
}
pub fn sigmoid_<T: FloatElement>(input: &mut Tensor<T>) -> TorshResult<()>
where
T: Copy + PartialOrd + torsh_core::dtype::TensorElement + Default,
{
let one = <T as torsh_core::dtype::TensorElement>::one();
let zero = <T as torsh_core::dtype::TensorElement>::zero();
apply_inplace_elementwise(input, |x| {
if x >= zero {
one / (one + (-x).exp())
} else {
let exp_x = x.exp();
exp_x / (one + exp_x)
}
})
}
pub fn tanh_<T: FloatElement>(input: &mut Tensor<T>) -> TorshResult<()>
where
T: Copy + PartialOrd + torsh_core::dtype::TensorElement + Default,
{
let one = <T as torsh_core::dtype::TensorElement>::one();
let two = one + one;
let zero = <T as torsh_core::dtype::TensorElement>::zero();
apply_inplace_elementwise(input, |x| {
if x >= zero {
let exp_2x = (two * x).exp();
(exp_2x - one) / (exp_2x + one)
} else {
let exp_neg_2x = (-two * x).exp();
(one - exp_neg_2x) / (one + exp_neg_2x)
}
})
}
pub fn leaky_relu_<T: FloatElement>(input: &mut Tensor<T>, negative_slope: f64) -> TorshResult<()>
where
T: Copy + PartialOrd + From<f32> + torsh_core::dtype::TensorElement,
{
let slope = <T as From<f32>>::from(negative_slope as f32);
let zero = <T as torsh_core::dtype::TensorElement>::zero();
apply_inplace_elementwise(input, |x| if x >= zero { x } else { x * slope })
}
pub fn gelu_<T: FloatElement>(input: &mut Tensor<T>) -> TorshResult<()>
where
T: Copy + PartialOrd + From<f32> + torsh_core::dtype::TensorElement,
{
let half = <T as From<f32>>::from(0.5);
let one = <T as torsh_core::dtype::TensorElement>::one();
let sqrt_2_over_pi = <T as From<f32>>::from(0.797884561); let coeff = <T as From<f32>>::from(0.044715);
apply_inplace_elementwise(input, |x| {
let x_cubed = x * x * x;
let inner = sqrt_2_over_pi * (x + coeff * x_cubed);
let tanh_val = inner.tanh();
half * x * (one + tanh_val)
})
}
pub fn silu_<T: FloatElement>(input: &mut Tensor<T>) -> TorshResult<()>
where
T: Copy + PartialOrd + torsh_core::dtype::TensorElement,
{
let one = <T as torsh_core::dtype::TensorElement>::one();
let zero = <T as torsh_core::dtype::TensorElement>::zero();
apply_inplace_elementwise(input, |x| {
let sigmoid_x = if x >= zero {
one / (one + (-x).exp())
} else {
let exp_x = x.exp();
exp_x / (one + exp_x)
};
x * sigmoid_x
})
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_in_place_relu() -> TorshResult<()> {
let mut input = from_vec(vec![-2.0f32, -1.0, 0.0, 1.0, 2.0], &[5], DeviceType::Cpu)?;
let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
relu_(&mut input)?;
let data = input.data()?;
for (i, (&actual, &expected)) in data.iter().zip(expected.iter()).enumerate() {
assert!(
(actual - expected).abs() < 1e-6,
"ReLU mismatch at index {}: {} vs {}",
i,
actual,
expected
);
}
Ok(())
}
#[test]
fn test_in_place_sigmoid() -> TorshResult<()> {
let mut input = from_vec(vec![-2.0f32, 0.0, 2.0], &[3], DeviceType::Cpu)?;
sigmoid_(&mut input)?;
let data = input.data()?;
for &val in data.iter() {
assert!(
val > 0.0 && val < 1.0,
"Sigmoid output {} not in (0,1)",
val
);
}
assert!(
(data[1] - 0.5).abs() < 1e-6,
"sigmoid(0) should be 0.5, got {}",
data[1]
);
assert!(data[0] < data[1]);
assert!(data[1] < data[2]);
Ok(())
}
#[test]
fn test_in_place_tanh() -> TorshResult<()> {
let mut input = from_vec(vec![-2.0f32, 0.0, 2.0], &[3], DeviceType::Cpu)?;
tanh_(&mut input)?;
let data = input.data()?;
for &val in data.iter() {
assert!(val > -1.0 && val < 1.0, "Tanh output {} not in (-1,1)", val);
}
assert!(
(data[1] - 0.0).abs() < 1e-6,
"tanh(0) should be 0, got {}",
data[1]
);
assert!(data[0] < data[1]);
assert!(data[1] < data[2]);
assert!(data[0] < -0.95); assert!(data[2] > 0.95);
Ok(())
}
#[test]
fn test_in_place_leaky_relu() -> TorshResult<()> {
let mut input = from_vec(vec![-2.0f32, -1.0, 0.0, 1.0, 2.0], &[5], DeviceType::Cpu)?;
let negative_slope = 0.1;
leaky_relu_(&mut input, negative_slope)?;
let data = input.data()?;
let expected = vec![-0.2, -0.1, 0.0, 1.0, 2.0];
for (i, (&actual, &expected)) in data.iter().zip(expected.iter()).enumerate() {
assert!(
(actual - expected).abs() < 1e-6,
"LeakyReLU mismatch at index {}: {} vs {}",
i,
actual,
expected
);
}
Ok(())
}
#[test]
fn test_in_place_gelu() -> TorshResult<()> {
let mut input = from_vec(vec![-1.0f32, 0.0, 1.0], &[3], DeviceType::Cpu)?;
gelu_(&mut input)?;
let data = input.data()?;
assert!(
(data[1] - 0.0).abs() < 1e-6,
"GELU(0) should be 0, got {}",
data[1]
);
assert!(data[2] > 0.0, "GELU(1) should be positive, got {}", data[2]);
assert!(
data[0] < 0.0 && data[0] > -0.5,
"GELU(-1) should be small negative, got {}",
data[0]
);
Ok(())
}
#[test]
fn test_in_place_silu() -> TorshResult<()> {
let mut input = from_vec(vec![-1.0f32, 0.0, 1.0], &[3], DeviceType::Cpu)?;
silu_(&mut input)?;
let data = input.data()?;
assert!(
(data[1] - 0.0).abs() < 1e-6,
"SiLU(0) should be 0, got {}",
data[1]
);
assert!(data[2] > 0.0, "SiLU(1) should be positive, got {}", data[2]);
assert!(
data[0] < 0.0 && data[0] > -1.0,
"SiLU(-1) should be small negative, got {}",
data[0]
);
Ok(())
}
#[test]
fn test_in_place_memory_efficiency() -> TorshResult<()> {
let mut original = from_vec(vec![1.0f32, -1.0, 2.0, -2.0], &[4], DeviceType::Cpu)?;
let original_data = original.data()?.clone();
relu_(&mut original)?;
let modified_data = original.data()?;
assert_eq!(modified_data[0], 1.0); assert_eq!(modified_data[1], 0.0); assert_eq!(modified_data[2], 2.0); assert_eq!(modified_data[3], 0.0);
assert_ne!(original_data[1], modified_data[1]);
assert_ne!(original_data[3], modified_data[3]);
Ok(())
}
#[test]
fn test_in_place_numerical_stability() -> TorshResult<()> {
let mut large_input = from_vec(vec![-100.0f32, 0.0, 100.0], &[3], DeviceType::Cpu)?;
sigmoid_(&mut large_input)?;
let data = large_input.data()?;
for &val in data.iter() {
assert!(
val.is_finite(),
"Sigmoid produced non-finite value: {}",
val
);
assert!(
val >= 0.0 && val <= 1.0,
"Sigmoid value {} not in [0,1]",
val
);
}
assert!(
data[0] < 1e-10,
"sigmoid(-100) should be ~0, got {}",
data[0]
);
assert!(
(data[1] - 0.5).abs() < 1e-6,
"sigmoid(0) should be 0.5, got {}",
data[1]
);
assert!(
data[2] > 0.9999999,
"sigmoid(100) should be ~1, got {}",
data[2]
);
Ok(())
}
}