use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
use super::basic::{ReLU, Sigmoid};
use super::normalization::Softplus;
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String};
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct GELU {
base: ModuleBase,
approximate: bool,
}
impl GELU {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
approximate: false,
}
}
pub fn with_approximate(approximate: bool) -> Self {
Self {
base: ModuleBase::new(),
approximate,
}
}
pub fn approximate() -> Self {
Self::with_approximate(true)
}
pub fn exact() -> Self {
Self::with_approximate(false)
}
}
impl Default for GELU {
fn default() -> Self {
Self::new()
}
}
impl Module for GELU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if self.approximate {
let x_cubed = input.pow(3.0)?;
let x_cubed_scaled = x_cubed.scalar_mul(0.044715)?;
let inner_term = input.add(&x_cubed_scaled)?;
let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
let scaled_term = inner_term.scalar_mul(sqrt_2_over_pi)?;
let tanh_term = scaled_term.tanh()?;
let one_plus_tanh = tanh_term.add(&ones(input.shape().dims())?)?;
let half_x = input.scalar_mul(0.5)?;
half_x.mul(&one_plus_tanh)
} else {
let sqrt_2 = (2.0_f32).sqrt();
let x_div_sqrt2 = input.scalar_mul(1.0 / sqrt_2)?;
let erf_result = x_div_sqrt2.erf()?;
let one_plus_erf = erf_result.add(&ones(input.shape().dims())?)?;
let half_x = input.scalar_mul(0.5)?;
half_x.mul(&one_plus_erf)
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for GELU {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GELU")
.field("approximate", &self.approximate)
.finish()
}
}
pub struct SiLU {
base: ModuleBase,
}
pub type Swish = SiLU;
impl SiLU {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for SiLU {
fn default() -> Self {
Self::new()
}
}
impl Module for SiLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let sigmoid_result = input.sigmoid()?;
input.mul(&sigmoid_result)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for SiLU {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SiLU").finish()
}
}
pub struct Mish {
base: ModuleBase,
}
impl Mish {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Mish {
fn default() -> Self {
Self::new()
}
}
impl Module for Mish {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let exp_x = input.exp()?;
let one_plus_exp = ones(input.shape().dims())?.add(&exp_x)?;
let softplus_result = one_plus_exp.log()?;
let tanh_result = softplus_result.tanh()?;
input.mul(&tanh_result)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for Mish {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Mish").finish()
}
}
pub struct Hardswish {
base: ModuleBase,
}
impl Hardswish {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Hardswish {
fn default() -> Self {
Self::new()
}
}
impl Module for Hardswish {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let three = full(input.shape().dims(), 3.0)?;
let six = full(input.shape().dims(), 6.0)?;
let zero = zeros(input.shape().dims())?;
let one = ones(input.shape().dims())?;
let x_plus_3 = input.add(&three)?;
let divided = x_plus_3.scalar_mul(1.0 / 6.0)?;
let clipped_upper = divided.minimum(&one)?;
let hardsigmoid = clipped_upper.maximum(&zero)?;
input.mul(&hardsigmoid)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for Hardswish {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Hardswish").finish()
}
}
pub struct GLU {
base: ModuleBase,
dim: isize,
}
impl GLU {
pub fn new(dim: isize) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn last_dim() -> Self {
Self::new(-1)
}
pub fn channel_dim() -> Self {
Self::new(1)
}
}
impl Default for GLU {
fn default() -> Self {
Self::last_dim()
}
}
impl Module for GLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let input_shape = input.shape().dims();
let split_dim = if self.dim < 0 {
(input_shape.len() as isize + self.dim) as usize
} else {
self.dim as usize
};
if split_dim >= input_shape.len() {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
self.dim,
input_shape.len()
)));
}
let split_size = input_shape[split_dim];
if split_size % 2 != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Input dimension {} must be even for GLU, got {}",
split_size, split_size
)));
}
let chunks = input.chunk(2, split_dim as i32)?;
if chunks.len() != 2 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Failed to split input into two chunks".to_string(),
));
}
let value = &chunks[0];
let gate = &chunks[1];
let sigmoid_gate = gate.sigmoid()?;
value.mul(&sigmoid_gate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for GLU {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GLU").field("dim", &self.dim).finish()
}
}
pub struct GEGLU {
base: ModuleBase,
dim: isize,
approximate_gelu: bool,
}
impl GEGLU {
pub fn new(dim: isize, approximate_gelu: bool) -> Self {
Self {
base: ModuleBase::new(),
dim,
approximate_gelu,
}
}
pub fn exact() -> Self {
Self::new(-1, false)
}
pub fn approximate() -> Self {
Self::new(-1, true)
}
}
impl Default for GEGLU {
fn default() -> Self {
Self::exact()
}
}
impl Module for GEGLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let input_shape = input.shape().dims();
let split_dim = if self.dim < 0 {
(input_shape.len() as isize + self.dim) as usize
} else {
self.dim as usize
};
if split_dim >= input_shape.len() {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
self.dim,
input_shape.len()
)));
}
let split_size = input_shape[split_dim];
if split_size % 2 != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Input dimension {} must be even for GEGLU, got {}",
split_size, split_size
)));
}
let chunks = input.chunk(2, split_dim as i32)?;
if chunks.len() != 2 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Failed to split input into two chunks".to_string(),
));
}
let value = &chunks[0];
let gate = &chunks[1];
let gelu = GELU::with_approximate(self.approximate_gelu);
let gelu_gate = gelu.forward(gate)?;
value.mul(&gelu_gate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for GEGLU {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GEGLU")
.field("dim", &self.dim)
.field("approximate_gelu", &self.approximate_gelu)
.finish()
}
}
pub struct ReGLU {
base: ModuleBase,
dim: isize,
}
impl ReGLU {
pub fn new(dim: isize) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn last_dim() -> Self {
Self::new(-1)
}
}
impl Default for ReGLU {
fn default() -> Self {
Self::last_dim()
}
}
impl Module for ReGLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let input_shape = input.shape().dims();
let split_dim = if self.dim < 0 {
(input_shape.len() as isize + self.dim) as usize
} else {
self.dim as usize
};
if split_dim >= input_shape.len() {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
self.dim,
input_shape.len()
)));
}
let split_size = input_shape[split_dim];
if split_size % 2 != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Input dimension {} must be even for ReGLU, got {}",
split_size, split_size
)));
}
let chunks = input.chunk(2, split_dim as i32)?;
if chunks.len() != 2 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Failed to split input into two chunks".to_string(),
));
}
let value = &chunks[0];
let gate = &chunks[1];
let relu = ReLU::new();
let relu_gate = relu.forward(gate)?;
value.mul(&relu_gate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for ReGLU {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReGLU").field("dim", &self.dim).finish()
}
}
pub struct SwiGLU {
base: ModuleBase,
dim: isize,
}
impl SwiGLU {
pub fn new(dim: isize) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn last_dim() -> Self {
Self::new(-1)
}
}
impl Default for SwiGLU {
fn default() -> Self {
Self::last_dim()
}
}
impl Module for SwiGLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let input_shape = input.shape().dims();
let split_dim = if self.dim < 0 {
(input_shape.len() as isize + self.dim) as usize
} else {
self.dim as usize
};
if split_dim >= input_shape.len() {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
self.dim,
input_shape.len()
)));
}
let split_size = input_shape[split_dim];
if split_size % 2 != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Input dimension {} must be even for SwiGLU, got {}",
split_size, split_size
)));
}
let chunks = input.chunk(2, split_dim as i32)?;
if chunks.len() != 2 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Failed to split input into two chunks".to_string(),
));
}
let value = &chunks[0];
let gate = &chunks[1];
let silu = SiLU::new();
let silu_gate = silu.forward(gate)?;
value.mul(&silu_gate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for SwiGLU {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SwiGLU").field("dim", &self.dim).finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use torsh_tensor::creation::*;
#[test]
fn test_gelu_forward() {
let gelu = GELU::new();
let input = Tensor::from_data(vec![0.0], vec![1], DeviceType::Cpu).expect("Tensor should succeed");
let output = gelu.forward(&input).expect("forward pass should succeed");
assert_relative_eq!(output.to_vec().expect("tensor to vec conversion should succeed")[0], 0.0, epsilon = 1e-5);
}
#[test]
fn test_gelu_approximate_vs_exact() {
let input = Tensor::from_data(vec![1.0], vec![1], DeviceType::Cpu).expect("Tensor should succeed");
let gelu_exact = GELU::exact();
let gelu_approx = GELU::approximate();
let output_exact = gelu_exact.forward(&input).expect("forward pass should succeed");
let output_approx = gelu_approx.forward(&input).expect("forward pass should succeed");
let diff = (output_exact.to_vec().expect("tensor to vec conversion should succeed")[0] - output_approx.to_vec().expect("tensor to vec conversion should succeed")[0]).abs();
assert!(diff < 0.1); }
#[test]
fn test_silu_forward() {
let silu = SiLU::new();
let input = Tensor::from_data(vec![0.0, 1.0], vec![2], DeviceType::Cpu).expect("Tensor should succeed");
let output = silu.forward(&input).expect("forward pass should succeed");
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
assert_relative_eq!(output_vec[0], 0.0, epsilon = 1e-5);
assert!(output_vec[1] > 0.5 && output_vec[1] < 1.0);
}
#[test]
fn test_mish_forward() {
let mish = Mish::new();
let input = Tensor::from_data(vec![0.0, 1.0], vec![2], DeviceType::Cpu).expect("Tensor should succeed");
let output = mish.forward(&input).expect("forward pass should succeed");
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
assert_relative_eq!(output_vec[0], 0.0, epsilon = 1e-2);
assert!(output_vec[1] > 0.5 && output_vec[1] < 1.5);
}
#[test]
fn test_hardswish_forward() {
let hardswish = Hardswish::new();
let input = Tensor::from_data(vec![-3.0, 0.0, 3.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let output = hardswish.forward(&input).expect("forward pass should succeed");
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
assert_relative_eq!(output_vec[0], 0.0, epsilon = 1e-5);
assert_relative_eq!(output_vec[1], 0.0, epsilon = 1e-5);
assert_relative_eq!(output_vec[2], 3.0, epsilon = 1e-5);
}
#[test]
fn test_glu_forward() {
let input = Tensor::from_data(
vec![
1.0, 2.0, 0.0, 1.0, 3.0, 4.0, -1.0, 2.0,
], vec![2, 4],
DeviceType::Cpu,
)
.expect("operation should succeed");
let glu = GLU::new(-1); let output = glu.forward(&input).expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[2, 2]);
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
assert_eq!(output_vec.len(), 4);
}
#[test]
fn test_glu_invalid_dimension() {
let input = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let glu = GLU::new(-1);
assert!(glu.forward(&input).is_err());
}
#[test]
fn test_geglu_forward() {
let input = Tensor::from_data(vec![1.0, 2.0, 0.0, 1.0], vec![4], DeviceType::Cpu).expect("Tensor should succeed");
let geglu = GEGLU::exact();
let output = geglu.forward(&input).expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[2]);
}
#[test]
fn test_reglu_forward() {
let input = Tensor::from_data(vec![1.0, 2.0, -1.0, 1.0], vec![4], DeviceType::Cpu).expect("Tensor should succeed");
let reglu = ReGLU::new(-1);
let output = reglu.forward(&input).expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[2]);
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
assert_relative_eq!(output_vec[0], 0.0, epsilon = 1e-5); assert_relative_eq!(output_vec[1], 2.0, epsilon = 1e-5); }
#[test]
fn test_swiglu_forward() {
let input = Tensor::from_data(vec![1.0, 2.0, 0.0, 1.0], vec![4], DeviceType::Cpu).expect("Tensor should succeed");
let swiglu = SwiGLU::new(-1);
let output = swiglu.forward(&input).expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[2]);
}
#[test]
fn test_module_interface() {
let mut gelu = GELU::new();
assert!(gelu.training()); gelu.eval();
assert!(!gelu.training());
gelu.train();
assert!(gelu.training());
assert!(gelu.parameters().is_empty());
assert!(gelu.named_parameters().is_empty());
}
#[test]
fn test_swish_alias() {
let silu = SiLU::new();
let swish = Swish::new();
let input = Tensor::from_data(vec![1.0], vec![1], DeviceType::Cpu).expect("Tensor should succeed");
let silu_output = silu.forward(&input).expect("forward pass should succeed");
let swish_output = swish.forward(&input).expect("forward pass should succeed");
assert_eq!(
silu_output.to_vec().expect("tensor to vec conversion should succeed"),
swish_output.to_vec().expect("tensor to vec conversion should succeed")
);
}
}