use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String};
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct Softplus {
base: ModuleBase,
beta: f32,
threshold: f32,
}
impl Softplus {
pub fn new(beta: f32, threshold: f32) -> Self {
Self {
base: ModuleBase::new(),
beta,
threshold,
}
}
pub fn standard() -> Self {
Self::new(1.0, 20.0)
}
pub fn beta(&self) -> f32 {
self.beta
}
pub fn threshold(&self) -> f32 {
self.threshold
}
}
impl Default for Softplus {
fn default() -> Self {
Self::new(1.0, 20.0)
}
}
impl Module for Softplus {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let beta_tensor = full(input.shape().dims(), self.beta)?;
let beta_x = input.mul(&beta_tensor)?;
let threshold_tensor = full(input.shape().dims(), self.threshold)?;
let above_threshold = beta_x.gt(&threshold_tensor)?;
let exp_beta_x = beta_x.exp()?;
let ones_tensor = ones(input.shape().dims())?;
let one_plus_exp = exp_beta_x.add(&ones_tensor)?;
let log_part = one_plus_exp.log()?;
let inv_beta_tensor = full(input.shape().dims(), 1.0 / self.beta)?;
let softplus_part = log_part.mul(&inv_beta_tensor)?;
input.where_tensor(&above_threshold, &softplus_part)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Softsign {
base: ModuleBase,
}
impl Softsign {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Softsign {
fn default() -> Self {
Self::new()
}
}
impl Module for Softsign {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let abs_input = input.abs()?;
let ones_tensor = ones(input.shape().dims())?;
let denominator = abs_input.add(&ones_tensor)?;
input.div(&denominator)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Hardsigmoid {
base: ModuleBase,
alpha: f32,
beta: f32,
}
impl Hardsigmoid {
pub fn new_with_params(alpha: f32, beta: f32) -> Self {
Self {
base: ModuleBase::new(),
alpha,
beta,
}
}
pub fn new() -> Self {
Self::new_with_params(0.2, 0.5)
}
pub fn alpha(&self) -> f32 {
self.alpha
}
pub fn beta(&self) -> f32 {
self.beta
}
}
impl Default for Hardsigmoid {
fn default() -> Self {
Self::new()
}
}
impl Module for Hardsigmoid {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let alpha_tensor = full(input.shape().dims(), self.alpha)?;
let beta_tensor = full(input.shape().dims(), self.beta)?;
let zero_tensor = zeros(input.shape().dims())?;
let one_tensor = ones(input.shape().dims())?;
let linear = input.mul(&alpha_tensor)?.add(&beta_tensor)?;
let clamped_high = linear.minimum(&one_tensor)?;
clamped_high.maximum(&zero_tensor)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softplus_parameters() {
let softplus = Softplus::new(2.0, 30.0);
assert_eq!(softplus.beta(), 2.0);
assert_eq!(softplus.threshold(), 30.0);
let standard_softplus = Softplus::standard();
assert_eq!(standard_softplus.beta(), 1.0);
assert_eq!(standard_softplus.threshold(), 20.0);
}
#[test]
fn test_softplus_forward() -> Result<()> {
let softplus = Softplus::new(1.0, 20.0);
let input = randn(&[2, 3])?;
let output = softplus.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_softsign_forward() -> Result<()> {
let softsign = Softsign::new();
let input = randn(&[2, 3])?;
let output = softsign.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_hardsigmoid_parameters() {
let hardsigmoid = Hardsigmoid::new_with_params(0.1, 0.6);
assert_eq!(hardsigmoid.alpha(), 0.1);
assert_eq!(hardsigmoid.beta(), 0.6);
let standard_hardsigmoid = Hardsigmoid::new();
assert_eq!(standard_hardsigmoid.alpha(), 0.2);
assert_eq!(standard_hardsigmoid.beta(), 0.5);
}
#[test]
fn test_hardsigmoid_forward() -> Result<()> {
let hardsigmoid = Hardsigmoid::new();
let input = randn(&[2, 3])?;
let output = hardsigmoid.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_softplus_numerical_behavior() -> Result<()> {
let softplus = Softplus::new(1.0, 20.0);
let small_input = full(&[2, 2], 1.0)?;
let result = softplus.forward(&small_input)?;
assert_eq!(result.shape(), small_input.shape());
let large_input = full(&[2, 2], 25.0)?;
let result = softplus.forward(&large_input)?;
assert_eq!(result.shape(), large_input.shape());
Ok(())
}
#[test]
fn test_softsign_range() -> Result<()> {
let softsign = Softsign::new();
let large_positive = full(&[2, 2], 1000.0)?;
let result = softsign.forward(&large_positive)?;
assert_eq!(result.shape(), large_positive.shape());
let large_negative = full(&[2, 2], -1000.0)?;
let result = softsign.forward(&large_negative)?;
assert_eq!(result.shape(), large_negative.shape());
Ok(())
}
#[test]
fn test_hardsigmoid_range() -> Result<()> {
let hardsigmoid = Hardsigmoid::new();
let large_positive = full(&[2, 2], 1000.0)?;
let result = hardsigmoid.forward(&large_positive)?;
assert_eq!(result.shape(), large_positive.shape());
let large_negative = full(&[2, 2], -1000.0)?;
let result = hardsigmoid.forward(&large_negative)?;
assert_eq!(result.shape(), large_negative.shape());
Ok(())
}
#[test]
fn test_training_mode_toggle() -> Result<()> {
let mut softplus = Softplus::new(1.0, 20.0);
assert!(softplus.training());
softplus.eval();
assert!(!softplus.training());
softplus.train();
assert!(softplus.training());
Ok(())
}
#[test]
fn test_default_implementations() {
let _softplus = Softplus::default();
let _softsign = Softsign::default();
let _hardsigmoid = Hardsigmoid::default();
}
#[test]
fn test_convenience_constructors() {
let standard_softplus = Softplus::standard();
assert_eq!(standard_softplus.beta(), 1.0);
assert_eq!(standard_softplus.threshold(), 20.0);
let custom_hardsigmoid = Hardsigmoid::new_with_params(0.1, 0.3);
assert_eq!(custom_hardsigmoid.alpha(), 0.1);
assert_eq!(custom_hardsigmoid.beta(), 0.3);
}
}