use crate::{Module, ModuleBase, Parameter};
use std::sync::Arc;
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String};
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct ReLU {
base: ModuleBase,
}
impl ReLU {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for ReLU {
fn default() -> Self {
Self::new()
}
}
impl Module for ReLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let zero = zeros(input.shape().dims())?;
input.maximum(&zero)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Sigmoid {
base: ModuleBase,
}
impl Sigmoid {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Sigmoid {
fn default() -> Self {
Self::new()
}
}
impl Module for Sigmoid {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let neg_input = input.neg()?;
let exp_neg = neg_input.exp()?;
let one_plus_exp = exp_neg.add_scalar(1.0)?;
let one = ones(input.shape().dims())?;
one.div(&one_plus_exp)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Tanh {
base: ModuleBase,
}
impl Tanh {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for Tanh {
fn default() -> Self {
Self::new()
}
}
impl Module for Tanh {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
input.tanh()
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct LeakyReLU {
base: ModuleBase,
negative_slope: f64,
}
impl LeakyReLU {
pub fn new(negative_slope: f64) -> Self {
Self {
base: ModuleBase::new(),
negative_slope,
}
}
pub fn default_slope() -> Self {
Self::new(0.01)
}
}
impl Default for LeakyReLU {
fn default() -> Self {
Self::new(0.01)
}
}
impl Module for LeakyReLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let negative_part = input.mul_scalar(self.negative_slope as f32)?;
input.maximum(&negative_part)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct ReLU6 {
base: ModuleBase,
}
impl ReLU6 {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
}
}
}
impl Default for ReLU6 {
fn default() -> Self {
Self::new()
}
}
impl Module for ReLU6 {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let zero = zeros(input.shape().dims())?;
let six = full(input.shape().dims(), 6.0)?;
let relu_output = input.maximum(&zero)?;
relu_output.minimum(&six)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct PReLU {
base: ModuleBase,
alpha: Parameter,
num_parameters: usize,
}
impl PReLU {
pub fn new(num_parameters: usize) -> Result<Self> {
let alpha_shape = if num_parameters == 1 {
vec![1]
} else {
vec![num_parameters]
};
let alpha_tensor = full(&alpha_shape, 0.25)?;
let alpha = Parameter::new(alpha_tensor);
let mut base = ModuleBase::new();
base.register_parameter("alpha".to_string(), alpha.clone());
Ok(Self {
base,
alpha,
num_parameters,
})
}
pub fn single_parameter() -> Result<Self> {
Self::new(1)
}
pub fn alpha(&self) -> Arc<parking_lot::RwLock<Tensor>> {
self.alpha.tensor()
}
}
impl Module for PReLU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let alpha_expanded = if self.num_parameters == 1 {
self.alpha
.tensor()
.read()
.broadcast_to(input.shape().dims())?
} else {
let mut alpha_shape = vec![1i32; input.shape().ndim()];
alpha_shape[1] = self.num_parameters as i32; self.alpha
.tensor()
.read()
.reshape(&alpha_shape)?
.broadcast_to(input.shape().dims())?
};
let negative_part = input.mul(&alpha_expanded)?;
input.maximum(&negative_part)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = self.base.parameters.clone();
params.insert("alpha".to_string(), self.alpha.clone());
params
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)?;
self.alpha.to_device(device)?;
Ok(())
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = self.base.named_parameters();
params.insert("alpha".to_string(), self.alpha.clone());
params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_relu_forward() -> Result<()> {
let relu = ReLU::new();
let input = tensor_1d(&[-2.0, -1.0, 0.0, 1.0, 2.0])?;
let output = relu.forward(&input)?;
let expected = tensor_1d(&[0.0, 0.0, 0.0, 1.0, 2.0])?;
assert_eq!(output.shape(), expected.shape());
Ok(())
}
#[test]
fn test_sigmoid_shape() -> Result<()> {
let sigmoid = Sigmoid::new();
let input = randn(&[2, 3, 4])?;
let output = sigmoid.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_tanh_shape() -> Result<()> {
let tanh = Tanh::new();
let input = randn(&[5, 10])?;
let output = tanh.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_leaky_relu_forward() -> Result<()> {
let leaky_relu = LeakyReLU::new(0.1);
let input = tensor_1d(&[-2.0, -1.0, 0.0, 1.0, 2.0])?;
let output = leaky_relu.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_relu6_forward() -> Result<()> {
let relu6 = ReLU6::new();
let input = tensor_1d(&[-2.0, 0.0, 3.0, 8.0])?;
let output = relu6.forward(&input)?;
assert_eq!(output.shape(), input.shape());
Ok(())
}
#[test]
fn test_prelu_creation() -> Result<()> {
let prelu = PReLU::new(64)?;
assert_eq!(prelu.num_parameters, 64);
let single_prelu = PReLU::single_parameter()?;
assert_eq!(single_prelu.num_parameters, 1);
Ok(())
}
#[test]
fn test_training_mode_toggle() -> Result<()> {
let mut relu = ReLU::new();
assert!(relu.training());
relu.eval();
assert!(!relu.training());
relu.train();
assert!(relu.training());
Ok(())
}
#[test]
fn test_default_implementations() {
let _relu = ReLU::default();
let _sigmoid = Sigmoid::default();
let _tanh = Tanh::default();
let _leaky_relu = LeakyReLU::default();
let _relu6 = ReLU6::default();
}
}