use ferrotorch_core::grad_fns::activation as act;
use ferrotorch_core::grad_fns::arithmetic;
use ferrotorch_core::grad_fns::transcendental;
use ferrotorch_core::ops::elementwise::unary_map;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, normalize_axis};
use crate::module::Module;
use crate::parameter::Parameter;
macro_rules! impl_activation_module {
($ty:ident) => {
impl<T: Float> Module<T> for $ty {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward(input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
};
}
#[derive(Debug, Clone)]
pub struct ReLU {
training: bool,
}
impl ReLU {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::relu(input)
}
}
impl Default for ReLU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(ReLU);
#[derive(Debug, Clone)]
pub struct Softmax2d {
training: bool,
}
impl Softmax2d {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 4 {
return Err(ferrotorch_core::error::FerrotorchError::InvalidArgument {
message: format!("Softmax2d expects 4-D input [N,C,H,W], got {:?}", input.shape()),
});
}
if input.is_cuda() {
return Err(ferrotorch_core::error::FerrotorchError::NotImplementedOnCuda {
op: "Softmax2d",
});
}
let shape = input.shape();
let n = shape[0];
let c = shape[1];
let h = shape[2];
let w = shape[3];
let data = input.data()?;
let mut out = vec![<T as num_traits::Zero>::zero(); n * c * h * w];
for batch in 0..n {
for row in 0..h {
for col in 0..w {
let mut max_val = T::neg_infinity();
for ch in 0..c {
let idx = batch * c * h * w + ch * h * w + row * w + col;
if data[idx] > max_val {
max_val = data[idx];
}
}
let mut sum_exp = <T as num_traits::Zero>::zero();
for ch in 0..c {
let idx = batch * c * h * w + ch * h * w + row * w + col;
let e = (data[idx] - max_val).exp();
out[idx] = e;
sum_exp = sum_exp + e;
}
for ch in 0..c {
let idx = batch * c * h * w + ch * h * w + row * w + col;
out[idx] = out[idx] / sum_exp;
}
}
}
}
Tensor::from_storage(
ferrotorch_core::storage::TensorStorage::cpu(out),
shape.to_vec(),
false,
)
}
}
impl Default for Softmax2d {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Softmax2d);
pub use act::GeluApproximate;
#[derive(Debug, Clone)]
pub struct GELU {
approximate: GeluApproximate,
training: bool,
}
impl GELU {
pub fn new() -> Self {
Self {
approximate: GeluApproximate::default(),
training: true,
}
}
pub fn with_approximate(approximate: GeluApproximate) -> Self {
Self {
approximate,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::gelu_with(input, self.approximate)
}
}
impl Default for GELU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(GELU);
#[derive(Debug, Clone)]
pub struct SiLU {
training: bool,
}
impl SiLU {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::silu(input)
}
}
impl Default for SiLU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(SiLU);
#[derive(Debug, Clone)]
pub struct Sigmoid {
training: bool,
}
impl Sigmoid {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::sigmoid(input)
}
}
impl Default for Sigmoid {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Sigmoid);
#[derive(Debug, Clone)]
pub struct Tanh {
training: bool,
}
impl Tanh {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::tanh(input)
}
}
impl Default for Tanh {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Tanh);
#[derive(Debug, Clone)]
pub struct Softmax {
pub dim: isize,
training: bool,
}
impl Softmax {
pub fn new(dim: isize) -> Self {
Self {
dim,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return act::softmax(input);
}
let axis = normalize_axis(self.dim, ndim)?;
if axis != ndim - 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Softmax currently only supports dim=-1 (last axis), \
but got dim={} (axis={}) for a {}-D tensor",
self.dim, axis, ndim,
),
});
}
act::softmax(input)
}
}
impl Default for Softmax {
fn default() -> Self {
Self::new(-1)
}
}
impl_activation_module!(Softmax);
#[derive(Debug, Clone)]
pub struct LogSoftmax {
pub dim: isize,
training: bool,
}
impl LogSoftmax {
pub fn new(dim: isize) -> Self {
Self {
dim,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return act::log_softmax(input);
}
let axis = normalize_axis(self.dim, ndim)?;
if axis != ndim - 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"LogSoftmax currently only supports dim=-1 (last axis), \
but got dim={} (axis={}) for a {}-D tensor",
self.dim, axis, ndim,
),
});
}
act::log_softmax(input)
}
}
impl Default for LogSoftmax {
fn default() -> Self {
Self::new(-1)
}
}
impl_activation_module!(LogSoftmax);
#[derive(Debug, Clone)]
pub struct LeakyReLU {
pub negative_slope: f64,
training: bool,
}
impl LeakyReLU {
pub fn new(negative_slope: f64) -> Self {
Self {
negative_slope,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if (self.negative_slope - 0.0).abs() < f64::EPSILON {
return act::relu(input);
}
if (self.negative_slope - 1.0).abs() < f64::EPSILON {
return Ok(input.clone());
}
let relu_x = act::relu(input)?;
let scale = T::from(1.0 - self.negative_slope).unwrap();
let slope = T::from(self.negative_slope).unwrap();
let scale_tensor = ferrotorch_core::scalar(scale)?;
let slope_tensor = ferrotorch_core::scalar(slope)?;
let scaled_relu = arithmetic::mul(&relu_x, &scale_tensor)?;
let scaled_x = arithmetic::mul(input, &slope_tensor)?;
arithmetic::add(&scaled_relu, &scaled_x)
}
}
impl Default for LeakyReLU {
fn default() -> Self {
Self::new(0.01)
}
}
impl_activation_module!(LeakyReLU);
#[derive(Debug, Clone)]
pub struct ELU {
pub alpha: f64,
training: bool,
}
impl ELU {
pub fn new(alpha: f64) -> Self {
Self {
alpha,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::elu(input, self.alpha)
}
}
impl Default for ELU {
fn default() -> Self {
Self::new(1.0)
}
}
impl_activation_module!(ELU);
#[derive(Debug, Clone)]
pub struct Mish {
training: bool,
}
impl Mish {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::mish(input)
}
}
impl Default for Mish {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Mish);
#[derive(Debug, Clone)]
pub struct PReLU<T: Float> {
pub alpha: Parameter<T>,
training: bool,
}
impl<T: Float> PReLU<T> {
pub fn new(init_alpha: f64) -> FerrotorchResult<Self> {
let alpha_val = T::from(init_alpha).unwrap();
let alpha_tensor = ferrotorch_core::from_slice(&[alpha_val], &[1])?;
Ok(Self {
alpha: Parameter::new(alpha_tensor),
training: true,
})
}
pub fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let relu_x = act::relu(input)?;
if self.alpha.tensor().is_cuda() {
return Err(ferrotorch_core::error::FerrotorchError::NotImplementedOnCuda {
op: "PReLU",
});
}
let alpha_data = self.alpha.tensor().data()?;
let alpha_val = alpha_data[0];
let one_minus_alpha = T::from(1.0).unwrap() - alpha_val;
let scale_tensor = ferrotorch_core::scalar(one_minus_alpha)?;
let alpha_tensor = ferrotorch_core::scalar(alpha_val)?;
let scaled_relu = arithmetic::mul(&relu_x, &scale_tensor)?;
let scaled_x = arithmetic::mul(input, &alpha_tensor)?;
arithmetic::add(&scaled_relu, &scaled_x)
}
}
impl<T: Float> Module<T> for PReLU<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward(input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![&self.alpha]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![&mut self.alpha]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![("alpha".to_string(), &self.alpha)]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug, Clone)]
pub struct CELU {
pub alpha: f64,
training: bool,
}
impl CELU {
pub fn new(alpha: f64) -> Self {
Self {
alpha,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let alpha = T::from(self.alpha).unwrap();
unary_map(input, |x| {
let pos = if x > zero { x } else { zero };
let neg = if x < zero {
alpha * ((x / alpha).exp() - one)
} else {
zero
};
pos + neg
})
}
}
impl Default for CELU {
fn default() -> Self {
Self::new(1.0)
}
}
impl_activation_module!(CELU);
#[derive(Debug, Clone)]
pub struct SELU {
training: bool,
}
const SELU_ALPHA: f64 = 1.6732632423543772;
const SELU_LAMBDA: f64 = 1.0507009873554805;
impl SELU {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let alpha = T::from(SELU_ALPHA).unwrap();
let lambda = T::from(SELU_LAMBDA).unwrap();
unary_map(input, |x| {
if x > zero {
lambda * x
} else {
lambda * alpha * (x.exp() - one)
}
})
}
}
impl Default for SELU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(SELU);
#[derive(Debug, Clone)]
pub struct HardSigmoid {
training: bool,
}
impl HardSigmoid {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let three = T::from(3.0).unwrap();
let six = T::from(6.0).unwrap();
unary_map(input, |x| {
let v = (x + three) / six;
if v < zero {
zero
} else if v > one {
one
} else {
v
}
})
}
}
impl Default for HardSigmoid {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(HardSigmoid);
#[derive(Debug, Clone)]
pub struct HardSwish {
training: bool,
}
impl HardSwish {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let three = T::from(3.0).unwrap();
let six = T::from(6.0).unwrap();
unary_map(input, |x| {
let hard_sig = {
let v = (x + three) / six;
if v < zero {
zero
} else if v > one {
one
} else {
v
}
};
x * hard_sig
})
}
}
impl Default for HardSwish {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(HardSwish);
#[derive(Debug, Clone)]
pub struct Softplus {
pub beta: f64,
pub threshold: f64,
training: bool,
}
impl Softplus {
pub fn new(beta: f64) -> Self {
Self {
beta,
threshold: 20.0,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::softplus(input, self.beta, self.threshold)
}
}
impl Default for Softplus {
fn default() -> Self {
Self::new(1.0)
}
}
impl_activation_module!(Softplus);
#[derive(Debug, Clone)]
pub struct GLU {
training: bool,
}
impl GLU {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
let ndim = shape.len();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "GLU requires at least 1D input".to_string(),
});
}
let last_dim = shape[ndim - 1];
if last_dim % 2 != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"GLU requires the last dimension to be even, got {}",
last_dim
),
});
}
let half = last_dim / 2;
let device = input.device();
let data = input.data_vec()?;
let outer_size: usize = shape[..ndim - 1].iter().product();
let outer_size = if outer_size == 0 { 1 } else { outer_size };
let one = <T as num_traits::One>::one();
let mut result = Vec::with_capacity(outer_size * half);
for i in 0..outer_size {
let base = i * last_dim;
for j in 0..half {
let a = data[base + j];
let b = data[base + half + j];
let sig_b = one / (one + (-b).exp());
result.push(a * sig_b);
}
}
let mut out_shape = shape.to_vec();
out_shape[ndim - 1] = half;
let out = Tensor::from_storage(
ferrotorch_core::TensorStorage::cpu(result),
out_shape,
false,
)?;
if device.is_cuda() {
out.to(device)
} else {
Ok(out)
}
}
}
impl Default for GLU {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(GLU);
#[derive(Debug, Clone)]
pub struct ReLU6 {
training: bool,
}
impl ReLU6 {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let six = T::from(6.0).unwrap();
transcendental::clamp(input, zero, six)
}
}
impl Default for ReLU6 {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(ReLU6);
#[derive(Debug, Clone)]
pub struct Hardtanh {
pub min_val: f64,
pub max_val: f64,
training: bool,
}
impl Hardtanh {
pub fn new(min_val: f64, max_val: f64) -> Self {
Self {
min_val,
max_val,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let min = T::from(self.min_val).unwrap();
let max = T::from(self.max_val).unwrap();
transcendental::clamp(input, min, max)
}
}
impl Default for Hardtanh {
fn default() -> Self {
Self::new(-1.0, 1.0)
}
}
impl_activation_module!(Hardtanh);
#[derive(Debug, Clone)]
pub struct LogSigmoid {
training: bool,
}
impl LogSigmoid {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let neg_input = arithmetic::neg(input)?;
let sp = act::softplus(&neg_input, 1.0, 20.0)?;
arithmetic::neg(&sp)
}
}
impl Default for LogSigmoid {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(LogSigmoid);
#[derive(Debug, Clone)]
pub struct Softmin {
pub dim: isize,
training: bool,
}
impl Softmin {
pub fn new(dim: isize) -> Self {
Self {
dim,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
let neg_input = arithmetic::neg(input)?;
return act::softmax(&neg_input);
}
let axis = normalize_axis(self.dim, ndim)?;
if axis != ndim - 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Softmin currently only supports dim=-1 (last axis), \
but got dim={} (axis={}) for a {}-D tensor",
self.dim, axis, ndim,
),
});
}
let neg_input = arithmetic::neg(input)?;
act::softmax(&neg_input)
}
}
impl Default for Softmin {
fn default() -> Self {
Self::new(-1)
}
}
impl_activation_module!(Softmin);
#[derive(Debug, Clone)]
pub struct Threshold {
pub threshold: f64,
pub value: f64,
training: bool,
}
impl Threshold {
pub fn new(threshold: f64, value: f64) -> Self {
Self {
threshold,
value,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let thresh = T::from(self.threshold).unwrap();
let val = T::from(self.value).unwrap();
unary_map(input, |x| if x > thresh { x } else { val })
}
}
impl_activation_module!(Threshold);
#[derive(Debug, Clone)]
pub struct Softshrink {
pub lambda: f64,
training: bool,
}
impl Softshrink {
pub fn new(lambda: f64) -> Self {
Self {
lambda,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let lam = T::from(self.lambda).unwrap();
let neg_lam = T::from(-self.lambda).unwrap();
let zero = <T as num_traits::Zero>::zero();
unary_map(input, |x| {
if x > lam {
x - lam
} else if x < neg_lam {
x + lam
} else {
zero
}
})
}
}
impl Default for Softshrink {
fn default() -> Self {
Self::new(0.5)
}
}
impl_activation_module!(Softshrink);
#[derive(Debug, Clone)]
pub struct Hardshrink {
pub lambda: f64,
training: bool,
}
impl Hardshrink {
pub fn new(lambda: f64) -> Self {
Self {
lambda,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let lam = T::from(self.lambda).unwrap();
let neg_lam = T::from(-self.lambda).unwrap();
let zero = <T as num_traits::Zero>::zero();
unary_map(input, |x| if x > lam || x < neg_lam { x } else { zero })
}
}
impl Default for Hardshrink {
fn default() -> Self {
Self::new(0.5)
}
}
impl_activation_module!(Hardshrink);
#[derive(Debug, Clone)]
pub struct Tanhshrink {
training: bool,
}
impl Tanhshrink {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let tanh_x = act::tanh(input)?;
arithmetic::sub(input, &tanh_x)
}
}
impl Default for Tanhshrink {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Tanhshrink);
#[derive(Debug, Clone)]
pub struct Softsign {
training: bool,
}
impl Softsign {
pub fn new() -> Self {
Self { training: true }
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let one = <T as num_traits::One>::one();
unary_map(input, |x| x / (one + x.abs()))
}
}
impl Default for Softsign {
fn default() -> Self {
Self::new()
}
}
impl_activation_module!(Softsign);
#[derive(Debug, Clone)]
pub struct RReLU {
pub lower: f64,
pub upper: f64,
training: bool,
}
fn rrelu_xorshift_seed() -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
let mut hasher = DefaultHasher::new();
SystemTime::now().hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
let mut state = hasher.finish();
if state == 0 {
state = 0xdeadbeefcafe;
}
state
}
#[inline]
fn rrelu_xorshift_next(state: &mut u64) -> f64 {
*state ^= *state << 13;
*state ^= *state >> 7;
*state ^= *state << 17;
(*state as f64) / (u64::MAX as f64)
}
impl RReLU {
pub fn new(lower: f64, upper: f64) -> Self {
Self {
lower,
upper,
training: true,
}
}
pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
if self.training {
let rng_state = std::cell::Cell::new(rrelu_xorshift_seed());
let lower = self.lower;
let upper = self.upper;
let range = upper - lower;
unary_map(input, |x| {
if x >= zero {
x
} else {
let mut st = rng_state.get();
let u = rrelu_xorshift_next(&mut st);
rng_state.set(st);
let slope = T::from(lower + u * range).unwrap();
slope * x
}
})
} else {
let mean_slope = T::from((self.lower + self.upper) / 2.0).unwrap();
unary_map(input, |x| if x >= zero { x } else { mean_slope * x })
}
}
}
impl Default for RReLU {
fn default() -> Self {
Self::new(1.0 / 8.0, 1.0 / 3.0)
}
}
impl_activation_module!(RReLU);
#[cfg(test)]
mod tests {
use super::*;
use ferrotorch_core::TensorStorage;
fn t(data: &[f64]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], false).unwrap()
}
fn t2d(data: &[f64], rows: usize, cols: usize) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![rows, cols], false).unwrap()
}
fn assert_zero_param_module<M, T: Float>(module: &mut M)
where
M: Module<T>,
{
assert!(module.parameters().is_empty(), "should have no parameters");
assert!(
module.parameters_mut().is_empty(),
"should have no mutable parameters"
);
assert!(
module.named_parameters().is_empty(),
"should have no named parameters"
);
assert!(module.is_training(), "default should be training mode");
module.eval();
assert!(!module.is_training(), "eval() should set training=false");
module.train();
assert!(module.is_training(), "train() should set training=true");
}
#[test]
fn test_relu_forward() {
let m = ReLU::new();
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7);
assert!((d[1] - 0.0).abs() < 1e-7);
assert!((d[2] - 0.0).abs() < 1e-7);
assert!((d[3] - 1.0).abs() < 1e-7);
assert!((d[4] - 2.0).abs() < 1e-7);
}
#[test]
fn test_relu_module_trait() {
let mut m = ReLU::new();
assert_zero_param_module::<ReLU, f64>(&mut m);
}
#[test]
fn test_gelu_forward() {
let m = GELU::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7);
let x = t(&[1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(d[0] > 0.0);
assert!(d[1] > 0.0);
let x = t(&[10.0]);
let y = m.forward(&x).unwrap();
assert!((y.data().unwrap()[0] - 10.0).abs() < 0.01);
}
#[test]
fn test_gelu_module_trait() {
let mut m = GELU::new();
assert_zero_param_module::<GELU, f64>(&mut m);
}
#[test]
fn test_silu_forward() {
let m = SiLU::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7);
let x = t(&[10.0]);
let y = m.forward(&x).unwrap();
assert!((y.data().unwrap()[0] - 10.0).abs() < 0.01);
}
#[test]
fn test_silu_module_trait() {
let mut m = SiLU::new();
assert_zero_param_module::<SiLU, f64>(&mut m);
}
#[test]
fn test_sigmoid_forward() {
let m = Sigmoid::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!((y.data().unwrap()[0] - 0.5).abs() < 1e-7);
let x = t(&[-100.0, 100.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(d[0] < 1e-10, "sigmoid(-100) should be ~0");
assert!((d[1] - 1.0).abs() < 1e-10, "sigmoid(100) should be ~1");
}
#[test]
fn test_sigmoid_module_trait() {
let mut m = Sigmoid::new();
assert_zero_param_module::<Sigmoid, f64>(&mut m);
}
#[test]
fn test_tanh_forward() {
let m = Tanh::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7);
let x = t(&[-100.0, 100.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] + 1.0).abs() < 1e-10, "tanh(-100) should be ~-1");
assert!((d[1] - 1.0).abs() < 1e-10, "tanh(100) should be ~1");
}
#[test]
fn test_tanh_module_trait() {
let mut m = Tanh::new();
assert_zero_param_module::<Tanh, f64>(&mut m);
}
#[test]
fn test_softmax_forward_1d() {
let m = Softmax::new(-1);
let x = t(&[1.0, 2.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let total: f64 = d.iter().sum();
assert!((total - 1.0).abs() < 1e-7);
assert!(d[0] < d[1]);
assert!(d[1] < d[2]);
}
#[test]
fn test_softmax_forward_2d() {
let m = Softmax::new(-1);
let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let row0_sum = d[0] + d[1];
let row1_sum = d[2] + d[3];
assert!((row0_sum - 1.0).abs() < 1e-7);
assert!((row1_sum - 1.0).abs() < 1e-7);
}
#[test]
fn test_softmax_wrong_dim() {
let m = Softmax::new(0);
let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
assert!(m.forward(&x).is_err());
}
#[test]
fn test_softmax_module_trait() {
let mut m = Softmax::new(-1);
assert_zero_param_module::<Softmax, f64>(&mut m);
}
#[test]
fn test_log_softmax_forward_1d() {
let m = LogSoftmax::new(-1);
let x = t(&[1.0, 2.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let total: f64 = d.iter().map(|&v| v.exp()).sum();
assert!((total - 1.0).abs() < 1e-7, "exp(log_softmax) sum = {total}");
assert!(d.iter().all(|&v| v <= 0.0));
}
#[test]
fn test_log_softmax_module_trait() {
let mut m = LogSoftmax::new(-1);
assert_zero_param_module::<LogSoftmax, f64>(&mut m);
}
#[test]
fn test_leaky_relu_forward() {
let m = LeakyReLU::new(0.01);
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-0.02)).abs() < 1e-7, "LeakyReLU(-2) = {}", d[0]);
assert!((d[1] - (-0.01)).abs() < 1e-7, "LeakyReLU(-1) = {}", d[1]);
assert!((d[2] - 0.0).abs() < 1e-7, "LeakyReLU(0) = {}", d[2]);
assert!((d[3] - 1.0).abs() < 1e-7, "LeakyReLU(1) = {}", d[3]);
assert!((d[4] - 2.0).abs() < 1e-7, "LeakyReLU(2) = {}", d[4]);
}
#[test]
fn test_leaky_relu_large_slope() {
let m = LeakyReLU::new(0.2);
let x = t(&[-5.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(
(d[0] - (-1.0)).abs() < 1e-7,
"LeakyReLU(-5, slope=0.2) = {}",
d[0]
);
assert!(
(d[1] - 3.0).abs() < 1e-7,
"LeakyReLU(3, slope=0.2) = {}",
d[1]
);
}
#[test]
fn test_leaky_relu_zero_slope_is_relu() {
let m = LeakyReLU::new(0.0);
let x = t(&[-2.0, 0.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7);
assert!((d[1] - 0.0).abs() < 1e-7);
assert!((d[2] - 3.0).abs() < 1e-7);
}
#[test]
fn test_leaky_relu_module_trait() {
let mut m = LeakyReLU::new(0.01);
assert_zero_param_module::<LeakyReLU, f64>(&mut m);
}
#[test]
fn test_elu_forward() {
let m = ELU::new(1.0);
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[3] - 1.0).abs() < 1e-7);
assert!((d[4] - 2.0).abs() < 1e-7);
assert!((d[2] - 0.0).abs() < 1e-7);
let expected_m1 = 1.0 * ((-1.0_f64).exp() - 1.0);
assert!(
(d[1] - expected_m1).abs() < 1e-7,
"ELU(-1) expected {}, got {}",
expected_m1,
d[1]
);
let expected_m2 = 1.0 * ((-2.0_f64).exp() - 1.0);
assert!(
(d[0] - expected_m2).abs() < 1e-7,
"ELU(-2) expected {}, got {}",
expected_m2,
d[0]
);
let x = t(&[-100.0]);
let y = m.forward(&x).unwrap();
assert!((y.data().unwrap()[0] + 1.0).abs() < 1e-7);
}
#[test]
fn test_elu_custom_alpha() {
let m = ELU::new(2.0);
let x = t(&[-1.0, 1.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let expected = 2.0 * ((-1.0_f64).exp() - 1.0);
assert!((d[0] - expected).abs() < 1e-7);
assert!((d[1] - 1.0).abs() < 1e-7);
}
#[test]
fn test_elu_module_trait() {
let mut m = ELU::new(1.0);
assert_zero_param_module::<ELU, f64>(&mut m);
}
#[test]
fn test_mish_forward() {
let m = Mish::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7, "mish(0) should be 0");
let x = t(&[20.0]);
let y = m.forward(&x).unwrap();
assert!(
(y.data().unwrap()[0] - 20.0).abs() < 0.01,
"mish(20) should be ~20"
);
let x = t(&[-1.0]);
let y = m.forward(&x).unwrap();
let val = y.data().unwrap()[0];
let softplus = (1.0 + (-1.0_f64).exp()).ln();
let expected = -1.0 * softplus.tanh();
assert!(
(val - expected).abs() < 1e-7,
"mish(-1) expected {}, got {}",
expected,
val
);
}
#[test]
fn test_mish_module_trait() {
let mut m = Mish::new();
assert_zero_param_module::<Mish, f64>(&mut m);
}
#[test]
fn test_prelu_forward_default() {
let m = PReLU::<f64>::new(0.25).unwrap();
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-0.5)).abs() < 1e-6, "PReLU(-2) = {}", d[0]);
assert!((d[1] - (-0.25)).abs() < 1e-6, "PReLU(-1) = {}", d[1]);
assert!((d[2] - 0.0).abs() < 1e-6, "PReLU(0) = {}", d[2]);
assert!((d[3] - 1.0).abs() < 1e-6, "PReLU(1) = {}", d[3]);
assert!((d[4] - 2.0).abs() < 1e-6, "PReLU(2) = {}", d[4]);
}
#[test]
fn test_prelu_has_parameter() {
let m = PReLU::<f64>::new(0.25).unwrap();
assert_eq!(m.parameters().len(), 1, "PReLU should have 1 parameter");
let named = m.named_parameters();
assert_eq!(named.len(), 1);
assert_eq!(named[0].0, "alpha");
}
#[test]
fn test_prelu_module_trait() {
let mut m = PReLU::<f64>::new(0.25).unwrap();
assert_eq!(m.parameters().len(), 1);
assert!(m.is_training());
m.eval();
assert!(!m.is_training());
m.train();
assert!(m.is_training());
}
#[test]
fn test_celu_forward() {
let m = CELU::new(1.0);
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[3] - 1.0).abs() < 1e-7);
assert!((d[4] - 2.0).abs() < 1e-7);
assert!((d[2] - 0.0).abs() < 1e-7);
let expected_m1 = 1.0 * ((-1.0_f64).exp() - 1.0);
assert!((d[1] - expected_m1).abs() < 1e-7, "CELU(-1) = {}", d[1]);
}
#[test]
fn test_celu_module_trait() {
let mut m = CELU::new(1.0);
assert_zero_param_module::<CELU, f64>(&mut m);
}
#[test]
fn test_selu_forward() {
let m = SELU::new();
let x = t(&[-1.0, 0.0, 1.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let lambda = 1.0507009873554805_f64;
let alpha = 1.6732632423543772_f64;
assert!((d[2] - lambda * 1.0).abs() < 1e-7, "SELU(1) = {}", d[2]);
assert!((d[1] - 0.0).abs() < 1e-7, "SELU(0) = {}", d[1]);
let expected_m1 = lambda * alpha * ((-1.0_f64).exp() - 1.0);
assert!((d[0] - expected_m1).abs() < 1e-7, "SELU(-1) = {}", d[0]);
}
#[test]
fn test_selu_module_trait() {
let mut m = SELU::new();
assert_zero_param_module::<SELU, f64>(&mut m);
}
#[test]
fn test_hard_sigmoid_forward() {
let m = HardSigmoid::new();
let x = t(&[-4.0, -3.0, 0.0, 3.0, 5.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7, "HardSigmoid(-4) = {}", d[0]);
assert!((d[1] - 0.0).abs() < 1e-7, "HardSigmoid(-3) = {}", d[1]);
assert!((d[2] - 0.5).abs() < 1e-7, "HardSigmoid(0) = {}", d[2]);
assert!((d[3] - 1.0).abs() < 1e-7, "HardSigmoid(3) = {}", d[3]);
assert!((d[4] - 1.0).abs() < 1e-7, "HardSigmoid(5) = {}", d[4]);
}
#[test]
fn test_hard_sigmoid_module_trait() {
let mut m = HardSigmoid::new();
assert_zero_param_module::<HardSigmoid, f64>(&mut m);
}
#[test]
fn test_hard_swish_forward() {
let m = HardSwish::new();
let x = t(&[-4.0, 0.0, 3.0, 5.0, -1.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7, "HardSwish(-4) = {}", d[0]);
assert!((d[1] - 0.0).abs() < 1e-7, "HardSwish(0) = {}", d[1]);
assert!((d[2] - 3.0).abs() < 1e-7, "HardSwish(3) = {}", d[2]);
assert!((d[3] - 5.0).abs() < 1e-7, "HardSwish(5) = {}", d[3]);
assert!(
(d[4] - (-1.0 / 3.0)).abs() < 1e-7,
"HardSwish(-1) = {}",
d[4]
);
}
#[test]
fn test_hard_swish_module_trait() {
let mut m = HardSwish::new();
assert_zero_param_module::<HardSwish, f64>(&mut m);
}
#[test]
fn test_softplus_forward() {
let m = Softplus::new(1.0);
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 2.0_f64.ln()).abs() < 1e-7, "Softplus(0) = {}", d[0]);
let x = t(&[25.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 25.0).abs() < 1e-5, "Softplus(25) = {}", d[0]);
let x = t(&[1.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let expected = (1.0 + 1.0_f64.exp()).ln();
assert!((d[0] - expected).abs() < 1e-7, "Softplus(1) = {}", d[0]);
}
#[test]
fn test_softplus_custom_beta() {
let m = Softplus::new(2.0);
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let expected = 2.0_f64.ln() / 2.0;
assert!(
(d[0] - expected).abs() < 1e-7,
"Softplus(0, beta=2) = {}",
d[0]
);
}
#[test]
fn test_softplus_module_trait() {
let mut m = Softplus::new(1.0);
assert_zero_param_module::<Softplus, f64>(&mut m);
}
#[test]
fn test_glu_forward_1d() {
let m = GLU::new();
let x = t(&[1.0, 0.0, 2.0, 0.0]);
let y = m.forward(&x).unwrap();
assert_eq!(y.shape(), &[2]);
let d = y.data().unwrap();
let sig_2 = 1.0 / (1.0 + (-2.0_f64).exp());
assert!((d[0] - sig_2).abs() < 1e-7, "GLU[0] = {}", d[0]);
assert!((d[1] - 0.0).abs() < 1e-7, "GLU[1] = {}", d[1]);
}
#[test]
fn test_glu_forward_2d() {
let m = GLU::new();
let x = t2d(&[1.0, 0.0, 2.0, 0.0], 1, 4);
let y = m.forward(&x).unwrap();
assert_eq!(y.shape(), &[1, 2]);
let d = y.data().unwrap();
let sig_2 = 1.0 / (1.0 + (-2.0_f64).exp());
assert!((d[0] - sig_2).abs() < 1e-7);
assert!((d[1] - 0.0).abs() < 1e-7);
}
#[test]
fn test_glu_odd_dim_error() {
let m = GLU::new();
let x = t(&[1.0, 2.0, 3.0]); assert!(m.forward(&x).is_err());
}
#[test]
fn test_glu_module_trait() {
let mut m = GLU::new();
assert_zero_param_module::<GLU, f64>(&mut m);
}
#[test]
fn test_relu6_forward() {
let m = ReLU6::new();
let x = t(&[-2.0, 0.0, 3.0, 6.0, 10.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7, "ReLU6(-2) = {}", d[0]);
assert!((d[1] - 0.0).abs() < 1e-7, "ReLU6(0) = {}", d[1]);
assert!((d[2] - 3.0).abs() < 1e-7, "ReLU6(3) = {}", d[2]);
assert!((d[3] - 6.0).abs() < 1e-7, "ReLU6(6) = {}", d[3]);
assert!((d[4] - 6.0).abs() < 1e-7, "ReLU6(10) = {}", d[4]);
}
#[test]
fn test_relu6_module_trait() {
let mut m = ReLU6::new();
assert_zero_param_module::<ReLU6, f64>(&mut m);
}
#[test]
fn test_hardtanh_forward_default() {
let m = Hardtanh::default();
let x = t(&[-5.0, -1.0, 0.0, 0.5, 1.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-1.0)).abs() < 1e-7, "Hardtanh(-5) = {}", d[0]);
assert!((d[1] - (-1.0)).abs() < 1e-7, "Hardtanh(-1) = {}", d[1]);
assert!((d[2] - 0.0).abs() < 1e-7, "Hardtanh(0) = {}", d[2]);
assert!((d[3] - 0.5).abs() < 1e-7, "Hardtanh(0.5) = {}", d[3]);
assert!((d[4] - 1.0).abs() < 1e-7, "Hardtanh(1) = {}", d[4]);
assert!((d[5] - 1.0).abs() < 1e-7, "Hardtanh(3) = {}", d[5]);
}
#[test]
fn test_hardtanh_custom_range() {
let m = Hardtanh::new(-2.0, 2.0);
let x = t(&[-5.0, -2.0, 0.0, 2.0, 5.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-2.0)).abs() < 1e-7);
assert!((d[1] - (-2.0)).abs() < 1e-7);
assert!((d[2] - 0.0).abs() < 1e-7);
assert!((d[3] - 2.0).abs() < 1e-7);
assert!((d[4] - 2.0).abs() < 1e-7);
}
#[test]
fn test_hardtanh_module_trait() {
let mut m = Hardtanh::default();
assert_zero_param_module::<Hardtanh, f64>(&mut m);
}
#[test]
fn test_log_sigmoid_forward() {
let m = LogSigmoid::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(
(d[0] - (-2.0_f64.ln())).abs() < 1e-6,
"LogSigmoid(0) = {}, expected {}",
d[0],
-2.0_f64.ln()
);
let x = t(&[-10.0, -1.0, 0.0, 1.0, 10.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(
d.iter().all(|&v| v <= 0.0),
"All LogSigmoid values should be <= 0"
);
assert!(
d[4].abs() < 1e-4,
"LogSigmoid(10) should be ~0, got {}",
d[4]
);
assert!(
(d[0] - (-10.0)).abs() < 0.1,
"LogSigmoid(-10) should be ~-10, got {}",
d[0]
);
}
#[test]
fn test_log_sigmoid_module_trait() {
let mut m = LogSigmoid::new();
assert_zero_param_module::<LogSigmoid, f64>(&mut m);
}
#[test]
fn test_softmin_forward_1d() {
let m = Softmin::new(-1);
let x = t(&[1.0, 2.0, 3.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
let total: f64 = d.iter().sum();
assert!((total - 1.0).abs() < 1e-7, "Softmin sum = {}", total);
assert!(d[0] > d[1], "softmin(1) > softmin(2)");
assert!(d[1] > d[2], "softmin(2) > softmin(3)");
}
#[test]
fn test_softmin_wrong_dim() {
let m = Softmin::new(0);
let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
assert!(m.forward(&x).is_err());
}
#[test]
fn test_softmin_module_trait() {
let mut m = Softmin::new(-1);
assert_zero_param_module::<Softmin, f64>(&mut m);
}
#[test]
fn test_threshold_forward() {
let m = Threshold::new(0.5, -1.0);
let x = t(&[-1.0, 0.0, 0.5, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-1.0)).abs() < 1e-7, "Threshold(-1) = {}", d[0]);
assert!((d[1] - (-1.0)).abs() < 1e-7, "Threshold(0) = {}", d[1]);
assert!((d[2] - (-1.0)).abs() < 1e-7, "Threshold(0.5) = {}", d[2]);
assert!((d[3] - 1.0).abs() < 1e-7, "Threshold(1) = {}", d[3]);
assert!((d[4] - 2.0).abs() < 1e-7, "Threshold(2) = {}", d[4]);
}
#[test]
fn test_threshold_module_trait() {
let mut m = Threshold::new(0.5, -1.0);
assert_zero_param_module::<Threshold, f64>(&mut m);
}
#[test]
fn test_softshrink_forward() {
let m = Softshrink::default(); let x = t(&[-2.0, -0.5, -0.3, 0.0, 0.3, 0.5, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[6] - 1.5).abs() < 1e-7, "Softshrink(2) = {}", d[6]);
assert!((d[0] - (-1.5)).abs() < 1e-7, "Softshrink(-2) = {}", d[0]);
assert!((d[2] - 0.0).abs() < 1e-7, "Softshrink(-0.3) = {}", d[2]);
assert!((d[3] - 0.0).abs() < 1e-7, "Softshrink(0) = {}", d[3]);
assert!((d[4] - 0.0).abs() < 1e-7, "Softshrink(0.3) = {}", d[4]);
assert!((d[1] - 0.0).abs() < 1e-7, "Softshrink(-0.5) = {}", d[1]);
assert!((d[5] - 0.0).abs() < 1e-7, "Softshrink(0.5) = {}", d[5]);
}
#[test]
fn test_softshrink_custom_lambda() {
let m = Softshrink::new(1.0);
let x = t(&[-2.0, -0.5, 0.5, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-1.0)).abs() < 1e-7);
assert!((d[1] - 0.0).abs() < 1e-7);
assert!((d[2] - 0.0).abs() < 1e-7);
assert!((d[3] - 1.0).abs() < 1e-7);
}
#[test]
fn test_softshrink_module_trait() {
let mut m = Softshrink::default();
assert_zero_param_module::<Softshrink, f64>(&mut m);
}
#[test]
fn test_hardshrink_forward() {
let m = Hardshrink::default(); let x = t(&[-2.0, -0.5, -0.3, 0.0, 0.3, 0.5, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - (-2.0)).abs() < 1e-7, "Hardshrink(-2) = {}", d[0]);
assert!((d[6] - 2.0).abs() < 1e-7, "Hardshrink(2) = {}", d[6]);
assert!((d[2] - 0.0).abs() < 1e-7, "Hardshrink(-0.3) = {}", d[2]);
assert!((d[3] - 0.0).abs() < 1e-7, "Hardshrink(0) = {}", d[3]);
assert!((d[4] - 0.0).abs() < 1e-7, "Hardshrink(0.3) = {}", d[4]);
assert!((d[1] - 0.0).abs() < 1e-7, "Hardshrink(-0.5) = {}", d[1]);
assert!((d[5] - 0.0).abs() < 1e-7, "Hardshrink(0.5) = {}", d[5]);
}
#[test]
fn test_hardshrink_module_trait() {
let mut m = Hardshrink::default();
assert_zero_param_module::<Hardshrink, f64>(&mut m);
}
#[test]
fn test_tanhshrink_forward() {
let m = Tanhshrink::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(
y.data().unwrap()[0].abs() < 1e-7,
"Tanhshrink(0) should be 0"
);
let x = t(&[10.0, -10.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(
(d[0] - 9.0).abs() < 0.01,
"Tanhshrink(10) should be ~9, got {}",
d[0]
);
assert!(
(d[1] - (-9.0)).abs() < 0.01,
"Tanhshrink(-10) should be ~-9, got {}",
d[1]
);
let x = t(&[1.0]);
let y = m.forward(&x).unwrap();
let expected = 1.0 - 1.0_f64.tanh();
assert!(
(y.data().unwrap()[0] - expected).abs() < 1e-7,
"Tanhshrink(1) expected {}, got {}",
expected,
y.data().unwrap()[0]
);
}
#[test]
fn test_tanhshrink_module_trait() {
let mut m = Tanhshrink::new();
assert_zero_param_module::<Tanhshrink, f64>(&mut m);
}
#[test]
fn test_softsign_forward() {
let m = Softsign::new();
let x = t(&[0.0]);
let y = m.forward(&x).unwrap();
assert!(y.data().unwrap()[0].abs() < 1e-7, "Softsign(0) should be 0");
let x = t(&[1.0]);
let y = m.forward(&x).unwrap();
assert!(
(y.data().unwrap()[0] - 0.5).abs() < 1e-7,
"Softsign(1) should be 0.5"
);
let x = t(&[-1.0]);
let y = m.forward(&x).unwrap();
assert!(
(y.data().unwrap()[0] - (-0.5)).abs() < 1e-7,
"Softsign(-1) should be -0.5"
);
let x = t(&[100.0, -100.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(
d[0] > 0.99 && d[0] < 1.0,
"Softsign(100) should be ~1, got {}",
d[0]
);
assert!(
d[1] < -0.99 && d[1] > -1.0,
"Softsign(-100) should be ~-1, got {}",
d[1]
);
}
#[test]
fn test_softsign_module_trait() {
let mut m = Softsign::new();
assert_zero_param_module::<Softsign, f64>(&mut m);
}
#[test]
fn test_rrelu_eval_forward() {
let mut m = RReLU::default(); m.training = false;
let mean_slope = (1.0 / 8.0 + 1.0 / 3.0) / 2.0;
let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!(
(d[0] - (-2.0 * mean_slope)).abs() < 1e-7,
"RReLU(-2,eval) = {}",
d[0]
);
assert!(
(d[1] - (-mean_slope)).abs() < 1e-7,
"RReLU(-1,eval) = {}",
d[1]
);
assert!((d[2] - 0.0).abs() < 1e-7, "RReLU(0,eval) = {}", d[2]);
assert!((d[3] - 1.0).abs() < 1e-7, "RReLU(1,eval) = {}", d[3]);
assert!((d[4] - 2.0).abs() < 1e-7, "RReLU(2,eval) = {}", d[4]);
}
#[test]
fn test_rrelu_training_positive_passthrough() {
let m = RReLU::default();
let x = t(&[0.0, 1.0, 5.0, 100.0]);
let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-7);
assert!((d[1] - 1.0).abs() < 1e-7);
assert!((d[2] - 5.0).abs() < 1e-7);
assert!((d[3] - 100.0).abs() < 1e-7);
}
#[test]
fn test_rrelu_training_negative_bounded() {
let m = RReLU::new(0.1, 0.5);
let x = t(&[-1.0; 100]); let y = m.forward(&x).unwrap();
let d = y.data().unwrap();
for (i, &val) in d.iter().enumerate() {
assert!(
val >= -0.5 - 1e-7 && val <= -0.1 + 1e-7,
"RReLU(-1, train)[{}] = {} not in [-0.5, -0.1]",
i,
val
);
}
let first = d[0];
let has_variance = d.iter().any(|&v| (v - first).abs() > 1e-10);
assert!(has_variance, "RReLU training should produce varying slopes");
}
#[test]
fn test_rrelu_module_trait() {
let mut m = RReLU::default();
assert_zero_param_module::<RReLU, f64>(&mut m);
}
#[test]
fn test_defaults() {
let _relu = ReLU::default();
let _gelu = GELU::default();
let _silu = SiLU::default();
let _sigmoid = Sigmoid::default();
let _tanh = Tanh::default();
let _softmax = Softmax::default();
let _log_softmax = LogSoftmax::default();
let lrelu = LeakyReLU::default();
assert!((lrelu.negative_slope - 0.01).abs() < f64::EPSILON);
let elu = ELU::default();
assert!((elu.alpha - 1.0).abs() < f64::EPSILON);
let _mish = Mish::default();
let celu = CELU::default();
assert!((celu.alpha - 1.0).abs() < f64::EPSILON);
let _selu = SELU::default();
let _hard_sigmoid = HardSigmoid::default();
let _hard_swish = HardSwish::default();
let softplus = Softplus::default();
assert!((softplus.beta - 1.0).abs() < f64::EPSILON);
let _glu = GLU::default();
let _relu6 = ReLU6::default();
let hardtanh = Hardtanh::default();
assert!((hardtanh.min_val - (-1.0)).abs() < f64::EPSILON);
assert!((hardtanh.max_val - 1.0).abs() < f64::EPSILON);
let _log_sigmoid = LogSigmoid::default();
let _softmin = Softmin::default();
let softshrink = Softshrink::default();
assert!((softshrink.lambda - 0.5).abs() < f64::EPSILON);
let hardshrink = Hardshrink::default();
assert!((hardshrink.lambda - 0.5).abs() < f64::EPSILON);
let _tanhshrink = Tanhshrink::default();
let _softsign = Softsign::default();
let rrelu = RReLU::default();
assert!((rrelu.lower - 1.0 / 8.0).abs() < f64::EPSILON);
assert!((rrelu.upper - 1.0 / 3.0).abs() < f64::EPSILON);
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ReLU>();
assert_send_sync::<GELU>();
assert_send_sync::<SiLU>();
assert_send_sync::<Sigmoid>();
assert_send_sync::<Tanh>();
assert_send_sync::<Softmax>();
assert_send_sync::<LogSoftmax>();
assert_send_sync::<LeakyReLU>();
assert_send_sync::<ELU>();
assert_send_sync::<Mish>();
assert_send_sync::<PReLU<f64>>();
assert_send_sync::<CELU>();
assert_send_sync::<SELU>();
assert_send_sync::<HardSigmoid>();
assert_send_sync::<HardSwish>();
assert_send_sync::<Softplus>();
assert_send_sync::<GLU>();
assert_send_sync::<ReLU6>();
assert_send_sync::<Hardtanh>();
assert_send_sync::<LogSigmoid>();
assert_send_sync::<Softmin>();
assert_send_sync::<Threshold>();
assert_send_sync::<Softshrink>();
assert_send_sync::<Hardshrink>();
assert_send_sync::<Tanhshrink>();
assert_send_sync::<Softsign>();
assert_send_sync::<RReLU>();
}
fn t_grad(data: &[f64]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], true).unwrap()
}
fn t_scalar_grad(val: f64) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], true).unwrap()
}
fn numerical_grad(f: impl Fn(f64) -> f64, x: f64) -> f64 {
let h = 1e-5;
(f(x + h) - f(x - h)) / (2.0 * h)
}
#[test]
fn test_softplus_backward_produces_grad() {
let x = t_scalar_grad(1.0);
let m = Softplus::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap();
assert!(
grad.is_some(),
"Softplus backward should produce a gradient"
);
}
#[test]
fn test_softplus_backward_at_zero() {
let x = t_scalar_grad(0.0);
let m = Softplus::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
assert!(
(grad.item().unwrap() - 0.5).abs() < 1e-6,
"Softplus grad at x=0: expected 0.5, got {}",
grad.item().unwrap()
);
}
#[test]
fn test_softplus_backward_matches_numerical() {
for &val in &[-2.0, -0.5, 0.0, 1.0, 3.0] {
let x = t_scalar_grad(val);
let m = Softplus::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(|v| (1.0 + v.exp()).ln(), val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"Softplus grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_softplus_backward_custom_beta() {
let val = 1.0;
let beta = 2.0;
let x = t_scalar_grad(val);
let m = Softplus::new(beta);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(|v| (1.0 + (beta * v).exp()).ln() / beta, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"Softplus grad at x={}, beta={}: expected {}, got {}",
val,
beta,
expected,
grad.item().unwrap()
);
}
#[test]
fn test_softplus_backward_vector() {
let x = t_grad(&[-2.0, -0.5, 0.0, 1.0, 3.0]);
let m = Softplus::new(1.0);
let y = m.forward(&x).unwrap();
let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
ferrotorch_core::backward(&sum).unwrap();
let grad = x.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
for (i, &val) in [-2.0_f64, -0.5, 0.0, 1.0, 3.0].iter().enumerate() {
let expected = numerical_grad(|v| (1.0 + v.exp()).ln(), val);
assert!(
(grad_data[i] - expected).abs() < 1e-4,
"Softplus grad[{}] at x={}: expected {}, got {}",
i,
val,
expected,
grad_data[i]
);
}
}
#[test]
fn test_elu_backward_produces_grad() {
let x = t_scalar_grad(-1.0);
let m = ELU::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap();
assert!(grad.is_some(), "ELU backward should produce a gradient");
}
#[test]
fn test_elu_backward_positive() {
let x = t_scalar_grad(2.0);
let m = ELU::new(1.0);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
assert!(
(grad.item().unwrap() - 1.0).abs() < 1e-6,
"ELU grad at x=2: expected 1.0, got {}",
grad.item().unwrap()
);
}
#[test]
fn test_elu_backward_matches_numerical() {
let alpha = 1.0;
for &val in &[-2.0, -1.0, -0.5, 0.5, 2.0] {
let x = t_scalar_grad(val);
let m = ELU::new(alpha);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected =
numerical_grad(|v| if v > 0.0 { v } else { alpha * (v.exp() - 1.0) }, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"ELU grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_elu_backward_custom_alpha() {
let alpha = 2.0;
let val = -0.5;
let x = t_scalar_grad(val);
let m = ELU::new(alpha);
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = alpha * val.exp();
assert!(
(grad.item().unwrap() - expected).abs() < 1e-5,
"ELU grad at x={}, alpha={}: expected {}, got {}",
val,
alpha,
expected,
grad.item().unwrap()
);
}
#[test]
fn test_mish_backward_produces_grad() {
let x = t_scalar_grad(1.0);
let m = Mish::new();
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap();
assert!(grad.is_some(), "Mish backward should produce a gradient");
}
#[test]
fn test_mish_backward_matches_numerical() {
let mish_fn = |v: f64| {
let sp = (1.0 + v.exp()).ln();
v * sp.tanh()
};
for &val in &[-2.0, -1.0, 0.0, 0.5, 1.5, 3.0] {
let x = t_scalar_grad(val);
let m = Mish::new();
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(mish_fn, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"Mish grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_mish_backward_vector() {
let x = t_grad(&[-1.0, 0.0, 1.0, 2.0]);
let m = Mish::new();
let y = m.forward(&x).unwrap();
let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
ferrotorch_core::backward(&sum).unwrap();
let grad = x.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
let mish_fn = |v: f64| {
let sp = (1.0 + v.exp()).ln();
v * sp.tanh()
};
for (i, &val) in [-1.0_f64, 0.0, 1.0, 2.0].iter().enumerate() {
let expected = numerical_grad(mish_fn, val);
assert!(
(grad_data[i] - expected).abs() < 1e-4,
"Mish grad[{}] at x={}: expected {}, got {}",
i,
val,
expected,
grad_data[i]
);
}
}
#[test]
fn test_relu6_backward_matches_numerical() {
let relu6_fn = |v: f64| v.max(0.0).min(6.0);
for &val in &[-2.0, 0.5, 3.0, 5.5, 8.0] {
let x = t_scalar_grad(val);
let m = ReLU6::new();
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(relu6_fn, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"ReLU6 grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_hardtanh_backward_matches_numerical() {
let hardtanh_fn = |v: f64| v.max(-1.0).min(1.0);
for &val in &[-2.0, -0.5, 0.0, 0.5, 2.0] {
let x = t_scalar_grad(val);
let m = Hardtanh::default();
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(hardtanh_fn, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"Hardtanh grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_log_sigmoid_backward_matches_numerical() {
let logsigmoid_fn = |v: f64| {
-(1.0 + (-v).exp()).ln()
};
for &val in &[-3.0, -1.0, 0.0, 1.0, 3.0] {
let x = t_scalar_grad(val);
let m = LogSigmoid::new();
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(logsigmoid_fn, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"LogSigmoid grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_tanhshrink_backward_matches_numerical() {
let tanhshrink_fn = |v: f64| v - v.tanh();
for &val in &[-2.0, -0.5, 0.0, 0.5, 2.0] {
let x = t_scalar_grad(val);
let m = Tanhshrink::new();
let y = m.forward(&x).unwrap();
ferrotorch_core::backward(&y).unwrap();
let grad = x.grad().unwrap().unwrap();
let expected = numerical_grad(tanhshrink_fn, val);
assert!(
(grad.item().unwrap() - expected).abs() < 1e-4,
"Tanhshrink grad at x={}: expected {}, got {}",
val,
expected,
grad.item().unwrap()
);
}
}
#[test]
fn test_state_dict_empty() {
let m = ReLU::new();
let sd = Module::<f64>::state_dict(&m);
assert!(sd.is_empty());
}
}