use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String};
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct Softmax {
base: ModuleBase,
dim: Option<usize>,
}
impl Softmax {
pub fn new(dim: Option<usize>) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn along_last_dim() -> Self {
Self::new(Some(1))
}
pub fn global() -> Self {
Self::new(None)
}
pub fn dim(&self) -> Option<usize> {
self.dim
}
}
impl Default for Softmax {
fn default() -> Self {
Self::new(None)
}
}
impl Module for Softmax {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let max_vals = if let Some(dim) = self.dim {
input.max_dim(dim as i32, true)?
} else {
let max_val = input.max(None, false)?;
full(input.shape().dims(), max_val.item()?)?
};
let shifted = input.sub(&max_vals)?;
let exp_input = shifted.exp()?;
let sum_exp = if let Some(dim) = self.dim {
exp_input.sum_dim(&[dim as i32], true)?
} else {
let sum_val = exp_input.sum()?;
full(input.shape().dims(), sum_val.item()?)?
};
exp_input.div(&sum_exp)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct LogSoftmax {
base: ModuleBase,
dim: Option<usize>,
}
impl LogSoftmax {
pub fn new(dim: Option<usize>) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn along_last_dim() -> Self {
Self::new(Some(1))
}
pub fn global() -> Self {
Self::new(None)
}
pub fn dim(&self) -> Option<usize> {
self.dim
}
}
impl Default for LogSoftmax {
fn default() -> Self {
Self::new(None)
}
}
impl Module for LogSoftmax {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let max_vals = if let Some(dim) = self.dim {
input.max_dim(dim as i32, true)?
} else {
let max_val = input.max(None, false)?;
full(input.shape().dims(), max_val.item()?)?
};
let shifted = input.sub(&max_vals)?;
let exp_shifted = shifted.exp()?;
let sum_exp = if let Some(dim) = self.dim {
exp_shifted.sum_dim(&[dim as i32], true)?
} else {
let sum_val = exp_shifted.sum()?;
full(input.shape().dims(), sum_val.item()?)?
};
let log_sum_exp = sum_exp.log()?;
let log_sum_exp_with_max = log_sum_exp.add(&max_vals)?;
input.sub(&log_sum_exp_with_max)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct LogSigmoid {
base: ModuleBase,
}
impl LogSigmoid {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for LogSigmoid {
fn default() -> Self {
Self::new()
}
}
impl Module for LogSigmoid {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let zeros_tensor = zeros(input.shape().dims())?;
let ones_tensor = ones(input.shape().dims())?;
let positive_condition = input.ge(&zeros_tensor)?;
let negative_condition = input.lt(&zeros_tensor)?;
let pos_mask_f32 = ones_tensor.where_tensor(&positive_condition, &zeros_tensor)?;
let neg_mask_f32 = ones_tensor.where_tensor(&negative_condition, &zeros_tensor)?;
let neg_input = input.neg()?;
let exp_neg = neg_input.exp()?;
let one_plus_exp_neg = exp_neg.add(&ones_tensor)?;
let log_one_plus_exp_neg = one_plus_exp_neg.log()?;
let positive_result = log_one_plus_exp_neg.neg()?;
let exp_input = input.exp()?;
let one_plus_exp = exp_input.add(&ones_tensor)?;
let log_one_plus_exp = one_plus_exp.log()?;
let negative_result = input.sub(&log_one_plus_exp)?;
let pos_part = positive_result.mul(&pos_mask_f32)?;
let neg_part = negative_result.mul(&neg_mask_f32)?;
pos_part.add(&neg_part)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_creation() {
let softmax = Softmax::new(Some(1));
assert_eq!(softmax.dim(), Some(1));
let global_softmax = Softmax::global();
assert_eq!(global_softmax.dim(), None);
let last_dim_softmax = Softmax::along_last_dim();
assert_eq!(last_dim_softmax.dim(), Some(1));
}
#[test]
fn test_softmax_forward() -> Result<()> {
let softmax = Softmax::new(Some(1));
let input = randn(&[2, 3])?;
let output = softmax.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_log_softmax_creation() {
let log_softmax = LogSoftmax::new(Some(1));
assert_eq!(log_softmax.dim(), Some(1));
let global_log_softmax = LogSoftmax::global();
assert_eq!(global_log_softmax.dim(), None);
}
#[test]
fn test_log_softmax_forward() -> Result<()> {
let log_softmax = LogSoftmax::new(Some(1));
let input = randn(&[2, 3])?;
let output = log_softmax.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_log_sigmoid_forward() -> Result<()> {
let log_sigmoid = LogSigmoid::new();
let input = randn(&[2, 3])?;
let output = log_sigmoid.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_softmax_dimension_parameter() {
let softmax_dim0 = Softmax::new(Some(0));
let softmax_dim1 = Softmax::new(Some(1));
let softmax_global = Softmax::new(None);
assert_eq!(softmax_dim0.dim(), Some(0));
assert_eq!(softmax_dim1.dim(), Some(1));
assert_eq!(softmax_global.dim(), None);
}
#[test]
fn test_training_mode_toggle() -> Result<()> {
let mut softmax = Softmax::new(Some(1));
assert!(softmax.training());
softmax.eval();
assert!(!softmax.training());
softmax.train();
assert!(softmax.training());
Ok(())
}
#[test]
fn test_default_implementations() {
let _softmax = Softmax::default();
let _log_softmax = LogSoftmax::default();
let _log_sigmoid = LogSigmoid::default();
}
#[test]
fn test_log_sigmoid_numerical_stability() -> Result<()> {
let log_sigmoid = LogSigmoid::new();
let large_positive = full(&[2, 2], 100.0)?;
let result = log_sigmoid.forward(&large_positive)?;
assert_eq!(result.shape(), large_positive.shape());
let large_negative = full(&[2, 2], -100.0)?;
let result = log_sigmoid.forward(&large_negative)?;
assert_eq!(result.shape(), large_negative.shape());
Ok(())
}
#[test]
fn test_convenience_constructors() {
let last_dim_softmax = Softmax::along_last_dim();
let global_softmax = Softmax::global();
let last_dim_log_softmax = LogSoftmax::along_last_dim();
let global_log_softmax = LogSoftmax::global();
assert_eq!(last_dim_softmax.dim(), Some(1));
assert_eq!(global_softmax.dim(), None);
assert_eq!(last_dim_log_softmax.dim(), Some(1));
assert_eq!(global_log_softmax.dim(), None);
}
}