pub mod advanced;
pub mod basic;
pub mod normalization;
pub use basic::{
Hardshrink,
Hardsigmoid,
Hardtanh,
LeakyReLU,
LogSigmoid,
PReLU,
ReLU,
ReLU6,
Sigmoid,
Softshrink,
Tanh,
Tanhshrink,
Threshold,
ELU,
SELU,
};
pub use normalization::{
log_sum_exp,
stable_log_softmax,
stable_softmax,
LogSoftmax,
Softmax,
Softplus,
Softsign,
};
pub use advanced::{
Hardswish,
Mish,
ReGLU,
SiLU,
SwiGLU,
Swish, GEGLU,
GELU,
GLU,
};
pub struct ActivationFactory;
impl ActivationFactory {
pub fn create(name: &str) -> Option<Box<dyn crate::Module>> {
match name.to_lowercase().as_str() {
"relu" => Some(Box::new(ReLU::new())),
"leaky_relu" | "leakyrelu" => Some(Box::new(LeakyReLU::default())),
"relu6" => Some(Box::new(ReLU6::new())),
"prelu" => Some(Box::new(PReLU::default_params().expect("PReLU default params should succeed"))),
"elu" => Some(Box::new(ELU::default())),
"selu" => Some(Box::new(SELU::new())),
"sigmoid" => Some(Box::new(Sigmoid::new())),
"hardsigmoid" => Some(Box::new(Hardsigmoid::new())),
"logsigmoid" | "log_sigmoid" => Some(Box::new(LogSigmoid::new())),
"tanh" => Some(Box::new(Tanh::new())),
"hardtanh" => Some(Box::new(Hardtanh::default())),
"tanhshrink" => Some(Box::new(Tanhshrink::new())),
"threshold" => Some(Box::new(Threshold::default_params())),
"hardshrink" => Some(Box::new(Hardshrink::default())),
"softshrink" => Some(Box::new(Softshrink::default())),
"softmax" => Some(Box::new(Softmax::new(None))),
"log_softmax" | "logsoftmax" => Some(Box::new(LogSoftmax::new(None))),
"softplus" => Some(Box::new(Softplus::default())),
"softsign" => Some(Box::new(Softsign::new())),
"gelu" => Some(Box::new(GELU::new())),
"gelu_exact" => Some(Box::new(GELU::exact())),
"gelu_approx" | "gelu_approximate" => Some(Box::new(GELU::approximate())),
"silu" | "swish" => Some(Box::new(SiLU::new())),
"mish" => Some(Box::new(Mish::new())),
"hardswish" => Some(Box::new(Hardswish::new())),
"glu" => Some(Box::new(GLU::default())),
"geglu" => Some(Box::new(GEGLU::default())),
"reglu" => Some(Box::new(ReGLU::default())),
"swiglu" => Some(Box::new(SwiGLU::default())),
_ => None,
}
}
pub fn create_with_params(
name: &str,
params: &std::collections::HashMap<String, String>,
) -> Option<Box<dyn crate::Module>> {
match name.to_lowercase().as_str() {
"leaky_relu" | "leakyrelu" => {
let slope = params
.get("negative_slope")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(0.01);
Some(Box::new(LeakyReLU::new(slope)))
}
"elu" => {
let alpha = params
.get("alpha")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(1.0);
Some(Box::new(ELU::new(alpha)))
}
"softmax" => {
let dim = params.get("dim").and_then(|s| s.parse::<usize>().ok());
Some(Box::new(Softmax::new(dim)))
}
"log_softmax" | "logsoftmax" => {
let dim = params.get("dim").and_then(|s| s.parse::<usize>().ok());
Some(Box::new(LogSoftmax::new(dim)))
}
"softplus" => {
let beta = params
.get("beta")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(1.0);
let threshold = params
.get("threshold")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(20.0);
Some(Box::new(Softplus::new(beta, threshold)))
}
"hardtanh" => {
let min_val = params
.get("min_val")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(-1.0);
let max_val = params
.get("max_val")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(1.0);
Some(Box::new(Hardtanh::new(min_val, max_val)))
}
"threshold" => {
let threshold = params
.get("threshold")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(1.0);
let value = params
.get("value")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(1.0);
Some(Box::new(Threshold::new(threshold, value)))
}
"hardshrink" => {
let lambd = params
.get("lambd")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(0.5);
Some(Box::new(Hardshrink::new(lambd)))
}
"softshrink" => {
let lambd = params
.get("lambd")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(0.5);
Some(Box::new(Softshrink::new(lambd)))
}
"gelu" => {
let approximate = params
.get("approximate")
.and_then(|s| s.parse::<bool>().ok())
.unwrap_or(false);
Some(Box::new(GELU::with_approximate(approximate)))
}
"glu" => {
let dim = params
.get("dim")
.and_then(|s| s.parse::<isize>().ok())
.unwrap_or(-1);
Some(Box::new(GLU::new(dim)))
}
"geglu" => {
let dim = params
.get("dim")
.and_then(|s| s.parse::<isize>().ok())
.unwrap_or(-1);
let approximate = params
.get("approximate_gelu")
.and_then(|s| s.parse::<bool>().ok())
.unwrap_or(false);
Some(Box::new(GEGLU::new(dim, approximate)))
}
"reglu" => {
let dim = params
.get("dim")
.and_then(|s| s.parse::<isize>().ok())
.unwrap_or(-1);
Some(Box::new(ReGLU::new(dim)))
}
"swiglu" => {
let dim = params
.get("dim")
.and_then(|s| s.parse::<isize>().ok())
.unwrap_or(-1);
Some(Box::new(SwiGLU::new(dim)))
}
_ => Self::create(name),
}
}
pub fn available_activations() -> Vec<&'static str> {
vec![
"relu",
"leaky_relu",
"relu6",
"prelu",
"elu",
"selu",
"sigmoid",
"hardsigmoid",
"logsigmoid",
"tanh",
"hardtanh",
"tanhshrink",
"threshold",
"hardshrink",
"softshrink",
"softmax",
"log_softmax",
"softplus",
"softsign",
"gelu",
"gelu_exact",
"gelu_approx",
"silu",
"swish",
"mish",
"hardswish",
"glu",
"geglu",
"reglu",
"swiglu",
]
}
}
pub struct ActivationBuilder {
name: String,
params: std::collections::HashMap<String, String>,
}
impl ActivationBuilder {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
params: std::collections::HashMap::new(),
}
}
pub fn dim(mut self, dim: isize) -> Self {
self.params.insert("dim".to_string(), dim.to_string());
self
}
pub fn dim_usize(mut self, dim: usize) -> Self {
self.params.insert("dim".to_string(), dim.to_string());
self
}
pub fn alpha(mut self, alpha: f32) -> Self {
self.params.insert("alpha".to_string(), alpha.to_string());
self
}
pub fn beta(mut self, beta: f32) -> Self {
self.params.insert("beta".to_string(), beta.to_string());
self
}
pub fn threshold(mut self, threshold: f32) -> Self {
self.params
.insert("threshold".to_string(), threshold.to_string());
self
}
pub fn lambda(mut self, lambd: f32) -> Self {
self.params.insert("lambd".to_string(), lambd.to_string());
self
}
pub fn negative_slope(mut self, slope: f32) -> Self {
self.params
.insert("negative_slope".to_string(), slope.to_string());
self
}
pub fn approximate(mut self, approx: bool) -> Self {
self.params
.insert("approximate".to_string(), approx.to_string());
self
}
pub fn approximate_gelu(mut self, approx: bool) -> Self {
self.params
.insert("approximate_gelu".to_string(), approx.to_string());
self
}
pub fn min_val(mut self, min_val: f32) -> Self {
self.params
.insert("min_val".to_string(), min_val.to_string());
self
}
pub fn max_val(mut self, max_val: f32) -> Self {
self.params
.insert("max_val".to_string(), max_val.to_string());
self
}
pub fn value(mut self, value: f32) -> Self {
self.params.insert("value".to_string(), value.to_string());
self
}
pub fn build(self) -> Option<Box<dyn crate::Module>> {
ActivationFactory::create_with_params(&self.name, &self.params)
}
}
pub mod presets {
use super::*;
pub fn for_architecture(arch: &str) -> Box<dyn crate::Module> {
match arch.to_lowercase().as_str() {
"transformer" | "attention" => Box::new(GELU::new()),
"cnn" | "convnet" | "resnet" => Box::new(ReLU::new()),
"mobile" | "mobilenet" => Box::new(Hardswish::new()),
"efficientnet" => Box::new(SiLU::new()),
"vision_transformer" | "vit" => Box::new(GELU::new()),
"bert" | "gpt" => Box::new(GELU::new()),
"lstm" | "gru" | "rnn" => Box::new(Tanh::new()),
"autoencoder" => Box::new(ReLU::new()),
"gan" => Box::new(LeakyReLU::new(0.2)),
_ => Box::new(ReLU::new()), }
}
pub fn modern_replacement(legacy: &str) -> Box<dyn crate::Module> {
match legacy.to_lowercase().as_str() {
"relu" => Box::new(GELU::new()),
"sigmoid" => Box::new(SiLU::new()),
"tanh" => Box::new(Mish::new()),
"swish" => Box::new(SiLU::new()),
_ => ActivationFactory::create(legacy).unwrap_or_else(|| Box::new(ReLU::new())),
}
}
pub fn mobile_optimized(standard: &str) -> Box<dyn crate::Module> {
match standard.to_lowercase().as_str() {
"relu" => Box::new(ReLU6::new()),
"swish" | "silu" => Box::new(Hardswish::new()),
"sigmoid" => Box::new(Hardsigmoid::new()),
"gelu" => Box::new(GELU::approximate()),
_ => ActivationFactory::create(standard).unwrap_or_else(|| Box::new(ReLU6::new())),
}
}
}
pub mod benchmarks {
use super::*;
use std::time::Instant;
pub fn benchmark_activation(
activation: &dyn crate::Module,
input: &torsh_tensor::Tensor,
iterations: usize,
) -> Result<(f64, f64), torsh_core::error::TorshError> {
let mut times = Vec::with_capacity(iterations);
for _ in 0..iterations {
let start = Instant::now();
let _output = activation.forward(input)?;
let elapsed = start.elapsed().as_secs_f64();
times.push(elapsed);
}
let mean_time = times.iter().sum::<f64>() / times.len() as f64;
let variance = times
.iter()
.map(|&time| (time - mean_time).powi(2))
.sum::<f64>()
/ times.len() as f64;
let std_dev = variance.sqrt();
Ok((mean_time, std_dev))
}
pub fn compare_activations(
activations: &[(&str, Box<dyn crate::Module>)],
input: &torsh_tensor::Tensor,
iterations: usize,
) -> Vec<(String, f64, f64)> {
activations
.iter()
.filter_map(|(name, activation)| {
benchmark_activation(activation.as_ref(), input, iterations)
.ok()
.map(|(mean, std)| (name.to_string(), mean, std))
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::*;
#[test]
fn test_activation_factory() {
assert!(ActivationFactory::create("relu").is_some());
assert!(ActivationFactory::create("gelu").is_some());
assert!(ActivationFactory::create("swiglu").is_some());
assert!(ActivationFactory::create("nonexistent").is_none());
}
#[test]
fn test_activation_factory_with_params() {
let mut params = std::collections::HashMap::new();
params.insert("negative_slope".to_string(), "0.2".to_string());
let activation = ActivationFactory::create_with_params("leaky_relu", ¶ms);
assert!(activation.is_some());
}
#[test]
fn test_activation_builder() {
let activation = ActivationBuilder::new("gelu").approximate(true).build();
assert!(activation.is_some());
let glu = ActivationBuilder::new("swiglu").dim(-1).build();
assert!(glu.is_some());
}
#[test]
fn test_available_activations() {
let activations = ActivationFactory::available_activations();
assert!(!activations.is_empty());
assert!(activations.contains(&"relu"));
assert!(activations.contains(&"gelu"));
assert!(activations.contains(&"swiglu"));
}
#[test]
fn test_presets() {
let transformer_activation = presets::for_architecture("transformer");
let mobile_activation = presets::for_architecture("mobile");
let cnn_activation = presets::for_architecture("cnn");
assert_eq!(
std::mem::discriminant(&*transformer_activation),
std::mem::discriminant(&*Box::new(GELU::new()) as Box<dyn crate::Module>)
);
}
#[test]
fn test_module_integration() {
let relu = ReLU::new();
let gelu = GELU::new();
let softmax = Softmax::new(None);
let swiglu = SwiGLU::new(-1);
assert_eq!(relu.training(), true); assert_eq!(gelu.training(), true);
assert_eq!(softmax.training(), true);
assert_eq!(swiglu.training(), true);
}
#[test]
fn test_backward_compatibility() {
let _relu = ReLU::new();
let _leaky_relu = LeakyReLU::new(0.01);
let _sigmoid = Sigmoid::new();
let _tanh = Tanh::new();
let _gelu = GELU::new();
let _silu = SiLU::new();
let _swish = Swish::new(); let _softmax = Softmax::new(None);
let _glu = GLU::new(-1);
let _geglu = GEGLU::new(-1, false);
let _reglu = ReGLU::new(-1);
let _swiglu = SwiGLU::new(-1);
}
#[test]
fn test_performance_benchmarking() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("operation should succeed");
let relu = ReLU::new();
let result = benchmarks::benchmark_activation(&relu, &input, 5);
assert!(result.is_ok());
let (mean_time, std_dev) = result.expect("operation should succeed");
assert!(mean_time >= 0.0);
assert!(std_dev >= 0.0);
}
}