nabla_ml/
nab_math.rs

1//! Mathematical functions for NDArray operations
2//! 
3//! This module provides mathematical operations commonly found in NumPy,
4//! implemented for the NDArray struct.
5
6use crate::nab_array::NDArray;
7
8/// Mathematical functions for NDArray
9pub struct NabMath;
10
11impl NDArray {
12    /// Calculates the square root of each element in the array
13    ///
14    /// # Returns
15    ///
16    /// A new NDArray with the square root of each element.
17    #[allow(dead_code)]
18    pub fn sqrt(&self) -> Self {
19        let data = self.data().iter().map(|x| x.sqrt()).collect();
20        NDArray::new(data, self.shape().to_vec())
21    }
22
23    /// Calculates the exponential (e^x) of each element in the array
24    ///
25    /// # Returns
26    ///
27    /// A new NDArray with the exponential of each element.
28    #[allow(dead_code)]
29    pub fn exp(&self) -> Self {
30        let data = self.data().iter().map(|x| x.exp()).collect();
31        NDArray::new(data, self.shape().to_vec())
32    }
33
34    /// Calculates the sine of each element in the array
35    ///
36    /// # Returns
37    ///
38    /// A new NDArray with the sine of each element.
39    #[allow(dead_code)]
40    pub fn sin(&self) -> Self {
41        let data: Vec<f64> = self.data().iter().map(|&x| x.sin()).collect();
42        Self::new(data, self.shape().to_vec())
43    }
44
45    /// Calculates the cosine of each element in the array
46    ///
47    /// # Returns
48    ///
49    /// A new NDArray with the cosine of each element.
50    #[allow(dead_code)]
51    pub fn cos(&self) -> Self {
52        let data: Vec<f64> = self.data().iter().map(|&x| x.cos()).collect();
53        Self::new(data, self.shape().to_vec())
54    }
55
56    /// Calculates the natural logarithm of each element in the array
57    ///
58    /// # Returns
59    ///
60    /// A new NDArray with the natural logarithm of each element.
61    #[allow(dead_code)]
62    pub fn ln(&self) -> Self {
63        let data: Vec<f64> = self.data().iter().map(|&x| x.ln()).collect();
64        Self::new(data, self.shape().to_vec())
65    }
66
67}
68
69impl NabMath {
70    /// Computes the sigmoid function element-wise
71    /// 
72    /// sigmoid(x) = 1 / (1 + exp(-x))
73    ///
74    /// # Arguments
75    ///
76    /// * `x` - Input NDArray
77    ///
78    /// # Returns
79    ///
80    /// NDArray with sigmoid applied element-wise
81    pub fn sigmoid(x: &NDArray) -> NDArray {
82        x.map(|val| 1.0 / (1.0 + (-val).exp()))
83    }
84
85    /// Computes the derivative of sigmoid function element-wise
86    ///
87    /// sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
88    ///
89    /// # Arguments
90    ///
91    /// * `x` - Input NDArray
92    ///
93    /// # Returns
94    ///
95    /// NDArray with sigmoid derivative applied element-wise
96    pub fn sigmoid_derivative(x: &NDArray) -> NDArray {
97        let sigmoid_x = Self::sigmoid(x);
98        sigmoid_x.map(|val| val * (1.0 - val))
99    }
100
101    /// Computes the hyperbolic tangent function element-wise
102    ///
103    /// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
104    ///
105    /// # Arguments
106    ///
107    /// * `x` - Input NDArray
108    ///
109    /// # Returns
110    ///
111    /// NDArray with tanh applied element-wise
112    pub fn tanh(x: &NDArray) -> NDArray {
113        x.map(|val| val.tanh())
114    }
115
116    /// Computes the derivative of tanh function element-wise
117    ///
118    /// tanh'(x) = 1 - tanh²(x)
119    ///
120    /// # Arguments
121    ///
122    /// * `x` - Input NDArray
123    ///
124    /// # Returns
125    ///
126    /// NDArray with tanh derivative applied element-wise
127    pub fn tanh_derivative(x: &NDArray) -> NDArray {
128        let tanh_x = Self::tanh(x);
129        tanh_x.map(|val| 1.0 - val * val)
130    }
131
132    /// Computes the ReLU function element-wise
133    ///
134    /// ReLU(x) = max(0, x)
135    ///
136    /// # Arguments
137    ///
138    /// * `x` - Input NDArray
139    ///
140    /// # Returns
141    ///
142    /// NDArray with ReLU applied element-wise
143    pub fn relu(x: &NDArray) -> NDArray {
144        x.map(|val| val.max(0.0))
145    }
146
147    /// Computes the derivative of ReLU function element-wise
148    ///
149    /// ReLU'(x) = 1 if x > 0, 0 otherwise
150    ///
151    /// # Arguments
152    ///
153    /// * `x` - Input NDArray
154    ///
155    /// # Returns
156    ///
157    /// NDArray with ReLU derivative applied element-wise
158    pub fn relu_derivative(x: &NDArray) -> NDArray {
159        x.map(|val| if val > 0.0 { 1.0 } else { 0.0 })
160    }
161
162    /// Computes the softmax function along the specified axis
163    ///
164    /// softmax(x) = exp(x) / sum(exp(x))
165    ///
166    /// # Arguments
167    ///
168    /// * `x` - Input NDArray
169    /// * `axis` - Axis along which to compute softmax (default: -1 for last axis)
170    ///
171    /// # Returns
172    ///
173    /// NDArray with softmax applied along specified axis
174    pub fn softmax(x: &NDArray, _axis: Option<usize>) -> NDArray {
175        assert!(x.ndim() == 1 || x.ndim() == 2, "Softmax is only defined for 1D or 2D arrays");
176
177        let exp = x.map(|val| val.exp());
178        
179        if x.ndim() == 1 {
180            // For 1D arrays
181            let sum = exp.sum();
182            exp.map(|val| val / sum)
183        } else {
184            // For 2D arrays, always compute along rows (axis=1)
185            let (rows, cols) = (x.shape()[0], x.shape()[1]);
186            let sum = exp.sum_axis(1);  // Shape: [rows, 1]
187            
188            // Create broadcasted sum array
189            let mut result_data = Vec::with_capacity(rows * cols);
190            for i in 0..rows {
191                for j in 0..cols {
192                    // Use sum[i] for each row instead of sum[0]
193                    result_data.push(exp.data()[i * cols + j] / sum.data()[i]);
194                }
195            }
196            
197            NDArray::new(result_data, x.shape().to_vec())
198        }
199    }
200
201    /// Computes the derivative of softmax function
202    ///
203    /// # Arguments
204    ///
205    /// * `x` - Input NDArray (softmax output)
206    ///
207    /// # Returns
208    ///
209    /// NDArray with softmax derivative
210    pub fn softmax_derivative(x: &NDArray) -> NDArray {
211        x.map(|val| val * (1.0 - val))
212    }
213
214    /// Computes the Leaky ReLU function element-wise
215    ///
216    /// LeakyReLU(x) = max(alpha * x, x)
217    ///
218    /// # Arguments
219    ///
220    /// * `x` - Input NDArray
221    /// * `alpha` - Slope for negative values (default: 0.01)
222    ///
223    /// # Returns
224    ///
225    /// NDArray with Leaky ReLU applied element-wise
226    pub fn leaky_relu(x: &NDArray, alpha: Option<f64>) -> NDArray {
227        let alpha = alpha.unwrap_or(0.01);
228        x.map(|val| if val > 0.0 { val } else { alpha * val })
229    }
230
231    /// Computes the derivative of Leaky ReLU function
232    ///
233    /// # Arguments
234    ///
235    /// * `x` - Input NDArray
236    /// * `alpha` - Slope for negative values (default: 0.01)
237    ///
238    /// # Returns
239    ///
240    /// NDArray with Leaky ReLU derivative
241    pub fn leaky_relu_derivative(x: &NDArray, alpha: Option<f64>) -> NDArray {
242        let alpha = alpha.unwrap_or(0.01);
243        x.map(|val| if val > 0.0 { 1.0 } else { alpha })
244    }
245
246    /// Computes the ELU (Exponential Linear Unit) function
247    ///
248    /// ELU(x) = x if x > 0, alpha * (exp(x) - 1) if x <= 0
249    ///
250    /// # Arguments
251    ///
252    /// * `x` - Input NDArray
253    /// * `alpha` - Scale for negative values (default: 1.0)
254    ///
255    /// # Returns
256    ///
257    /// NDArray with ELU applied element-wise
258    pub fn elu(x: &NDArray, alpha: Option<f64>) -> NDArray {
259        let alpha = alpha.unwrap_or(1.0);
260        x.map(|val| if val > 0.0 { val } else { alpha * (val.exp() - 1.0) })
261    }
262
263    /// Computes the derivative of ELU function
264    ///
265    /// # Arguments
266    ///
267    /// * `x` - Input NDArray
268    /// * `alpha` - Scale for negative values (default: 1.0)
269    ///
270    /// # Returns
271    ///
272    /// NDArray with ELU derivative
273    pub fn elu_derivative(x: &NDArray, alpha: Option<f64>) -> NDArray {
274        let alpha = alpha.unwrap_or(1.0);
275        x.map(|val| if val > 0.0 { 1.0 } else { alpha * val.exp() })
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_sqrt() {
285        let arr = NDArray::from_vec(vec![1.0, 4.0, 9.0]);
286        let sqrt_arr = arr.sqrt();
287        assert_eq!(sqrt_arr.data(), &[1.0, 2.0, 3.0]);
288    }
289
290    #[test]
291    fn test_exp() {
292        let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0]);
293        let exp_arr = arr.exp();
294        assert!((exp_arr.data()[0] - 1.0).abs() < 1e-4);
295        assert!((exp_arr.data()[1] - std::f64::consts::E).abs() < 1e-4);
296        assert!((exp_arr.data()[2] - std::f64::consts::E.powi(2)).abs() < 1e-4);
297    }
298
299    /// Tests sigmoid function computation
300    #[test]
301    fn test_sigmoid() {
302        let x = NDArray::from_vec(vec![-2.0, 0.0, 2.0]);
303        let result = NabMath::sigmoid(&x);
304        
305        // Test output range (0 to 1)
306        for &val in result.data() {
307            assert!(val > 0.0 && val < 1.0);
308        }
309        
310        // Test sigmoid(0) = 0.5
311        assert!((result.data()[1] - 0.5).abs() < 1e-6);
312        
313        // Test symmetry: sigmoid(-x) = 1 - sigmoid(x)
314        assert!((result.data()[0] - (1.0 - result.data()[2])).abs() < 1e-6);
315    }
316
317    /// Tests sigmoid derivative computation
318    #[test]
319    fn test_sigmoid_derivative() {
320        let x = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
321        let result = NabMath::sigmoid_derivative(&x);
322        assert!((result.data()[0] - 0.1966).abs() < 1e-4);
323        assert!((result.data()[1] - 0.2500).abs() < 1e-4);
324        assert!((result.data()[2] - 0.1966).abs() < 1e-4);
325    }
326
327    /// Tests tanh function computation
328    #[test]
329    fn test_tanh() {
330        let x = NDArray::from_vec(vec![-2.0, 0.0, 2.0]);
331        let result = NabMath::tanh(&x);
332        
333        // Test output range (-1 to 1)
334        for &val in result.data() {
335            assert!(val >= -1.0 && val <= 1.0);
336        }
337        
338        // Test tanh(0) = 0
339        assert!(result.data()[1].abs() < 1e-6);
340        
341        // Test symmetry: tanh(-x) = -tanh(x)
342        assert!((result.data()[0] + result.data()[2]).abs() < 1e-6);
343    }
344
345    /// Tests tanh derivative computation
346    #[test]
347    fn test_tanh_derivative() {
348        let x = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
349        let result = NabMath::tanh_derivative(&x);
350        assert!((result.data()[0] - 0.4199).abs() < 1e-4);
351        assert!((result.data()[1] - 1.0000).abs() < 1e-4);
352        assert!((result.data()[2] - 0.4199).abs() < 1e-4);
353    }
354
355    /// Tests ReLU function computation
356    #[test]
357    fn test_relu() {
358        let x = NDArray::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
359        let result = NabMath::relu(&x);
360        
361        // Test positive values remain unchanged
362        assert_eq!(result.data()[3], 1.0);
363        assert_eq!(result.data()[4], 2.0);
364        
365        // Test negative values become zero
366        assert_eq!(result.data()[0], 0.0);
367        assert_eq!(result.data()[1], 0.0);
368        
369        // Test zero remains zero
370        assert_eq!(result.data()[2], 0.0);
371    }
372
373    /// Tests ReLU derivative computation
374    #[test]
375    fn test_relu_derivative() {
376        let x = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
377        let result = NabMath::relu_derivative(&x);
378        assert_eq!(result.data(), &[0.0, 0.0, 1.0]);
379    }
380
381    /// Tests softmax computation on different dimensions
382    #[test]
383    fn test_softmax() {
384        // Test 1D array
385        let x = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
386        let result = NabMath::softmax(&x, None);
387        
388        // Test sum equals 1
389        let sum: f64 = result.data().iter().sum();
390        assert!((sum - 1.0).abs() < 1e-6);
391        
392        // Test monotonicity (larger inputs -> larger probabilities)
393        for i in 1..result.data().len() {
394            assert!(result.data()[i] > result.data()[i-1]);
395        }
396
397        // Test 2D array
398        let x = NDArray::from_matrix(vec![
399            vec![1.0, 2.0, 3.0],
400            vec![4.0, 5.0, 6.0]
401        ]);
402        let result = NabMath::softmax(&x, Some(1));
403        
404        // Test each row sums to 1
405        for i in 0..2 {
406            let row_sum: f64 = result.data()[i*3..(i+1)*3].iter().sum();
407            assert!((row_sum - 1.0).abs() < 1e-6);
408        }
409    }
410
411    /// Tests softmax derivative computation
412    #[test]
413    fn test_softmax_derivative() {
414        let x = NDArray::from_vec(vec![0.1, 0.7, 0.2]);
415        let result = NabMath::softmax_derivative(&x);
416        assert_eq!(result.shape(), &[3]);
417        // Verify derivative values
418        for &val in result.data() {
419            assert!(val >= 0.0 && val <= 0.25); // Maximum value is 0.25 for softmax derivative
420        }
421    }
422
423    /// Tests Leaky ReLU computation with different alphas
424    #[test]
425    fn test_leaky_relu() {
426        let x = NDArray::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
427        
428        // Test with default alpha
429        let result = NabMath::leaky_relu(&x, None);
430        assert_eq!(result.data()[3], 1.0);  // Positive values unchanged
431        assert_eq!(result.data()[4], 2.0);
432        assert_eq!(result.data()[0], -0.02); // Negative values scaled by 0.01
433        assert_eq!(result.data()[2], 0.0);   // Zero unchanged
434        
435        // Test with custom alpha
436        let result = NabMath::leaky_relu(&x, Some(0.1));
437        assert_eq!(result.data()[3], 1.0);   // Positive values unchanged
438        assert_eq!(result.data()[0], -0.2);  // Negative values scaled by 0.1
439    }
440
441    /// Tests ELU computation with different alphas
442    #[test]
443    fn test_elu() {
444        let x = NDArray::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
445        
446        // Test with default alpha
447        let result = NabMath::elu(&x, None);
448        assert!(result.data()[0] < -0.8); // ELU(-2) ≈ -0.86
449        assert_eq!(result.data()[3], 1.0);
450
451        // Test with custom alpha
452        let result = NabMath::elu(&x, Some(2.0));
453        assert!(result.data()[0] < -1.7); // ELU(-2) with alpha=2 ≈ -1.73
454        assert_eq!(result.data()[3], 1.0);
455    }
456
457    /// Tests ELU derivative computation
458    #[test]
459    fn test_elu_derivative() {
460        let x = NDArray::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
461        let result = NabMath::elu_derivative(&x, None);
462        assert!(result.data()[0] > 0.0 && result.data()[0] < 1.0);
463        assert_eq!(result.data()[3], 1.0);
464    }
465}