pub mod basic;
pub mod gated;
pub mod modern;
pub mod smooth;
pub mod softmax;
pub mod threshold;
pub use basic::{LeakyReLU, PReLU, ReLU, ReLU6, Sigmoid, Tanh};
pub use modern::{
Hardswish,
Mish,
SiLU,
Swish, ELU,
GELU,
SELU,
};
pub use softmax::{LogSigmoid, LogSoftmax, Softmax};
pub use threshold::{Hardshrink, Hardtanh, Softshrink, Tanhshrink, Threshold};
pub use smooth::{Hardsigmoid, Softplus, Softsign};
pub use gated::{ReGLU, SwiGLU, GEGLU, GLU};
pub mod prelude {
pub use super::basic::{LeakyReLU, ReLU, Sigmoid, Tanh};
pub use super::gated::{SwiGLU, GLU};
pub use super::modern::{Mish, SiLU, GELU};
pub use super::smooth::Softplus;
pub use super::softmax::{LogSoftmax, Softmax};
pub use super::threshold::Hardtanh;
}
pub mod collections {
pub mod classical {
pub use crate::layers::activation::basic::{ReLU, Sigmoid, Tanh};
pub use crate::layers::activation::smooth::{Softplus, Softsign};
pub use crate::layers::activation::threshold::Hardtanh;
}
pub mod modern {
pub use crate::layers::activation::gated::{SwiGLU, GEGLU, GLU};
pub use crate::layers::activation::modern::{Mish, SiLU, ELU, GELU, SELU};
}
pub mod mobile {
pub use crate::layers::activation::basic::{ReLU, ReLU6};
pub use crate::layers::activation::modern::Hardswish;
pub use crate::layers::activation::smooth::Hardsigmoid;
pub use crate::layers::activation::threshold::Hardtanh;
}
pub mod transformer {
pub use crate::layers::activation::gated::{ReGLU, SwiGLU, GEGLU, GLU};
pub use crate::layers::activation::modern::{SiLU, GELU};
}
pub mod classification {
pub use crate::layers::activation::basic::Sigmoid;
pub use crate::layers::activation::softmax::{LogSigmoid, LogSoftmax, Softmax};
}
pub mod sparse {
pub use crate::layers::activation::basic::ReLU;
pub use crate::layers::activation::gated::ReGLU;
pub use crate::layers::activation::threshold::{Hardshrink, Softshrink, Threshold};
}
}
pub mod factory {
use super::*;
use torsh_core::error::Result;
pub fn relu() -> ReLU {
ReLU::new()
}
pub fn gelu() -> GELU {
GELU::new()
}
pub fn silu() -> SiLU {
SiLU::new()
}
pub fn softmax() -> Softmax {
Softmax::new(Some(1))
}
pub fn log_softmax() -> LogSoftmax {
LogSoftmax::new(Some(1))
}
pub fn swiglu() -> SwiGLU {
SwiGLU::new(-1)
}
pub fn geglu() -> GEGLU {
GEGLU::new(-1)
}
pub fn prelu(num_parameters: usize) -> Result<PReLU> {
PReLU::new(num_parameters)
}
pub fn leaky_relu() -> LeakyReLU {
LeakyReLU::default()
}
pub fn leaky_relu_with_slope(slope: f64) -> LeakyReLU {
LeakyReLU::new(slope)
}
pub fn selu() -> SELU {
SELU::new()
}
pub fn mish() -> Mish {
Mish::new()
}
pub fn hardswish() -> Hardswish {
Hardswish::new()
}
pub fn relu6() -> ReLU6 {
ReLU6::new()
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use crate::Module;
use torsh_tensor::creation::*;
use torsh_tensor::Tensor;
#[test]
fn test_all_activations_implement_module() -> torsh_core::error::Result<()> {
let input = randn(&[2, 4])?;
let relu = ReLU::new();
let sigmoid = Sigmoid::new();
let tanh = Tanh::new();
let leaky_relu = LeakyReLU::new(0.01);
let relu6 = ReLU6::new();
let _relu_out = relu.forward(&input)?;
let _sigmoid_out = sigmoid.forward(&input)?;
let _tanh_out = tanh.forward(&input)?;
let _leaky_relu_out = leaky_relu.forward(&input)?;
let _relu6_out = relu6.forward(&input)?;
let gelu = GELU::new();
let silu = SiLU::new();
let mish = Mish::new();
let hardswish = Hardswish::new();
let elu = ELU::new(1.0);
let selu = SELU::new();
let _gelu_out = gelu.forward(&input)?;
let _silu_out = silu.forward(&input)?;
let _mish_out = mish.forward(&input)?;
let _hardswish_out = hardswish.forward(&input)?;
let _elu_out = elu.forward(&input)?;
let _selu_out = selu.forward(&input)?;
let softmax = Softmax::new(Some(1));
let log_softmax = LogSoftmax::new(Some(1));
let log_sigmoid = LogSigmoid::new();
let _softmax_out = softmax.forward(&input)?;
let _log_softmax_out = log_softmax.forward(&input)?;
let _log_sigmoid_out = log_sigmoid.forward(&input)?;
let hardshrink = Hardshrink::new(0.5);
let softshrink = Softshrink::new(0.5);
let hardtanh = Hardtanh::new(-1.0, 1.0);
let threshold = Threshold::new(0.0, 0.0);
let tanhshrink = Tanhshrink::new();
let _hardshrink_out = hardshrink.forward(&input)?;
let _softshrink_out = softshrink.forward(&input)?;
let _hardtanh_out = hardtanh.forward(&input)?;
let _threshold_out = threshold.forward(&input)?;
let _tanhshrink_out = tanhshrink.forward(&input)?;
let softplus = Softplus::new(1.0, 20.0);
let softsign = Softsign::new();
let hardsigmoid = Hardsigmoid::new();
let _softplus_out = softplus.forward(&input)?;
let _softsign_out = softsign.forward(&input)?;
let _hardsigmoid_out = hardsigmoid.forward(&input)?;
let gated_input = randn(&[2, 8])?; let glu = GLU::new(-1);
let geglu = GEGLU::new(-1);
let reglu = ReGLU::new(-1);
let swiglu = SwiGLU::new(-1);
let _glu_out = glu.forward(&gated_input)?;
let _geglu_out = geglu.forward(&gated_input)?;
let _reglu_out = reglu.forward(&gated_input)?;
let _swiglu_out = swiglu.forward(&gated_input)?;
Ok(())
}
#[test]
fn test_factory_functions() -> torsh_core::error::Result<()> {
let input = randn(&[2, 4])?;
let relu = factory::relu();
let gelu = factory::gelu();
let silu = factory::silu();
let softmax = factory::softmax();
let log_softmax = factory::log_softmax();
let leaky_relu = factory::leaky_relu();
let selu = factory::selu();
let mish = factory::mish();
let hardswish = factory::hardswish();
let relu6 = factory::relu6();
let _relu_out = relu.forward(&input)?;
let _gelu_out = gelu.forward(&input)?;
let _silu_out = silu.forward(&input)?;
let _softmax_out = softmax.forward(&input)?;
let _log_softmax_out = log_softmax.forward(&input)?;
let _leaky_relu_out = leaky_relu.forward(&input)?;
let _selu_out = selu.forward(&input)?;
let _mish_out = mish.forward(&input)?;
let _hardswish_out = hardswish.forward(&input)?;
let _relu6_out = relu6.forward(&input)?;
let gated_input = randn(&[2, 8])?;
let swiglu = factory::swiglu();
let geglu = factory::geglu();
let _swiglu_out = swiglu.forward(&gated_input)?;
let _geglu_out = geglu.forward(&gated_input)?;
Ok(())
}
#[test]
fn test_prelude_imports() -> torsh_core::error::Result<()> {
use super::prelude::*;
let input = randn(&[2, 4])?;
let relu = ReLU::new();
let sigmoid = Sigmoid::new();
let tanh = Tanh::new();
let gelu = GELU::new();
let silu = SiLU::new();
let mish = Mish::new();
let softmax = Softmax::new(Some(1));
let log_softmax = LogSoftmax::new(Some(1));
let _outputs = vec![
relu.forward(&input)?,
sigmoid.forward(&input)?,
tanh.forward(&input)?,
gelu.forward(&input)?,
silu.forward(&input)?,
mish.forward(&input)?,
softmax.forward(&input)?,
log_softmax.forward(&input)?,
];
Ok(())
}
#[test]
fn test_collections() -> torsh_core::error::Result<()> {
let input = randn(&[2, 4])?;
let classical_relu = collections::classical::ReLU::new();
let classical_sigmoid = collections::classical::Sigmoid::new();
let _relu_out = classical_relu.forward(&input)?;
let _sigmoid_out = classical_sigmoid.forward(&input)?;
let modern_gelu = collections::modern::GELU::new();
let modern_silu = collections::modern::SiLU::new();
let _gelu_out = modern_gelu.forward(&input)?;
let _silu_out = modern_silu.forward(&input)?;
let mobile_relu = collections::mobile::ReLU::new();
let mobile_relu6 = collections::mobile::ReLU6::new();
let _mobile_relu_out = mobile_relu.forward(&input)?;
let _mobile_relu6_out = mobile_relu6.forward(&input)?;
let class_softmax = collections::classification::Softmax::new(Some(1));
let class_log_softmax = collections::classification::LogSoftmax::new(Some(1));
let _softmax_out = class_softmax.forward(&input)?;
let _log_softmax_out = class_log_softmax.forward(&input)?;
Ok(())
}
#[test]
fn test_backward_compatibility() -> torsh_core::error::Result<()> {
let input = randn(&[2, 4])?;
let gated_input = randn(&[2, 8])?;
let activations: Vec<Box<dyn Fn(&Tensor) -> torsh_core::error::Result<Tensor>>> = vec![
Box::new(|x| ReLU::new().forward(x)),
Box::new(|x| Sigmoid::new().forward(x)),
Box::new(|x| Tanh::new().forward(x)),
Box::new(|x| GELU::new().forward(x)),
Box::new(|x| SiLU::new().forward(x)),
Box::new(|x| Mish::new().forward(x)),
Box::new(|x| LeakyReLU::new(0.01).forward(x)),
Box::new(|x| ReLU6::new().forward(x)),
Box::new(|x| ELU::new(1.0).forward(x)),
Box::new(|x| SELU::new().forward(x)),
Box::new(|x| Hardswish::new().forward(x)),
Box::new(|x| Softmax::new(Some(1)).forward(x)),
Box::new(|x| LogSoftmax::new(Some(1)).forward(x)),
Box::new(|x| LogSigmoid::new().forward(x)),
Box::new(|x| Hardshrink::new(0.5).forward(x)),
Box::new(|x| Softshrink::new(0.5).forward(x)),
Box::new(|x| Hardtanh::new(-1.0, 1.0).forward(x)),
Box::new(|x| Threshold::new(0.0, 0.0).forward(x)),
Box::new(|x| Tanhshrink::new().forward(x)),
Box::new(|x| Softplus::new(1.0, 20.0).forward(x)),
Box::new(|x| Softsign::new().forward(x)),
Box::new(|x| Hardsigmoid::new().forward(x)),
];
for activation in activations {
let _output = activation(&input)?;
}
let gated_activations: Vec<Box<dyn Fn(&Tensor) -> torsh_core::error::Result<Tensor>>> = vec![
Box::new(|x| GLU::new(-1).forward(x)),
Box::new(|x| GEGLU::new(-1).forward(x)),
Box::new(|x| ReGLU::new(-1).forward(x)),
Box::new(|x| SwiGLU::new(-1).forward(x)),
];
for activation in gated_activations {
let _output = activation(&gated_input)?;
}
Ok(())
}
}