Skip to main content

axonml_nn/
activation.rs

1//! Activation function modules implementing the `Module` trait.
2//!
3//! 437 lines. `ReLU`, `Sigmoid`, `Tanh`, `GELU`, `SiLU` (Swish), `ELU`,
4//! `LeakyReLU` (configurable negative_slope), `Mish`, `Softmax` (with dim),
5//! `LogSoftmax`, `Identity`, `Flatten`. Each wraps the corresponding
6//! `Variable` method in a `Module::forward` so it can be used in `Sequential`
7//! or `ModuleList` chains. Stateless (no parameters, no train/eval difference).
8//!
9//! # File
10//! `crates/axonml-nn/src/activation.rs`
11//!
12//! # Author
13//! Andrew Jewell Sr. — AutomataNexus LLC
14//! ORCID: 0009-0005-2158-7060
15//!
16//! # Updated
17//! April 14, 2026 11:15 PM EST
18//!
19//! # Disclaimer
20//! Use at own risk. This software is provided "as is", without warranty of any
21//! kind, express or implied. The author and AutomataNexus shall not be held
22//! liable for any damages arising from the use of this software.
23
24use axonml_autograd::Variable;
25
26use crate::module::Module;
27
28// =============================================================================
29// ReLU
30// =============================================================================
31
32/// Applies the rectified linear unit function element-wise.
33///
34/// ReLU(x) = max(0, x)
35#[derive(Debug, Clone, Copy, Default)]
36pub struct ReLU;
37
38impl ReLU {
39    /// Creates a new ReLU activation.
40    pub fn new() -> Self {
41        Self
42    }
43}
44
45impl Module for ReLU {
46    fn forward(&self, input: &Variable) -> Variable {
47        input.relu()
48    }
49
50    fn name(&self) -> &'static str {
51        "ReLU"
52    }
53}
54
55// =============================================================================
56// LeakyReLU
57// =============================================================================
58
59/// Applies the leaky ReLU function element-wise.
60///
61/// LeakyReLU(x) = max(0, x) + negative_slope * min(0, x)
62#[derive(Debug, Clone, Copy)]
63pub struct LeakyReLU {
64    negative_slope: f32,
65}
66
67impl LeakyReLU {
68    /// Creates a new LeakyReLU with default negative slope (0.01).
69    pub fn new() -> Self {
70        Self {
71            negative_slope: 0.01,
72        }
73    }
74
75    /// Creates a LeakyReLU with custom negative slope.
76    pub fn with_slope(negative_slope: f32) -> Self {
77        Self { negative_slope }
78    }
79}
80
81impl Default for LeakyReLU {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl Module for LeakyReLU {
88    fn forward(&self, input: &Variable) -> Variable {
89        input.leaky_relu(self.negative_slope)
90    }
91
92    fn name(&self) -> &'static str {
93        "LeakyReLU"
94    }
95}
96
97// =============================================================================
98// Sigmoid
99// =============================================================================
100
101/// Applies the sigmoid function element-wise.
102///
103/// Sigmoid(x) = 1 / (1 + exp(-x))
104#[derive(Debug, Clone, Copy, Default)]
105pub struct Sigmoid;
106
107impl Sigmoid {
108    /// Creates a new Sigmoid activation.
109    pub fn new() -> Self {
110        Self
111    }
112}
113
114impl Module for Sigmoid {
115    fn forward(&self, input: &Variable) -> Variable {
116        input.sigmoid()
117    }
118
119    fn name(&self) -> &'static str {
120        "Sigmoid"
121    }
122}
123
124// =============================================================================
125// Tanh
126// =============================================================================
127
128/// Applies the hyperbolic tangent function element-wise.
129///
130/// Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
131#[derive(Debug, Clone, Copy, Default)]
132pub struct Tanh;
133
134impl Tanh {
135    /// Creates a new Tanh activation.
136    pub fn new() -> Self {
137        Self
138    }
139}
140
141impl Module for Tanh {
142    fn forward(&self, input: &Variable) -> Variable {
143        input.tanh()
144    }
145
146    fn name(&self) -> &'static str {
147        "Tanh"
148    }
149}
150
151// =============================================================================
152// Softmax
153// =============================================================================
154
155/// Applies the softmax function along a dimension.
156///
157/// Softmax(x_i) = exp(x_i) / sum(exp(x_j))
158#[derive(Debug, Clone, Copy)]
159pub struct Softmax {
160    dim: i64,
161}
162
163impl Softmax {
164    /// Creates a new Softmax along the specified dimension.
165    pub fn new(dim: i64) -> Self {
166        Self { dim }
167    }
168}
169
170impl Default for Softmax {
171    fn default() -> Self {
172        Self::new(-1)
173    }
174}
175
176impl Module for Softmax {
177    fn forward(&self, input: &Variable) -> Variable {
178        input.softmax(self.dim as i32)
179    }
180
181    fn name(&self) -> &'static str {
182        "Softmax"
183    }
184}
185
186// =============================================================================
187// LogSoftmax
188// =============================================================================
189
190/// Applies log(softmax(x)) along a dimension.
191#[derive(Debug, Clone, Copy)]
192pub struct LogSoftmax {
193    dim: i64,
194}
195
196impl LogSoftmax {
197    /// Creates a new LogSoftmax along the specified dimension.
198    pub fn new(dim: i64) -> Self {
199        Self { dim }
200    }
201}
202
203impl Default for LogSoftmax {
204    fn default() -> Self {
205        Self::new(-1)
206    }
207}
208
209impl Module for LogSoftmax {
210    fn forward(&self, input: &Variable) -> Variable {
211        input.log_softmax(self.dim as i32)
212    }
213
214    fn name(&self) -> &'static str {
215        "LogSoftmax"
216    }
217}
218
219// =============================================================================
220// GELU
221// =============================================================================
222
223/// Applies the Gaussian Error Linear Unit function.
224///
225/// GELU(x) = x * Phi(x) where Phi is the CDF of standard normal distribution.
226#[derive(Debug, Clone, Copy, Default)]
227pub struct GELU;
228
229impl GELU {
230    /// Creates a new GELU activation.
231    pub fn new() -> Self {
232        Self
233    }
234}
235
236impl Module for GELU {
237    fn forward(&self, input: &Variable) -> Variable {
238        input.gelu()
239    }
240
241    fn name(&self) -> &'static str {
242        "GELU"
243    }
244}
245
246// =============================================================================
247// SiLU / Swish
248// =============================================================================
249
250/// Applies the SiLU (Swish) function element-wise.
251///
252/// SiLU(x) = x * sigmoid(x)
253#[derive(Debug, Clone, Copy, Default)]
254pub struct SiLU;
255
256impl SiLU {
257    /// Creates a new SiLU activation.
258    pub fn new() -> Self {
259        Self
260    }
261}
262
263impl Module for SiLU {
264    fn forward(&self, input: &Variable) -> Variable {
265        let sigmoid = input.sigmoid();
266        input.mul_var(&sigmoid)
267    }
268
269    fn name(&self) -> &'static str {
270        "SiLU"
271    }
272}
273
274// =============================================================================
275// ELU
276// =============================================================================
277
278/// Applies the Exponential Linear Unit function.
279///
280/// ELU(x) = x if x > 0, else alpha * (exp(x) - 1)
281#[derive(Debug, Clone, Copy)]
282pub struct ELU {
283    alpha: f32,
284}
285
286impl ELU {
287    /// Creates a new ELU with default alpha (1.0).
288    pub fn new() -> Self {
289        Self { alpha: 1.0 }
290    }
291
292    /// Creates an ELU with custom alpha.
293    pub fn with_alpha(alpha: f32) -> Self {
294        Self { alpha }
295    }
296}
297
298impl Default for ELU {
299    fn default() -> Self {
300        Self::new()
301    }
302}
303
304impl Module for ELU {
305    fn forward(&self, input: &Variable) -> Variable {
306        input.elu(self.alpha)
307    }
308
309    fn name(&self) -> &'static str {
310        "ELU"
311    }
312}
313
314// =============================================================================
315// Identity
316// =============================================================================
317
318/// Identity activation (no-op).
319#[derive(Debug, Clone, Copy, Default)]
320pub struct Identity;
321
322impl Identity {
323    /// Creates a new Identity activation.
324    pub fn new() -> Self {
325        Self
326    }
327}
328
329impl Module for Identity {
330    fn forward(&self, input: &Variable) -> Variable {
331        input.clone()
332    }
333
334    fn name(&self) -> &'static str {
335        "Identity"
336    }
337}
338
339// =============================================================================
340// Flatten
341// =============================================================================
342
343/// Flattens all dimensions from `start_dim` to the end into a single dimension.
344///
345/// Default: `start_dim = 1` (preserves batch dimension).
346///
347/// # Examples
348/// ```ignore
349/// let flatten = Flatten::new();      // flattens from dim 1 (batch preserved)
350/// let flat_all = Flatten::from(0);   // flattens everything
351/// ```
352#[derive(Debug, Clone, Copy)]
353pub struct Flatten {
354    start_dim: usize,
355}
356
357impl Flatten {
358    /// Creates a Flatten module that flattens from dimension 1 (preserves batch dim).
359    pub fn new() -> Self {
360        Self { start_dim: 1 }
361    }
362
363    /// Creates a Flatten module that flattens from the specified start dimension.
364    pub fn from(start_dim: usize) -> Self {
365        Self { start_dim }
366    }
367}
368
369impl Default for Flatten {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375impl Module for Flatten {
376    fn forward(&self, input: &Variable) -> Variable {
377        input.flatten(self.start_dim)
378    }
379
380    fn parameters(&self) -> Vec<crate::Parameter> {
381        Vec::new()
382    }
383
384    fn name(&self) -> &'static str {
385        "Flatten"
386    }
387}
388
389// =============================================================================
390// Tests
391// =============================================================================
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use axonml_tensor::Tensor;
397
398    #[test]
399    fn test_relu() {
400        let relu = ReLU::new();
401        let input = Variable::new(
402            Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
403            false,
404        );
405        let output = relu.forward(&input);
406        assert_eq!(output.data().to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
407    }
408
409    #[test]
410    fn test_sigmoid() {
411        let sigmoid = Sigmoid::new();
412        let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
413        let output = sigmoid.forward(&input);
414        assert!((output.data().to_vec()[0] - 0.5).abs() < 1e-6);
415    }
416
417    #[test]
418    fn test_softmax() {
419        let softmax = Softmax::new(-1);
420        let input = Variable::new(
421            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
422            false,
423        );
424        let output = softmax.forward(&input);
425        let sum: f32 = output.data().to_vec().iter().sum();
426        assert!((sum - 1.0).abs() < 1e-5);
427    }
428
429    #[test]
430    fn test_leaky_relu() {
431        let leaky = LeakyReLU::with_slope(0.1);
432        let input = Variable::new(Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap(), false);
433        let output = leaky.forward(&input);
434        assert_eq!(output.data().to_vec(), vec![-0.1, 0.0, 1.0]);
435    }
436
437    #[test]
438    fn test_identity() {
439        let id = Identity::new();
440        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
441        let output = id.forward(&input);
442        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
443    }
444}