Skip to main content

ferrotorch_nn/
activation.rs

1//! Activation function wrapper modules.
2//!
3//! Each struct is a zero-parameter [`Module`] that applies the corresponding
4//! elementwise non-linearity in its [`forward`](Module::forward) method.
5//! They carry a `training` flag for API consistency but their behaviour is
6//! identical in train and eval modes.
7//!
8//! ## REQ status (per `.design/ferrotorch-nn/activation.md`)
9//!
10//! | REQ | Status | Evidence |
11//! |---|---|---|
12//! | REQ-1 | SHIPPED | `pub struct ReLU` zero-param module delegating `forward` to `act::relu` mirrors `torch/nn/modules/activation.py:104-152`; consumed by `ferrotorch-vision/src/models/vgg.rs:21` (`use ferrotorch_nn::activation::ReLU;`) and `ferrotorch-optim/src/sgd.rs:818` SGD module-flow harness. |
13//! | REQ-2 | SHIPPED | `pub struct Sigmoid`, `pub struct Tanh` zero-param wrappers mirror `torch/nn/modules/activation.py:337-434`; consumed by `ferrotorch-rl/src/mlp_policy.rs:53` (`use ferrotorch_nn::activation::Tanh;`). |
14//! | REQ-3 | SHIPPED | `pub struct GELU` plus `pub use act::GeluApproximate` re-export covers `None`/`Tanh`/`Sigmoid` modes mirroring `torch/nn/modules/activation.py:777-824`; consumed by `ferrotorch-bert/src/layer.rs:13` and `ferrotorch-whisper/src/encoder.rs:27`. |
15//! | REQ-4 | SHIPPED | `pub struct SiLU` (`x * sigmoid(x)`) mirrors `torch/nn/modules/activation.py:435-484`; consumed by `ferrotorch-diffusion/src/vae.rs:24` ResnetBlock chain. |
16//! | REQ-5 | SHIPPED | `pub struct Softmax`, `pub struct LogSoftmax`, `pub struct Softmin`, `pub struct Softmax2d` (GPU forward via REQ-11) mirror `torch/nn/modules/activation.py:1709-1929`; consumed by `ferrotorch-nn/src/lib.rs:189-193` re-exports and downstream classifier heads. |
17//! | REQ-6 | SHIPPED | `pub struct LeakyReLU`, `pub struct PReLU<T>`, `pub struct ELU`, `pub struct CELU`, `pub struct SELU`, `pub struct RReLU` parameterised activations mirror `torch/nn/modules/activation.py:153-218, 575-735, 874-931, 1575-1656`; consumed via `ferrotorch-nn/src/lib.rs:189-193` re-exports. |
18//! | REQ-7 | SHIPPED | `pub struct Hardtanh`, `pub struct ReLU6`, `pub struct HardSigmoid`, `pub struct HardSwish`, `pub struct Hardshrink`, `pub struct Softshrink`, `pub struct Tanhshrink`, `pub struct Softsign`, `pub struct LogSigmoid`, `pub struct Threshold`, `pub struct Softplus`, `pub struct Mish`, `pub struct GLU` mirror their counterparts at `torch/nn/modules/activation.py:219-336, 364-406, 485-574, 530-574, 680-735, 736-776, 825-873, 958-1056`; consumed by `ferrotorch-vision/src/models/mobilenet.rs:51` (`HardSigmoid, HardSwish, ReLU, ReLU6`). |
19//! | REQ-8 | SHIPPED | `pub struct PReLU<T: Float>` owns `pub alpha: Parameter<T>` and the hand-written `Module<T>` impl returns `("alpha", ..)` via `named_parameters` mirroring `torch/nn/modules/activation.py:1575-1656`; consumed via `ferrotorch-nn/src/lib.rs:189-193` re-export; pinned by `test_prelu_has_parameter`. |
20//! | REQ-9 | SHIPPED | The `impl_activation_module!` declarative macro synthesises `Module<T>::{forward, parameters, parameters_mut, named_parameters, train, eval, is_training}` for every zero-param activation (`PReLU` has a hand-written impl) mirroring `torch/nn/modules/module.py`; consumed by `ferrotorch-vision/src/models/vgg.rs` building `Module<f32>` chains. |
21//! | REQ-10 | SHIPPED | Every `forward` delegates to `act::*` in `ferrotorch_core::grad_fns::activation`, which attaches `ReluBackward`/`SigmoidBackward`/`TanhBackward`/`GeluBackward`/`SiluBackward`/`SoftmaxBackward`/`LogSoftmaxBackward`/`LeakyReluBackward`/`EluBackward`/`MishBackward`/`SoftplusBackward`/`GLUBackward`/`PReluBackward` when grad is enabled mirroring `aten/src/ATen/native/Activation.cpp`; consumed by `ferrotorch-optim/src/sgd.rs:818` end-to-end backward in the SGD harness. |
22//! | REQ-11 | SHIPPED | `Softmax2d::forward` dispatches CUDA input to `GpuBackend::softmax2d_f32` (channel-axis softmax; PTX in `ferrotorch-gpu/src/group_norm.rs`, wired via `CudaBackendImpl::softmax2d_f32`); CUDA-with-no-backend still returns `NotImplementedOnCuda`. Forward-only. Closed #1451; runtime parity pinned by `#[ignore]`'d `softmax2d_forward_gpu_matches_cpu`. |
23
24use ferrotorch_core::grad_fns::activation as act;
25use ferrotorch_core::grad_fns::arithmetic;
26use ferrotorch_core::grad_fns::transcendental;
27use ferrotorch_core::ops::elementwise::unary_map;
28use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, normalize_axis};
29
30use crate::module::Module;
31use crate::parameter::Parameter;
32
33// ---------------------------------------------------------------------------
34// Macro: implements the full `Module` trait for a zero-parameter activation.
35// ---------------------------------------------------------------------------
36
37macro_rules! impl_activation_module {
38    ($ty:ident) => {
39        impl<T: Float> Module<T> for $ty {
40            fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
41                self.forward(input)
42            }
43
44            fn parameters(&self) -> Vec<&Parameter<T>> {
45                vec![]
46            }
47
48            fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
49                vec![]
50            }
51
52            fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
53                vec![]
54            }
55
56            fn train(&mut self) {
57                self.training = true;
58            }
59
60            fn eval(&mut self) {
61                self.training = false;
62            }
63
64            fn is_training(&self) -> bool {
65                self.training
66            }
67        }
68    };
69}
70
71// ===========================================================================
72// ReLU
73// ===========================================================================
74
75/// Applies the rectified linear unit function elementwise:
76///
77/// `ReLU(x) = max(0, x)`
78#[derive(Debug, Clone)]
79pub struct ReLU {
80    training: bool,
81}
82
83impl ReLU {
84    /// Create a new `ReLU` module.
85    pub fn new() -> Self {
86        Self { training: true }
87    }
88
89    /// Forward pass.
90    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
91        act::relu(input)
92    }
93}
94
95impl Default for ReLU {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl_activation_module!(ReLU);
102
103// ===========================================================================
104// Softmax2d
105// ===========================================================================
106
107/// Applies softmax over the channel dimension of 4-D input [N, C, H, W].
108///
109/// `Softmax2d(x)[n, c, h, w] = exp(x[n,c,h,w]) / sum_c'(exp(x[n,c',h,w]))`
110///
111/// Matches PyTorch's `nn.Softmax2d`.
112#[derive(Debug, Clone)]
113pub struct Softmax2d {
114    training: bool,
115}
116
117impl Softmax2d {
118    pub fn new() -> Self {
119        Self { training: true }
120    }
121
122    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
123        if input.ndim() != 4 {
124            return Err(ferrotorch_core::error::FerrotorchError::InvalidArgument {
125                message: format!(
126                    "Softmax2d expects 4-D input [N,C,H,W], got {:?}",
127                    input.shape()
128                ),
129            });
130        }
131
132        let shape = input.shape();
133        let n = shape[0];
134        let c = shape[1];
135        let h = shape[2];
136        let w = shape[3];
137
138        // GPU fast path: native channel-axis softmax kernel (#1451). Softmax2d
139        // is forward-only here (the CPU path returns a non-grad tensor), so the
140        // GPU branch likewise returns a non-grad GPU-resident tensor. Mirrors
141        // `torch.nn.Softmax2d` (softmax over `dim=1`).
142        if input.is_cuda() {
143            if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
144                let handle = backend.softmax2d_f32(input.gpu_handle()?, n, c, h * w)?;
145                return Tensor::from_storage(
146                    ferrotorch_core::storage::TensorStorage::gpu(handle),
147                    shape.to_vec(),
148                    false,
149                );
150            }
151            return Err(
152                ferrotorch_core::error::FerrotorchError::NotImplementedOnCuda { op: "Softmax2d" },
153            );
154        }
155
156        let data = input.data()?;
157        let mut out = vec![<T as num_traits::Zero>::zero(); n * c * h * w];
158
159        // Softmax over channel dim (dim=1) for each (n, h, w) position.
160        for batch in 0..n {
161            for row in 0..h {
162                for col in 0..w {
163                    // Find max for stability.
164                    let mut max_val = T::neg_infinity();
165                    for ch in 0..c {
166                        let idx = batch * c * h * w + ch * h * w + row * w + col;
167                        if data[idx] > max_val {
168                            max_val = data[idx];
169                        }
170                    }
171                    // Compute exp and sum.
172                    let mut sum_exp = <T as num_traits::Zero>::zero();
173                    for ch in 0..c {
174                        let idx = batch * c * h * w + ch * h * w + row * w + col;
175                        let e = (data[idx] - max_val).exp();
176                        out[idx] = e;
177                        sum_exp += e;
178                    }
179                    // Normalize.
180                    for ch in 0..c {
181                        let idx = batch * c * h * w + ch * h * w + row * w + col;
182                        out[idx] = out[idx] / sum_exp;
183                    }
184                }
185            }
186        }
187
188        Tensor::from_storage(
189            ferrotorch_core::storage::TensorStorage::cpu(out),
190            shape.to_vec(),
191            false,
192        )
193    }
194}
195
196impl Default for Softmax2d {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202impl_activation_module!(Softmax2d);
203
204// ===========================================================================
205// GELU
206// ===========================================================================
207
208pub use act::GeluApproximate;
209
210/// Applies the Gaussian Error Linear Unit activation function.
211///
212/// Three approximation modes are available (see [`GeluApproximate`]):
213///
214/// - **`None`** (default) — exact erf-based, matches PyTorch `approximate="none"`.
215/// - **`Tanh`** — tanh approximation, matches PyTorch `approximate="tanh"`.
216/// - **`Sigmoid`** — fast `x * sigmoid(1.702 * x)`.
217#[derive(Debug, Clone)]
218pub struct GELU {
219    approximate: GeluApproximate,
220    training: bool,
221}
222
223impl GELU {
224    /// Create a new `GELU` module with the default exact (erf) mode.
225    pub fn new() -> Self {
226        Self {
227            approximate: GeluApproximate::default(),
228            training: true,
229        }
230    }
231
232    /// Create a new `GELU` module with the specified approximation mode.
233    pub fn with_approximate(approximate: GeluApproximate) -> Self {
234        Self {
235            approximate,
236            training: true,
237        }
238    }
239
240    /// Forward pass.
241    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
242        act::gelu_with(input, self.approximate)
243    }
244}
245
246impl Default for GELU {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252impl_activation_module!(GELU);
253
254// ===========================================================================
255// SiLU (Swish)
256// ===========================================================================
257
258/// Applies the Sigmoid Linear Unit (Swish) function:
259///
260/// `SiLU(x) = x * sigmoid(x)`
261#[derive(Debug, Clone)]
262pub struct SiLU {
263    training: bool,
264}
265
266impl SiLU {
267    /// Create a new `SiLU` module.
268    pub fn new() -> Self {
269        Self { training: true }
270    }
271
272    /// Forward pass.
273    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
274        act::silu(input)
275    }
276}
277
278impl Default for SiLU {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284impl_activation_module!(SiLU);
285
286// ===========================================================================
287// Sigmoid
288// ===========================================================================
289
290/// Applies the logistic sigmoid function elementwise:
291///
292/// `Sigmoid(x) = 1 / (1 + exp(-x))`
293#[derive(Debug, Clone)]
294pub struct Sigmoid {
295    training: bool,
296}
297
298impl Sigmoid {
299    /// Create a new `Sigmoid` module.
300    pub fn new() -> Self {
301        Self { training: true }
302    }
303
304    /// Forward pass.
305    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
306        act::sigmoid(input)
307    }
308}
309
310impl Default for Sigmoid {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316impl_activation_module!(Sigmoid);
317
318// ===========================================================================
319// Tanh
320// ===========================================================================
321
322/// Applies the hyperbolic tangent function elementwise.
323#[derive(Debug, Clone)]
324pub struct Tanh {
325    training: bool,
326}
327
328impl Tanh {
329    /// Create a new `Tanh` module.
330    pub fn new() -> Self {
331        Self { training: true }
332    }
333
334    /// Forward pass.
335    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
336        act::tanh(input)
337    }
338}
339
340impl Default for Tanh {
341    fn default() -> Self {
342        Self::new()
343    }
344}
345
346impl_activation_module!(Tanh);
347
348// ===========================================================================
349// Softmax
350// ===========================================================================
351
352/// Applies the softmax function along a given dimension.
353///
354/// Currently only the last axis (`dim = -1`) is supported because the
355/// underlying `ferrotorch_core::grad_fns::activation::softmax` operates on
356/// the last axis. Passing any other dimension returns an error.
357#[derive(Debug, Clone)]
358pub struct Softmax {
359    /// The dimension along which to compute softmax.
360    pub dim: isize,
361    training: bool,
362}
363
364impl Softmax {
365    /// Create a new `Softmax` module operating along `dim`.
366    ///
367    /// Defaults to `dim = -1` (last axis), matching PyTorch convention.
368    pub fn new(dim: isize) -> Self {
369        Self {
370            dim,
371            training: true,
372        }
373    }
374
375    /// Forward pass.
376    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
377        let ndim = input.ndim();
378        if ndim == 0 {
379            // Scalar: softmax is always 1.
380            return act::softmax(input);
381        }
382
383        let axis = normalize_axis(self.dim, ndim)?;
384        if axis != ndim - 1 {
385            return Err(FerrotorchError::InvalidArgument {
386                message: format!(
387                    "Softmax currently only supports dim=-1 (last axis), \
388                     but got dim={} (axis={}) for a {}-D tensor",
389                    self.dim, axis, ndim,
390                ),
391            });
392        }
393
394        act::softmax(input)
395    }
396}
397
398impl Default for Softmax {
399    fn default() -> Self {
400        Self::new(-1)
401    }
402}
403
404impl_activation_module!(Softmax);
405
406// ===========================================================================
407// LogSoftmax
408// ===========================================================================
409
410/// Applies log(softmax(x)) along a given dimension.
411///
412/// More numerically stable than computing `log(softmax(x))` separately.
413/// Currently only the last axis (`dim = -1`) is supported.
414#[derive(Debug, Clone)]
415pub struct LogSoftmax {
416    /// The dimension along which to compute log-softmax.
417    pub dim: isize,
418    training: bool,
419}
420
421impl LogSoftmax {
422    /// Create a new `LogSoftmax` module operating along `dim`.
423    pub fn new(dim: isize) -> Self {
424        Self {
425            dim,
426            training: true,
427        }
428    }
429
430    /// Forward pass.
431    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
432        let ndim = input.ndim();
433        if ndim == 0 {
434            return act::log_softmax(input);
435        }
436
437        let axis = normalize_axis(self.dim, ndim)?;
438        if axis != ndim - 1 {
439            return Err(FerrotorchError::InvalidArgument {
440                message: format!(
441                    "LogSoftmax currently only supports dim=-1 (last axis), \
442                     but got dim={} (axis={}) for a {}-D tensor",
443                    self.dim, axis, ndim,
444                ),
445            });
446        }
447
448        act::log_softmax(input)
449    }
450}
451
452impl Default for LogSoftmax {
453    fn default() -> Self {
454        Self::new(-1)
455    }
456}
457
458impl_activation_module!(LogSoftmax);
459
460// ===========================================================================
461// LeakyReLU
462// ===========================================================================
463
464/// Applies the leaky rectified linear unit function:
465///
466/// `LeakyReLU(x) = max(0, x) + negative_slope * min(0, x)`
467///
468/// This is implemented by composing differentiable primitives so that
469/// autograd works automatically:
470///
471/// ```text
472/// forward(x) = (1 - negative_slope) * relu(x) + negative_slope * x
473/// ```
474#[derive(Debug, Clone)]
475pub struct LeakyReLU {
476    /// Slope for negative inputs. Default: 0.01.
477    pub negative_slope: f64,
478    training: bool,
479}
480
481impl LeakyReLU {
482    /// Create a new `LeakyReLU` with the given negative slope.
483    pub fn new(negative_slope: f64) -> Self {
484        Self {
485            negative_slope,
486            training: true,
487        }
488    }
489
490    /// Forward pass.
491    ///
492    /// Computes `(1 - negative_slope) * relu(x) + negative_slope * x`
493    /// using differentiable core operations so gradients propagate correctly.
494    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
495        if (self.negative_slope - 0.0).abs() < f64::EPSILON {
496            // Degenerate case: standard ReLU.
497            return act::relu(input);
498        }
499        if (self.negative_slope - 1.0).abs() < f64::EPSILON {
500            // Degenerate case: identity.
501            return Ok(input.clone());
502        }
503
504        // relu_x = relu(input)
505        let relu_x = act::relu(input)?;
506
507        // scale = (1 - negative_slope)
508        let scale = T::from(1.0 - self.negative_slope).unwrap();
509        let slope = T::from(self.negative_slope).unwrap();
510
511        // scale_tensor = scalar(1 - negative_slope)
512        let scale_tensor = ferrotorch_core::scalar(scale)?;
513        // slope_tensor = scalar(negative_slope)
514        let slope_tensor = ferrotorch_core::scalar(slope)?;
515
516        // result = scale * relu(x) + slope * x
517        let scaled_relu = arithmetic::mul(&relu_x, &scale_tensor)?;
518        let scaled_x = arithmetic::mul(input, &slope_tensor)?;
519        arithmetic::add(&scaled_relu, &scaled_x)
520    }
521}
522
523impl Default for LeakyReLU {
524    fn default() -> Self {
525        Self::new(0.01)
526    }
527}
528
529impl_activation_module!(LeakyReLU);
530
531// ===========================================================================
532// ELU
533// ===========================================================================
534
535/// Applies the Exponential Linear Unit function:
536///
537/// ```text
538/// ELU(x) = x            if x > 0
539///        = alpha * (exp(x) - 1)  if x <= 0
540/// ```
541///
542/// Differentiable: autograd backward is supported via `EluBackward`.
543#[derive(Debug, Clone)]
544pub struct ELU {
545    /// Scale for the negative region. Default: 1.0.
546    pub alpha: f64,
547    training: bool,
548}
549
550impl ELU {
551    /// Create a new `ELU` module with the given alpha.
552    pub fn new(alpha: f64) -> Self {
553        Self {
554            alpha,
555            training: true,
556        }
557    }
558
559    /// Forward pass.
560    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
561        act::elu(input, self.alpha)
562    }
563}
564
565impl Default for ELU {
566    fn default() -> Self {
567        Self::new(1.0)
568    }
569}
570
571impl_activation_module!(ELU);
572
573// ===========================================================================
574// Mish
575// ===========================================================================
576
577/// Applies the Mish activation function:
578///
579/// `Mish(x) = x * tanh(softplus(x))`
580///
581/// where `softplus(x) = ln(1 + exp(x))`.
582///
583/// Differentiable: autograd backward is supported via `MishBackward`.
584#[derive(Debug, Clone)]
585pub struct Mish {
586    training: bool,
587}
588
589impl Mish {
590    /// Create a new `Mish` module.
591    pub fn new() -> Self {
592        Self { training: true }
593    }
594
595    /// Forward pass.
596    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
597        act::mish(input)
598    }
599}
600
601impl Default for Mish {
602    fn default() -> Self {
603        Self::new()
604    }
605}
606
607impl_activation_module!(Mish);
608
609// ===========================================================================
610// PReLU (Parametric ReLU)
611// ===========================================================================
612
613/// Parametric Rectified Linear Unit with a learnable negative slope.
614///
615/// `PReLU(x) = max(0, x) + alpha * min(0, x)`
616///
617/// where `alpha` is a learnable [`Parameter`]. This is equivalent to
618/// `(1 - alpha) * relu(x) + alpha * x` for differentiable composition.
619#[derive(Debug, Clone)]
620pub struct PReLU<T: Float> {
621    /// Learnable negative slope parameter.
622    pub alpha: Parameter<T>,
623    training: bool,
624}
625
626impl<T: Float> PReLU<T> {
627    /// Create a new `PReLU` module with the given initial negative slope.
628    pub fn new(init_alpha: f64) -> FerrotorchResult<Self> {
629        let alpha_val = T::from(init_alpha).unwrap();
630        let alpha_tensor = ferrotorch_core::from_slice(&[alpha_val], &[1])?;
631        Ok(Self {
632            alpha: Parameter::new(alpha_tensor),
633            training: true,
634        })
635    }
636
637    /// Forward pass.
638    ///
639    /// Computes `prelu(x, alpha) = max(0, x) + alpha * min(0, x)` via the
640    /// native fused [`act::prelu`] op (single forward, single backward).
641    pub fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
642        if self.alpha.tensor().is_cuda() {
643            return Err(
644                ferrotorch_core::error::FerrotorchError::NotImplementedOnCuda { op: "PReLU" },
645            );
646        }
647        act::prelu(input, self.alpha.tensor())
648    }
649}
650
651impl<T: Float> Module<T> for PReLU<T> {
652    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
653        self.forward(input)
654    }
655
656    fn parameters(&self) -> Vec<&Parameter<T>> {
657        vec![&self.alpha]
658    }
659
660    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
661        vec![&mut self.alpha]
662    }
663
664    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
665        vec![("alpha".to_string(), &self.alpha)]
666    }
667
668    fn train(&mut self) {
669        self.training = true;
670    }
671
672    fn eval(&mut self) {
673        self.training = false;
674    }
675
676    fn is_training(&self) -> bool {
677        self.training
678    }
679}
680
681// ===========================================================================
682// CELU
683// ===========================================================================
684
685/// Continuously Differentiable Exponential Linear Unit:
686///
687/// ```text
688/// CELU(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))
689/// ```
690///
691/// Unlike ELU, CELU is continuously differentiable everywhere.
692#[derive(Debug, Clone)]
693pub struct CELU {
694    /// Scale for the negative region. Default: 1.0.
695    pub alpha: f64,
696    training: bool,
697}
698
699impl CELU {
700    /// Create a new `CELU` module with the given alpha.
701    pub fn new(alpha: f64) -> Self {
702        Self {
703            alpha,
704            training: true,
705        }
706    }
707
708    /// Forward pass.
709    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
710        let zero = <T as num_traits::Zero>::zero();
711        let one = <T as num_traits::One>::one();
712        let alpha = T::from(self.alpha).unwrap();
713
714        unary_map(input, |x| {
715            let pos = if x > zero { x } else { zero };
716            let neg = if x < zero {
717                alpha * ((x / alpha).exp() - one)
718            } else {
719                zero
720            };
721            pos + neg
722        })
723    }
724}
725
726impl Default for CELU {
727    fn default() -> Self {
728        Self::new(1.0)
729    }
730}
731
732impl_activation_module!(CELU);
733
734// ===========================================================================
735// SELU
736// ===========================================================================
737
738/// Scaled Exponential Linear Unit with fixed constants:
739///
740/// ```text
741/// SELU(x) = lambda * (x                    if x > 0)
742///         = lambda * (alpha * (exp(x) - 1)  if x <= 0)
743/// ```
744///
745/// where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
746/// These constants enable self-normalizing behaviour when used with
747/// properly initialized weights (LeCun normal).
748#[derive(Debug, Clone)]
749pub struct SELU {
750    training: bool,
751}
752
753/// SELU alpha constant.
754const SELU_ALPHA: f64 = 1.6732632423543772;
755/// SELU lambda (scale) constant.
756const SELU_LAMBDA: f64 = 1.0507009873554805;
757
758impl SELU {
759    /// Create a new `SELU` module.
760    pub fn new() -> Self {
761        Self { training: true }
762    }
763
764    /// Forward pass.
765    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
766        let zero = <T as num_traits::Zero>::zero();
767        let one = <T as num_traits::One>::one();
768        let alpha = T::from(SELU_ALPHA).unwrap();
769        let lambda = T::from(SELU_LAMBDA).unwrap();
770
771        unary_map(input, |x| {
772            if x > zero {
773                lambda * x
774            } else {
775                lambda * alpha * (x.exp() - one)
776            }
777        })
778    }
779}
780
781impl Default for SELU {
782    fn default() -> Self {
783        Self::new()
784    }
785}
786
787impl_activation_module!(SELU);
788
789// ===========================================================================
790// HardSigmoid
791// ===========================================================================
792
793/// Hard Sigmoid activation:
794///
795/// `HardSigmoid(x) = clamp((x + 3) / 6, 0, 1)`
796///
797/// A piecewise-linear approximation of the sigmoid function.
798#[derive(Debug, Clone)]
799pub struct HardSigmoid {
800    training: bool,
801}
802
803impl HardSigmoid {
804    /// Create a new `HardSigmoid` module.
805    pub fn new() -> Self {
806        Self { training: true }
807    }
808
809    /// Forward pass.
810    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
811        let zero = <T as num_traits::Zero>::zero();
812        let one = <T as num_traits::One>::one();
813        let three = T::from(3.0).unwrap();
814        let six = T::from(6.0).unwrap();
815
816        unary_map(input, |x| {
817            let v = (x + three) / six;
818            if v < zero {
819                zero
820            } else if v > one {
821                one
822            } else {
823                v
824            }
825        })
826    }
827}
828
829impl Default for HardSigmoid {
830    fn default() -> Self {
831        Self::new()
832    }
833}
834
835impl_activation_module!(HardSigmoid);
836
837// ===========================================================================
838// HardSwish
839// ===========================================================================
840
841/// Hard Swish activation:
842///
843/// `HardSwish(x) = x * HardSigmoid(x) = x * clamp((x + 3) / 6, 0, 1)`
844///
845/// A piecewise-linear approximation of SiLU (Swish).
846#[derive(Debug, Clone)]
847pub struct HardSwish {
848    training: bool,
849}
850
851impl HardSwish {
852    /// Create a new `HardSwish` module.
853    pub fn new() -> Self {
854        Self { training: true }
855    }
856
857    /// Forward pass.
858    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
859        let zero = <T as num_traits::Zero>::zero();
860        let one = <T as num_traits::One>::one();
861        let three = T::from(3.0).unwrap();
862        let six = T::from(6.0).unwrap();
863
864        unary_map(input, |x| {
865            let hard_sig = {
866                let v = (x + three) / six;
867                if v < zero {
868                    zero
869                } else if v > one {
870                    one
871                } else {
872                    v
873                }
874            };
875            x * hard_sig
876        })
877    }
878}
879
880impl Default for HardSwish {
881    fn default() -> Self {
882        Self::new()
883    }
884}
885
886impl_activation_module!(HardSwish);
887
888// ===========================================================================
889// Softplus
890// ===========================================================================
891
892/// Softplus activation:
893///
894/// `Softplus(x) = log(1 + exp(beta * x)) / beta`
895///
896/// A smooth approximation of ReLU. As `beta` increases, Softplus converges
897/// to ReLU.
898#[derive(Debug, Clone)]
899pub struct Softplus {
900    /// Sharpness parameter. Default: 1.0.
901    pub beta: f64,
902    /// Threshold above which the function reverts to a linear function
903    /// for numerical stability. Default: 20.0.
904    pub threshold: f64,
905    training: bool,
906}
907
908impl Softplus {
909    /// Create a new `Softplus` module with the given beta.
910    pub fn new(beta: f64) -> Self {
911        Self {
912            beta,
913            threshold: 20.0,
914            training: true,
915        }
916    }
917
918    /// Forward pass.
919    ///
920    /// Differentiable: autograd backward is supported via `SoftplusBackward`.
921    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
922        act::softplus(input, self.beta, self.threshold)
923    }
924}
925
926impl Default for Softplus {
927    fn default() -> Self {
928        Self::new(1.0)
929    }
930}
931
932impl_activation_module!(Softplus);
933
934// ===========================================================================
935// GLU (Gated Linear Unit)
936// ===========================================================================
937
938/// Gated Linear Unit:
939///
940/// `GLU(x) = a * sigmoid(b)`
941///
942/// where `a` and `b` are the two halves of the input split along the last
943/// dimension. The input's last dimension must be even.
944///
945/// Reference: *Language Modeling with Gated Convolutional Networks* (Dauphin et al., 2017).
946#[derive(Debug, Clone)]
947pub struct GLU {
948    training: bool,
949}
950
951impl GLU {
952    /// Create a new `GLU` module.
953    pub fn new() -> Self {
954        Self { training: true }
955    }
956
957    /// Forward pass.
958    ///
959    /// Splits the input along the last dimension into two equal halves,
960    /// then computes `first_half * sigmoid(second_half)`.
961    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
962        let shape = input.shape();
963        let ndim = shape.len();
964        if ndim == 0 {
965            return Err(FerrotorchError::InvalidArgument {
966                message: "GLU requires at least 1D input".to_string(),
967            });
968        }
969
970        let last_dim = shape[ndim - 1];
971        if last_dim % 2 != 0 {
972            return Err(FerrotorchError::InvalidArgument {
973                message: format!(
974                    "GLU requires the last dimension to be even, got {}",
975                    last_dim
976                ),
977            });
978        }
979
980        let half = last_dim / 2;
981        let device = input.device();
982        let data = input.data_vec()?;
983
984        // Compute the stride of the last dimension (number of elements per
985        // "row" in the last dimension).
986        let outer_size: usize = shape[..ndim - 1].iter().product();
987        let outer_size = if outer_size == 0 { 1 } else { outer_size };
988
989        let one = <T as num_traits::One>::one();
990
991        let mut result = Vec::with_capacity(outer_size * half);
992        for i in 0..outer_size {
993            let base = i * last_dim;
994            for j in 0..half {
995                let a = data[base + j];
996                let b = data[base + half + j];
997                let sig_b = one / (one + (-b).exp());
998                result.push(a * sig_b);
999            }
1000        }
1001
1002        let mut out_shape = shape.to_vec();
1003        out_shape[ndim - 1] = half;
1004
1005        let out = Tensor::from_storage(
1006            ferrotorch_core::TensorStorage::cpu(result),
1007            out_shape,
1008            false,
1009        )?;
1010        if device.is_cuda() {
1011            out.to(device)
1012        } else {
1013            Ok(out)
1014        }
1015    }
1016}
1017
1018impl Default for GLU {
1019    fn default() -> Self {
1020        Self::new()
1021    }
1022}
1023
1024impl_activation_module!(GLU);
1025
1026// ===========================================================================
1027// ReLU6
1028// ===========================================================================
1029
1030/// Applies `ReLU6(x) = min(max(0, x), 6)` elementwise.
1031///
1032/// A ReLU clamped to `[0, 6]`, commonly used in MobileNet architectures.
1033///
1034/// Differentiable: uses [`transcendental::clamp`] which tracks gradients.
1035#[derive(Debug, Clone)]
1036pub struct ReLU6 {
1037    training: bool,
1038}
1039
1040impl ReLU6 {
1041    /// Create a new `ReLU6` module.
1042    pub fn new() -> Self {
1043        Self { training: true }
1044    }
1045
1046    /// Forward pass.
1047    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1048        let zero = <T as num_traits::Zero>::zero();
1049        let six = T::from(6.0).unwrap();
1050        transcendental::clamp(input, zero, six)
1051    }
1052}
1053
1054impl Default for ReLU6 {
1055    fn default() -> Self {
1056        Self::new()
1057    }
1058}
1059
1060impl_activation_module!(ReLU6);
1061
1062// ===========================================================================
1063// Hardtanh
1064// ===========================================================================
1065
1066/// Applies the hard tanh function elementwise:
1067///
1068/// ```text
1069/// Hardtanh(x) = min_val  if x < min_val
1070///             = max_val  if x > max_val
1071///             = x        otherwise
1072/// ```
1073///
1074/// Differentiable: uses [`transcendental::clamp`] which tracks gradients.
1075#[derive(Debug, Clone)]
1076pub struct Hardtanh {
1077    /// Minimum value. Default: -1.0.
1078    pub min_val: f64,
1079    /// Maximum value. Default: 1.0.
1080    pub max_val: f64,
1081    training: bool,
1082}
1083
1084impl Hardtanh {
1085    /// Create a new `Hardtanh` module with the given min and max values.
1086    pub fn new(min_val: f64, max_val: f64) -> Self {
1087        Self {
1088            min_val,
1089            max_val,
1090            training: true,
1091        }
1092    }
1093
1094    /// Forward pass.
1095    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1096        let min = T::from(self.min_val).unwrap();
1097        let max = T::from(self.max_val).unwrap();
1098        transcendental::clamp(input, min, max)
1099    }
1100}
1101
1102impl Default for Hardtanh {
1103    fn default() -> Self {
1104        Self::new(-1.0, 1.0)
1105    }
1106}
1107
1108impl_activation_module!(Hardtanh);
1109
1110// ===========================================================================
1111// LogSigmoid
1112// ===========================================================================
1113
1114/// Applies `LogSigmoid(x) = log(sigmoid(x))` elementwise.
1115///
1116/// Numerically stable: implemented as `-softplus(-x)` to avoid overflow.
1117///
1118/// Differentiable: composes differentiable primitives (softplus, neg).
1119#[derive(Debug, Clone)]
1120pub struct LogSigmoid {
1121    training: bool,
1122}
1123
1124impl LogSigmoid {
1125    /// Create a new `LogSigmoid` module.
1126    pub fn new() -> Self {
1127        Self { training: true }
1128    }
1129
1130    /// Forward pass.
1131    ///
1132    /// Uses the identity `log(sigmoid(x)) = -softplus(-x)` for numerical
1133    /// stability (avoids computing `exp(x)` for large positive `x`).
1134    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1135        // log(sigmoid(x)) = log(1/(1+exp(-x))) = -log(1+exp(-x)) = -softplus(-x)
1136        let neg_input = arithmetic::neg(input)?;
1137        let sp = act::softplus(&neg_input, 1.0, 20.0)?;
1138        arithmetic::neg(&sp)
1139    }
1140}
1141
1142impl Default for LogSigmoid {
1143    fn default() -> Self {
1144        Self::new()
1145    }
1146}
1147
1148impl_activation_module!(LogSigmoid);
1149
1150// ===========================================================================
1151// Softmin
1152// ===========================================================================
1153
1154/// Applies `Softmin(x) = Softmax(-x)` along a given dimension.
1155///
1156/// Reverses the ordering: the smallest input gets the largest probability.
1157/// Currently only the last axis (`dim = -1`) is supported.
1158#[derive(Debug, Clone)]
1159pub struct Softmin {
1160    /// The dimension along which to compute softmin.
1161    pub dim: isize,
1162    training: bool,
1163}
1164
1165impl Softmin {
1166    /// Create a new `Softmin` module operating along `dim`.
1167    pub fn new(dim: isize) -> Self {
1168        Self {
1169            dim,
1170            training: true,
1171        }
1172    }
1173
1174    /// Forward pass.
1175    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1176        let ndim = input.ndim();
1177        if ndim == 0 {
1178            let neg_input = arithmetic::neg(input)?;
1179            return act::softmax(&neg_input);
1180        }
1181
1182        let axis = normalize_axis(self.dim, ndim)?;
1183        if axis != ndim - 1 {
1184            return Err(FerrotorchError::InvalidArgument {
1185                message: format!(
1186                    "Softmin currently only supports dim=-1 (last axis), \
1187                     but got dim={} (axis={}) for a {}-D tensor",
1188                    self.dim, axis, ndim,
1189                ),
1190            });
1191        }
1192
1193        let neg_input = arithmetic::neg(input)?;
1194        act::softmax(&neg_input)
1195    }
1196}
1197
1198impl Default for Softmin {
1199    fn default() -> Self {
1200        Self::new(-1)
1201    }
1202}
1203
1204impl_activation_module!(Softmin);
1205
1206// ===========================================================================
1207// Threshold
1208// ===========================================================================
1209
1210/// Applies the threshold function:
1211///
1212/// ```text
1213/// Threshold(x) = x      if x > threshold
1214///              = value   otherwise
1215/// ```
1216///
1217/// Matches PyTorch `nn.Threshold(threshold, value)`.
1218#[derive(Debug, Clone)]
1219pub struct Threshold {
1220    /// Threshold value.
1221    pub threshold: f64,
1222    /// Replacement value for inputs at or below the threshold.
1223    pub value: f64,
1224    training: bool,
1225}
1226
1227impl Threshold {
1228    /// Create a new `Threshold` module.
1229    pub fn new(threshold: f64, value: f64) -> Self {
1230        Self {
1231            threshold,
1232            value,
1233            training: true,
1234        }
1235    }
1236
1237    /// Forward pass.
1238    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1239        let thresh = T::from(self.threshold).unwrap();
1240        let val = T::from(self.value).unwrap();
1241        unary_map(input, |x| if x > thresh { x } else { val })
1242    }
1243}
1244
1245impl_activation_module!(Threshold);
1246
1247// ===========================================================================
1248// Softshrink
1249// ===========================================================================
1250
1251/// Applies the soft shrinkage function elementwise:
1252///
1253/// ```text
1254/// Softshrink(x) = x - lambda  if x > lambda
1255///               = x + lambda  if x < -lambda
1256///               = 0           otherwise
1257/// ```
1258///
1259/// Default `lambda = 0.5`.
1260#[derive(Debug, Clone)]
1261pub struct Softshrink {
1262    /// Shrinkage threshold. Default: 0.5.
1263    pub lambda: f64,
1264    training: bool,
1265}
1266
1267impl Softshrink {
1268    /// Create a new `Softshrink` module with the given lambda.
1269    pub fn new(lambda: f64) -> Self {
1270        Self {
1271            lambda,
1272            training: true,
1273        }
1274    }
1275
1276    /// Forward pass.
1277    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1278        let lam = T::from(self.lambda).unwrap();
1279        let neg_lam = T::from(-self.lambda).unwrap();
1280        let zero = <T as num_traits::Zero>::zero();
1281        unary_map(input, |x| {
1282            if x > lam {
1283                x - lam
1284            } else if x < neg_lam {
1285                x + lam
1286            } else {
1287                zero
1288            }
1289        })
1290    }
1291}
1292
1293impl Default for Softshrink {
1294    fn default() -> Self {
1295        Self::new(0.5)
1296    }
1297}
1298
1299impl_activation_module!(Softshrink);
1300
1301// ===========================================================================
1302// Hardshrink
1303// ===========================================================================
1304
1305/// Applies the hard shrinkage function elementwise:
1306///
1307/// ```text
1308/// Hardshrink(x) = x  if x > lambda  or  x < -lambda
1309///               = 0  otherwise
1310/// ```
1311///
1312/// Default `lambda = 0.5`.
1313#[derive(Debug, Clone)]
1314pub struct Hardshrink {
1315    /// Shrinkage threshold. Default: 0.5.
1316    pub lambda: f64,
1317    training: bool,
1318}
1319
1320impl Hardshrink {
1321    /// Create a new `Hardshrink` module with the given lambda.
1322    pub fn new(lambda: f64) -> Self {
1323        Self {
1324            lambda,
1325            training: true,
1326        }
1327    }
1328
1329    /// Forward pass.
1330    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1331        let lam = T::from(self.lambda).unwrap();
1332        let neg_lam = T::from(-self.lambda).unwrap();
1333        let zero = <T as num_traits::Zero>::zero();
1334        unary_map(input, |x| if x > lam || x < neg_lam { x } else { zero })
1335    }
1336}
1337
1338impl Default for Hardshrink {
1339    fn default() -> Self {
1340        Self::new(0.5)
1341    }
1342}
1343
1344impl_activation_module!(Hardshrink);
1345
1346// ===========================================================================
1347// Tanhshrink
1348// ===========================================================================
1349
1350/// Applies `Tanhshrink(x) = x - tanh(x)` elementwise.
1351///
1352/// Differentiable: composes differentiable primitives (tanh, sub).
1353#[derive(Debug, Clone)]
1354pub struct Tanhshrink {
1355    training: bool,
1356}
1357
1358impl Tanhshrink {
1359    /// Create a new `Tanhshrink` module.
1360    pub fn new() -> Self {
1361        Self { training: true }
1362    }
1363
1364    /// Forward pass.
1365    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1366        let tanh_x = act::tanh(input)?;
1367        arithmetic::sub(input, &tanh_x)
1368    }
1369}
1370
1371impl Default for Tanhshrink {
1372    fn default() -> Self {
1373        Self::new()
1374    }
1375}
1376
1377impl_activation_module!(Tanhshrink);
1378
1379// ===========================================================================
1380// Softsign
1381// ===========================================================================
1382
1383/// Applies `Softsign(x) = x / (1 + |x|)` elementwise.
1384///
1385/// A smooth, bounded activation similar to tanh but with lighter tails.
1386#[derive(Debug, Clone)]
1387pub struct Softsign {
1388    training: bool,
1389}
1390
1391impl Softsign {
1392    /// Create a new `Softsign` module.
1393    pub fn new() -> Self {
1394        Self { training: true }
1395    }
1396
1397    /// Forward pass.
1398    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1399        let one = <T as num_traits::One>::one();
1400        unary_map(input, |x| x / (one + x.abs()))
1401    }
1402}
1403
1404impl Default for Softsign {
1405    fn default() -> Self {
1406        Self::new()
1407    }
1408}
1409
1410impl_activation_module!(Softsign);
1411
1412// ===========================================================================
1413// RReLU (Randomized Leaky ReLU)
1414// ===========================================================================
1415
1416/// Applies the Randomized Leaky ReLU function:
1417///
1418/// ```text
1419/// RReLU(x) = x                              if x >= 0
1420///          = a * x  (a ~ Uniform[lower, upper])  if x < 0   (training)
1421///          = ((lower + upper) / 2) * x       if x < 0       (eval)
1422/// ```
1423///
1424/// In training mode, each negative element gets an independent random slope
1425/// drawn from `Uniform(lower, upper)`. In eval mode, the deterministic mean
1426/// slope `(lower + upper) / 2` is used.
1427///
1428/// Default: `lower = 1/8`, `upper = 1/3`, matching PyTorch.
1429#[derive(Debug, Clone)]
1430pub struct RReLU {
1431    /// Lower bound for the random slope. Default: 1/8.
1432    pub lower: f64,
1433    /// Upper bound for the random slope. Default: 1/3.
1434    pub upper: f64,
1435    training: bool,
1436}
1437
1438/// Seed a xorshift64 state from system time and thread id.
1439fn rrelu_xorshift_seed() -> u64 {
1440    use std::collections::hash_map::DefaultHasher;
1441    use std::hash::{Hash, Hasher};
1442    use std::time::SystemTime;
1443
1444    let mut hasher = DefaultHasher::new();
1445    SystemTime::now().hash(&mut hasher);
1446    std::thread::current().id().hash(&mut hasher);
1447    let mut state = hasher.finish();
1448    if state == 0 {
1449        state = 0xdeadbeefcafe;
1450    }
1451    state
1452}
1453
1454/// Advance xorshift64 state and return a uniform value in [0, 1).
1455#[inline]
1456fn rrelu_xorshift_next(state: &mut u64) -> f64 {
1457    *state ^= *state << 13;
1458    *state ^= *state >> 7;
1459    *state ^= *state << 17;
1460    (*state as f64) / (u64::MAX as f64)
1461}
1462
1463impl RReLU {
1464    /// Create a new `RReLU` module with the given lower and upper bounds.
1465    pub fn new(lower: f64, upper: f64) -> Self {
1466        Self {
1467            lower,
1468            upper,
1469            training: true,
1470        }
1471    }
1472
1473    /// Forward pass.
1474    pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1475        let zero = <T as num_traits::Zero>::zero();
1476
1477        if self.training {
1478            // Stochastic: per-element random slope in [lower, upper].
1479            // Use Cell for interior mutability since unary_map requires Fn.
1480            let rng_state = std::cell::Cell::new(rrelu_xorshift_seed());
1481            let lower = self.lower;
1482            let upper = self.upper;
1483            let range = upper - lower;
1484
1485            unary_map(input, |x| {
1486                if x >= zero {
1487                    x
1488                } else {
1489                    let mut st = rng_state.get();
1490                    let u = rrelu_xorshift_next(&mut st);
1491                    rng_state.set(st);
1492                    let slope = T::from(lower + u * range).unwrap();
1493                    slope * x
1494                }
1495            })
1496        } else {
1497            // Deterministic: mean slope.
1498            let mean_slope = T::from((self.lower + self.upper) / 2.0).unwrap();
1499            unary_map(input, |x| if x >= zero { x } else { mean_slope * x })
1500        }
1501    }
1502}
1503
1504impl Default for RReLU {
1505    fn default() -> Self {
1506        Self::new(1.0 / 8.0, 1.0 / 3.0)
1507    }
1508}
1509
1510impl_activation_module!(RReLU);
1511
1512// ===========================================================================
1513// Tests
1514// ===========================================================================
1515
1516#[cfg(test)]
1517mod tests {
1518    use super::*;
1519    use ferrotorch_core::TensorStorage;
1520
1521    /// Helper: 1-D tensor from a slice (no grad).
1522    fn t(data: &[f64]) -> Tensor<f64> {
1523        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], false).unwrap()
1524    }
1525
1526    /// Helper: 2-D tensor (no grad).
1527    fn t2d(data: &[f64], rows: usize, cols: usize) -> Tensor<f64> {
1528        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![rows, cols], false).unwrap()
1529    }
1530
1531    // -----------------------------------------------------------------------
1532    // Module trait compliance helpers
1533    // -----------------------------------------------------------------------
1534
1535    /// Verify that a module has zero parameters and responds to train/eval.
1536    fn assert_zero_param_module<M, T: Float>(module: &mut M)
1537    where
1538        M: Module<T>,
1539    {
1540        assert!(module.parameters().is_empty(), "should have no parameters");
1541        assert!(
1542            module.parameters_mut().is_empty(),
1543            "should have no mutable parameters"
1544        );
1545        assert!(
1546            module.named_parameters().is_empty(),
1547            "should have no named parameters"
1548        );
1549        assert!(module.is_training(), "default should be training mode");
1550        module.eval();
1551        assert!(!module.is_training(), "eval() should set training=false");
1552        module.train();
1553        assert!(module.is_training(), "train() should set training=true");
1554    }
1555
1556    // -----------------------------------------------------------------------
1557    // ReLU
1558    // -----------------------------------------------------------------------
1559
1560    #[test]
1561    fn test_relu_forward() {
1562        let m = ReLU::new();
1563        let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1564        let y = m.forward(&x).unwrap();
1565        let d = y.data().unwrap();
1566        assert!((d[0] - 0.0).abs() < 1e-7);
1567        assert!((d[1] - 0.0).abs() < 1e-7);
1568        assert!((d[2] - 0.0).abs() < 1e-7);
1569        assert!((d[3] - 1.0).abs() < 1e-7);
1570        assert!((d[4] - 2.0).abs() < 1e-7);
1571    }
1572
1573    #[test]
1574    fn test_relu_module_trait() {
1575        let mut m = ReLU::new();
1576        assert_zero_param_module::<ReLU, f64>(&mut m);
1577    }
1578
1579    // -----------------------------------------------------------------------
1580    // GELU
1581    // -----------------------------------------------------------------------
1582
1583    #[test]
1584    fn test_gelu_forward() {
1585        let m = GELU::new();
1586        // gelu(0) = 0
1587        let x = t(&[0.0]);
1588        let y = m.forward(&x).unwrap();
1589        assert!(y.data().unwrap()[0].abs() < 1e-7);
1590
1591        // gelu(x) > 0 for x > 0
1592        let x = t(&[1.0, 2.0]);
1593        let y = m.forward(&x).unwrap();
1594        let d = y.data().unwrap();
1595        assert!(d[0] > 0.0);
1596        assert!(d[1] > 0.0);
1597
1598        // gelu(x) is close to x for large positive x
1599        let x = t(&[10.0]);
1600        let y = m.forward(&x).unwrap();
1601        assert!((y.data().unwrap()[0] - 10.0).abs() < 0.01);
1602    }
1603
1604    #[test]
1605    fn test_gelu_module_trait() {
1606        let mut m = GELU::new();
1607        assert_zero_param_module::<GELU, f64>(&mut m);
1608    }
1609
1610    // -----------------------------------------------------------------------
1611    // SiLU
1612    // -----------------------------------------------------------------------
1613
1614    #[test]
1615    fn test_silu_forward() {
1616        let m = SiLU::new();
1617        // silu(0) = 0 * sigmoid(0) = 0
1618        let x = t(&[0.0]);
1619        let y = m.forward(&x).unwrap();
1620        assert!(y.data().unwrap()[0].abs() < 1e-7);
1621
1622        // silu(x) = x * sigmoid(x); for large x, sigmoid(x) -> 1 so silu(x) -> x
1623        let x = t(&[10.0]);
1624        let y = m.forward(&x).unwrap();
1625        assert!((y.data().unwrap()[0] - 10.0).abs() < 0.01);
1626    }
1627
1628    #[test]
1629    fn test_silu_module_trait() {
1630        let mut m = SiLU::new();
1631        assert_zero_param_module::<SiLU, f64>(&mut m);
1632    }
1633
1634    // -----------------------------------------------------------------------
1635    // Sigmoid
1636    // -----------------------------------------------------------------------
1637
1638    #[test]
1639    fn test_sigmoid_forward() {
1640        let m = Sigmoid::new();
1641        let x = t(&[0.0]);
1642        let y = m.forward(&x).unwrap();
1643        assert!((y.data().unwrap()[0] - 0.5).abs() < 1e-7);
1644
1645        let x = t(&[-100.0, 100.0]);
1646        let y = m.forward(&x).unwrap();
1647        let d = y.data().unwrap();
1648        assert!(d[0] < 1e-10, "sigmoid(-100) should be ~0");
1649        assert!((d[1] - 1.0).abs() < 1e-10, "sigmoid(100) should be ~1");
1650    }
1651
1652    #[test]
1653    fn test_sigmoid_module_trait() {
1654        let mut m = Sigmoid::new();
1655        assert_zero_param_module::<Sigmoid, f64>(&mut m);
1656    }
1657
1658    // -----------------------------------------------------------------------
1659    // Tanh
1660    // -----------------------------------------------------------------------
1661
1662    #[test]
1663    fn test_tanh_forward() {
1664        let m = Tanh::new();
1665        let x = t(&[0.0]);
1666        let y = m.forward(&x).unwrap();
1667        assert!(y.data().unwrap()[0].abs() < 1e-7);
1668
1669        let x = t(&[-100.0, 100.0]);
1670        let y = m.forward(&x).unwrap();
1671        let d = y.data().unwrap();
1672        assert!((d[0] + 1.0).abs() < 1e-10, "tanh(-100) should be ~-1");
1673        assert!((d[1] - 1.0).abs() < 1e-10, "tanh(100) should be ~1");
1674    }
1675
1676    #[test]
1677    fn test_tanh_module_trait() {
1678        let mut m = Tanh::new();
1679        assert_zero_param_module::<Tanh, f64>(&mut m);
1680    }
1681
1682    // -----------------------------------------------------------------------
1683    // Softmax
1684    // -----------------------------------------------------------------------
1685
1686    #[test]
1687    fn test_softmax_forward_1d() {
1688        let m = Softmax::new(-1);
1689        let x = t(&[1.0, 2.0, 3.0]);
1690        let y = m.forward(&x).unwrap();
1691        let d = y.data().unwrap();
1692
1693        // Sum should be 1.
1694        let total: f64 = d.iter().sum();
1695        assert!((total - 1.0).abs() < 1e-7);
1696
1697        // Monotonicity.
1698        assert!(d[0] < d[1]);
1699        assert!(d[1] < d[2]);
1700    }
1701
1702    #[test]
1703    fn test_softmax_forward_2d() {
1704        let m = Softmax::new(-1);
1705        // [[1, 2], [3, 4]]
1706        let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
1707        let y = m.forward(&x).unwrap();
1708        let d = y.data().unwrap();
1709
1710        // Each row should sum to 1.
1711        let row0_sum = d[0] + d[1];
1712        let row1_sum = d[2] + d[3];
1713        assert!((row0_sum - 1.0).abs() < 1e-7);
1714        assert!((row1_sum - 1.0).abs() < 1e-7);
1715    }
1716
1717    #[test]
1718    fn test_softmax_wrong_dim() {
1719        let m = Softmax::new(0);
1720        let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
1721        // dim=0 is not the last axis for a 2-D tensor, should error.
1722        assert!(m.forward(&x).is_err());
1723    }
1724
1725    #[test]
1726    fn test_softmax_module_trait() {
1727        let mut m = Softmax::new(-1);
1728        assert_zero_param_module::<Softmax, f64>(&mut m);
1729    }
1730
1731    // -----------------------------------------------------------------------
1732    // LogSoftmax
1733    // -----------------------------------------------------------------------
1734
1735    #[test]
1736    fn test_log_softmax_forward_1d() {
1737        let m = LogSoftmax::new(-1);
1738        let x = t(&[1.0, 2.0, 3.0]);
1739        let y = m.forward(&x).unwrap();
1740        let d = y.data().unwrap();
1741
1742        // exp(log_softmax) should sum to 1.
1743        let total: f64 = d.iter().map(|&v| v.exp()).sum();
1744        assert!((total - 1.0).abs() < 1e-7, "exp(log_softmax) sum = {total}");
1745
1746        // All log-probabilities should be negative.
1747        assert!(d.iter().all(|&v| v <= 0.0));
1748    }
1749
1750    #[test]
1751    fn test_log_softmax_module_trait() {
1752        let mut m = LogSoftmax::new(-1);
1753        assert_zero_param_module::<LogSoftmax, f64>(&mut m);
1754    }
1755
1756    // -----------------------------------------------------------------------
1757    // LeakyReLU
1758    // -----------------------------------------------------------------------
1759
1760    #[test]
1761    fn test_leaky_relu_forward() {
1762        let m = LeakyReLU::new(0.01);
1763        let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1764        let y = m.forward(&x).unwrap();
1765        let d = y.data().unwrap();
1766
1767        assert!((d[0] - (-0.02)).abs() < 1e-7, "LeakyReLU(-2) = {}", d[0]);
1768        assert!((d[1] - (-0.01)).abs() < 1e-7, "LeakyReLU(-1) = {}", d[1]);
1769        assert!((d[2] - 0.0).abs() < 1e-7, "LeakyReLU(0) = {}", d[2]);
1770        assert!((d[3] - 1.0).abs() < 1e-7, "LeakyReLU(1) = {}", d[3]);
1771        assert!((d[4] - 2.0).abs() < 1e-7, "LeakyReLU(2) = {}", d[4]);
1772    }
1773
1774    #[test]
1775    fn test_leaky_relu_large_slope() {
1776        let m = LeakyReLU::new(0.2);
1777        let x = t(&[-5.0, 3.0]);
1778        let y = m.forward(&x).unwrap();
1779        let d = y.data().unwrap();
1780
1781        assert!(
1782            (d[0] - (-1.0)).abs() < 1e-7,
1783            "LeakyReLU(-5, slope=0.2) = {}",
1784            d[0]
1785        );
1786        assert!(
1787            (d[1] - 3.0).abs() < 1e-7,
1788            "LeakyReLU(3, slope=0.2) = {}",
1789            d[1]
1790        );
1791    }
1792
1793    #[test]
1794    fn test_leaky_relu_zero_slope_is_relu() {
1795        let m = LeakyReLU::new(0.0);
1796        let x = t(&[-2.0, 0.0, 3.0]);
1797        let y = m.forward(&x).unwrap();
1798        let d = y.data().unwrap();
1799
1800        assert!((d[0] - 0.0).abs() < 1e-7);
1801        assert!((d[1] - 0.0).abs() < 1e-7);
1802        assert!((d[2] - 3.0).abs() < 1e-7);
1803    }
1804
1805    #[test]
1806    fn test_leaky_relu_module_trait() {
1807        let mut m = LeakyReLU::new(0.01);
1808        assert_zero_param_module::<LeakyReLU, f64>(&mut m);
1809    }
1810
1811    // -----------------------------------------------------------------------
1812    // ELU
1813    // -----------------------------------------------------------------------
1814
1815    #[test]
1816    fn test_elu_forward() {
1817        let m = ELU::new(1.0);
1818        let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1819        let y = m.forward(&x).unwrap();
1820        let d = y.data().unwrap();
1821
1822        // For x > 0, ELU(x) = x.
1823        assert!((d[3] - 1.0).abs() < 1e-7);
1824        assert!((d[4] - 2.0).abs() < 1e-7);
1825
1826        // For x = 0, ELU(0) = 0.
1827        assert!((d[2] - 0.0).abs() < 1e-7);
1828
1829        // For x < 0, ELU(x) = alpha * (exp(x) - 1) < 0.
1830        let expected_m1 = 1.0 * ((-1.0_f64).exp() - 1.0);
1831        assert!(
1832            (d[1] - expected_m1).abs() < 1e-7,
1833            "ELU(-1) expected {}, got {}",
1834            expected_m1,
1835            d[1]
1836        );
1837
1838        let expected_m2 = 1.0 * ((-2.0_f64).exp() - 1.0);
1839        assert!(
1840            (d[0] - expected_m2).abs() < 1e-7,
1841            "ELU(-2) expected {}, got {}",
1842            expected_m2,
1843            d[0]
1844        );
1845
1846        // ELU approaches -alpha from below for very negative x.
1847        let x = t(&[-100.0]);
1848        let y = m.forward(&x).unwrap();
1849        assert!((y.data().unwrap()[0] + 1.0).abs() < 1e-7);
1850    }
1851
1852    #[test]
1853    fn test_elu_custom_alpha() {
1854        let m = ELU::new(2.0);
1855        let x = t(&[-1.0, 1.0]);
1856        let y = m.forward(&x).unwrap();
1857        let d = y.data().unwrap();
1858
1859        let expected = 2.0 * ((-1.0_f64).exp() - 1.0);
1860        assert!((d[0] - expected).abs() < 1e-7);
1861        assert!((d[1] - 1.0).abs() < 1e-7);
1862    }
1863
1864    #[test]
1865    fn test_elu_module_trait() {
1866        let mut m = ELU::new(1.0);
1867        assert_zero_param_module::<ELU, f64>(&mut m);
1868    }
1869
1870    // -----------------------------------------------------------------------
1871    // Mish
1872    // -----------------------------------------------------------------------
1873
1874    #[test]
1875    fn test_mish_forward() {
1876        let m = Mish::new();
1877        // mish(0) = 0 * tanh(softplus(0)) = 0 * tanh(ln(2)) = 0
1878        let x = t(&[0.0]);
1879        let y = m.forward(&x).unwrap();
1880        assert!(y.data().unwrap()[0].abs() < 1e-7, "mish(0) should be 0");
1881
1882        // For large positive x, mish(x) -> x (softplus(x) -> x, tanh(x) -> 1).
1883        let x = t(&[20.0]);
1884        let y = m.forward(&x).unwrap();
1885        assert!(
1886            (y.data().unwrap()[0] - 20.0).abs() < 0.01,
1887            "mish(20) should be ~20"
1888        );
1889
1890        // mish is slightly negative for negative inputs.
1891        let x = t(&[-1.0]);
1892        let y = m.forward(&x).unwrap();
1893        let val = y.data().unwrap()[0];
1894        let softplus = (1.0 + (-1.0_f64).exp()).ln();
1895        let expected = -softplus.tanh();
1896        assert!(
1897            (val - expected).abs() < 1e-7,
1898            "mish(-1) expected {}, got {}",
1899            expected,
1900            val
1901        );
1902    }
1903
1904    #[test]
1905    fn test_mish_module_trait() {
1906        let mut m = Mish::new();
1907        assert_zero_param_module::<Mish, f64>(&mut m);
1908    }
1909
1910    // -----------------------------------------------------------------------
1911    // Default constructors
1912    // -----------------------------------------------------------------------
1913
1914    // -----------------------------------------------------------------------
1915    // PReLU
1916    // -----------------------------------------------------------------------
1917
1918    #[test]
1919    fn test_prelu_forward_default() {
1920        let m = PReLU::<f64>::new(0.25).unwrap();
1921        let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1922        let y = m.forward(&x).unwrap();
1923        let d = y.data().unwrap();
1924        // For x > 0: output = x. For x < 0: output = 0.25 * x.
1925        assert!((d[0] - (-0.5)).abs() < 1e-6, "PReLU(-2) = {}", d[0]);
1926        assert!((d[1] - (-0.25)).abs() < 1e-6, "PReLU(-1) = {}", d[1]);
1927        assert!((d[2] - 0.0).abs() < 1e-6, "PReLU(0) = {}", d[2]);
1928        assert!((d[3] - 1.0).abs() < 1e-6, "PReLU(1) = {}", d[3]);
1929        assert!((d[4] - 2.0).abs() < 1e-6, "PReLU(2) = {}", d[4]);
1930    }
1931
1932    #[test]
1933    fn test_prelu_has_parameter() {
1934        let m = PReLU::<f64>::new(0.25).unwrap();
1935        assert_eq!(m.parameters().len(), 1, "PReLU should have 1 parameter");
1936        let named = m.named_parameters();
1937        assert_eq!(named.len(), 1);
1938        assert_eq!(named[0].0, "alpha");
1939    }
1940
1941    #[test]
1942    fn test_prelu_module_trait() {
1943        let mut m = PReLU::<f64>::new(0.25).unwrap();
1944        assert_eq!(m.parameters().len(), 1);
1945        assert!(m.is_training());
1946        m.eval();
1947        assert!(!m.is_training());
1948        m.train();
1949        assert!(m.is_training());
1950    }
1951
1952    // -----------------------------------------------------------------------
1953    // CELU
1954    // -----------------------------------------------------------------------
1955
1956    #[test]
1957    fn test_celu_forward() {
1958        let m = CELU::new(1.0);
1959        let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1960        let y = m.forward(&x).unwrap();
1961        let d = y.data().unwrap();
1962
1963        // For x > 0: CELU(x) = x
1964        assert!((d[3] - 1.0).abs() < 1e-7);
1965        assert!((d[4] - 2.0).abs() < 1e-7);
1966        assert!((d[2] - 0.0).abs() < 1e-7);
1967
1968        // For x < 0: CELU(x) = alpha * (exp(x/alpha) - 1)
1969        let expected_m1 = 1.0 * ((-1.0_f64).exp() - 1.0);
1970        assert!((d[1] - expected_m1).abs() < 1e-7, "CELU(-1) = {}", d[1]);
1971    }
1972
1973    #[test]
1974    fn test_celu_module_trait() {
1975        let mut m = CELU::new(1.0);
1976        assert_zero_param_module::<CELU, f64>(&mut m);
1977    }
1978
1979    // -----------------------------------------------------------------------
1980    // SELU
1981    // -----------------------------------------------------------------------
1982
1983    #[test]
1984    fn test_selu_forward() {
1985        let m = SELU::new();
1986        let x = t(&[-1.0, 0.0, 1.0]);
1987        let y = m.forward(&x).unwrap();
1988        let d = y.data().unwrap();
1989
1990        // For x > 0: SELU(x) = lambda * x
1991        let lambda = 1.0507009873554805_f64;
1992        let alpha = 1.6732632423543772_f64;
1993        assert!((d[2] - lambda * 1.0).abs() < 1e-7, "SELU(1) = {}", d[2]);
1994        assert!((d[1] - 0.0).abs() < 1e-7, "SELU(0) = {}", d[1]);
1995
1996        // For x < 0: SELU(x) = lambda * alpha * (exp(x) - 1)
1997        let expected_m1 = lambda * alpha * ((-1.0_f64).exp() - 1.0);
1998        assert!((d[0] - expected_m1).abs() < 1e-7, "SELU(-1) = {}", d[0]);
1999    }
2000
2001    #[test]
2002    fn test_selu_module_trait() {
2003        let mut m = SELU::new();
2004        assert_zero_param_module::<SELU, f64>(&mut m);
2005    }
2006
2007    // -----------------------------------------------------------------------
2008    // HardSigmoid
2009    // -----------------------------------------------------------------------
2010
2011    #[test]
2012    fn test_hard_sigmoid_forward() {
2013        let m = HardSigmoid::new();
2014        // clamp((x+3)/6, 0, 1)
2015        // x = -4: (−4+3)/6 = −1/6 < 0 -> 0
2016        // x = -3: (-3+3)/6 = 0
2017        // x = 0: (0+3)/6 = 0.5
2018        // x = 3: (3+3)/6 = 1.0
2019        // x = 5: (5+3)/6 > 1 -> 1
2020        let x = t(&[-4.0, -3.0, 0.0, 3.0, 5.0]);
2021        let y = m.forward(&x).unwrap();
2022        let d = y.data().unwrap();
2023        assert!((d[0] - 0.0).abs() < 1e-7, "HardSigmoid(-4) = {}", d[0]);
2024        assert!((d[1] - 0.0).abs() < 1e-7, "HardSigmoid(-3) = {}", d[1]);
2025        assert!((d[2] - 0.5).abs() < 1e-7, "HardSigmoid(0) = {}", d[2]);
2026        assert!((d[3] - 1.0).abs() < 1e-7, "HardSigmoid(3) = {}", d[3]);
2027        assert!((d[4] - 1.0).abs() < 1e-7, "HardSigmoid(5) = {}", d[4]);
2028    }
2029
2030    #[test]
2031    fn test_hard_sigmoid_module_trait() {
2032        let mut m = HardSigmoid::new();
2033        assert_zero_param_module::<HardSigmoid, f64>(&mut m);
2034    }
2035
2036    // -----------------------------------------------------------------------
2037    // HardSwish
2038    // -----------------------------------------------------------------------
2039
2040    #[test]
2041    fn test_hard_swish_forward() {
2042        let m = HardSwish::new();
2043        // HardSwish(x) = x * clamp((x+3)/6, 0, 1)
2044        // x = -4: -4 * 0 = 0
2045        // x = 0: 0 * 0.5 = 0
2046        // x = 3: 3 * 1.0 = 3
2047        // x = 5: 5 * 1.0 = 5
2048        // x = -1: -1 * ((-1+3)/6) = -1 * (1/3) = -1/3
2049        let x = t(&[-4.0, 0.0, 3.0, 5.0, -1.0]);
2050        let y = m.forward(&x).unwrap();
2051        let d = y.data().unwrap();
2052        assert!((d[0] - 0.0).abs() < 1e-7, "HardSwish(-4) = {}", d[0]);
2053        assert!((d[1] - 0.0).abs() < 1e-7, "HardSwish(0) = {}", d[1]);
2054        assert!((d[2] - 3.0).abs() < 1e-7, "HardSwish(3) = {}", d[2]);
2055        assert!((d[3] - 5.0).abs() < 1e-7, "HardSwish(5) = {}", d[3]);
2056        assert!(
2057            (d[4] - (-1.0 / 3.0)).abs() < 1e-7,
2058            "HardSwish(-1) = {}",
2059            d[4]
2060        );
2061    }
2062
2063    #[test]
2064    fn test_hard_swish_module_trait() {
2065        let mut m = HardSwish::new();
2066        assert_zero_param_module::<HardSwish, f64>(&mut m);
2067    }
2068
2069    // -----------------------------------------------------------------------
2070    // Softplus
2071    // -----------------------------------------------------------------------
2072
2073    #[test]
2074    fn test_softplus_forward() {
2075        let m = Softplus::new(1.0);
2076        // softplus(0) = ln(1 + 1) = ln(2)
2077        let x = t(&[0.0]);
2078        let y = m.forward(&x).unwrap();
2079        let d = y.data().unwrap();
2080        assert!((d[0] - 2.0_f64.ln()).abs() < 1e-7, "Softplus(0) = {}", d[0]);
2081
2082        // For large x, softplus(x) -> x (threshold mode).
2083        let x = t(&[25.0]);
2084        let y = m.forward(&x).unwrap();
2085        let d = y.data().unwrap();
2086        assert!((d[0] - 25.0).abs() < 1e-5, "Softplus(25) = {}", d[0]);
2087
2088        // softplus(1) = ln(1 + e) ~ 1.3133
2089        let x = t(&[1.0]);
2090        let y = m.forward(&x).unwrap();
2091        let d = y.data().unwrap();
2092        let expected = (1.0 + 1.0_f64.exp()).ln();
2093        assert!((d[0] - expected).abs() < 1e-7, "Softplus(1) = {}", d[0]);
2094    }
2095
2096    #[test]
2097    fn test_softplus_custom_beta() {
2098        let m = Softplus::new(2.0);
2099        // softplus(x, beta=2) = ln(1 + exp(2*x)) / 2
2100        let x = t(&[0.0]);
2101        let y = m.forward(&x).unwrap();
2102        let d = y.data().unwrap();
2103        let expected = 2.0_f64.ln() / 2.0;
2104        assert!(
2105            (d[0] - expected).abs() < 1e-7,
2106            "Softplus(0, beta=2) = {}",
2107            d[0]
2108        );
2109    }
2110
2111    #[test]
2112    fn test_softplus_module_trait() {
2113        let mut m = Softplus::new(1.0);
2114        assert_zero_param_module::<Softplus, f64>(&mut m);
2115    }
2116
2117    // -----------------------------------------------------------------------
2118    // GLU
2119    // -----------------------------------------------------------------------
2120
2121    #[test]
2122    fn test_glu_forward_1d() {
2123        let m = GLU::new();
2124        // input = [1.0, 0.0, 2.0, 0.0]  (last dim = 4, split into [1,0] and [2,0])
2125        // a = [1.0, 0.0], b = [2.0, 0.0]
2126        // output = a * sigmoid(b) = [1.0 * sigmoid(2.0), 0.0 * sigmoid(0.0)]
2127        let x = t(&[1.0, 0.0, 2.0, 0.0]);
2128        let y = m.forward(&x).unwrap();
2129        assert_eq!(y.shape(), &[2]);
2130        let d = y.data().unwrap();
2131        let sig_2 = 1.0 / (1.0 + (-2.0_f64).exp());
2132        assert!((d[0] - sig_2).abs() < 1e-7, "GLU[0] = {}", d[0]);
2133        assert!((d[1] - 0.0).abs() < 1e-7, "GLU[1] = {}", d[1]);
2134    }
2135
2136    #[test]
2137    fn test_glu_forward_2d() {
2138        let m = GLU::new();
2139        // [[1.0, 0.0, 2.0, 0.0]] -> shape [1, 4], splits last dim
2140        let x = t2d(&[1.0, 0.0, 2.0, 0.0], 1, 4);
2141        let y = m.forward(&x).unwrap();
2142        assert_eq!(y.shape(), &[1, 2]);
2143        let d = y.data().unwrap();
2144        let sig_2 = 1.0 / (1.0 + (-2.0_f64).exp());
2145        assert!((d[0] - sig_2).abs() < 1e-7);
2146        assert!((d[1] - 0.0).abs() < 1e-7);
2147    }
2148
2149    #[test]
2150    fn test_glu_odd_dim_error() {
2151        let m = GLU::new();
2152        let x = t(&[1.0, 2.0, 3.0]); // last dim = 3 (odd)
2153        assert!(m.forward(&x).is_err());
2154    }
2155
2156    #[test]
2157    fn test_glu_module_trait() {
2158        let mut m = GLU::new();
2159        assert_zero_param_module::<GLU, f64>(&mut m);
2160    }
2161
2162    // -----------------------------------------------------------------------
2163    // ReLU6
2164    // -----------------------------------------------------------------------
2165
2166    #[test]
2167    fn test_relu6_forward() {
2168        let m = ReLU6::new();
2169        let x = t(&[-2.0, 0.0, 3.0, 6.0, 10.0]);
2170        let y = m.forward(&x).unwrap();
2171        let d = y.data().unwrap();
2172        assert!((d[0] - 0.0).abs() < 1e-7, "ReLU6(-2) = {}", d[0]);
2173        assert!((d[1] - 0.0).abs() < 1e-7, "ReLU6(0) = {}", d[1]);
2174        assert!((d[2] - 3.0).abs() < 1e-7, "ReLU6(3) = {}", d[2]);
2175        assert!((d[3] - 6.0).abs() < 1e-7, "ReLU6(6) = {}", d[3]);
2176        assert!((d[4] - 6.0).abs() < 1e-7, "ReLU6(10) = {}", d[4]);
2177    }
2178
2179    #[test]
2180    fn test_relu6_module_trait() {
2181        let mut m = ReLU6::new();
2182        assert_zero_param_module::<ReLU6, f64>(&mut m);
2183    }
2184
2185    // -----------------------------------------------------------------------
2186    // Hardtanh
2187    // -----------------------------------------------------------------------
2188
2189    #[test]
2190    fn test_hardtanh_forward_default() {
2191        let m = Hardtanh::default();
2192        // clamp(x, -1, 1)
2193        let x = t(&[-5.0, -1.0, 0.0, 0.5, 1.0, 3.0]);
2194        let y = m.forward(&x).unwrap();
2195        let d = y.data().unwrap();
2196        assert!((d[0] - (-1.0)).abs() < 1e-7, "Hardtanh(-5) = {}", d[0]);
2197        assert!((d[1] - (-1.0)).abs() < 1e-7, "Hardtanh(-1) = {}", d[1]);
2198        assert!((d[2] - 0.0).abs() < 1e-7, "Hardtanh(0) = {}", d[2]);
2199        assert!((d[3] - 0.5).abs() < 1e-7, "Hardtanh(0.5) = {}", d[3]);
2200        assert!((d[4] - 1.0).abs() < 1e-7, "Hardtanh(1) = {}", d[4]);
2201        assert!((d[5] - 1.0).abs() < 1e-7, "Hardtanh(3) = {}", d[5]);
2202    }
2203
2204    #[test]
2205    fn test_hardtanh_custom_range() {
2206        let m = Hardtanh::new(-2.0, 2.0);
2207        let x = t(&[-5.0, -2.0, 0.0, 2.0, 5.0]);
2208        let y = m.forward(&x).unwrap();
2209        let d = y.data().unwrap();
2210        assert!((d[0] - (-2.0)).abs() < 1e-7);
2211        assert!((d[1] - (-2.0)).abs() < 1e-7);
2212        assert!((d[2] - 0.0).abs() < 1e-7);
2213        assert!((d[3] - 2.0).abs() < 1e-7);
2214        assert!((d[4] - 2.0).abs() < 1e-7);
2215    }
2216
2217    #[test]
2218    fn test_hardtanh_module_trait() {
2219        let mut m = Hardtanh::default();
2220        assert_zero_param_module::<Hardtanh, f64>(&mut m);
2221    }
2222
2223    // -----------------------------------------------------------------------
2224    // LogSigmoid
2225    // -----------------------------------------------------------------------
2226
2227    #[test]
2228    fn test_log_sigmoid_forward() {
2229        let m = LogSigmoid::new();
2230        // log(sigmoid(0)) = log(0.5) = -ln(2)
2231        let x = t(&[0.0]);
2232        let y = m.forward(&x).unwrap();
2233        let d = y.data().unwrap();
2234        assert!(
2235            (d[0] - (-2.0_f64.ln())).abs() < 1e-6,
2236            "LogSigmoid(0) = {}, expected {}",
2237            d[0],
2238            -2.0_f64.ln()
2239        );
2240
2241        // All outputs should be <= 0 (log of a probability).
2242        let x = t(&[-10.0, -1.0, 0.0, 1.0, 10.0]);
2243        let y = m.forward(&x).unwrap();
2244        let d = y.data().unwrap();
2245        assert!(
2246            d.iter().all(|&v| v <= 0.0),
2247            "All LogSigmoid values should be <= 0"
2248        );
2249
2250        // For large positive x, log(sigmoid(x)) -> 0.
2251        assert!(
2252            d[4].abs() < 1e-4,
2253            "LogSigmoid(10) should be ~0, got {}",
2254            d[4]
2255        );
2256
2257        // For large negative x, log(sigmoid(x)) -> x.
2258        assert!(
2259            (d[0] - (-10.0)).abs() < 0.1,
2260            "LogSigmoid(-10) should be ~-10, got {}",
2261            d[0]
2262        );
2263    }
2264
2265    #[test]
2266    fn test_log_sigmoid_module_trait() {
2267        let mut m = LogSigmoid::new();
2268        assert_zero_param_module::<LogSigmoid, f64>(&mut m);
2269    }
2270
2271    // -----------------------------------------------------------------------
2272    // Softmin
2273    // -----------------------------------------------------------------------
2274
2275    #[test]
2276    fn test_softmin_forward_1d() {
2277        let m = Softmin::new(-1);
2278        let x = t(&[1.0, 2.0, 3.0]);
2279        let y = m.forward(&x).unwrap();
2280        let d = y.data().unwrap();
2281
2282        // Sum should be 1.
2283        let total: f64 = d.iter().sum();
2284        assert!((total - 1.0).abs() < 1e-7, "Softmin sum = {}", total);
2285
2286        // Softmin reverses ordering: smallest input gets largest probability.
2287        assert!(d[0] > d[1], "softmin(1) > softmin(2)");
2288        assert!(d[1] > d[2], "softmin(2) > softmin(3)");
2289    }
2290
2291    #[test]
2292    fn test_softmin_wrong_dim() {
2293        let m = Softmin::new(0);
2294        let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
2295        assert!(m.forward(&x).is_err());
2296    }
2297
2298    #[test]
2299    fn test_softmin_module_trait() {
2300        let mut m = Softmin::new(-1);
2301        assert_zero_param_module::<Softmin, f64>(&mut m);
2302    }
2303
2304    // -----------------------------------------------------------------------
2305    // Threshold
2306    // -----------------------------------------------------------------------
2307
2308    #[test]
2309    fn test_threshold_forward() {
2310        let m = Threshold::new(0.5, -1.0);
2311        let x = t(&[-1.0, 0.0, 0.5, 1.0, 2.0]);
2312        let y = m.forward(&x).unwrap();
2313        let d = y.data().unwrap();
2314        // x <= threshold -> value
2315        assert!((d[0] - (-1.0)).abs() < 1e-7, "Threshold(-1) = {}", d[0]);
2316        assert!((d[1] - (-1.0)).abs() < 1e-7, "Threshold(0) = {}", d[1]);
2317        assert!((d[2] - (-1.0)).abs() < 1e-7, "Threshold(0.5) = {}", d[2]);
2318        // x > threshold -> x
2319        assert!((d[3] - 1.0).abs() < 1e-7, "Threshold(1) = {}", d[3]);
2320        assert!((d[4] - 2.0).abs() < 1e-7, "Threshold(2) = {}", d[4]);
2321    }
2322
2323    #[test]
2324    fn test_threshold_module_trait() {
2325        let mut m = Threshold::new(0.5, -1.0);
2326        assert_zero_param_module::<Threshold, f64>(&mut m);
2327    }
2328
2329    // -----------------------------------------------------------------------
2330    // Softshrink
2331    // -----------------------------------------------------------------------
2332
2333    #[test]
2334    fn test_softshrink_forward() {
2335        let m = Softshrink::default(); // lambda = 0.5
2336        let x = t(&[-2.0, -0.5, -0.3, 0.0, 0.3, 0.5, 2.0]);
2337        let y = m.forward(&x).unwrap();
2338        let d = y.data().unwrap();
2339        // x > lambda: x - lambda
2340        assert!((d[6] - 1.5).abs() < 1e-7, "Softshrink(2) = {}", d[6]);
2341        // x < -lambda: x + lambda
2342        assert!((d[0] - (-1.5)).abs() < 1e-7, "Softshrink(-2) = {}", d[0]);
2343        // -lambda <= x <= lambda: 0
2344        assert!((d[2] - 0.0).abs() < 1e-7, "Softshrink(-0.3) = {}", d[2]);
2345        assert!((d[3] - 0.0).abs() < 1e-7, "Softshrink(0) = {}", d[3]);
2346        assert!((d[4] - 0.0).abs() < 1e-7, "Softshrink(0.3) = {}", d[4]);
2347        // Boundary: x == lambda or x == -lambda -> 0
2348        assert!((d[1] - 0.0).abs() < 1e-7, "Softshrink(-0.5) = {}", d[1]);
2349        assert!((d[5] - 0.0).abs() < 1e-7, "Softshrink(0.5) = {}", d[5]);
2350    }
2351
2352    #[test]
2353    fn test_softshrink_custom_lambda() {
2354        let m = Softshrink::new(1.0);
2355        let x = t(&[-2.0, -0.5, 0.5, 2.0]);
2356        let y = m.forward(&x).unwrap();
2357        let d = y.data().unwrap();
2358        assert!((d[0] - (-1.0)).abs() < 1e-7);
2359        assert!((d[1] - 0.0).abs() < 1e-7);
2360        assert!((d[2] - 0.0).abs() < 1e-7);
2361        assert!((d[3] - 1.0).abs() < 1e-7);
2362    }
2363
2364    #[test]
2365    fn test_softshrink_module_trait() {
2366        let mut m = Softshrink::default();
2367        assert_zero_param_module::<Softshrink, f64>(&mut m);
2368    }
2369
2370    // -----------------------------------------------------------------------
2371    // Hardshrink
2372    // -----------------------------------------------------------------------
2373
2374    #[test]
2375    fn test_hardshrink_forward() {
2376        let m = Hardshrink::default(); // lambda = 0.5
2377        let x = t(&[-2.0, -0.5, -0.3, 0.0, 0.3, 0.5, 2.0]);
2378        let y = m.forward(&x).unwrap();
2379        let d = y.data().unwrap();
2380        // |x| > lambda: x
2381        assert!((d[0] - (-2.0)).abs() < 1e-7, "Hardshrink(-2) = {}", d[0]);
2382        assert!((d[6] - 2.0).abs() < 1e-7, "Hardshrink(2) = {}", d[6]);
2383        // |x| <= lambda: 0
2384        assert!((d[2] - 0.0).abs() < 1e-7, "Hardshrink(-0.3) = {}", d[2]);
2385        assert!((d[3] - 0.0).abs() < 1e-7, "Hardshrink(0) = {}", d[3]);
2386        assert!((d[4] - 0.0).abs() < 1e-7, "Hardshrink(0.3) = {}", d[4]);
2387        // Boundary: x == lambda or x == -lambda -> 0
2388        assert!((d[1] - 0.0).abs() < 1e-7, "Hardshrink(-0.5) = {}", d[1]);
2389        assert!((d[5] - 0.0).abs() < 1e-7, "Hardshrink(0.5) = {}", d[5]);
2390    }
2391
2392    #[test]
2393    fn test_hardshrink_module_trait() {
2394        let mut m = Hardshrink::default();
2395        assert_zero_param_module::<Hardshrink, f64>(&mut m);
2396    }
2397
2398    // -----------------------------------------------------------------------
2399    // Tanhshrink
2400    // -----------------------------------------------------------------------
2401
2402    #[test]
2403    fn test_tanhshrink_forward() {
2404        let m = Tanhshrink::new();
2405        // tanhshrink(0) = 0 - tanh(0) = 0
2406        let x = t(&[0.0]);
2407        let y = m.forward(&x).unwrap();
2408        assert!(
2409            y.data().unwrap()[0].abs() < 1e-7,
2410            "Tanhshrink(0) should be 0"
2411        );
2412
2413        // For large |x|, tanh(x) -> sign(x), so tanhshrink(x) -> x - sign(x).
2414        let x = t(&[10.0, -10.0]);
2415        let y = m.forward(&x).unwrap();
2416        let d = y.data().unwrap();
2417        assert!(
2418            (d[0] - 9.0).abs() < 0.01,
2419            "Tanhshrink(10) should be ~9, got {}",
2420            d[0]
2421        );
2422        assert!(
2423            (d[1] - (-9.0)).abs() < 0.01,
2424            "Tanhshrink(-10) should be ~-9, got {}",
2425            d[1]
2426        );
2427
2428        // Exact check: tanhshrink(1) = 1 - tanh(1)
2429        let x = t(&[1.0]);
2430        let y = m.forward(&x).unwrap();
2431        let expected = 1.0 - 1.0_f64.tanh();
2432        assert!(
2433            (y.data().unwrap()[0] - expected).abs() < 1e-7,
2434            "Tanhshrink(1) expected {}, got {}",
2435            expected,
2436            y.data().unwrap()[0]
2437        );
2438    }
2439
2440    #[test]
2441    fn test_tanhshrink_module_trait() {
2442        let mut m = Tanhshrink::new();
2443        assert_zero_param_module::<Tanhshrink, f64>(&mut m);
2444    }
2445
2446    // -----------------------------------------------------------------------
2447    // Softsign
2448    // -----------------------------------------------------------------------
2449
2450    #[test]
2451    fn test_softsign_forward() {
2452        let m = Softsign::new();
2453        // softsign(0) = 0
2454        let x = t(&[0.0]);
2455        let y = m.forward(&x).unwrap();
2456        assert!(y.data().unwrap()[0].abs() < 1e-7, "Softsign(0) should be 0");
2457
2458        // softsign(1) = 1/2 = 0.5
2459        let x = t(&[1.0]);
2460        let y = m.forward(&x).unwrap();
2461        assert!(
2462            (y.data().unwrap()[0] - 0.5).abs() < 1e-7,
2463            "Softsign(1) should be 0.5"
2464        );
2465
2466        // softsign(-1) = -1/2 = -0.5
2467        let x = t(&[-1.0]);
2468        let y = m.forward(&x).unwrap();
2469        assert!(
2470            (y.data().unwrap()[0] - (-0.5)).abs() < 1e-7,
2471            "Softsign(-1) should be -0.5"
2472        );
2473
2474        // Bounded in (-1, 1) for large values.
2475        let x = t(&[100.0, -100.0]);
2476        let y = m.forward(&x).unwrap();
2477        let d = y.data().unwrap();
2478        assert!(
2479            d[0] > 0.99 && d[0] < 1.0,
2480            "Softsign(100) should be ~1, got {}",
2481            d[0]
2482        );
2483        assert!(
2484            d[1] < -0.99 && d[1] > -1.0,
2485            "Softsign(-100) should be ~-1, got {}",
2486            d[1]
2487        );
2488    }
2489
2490    #[test]
2491    fn test_softsign_module_trait() {
2492        let mut m = Softsign::new();
2493        assert_zero_param_module::<Softsign, f64>(&mut m);
2494    }
2495
2496    // -----------------------------------------------------------------------
2497    // RReLU
2498    // -----------------------------------------------------------------------
2499
2500    #[test]
2501    #[allow(clippy::field_reassign_with_default)]
2502    fn test_rrelu_eval_forward() {
2503        // In eval mode, RReLU uses deterministic mean slope.
2504        let mut m = RReLU::default(); // lower=1/8, upper=1/3
2505        m.training = false;
2506        let mean_slope = (1.0 / 8.0 + 1.0 / 3.0) / 2.0;
2507
2508        let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
2509        let y = m.forward(&x).unwrap();
2510        let d = y.data().unwrap();
2511
2512        assert!(
2513            (d[0] - (-2.0 * mean_slope)).abs() < 1e-7,
2514            "RReLU(-2,eval) = {}",
2515            d[0]
2516        );
2517        assert!(
2518            (d[1] - (-mean_slope)).abs() < 1e-7,
2519            "RReLU(-1,eval) = {}",
2520            d[1]
2521        );
2522        assert!((d[2] - 0.0).abs() < 1e-7, "RReLU(0,eval) = {}", d[2]);
2523        assert!((d[3] - 1.0).abs() < 1e-7, "RReLU(1,eval) = {}", d[3]);
2524        assert!((d[4] - 2.0).abs() < 1e-7, "RReLU(2,eval) = {}", d[4]);
2525    }
2526
2527    #[test]
2528    fn test_rrelu_training_positive_passthrough() {
2529        // In training mode, positive values should pass through unchanged.
2530        let m = RReLU::default();
2531        let x = t(&[0.0, 1.0, 5.0, 100.0]);
2532        let y = m.forward(&x).unwrap();
2533        let d = y.data().unwrap();
2534        assert!((d[0] - 0.0).abs() < 1e-7);
2535        assert!((d[1] - 1.0).abs() < 1e-7);
2536        assert!((d[2] - 5.0).abs() < 1e-7);
2537        assert!((d[3] - 100.0).abs() < 1e-7);
2538    }
2539
2540    #[test]
2541    fn test_rrelu_training_negative_bounded() {
2542        // In training mode, negative outputs should be scaled by a slope in [lower, upper].
2543        let m = RReLU::new(0.1, 0.5);
2544        let x = t(&[-1.0; 100]); // 100 copies of -1
2545        let y = m.forward(&x).unwrap();
2546        let d = y.data().unwrap();
2547
2548        for (i, &val) in d.iter().enumerate() {
2549            // slope * (-1) should be in [-0.5, -0.1]
2550            assert!(
2551                (-0.5 - 1e-7..=-0.1 + 1e-7).contains(&val),
2552                "RReLU(-1, train)[{}] = {} not in [-0.5, -0.1]",
2553                i,
2554                val
2555            );
2556        }
2557
2558        // With 100 samples, we should see some variance (not all the same).
2559        let first = d[0];
2560        let has_variance = d.iter().any(|&v| (v - first).abs() > 1e-10);
2561        assert!(has_variance, "RReLU training should produce varying slopes");
2562    }
2563
2564    #[test]
2565    fn test_rrelu_module_trait() {
2566        let mut m = RReLU::default();
2567        assert_zero_param_module::<RReLU, f64>(&mut m);
2568    }
2569
2570    // -----------------------------------------------------------------------
2571    // Default constructors
2572    // -----------------------------------------------------------------------
2573
2574    #[test]
2575    fn test_defaults() {
2576        let _relu = ReLU::default();
2577        let _gelu = GELU::default();
2578        let _silu = SiLU::default();
2579        let _sigmoid = Sigmoid::default();
2580        let _tanh = Tanh::default();
2581        let _softmax = Softmax::default();
2582        let _log_softmax = LogSoftmax::default();
2583
2584        let lrelu = LeakyReLU::default();
2585        assert!((lrelu.negative_slope - 0.01).abs() < f64::EPSILON);
2586
2587        let elu = ELU::default();
2588        assert!((elu.alpha - 1.0).abs() < f64::EPSILON);
2589
2590        let _mish = Mish::default();
2591
2592        let celu = CELU::default();
2593        assert!((celu.alpha - 1.0).abs() < f64::EPSILON);
2594
2595        let _selu = SELU::default();
2596        let _hard_sigmoid = HardSigmoid::default();
2597        let _hard_swish = HardSwish::default();
2598
2599        let softplus = Softplus::default();
2600        assert!((softplus.beta - 1.0).abs() < f64::EPSILON);
2601
2602        let _glu = GLU::default();
2603
2604        // New activations
2605        let _relu6 = ReLU6::default();
2606
2607        let hardtanh = Hardtanh::default();
2608        assert!((hardtanh.min_val - (-1.0)).abs() < f64::EPSILON);
2609        assert!((hardtanh.max_val - 1.0).abs() < f64::EPSILON);
2610
2611        let _log_sigmoid = LogSigmoid::default();
2612        let _softmin = Softmin::default();
2613
2614        let softshrink = Softshrink::default();
2615        assert!((softshrink.lambda - 0.5).abs() < f64::EPSILON);
2616
2617        let hardshrink = Hardshrink::default();
2618        assert!((hardshrink.lambda - 0.5).abs() < f64::EPSILON);
2619
2620        let _tanhshrink = Tanhshrink::default();
2621        let _softsign = Softsign::default();
2622
2623        let rrelu = RReLU::default();
2624        assert!((rrelu.lower - 1.0 / 8.0).abs() < f64::EPSILON);
2625        assert!((rrelu.upper - 1.0 / 3.0).abs() < f64::EPSILON);
2626    }
2627
2628    // -----------------------------------------------------------------------
2629    // Send + Sync
2630    // -----------------------------------------------------------------------
2631
2632    #[test]
2633    fn test_send_sync() {
2634        fn assert_send_sync<T: Send + Sync>() {}
2635        assert_send_sync::<ReLU>();
2636        assert_send_sync::<GELU>();
2637        assert_send_sync::<SiLU>();
2638        assert_send_sync::<Sigmoid>();
2639        assert_send_sync::<Tanh>();
2640        assert_send_sync::<Softmax>();
2641        assert_send_sync::<LogSoftmax>();
2642        assert_send_sync::<LeakyReLU>();
2643        assert_send_sync::<ELU>();
2644        assert_send_sync::<Mish>();
2645        assert_send_sync::<PReLU<f64>>();
2646        assert_send_sync::<CELU>();
2647        assert_send_sync::<SELU>();
2648        assert_send_sync::<HardSigmoid>();
2649        assert_send_sync::<HardSwish>();
2650        assert_send_sync::<Softplus>();
2651        assert_send_sync::<GLU>();
2652        // New activations
2653        assert_send_sync::<ReLU6>();
2654        assert_send_sync::<Hardtanh>();
2655        assert_send_sync::<LogSigmoid>();
2656        assert_send_sync::<Softmin>();
2657        assert_send_sync::<Threshold>();
2658        assert_send_sync::<Softshrink>();
2659        assert_send_sync::<Hardshrink>();
2660        assert_send_sync::<Tanhshrink>();
2661        assert_send_sync::<Softsign>();
2662        assert_send_sync::<RReLU>();
2663    }
2664
2665    // -----------------------------------------------------------------------
2666    // Backward (autograd) tests for Softplus, ELU, Mish
2667    // -----------------------------------------------------------------------
2668
2669    /// Helper: 1-D tensor from a slice with `requires_grad = true`.
2670    fn t_grad(data: &[f64]) -> Tensor<f64> {
2671        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], true).unwrap()
2672    }
2673
2674    /// Helper: scalar leaf tensor with `requires_grad = true`.
2675    fn t_scalar_grad(val: f64) -> Tensor<f64> {
2676        Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], true).unwrap()
2677    }
2678
2679    /// Numerical gradient via central difference: (f(x+h) - f(x-h)) / (2h).
2680    fn numerical_grad(f: impl Fn(f64) -> f64, x: f64) -> f64 {
2681        let h = 1e-5;
2682        (f(x + h) - f(x - h)) / (2.0 * h)
2683    }
2684
2685    // -- Softplus backward --
2686
2687    #[test]
2688    fn test_softplus_backward_produces_grad() {
2689        let x = t_scalar_grad(1.0);
2690        let m = Softplus::new(1.0);
2691        let y = m.forward(&x).unwrap();
2692        ferrotorch_core::backward(&y).unwrap();
2693
2694        let grad = x.grad().unwrap();
2695        assert!(
2696            grad.is_some(),
2697            "Softplus backward should produce a gradient"
2698        );
2699    }
2700
2701    #[test]
2702    fn test_softplus_backward_at_zero() {
2703        let x = t_scalar_grad(0.0);
2704        let m = Softplus::new(1.0);
2705        let y = m.forward(&x).unwrap();
2706        ferrotorch_core::backward(&y).unwrap();
2707
2708        let grad = x.grad().unwrap().unwrap();
2709        // d/dx softplus(0) = sigmoid(0) = 0.5
2710        assert!(
2711            (grad.item().unwrap() - 0.5).abs() < 1e-6,
2712            "Softplus grad at x=0: expected 0.5, got {}",
2713            grad.item().unwrap()
2714        );
2715    }
2716
2717    #[test]
2718    fn test_softplus_backward_matches_numerical() {
2719        for &val in &[-2.0, -0.5, 0.0, 1.0, 3.0] {
2720            let x = t_scalar_grad(val);
2721            let m = Softplus::new(1.0);
2722            let y = m.forward(&x).unwrap();
2723            ferrotorch_core::backward(&y).unwrap();
2724
2725            let grad = x.grad().unwrap().unwrap();
2726            let expected = numerical_grad(|v| (1.0 + v.exp()).ln(), val);
2727            assert!(
2728                (grad.item().unwrap() - expected).abs() < 1e-4,
2729                "Softplus grad at x={}: expected {}, got {}",
2730                val,
2731                expected,
2732                grad.item().unwrap()
2733            );
2734        }
2735    }
2736
2737    #[test]
2738    fn test_softplus_backward_custom_beta() {
2739        let val = 1.0;
2740        let beta = 2.0;
2741        let x = t_scalar_grad(val);
2742        let m = Softplus::new(beta);
2743        let y = m.forward(&x).unwrap();
2744        ferrotorch_core::backward(&y).unwrap();
2745
2746        let grad = x.grad().unwrap().unwrap();
2747        let expected = numerical_grad(|v| (1.0 + (beta * v).exp()).ln() / beta, val);
2748        assert!(
2749            (grad.item().unwrap() - expected).abs() < 1e-4,
2750            "Softplus grad at x={}, beta={}: expected {}, got {}",
2751            val,
2752            beta,
2753            expected,
2754            grad.item().unwrap()
2755        );
2756    }
2757
2758    #[test]
2759    fn test_softplus_backward_vector() {
2760        let x = t_grad(&[-2.0, -0.5, 0.0, 1.0, 3.0]);
2761        let m = Softplus::new(1.0);
2762        let y = m.forward(&x).unwrap();
2763        // Sum to get a scalar for backward.
2764        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
2765        ferrotorch_core::backward(&sum).unwrap();
2766
2767        let grad = x.grad().unwrap().unwrap();
2768        let grad_data = grad.data().unwrap();
2769
2770        for (i, &val) in [-2.0_f64, -0.5, 0.0, 1.0, 3.0].iter().enumerate() {
2771            let expected = numerical_grad(|v| (1.0 + v.exp()).ln(), val);
2772            assert!(
2773                (grad_data[i] - expected).abs() < 1e-4,
2774                "Softplus grad[{}] at x={}: expected {}, got {}",
2775                i,
2776                val,
2777                expected,
2778                grad_data[i]
2779            );
2780        }
2781    }
2782
2783    // -- ELU backward --
2784
2785    #[test]
2786    fn test_elu_backward_produces_grad() {
2787        let x = t_scalar_grad(-1.0);
2788        let m = ELU::new(1.0);
2789        let y = m.forward(&x).unwrap();
2790        ferrotorch_core::backward(&y).unwrap();
2791
2792        let grad = x.grad().unwrap();
2793        assert!(grad.is_some(), "ELU backward should produce a gradient");
2794    }
2795
2796    #[test]
2797    fn test_elu_backward_positive() {
2798        let x = t_scalar_grad(2.0);
2799        let m = ELU::new(1.0);
2800        let y = m.forward(&x).unwrap();
2801        ferrotorch_core::backward(&y).unwrap();
2802
2803        let grad = x.grad().unwrap().unwrap();
2804        // d/dx elu(x) at x=2 (positive) = 1.
2805        assert!(
2806            (grad.item().unwrap() - 1.0).abs() < 1e-6,
2807            "ELU grad at x=2: expected 1.0, got {}",
2808            grad.item().unwrap()
2809        );
2810    }
2811
2812    #[test]
2813    fn test_elu_backward_matches_numerical() {
2814        let alpha = 1.0;
2815        for &val in &[-2.0, -1.0, -0.5, 0.5, 2.0] {
2816            let x = t_scalar_grad(val);
2817            let m = ELU::new(alpha);
2818            let y = m.forward(&x).unwrap();
2819            ferrotorch_core::backward(&y).unwrap();
2820
2821            let grad = x.grad().unwrap().unwrap();
2822            let expected =
2823                numerical_grad(|v| if v > 0.0 { v } else { alpha * (v.exp() - 1.0) }, val);
2824            assert!(
2825                (grad.item().unwrap() - expected).abs() < 1e-4,
2826                "ELU grad at x={}: expected {}, got {}",
2827                val,
2828                expected,
2829                grad.item().unwrap()
2830            );
2831        }
2832    }
2833
2834    #[test]
2835    fn test_elu_backward_custom_alpha() {
2836        let alpha = 2.0;
2837        let val = -0.5;
2838        let x = t_scalar_grad(val);
2839        let m = ELU::new(alpha);
2840        let y = m.forward(&x).unwrap();
2841        ferrotorch_core::backward(&y).unwrap();
2842
2843        let grad = x.grad().unwrap().unwrap();
2844        // d/dx [alpha * (exp(x) - 1)] = alpha * exp(x)
2845        let expected = alpha * val.exp();
2846        assert!(
2847            (grad.item().unwrap() - expected).abs() < 1e-5,
2848            "ELU grad at x={}, alpha={}: expected {}, got {}",
2849            val,
2850            alpha,
2851            expected,
2852            grad.item().unwrap()
2853        );
2854    }
2855
2856    // -- Mish backward --
2857
2858    #[test]
2859    fn test_mish_backward_produces_grad() {
2860        let x = t_scalar_grad(1.0);
2861        let m = Mish::new();
2862        let y = m.forward(&x).unwrap();
2863        ferrotorch_core::backward(&y).unwrap();
2864
2865        let grad = x.grad().unwrap();
2866        assert!(grad.is_some(), "Mish backward should produce a gradient");
2867    }
2868
2869    #[test]
2870    fn test_mish_backward_matches_numerical() {
2871        let mish_fn = |v: f64| {
2872            let sp = (1.0 + v.exp()).ln();
2873            v * sp.tanh()
2874        };
2875
2876        for &val in &[-2.0, -1.0, 0.0, 0.5, 1.5, 3.0] {
2877            let x = t_scalar_grad(val);
2878            let m = Mish::new();
2879            let y = m.forward(&x).unwrap();
2880            ferrotorch_core::backward(&y).unwrap();
2881
2882            let grad = x.grad().unwrap().unwrap();
2883            let expected = numerical_grad(mish_fn, val);
2884            assert!(
2885                (grad.item().unwrap() - expected).abs() < 1e-4,
2886                "Mish grad at x={}: expected {}, got {}",
2887                val,
2888                expected,
2889                grad.item().unwrap()
2890            );
2891        }
2892    }
2893
2894    #[test]
2895    fn test_mish_backward_vector() {
2896        let x = t_grad(&[-1.0, 0.0, 1.0, 2.0]);
2897        let m = Mish::new();
2898        let y = m.forward(&x).unwrap();
2899        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
2900        ferrotorch_core::backward(&sum).unwrap();
2901
2902        let grad = x.grad().unwrap().unwrap();
2903        let grad_data = grad.data().unwrap();
2904
2905        let mish_fn = |v: f64| {
2906            let sp = (1.0 + v.exp()).ln();
2907            v * sp.tanh()
2908        };
2909
2910        for (i, &val) in [-1.0_f64, 0.0, 1.0, 2.0].iter().enumerate() {
2911            let expected = numerical_grad(mish_fn, val);
2912            assert!(
2913                (grad_data[i] - expected).abs() < 1e-4,
2914                "Mish grad[{}] at x={}: expected {}, got {}",
2915                i,
2916                val,
2917                expected,
2918                grad_data[i]
2919            );
2920        }
2921    }
2922
2923    // -----------------------------------------------------------------------
2924    // Backward (autograd) tests for new activations
2925    // -----------------------------------------------------------------------
2926
2927    // -- ReLU6 backward --
2928
2929    #[test]
2930    fn test_relu6_backward_matches_numerical() {
2931        let relu6_fn = |v: f64| v.clamp(0.0, 6.0);
2932
2933        for &val in &[-2.0, 0.5, 3.0, 5.5, 8.0] {
2934            let x = t_scalar_grad(val);
2935            let m = ReLU6::new();
2936            let y = m.forward(&x).unwrap();
2937            ferrotorch_core::backward(&y).unwrap();
2938
2939            let grad = x.grad().unwrap().unwrap();
2940            let expected = numerical_grad(relu6_fn, val);
2941            assert!(
2942                (grad.item().unwrap() - expected).abs() < 1e-4,
2943                "ReLU6 grad at x={}: expected {}, got {}",
2944                val,
2945                expected,
2946                grad.item().unwrap()
2947            );
2948        }
2949    }
2950
2951    // -- Hardtanh backward --
2952
2953    #[test]
2954    fn test_hardtanh_backward_matches_numerical() {
2955        let hardtanh_fn = |v: f64| v.clamp(-1.0, 1.0);
2956
2957        for &val in &[-2.0, -0.5, 0.0, 0.5, 2.0] {
2958            let x = t_scalar_grad(val);
2959            let m = Hardtanh::default();
2960            let y = m.forward(&x).unwrap();
2961            ferrotorch_core::backward(&y).unwrap();
2962
2963            let grad = x.grad().unwrap().unwrap();
2964            let expected = numerical_grad(hardtanh_fn, val);
2965            assert!(
2966                (grad.item().unwrap() - expected).abs() < 1e-4,
2967                "Hardtanh grad at x={}: expected {}, got {}",
2968                val,
2969                expected,
2970                grad.item().unwrap()
2971            );
2972        }
2973    }
2974
2975    // -- LogSigmoid backward --
2976
2977    #[test]
2978    fn test_log_sigmoid_backward_matches_numerical() {
2979        let logsigmoid_fn = |v: f64| {
2980            // log(sigmoid(v)) = -softplus(-v) = -ln(1+exp(-v))
2981            -(1.0 + (-v).exp()).ln()
2982        };
2983
2984        for &val in &[-3.0, -1.0, 0.0, 1.0, 3.0] {
2985            let x = t_scalar_grad(val);
2986            let m = LogSigmoid::new();
2987            let y = m.forward(&x).unwrap();
2988            ferrotorch_core::backward(&y).unwrap();
2989
2990            let grad = x.grad().unwrap().unwrap();
2991            let expected = numerical_grad(logsigmoid_fn, val);
2992            assert!(
2993                (grad.item().unwrap() - expected).abs() < 1e-4,
2994                "LogSigmoid grad at x={}: expected {}, got {}",
2995                val,
2996                expected,
2997                grad.item().unwrap()
2998            );
2999        }
3000    }
3001
3002    // -- Tanhshrink backward --
3003
3004    #[test]
3005    fn test_tanhshrink_backward_matches_numerical() {
3006        let tanhshrink_fn = |v: f64| v - v.tanh();
3007
3008        for &val in &[-2.0, -0.5, 0.0, 0.5, 2.0] {
3009            let x = t_scalar_grad(val);
3010            let m = Tanhshrink::new();
3011            let y = m.forward(&x).unwrap();
3012            ferrotorch_core::backward(&y).unwrap();
3013
3014            let grad = x.grad().unwrap().unwrap();
3015            let expected = numerical_grad(tanhshrink_fn, val);
3016            assert!(
3017                (grad.item().unwrap() - expected).abs() < 1e-4,
3018                "Tanhshrink grad at x={}: expected {}, got {}",
3019                val,
3020                expected,
3021                grad.item().unwrap()
3022            );
3023        }
3024    }
3025
3026    // -----------------------------------------------------------------------
3027    // State dict round-trip (empty for all activations)
3028    // -----------------------------------------------------------------------
3029
3030    #[test]
3031    fn test_state_dict_empty() {
3032        let m = ReLU::new();
3033        let sd = Module::<f64>::state_dict(&m);
3034        assert!(sd.is_empty());
3035    }
3036
3037    /// #1451: `Softmax2d::forward` on a CUDA-resident `[N, C, H, W]` input must
3038    /// route through `GpuBackend::softmax2d_f32` (channel-axis softmax) and
3039    /// match the CPU path within f32 tolerance.
3040    ///
3041    /// Gated `#[ignore]` because it needs real CUDA hardware (the build host
3042    /// has none); it documents the expected GPU↔CPU parity for a future
3043    /// CUDA-host run. Tracking #1451.
3044    #[test]
3045    #[cfg(feature = "cuda")]
3046    #[ignore = "needs CUDA hardware; tracking #1451"]
3047    fn softmax2d_forward_gpu_matches_cpu() {
3048        use ferrotorch_core::Device;
3049        use ferrotorch_core::storage::TensorStorage;
3050        use ferrotorch_gpu::init_cuda_backend;
3051        init_cuda_backend().expect("CUDA init failed");
3052
3053        // [N=2, C=5, H=3, W=4].
3054        let (n, c, h, w) = (2usize, 5usize, 3usize, 4usize);
3055        let total = n * c * h * w;
3056        let data: Vec<f32> = (0..total).map(|k| ((k % 11) as f32) * 0.3 - 1.4).collect();
3057
3058        let sm = Softmax2d::new();
3059
3060        let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![n, c, h, w], false)
3061            .unwrap();
3062        let y_cpu = sm.forward(&x_cpu).unwrap();
3063        let cpu_vals = y_cpu.data().unwrap().to_vec();
3064
3065        let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
3066        let y_gpu = sm.forward(&x_gpu).unwrap();
3067        assert!(y_gpu.is_cuda(), "Softmax2d GPU output must stay on CUDA");
3068        let gpu_vals = y_gpu.data_vec().unwrap();
3069
3070        assert_eq!(gpu_vals.len(), cpu_vals.len());
3071        let mut max_abs = 0.0f32;
3072        for (g, c) in gpu_vals.iter().zip(cpu_vals.iter()) {
3073            max_abs = max_abs.max((g - c).abs());
3074        }
3075        assert!(max_abs < 1e-4, "Softmax2d GPU vs CPU max|Δ| = {max_abs}");
3076    }
3077}