use super::core::{Activation, FuncResult, FunctionalConfig};
use crate::{func_error, validate_inputs};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
pub fn relu(input: &Tensor) -> Result<Tensor> {
let zeros = torsh_tensor::creation::zeros_like(input)?;
input.maximum(&zeros)
}
pub fn relu_inplace(input: &mut Tensor) -> Result<()> {
let zeros = torsh_tensor::creation::zeros_like(input)?;
*input = input.maximum(&zeros)?;
Ok(())
}
pub fn leaky_relu(input: &Tensor, negative_slope: f32) -> Result<Tensor> {
let zeros = torsh_tensor::creation::zeros_like(input)?;
let positive_part = input.maximum(&zeros)?;
let negative_part = input.minimum(&zeros)?;
let slope_tensor = torsh_tensor::creation::full_like(input, negative_slope)?;
let scaled_negative = negative_part.mul_op(&slope_tensor)?;
positive_part.add(&scaled_negative)
}
pub fn gelu(input: &Tensor) -> Result<Tensor> {
let factor = torsh_tensor::creation::full_like(input, 1.702)?;
let scaled = input.mul_op(&factor)?;
let sigmoid_result = sigmoid(&scaled)?;
input.mul_op(&sigmoid_result)
}
pub fn sigmoid(input: &Tensor) -> Result<Tensor> {
let data = input.to_vec()?;
let result_data: Vec<f32> = data
.iter()
.map(|&x| {
if x > 0.0 {
let exp_neg_x = (-x).exp();
1.0 / (1.0 + exp_neg_x)
} else {
let exp_x = x.exp();
exp_x / (1.0 + exp_x)
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn softmax(input: &Tensor, dim: Option<i32>) -> Result<Tensor> {
let dim = dim.unwrap_or(-1);
let shape = input.shape();
if shape.dims().len() == 2 && dim == 1 {
let data = input.to_vec()?;
let rows = shape.dims()[0];
let cols = shape.dims()[1];
let mut result_data = vec![0.0; data.len()];
for row in 0..rows {
let row_start = row * cols;
let row_end = (row + 1) * cols;
let row_data = &data[row_start..row_end];
let max_val = row_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut exp_sum = 0.0;
let mut exp_vals = Vec::with_capacity(cols);
for &x in row_data {
let exp_val = (x - max_val).exp();
exp_vals.push(exp_val);
exp_sum += exp_val;
}
for (i, exp_val) in exp_vals.into_iter().enumerate() {
result_data[row_start + i] = exp_val / exp_sum;
}
}
return Tensor::from_data(result_data, shape.dims().to_vec(), input.device());
}
let data = input.to_vec()?;
let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let shifted_data: Vec<f32> = data.iter().map(|&x| x - max_val).collect();
let exp_data: Vec<f32> = shifted_data.iter().map(|&x| x.exp()).collect();
let sum_exp: f32 = exp_data.iter().sum();
let result_data: Vec<f32> = exp_data.iter().map(|&x| x / sum_exp).collect();
Tensor::from_data(result_data, shape.dims().to_vec(), input.device())
}
pub fn log_softmax(input: &Tensor, dim: Option<i32>) -> Result<Tensor> {
let dim = dim.unwrap_or(-1);
let max_vals = input.max_dim(dim, true)?;
let shifted = input.sub(&max_vals)?;
let exp_vals = shifted.exp()?;
let sum_exp = exp_vals.sum_dim(&[dim], true)?;
let log_sum_exp = sum_exp.log()?;
shifted.sub(&log_sum_exp)
}
pub fn tanh(input: &Tensor) -> Result<Tensor> {
let data = input.to_vec()?;
let result_data: Vec<f32> = data
.iter()
.map(|&x| {
if x > 20.0 {
1.0 } else if x < -20.0 {
-1.0 } else {
let exp_2x = (2.0 * x).exp();
if exp_2x.is_infinite() {
if x > 0.0 {
1.0
} else {
-1.0
}
} else {
(exp_2x - 1.0) / (exp_2x + 1.0)
}
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub fn swish(input: &Tensor) -> Result<Tensor> {
let sigmoid_result = sigmoid(input)?;
input.mul_op(&sigmoid_result)
}
pub fn mish(input: &Tensor) -> Result<Tensor> {
let exp_input = input.exp()?;
let ones = torsh_tensor::creation::ones_like(input)?;
let softplus = exp_input.add(&ones)?.log()?;
let tanh_result = tanh(&softplus)?;
input.mul_op(&tanh_result)
}
pub fn elu(input: &Tensor, alpha: f32) -> Result<Tensor> {
let zeros = torsh_tensor::creation::zeros_like(input)?;
let positive_mask = input.gt(&zeros)?;
let exp_input = input.exp()?;
let ones = torsh_tensor::creation::ones_like(input)?;
let alpha_tensor = torsh_tensor::creation::full_like(input, alpha)?;
let negative_part = alpha_tensor.mul_op(&exp_input.sub(&ones)?)?;
input.where_tensor(&positive_mask, &negative_part)
}
pub fn selu(input: &Tensor) -> Result<Tensor> {
let alpha = 1.6732632423543772;
let scale = 1.0507009873554805;
let elu_result = elu(input, alpha)?;
let scale_tensor = torsh_tensor::creation::full_like(input, scale)?;
elu_result.mul_op(&scale_tensor)
}
pub fn dropout(input: &Tensor, p: f32, training: bool) -> Result<Tensor> {
use scirs2_core::random::thread_rng;
if !training || p == 0.0 {
return Ok(input.clone());
}
if p == 1.0 {
let shape = input.shape().dims().to_vec();
return torsh_tensor::creation::zeros(&shape);
}
if !(0.0..=1.0).contains(&p) {
return Err(TorshError::InvalidArgument(format!(
"Dropout probability must be between 0 and 1, got {}",
p
)));
}
let data = input.data()?;
let scale = 1.0 / (1.0 - p);
let mut rng = thread_rng();
let result_data: Vec<f32> = data
.iter()
.map(|&x| {
let random_val: f32 = rng.random();
if random_val < p {
0.0 } else {
x * scale }
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
pub mod configured {
use super::super::core::validation;
use super::*;
pub fn relu_configured(input: &Tensor, config: &FunctionalConfig) -> FuncResult<Tensor> {
validate_inputs!(config, validation::validate_not_empty(input, "input"));
func_error!(relu(input), "ReLU activation")
}
pub fn sigmoid_configured(input: &Tensor, config: &FunctionalConfig) -> FuncResult<Tensor> {
validate_inputs!(config, validation::validate_not_empty(input, "input"));
func_error!(sigmoid(input), "Sigmoid activation")
}
pub fn tanh_configured(input: &Tensor, config: &FunctionalConfig) -> FuncResult<Tensor> {
validate_inputs!(config, validation::validate_not_empty(input, "input"));
func_error!(tanh(input), "Tanh activation")
}
pub fn softmax_configured(
input: &Tensor,
dim: Option<i32>,
config: &FunctionalConfig,
) -> FuncResult<Tensor> {
validate_inputs!(config, validation::validate_not_empty(input, "input"));
func_error!(softmax(input, dim), "Softmax activation")
}
pub fn gelu_configured(input: &Tensor, config: &FunctionalConfig) -> FuncResult<Tensor> {
validate_inputs!(config, validation::validate_not_empty(input, "input"));
func_error!(gelu(input), "GELU activation")
}
pub fn swish_configured(input: &Tensor, config: &FunctionalConfig) -> FuncResult<Tensor> {
validate_inputs!(config, validation::validate_not_empty(input, "input"));
func_error!(swish(input), "Swish activation")
}
pub fn mish_configured(input: &Tensor, config: &FunctionalConfig) -> FuncResult<Tensor> {
validate_inputs!(config, validation::validate_not_empty(input, "input"));
func_error!(mish(input), "Mish activation")
}
}
pub struct ReLU {
inplace: bool,
}
impl ReLU {
pub fn new(inplace: bool) -> Self {
Self { inplace }
}
}
impl Activation for ReLU {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
if self.inplace {
let mut result = input.clone();
relu_inplace(&mut result)?;
Ok(result)
} else {
relu(input).map_err(|e| e.into())
}
}
}
pub struct Sigmoid;
impl Sigmoid {
pub fn new() -> Self {
Self
}
}
impl Default for Sigmoid {
fn default() -> Self {
Self::new()
}
}
impl Activation for Sigmoid {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
sigmoid(input).map_err(|e| e.into())
}
}
pub struct Tanh;
impl Tanh {
pub fn new() -> Self {
Self
}
}
impl Default for Tanh {
fn default() -> Self {
Self::new()
}
}
impl Activation for Tanh {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
tanh(input).map_err(|e| e.into())
}
}
pub struct GELU;
impl GELU {
pub fn new() -> Self {
Self
}
}
impl Default for GELU {
fn default() -> Self {
Self::new()
}
}
impl Activation for GELU {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
gelu(input).map_err(|e| e.into())
}
}
pub struct Swish;
impl Swish {
pub fn new() -> Self {
Self
}
}
impl Default for Swish {
fn default() -> Self {
Self::new()
}
}
impl Activation for Swish {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
swish(input).map_err(|e| e.into())
}
}
pub struct Mish;
impl Mish {
pub fn new() -> Self {
Self
}
}
impl Default for Mish {
fn default() -> Self {
Self::new()
}
}
impl Activation for Mish {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
mish(input).map_err(|e| e.into())
}
}
pub struct ELU {
alpha: f32,
}
impl ELU {
pub fn new(alpha: f32) -> Self {
Self { alpha }
}
}
impl Default for ELU {
fn default() -> Self {
Self::new(1.0)
}
}
impl Activation for ELU {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
elu(input, self.alpha).map_err(|e| e.into())
}
}
pub struct SELU;
impl SELU {
pub fn new() -> Self {
Self
}
}
impl Default for SELU {
fn default() -> Self {
Self::new()
}
}
impl Activation for SELU {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
selu(input).map_err(|e| e.into())
}
}
pub struct LeakyReLU {
negative_slope: f32,
}
impl LeakyReLU {
pub fn new(negative_slope: f32) -> Self {
Self { negative_slope }
}
}
impl Default for LeakyReLU {
fn default() -> Self {
Self::new(0.01)
}
}
impl Activation for LeakyReLU {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
leaky_relu(input, self.negative_slope).map_err(|e| e.into())
}
}
pub struct Softmax {
dim: i32,
}
impl Softmax {
pub fn new(dim: i32) -> Self {
Self { dim }
}
}
impl Default for Softmax {
fn default() -> Self {
Self::new(-1)
}
}
impl Activation for Softmax {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
softmax(input, Some(self.dim)).map_err(|e| e.into())
}
}
pub struct LogSoftmax {
dim: i32,
}
impl LogSoftmax {
pub fn new(dim: i32) -> Self {
Self { dim }
}
}
impl Default for LogSoftmax {
fn default() -> Self {
Self::new(-1)
}
}
impl Activation for LogSoftmax {
fn apply(&self, input: &Tensor) -> FuncResult<Tensor> {
log_softmax(input, Some(self.dim)).map_err(|e| e.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_dropout_training_p_zero() -> Result<()> {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
let output = dropout(&input, 0.0, true)?;
let input_data = input.to_vec()?;
let output_data = output.to_vec()?;
assert_eq!(input_data.len(), output_data.len());
for (i, o) in input_data.iter().zip(output_data.iter()) {
assert_relative_eq!(i, o, epsilon = 1e-6);
}
Ok(())
}
#[test]
fn test_dropout_training_p_one() -> Result<()> {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
let output = dropout(&input, 1.0, true)?;
let output_data = output.to_vec()?;
for &val in output_data.iter() {
assert_relative_eq!(val, 0.0, epsilon = 1e-6);
}
Ok(())
}
#[test]
fn test_dropout_eval_mode() -> Result<()> {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
let output = dropout(&input, 0.5, false)?;
let input_data = input.to_vec()?;
let output_data = output.to_vec()?;
assert_eq!(input_data.len(), output_data.len());
for (i, o) in input_data.iter().zip(output_data.iter()) {
assert_relative_eq!(i, o, epsilon = 1e-6);
}
Ok(())
}
#[test]
fn test_dropout_training_p_half() -> Result<()> {
let size = 1000;
let input_data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let input = Tensor::from_vec(input_data.clone(), &[size])?;
let output = dropout(&input, 0.5, true)?;
let output_data = output.to_vec()?;
let zeros_count = output_data.iter().filter(|&&x| x == 0.0).count();
assert!(
zeros_count >= 400 && zeros_count <= 600,
"Expected 400-600 zeros, got {}",
zeros_count
);
Ok(())
}
#[test]
fn test_dropout_scaling() -> Result<()> {
let size = 10000;
let input_data: Vec<f32> = vec![1.0; size];
let input = Tensor::from_vec(input_data, &[size])?;
let p = 0.3;
let output = dropout(&input, p, true)?;
let output_data = output.to_vec()?;
let non_zeros: Vec<f32> = output_data.iter().filter(|&&x| x != 0.0).copied().collect();
if !non_zeros.is_empty() {
let mean_non_zero: f32 = non_zeros.iter().sum::<f32>() / non_zeros.len() as f32;
let expected_scale = 1.0 / (1.0 - p);
assert_relative_eq!(mean_non_zero, expected_scale, epsilon = 0.01);
}
let total_mean: f32 = output_data.iter().sum::<f32>() / output_data.len() as f32;
assert_relative_eq!(total_mean, 1.0, epsilon = 0.1);
Ok(())
}
#[test]
fn test_dropout_shape_preservation() -> Result<()> {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])?;
let output = dropout(&input, 0.5, true)?;
assert_eq!(input.shape().dims(), output.shape().dims());
assert_eq!(input.shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_dropout_invalid_p_negative() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("Tensor should succeed");
let result = dropout(&input, -0.1, true);
assert!(result.is_err());
if let Err(TorshError::InvalidArgument(msg)) = result {
assert!(msg.contains("Dropout probability must be between 0 and 1"));
} else {
panic!("Expected InvalidArgument error for negative p");
}
}
#[test]
fn test_dropout_invalid_p_too_large() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("Tensor should succeed");
let result = dropout(&input, 1.5, true);
assert!(result.is_err());
if let Err(TorshError::InvalidArgument(msg)) = result {
assert!(msg.contains("Dropout probability must be between 0 and 1"));
} else {
panic!("Expected InvalidArgument error for p > 1.0");
}
}
#[test]
fn test_dropout_multidimensional() -> Result<()> {
let input = Tensor::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
&[3, 4],
)?;
let output = dropout(&input, 0.5, true)?;
assert_eq!(output.shape().dims(), &[3, 4]);
let output_data = output.to_vec()?;
let has_zeros = output_data.iter().any(|&x| x == 0.0);
let has_nonzeros = output_data.iter().any(|&x| x != 0.0);
assert!(has_zeros, "Should have some dropped (zero) elements");
assert!(has_nonzeros, "Should have some kept (non-zero) elements");
Ok(())
}
#[test]
fn test_dropout_edge_case_empty_like() -> Result<()> {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
let output = dropout(&input, 0.01, true)?;
let output_data = output.to_vec()?;
let non_zeros = output_data.iter().filter(|&&x| x != 0.0).count();
assert!(
non_zeros >= 3,
"Expected at least 3 non-zero elements with p=0.01, got {}",
non_zeros
);
Ok(())
}
}