Skip to main content

axonml_nn/
activation.rs

1//! Activation Modules - Non-linear Activation Functions
2//!
3//! # File
4//! `crates/axonml-nn/src/activation.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_autograd::Variable;
18
19use crate::module::Module;
20
21// =============================================================================
22// ReLU
23// =============================================================================
24
25/// Applies the rectified linear unit function element-wise.
26///
27/// ReLU(x) = max(0, x)
28#[derive(Debug, Clone, Copy, Default)]
29pub struct ReLU;
30
31impl ReLU {
32    /// Creates a new ReLU activation.
33    pub fn new() -> Self {
34        Self
35    }
36}
37
38impl Module for ReLU {
39    fn forward(&self, input: &Variable) -> Variable {
40        input.relu()
41    }
42
43    fn name(&self) -> &'static str {
44        "ReLU"
45    }
46}
47
48// =============================================================================
49// LeakyReLU
50// =============================================================================
51
52/// Applies the leaky ReLU function element-wise.
53///
54/// LeakyReLU(x) = max(0, x) + negative_slope * min(0, x)
55#[derive(Debug, Clone, Copy)]
56pub struct LeakyReLU {
57    negative_slope: f32,
58}
59
60impl LeakyReLU {
61    /// Creates a new LeakyReLU with default negative slope (0.01).
62    pub fn new() -> Self {
63        Self {
64            negative_slope: 0.01,
65        }
66    }
67
68    /// Creates a LeakyReLU with custom negative slope.
69    pub fn with_slope(negative_slope: f32) -> Self {
70        Self { negative_slope }
71    }
72}
73
74impl Default for LeakyReLU {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl Module for LeakyReLU {
81    fn forward(&self, input: &Variable) -> Variable {
82        input.leaky_relu(self.negative_slope)
83    }
84
85    fn name(&self) -> &'static str {
86        "LeakyReLU"
87    }
88}
89
90// =============================================================================
91// Sigmoid
92// =============================================================================
93
94/// Applies the sigmoid function element-wise.
95///
96/// Sigmoid(x) = 1 / (1 + exp(-x))
97#[derive(Debug, Clone, Copy, Default)]
98pub struct Sigmoid;
99
100impl Sigmoid {
101    /// Creates a new Sigmoid activation.
102    pub fn new() -> Self {
103        Self
104    }
105}
106
107impl Module for Sigmoid {
108    fn forward(&self, input: &Variable) -> Variable {
109        input.sigmoid()
110    }
111
112    fn name(&self) -> &'static str {
113        "Sigmoid"
114    }
115}
116
117// =============================================================================
118// Tanh
119// =============================================================================
120
121/// Applies the hyperbolic tangent function element-wise.
122///
123/// Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
124#[derive(Debug, Clone, Copy, Default)]
125pub struct Tanh;
126
127impl Tanh {
128    /// Creates a new Tanh activation.
129    pub fn new() -> Self {
130        Self
131    }
132}
133
134impl Module for Tanh {
135    fn forward(&self, input: &Variable) -> Variable {
136        input.tanh()
137    }
138
139    fn name(&self) -> &'static str {
140        "Tanh"
141    }
142}
143
144// =============================================================================
145// Softmax
146// =============================================================================
147
148/// Applies the softmax function along a dimension.
149///
150/// Softmax(x_i) = exp(x_i) / sum(exp(x_j))
151#[derive(Debug, Clone, Copy)]
152pub struct Softmax {
153    dim: i64,
154}
155
156impl Softmax {
157    /// Creates a new Softmax along the specified dimension.
158    pub fn new(dim: i64) -> Self {
159        Self { dim }
160    }
161}
162
163impl Default for Softmax {
164    fn default() -> Self {
165        Self::new(-1)
166    }
167}
168
169impl Module for Softmax {
170    fn forward(&self, input: &Variable) -> Variable {
171        input.softmax(self.dim as i32)
172    }
173
174    fn name(&self) -> &'static str {
175        "Softmax"
176    }
177}
178
179// =============================================================================
180// LogSoftmax
181// =============================================================================
182
183/// Applies log(softmax(x)) along a dimension.
184#[derive(Debug, Clone, Copy)]
185pub struct LogSoftmax {
186    dim: i64,
187}
188
189impl LogSoftmax {
190    /// Creates a new LogSoftmax along the specified dimension.
191    pub fn new(dim: i64) -> Self {
192        Self { dim }
193    }
194}
195
196impl Default for LogSoftmax {
197    fn default() -> Self {
198        Self::new(-1)
199    }
200}
201
202impl Module for LogSoftmax {
203    fn forward(&self, input: &Variable) -> Variable {
204        input.log_softmax(self.dim as i32)
205    }
206
207    fn name(&self) -> &'static str {
208        "LogSoftmax"
209    }
210}
211
212// =============================================================================
213// GELU
214// =============================================================================
215
216/// Applies the Gaussian Error Linear Unit function.
217///
218/// GELU(x) = x * Phi(x) where Phi is the CDF of standard normal distribution.
219#[derive(Debug, Clone, Copy, Default)]
220pub struct GELU;
221
222impl GELU {
223    /// Creates a new GELU activation.
224    pub fn new() -> Self {
225        Self
226    }
227}
228
229impl Module for GELU {
230    fn forward(&self, input: &Variable) -> Variable {
231        input.gelu()
232    }
233
234    fn name(&self) -> &'static str {
235        "GELU"
236    }
237}
238
239// =============================================================================
240// SiLU / Swish
241// =============================================================================
242
243/// Applies the SiLU (Swish) function element-wise.
244///
245/// SiLU(x) = x * sigmoid(x)
246#[derive(Debug, Clone, Copy, Default)]
247pub struct SiLU;
248
249impl SiLU {
250    /// Creates a new SiLU activation.
251    pub fn new() -> Self {
252        Self
253    }
254}
255
256impl Module for SiLU {
257    fn forward(&self, input: &Variable) -> Variable {
258        let sigmoid = input.sigmoid();
259        input.mul_var(&sigmoid)
260    }
261
262    fn name(&self) -> &'static str {
263        "SiLU"
264    }
265}
266
267// =============================================================================
268// ELU
269// =============================================================================
270
271/// Applies the Exponential Linear Unit function.
272///
273/// ELU(x) = x if x > 0, else alpha * (exp(x) - 1)
274#[derive(Debug, Clone, Copy)]
275pub struct ELU {
276    alpha: f32,
277}
278
279impl ELU {
280    /// Creates a new ELU with default alpha (1.0).
281    pub fn new() -> Self {
282        Self { alpha: 1.0 }
283    }
284
285    /// Creates an ELU with custom alpha.
286    pub fn with_alpha(alpha: f32) -> Self {
287        Self { alpha }
288    }
289}
290
291impl Default for ELU {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297impl Module for ELU {
298    fn forward(&self, input: &Variable) -> Variable {
299        input.elu(self.alpha)
300    }
301
302    fn name(&self) -> &'static str {
303        "ELU"
304    }
305}
306
307// =============================================================================
308// Identity
309// =============================================================================
310
311/// Identity activation (no-op).
312#[derive(Debug, Clone, Copy, Default)]
313pub struct Identity;
314
315impl Identity {
316    /// Creates a new Identity activation.
317    pub fn new() -> Self {
318        Self
319    }
320}
321
322impl Module for Identity {
323    fn forward(&self, input: &Variable) -> Variable {
324        input.clone()
325    }
326
327    fn name(&self) -> &'static str {
328        "Identity"
329    }
330}
331
332// =============================================================================
333// Flatten
334// =============================================================================
335
336/// Flattens all dimensions from `start_dim` to the end into a single dimension.
337///
338/// Default: `start_dim = 1` (preserves batch dimension).
339///
340/// # Examples
341/// ```ignore
342/// let flatten = Flatten::new();      // flattens from dim 1 (batch preserved)
343/// let flat_all = Flatten::from(0);   // flattens everything
344/// ```
345#[derive(Debug, Clone, Copy)]
346pub struct Flatten {
347    start_dim: usize,
348}
349
350impl Flatten {
351    /// Creates a Flatten module that flattens from dimension 1 (preserves batch dim).
352    pub fn new() -> Self {
353        Self { start_dim: 1 }
354    }
355
356    /// Creates a Flatten module that flattens from the specified start dimension.
357    pub fn from(start_dim: usize) -> Self {
358        Self { start_dim }
359    }
360}
361
362impl Default for Flatten {
363    fn default() -> Self {
364        Self::new()
365    }
366}
367
368impl Module for Flatten {
369    fn forward(&self, input: &Variable) -> Variable {
370        input.flatten(self.start_dim)
371    }
372
373    fn parameters(&self) -> Vec<crate::Parameter> {
374        Vec::new()
375    }
376
377    fn name(&self) -> &'static str {
378        "Flatten"
379    }
380}
381
382// =============================================================================
383// Tests
384// =============================================================================
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use axonml_tensor::Tensor;
390
391    #[test]
392    fn test_relu() {
393        let relu = ReLU::new();
394        let input = Variable::new(
395            Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
396            false,
397        );
398        let output = relu.forward(&input);
399        assert_eq!(output.data().to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
400    }
401
402    #[test]
403    fn test_sigmoid() {
404        let sigmoid = Sigmoid::new();
405        let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
406        let output = sigmoid.forward(&input);
407        assert!((output.data().to_vec()[0] - 0.5).abs() < 1e-6);
408    }
409
410    #[test]
411    fn test_softmax() {
412        let softmax = Softmax::new(-1);
413        let input = Variable::new(
414            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
415            false,
416        );
417        let output = softmax.forward(&input);
418        let sum: f32 = output.data().to_vec().iter().sum();
419        assert!((sum - 1.0).abs() < 1e-5);
420    }
421
422    #[test]
423    fn test_leaky_relu() {
424        let leaky = LeakyReLU::with_slope(0.1);
425        let input = Variable::new(Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap(), false);
426        let output = leaky.forward(&input);
427        assert_eq!(output.data().to_vec(), vec![-0.1, 0.0, 1.0]);
428    }
429
430    #[test]
431    fn test_identity() {
432        let id = Identity::new();
433        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
434        let output = id.forward(&input);
435        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
436    }
437}