nabla_ml/
nab_activations.rs

1use crate::nab_array::NDArray;
2use crate::nab_math::NabMath;
3
4pub struct NablaActivation;
5
6impl NablaActivation {
7    /// Applies the Rectified Linear Unit (ReLU) activation function in forward pass
8    /// 
9    /// ReLU(x) = max(0, x)
10    ///
11    /// # Arguments
12    ///
13    /// * `x` - Input NDArray
14    ///
15    /// # Returns
16    ///
17    /// NDArray with ReLU activation applied element-wise
18    ///
19    /// # Example
20    ///
21    /// ```
22    /// use nabla_ml::nab_array::NDArray;
23    /// use nabla_ml::nab_activations::NablaActivation;
24    ///
25    /// let x = NDArray::from_vec(vec![-1.0, 0.0, 2.0]);
26    /// let output = NablaActivation::relu_forward(&x);
27    /// assert_eq!(output.data(), &[0.0, 0.0, 2.0]);
28    /// ```
29    pub fn relu_forward(x: &NDArray) -> NDArray {
30        NabMath::relu(x)
31    }
32
33    /// Computes the gradient for ReLU activation in backward pass
34    /// 
35    /// ReLU'(x) = 1 if x > 0, else 0
36    ///
37    /// # Arguments
38    ///
39    /// * `gradient` - Gradient from the next layer
40    /// * `x` - Original input to the ReLU function
41    ///
42    /// # Returns
43    ///
44    /// NDArray containing the gradients for backpropagation
45    pub fn relu_backward(gradient: &NDArray, x: &NDArray) -> NDArray {
46        // ReLU derivative: 1 if x > 0, 0 otherwise
47        let dx = x.map(|val| if val > 0.0 { 1.0 } else { 0.0 });
48        gradient * &dx
49    }
50
51    /// Applies the Softmax activation function in forward pass
52    /// 
53    /// Softmax(x)_i = exp(x_i) / sum(exp(x_j))
54    ///
55    /// # Arguments
56    ///
57    /// * `x` - Input NDArray
58    /// * `axis` - Optional axis along which to apply softmax
59    ///
60    /// # Returns
61    ///
62    /// NDArray with softmax probabilities that sum to 1
63    ///
64    /// # Example
65    ///
66    /// ```
67    /// use nabla_ml::nab_array::NDArray;
68    /// use nabla_ml::nab_activations::NablaActivation;
69    ///
70    /// let x = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
71    /// let output = NablaActivation::softmax_forward(&x, None);
72    /// let sum: f64 = output.data().iter().sum();
73    /// assert!((sum - 1.0).abs() < 1e-6);
74    /// ```
75    pub fn softmax_forward(x: &NDArray, axis: Option<usize>) -> NDArray {
76        NabMath::softmax(x, axis)
77    }
78
79    /// Computes the gradient for Softmax activation in backward pass
80    /// 
81    /// Note: For numerical stability, the actual softmax gradient computation
82    /// is typically combined with the loss function gradient.
83    ///
84    /// # Arguments
85    ///
86    /// * `gradient` - Gradient from the loss function
87    /// * `output` - Output from the softmax forward pass
88    ///
89    /// # Returns
90    ///
91    /// NDArray containing the gradients for backpropagation
92    pub fn softmax_backward(gradient: &NDArray, _output: &NDArray) -> NDArray {
93        // Softmax derivative is handled in loss function for numerical stability
94        gradient.clone()
95    }
96
97    /// Applies the Sigmoid activation function in forward pass
98    /// 
99    /// sigmoid(x) = 1 / (1 + exp(-x))
100    ///
101    /// # Arguments
102    ///
103    /// * `x` - Input NDArray
104    ///
105    /// # Returns
106    ///
107    /// NDArray with values squashed between 0 and 1
108    ///
109    /// # Example
110    ///
111    /// ```
112    /// use nabla_ml::nab_array::NDArray;
113    /// use nabla_ml::nab_activations::NablaActivation;
114    ///
115    /// let x = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
116    /// let output = NablaActivation::sigmoid_forward(&x);
117    /// // Values should be between 0 and 1
118    /// for &val in output.data() {
119    ///     assert!(val > 0.0 && val < 1.0);
120    /// }
121    /// ```
122    pub fn sigmoid_forward(x: &NDArray) -> NDArray {
123        NabMath::sigmoid(x)
124    }
125
126    /// Computes the gradient for Sigmoid activation in backward pass
127    /// 
128    /// sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
129    ///
130    /// # Arguments
131    ///
132    /// * `gradient` - Gradient from the next layer
133    /// * `output` - Output from the sigmoid forward pass
134    ///
135    /// # Returns
136    ///
137    /// NDArray containing the gradients for backpropagation
138    pub fn sigmoid_backward(gradient: &NDArray, output: &NDArray) -> NDArray {
139        let sigmoid_derivative = output * &(output.scalar_sub(1.0).multiply_scalar(-1.0));
140        gradient * &sigmoid_derivative
141    }
142
143    /// Applies the Leaky ReLU activation function in forward pass
144    /// 
145    /// leaky_relu(x) = x if x > 0, else alpha * x
146    ///
147    /// # Arguments
148    ///
149    /// * `x` - Input NDArray
150    /// * `alpha` - Slope for negative values (default: 0.01)
151    ///
152    /// # Returns
153    ///
154    /// NDArray with Leaky ReLU activation applied element-wise
155    ///
156    /// # Example
157    ///
158    /// ```
159    /// use nabla_ml::nab_array::NDArray;
160    /// use nabla_ml::nab_activations::NablaActivation;
161    ///
162    /// let x = NDArray::from_vec(vec![-2.0, 0.0, 2.0]);
163    /// let output = NablaActivation::leaky_relu_forward(&x, Some(0.1));
164    /// // Negative values are scaled by alpha
165    /// assert_eq!(output.data()[0], -0.2);
166    /// // Positive values remain unchanged
167    /// assert_eq!(output.data()[2], 2.0);
168    /// ```
169    pub fn leaky_relu_forward(x: &NDArray, alpha: Option<f64>) -> NDArray {
170        NabMath::leaky_relu(x, alpha)
171    }
172
173    /// Computes the gradient for Leaky ReLU activation in backward pass
174    /// 
175    /// leaky_relu'(x) = 1 if x > 0, else alpha
176    ///
177    /// # Arguments
178    ///
179    /// * `gradient` - Gradient from the next layer
180    /// * `x` - Original input to the Leaky ReLU function
181    /// * `alpha` - Slope for negative values (default: 0.01)
182    ///
183    /// # Returns
184    ///
185    /// NDArray containing the gradients for backpropagation
186    pub fn leaky_relu_backward(gradient: &NDArray, x: &NDArray, alpha: Option<f64>) -> NDArray {
187        let alpha = alpha.unwrap_or(0.01);
188        let dx = x.map(|val| if val >= 0.0 { 1.0 } else { alpha });
189        gradient * &dx
190    }
191
192    /// Applies the Hyperbolic Tangent (tanh) activation function in forward pass
193    /// 
194    /// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
195    ///
196    /// # Arguments
197    ///
198    /// * `x` - Input NDArray
199    ///
200    /// # Returns
201    ///
202    /// NDArray with values squashed between -1 and 1
203    ///
204    /// # Example
205    ///
206    /// ```
207    /// use nabla_ml::nab_array::NDArray;
208    /// use nabla_ml::nab_activations::NablaActivation;
209    ///
210    /// let x = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
211    /// let output = NablaActivation::tanh_forward(&x);
212    /// // Values should be between -1 and 1
213    /// for &val in output.data() {
214    ///     assert!(val >= -1.0 && val <= 1.0);
215    /// }
216    /// ```
217    pub fn tanh_forward(x: &NDArray) -> NDArray {
218        NabMath::tanh(x)
219    }
220
221    /// Computes the gradient for tanh activation in backward pass
222    /// 
223    /// tanh'(x) = 1 - tanh²(x)
224    ///
225    /// # Arguments
226    ///
227    /// * `gradient` - Gradient from the next layer
228    /// * `output` - Output from the tanh forward pass
229    ///
230    /// # Returns
231    ///
232    /// NDArray containing the gradients for backpropagation
233    pub fn tanh_backward(gradient: &NDArray, output: &NDArray) -> NDArray {
234        let tanh_derivative = output.multiply(output)  // tanh²(x)
235            .scalar_sub(1.0)                          // -1 + tanh²(x)
236            .multiply_scalar(-1.0);                   // 1 - tanh²(x)
237        gradient * &tanh_derivative
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_relu_forward_backward() {
247        // Test forward pass with mixed positive/negative values
248        let x = NDArray::from_vec(vec![-1.0, 0.0, 2.0]);
249        let forward = NablaActivation::relu_forward(&x);
250        // Verify ReLU zeros out negative values and keeps positive values
251        assert_eq!(forward.data(), &[0.0, 0.0, 2.0]);
252
253        // Test backward pass with uniform gradient
254        let gradient = NDArray::from_vec(vec![1.0, 1.0, 1.0]);
255        let backward = NablaActivation::relu_backward(&gradient, &x);
256        // Verify gradient is zero for negative inputs and unchanged for positive inputs
257        assert_eq!(backward.data(), &[0.0, 0.0, 1.0]);
258    }
259
260    #[test]
261    fn test_softmax_forward_backward() {
262        // Test forward pass with increasing values
263        let x = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
264        let forward = NablaActivation::softmax_forward(&x, None);
265        
266        // Verify softmax output sums to 1 (probability distribution)
267        let sum: f64 = forward.data().iter().sum();
268        assert!((sum - 1.0).abs() < 1e-6);
269
270        // Verify softmax maintains relative ordering (monotonicity)
271        let mut prev = 0.0;
272        for &val in forward.data() {
273            assert!(val >= prev);
274            prev = val;
275        }
276    }
277
278    #[test]
279    fn test_sigmoid_forward_backward() {
280        // Test forward pass with various inputs
281        let x = NDArray::from_vec(vec![-2.0, 0.0, 2.0]);
282        let forward = NablaActivation::sigmoid_forward(&x);
283        
284        // Verify sigmoid output is between 0 and 1
285        for &val in forward.data() {
286            assert!(val > 0.0 && val < 1.0);
287        }
288
289        // Verify sigmoid(0) ≈ 0.5
290        assert!((forward.data()[1] - 0.5).abs() < 1e-6);
291
292        // Test backward pass
293        let gradient = NDArray::from_vec(vec![1.0, 1.0, 1.0]);
294        let backward = NablaActivation::sigmoid_backward(&gradient, &forward);
295        
296        // Verify gradient shape matches input
297        assert_eq!(backward.shape(), x.shape());
298        
299        // Verify gradient is maximum at x = 0 (where sigmoid'(0) = 0.25)
300        assert!(backward.data()[1] > backward.data()[0]);
301        assert!(backward.data()[1] > backward.data()[2]);
302    }
303
304    #[test]
305    fn test_leaky_relu_forward_backward() {
306        // Test forward pass with default alpha
307        let x = NDArray::from_vec(vec![-2.0, 0.0, 2.0]);
308        let forward = NablaActivation::leaky_relu_forward(&x, None);
309        
310        // Verify positive values remain unchanged
311        assert_eq!(forward.data()[2], 2.0);
312        // Verify negative values are scaled by default alpha (0.01)
313        assert_eq!(forward.data()[0], -0.02);
314        // Verify zero remains unchanged
315        assert_eq!(forward.data()[1], 0.0);
316
317        // Test forward pass with custom alpha
318        let forward_custom = NablaActivation::leaky_relu_forward(&x, Some(0.1));
319        // Verify negative values are scaled by custom alpha
320        assert_eq!(forward_custom.data()[0], -0.2);
321
322        // Test backward pass
323        let gradient = NDArray::from_vec(vec![1.0, 1.0, 1.0]);
324        let backward = NablaActivation::leaky_relu_backward(&gradient, &x, Some(0.1));
325        
326        // Verify gradient for positive values is unchanged
327        assert_eq!(backward.data()[2], 1.0);
328        // Verify gradient for negative values is scaled by alpha
329        assert_eq!(backward.data()[0], 0.1);
330        // Verify gradient at zero is 1 (positive side of derivative)
331        assert_eq!(backward.data()[1], 1.0);
332    }
333
334    #[test]
335    fn test_tanh_forward_backward() {
336        // Test forward pass with various inputs
337        let x = NDArray::from_vec(vec![-2.0, 0.0, 2.0]);
338        let forward = NablaActivation::tanh_forward(&x);
339        
340        // Verify tanh output is between -1 and 1
341        for &val in forward.data() {
342            assert!(val >= -1.0 && val <= 1.0);
343        }
344
345        // Verify tanh(0) = 0
346        assert!(forward.data()[1].abs() < 1e-6);
347
348        // Test backward pass
349        let gradient = NDArray::from_vec(vec![1.0, 1.0, 1.0]);
350        let backward = NablaActivation::tanh_backward(&gradient, &forward);
351        
352        // Verify gradient shape matches input
353        assert_eq!(backward.shape(), x.shape());
354        
355        // Verify gradient is maximum at x = 0 (where tanh'(0) = 1)
356        assert!(backward.data()[1] > backward.data()[0]);
357        assert!(backward.data()[1] > backward.data()[2]);
358
359        // Verify gradient at x = 0 is close to 1
360        assert!((backward.data()[1] - 1.0).abs() < 1e-6);
361    }
362}