axonml_nn/
activation.rs

1//! Activation Modules - Non-linear Activation Functions
2//!
3//! Provides activation functions as modules for use in Sequential and other containers.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_autograd::Variable;
9use axonml_tensor::Tensor;
10
11use crate::module::Module;
12
13// =============================================================================
14// ReLU
15// =============================================================================
16
17/// Applies the rectified linear unit function element-wise.
18///
19/// ReLU(x) = max(0, x)
20#[derive(Debug, Clone, Copy, Default)]
21pub struct ReLU;
22
23impl ReLU {
24    /// Creates a new ReLU activation.
25    pub fn new() -> Self {
26        Self
27    }
28}
29
30impl Module for ReLU {
31    fn forward(&self, input: &Variable) -> Variable {
32        input.relu()
33    }
34
35    fn name(&self) -> &'static str {
36        "ReLU"
37    }
38}
39
40// =============================================================================
41// LeakyReLU
42// =============================================================================
43
44/// Applies the leaky ReLU function element-wise.
45///
46/// LeakyReLU(x) = max(0, x) + negative_slope * min(0, x)
47#[derive(Debug, Clone, Copy)]
48pub struct LeakyReLU {
49    negative_slope: f32,
50}
51
52impl LeakyReLU {
53    /// Creates a new LeakyReLU with default negative slope (0.01).
54    pub fn new() -> Self {
55        Self {
56            negative_slope: 0.01,
57        }
58    }
59
60    /// Creates a LeakyReLU with custom negative slope.
61    pub fn with_slope(negative_slope: f32) -> Self {
62        Self { negative_slope }
63    }
64}
65
66impl Default for LeakyReLU {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl Module for LeakyReLU {
73    fn forward(&self, input: &Variable) -> Variable {
74        let data = input.data();
75        let result: Vec<f32> = data
76            .to_vec()
77            .iter()
78            .map(|&x| if x > 0.0 { x } else { x * self.negative_slope })
79            .collect();
80        Variable::new(
81            Tensor::from_vec(result, data.shape()).unwrap(),
82            input.requires_grad(),
83        )
84    }
85
86    fn name(&self) -> &'static str {
87        "LeakyReLU"
88    }
89}
90
91// =============================================================================
92// Sigmoid
93// =============================================================================
94
95/// Applies the sigmoid function element-wise.
96///
97/// Sigmoid(x) = 1 / (1 + exp(-x))
98#[derive(Debug, Clone, Copy, Default)]
99pub struct Sigmoid;
100
101impl Sigmoid {
102    /// Creates a new Sigmoid activation.
103    pub fn new() -> Self {
104        Self
105    }
106}
107
108impl Module for Sigmoid {
109    fn forward(&self, input: &Variable) -> Variable {
110        input.sigmoid()
111    }
112
113    fn name(&self) -> &'static str {
114        "Sigmoid"
115    }
116}
117
118// =============================================================================
119// Tanh
120// =============================================================================
121
122/// Applies the hyperbolic tangent function element-wise.
123///
124/// Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
125#[derive(Debug, Clone, Copy, Default)]
126pub struct Tanh;
127
128impl Tanh {
129    /// Creates a new Tanh activation.
130    pub fn new() -> Self {
131        Self
132    }
133}
134
135impl Module for Tanh {
136    fn forward(&self, input: &Variable) -> Variable {
137        input.tanh()
138    }
139
140    fn name(&self) -> &'static str {
141        "Tanh"
142    }
143}
144
145// =============================================================================
146// Softmax
147// =============================================================================
148
149/// Applies the softmax function along a dimension.
150///
151/// Softmax(x_i) = exp(x_i) / sum(exp(x_j))
152#[derive(Debug, Clone, Copy)]
153pub struct Softmax {
154    dim: i64,
155}
156
157impl Softmax {
158    /// Creates a new Softmax along the specified dimension.
159    pub fn new(dim: i64) -> Self {
160        Self { dim }
161    }
162}
163
164impl Default for Softmax {
165    fn default() -> Self {
166        Self::new(-1)
167    }
168}
169
170impl Module for Softmax {
171    fn forward(&self, input: &Variable) -> Variable {
172        // Simple implementation for last dimension
173        let data = input.data();
174        let shape = data.shape().to_vec();
175        let data_vec = data.to_vec();
176
177        let ndim = shape.len();
178        let dim = if self.dim < 0 {
179            (ndim as i64 + self.dim) as usize
180        } else {
181            self.dim as usize
182        };
183
184        let outer_size: usize = shape[..dim].iter().product();
185        let dim_size = shape[dim];
186        let inner_size: usize = shape[dim + 1..].iter().product();
187
188        let mut result = vec![0.0f32; data_vec.len()];
189
190        for outer in 0..outer_size {
191            for inner in 0..inner_size {
192                // Find max for numerical stability
193                let mut max_val = f32::NEG_INFINITY;
194                for d in 0..dim_size {
195                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
196                    max_val = max_val.max(data_vec[idx]);
197                }
198
199                // Compute exp and sum
200                let mut sum = 0.0f32;
201                for d in 0..dim_size {
202                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
203                    let exp_val = (data_vec[idx] - max_val).exp();
204                    result[idx] = exp_val;
205                    sum += exp_val;
206                }
207
208                // Normalize
209                for d in 0..dim_size {
210                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
211                    result[idx] /= sum;
212                }
213            }
214        }
215
216        Variable::new(
217            Tensor::from_vec(result, &shape).unwrap(),
218            input.requires_grad(),
219        )
220    }
221
222    fn name(&self) -> &'static str {
223        "Softmax"
224    }
225}
226
227// =============================================================================
228// LogSoftmax
229// =============================================================================
230
231/// Applies log(softmax(x)) along a dimension.
232#[derive(Debug, Clone, Copy)]
233pub struct LogSoftmax {
234    dim: i64,
235}
236
237impl LogSoftmax {
238    /// Creates a new LogSoftmax along the specified dimension.
239    pub fn new(dim: i64) -> Self {
240        Self { dim }
241    }
242}
243
244impl Default for LogSoftmax {
245    fn default() -> Self {
246        Self::new(-1)
247    }
248}
249
250impl Module for LogSoftmax {
251    fn forward(&self, input: &Variable) -> Variable {
252        let softmax = Softmax::new(self.dim);
253        let sm = softmax.forward(input);
254        let sm_vec = sm.data().to_vec();
255        let result: Vec<f32> = sm_vec.iter().map(|&x| x.ln()).collect();
256        Variable::new(
257            Tensor::from_vec(result, sm.data().shape()).unwrap(),
258            input.requires_grad(),
259        )
260    }
261
262    fn name(&self) -> &'static str {
263        "LogSoftmax"
264    }
265}
266
267// =============================================================================
268// GELU
269// =============================================================================
270
271/// Applies the Gaussian Error Linear Unit function.
272///
273/// GELU(x) = x * Phi(x) where Phi is the CDF of standard normal distribution.
274#[derive(Debug, Clone, Copy, Default)]
275pub struct GELU;
276
277impl GELU {
278    /// Creates a new GELU activation.
279    pub fn new() -> Self {
280        Self
281    }
282}
283
284impl Module for GELU {
285    fn forward(&self, input: &Variable) -> Variable {
286        let data = input.data();
287        let data_vec = data.to_vec();
288        // GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
289        let sqrt_2_over_pi = (2.0_f32 / std::f32::consts::PI).sqrt();
290        let result: Vec<f32> = data_vec
291            .iter()
292            .map(|&x| {
293                let inner = sqrt_2_over_pi * (x + 0.044715 * x.powi(3));
294                0.5 * x * (1.0 + inner.tanh())
295            })
296            .collect();
297        Variable::new(
298            Tensor::from_vec(result, data.shape()).unwrap(),
299            input.requires_grad(),
300        )
301    }
302
303    fn name(&self) -> &'static str {
304        "GELU"
305    }
306}
307
308// =============================================================================
309// SiLU / Swish
310// =============================================================================
311
312/// Applies the SiLU (Swish) function element-wise.
313///
314/// SiLU(x) = x * sigmoid(x)
315#[derive(Debug, Clone, Copy, Default)]
316pub struct SiLU;
317
318impl SiLU {
319    /// Creates a new SiLU activation.
320    pub fn new() -> Self {
321        Self
322    }
323}
324
325impl Module for SiLU {
326    fn forward(&self, input: &Variable) -> Variable {
327        let sigmoid = input.sigmoid();
328        input.mul_var(&sigmoid)
329    }
330
331    fn name(&self) -> &'static str {
332        "SiLU"
333    }
334}
335
336// =============================================================================
337// ELU
338// =============================================================================
339
340/// Applies the Exponential Linear Unit function.
341///
342/// ELU(x) = x if x > 0, else alpha * (exp(x) - 1)
343#[derive(Debug, Clone, Copy)]
344pub struct ELU {
345    alpha: f32,
346}
347
348impl ELU {
349    /// Creates a new ELU with default alpha (1.0).
350    pub fn new() -> Self {
351        Self { alpha: 1.0 }
352    }
353
354    /// Creates an ELU with custom alpha.
355    pub fn with_alpha(alpha: f32) -> Self {
356        Self { alpha }
357    }
358}
359
360impl Default for ELU {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366impl Module for ELU {
367    fn forward(&self, input: &Variable) -> Variable {
368        let data = input.data();
369        let result: Vec<f32> = data
370            .to_vec()
371            .iter()
372            .map(|&x| {
373                if x > 0.0 {
374                    x
375                } else {
376                    self.alpha * (x.exp() - 1.0)
377                }
378            })
379            .collect();
380        Variable::new(
381            Tensor::from_vec(result, data.shape()).unwrap(),
382            input.requires_grad(),
383        )
384    }
385
386    fn name(&self) -> &'static str {
387        "ELU"
388    }
389}
390
391// =============================================================================
392// Identity
393// =============================================================================
394
395/// Identity activation (no-op).
396#[derive(Debug, Clone, Copy, Default)]
397pub struct Identity;
398
399impl Identity {
400    /// Creates a new Identity activation.
401    pub fn new() -> Self {
402        Self
403    }
404}
405
406impl Module for Identity {
407    fn forward(&self, input: &Variable) -> Variable {
408        input.clone()
409    }
410
411    fn name(&self) -> &'static str {
412        "Identity"
413    }
414}
415
416// =============================================================================
417// Tests
418// =============================================================================
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn test_relu() {
426        let relu = ReLU::new();
427        let input = Variable::new(
428            Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
429            false,
430        );
431        let output = relu.forward(&input);
432        assert_eq!(output.data().to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
433    }
434
435    #[test]
436    fn test_sigmoid() {
437        let sigmoid = Sigmoid::new();
438        let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
439        let output = sigmoid.forward(&input);
440        assert!((output.data().to_vec()[0] - 0.5).abs() < 1e-6);
441    }
442
443    #[test]
444    fn test_softmax() {
445        let softmax = Softmax::new(-1);
446        let input = Variable::new(
447            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
448            false,
449        );
450        let output = softmax.forward(&input);
451        let sum: f32 = output.data().to_vec().iter().sum();
452        assert!((sum - 1.0).abs() < 1e-5);
453    }
454
455    #[test]
456    fn test_leaky_relu() {
457        let leaky = LeakyReLU::with_slope(0.1);
458        let input = Variable::new(Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap(), false);
459        let output = leaky.forward(&input);
460        assert_eq!(output.data().to_vec(), vec![-0.1, 0.0, 1.0]);
461    }
462
463    #[test]
464    fn test_identity() {
465        let id = Identity::new();
466        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
467        let output = id.forward(&input);
468        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
469    }
470}