use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String};
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct Hardshrink {
base: ModuleBase,
lambd: f32,
}
impl Hardshrink {
pub fn new(lambd: f32) -> Self {
Self {
base: ModuleBase::new(),
lambd,
}
}
pub fn lambda(&self) -> f32 {
self.lambd
}
}
impl Default for Hardshrink {
fn default() -> Self {
Self::new(0.5)
}
}
impl Module for Hardshrink {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let abs_input = input.abs()?;
let lambda_tensor = full(input.shape().dims(), self.lambd)?;
let zeros_tensor = zeros(input.shape().dims())?;
let condition = abs_input.gt(&lambda_tensor)?;
input.where_tensor(&condition, &zeros_tensor)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Softshrink {
base: ModuleBase,
lambd: f32,
}
impl Softshrink {
pub fn new(lambd: f32) -> Self {
Self {
base: ModuleBase::new(),
lambd,
}
}
pub fn lambda(&self) -> f32 {
self.lambd
}
}
impl Default for Softshrink {
fn default() -> Self {
Self::new(0.5)
}
}
impl Module for Softshrink {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let abs_input = input.abs()?;
let lambda_tensor = full(input.shape().dims(), self.lambd)?;
let zero_tensor = zeros(input.shape().dims())?;
let abs_minus_lambda = abs_input.sub(&lambda_tensor)?;
let thresholded = abs_minus_lambda.maximum(&zero_tensor)?;
let sign_input = input.sign()?;
thresholded.mul(&sign_input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Hardtanh {
base: ModuleBase,
min_val: f32,
max_val: f32,
}
impl Hardtanh {
pub fn new(min_val: f32, max_val: f32) -> Self {
assert!(min_val < max_val, "min_val must be less than max_val");
Self {
base: ModuleBase::new(),
min_val,
max_val,
}
}
pub fn standard() -> Self {
Self::new(-1.0, 1.0)
}
pub fn min_val(&self) -> f32 {
self.min_val
}
pub fn max_val(&self) -> f32 {
self.max_val
}
}
impl Default for Hardtanh {
fn default() -> Self {
Self::new(-1.0, 1.0)
}
}
impl Module for Hardtanh {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let min_tensor = full(input.shape().dims(), self.min_val)?;
let max_tensor = full(input.shape().dims(), self.max_val)?;
let clamped_high = input.minimum(&max_tensor)?;
clamped_high.maximum(&min_tensor)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Threshold {
base: ModuleBase,
threshold: f32,
value: f32,
}
impl Threshold {
pub fn new(threshold: f32, value: f32) -> Self {
Self {
base: ModuleBase::new(),
threshold,
value,
}
}
pub fn zeroing(threshold: f32) -> Self {
Self::new(threshold, 0.0)
}
pub fn threshold(&self) -> f32 {
self.threshold
}
pub fn value(&self) -> f32 {
self.value
}
}
impl Default for Threshold {
fn default() -> Self {
Self::new(0.0, 0.0)
}
}
impl Module for Threshold {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let threshold_tensor = full(input.shape().dims(), self.threshold)?;
let value_tensor = full(input.shape().dims(), self.value)?;
let condition = input.gt(&threshold_tensor)?;
input.where_tensor(&condition, &value_tensor)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Tanhshrink {
base: ModuleBase,
}
impl Tanhshrink {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Tanhshrink {
fn default() -> Self {
Self::new()
}
}
impl Module for Tanhshrink {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let tanh_result = input.tanh()?;
input.sub(&tanh_result)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hardshrink_parameters() {
let hardshrink = Hardshrink::new(0.3);
assert_eq!(hardshrink.lambda(), 0.3);
let default_hardshrink = Hardshrink::default();
assert_eq!(default_hardshrink.lambda(), 0.5);
}
#[test]
fn test_hardshrink_forward() -> Result<()> {
let hardshrink = Hardshrink::new(0.5);
let input = randn(&[2, 3])?;
let output = hardshrink.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_softshrink_parameters() {
let softshrink = Softshrink::new(0.3);
assert_eq!(softshrink.lambda(), 0.3);
let default_softshrink = Softshrink::default();
assert_eq!(default_softshrink.lambda(), 0.5);
}
#[test]
fn test_softshrink_forward() -> Result<()> {
let softshrink = Softshrink::new(0.5);
let input = randn(&[2, 3])?;
let output = softshrink.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_hardtanh_parameters() {
let hardtanh = Hardtanh::new(-2.0, 3.0);
assert_eq!(hardtanh.min_val(), -2.0);
assert_eq!(hardtanh.max_val(), 3.0);
let standard_hardtanh = Hardtanh::standard();
assert_eq!(standard_hardtanh.min_val(), -1.0);
assert_eq!(standard_hardtanh.max_val(), 1.0);
}
#[test]
#[should_panic(expected = "min_val must be less than max_val")]
fn test_hardtanh_invalid_range() {
let _hardtanh = Hardtanh::new(1.0, -1.0); }
#[test]
fn test_hardtanh_forward() -> Result<()> {
let hardtanh = Hardtanh::new(-1.0, 1.0);
let input = randn(&[2, 3])?;
let output = hardtanh.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_threshold_parameters() {
let threshold = Threshold::new(0.1, -1.0);
assert_eq!(threshold.threshold(), 0.1);
assert_eq!(threshold.value(), -1.0);
let zeroing_threshold = Threshold::zeroing(0.5);
assert_eq!(zeroing_threshold.threshold(), 0.5);
assert_eq!(zeroing_threshold.value(), 0.0);
}
#[test]
fn test_threshold_forward() -> Result<()> {
let threshold = Threshold::new(0.0, -1.0);
let input = randn(&[2, 3])?;
let output = threshold.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_tanhshrink_forward() -> Result<()> {
let tanhshrink = Tanhshrink::new();
let input = randn(&[2, 3])?;
let output = tanhshrink.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_training_mode_toggle() -> Result<()> {
let mut hardshrink = Hardshrink::new(0.5);
assert!(hardshrink.training());
hardshrink.eval();
assert!(!hardshrink.training());
hardshrink.train();
assert!(hardshrink.training());
Ok(())
}
#[test]
fn test_default_implementations() {
let _hardshrink = Hardshrink::default();
let _softshrink = Softshrink::default();
let _hardtanh = Hardtanh::default();
let _threshold = Threshold::default();
let _tanhshrink = Tanhshrink::default();
}
#[test]
fn test_convenience_constructors() {
let standard_hardtanh = Hardtanh::standard();
assert_eq!(standard_hardtanh.min_val(), -1.0);
assert_eq!(standard_hardtanh.max_val(), 1.0);
let zeroing_threshold = Threshold::zeroing(0.1);
assert_eq!(zeroing_threshold.threshold(), 0.1);
assert_eq!(zeroing_threshold.value(), 0.0);
}
}