use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String};
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct GELU {
base: ModuleBase,
approximate: bool,
}
impl GELU {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
approximate: false,
}
}
pub fn with_approximate(approximate: bool) -> Self {
Self {
base: ModuleBase::new(),
approximate,
}
}
}
impl Default for GELU {
fn default() -> Self {
Self::new()
}
}
impl Module for GELU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if self.approximate {
let x_cubed = input.pow(3.0)?;
let coeff_tensor = full(input.shape().dims(), 0.044715)?;
let term = x_cubed.mul(&coeff_tensor)?;
let inner = input.add(&term)?;
let scale_tensor = full(input.shape().dims(), (2.0 / std::f32::consts::PI).sqrt())?;
let scaled = inner.mul(&scale_tensor)?;
let tanh_result = scaled.tanh()?;
let ones_tensor = ones(input.shape().dims())?;
let one_plus_tanh = tanh_result.add(&ones_tensor)?;
let half_tensor = full(input.shape().dims(), 0.5)?;
let half_x = input.mul(&half_tensor)?;
half_x.mul(&one_plus_tanh)
} else {
let sqrt_2 = (2.0_f32).sqrt();
let sqrt2_tensor = full(input.shape().dims(), sqrt_2)?;
let x_div_sqrt2 = input.div(&sqrt2_tensor)?;
let erf_approx = x_div_sqrt2.tanh()?; let ones_tensor = ones(input.shape().dims())?;
let one_plus_erf = erf_approx.add(&ones_tensor)?;
let half_tensor = full(input.shape().dims(), 0.5)?;
let half_x = input.mul(&half_tensor)?;
half_x.mul(&one_plus_erf)
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct SiLU {
base: ModuleBase,
}
pub type Swish = SiLU;
impl SiLU {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for SiLU {
fn default() -> Self {
Self::new()
}
}
impl Module for SiLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let neg_input = input.neg()?;
let exp_neg = neg_input.exp()?;
let one_plus_exp = exp_neg.add_scalar(1.0)?;
let ones_tensor = ones(input.shape().dims())?;
let sigmoid_result = ones_tensor.div(&one_plus_exp)?;
input.mul(&sigmoid_result)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Mish {
base: ModuleBase,
}
impl Mish {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Mish {
fn default() -> Self {
Self::new()
}
}
impl Module for Mish {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let exp_x = input.exp()?;
let one_plus_exp = exp_x.add_scalar(1.0)?;
let softplus = one_plus_exp.log()?;
let tanh_softplus = softplus.tanh()?;
input.mul(&tanh_softplus)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Hardswish {
base: ModuleBase,
}
impl Hardswish {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Hardswish {
fn default() -> Self {
Self::new()
}
}
impl Module for Hardswish {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let three = full(input.shape().dims(), 3.0)?;
let six = full(input.shape().dims(), 6.0)?;
let zero = zeros(input.shape().dims())?;
let one = ones(input.shape().dims())?;
let x_plus_3 = input.add(&three)?;
let divided = x_plus_3.div(&six)?;
let clamped_high = divided.minimum(&one)?;
let hard_sigmoid = clamped_high.maximum(&zero)?;
input.mul(&hard_sigmoid)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct ELU {
base: ModuleBase,
alpha: f32,
}
impl ELU {
pub fn new(alpha: f32) -> Self {
Self {
base: ModuleBase::new(),
alpha,
}
}
}
impl Default for ELU {
fn default() -> Self {
Self::new(1.0)
}
}
impl Module for ELU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let zero = zeros(input.shape().dims())?;
let alpha_tensor = full(input.shape().dims(), self.alpha)?;
let pos_condition = input.gt(&zero)?;
let exp_x = input.exp()?;
let exp_minus_one = exp_x.sub_scalar(1.0)?;
let neg_elu = alpha_tensor.mul(&exp_minus_one)?;
input.where_tensor(&pos_condition, &neg_elu)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct SELU {
base: ModuleBase,
}
impl SELU {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
const ALPHA: f32 = 1.6732632423543772848170429916717;
const SCALE: f32 = 1.0507009873554804934193349852946;
}
impl Default for SELU {
fn default() -> Self {
Self::new()
}
}
impl Module for SELU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let zero = zeros(input.shape().dims())?;
let alpha_tensor = full(input.shape().dims(), Self::ALPHA)?;
let scale_tensor = full(input.shape().dims(), Self::SCALE)?;
let pos_condition = input.gt(&zero)?;
let exp_x = input.exp()?;
let exp_minus_one = exp_x.sub_scalar(1.0)?;
let neg_elu = alpha_tensor.mul(&exp_minus_one)?;
let combined = input.where_tensor(&pos_condition, &neg_elu)?;
combined.mul(&scale_tensor)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gelu_creation() {
let gelu = GELU::new();
assert!(!gelu.approximate);
let gelu_approx = GELU::with_approximate(true);
assert!(gelu_approx.approximate);
}
#[test]
fn test_silu_forward() -> Result<()> {
let silu = SiLU::new();
let input = randn(&[2, 3])?;
let output = silu.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_swish_alias() {
let silu = SiLU::new();
let swish = Swish::new();
assert_eq!(silu.training(), swish.training());
}
#[test]
fn test_mish_forward() -> Result<()> {
let mish = Mish::new();
let input = randn(&[2, 3])?;
let output = mish.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_hardswish_forward() -> Result<()> {
let hardswish = Hardswish::new();
let input = randn(&[2, 3])?;
let output = hardswish.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_elu_parameters() {
let elu = ELU::new(1.0);
assert_eq!(elu.alpha, 1.0);
let elu_custom = ELU::new(0.5);
assert_eq!(elu_custom.alpha, 0.5);
}
#[test]
fn test_selu_constants() {
let _selu = SELU::new();
assert!((SELU::ALPHA - 1.6733).abs() < 0.01);
assert!((SELU::SCALE - 1.0507).abs() < 0.01);
}
#[test]
fn test_training_mode_toggle() -> Result<()> {
let mut gelu = GELU::new();
assert!(gelu.training());
gelu.eval();
assert!(!gelu.training());
gelu.train();
assert!(gelu.training());
Ok(())
}
#[test]
fn test_default_implementations() {
let _gelu = GELU::default();
let _silu = SiLU::default();
let _mish = Mish::default();
let _hardswish = Hardswish::default();
let _elu = ELU::default();
let _selu = SELU::default();
}
}