axonml_nn/activation.rs
1//! Activation function modules implementing the `Module` trait.
2//!
3//! 437 lines. `ReLU`, `Sigmoid`, `Tanh`, `GELU`, `SiLU` (Swish), `ELU`,
4//! `LeakyReLU` (configurable negative_slope), `Mish`, `Softmax` (with dim),
5//! `LogSoftmax`, `Identity`, `Flatten`. Each wraps the corresponding
6//! `Variable` method in a `Module::forward` so it can be used in `Sequential`
7//! or `ModuleList` chains. Stateless (no parameters, no train/eval difference).
8//!
9//! # File
10//! `crates/axonml-nn/src/activation.rs`
11//!
12//! # Author
13//! Andrew Jewell Sr. — AutomataNexus LLC
14//! ORCID: 0009-0005-2158-7060
15//!
16//! # Updated
17//! April 14, 2026 11:15 PM EST
18//!
19//! # Disclaimer
20//! Use at own risk. This software is provided "as is", without warranty of any
21//! kind, express or implied. The author and AutomataNexus shall not be held
22//! liable for any damages arising from the use of this software.
23
24use axonml_autograd::Variable;
25
26use crate::module::Module;
27
28// =============================================================================
29// ReLU
30// =============================================================================
31
32/// Applies the rectified linear unit function element-wise.
33///
34/// ReLU(x) = max(0, x)
35#[derive(Debug, Clone, Copy, Default)]
36pub struct ReLU;
37
38impl ReLU {
39 /// Creates a new ReLU activation.
40 pub fn new() -> Self {
41 Self
42 }
43}
44
45impl Module for ReLU {
46 fn forward(&self, input: &Variable) -> Variable {
47 input.relu()
48 }
49
50 fn name(&self) -> &'static str {
51 "ReLU"
52 }
53}
54
55// =============================================================================
56// LeakyReLU
57// =============================================================================
58
59/// Applies the leaky ReLU function element-wise.
60///
61/// LeakyReLU(x) = max(0, x) + negative_slope * min(0, x)
62#[derive(Debug, Clone, Copy)]
63pub struct LeakyReLU {
64 negative_slope: f32,
65}
66
67impl LeakyReLU {
68 /// Creates a new LeakyReLU with default negative slope (0.01).
69 pub fn new() -> Self {
70 Self {
71 negative_slope: 0.01,
72 }
73 }
74
75 /// Creates a LeakyReLU with custom negative slope.
76 pub fn with_slope(negative_slope: f32) -> Self {
77 Self { negative_slope }
78 }
79}
80
81impl Default for LeakyReLU {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl Module for LeakyReLU {
88 fn forward(&self, input: &Variable) -> Variable {
89 input.leaky_relu(self.negative_slope)
90 }
91
92 fn name(&self) -> &'static str {
93 "LeakyReLU"
94 }
95}
96
97// =============================================================================
98// Sigmoid
99// =============================================================================
100
101/// Applies the sigmoid function element-wise.
102///
103/// Sigmoid(x) = 1 / (1 + exp(-x))
104#[derive(Debug, Clone, Copy, Default)]
105pub struct Sigmoid;
106
107impl Sigmoid {
108 /// Creates a new Sigmoid activation.
109 pub fn new() -> Self {
110 Self
111 }
112}
113
114impl Module for Sigmoid {
115 fn forward(&self, input: &Variable) -> Variable {
116 input.sigmoid()
117 }
118
119 fn name(&self) -> &'static str {
120 "Sigmoid"
121 }
122}
123
124// =============================================================================
125// Tanh
126// =============================================================================
127
128/// Applies the hyperbolic tangent function element-wise.
129///
130/// Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
131#[derive(Debug, Clone, Copy, Default)]
132pub struct Tanh;
133
134impl Tanh {
135 /// Creates a new Tanh activation.
136 pub fn new() -> Self {
137 Self
138 }
139}
140
141impl Module for Tanh {
142 fn forward(&self, input: &Variable) -> Variable {
143 input.tanh()
144 }
145
146 fn name(&self) -> &'static str {
147 "Tanh"
148 }
149}
150
151// =============================================================================
152// Softmax
153// =============================================================================
154
155/// Applies the softmax function along a dimension.
156///
157/// Softmax(x_i) = exp(x_i) / sum(exp(x_j))
158#[derive(Debug, Clone, Copy)]
159pub struct Softmax {
160 dim: i64,
161}
162
163impl Softmax {
164 /// Creates a new Softmax along the specified dimension.
165 pub fn new(dim: i64) -> Self {
166 Self { dim }
167 }
168}
169
170impl Default for Softmax {
171 fn default() -> Self {
172 Self::new(-1)
173 }
174}
175
176impl Module for Softmax {
177 fn forward(&self, input: &Variable) -> Variable {
178 input.softmax(self.dim as i32)
179 }
180
181 fn name(&self) -> &'static str {
182 "Softmax"
183 }
184}
185
186// =============================================================================
187// LogSoftmax
188// =============================================================================
189
190/// Applies log(softmax(x)) along a dimension.
191#[derive(Debug, Clone, Copy)]
192pub struct LogSoftmax {
193 dim: i64,
194}
195
196impl LogSoftmax {
197 /// Creates a new LogSoftmax along the specified dimension.
198 pub fn new(dim: i64) -> Self {
199 Self { dim }
200 }
201}
202
203impl Default for LogSoftmax {
204 fn default() -> Self {
205 Self::new(-1)
206 }
207}
208
209impl Module for LogSoftmax {
210 fn forward(&self, input: &Variable) -> Variable {
211 input.log_softmax(self.dim as i32)
212 }
213
214 fn name(&self) -> &'static str {
215 "LogSoftmax"
216 }
217}
218
219// =============================================================================
220// GELU
221// =============================================================================
222
223/// Applies the Gaussian Error Linear Unit function.
224///
225/// GELU(x) = x * Phi(x) where Phi is the CDF of standard normal distribution.
226#[derive(Debug, Clone, Copy, Default)]
227pub struct GELU;
228
229impl GELU {
230 /// Creates a new GELU activation.
231 pub fn new() -> Self {
232 Self
233 }
234}
235
236impl Module for GELU {
237 fn forward(&self, input: &Variable) -> Variable {
238 input.gelu()
239 }
240
241 fn name(&self) -> &'static str {
242 "GELU"
243 }
244}
245
246// =============================================================================
247// SiLU / Swish
248// =============================================================================
249
250/// Applies the SiLU (Swish) function element-wise.
251///
252/// SiLU(x) = x * sigmoid(x)
253#[derive(Debug, Clone, Copy, Default)]
254pub struct SiLU;
255
256impl SiLU {
257 /// Creates a new SiLU activation.
258 pub fn new() -> Self {
259 Self
260 }
261}
262
263impl Module for SiLU {
264 fn forward(&self, input: &Variable) -> Variable {
265 let sigmoid = input.sigmoid();
266 input.mul_var(&sigmoid)
267 }
268
269 fn name(&self) -> &'static str {
270 "SiLU"
271 }
272}
273
274// =============================================================================
275// ELU
276// =============================================================================
277
278/// Applies the Exponential Linear Unit function.
279///
280/// ELU(x) = x if x > 0, else alpha * (exp(x) - 1)
281#[derive(Debug, Clone, Copy)]
282pub struct ELU {
283 alpha: f32,
284}
285
286impl ELU {
287 /// Creates a new ELU with default alpha (1.0).
288 pub fn new() -> Self {
289 Self { alpha: 1.0 }
290 }
291
292 /// Creates an ELU with custom alpha.
293 pub fn with_alpha(alpha: f32) -> Self {
294 Self { alpha }
295 }
296}
297
298impl Default for ELU {
299 fn default() -> Self {
300 Self::new()
301 }
302}
303
304impl Module for ELU {
305 fn forward(&self, input: &Variable) -> Variable {
306 input.elu(self.alpha)
307 }
308
309 fn name(&self) -> &'static str {
310 "ELU"
311 }
312}
313
314// =============================================================================
315// Identity
316// =============================================================================
317
318/// Identity activation (no-op).
319#[derive(Debug, Clone, Copy, Default)]
320pub struct Identity;
321
322impl Identity {
323 /// Creates a new Identity activation.
324 pub fn new() -> Self {
325 Self
326 }
327}
328
329impl Module for Identity {
330 fn forward(&self, input: &Variable) -> Variable {
331 input.clone()
332 }
333
334 fn name(&self) -> &'static str {
335 "Identity"
336 }
337}
338
339// =============================================================================
340// Flatten
341// =============================================================================
342
343/// Flattens all dimensions from `start_dim` to the end into a single dimension.
344///
345/// Default: `start_dim = 1` (preserves batch dimension).
346///
347/// # Examples
348/// ```ignore
349/// let flatten = Flatten::new(); // flattens from dim 1 (batch preserved)
350/// let flat_all = Flatten::from(0); // flattens everything
351/// ```
352#[derive(Debug, Clone, Copy)]
353pub struct Flatten {
354 start_dim: usize,
355}
356
357impl Flatten {
358 /// Creates a Flatten module that flattens from dimension 1 (preserves batch dim).
359 pub fn new() -> Self {
360 Self { start_dim: 1 }
361 }
362
363 /// Creates a Flatten module that flattens from the specified start dimension.
364 pub fn from(start_dim: usize) -> Self {
365 Self { start_dim }
366 }
367}
368
369impl Default for Flatten {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375impl Module for Flatten {
376 fn forward(&self, input: &Variable) -> Variable {
377 input.flatten(self.start_dim)
378 }
379
380 fn parameters(&self) -> Vec<crate::Parameter> {
381 Vec::new()
382 }
383
384 fn name(&self) -> &'static str {
385 "Flatten"
386 }
387}
388
389// =============================================================================
390// Tests
391// =============================================================================
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use axonml_tensor::Tensor;
397
398 #[test]
399 fn test_relu() {
400 let relu = ReLU::new();
401 let input = Variable::new(
402 Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
403 false,
404 );
405 let output = relu.forward(&input);
406 assert_eq!(output.data().to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
407 }
408
409 #[test]
410 fn test_sigmoid() {
411 let sigmoid = Sigmoid::new();
412 let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
413 let output = sigmoid.forward(&input);
414 assert!((output.data().to_vec()[0] - 0.5).abs() < 1e-6);
415 }
416
417 #[test]
418 fn test_softmax() {
419 let softmax = Softmax::new(-1);
420 let input = Variable::new(
421 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
422 false,
423 );
424 let output = softmax.forward(&input);
425 let sum: f32 = output.data().to_vec().iter().sum();
426 assert!((sum - 1.0).abs() < 1e-5);
427 }
428
429 #[test]
430 fn test_leaky_relu() {
431 let leaky = LeakyReLU::with_slope(0.1);
432 let input = Variable::new(Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap(), false);
433 let output = leaky.forward(&input);
434 assert_eq!(output.data().to_vec(), vec![-0.1, 0.0, 1.0]);
435 }
436
437 #[test]
438 fn test_identity() {
439 let id = Identity::new();
440 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
441 let output = id.forward(&input);
442 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
443 }
444}