use ferrotorch_core::grad_fns::activation as act;
use ferrotorch_core::grad_fns::arithmetic;
use ferrotorch_core::ops::elementwise::unary_map;
use ferrotorch_core::{normalize_axis, Float, FerrotorchError, FerrotorchResult, Tensor};
use crate::module::Module;
use crate::parameter::Parameter;
macro_rules! impl_activation_module {
($ty:ident) => {
impl<T: Float> Module<T> for $ty {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward(input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
};
}
#[derive(Debug, Clone)]
pub struct ReLU {
training: bool,
}
impl ReLU {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::relu(input)
}
}
impl Default for ReLU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(ReLU);
pub use act::GeluApproximate;
#[derive(Debug, Clone)]
pub struct GELU {
approximate: GeluApproximate,
training: bool,
}
impl GELU {
pub fn new() -> Self {
Self {
approximate: GeluApproximate::default(),
training: true,
}
}
pub fn with_approximate(approximate: GeluApproximate) -> Self {
Self {
approximate,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::gelu_with(input, self.approximate)
}
}
impl Default for GELU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(GELU);
#[derive(Debug, Clone)]
pub struct SiLU {
training: bool,
}
impl SiLU {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::silu(input)
}
}
impl Default for SiLU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(SiLU);
#[derive(Debug, Clone)]
pub struct Sigmoid {
training: bool,
}
impl Sigmoid {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::sigmoid(input)
}
}
impl Default for Sigmoid {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Sigmoid);
#[derive(Debug, Clone)]
pub struct Tanh {
training: bool,
}
impl Tanh {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::tanh(input)
}
}
impl Default for Tanh {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Tanh);
#[derive(Debug, Clone)]
pub struct Softmax {
pub dim: isize,
training: bool,
}
impl Softmax {
pub fn new(dim: isize) -> Self {
Self {
dim,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return act::softmax(input);
}
let axis = normalize_axis(self.dim, ndim)?;
if axis != ndim - 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Softmax currently only supports dim=-1 (last axis), \
but got dim={} (axis={}) for a {}-D tensor",
self.dim, axis, ndim,
),
});
}
act::softmax(input)
}
}
impl Default for Softmax {
fn default() -> Self {
Self::new(-1)
}
}
impl_activation_module!(Softmax);
#[derive(Debug, Clone)]
pub struct LogSoftmax {
pub dim: isize,
training: bool,
}
impl LogSoftmax {
pub fn new(dim: isize) -> Self {
Self {
dim,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return act::log_softmax(input);
}
let axis = normalize_axis(self.dim, ndim)?;
if axis != ndim - 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"LogSoftmax currently only supports dim=-1 (last axis), \
but got dim={} (axis={}) for a {}-D tensor",
self.dim, axis, ndim,
),
});
}
act::log_softmax(input)
}
}
impl Default for LogSoftmax {
fn default() -> Self {
Self::new(-1)
}
}
impl_activation_module!(LogSoftmax);
#[derive(Debug, Clone)]
pub struct LeakyReLU {
pub negative_slope: f64,
training: bool,
}
impl LeakyReLU {
pub fn new(negative_slope: f64) -> Self {
Self {
negative_slope,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if (self.negative_slope - 0.0).abs() < f64::EPSILON {
return act::relu(input);
}
if (self.negative_slope - 1.0).abs() < f64::EPSILON {
return Ok(input.clone());
}
let relu_x = act::relu(input)?;
let scale = T::from(1.0 - self.negative_slope).unwrap();
let slope = T::from(self.negative_slope).unwrap();
let scale_tensor = ferrotorch_core::scalar(scale)?;
let slope_tensor = ferrotorch_core::scalar(slope)?;
let scaled_relu = arithmetic::mul(&relu_x, &scale_tensor)?;
let scaled_x = arithmetic::mul(input, &slope_tensor)?;
arithmetic::add(&scaled_relu, &scaled_x)
}
}
impl Default for LeakyReLU {
fn default() -> Self {
Self::new(0.01)
}
}
impl_activation_module!(LeakyReLU);
#[derive(Debug, Clone)]
pub struct ELU {
pub alpha: f64,
training: bool,
}
impl ELU {
pub fn new(alpha: f64) -> Self {
Self {
alpha,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::elu(input, self.alpha)
}
}
impl Default for ELU {
fn default() -> Self {
Self::new(1.0)
}
}
impl_activation_module!(ELU);
#[derive(Debug, Clone)]
pub struct Mish {
training: bool,
}
impl Mish {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::mish(input)
}
}
impl Default for Mish {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Mish);
#[derive(Debug, Clone)]
pub struct PReLU<T: Float> {
pub alpha: Parameter<T>,
training: bool,
}
impl<T: Float> PReLU<T> {
pub fn new(init_alpha: f64) -> FerrotorchResult<Self> {
let alpha_val = T::from(init_alpha).unwrap();
let alpha_tensor = ferrotorch_core::from_slice(&[alpha_val], &[1])?;
Ok(Self {
alpha: Parameter::new(alpha_tensor),
training: true,
})
}
pub fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let relu_x = act::relu(input)?;
let cpu_alpha = if self.alpha.tensor().is_cuda() { self.alpha.tensor().cpu()? } else { self.alpha.tensor().clone() };
let alpha_data = cpu_alpha.data()?;
let alpha_val = alpha_data[0];
let one_minus_alpha = T::from(1.0).unwrap() - alpha_val;
let scale_tensor = ferrotorch_core::scalar(one_minus_alpha)?;
let alpha_tensor = ferrotorch_core::scalar(alpha_val)?;
let scaled_relu = arithmetic::mul(&relu_x, &scale_tensor)?;
let scaled_x = arithmetic::mul(input, &alpha_tensor)?;
arithmetic::add(&scaled_relu, &scaled_x)
}
}
impl<T: Float> Module<T> for PReLU<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward(input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![&self.alpha]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![&mut self.alpha]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![("alpha".to_string(), &self.alpha)]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug, Clone)]
pub struct CELU {
pub alpha: f64,
training: bool,
}
impl CELU {
pub fn new(alpha: f64) -> Self {
Self {
alpha,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let alpha = T::from(self.alpha).unwrap();
unary_map(input, |x| {
let pos = if x > zero { x } else { zero };
let neg = if x < zero {
alpha * ((x / alpha).exp() - one)
} else {
zero
};
pos + neg
})
}
}
impl Default for CELU {
fn default() -> Self {
Self::new(1.0)
}
}
impl_activation_module!(CELU);
#[derive(Debug, Clone)]
pub struct SELU {
training: bool,
}
const SELU_ALPHA: f64 = 1.6732632423543772;
const SELU_LAMBDA: f64 = 1.0507009873554805;
impl SELU {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let alpha = T::from(SELU_ALPHA).unwrap();
let lambda = T::from(SELU_LAMBDA).unwrap();
unary_map(input, |x| {
if x > zero {
lambda * x
} else {
lambda * alpha * (x.exp() - one)
}
})
}
}
impl Default for SELU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(SELU);
#[derive(Debug, Clone)]
pub struct HardSigmoid {
training: bool,
}
impl HardSigmoid {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let three = T::from(3.0).unwrap();
let six = T::from(6.0).unwrap();
unary_map(input, |x| {
let v = (x + three) / six;
if v < zero {
zero
} else if v > one {
one
} else {
v
}
})
}
}
impl Default for HardSigmoid {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(HardSigmoid);
#[derive(Debug, Clone)]
pub struct HardSwish {
training: bool,
}
impl HardSwish {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let three = T::from(3.0).unwrap();
let six = T::from(6.0).unwrap();
unary_map(input, |x| {
let hard_sig = {
let v = (x + three) / six;
if v < zero {
zero
} else if v > one {
one
} else {
v
}
};
x * hard_sig
})
}
}
impl Default for HardSwish {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(HardSwish);
#[derive(Debug, Clone)]
pub struct Softplus {
pub beta: f64,
pub threshold: f64,
training: bool,
}
impl Softplus {
pub fn new(beta: f64) -> Self {
Self {
beta,
threshold: 20.0,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::softplus(input, self.beta, self.threshold)
}
}
impl Default for Softplus {
fn default() -> Self {
Self::new(1.0)
}
}
impl_activation_module!(Softplus);
#[derive(Debug, Clone)]
pub struct GLU {
training: bool,
}
impl GLU {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
let ndim = shape.len();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "GLU requires at least 1D input".to_string(),
});
}
let last_dim = shape[ndim - 1];
if last_dim % 2 != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"GLU requires the last dimension to be even, got {}",
last_dim
),
});
}
let half = last_dim / 2;
let device = input.device();
let data = input.data_vec()?;
let outer_size: usize = shape[..ndim - 1].iter().product();
let outer_size = if outer_size == 0 { 1 } else { outer_size };
let one = <T as num_traits::One>::one();
let mut result = Vec::with_capacity(outer_size * half);
for i in 0..outer_size {
let base = i * last_dim;
for j in 0..half {
let a = data[base + j];
let b = data[base + half + j];
let sig_b = one / (one + (-b).exp());
result.push(a * sig_b);
}
}
let mut out_shape = shape.to_vec();
out_shape[ndim - 1] = half;
let out = Tensor::from_storage(
ferrotorch_core::TensorStorage::cpu(result),
out_shape,
false,
)?;
if device.is_cuda() { out.to(device) } else { Ok(out) }
}
}
impl Default for GLU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(GLU);
#[cfg(test)]
mod tests {
use super::*;
use ferrotorch_core::TensorStorage;
fn t(data: &[f64]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], false).unwrap()
}
fn t2d(data: &[f64], rows: usize, cols: usize) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![rows, cols], false).unwrap()
}
fn assert_zero_param_module<M, T: Float>(module: &mut M)
where
M: Module<T>,
{
assert!(module.parameters().is_empty(), "should have no parameters");
assert!(
module.parameters_mut().is_empty(),
"should have no mutable parameters"
);
assert!(
module.named_parameters().is_empty(),
"should have no named parameters"
);
assert!(module.is_training(), "default should be training mode");
module.eval();
assert!(!module.is_training(), "eval() should set training=false");
module.train();
assert!(module.is_training(), "train() should set training=true");
}
#[test]
fn test_relu_forward() {
let m = ReLU::new();
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7);
assert!((d[1] - 0.0).abs() < 1e-7);
assert!((d[2] - 0.0).abs() < 1e-7);
assert!((d[3] - 1.0).abs() < 1e-7);
assert!((d[4] - 2.0).abs() < 1e-7);
}
#[test]
fn test_relu_module_trait() {
let mut m = ReLU::new();
assert_zero_param_module::<ReLU, f64>(&mut m);
}
#[test]
fn test_gelu_forward() {
let m = GELU::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7);
let x = t(&[1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(d[0] > 0.0);
assert!(d[1] > 0.0);
let x = t(&[10.0]);
let y = m.forward(&x).unwrap();
assert!((y.data().unwrap()[0] - 10.0).abs() < 0.01);
}
#[test]
fn test_gelu_module_trait() {
let mut m = GELU::new();
assert_zero_param_module::<GELU, f64>(&mut m);
}
#[test]
fn test_silu_forward() {
let m = SiLU::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7);
let x = t(&[10.0]);
let y = m.forward(&x).unwrap();
assert!((y.data().unwrap()[0] - 10.0).abs() < 0.01);
}
#[test]
fn test_silu_module_trait() {
let mut m = SiLU::new();
assert_zero_param_module::<SiLU, f64>(&mut m);
}
#[test]
fn test_sigmoid_forward() {
let m = Sigmoid::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!((y.data().unwrap()[0] - 0.5).abs() < 1e-7);
let x = t(&[-100.0, 100.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(d[0] < 1e-10, "sigmoid(-100) should be ~0");
assert!((d[1] - 1.0).abs() < 1e-10, "sigmoid(100) should be ~1");
}
#[test]
fn test_sigmoid_module_trait() {
let mut m = Sigmoid::new();
assert_zero_param_module::<Sigmoid, f64>(&mut m);
}
#[test]
fn test_tanh_forward() {
let m = Tanh::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7);
let x = t(&[-100.0, 100.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] + 1.0).abs() < 1e-10, "tanh(-100) should be ~-1");
assert!((d[1] - 1.0).abs() < 1e-10, "tanh(100) should be ~1");
}
#[test]
fn test_tanh_module_trait() {
let mut m = Tanh::new();
assert_zero_param_module::<Tanh, f64>(&mut m);
}
#[test]
fn test_softmax_forward_1d() {
let m = Softmax::new(-1);
let x = t(&[1.0, 2.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let total: f64 = d.iter().sum();
assert!((total - 1.0).abs() < 1e-7);
assert!(d[0] < d[1]);
assert!(d[1] < d[2]);
}
#[test]
fn test_softmax_forward_2d() {
let m = Softmax::new(-1);
let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let row0_sum = d[0] + d[1];
let row1_sum = d[2] + d[3];
assert!((row0_sum - 1.0).abs() < 1e-7);
assert!((row1_sum - 1.0).abs() < 1e-7);
}
#[test]
fn test_softmax_wrong_dim() {
let m = Softmax::new(0);
let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
assert!(m.forward(&x).is_err());
}
#[test]
fn test_softmax_module_trait() {
let mut m = Softmax::new(-1);
assert_zero_param_module::<Softmax, f64>(&mut m);
}
#[test]
fn test_log_softmax_forward_1d() {
let m = LogSoftmax::new(-1);
let x = t(&[1.0, 2.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let total: f64 = d.iter().map(|&v| v.exp()).sum();
assert!(
(total - 1.0).abs() < 1e-7,
"exp(log_softmax) sum = {total}"
);
assert!(d.iter().all(|&v| v <= 0.0));
}
#[test]
fn test_log_softmax_module_trait() {
let mut m = LogSoftmax::new(-1);
assert_zero_param_module::<LogSoftmax, f64>(&mut m);
}
#[test]
fn test_leaky_relu_forward() {
let m = LeakyReLU::new(0.01);
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-0.02)).abs() < 1e-7, "LeakyReLU(-2) = {}", d[0]);
assert!((d[1] - (-0.01)).abs() < 1e-7, "LeakyReLU(-1) = {}", d[1]);
assert!((d[2] - 0.0).abs() < 1e-7, "LeakyReLU(0) = {}", d[2]);
assert!((d[3] - 1.0).abs() < 1e-7, "LeakyReLU(1) = {}", d[3]);
assert!((d[4] - 2.0).abs() < 1e-7, "LeakyReLU(2) = {}", d[4]);
}
#[test]
fn test_leaky_relu_large_slope() {
let m = LeakyReLU::new(0.2);
let x = t(&[-5.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-1.0)).abs() < 1e-7, "LeakyReLU(-5, slope=0.2) = {}", d[0]);
assert!((d[1] - 3.0).abs() < 1e-7, "LeakyReLU(3, slope=0.2) = {}", d[1]);
}
#[test]
fn test_leaky_relu_zero_slope_is_relu() {
let m = LeakyReLU::new(0.0);
let x = t(&[-2.0, 0.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7);
assert!((d[1] - 0.0).abs() < 1e-7);
assert!((d[2] - 3.0).abs() < 1e-7);
}
#[test]
fn test_leaky_relu_module_trait() {
let mut m = LeakyReLU::new(0.01);
assert_zero_param_module::<LeakyReLU, f64>(&mut m);
}
#[test]
fn test_elu_forward() {
let m = ELU::new(1.0);
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[3] - 1.0).abs() < 1e-7);
assert!((d[4] - 2.0).abs() < 1e-7);
assert!((d[2] - 0.0).abs() < 1e-7);
let expected_m1 = 1.0 * ((-1.0_f64).exp() - 1.0);
assert!(
(d[1] - expected_m1).abs() < 1e-7,
"ELU(-1) expected {}, got {}",
expected_m1,
d[1]
);
let expected_m2 = 1.0 * ((-2.0_f64).exp() - 1.0);
assert!(
(d[0] - expected_m2).abs() < 1e-7,
"ELU(-2) expected {}, got {}",
expected_m2,
d[0]
);
let x = t(&[-100.0]);
let y = m.forward(&x).unwrap();
assert!((y.data().unwrap()[0] + 1.0).abs() < 1e-7);
}
#[test]
fn test_elu_custom_alpha() {
let m = ELU::new(2.0);
let x = t(&[-1.0, 1.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let expected = 2.0 * ((-1.0_f64).exp() - 1.0);
assert!((d[0] - expected).abs() < 1e-7);
assert!((d[1] - 1.0).abs() < 1e-7);
}
#[test]
fn test_elu_module_trait() {
let mut m = ELU::new(1.0);
assert_zero_param_module::<ELU, f64>(&mut m);
}
#[test]
fn test_mish_forward() {
let m = Mish::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7, "mish(0) should be 0");
let x = t(&[20.0]);
let y = m.forward(&x).unwrap();
assert!(
(y.data().unwrap()[0] - 20.0).abs() < 0.01,
"mish(20) should be ~20"
);
let x = t(&[-1.0]);
let y = m.forward(&x).unwrap();
let val = y.data().unwrap()[0];
let softplus = (1.0 + (-1.0_f64).exp()).ln();
let expected = -1.0 * softplus.tanh();
assert!(
(val - expected).abs() < 1e-7,
"mish(-1) expected {}, got {}",
expected,
val
);
}
#[test]
fn test_mish_module_trait() {
let mut m = Mish::new();
assert_zero_param_module::<Mish, f64>(&mut m);
}
#[test]
fn test_prelu_forward_default() {
let m = PReLU::<f64>::new(0.25).unwrap();
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-0.5)).abs() < 1e-6, "PReLU(-2) = {}", d[0]);
assert!((d[1] - (-0.25)).abs() < 1e-6, "PReLU(-1) = {}", d[1]);
assert!((d[2] - 0.0).abs() < 1e-6, "PReLU(0) = {}", d[2]);
assert!((d[3] - 1.0).abs() < 1e-6, "PReLU(1) = {}", d[3]);
assert!((d[4] - 2.0).abs() < 1e-6, "PReLU(2) = {}", d[4]);
}
#[test]
fn test_prelu_has_parameter() {
let m = PReLU::<f64>::new(0.25).unwrap();
assert_eq!(m.parameters().len(), 1, "PReLU should have 1 parameter");
let named = m.named_parameters();
assert_eq!(named.len(), 1);
assert_eq!(named[0].0, "alpha");
}
#[test]
fn test_prelu_module_trait() {
let mut m = PReLU::<f64>::new(0.25).unwrap();
assert_eq!(m.parameters().len(), 1);
assert!(m.is_training());
m.eval();
assert!(!m.is_training());
m.train();
assert!(m.is_training());
}
#[test]
fn test_celu_forward() {
let m = CELU::new(1.0);
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[3] - 1.0).abs() < 1e-7);
assert!((d[4] - 2.0).abs() < 1e-7);
assert!((d[2] - 0.0).abs() < 1e-7);
let expected_m1 = 1.0 * ((-1.0_f64).exp() - 1.0);
assert!((d[1] - expected_m1).abs() < 1e-7, "CELU(-1) = {}", d[1]);
}
#[test]
fn test_celu_module_trait() {
let mut m = CELU::new(1.0);
assert_zero_param_module::<CELU, f64>(&mut m);
}
#[test]
fn test_selu_forward() {
let m = SELU::new();
let x = t(&[-1.0, 0.0, 1.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let lambda = 1.0507009873554805_f64;
let alpha = 1.6732632423543772_f64;
assert!((d[2] - lambda * 1.0).abs() < 1e-7, "SELU(1) = {}", d[2]);
assert!((d[1] - 0.0).abs() < 1e-7, "SELU(0) = {}", d[1]);
let expected_m1 = lambda * alpha * ((-1.0_f64).exp() - 1.0);
assert!((d[0] - expected_m1).abs() < 1e-7, "SELU(-1) = {}", d[0]);
}
#[test]
fn test_selu_module_trait() {
let mut m = SELU::new();
assert_zero_param_module::<SELU, f64>(&mut m);
}
#[test]
fn test_hard_sigmoid_forward() {
let m = HardSigmoid::new();
let x = t(&[-4.0, -3.0, 0.0, 3.0, 5.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7, "HardSigmoid(-4) = {}", d[0]);
assert!((d[1] - 0.0).abs() < 1e-7, "HardSigmoid(-3) = {}", d[1]);
assert!((d[2] - 0.5).abs() < 1e-7, "HardSigmoid(0) = {}", d[2]);
assert!((d[3] - 1.0).abs() < 1e-7, "HardSigmoid(3) = {}", d[3]);
assert!((d[4] - 1.0).abs() < 1e-7, "HardSigmoid(5) = {}", d[4]);
}
#[test]
fn test_hard_sigmoid_module_trait() {
let mut m = HardSigmoid::new();
assert_zero_param_module::<HardSigmoid, f64>(&mut m);
}
#[test]
fn test_hard_swish_forward() {
let m = HardSwish::new();
let x = t(&[-4.0, 0.0, 3.0, 5.0, -1.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7, "HardSwish(-4) = {}", d[0]);
assert!((d[1] - 0.0).abs() < 1e-7, "HardSwish(0) = {}", d[1]);
assert!((d[2] - 3.0).abs() < 1e-7, "HardSwish(3) = {}", d[2]);
assert!((d[3] - 5.0).abs() < 1e-7, "HardSwish(5) = {}", d[3]);
assert!((d[4] - (-1.0 / 3.0)).abs() < 1e-7, "HardSwish(-1) = {}", d[4]);
}
#[test]
fn test_hard_swish_module_trait() {
let mut m = HardSwish::new();
assert_zero_param_module::<HardSwish, f64>(&mut m);
}
#[test]
fn test_softplus_forward() {
let m = Softplus::new(1.0);
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 2.0_f64.ln()).abs() < 1e-7, "Softplus(0) = {}", d[0]);
let x = t(&[25.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 25.0).abs() < 1e-5, "Softplus(25) = {}", d[0]);
let x = t(&[1.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let expected = (1.0 + 1.0_f64.exp()).ln();
assert!((d[0] - expected).abs() < 1e-7, "Softplus(1) = {}", d[0]);
}
#[test]
fn test_softplus_custom_beta() {
let m = Softplus::new(2.0);
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let expected = 2.0_f64.ln() / 2.0;
assert!((d[0] - expected).abs() < 1e-7, "Softplus(0, beta=2) = {}", d[0]);
}
#[test]
fn test_softplus_module_trait() {
let mut m = Softplus::new(1.0);
assert_zero_param_module::<Softplus, f64>(&mut m);
}
#[test]
fn test_glu_forward_1d() {
let m = GLU::new();
let x = t(&[1.0, 0.0, 2.0, 0.0]);
let y = m.forward(&x).unwrap();
assert_eq!(y.shape(), &[2]);
let d = y.data().unwrap();
let sig_2 = 1.0 / (1.0 + (-2.0_f64).exp());
assert!((d[0] - sig_2).abs() < 1e-7, "GLU[0] = {}", d[0]);
assert!((d[1] - 0.0).abs() < 1e-7, "GLU[1] = {}", d[1]);
}
#[test]
fn test_glu_forward_2d() {
let m = GLU::new();
let x = t2d(&[1.0, 0.0, 2.0, 0.0], 1, 4);
let y = m.forward(&x).unwrap();
assert_eq!(y.shape(), &[1, 2]);
let d = y.data().unwrap();
let sig_2 = 1.0 / (1.0 + (-2.0_f64).exp());
assert!((d[0] - sig_2).abs() < 1e-7);
assert!((d[1] - 0.0).abs() < 1e-7);
}
#[test]
fn test_glu_odd_dim_error() {
let m = GLU::new();
let x = t(&[1.0, 2.0, 3.0]); assert!(m.forward(&x).is_err());
}
#[test]
fn test_glu_module_trait() {
let mut m = GLU::new();
assert_zero_param_module::<GLU, f64>(&mut m);
}
#[test]
fn test_defaults() {
let _relu = ReLU::default();
let _gelu = GELU::default();
let _silu = SiLU::default();
let _sigmoid = Sigmoid::default();
let _tanh = Tanh::default();
let _softmax = Softmax::default();
let _log_softmax = LogSoftmax::default();
let lrelu = LeakyReLU::default();
assert!((lrelu.negative_slope - 0.01).abs() < f64::EPSILON);
let elu = ELU::default();
assert!((elu.alpha - 1.0).abs() < f64::EPSILON);
let _mish = Mish::default();
let celu = CELU::default();
assert!((celu.alpha - 1.0).abs() < f64::EPSILON);
let _selu = SELU::default();
let _hard_sigmoid = HardSigmoid::default();
let _hard_swish = HardSwish::default();
let softplus = Softplus::default();
assert!((softplus.beta - 1.0).abs() < f64::EPSILON);
let _glu = GLU::default();
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ReLU>();
assert_send_sync::<GELU>();
assert_send_sync::<SiLU>();
assert_send_sync::<Sigmoid>();
assert_send_sync::<Tanh>();
assert_send_sync::<Softmax>();
assert_send_sync::<LogSoftmax>();
assert_send_sync::<LeakyReLU>();
assert_send_sync::<ELU>();
assert_send_sync::<Mish>();
assert_send_sync::<PReLU<f64>>();
assert_send_sync::<CELU>();
assert_send_sync::<SELU>();
assert_send_sync::<HardSigmoid>();
assert_send_sync::<HardSwish>();
assert_send_sync::<Softplus>();
assert_send_sync::<GLU>();
}
fn t_grad(data: &[f64]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], true).unwrap()
}
fn t_scalar_grad(val: f64) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], true).unwrap()
}
fn numerical_grad(f: impl Fn(f64) -> f64, x: f64) -> f64 {
let h = 1e-5;
(f(x + h) - f(x - h)) / (2.0 * h)
}
#[test]
fn test_softplus_backward_produces_grad() {
let x = t_scalar_grad(1.0);
let m = Softplus::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap();
assert!(grad.is_some(), "Softplus backward should produce a gradient");
}
#[test]
fn test_softplus_backward_at_zero() {
let x = t_scalar_grad(0.0);
let m = Softplus::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
assert!(
(grad.item().unwrap() - 0.5).abs() < 1e-6,
"Softplus grad at x=0: expected 0.5, got {}",
grad.item().unwrap()
);
}
#[test]
fn test_softplus_backward_matches_numerical() {
for &val in &[-2.0, -0.5, 0.0, 1.0, 3.0] {
let x = t_scalar_grad(val);
let m = Softplus::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(|v| (1.0 + v.exp()).ln(), val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"Softplus grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_softplus_backward_custom_beta() {
let val = 1.0;
let beta = 2.0;
let x = t_scalar_grad(val);
let m = Softplus::new(beta);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(|v| (1.0 + (beta * v).exp()).ln() / beta, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"Softplus grad at x={}, beta={}: expected {}, got {}",
val,
beta,
expected,
grad.item().unwrap()
);
}
#[test]
fn test_softplus_backward_vector() {
let x = t_grad(&[-2.0, -0.5, 0.0, 1.0, 3.0]);
let m = Softplus::new(1.0);
let y = m.forward(&x).unwrap();
let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
ferrotorch_core::backward(&sum).unwrap();
let grad = x.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
for (i, &val) in [-2.0_f64, -0.5, 0.0, 1.0, 3.0].iter().enumerate() {
let expected = numerical_grad(|v| (1.0 + v.exp()).ln(), val);
assert!(
(grad_data[i] - expected).abs() < 1e-4,
"Softplus grad[{}] at x={}: expected {}, got {}",
i,
val,
expected,
grad_data[i]
);
}
}
#[test]
fn test_elu_backward_produces_grad() {
let x = t_scalar_grad(-1.0);
let m = ELU::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap();
assert!(grad.is_some(), "ELU backward should produce a gradient");
}
#[test]
fn test_elu_backward_positive() {
let x = t_scalar_grad(2.0);
let m = ELU::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
assert!(
(grad.item().unwrap() - 1.0).abs() < 1e-6,
"ELU grad at x=2: expected 1.0, got {}",
grad.item().unwrap()
);
}
#[test]
fn test_elu_backward_matches_numerical() {
let alpha = 1.0;
for &val in &[-2.0, -1.0, -0.5, 0.5, 2.0] {
let x = t_scalar_grad(val);
let m = ELU::new(alpha);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(
|v| if v > 0.0 { v } else { alpha * (v.exp() - 1.0) },
val,
);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"ELU grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_elu_backward_custom_alpha() {
let alpha = 2.0;
let val = -0.5;
let x = t_scalar_grad(val);
let m = ELU::new(alpha);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = alpha * val.exp();
assert!(
(grad.item().unwrap() - expected).abs() < 1e-5,
"ELU grad at x={}, alpha={}: expected {}, got {}",
val,
alpha,
expected,
grad.item().unwrap()
);
}
#[test]
fn test_mish_backward_produces_grad() {
let x = t_scalar_grad(1.0);
let m = Mish::new();
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap();
assert!(grad.is_some(), "Mish backward should produce a gradient");
}
#[test]
fn test_mish_backward_matches_numerical() {
let mish_fn = |v: f64| {
let sp = (1.0 + v.exp()).ln();
v * sp.tanh()
};
for &val in &[-2.0, -1.0, 0.0, 0.5, 1.5, 3.0] {
let x = t_scalar_grad(val);
let m = Mish::new();
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(mish_fn, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"Mish grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_mish_backward_vector() {
let x = t_grad(&[-1.0, 0.0, 1.0, 2.0]);
let m = Mish::new();
let y = m.forward(&x).unwrap();
let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
ferrotorch_core::backward(&sum).unwrap();
let grad = x.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
let mish_fn = |v: f64| {
let sp = (1.0 + v.exp()).ln();
v * sp.tanh()
};
for (i, &val) in [-1.0_f64, 0.0, 1.0, 2.0].iter().enumerate() {
let expected = numerical_grad(mish_fn, val);
assert!(
(grad_data[i] - expected).abs() < 1e-4,
"Mish grad[{}] at x={}: expected {}, got {}",
i,
val,
expected,
grad_data[i]
);
}
}
#[test]
fn test_state_dict_empty() {
let m = ReLU::new();
let sd = Module::<f64>::state_dict(&m);
assert!(sd.is_empty());
}
}