pub mod advanced;
pub mod inplace;
pub mod relu_family;
pub mod sigmoid_family;
pub mod softmax_family;
pub mod tanh_family;
use torsh_core::dtype::FloatElement;
use torsh_core::Result as TorshResult;
use torsh_tensor::Tensor;
pub fn apply_elementwise<T, F>(input: &Tensor<T>, operation: F) -> TorshResult<Tensor<T>>
where
T: FloatElement + Copy,
F: Fn(T) -> T,
{
let data = input.data()?;
let result_data: Vec<T> = data.iter().map(|&x| operation(x)).collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn apply_elementwise_inplace<T, F>(
input: &Tensor<T>,
_inplace: bool,
operation: F,
) -> TorshResult<Tensor<T>>
where
T: FloatElement + Copy,
F: Fn(T) -> T,
{
apply_elementwise(input, operation)
}
pub use relu_family::{
celu, elu, hardshrink, leaky_relu, prelu, relu, relu6, rrelu, selu, softshrink, threshold,
};
pub use sigmoid_family::{
hardsigmoid, hardsigmoid_v2, hardswish, log_sigmoid, mish, sigmoid, silu, softplus, swish,
};
pub use tanh_family::{hardtanh, softsign, tanh, tanhshrink};
pub use softmax_family::{gumbel_softmax, log_softmax, softmax, softmin};
pub use advanced::{gelu, glu, local_response_norm, scaled_dot_product_attention};
pub use inplace::{gelu_, leaky_relu_, relu_, sigmoid_, silu_, tanh_};
#[cfg(test)]
mod integration_tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_activation_functions_integration() -> torsh_core::Result<()> {
let device = DeviceType::Cpu;
let input = from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5], device)?;
let _relu_out = relu(&input, false)?;
let _leaky_relu_out = leaky_relu(&input, 0.1, false)?;
let _elu_out = elu(&input, 1.0, false)?;
let _selu_out = selu(&input, false)?;
let _sigmoid_out = sigmoid(&input)?;
let _silu_out = silu(&input, false)?;
let _mish_out = mish(&input, false)?;
let _tanh_out = tanh(&input)?;
let _softsign_out = softsign(&input)?;
let _hardtanh_out = hardtanh(&input, -1.0, 1.0)?;
let _gelu_out = gelu(&input)?;
let logits = from_vec(vec![1.0, 2.0, 3.0], &[3], device)?;
let _softmax_out = softmax(&logits, 0, None)?;
let _log_softmax_out = log_softmax(&logits, 0, None)?;
Ok(())
}
#[test]
fn test_activation_functions_numerical_stability() -> torsh_core::Result<()> {
let device = DeviceType::Cpu;
let extreme_input = from_vec(vec![-100.0, -1e-8, 0.0, 1e-8, 100.0], &[5], device)?;
let sigmoid_out = sigmoid(&extreme_input)?;
let sigmoid_data = sigmoid_out.data()?;
for &val in sigmoid_data.iter() {
let val: f32 = val;
assert!(
val.is_finite() && !val.is_nan(),
"Sigmoid produced invalid value: {}",
val
);
assert!(
val >= 0.0 && val <= 1.0,
"Sigmoid value {} not in [0,1]",
val
);
}
let tanh_out = tanh(&extreme_input)?;
let tanh_data = tanh_out.data()?;
for &val in tanh_data.iter() {
let val: f32 = val;
assert!(
val.is_finite() && !val.is_nan(),
"Tanh produced invalid value: {}",
val
);
assert!(
val >= -1.0 && val <= 1.0,
"Tanh value {} not in [-1,1]",
val
);
}
let gelu_out = gelu(&extreme_input)?;
let gelu_data = gelu_out.data()?;
for &val in gelu_data.iter() {
let val: f32 = val;
assert!(
val.is_finite() && !val.is_nan(),
"GELU produced invalid value: {}",
val
);
}
Ok(())
}
#[test]
fn test_inplace_operations() -> torsh_core::Result<()> {
let device = DeviceType::Cpu;
let mut input = from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5], device)?;
relu_(&mut input)?;
let data = input.data()?;
assert_eq!(data[0], 0.0); assert_eq!(data[1], 0.0); assert_eq!(data[2], 0.0); assert_eq!(data[3], 1.0); assert_eq!(data[4], 2.0);
let mut input2 = from_vec(vec![0.0], &[1], device)?;
sigmoid_(&mut input2)?;
let data2 = input2.data()?;
assert!((data2[0] - 0.5_f32).abs() < 1e-6);
Ok(())
}
#[test]
fn test_activation_output_ranges() -> torsh_core::Result<()> {
let device = DeviceType::Cpu;
let input = from_vec(vec![-5.0, -1.0, 0.0, 1.0, 5.0], &[5], device)?;
let relu_out = relu(&input, false)?;
let relu_data = relu_out.data()?;
for &val in relu_data.iter() {
assert!(val >= 0.0, "ReLU output {} should be non-negative", val);
}
let sigmoid_out = sigmoid(&input)?;
let sigmoid_data = sigmoid_out.data()?;
for &val in sigmoid_data.iter() {
assert!(
val > 0.0 && val < 1.0,
"Sigmoid output {} not in (0,1)",
val
);
}
let tanh_out = tanh(&input)?;
let tanh_data = tanh_out.data()?;
for &val in tanh_data.iter() {
assert!(val > -1.0 && val < 1.0, "Tanh output {} not in (-1,1)", val);
}
let logits = from_vec(vec![1.0, 2.0, 3.0], &[3], device)?;
let softmax_out = softmax(&logits, 0, None)?;
let softmax_data = softmax_out.data()?;
let sum: f32 = softmax_data.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"Softmax should sum to 1, got {}",
sum
);
for &val in softmax_data.iter() {
assert!(
val >= 0.0 && val <= 1.0,
"Softmax output {} not in [0,1]",
val
);
}
Ok(())
}
#[test]
fn test_activation_monotonicity() -> torsh_core::Result<()> {
let device = DeviceType::Cpu;
let input = from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5], device)?;
let monotonic_functions = vec![
("relu", relu(&input, false)?),
("sigmoid", sigmoid(&input)?),
("tanh", tanh(&input)?),
];
for (name, output) in monotonic_functions {
let data = output.data()?;
for i in 1..data.len() {
assert!(
data[i] >= data[i - 1],
"{} should be monotonic: {} < {} at indices {}, {}",
name,
data[i],
data[i - 1],
i,
i - 1
);
}
}
Ok(())
}
}