use torsh_core::{dtype::FloatElement, Result as TorshResult};
use torsh_tensor::Tensor;
pub fn sigmoid<T: FloatElement>(input: &Tensor<T>) -> TorshResult<Tensor<T>> {
input.sigmoid()
}
pub fn hardsigmoid<T: FloatElement>(input: &Tensor<T>) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let three = <T as From<f32>>::from(3.0);
let six = <T as From<f32>>::from(6.0);
let zero = <T as torsh_core::dtype::TensorElement>::zero();
let one = <T as torsh_core::dtype::TensorElement>::one();
let data = input.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
let normalized = (x + three) / six;
if normalized < zero {
zero
} else if normalized > one {
one
} else {
normalized
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn hardsigmoid_v2<T: FloatElement>(input: &Tensor<T>, _inplace: bool) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let point_two = <T as From<f32>>::from(0.2);
let point_five = <T as From<f32>>::from(0.5);
let zero = <T as torsh_core::dtype::TensorElement>::zero();
let one = <T as torsh_core::dtype::TensorElement>::one();
let data = input.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
let value = x * point_two + point_five;
if value < zero {
zero
} else if value > one {
one
} else {
value
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn silu<T: FloatElement>(input: &Tensor<T>, _inplace: bool) -> TorshResult<Tensor<T>>
where
T: Copy,
{
let sigmoid_input = input.sigmoid()?;
input.mul(&sigmoid_input)
}
pub fn swish<T: FloatElement>(input: &Tensor<T>, inplace: bool) -> TorshResult<Tensor<T>>
where
T: Copy,
{
silu(input, inplace)
}
pub fn hardswish<T: FloatElement>(input: &Tensor<T>) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let hard_sigmoid_input = hardsigmoid(input)?;
input.mul(&hard_sigmoid_input)
}
pub fn mish<T: FloatElement>(input: &Tensor<T>, inplace: bool) -> TorshResult<Tensor<T>>
where
T: Copy + From<f32>,
{
let softplus_input = softplus_impl(input)?;
let tanh_softplus = softplus_input.tanh()?;
let result = input.mul(&tanh_softplus)?;
if inplace {
Ok(result)
} else {
Ok(result)
}
}
pub fn log_sigmoid<T: FloatElement>(input: &Tensor<T>) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let neg_input = input.neg()?;
let softplus_neg = softplus_impl(&neg_input)?;
softplus_neg.neg()
}
fn softplus_impl<T: FloatElement>(input: &Tensor<T>) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let data = input.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
let exp_x = x.exp();
let one = <T as torsh_core::dtype::TensorElement>::one();
(one + exp_x).ln()
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn softplus<T: FloatElement>(
input: &Tensor<T>,
beta: f64,
threshold: f64,
) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32>,
{
let beta_val = <T as From<f32>>::from(beta as f32);
let threshold_val = <T as From<f32>>::from(threshold as f32);
let data = input.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
let scaled = x * beta_val;
if scaled > threshold_val {
x
} else {
let exp_scaled = scaled.exp();
let one = <T as torsh_core::dtype::TensorElement>::one();
(one + exp_scaled).ln() / beta_val
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_sigmoid_range() -> TorshResult<()> {
let input = from_vec(vec![-10.0, -1.0, 0.0, 1.0, 10.0], &[5], DeviceType::Cpu)?;
let output = sigmoid(&input)?;
let output_data = output.data()?;
for &val in output_data.iter() {
assert!(val > 0.0 && val < 1.0, "Sigmoid value {} not in (0,1)", val);
}
assert!((output_data[2] - 0.5_f32).abs() < 1e-6);
Ok(())
}
#[test]
fn test_hardsigmoid_piecewise() -> TorshResult<()> {
let input = from_vec(vec![-10.0, -3.0, 0.0, 3.0, 10.0], &[5], DeviceType::Cpu)?;
let output = hardsigmoid(&input)?;
let output_data = output.data()?;
assert!(output_data[0] == 0.0); assert!(output_data[1] == 0.0); assert!((output_data[2] - 0.5_f32).abs() < 1e-6); assert!(output_data[3] == 1.0); assert!(output_data[4] == 1.0);
Ok(())
}
#[test]
fn test_silu_properties() -> TorshResult<()> {
let input = from_vec(vec![-1.0, 0.0, 1.0], &[3], DeviceType::Cpu)?;
let output = silu(&input, false)?;
let output_data = output.data()?;
assert!((output_data[1] as f32).abs() < 1e-6);
let large_input = from_vec(vec![10.0], &[1], DeviceType::Cpu)?;
let large_output = silu(&large_input, false)?;
let large_val = large_output.item()?;
assert!((large_val - 10.0_f32).abs() < 0.1);
Ok(())
}
#[test]
fn test_hardswish_efficiency() -> TorshResult<()> {
let input = from_vec(vec![-3.0, 0.0, 3.0], &[3], DeviceType::Cpu)?;
let output = hardswish(&input)?;
let output_data = output.data()?;
assert!(output_data[0] == 0.0);
assert!(output_data[2] == 3.0);
Ok(())
}
#[test]
fn test_log_sigmoid_stability() -> TorshResult<()> {
let input = from_vec(vec![-10.0, 0.0, 10.0], &[3], DeviceType::Cpu)?;
let output = log_sigmoid(&input)?;
let output_data = output.data()?;
for &val in output_data.iter() {
assert!(
val.is_finite(),
"log_sigmoid produced non-finite value: {}",
val
);
}
let expected_zero = -std::f32::consts::LN_2;
assert!((output_data[1] - expected_zero).abs() < 1e-6);
Ok(())
}
#[test]
fn test_softplus_approximation() -> TorshResult<()> {
let input = from_vec(vec![-5.0, 0.0, 5.0], &[3], DeviceType::Cpu)?;
let output = softplus(&input, 1.0, 20.0)?;
let output_data = output.data()?;
assert!(output_data[0] < 0.1); assert!((output_data[1] - std::f32::consts::LN_2).abs() < 1e-6); assert!((output_data[2] - 5.0).abs() < 0.1);
Ok(())
}
}