use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::{creation::*, Tensor};
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String};
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct GLU {
base: ModuleBase,
dim: isize,
}
impl GLU {
pub fn new(dim: isize) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn default_dim() -> Self {
Self::new(-1)
}
pub fn dim(&self) -> isize {
self.dim
}
}
impl Default for GLU {
fn default() -> Self {
Self::default_dim()
}
}
impl Module for GLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let shape_binding = input.shape();
let input_shape = shape_binding.dims();
let split_dim = if self.dim < 0 {
(input_shape.len() as isize + self.dim) as usize
} else {
self.dim as usize
};
if split_dim >= input_shape.len() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
self.dim,
input_shape.len()
)));
}
let split_size = input_shape[split_dim];
if split_size % 2 != 0 {
return Err(TorshError::InvalidArgument(format!(
"Input dimension {} must be even for GLU, got {}",
split_dim, split_size
)));
}
let half_size = split_size / 2;
let first_half = input.narrow(split_dim as i32, 0, half_size)?;
let second_half = input.narrow(split_dim as i32, half_size as i64, half_size)?;
let neg_second = second_half.neg()?;
let exp_neg = neg_second.exp()?;
let one_plus_exp = exp_neg.add_scalar(1.0)?;
let ones_tensor = ones(second_half.shape().dims())?;
let sigmoid_gate = ones_tensor.div(&one_plus_exp)?;
first_half.mul(&sigmoid_gate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct GEGLU {
base: ModuleBase,
dim: isize,
}
impl GEGLU {
pub fn new(dim: isize) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn dim(&self) -> isize {
self.dim
}
}
impl Default for GEGLU {
fn default() -> Self {
Self::new(-1) }
}
impl Module for GEGLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let binding = input.shape();
let input_shape = binding.dims();
let actual_dim = if self.dim < 0 {
(input_shape.len() as isize + self.dim) as usize
} else {
self.dim as usize
};
if actual_dim >= input_shape.len() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
self.dim,
input_shape.len()
)));
}
let dim_size = input_shape[actual_dim];
if dim_size % 2 != 0 {
return Err(TorshError::InvalidArgument(format!(
"Input dimension {} must be even for GEGLU, got {}",
actual_dim, dim_size
)));
}
let half_size = dim_size / 2;
let first_half = input.narrow(actual_dim as i32, 0, half_size)?;
let second_half = input.narrow(actual_dim as i32, half_size as i64, half_size)?;
let gelu_gate = self.apply_gelu(&second_half)?;
first_half.mul(&gelu_gate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl GEGLU {
fn apply_gelu(&self, input: &Tensor) -> Result<Tensor> {
let x_cubed = input.pow(3.0)?;
let coeff_tensor = full(input.shape().dims(), 0.044715)?;
let term = x_cubed.mul(&coeff_tensor)?;
let inner = input.add(&term)?;
let scale_tensor = full(input.shape().dims(), (2.0 / std::f32::consts::PI).sqrt())?;
let scaled = inner.mul(&scale_tensor)?;
let tanh_result = scaled.tanh()?;
let ones_tensor = ones(input.shape().dims())?;
let one_plus_tanh = tanh_result.add(&ones_tensor)?;
let half_tensor = full(input.shape().dims(), 0.5)?;
let half_x = input.mul(&half_tensor)?;
half_x.mul(&one_plus_tanh)
}
}
pub struct ReGLU {
base: ModuleBase,
dim: isize,
}
impl ReGLU {
pub fn new(dim: isize) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn dim(&self) -> isize {
self.dim
}
}
impl Default for ReGLU {
fn default() -> Self {
Self::new(-1)
}
}
impl Module for ReGLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let binding = input.shape();
let input_shape = binding.dims();
let actual_dim = if self.dim < 0 {
(input_shape.len() as isize + self.dim) as usize
} else {
self.dim as usize
};
if actual_dim >= input_shape.len() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
self.dim,
input_shape.len()
)));
}
let dim_size = input_shape[actual_dim];
if dim_size % 2 != 0 {
return Err(TorshError::InvalidArgument(format!(
"Input dimension {} must be even for ReGLU, got {}",
actual_dim, dim_size
)));
}
let half_size = dim_size / 2;
let first_half = input.narrow(actual_dim as i32, 0, half_size)?;
let second_half = input.narrow(actual_dim as i32, half_size as i64, half_size)?;
let zero = zeros(second_half.shape().dims())?;
let relu_gate = second_half.maximum(&zero)?;
first_half.mul(&relu_gate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct SwiGLU {
base: ModuleBase,
dim: isize,
}
impl SwiGLU {
pub fn new(dim: isize) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn dim(&self) -> isize {
self.dim
}
}
impl Default for SwiGLU {
fn default() -> Self {
Self::new(-1)
}
}
impl Module for SwiGLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let binding = input.shape();
let input_shape = binding.dims();
let actual_dim = if self.dim < 0 {
(input_shape.len() as isize + self.dim) as usize
} else {
self.dim as usize
};
if actual_dim >= input_shape.len() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
self.dim,
input_shape.len()
)));
}
let dim_size = input_shape[actual_dim];
if dim_size % 2 != 0 {
return Err(TorshError::InvalidArgument(format!(
"Input dimension {} must be even for SwiGLU, got {}",
actual_dim, dim_size
)));
}
let half_size = dim_size / 2;
let first_half = input.narrow(actual_dim as i32, 0, half_size)?;
let second_half = input.narrow(actual_dim as i32, half_size as i64, half_size)?;
let neg_second = second_half.neg()?;
let exp_neg = neg_second.exp()?;
let one_plus_exp = exp_neg.add_scalar(1.0)?;
let ones_tensor = ones(second_half.shape().dims())?;
let sigmoid = ones_tensor.div(&one_plus_exp)?;
let silu_gate = second_half.mul(&sigmoid)?;
first_half.mul(&silu_gate)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glu_dimension_parameter() {
let glu = GLU::new(-1);
assert_eq!(glu.dim(), -1);
let glu_dim0 = GLU::new(0);
assert_eq!(glu_dim0.dim(), 0);
}
#[test]
fn test_glu_forward_shape() -> Result<()> {
let glu = GLU::new(-1);
let input = randn(&[2, 8])?; let output = glu.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 4]); Ok(())
}
#[test]
fn test_glu_invalid_dimension() -> Result<()> {
let glu = GLU::new(-1);
let input = randn(&[2, 7])?; let result = glu.forward(&input);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_geglu_forward_shape() -> Result<()> {
let geglu = GEGLU::new(-1);
let input = randn(&[2, 8])?;
let output = geglu.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_reglu_forward_shape() -> Result<()> {
let reglu = ReGLU::new(-1);
let input = randn(&[2, 8])?;
let output = reglu.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_swiglu_forward_shape() -> Result<()> {
let swiglu = SwiGLU::new(-1);
let input = randn(&[2, 8])?;
let output = swiglu.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_dimension_handling() -> Result<()> {
let glu = GLU::new(1);
let input = randn(&[2, 8, 3])?; let output = glu.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 4, 3]);
let glu_neg = GLU::new(-2);
let input = randn(&[2, 8, 3])?; let output = glu_neg.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 4, 3]);
Ok(())
}
#[test]
fn test_training_mode_toggle() -> Result<()> {
let mut glu = GLU::new(-1);
assert!(glu.training());
glu.eval();
assert!(!glu.training());
glu.train();
assert!(glu.training());
Ok(())
}
#[test]
fn test_default_implementations() {
let _glu = GLU::default();
let _geglu = GEGLU::default();
let _reglu = ReGLU::default();
let _swiglu = SwiGLU::default();
}
#[test]
fn test_convenience_constructors() {
let default_glu = GLU::default_dim();
assert_eq!(default_glu.dim(), -1);
}
#[test]
fn test_error_handling() -> Result<()> {
let glu = GLU::new(10); let input = randn(&[2, 8])?; let result = glu.forward(&input);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_all_gated_functions_consistency() -> Result<()> {
let input = randn(&[4, 16])?;
let glu = GLU::new(-1);
let geglu = GEGLU::new(-1);
let reglu = ReGLU::new(-1);
let swiglu = SwiGLU::new(-1);
let glu_out = glu.forward(&input)?;
let geglu_out = geglu.forward(&input)?;
let reglu_out = reglu.forward(&input)?;
let swiglu_out = swiglu.forward(&input)?;
assert_eq!(glu_out.shape(), geglu_out.shape());
assert_eq!(glu_out.shape(), reglu_out.shape());
assert_eq!(glu_out.shape(), swiglu_out.shape());
assert_eq!(glu_out.shape().dims(), &[4, 8]);
Ok(())
}
}