use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String};
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct Softmax {
base: ModuleBase,
dim: Option<usize>,
}
impl Softmax {
pub fn new(dim: Option<usize>) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn last_dim() -> Self {
Self::new(Some(1))
}
pub fn all_dims() -> Self {
Self::new(None)
}
}
impl Default for Softmax {
fn default() -> Self {
Self::new(None)
}
}
impl Module for Softmax {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let max_vals = if let Some(dim) = self.dim {
input.max_dim(&[dim as i32], true)?
} else {
let max_val = input.max()?;
full(input.shape().dims(), max_val)?
};
let shifted_input = input.sub(&max_vals)?;
let exp_input = shifted_input.exp()?;
let sum_exp = if let Some(dim) = self.dim {
exp_input.sum_dim(&[dim as i32], true)?
} else {
let sum_val = exp_input.sum()?;
full(input.shape().dims(), sum_val)?
};
exp_input.div(&sum_exp)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for Softmax {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Softmax").field("dim", &self.dim).finish()
}
}
pub struct LogSoftmax {
base: ModuleBase,
dim: Option<usize>,
}
impl LogSoftmax {
pub fn new(dim: Option<usize>) -> Self {
Self {
base: ModuleBase::new(),
dim,
}
}
pub fn last_dim() -> Self {
Self::new(Some(1))
}
pub fn all_dims() -> Self {
Self::new(None)
}
}
impl Default for LogSoftmax {
fn default() -> Self {
Self::new(None)
}
}
impl Module for LogSoftmax {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let max_vals = if let Some(dim) = self.dim {
input.max_dim(&[dim as i32], true)?
} else {
let max_val = input.max()?;
full(input.shape().dims(), max_val)?
};
let shifted_input = input.sub(&max_vals)?;
let exp_shifted = shifted_input.exp()?;
let sum_exp = if let Some(dim) = self.dim {
exp_shifted.sum_dim(&[dim as i32], true)?
} else {
let sum_val = exp_shifted.sum()?;
full(input.shape().dims(), sum_val)?
};
let log_sum_exp = sum_exp.log()?;
let log_sum_exp_with_max = log_sum_exp.add(&max_vals)?;
input.sub(&log_sum_exp_with_max)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for LogSoftmax {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LogSoftmax")
.field("dim", &self.dim)
.finish()
}
}
pub struct Softplus {
base: ModuleBase,
beta: f32,
threshold: f32,
}
impl Softplus {
pub fn new(beta: f32, threshold: f32) -> Self {
Self {
base: ModuleBase::new(),
beta,
threshold,
}
}
pub fn default_params() -> Self {
Self::new(1.0, 20.0)
}
pub fn with_beta(beta: f32) -> Self {
Self::new(beta, 20.0)
}
}
impl Default for Softplus {
fn default() -> Self {
Self::default_params()
}
}
impl Module for Softplus {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let beta_x = input.scalar_mul(self.beta)?;
let threshold_tensor = full(input.shape().dims(), self.threshold)?;
let above_threshold = beta_x.gt(&threshold_tensor)?;
let linear_part = input.clone();
let exp_beta_x = beta_x.exp()?;
let ones_tensor = ones(input.shape().dims())?;
let one_plus_exp = ones_tensor.add(&exp_beta_x)?;
let log_part = one_plus_exp.log()?;
let softplus_part = log_part.scalar_mul(1.0 / self.beta)?;
let mask_data: Vec<bool> = above_threshold.to_vec()?;
let selected_values: Vec<f32> = mask_data
.iter()
.zip(
linear_part
.to_vec()?
.iter()
.zip(softplus_part.to_vec()?.iter()),
)
.map(
|(&use_linear, (&linear_val, &softplus_val))| {
if use_linear {
linear_val
} else {
softplus_val
}
},
)
.collect();
Tensor::from_data(
selected_values,
input.shape().dims().to_vec(),
input.device(),
)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for Softplus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Softplus")
.field("beta", &self.beta)
.field("threshold", &self.threshold)
.finish()
}
}
pub struct Softsign {
base: ModuleBase,
}
impl Softsign {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Softsign {
fn default() -> Self {
Self::new()
}
}
impl Module for Softsign {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let abs_input = input.abs()?;
let ones_tensor = ones(input.shape().dims())?;
let one_plus_abs = ones_tensor.add(&abs_input)?;
input.div(&one_plus_abs)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for Softsign {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Softsign").finish()
}
}
pub fn log_sum_exp(input: &Tensor, dim: Option<usize>, keepdim: bool) -> Result<Tensor> {
let max_vals = if let Some(d) = dim {
input.max_dim(&[d as i32], keepdim)?
} else {
let max_val = input.max()?;
if keepdim {
full(input.shape().dims(), max_val)?
} else {
Tensor::from_data(vec![max_val], vec![1], input.device())?
}
};
let shifted = input.sub(&max_vals)?;
let exp_shifted = shifted.exp()?;
let sum_exp = if let Some(d) = dim {
exp_shifted.sum_dim(&[d as i32], keepdim)?
} else {
let sum_val = exp_shifted.sum()?;
if keepdim {
full(input.shape().dims(), sum_val)?
} else {
Tensor::from_data(vec![sum_val], vec![1], input.device())?
}
};
let log_sum = sum_exp.log()?;
max_vals.add(&log_sum)
}
pub fn stable_softmax(input: &Tensor, dim: Option<usize>) -> Result<Tensor> {
let log_sum_exp_val = log_sum_exp(input, dim, true)?;
let log_softmax_val = input.sub(&log_sum_exp_val)?;
log_softmax_val.exp()
}
pub fn stable_log_softmax(input: &Tensor, dim: Option<usize>) -> Result<Tensor> {
let log_sum_exp_val = log_sum_exp(input, dim, true)?;
input.sub(&log_sum_exp_val)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use torsh_tensor::creation::*;
#[test]
fn test_softmax_forward() {
let softmax = Softmax::new(None);
let input = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let output = softmax.forward(&input).expect("forward pass should succeed");
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
let sum: f32 = output_vec.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
assert!(output_vec.iter().all(|&x| x > 0.0));
}
#[test]
fn test_log_softmax_forward() {
let log_softmax = LogSoftmax::new(None);
let input = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let output = log_softmax.forward(&input).expect("forward pass should succeed");
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
assert!(output_vec.iter().all(|&x| x <= 0.0));
let softmax = Softmax::new(None);
let softmax_output = softmax.forward(&input).expect("forward pass should succeed");
let log_softmax_exp: Vec<f32> = output_vec.iter().map(|&x| x.exp()).collect();
let softmax_vec = softmax_output.to_vec().expect("tensor to vec conversion should succeed");
for (actual, expected) in log_softmax_exp.iter().zip(softmax_vec.iter()) {
assert_relative_eq!(actual, expected, epsilon = 1e-5);
}
}
#[test]
fn test_softplus_forward() {
let softplus = Softplus::new(1.0, 20.0);
let input = Tensor::from_data(vec![-2.0, 0.0, 2.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let output = softplus.forward(&input).expect("forward pass should succeed");
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
assert!(output_vec.iter().all(|&x| x > 0.0));
assert_relative_eq!(output_vec[1], 2.0_f32.ln(), epsilon = 1e-3);
let large_input = Tensor::from_data(vec![10.0], vec![1], DeviceType::Cpu).expect("Tensor should succeed");
let large_output = softplus.forward(&large_input).expect("forward pass should succeed");
assert_relative_eq!(large_output.to_vec().expect("tensor to vec conversion should succeed")[0], 10.0, epsilon = 1e-3);
}
#[test]
fn test_softsign_forward() {
let softsign = Softsign::new();
let input = Tensor::from_data(vec![-2.0, 0.0, 2.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let output = softsign.forward(&input).expect("forward pass should succeed");
let output_vec = output.to_vec().expect("tensor to vec conversion should succeed");
assert_relative_eq!(output_vec[1], 0.0, epsilon = 1e-5);
assert_relative_eq!(output_vec[0], -output_vec[2], epsilon = 1e-5);
assert!(output_vec.iter().all(|&x| x > -1.0 && x < 1.0));
assert_relative_eq!(output_vec[2], 2.0 / 3.0, epsilon = 1e-5);
}
#[test]
fn test_log_sum_exp_utility() {
let input = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let result = log_sum_exp(&input, None, false).expect("log sum exp should succeed");
let result_val = result.to_vec().expect("tensor to vec conversion should succeed")[0];
let expected = (1.0_f32.exp() + 2.0_f32.exp() + 3.0_f32.exp()).ln();
assert_relative_eq!(result_val, expected, epsilon = 1e-5);
}
#[test]
fn test_stable_softmax_utility() {
let input = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let result = stable_softmax(&input, None).expect("stable softmax should succeed");
let result_vec = result.to_vec().expect("tensor to vec conversion should succeed");
let sum: f32 = result_vec.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
let large_input =
Tensor::from_data(vec![100.0, 101.0, 102.0], vec![3], DeviceType::Cpu).expect("Tensor should succeed");
let large_result = stable_softmax(&large_input, None).expect("stable softmax should succeed");
let large_sum: f32 = large_result.to_vec().expect("tensor to vec conversion should succeed").iter().sum();
assert_relative_eq!(large_sum, 1.0, epsilon = 1e-5);
}
#[test]
fn test_module_interface() {
let mut softmax = Softmax::new(Some(0));
assert!(softmax.training()); softmax.eval();
assert!(!softmax.training());
softmax.train();
assert!(softmax.training());
assert!(softmax.parameters().is_empty());
assert!(softmax.named_parameters().is_empty());
}
#[test]
fn test_softmax_dimensions() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
DeviceType::Cpu,
)
.expect("operation should succeed");
let softmax_all = Softmax::new(None);
let output_all = softmax_all.forward(&input).expect("forward pass should succeed");
let sum_all: f32 = output_all.to_vec().expect("tensor to vec conversion should succeed").iter().sum();
assert_relative_eq!(sum_all, 1.0, epsilon = 1e-5);
let softmax_dim1 = Softmax::new(Some(1));
let output_dim1 = softmax_dim1.forward(&input).expect("forward pass should succeed");
let output_2d = output_dim1.to_vec().expect("tensor to vec conversion should succeed");
let row1_sum: f32 = output_2d[0..3].iter().sum();
let row2_sum: f32 = output_2d[3..6].iter().sum();
assert_relative_eq!(row1_sum, 1.0, epsilon = 1e-5);
assert_relative_eq!(row2_sum, 1.0, epsilon = 1e-5);
}
}