sklears_multioutput/
activation.rs

1//! Activation functions for neural networks
2//!
3//! This module provides various activation functions commonly used in neural networks,
4//! including ReLU, Sigmoid, Tanh, Linear, and Softmax activations.
5
6// Use SciRS2-Core for arrays (SciRS2 Policy)
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use sklears_core::types::Float;
9
10/// Activation functions for neural networks
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum ActivationFunction {
13    /// ReLU activation: max(0, x)
14    ReLU,
15    /// Sigmoid activation: 1 / (1 + exp(-x))
16    Sigmoid,
17    /// Tanh activation: (exp(x) - exp(-x)) / (exp(x) + exp(-x))
18    Tanh,
19    /// Linear activation: x (identity function)
20    Linear,
21    /// Softmax activation for multi-class outputs
22    Softmax,
23}
24
25impl ActivationFunction {
26    /// Apply activation function element-wise
27    pub fn apply(&self, x: &Array1<Float>) -> Array1<Float> {
28        match self {
29            ActivationFunction::ReLU => x.map(|&val| val.max(0.0)),
30            ActivationFunction::Sigmoid => x.map(|&val| 1.0 / (1.0 + (-val).exp())),
31            ActivationFunction::Tanh => x.map(|&val| val.tanh()),
32            ActivationFunction::Linear => x.clone(),
33            ActivationFunction::Softmax => {
34                let max_val = x.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b));
35                let shifted = x.map(|&val| val - max_val);
36                let exp_vals = shifted.map(|&val| val.exp());
37                let sum_exp = exp_vals.sum();
38                exp_vals.map(|&val| val / sum_exp)
39            }
40        }
41    }
42
43    /// Apply activation function to 2D array row-wise
44    pub fn apply_2d(&self, x: &Array2<Float>) -> Array2<Float> {
45        match self {
46            ActivationFunction::Softmax => {
47                let mut result = Array2::<Float>::zeros(x.dim());
48                for (i, row) in x.axis_iter(Axis(0)).enumerate() {
49                    let activated = self.apply(&row.to_owned());
50                    result.row_mut(i).assign(&activated);
51                }
52                result
53            }
54            _ => x.map(|&val| {
55                let single_val = Array1::from(vec![val]);
56                self.apply(&single_val)[0]
57            }),
58        }
59    }
60
61    /// Compute derivative of activation function
62    pub fn derivative(&self, x: &Array1<Float>) -> Array1<Float> {
63        match self {
64            ActivationFunction::ReLU => x.map(|&val| if val > 0.0 { 1.0 } else { 0.0 }),
65            ActivationFunction::Sigmoid => {
66                let sigmoid_vals = self.apply(x);
67                sigmoid_vals.map(|&val| val * (1.0 - val))
68            }
69            ActivationFunction::Tanh => {
70                let tanh_vals = self.apply(x);
71                tanh_vals.map(|&val| 1.0 - val * val)
72            }
73            ActivationFunction::Linear => Array1::ones(x.len()),
74            ActivationFunction::Softmax => {
75                // For softmax, derivative is more complex and typically computed differently in practice
76                Array1::ones(x.len())
77            }
78        }
79    }
80}