nabla_ml/nab_math.rs
1//! Mathematical functions for NDArray operations
2//!
3//! This module provides mathematical operations commonly found in NumPy,
4//! implemented for the NDArray struct.
5
6use crate::nab_array::NDArray;
7
8/// Mathematical functions for NDArray
9pub struct NabMath;
10
11impl NDArray {
12 /// Calculates the square root of each element in the array
13 ///
14 /// # Returns
15 ///
16 /// A new NDArray with the square root of each element.
17 #[allow(dead_code)]
18 pub fn sqrt(&self) -> Self {
19 let data = self.data().iter().map(|x| x.sqrt()).collect();
20 NDArray::new(data, self.shape().to_vec())
21 }
22
23 /// Calculates the exponential (e^x) of each element in the array
24 ///
25 /// # Returns
26 ///
27 /// A new NDArray with the exponential of each element.
28 #[allow(dead_code)]
29 pub fn exp(&self) -> Self {
30 let data = self.data().iter().map(|x| x.exp()).collect();
31 NDArray::new(data, self.shape().to_vec())
32 }
33
34 /// Calculates the sine of each element in the array
35 ///
36 /// # Returns
37 ///
38 /// A new NDArray with the sine of each element.
39 #[allow(dead_code)]
40 pub fn sin(&self) -> Self {
41 let data: Vec<f64> = self.data().iter().map(|&x| x.sin()).collect();
42 Self::new(data, self.shape().to_vec())
43 }
44
45 /// Calculates the cosine of each element in the array
46 ///
47 /// # Returns
48 ///
49 /// A new NDArray with the cosine of each element.
50 #[allow(dead_code)]
51 pub fn cos(&self) -> Self {
52 let data: Vec<f64> = self.data().iter().map(|&x| x.cos()).collect();
53 Self::new(data, self.shape().to_vec())
54 }
55
56 /// Calculates the natural logarithm of each element in the array
57 ///
58 /// # Returns
59 ///
60 /// A new NDArray with the natural logarithm of each element.
61 #[allow(dead_code)]
62 pub fn ln(&self) -> Self {
63 let data: Vec<f64> = self.data().iter().map(|&x| x.ln()).collect();
64 Self::new(data, self.shape().to_vec())
65 }
66
67}
68
69impl NabMath {
70 /// Computes the sigmoid function element-wise
71 ///
72 /// sigmoid(x) = 1 / (1 + exp(-x))
73 ///
74 /// # Arguments
75 ///
76 /// * `x` - Input NDArray
77 ///
78 /// # Returns
79 ///
80 /// NDArray with sigmoid applied element-wise
81 pub fn sigmoid(x: &NDArray) -> NDArray {
82 x.map(|val| 1.0 / (1.0 + (-val).exp()))
83 }
84
85 /// Computes the derivative of sigmoid function element-wise
86 ///
87 /// sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
88 ///
89 /// # Arguments
90 ///
91 /// * `x` - Input NDArray
92 ///
93 /// # Returns
94 ///
95 /// NDArray with sigmoid derivative applied element-wise
96 pub fn sigmoid_derivative(x: &NDArray) -> NDArray {
97 let sigmoid_x = Self::sigmoid(x);
98 sigmoid_x.map(|val| val * (1.0 - val))
99 }
100
101 /// Computes the hyperbolic tangent function element-wise
102 ///
103 /// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
104 ///
105 /// # Arguments
106 ///
107 /// * `x` - Input NDArray
108 ///
109 /// # Returns
110 ///
111 /// NDArray with tanh applied element-wise
112 pub fn tanh(x: &NDArray) -> NDArray {
113 x.map(|val| val.tanh())
114 }
115
116 /// Computes the derivative of tanh function element-wise
117 ///
118 /// tanh'(x) = 1 - tanh²(x)
119 ///
120 /// # Arguments
121 ///
122 /// * `x` - Input NDArray
123 ///
124 /// # Returns
125 ///
126 /// NDArray with tanh derivative applied element-wise
127 pub fn tanh_derivative(x: &NDArray) -> NDArray {
128 let tanh_x = Self::tanh(x);
129 tanh_x.map(|val| 1.0 - val * val)
130 }
131
132 /// Computes the ReLU function element-wise
133 ///
134 /// ReLU(x) = max(0, x)
135 ///
136 /// # Arguments
137 ///
138 /// * `x` - Input NDArray
139 ///
140 /// # Returns
141 ///
142 /// NDArray with ReLU applied element-wise
143 pub fn relu(x: &NDArray) -> NDArray {
144 x.map(|val| val.max(0.0))
145 }
146
147 /// Computes the derivative of ReLU function element-wise
148 ///
149 /// ReLU'(x) = 1 if x > 0, 0 otherwise
150 ///
151 /// # Arguments
152 ///
153 /// * `x` - Input NDArray
154 ///
155 /// # Returns
156 ///
157 /// NDArray with ReLU derivative applied element-wise
158 pub fn relu_derivative(x: &NDArray) -> NDArray {
159 x.map(|val| if val > 0.0 { 1.0 } else { 0.0 })
160 }
161
162 /// Computes the softmax function along the specified axis
163 ///
164 /// softmax(x) = exp(x) / sum(exp(x))
165 ///
166 /// # Arguments
167 ///
168 /// * `x` - Input NDArray
169 /// * `axis` - Axis along which to compute softmax (default: -1 for last axis)
170 ///
171 /// # Returns
172 ///
173 /// NDArray with softmax applied along specified axis
174 pub fn softmax(x: &NDArray, _axis: Option<usize>) -> NDArray {
175 assert!(x.ndim() == 1 || x.ndim() == 2, "Softmax is only defined for 1D or 2D arrays");
176
177 let exp = x.map(|val| val.exp());
178
179 if x.ndim() == 1 {
180 // For 1D arrays
181 let sum = exp.sum();
182 exp.map(|val| val / sum)
183 } else {
184 // For 2D arrays, always compute along rows (axis=1)
185 let (rows, cols) = (x.shape()[0], x.shape()[1]);
186 let sum = exp.sum_axis(1); // Shape: [rows, 1]
187
188 // Create broadcasted sum array
189 let mut result_data = Vec::with_capacity(rows * cols);
190 for i in 0..rows {
191 for j in 0..cols {
192 // Use sum[i] for each row instead of sum[0]
193 result_data.push(exp.data()[i * cols + j] / sum.data()[i]);
194 }
195 }
196
197 NDArray::new(result_data, x.shape().to_vec())
198 }
199 }
200
201 /// Computes the derivative of softmax function
202 ///
203 /// # Arguments
204 ///
205 /// * `x` - Input NDArray (softmax output)
206 ///
207 /// # Returns
208 ///
209 /// NDArray with softmax derivative
210 pub fn softmax_derivative(x: &NDArray) -> NDArray {
211 x.map(|val| val * (1.0 - val))
212 }
213
214 /// Computes the Leaky ReLU function element-wise
215 ///
216 /// LeakyReLU(x) = max(alpha * x, x)
217 ///
218 /// # Arguments
219 ///
220 /// * `x` - Input NDArray
221 /// * `alpha` - Slope for negative values (default: 0.01)
222 ///
223 /// # Returns
224 ///
225 /// NDArray with Leaky ReLU applied element-wise
226 pub fn leaky_relu(x: &NDArray, alpha: Option<f64>) -> NDArray {
227 let alpha = alpha.unwrap_or(0.01);
228 x.map(|val| if val > 0.0 { val } else { alpha * val })
229 }
230
231 /// Computes the derivative of Leaky ReLU function
232 ///
233 /// # Arguments
234 ///
235 /// * `x` - Input NDArray
236 /// * `alpha` - Slope for negative values (default: 0.01)
237 ///
238 /// # Returns
239 ///
240 /// NDArray with Leaky ReLU derivative
241 pub fn leaky_relu_derivative(x: &NDArray, alpha: Option<f64>) -> NDArray {
242 let alpha = alpha.unwrap_or(0.01);
243 x.map(|val| if val > 0.0 { 1.0 } else { alpha })
244 }
245
246 /// Computes the ELU (Exponential Linear Unit) function
247 ///
248 /// ELU(x) = x if x > 0, alpha * (exp(x) - 1) if x <= 0
249 ///
250 /// # Arguments
251 ///
252 /// * `x` - Input NDArray
253 /// * `alpha` - Scale for negative values (default: 1.0)
254 ///
255 /// # Returns
256 ///
257 /// NDArray with ELU applied element-wise
258 pub fn elu(x: &NDArray, alpha: Option<f64>) -> NDArray {
259 let alpha = alpha.unwrap_or(1.0);
260 x.map(|val| if val > 0.0 { val } else { alpha * (val.exp() - 1.0) })
261 }
262
263 /// Computes the derivative of ELU function
264 ///
265 /// # Arguments
266 ///
267 /// * `x` - Input NDArray
268 /// * `alpha` - Scale for negative values (default: 1.0)
269 ///
270 /// # Returns
271 ///
272 /// NDArray with ELU derivative
273 pub fn elu_derivative(x: &NDArray, alpha: Option<f64>) -> NDArray {
274 let alpha = alpha.unwrap_or(1.0);
275 x.map(|val| if val > 0.0 { 1.0 } else { alpha * val.exp() })
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_sqrt() {
285 let arr = NDArray::from_vec(vec![1.0, 4.0, 9.0]);
286 let sqrt_arr = arr.sqrt();
287 assert_eq!(sqrt_arr.data(), &[1.0, 2.0, 3.0]);
288 }
289
290 #[test]
291 fn test_exp() {
292 let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0]);
293 let exp_arr = arr.exp();
294 assert!((exp_arr.data()[0] - 1.0).abs() < 1e-4);
295 assert!((exp_arr.data()[1] - std::f64::consts::E).abs() < 1e-4);
296 assert!((exp_arr.data()[2] - std::f64::consts::E.powi(2)).abs() < 1e-4);
297 }
298
299 /// Tests sigmoid function computation
300 #[test]
301 fn test_sigmoid() {
302 let x = NDArray::from_vec(vec![-2.0, 0.0, 2.0]);
303 let result = NabMath::sigmoid(&x);
304
305 // Test output range (0 to 1)
306 for &val in result.data() {
307 assert!(val > 0.0 && val < 1.0);
308 }
309
310 // Test sigmoid(0) = 0.5
311 assert!((result.data()[1] - 0.5).abs() < 1e-6);
312
313 // Test symmetry: sigmoid(-x) = 1 - sigmoid(x)
314 assert!((result.data()[0] - (1.0 - result.data()[2])).abs() < 1e-6);
315 }
316
317 /// Tests sigmoid derivative computation
318 #[test]
319 fn test_sigmoid_derivative() {
320 let x = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
321 let result = NabMath::sigmoid_derivative(&x);
322 assert!((result.data()[0] - 0.1966).abs() < 1e-4);
323 assert!((result.data()[1] - 0.2500).abs() < 1e-4);
324 assert!((result.data()[2] - 0.1966).abs() < 1e-4);
325 }
326
327 /// Tests tanh function computation
328 #[test]
329 fn test_tanh() {
330 let x = NDArray::from_vec(vec![-2.0, 0.0, 2.0]);
331 let result = NabMath::tanh(&x);
332
333 // Test output range (-1 to 1)
334 for &val in result.data() {
335 assert!(val >= -1.0 && val <= 1.0);
336 }
337
338 // Test tanh(0) = 0
339 assert!(result.data()[1].abs() < 1e-6);
340
341 // Test symmetry: tanh(-x) = -tanh(x)
342 assert!((result.data()[0] + result.data()[2]).abs() < 1e-6);
343 }
344
345 /// Tests tanh derivative computation
346 #[test]
347 fn test_tanh_derivative() {
348 let x = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
349 let result = NabMath::tanh_derivative(&x);
350 assert!((result.data()[0] - 0.4199).abs() < 1e-4);
351 assert!((result.data()[1] - 1.0000).abs() < 1e-4);
352 assert!((result.data()[2] - 0.4199).abs() < 1e-4);
353 }
354
355 /// Tests ReLU function computation
356 #[test]
357 fn test_relu() {
358 let x = NDArray::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
359 let result = NabMath::relu(&x);
360
361 // Test positive values remain unchanged
362 assert_eq!(result.data()[3], 1.0);
363 assert_eq!(result.data()[4], 2.0);
364
365 // Test negative values become zero
366 assert_eq!(result.data()[0], 0.0);
367 assert_eq!(result.data()[1], 0.0);
368
369 // Test zero remains zero
370 assert_eq!(result.data()[2], 0.0);
371 }
372
373 /// Tests ReLU derivative computation
374 #[test]
375 fn test_relu_derivative() {
376 let x = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
377 let result = NabMath::relu_derivative(&x);
378 assert_eq!(result.data(), &[0.0, 0.0, 1.0]);
379 }
380
381 /// Tests softmax computation on different dimensions
382 #[test]
383 fn test_softmax() {
384 // Test 1D array
385 let x = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
386 let result = NabMath::softmax(&x, None);
387
388 // Test sum equals 1
389 let sum: f64 = result.data().iter().sum();
390 assert!((sum - 1.0).abs() < 1e-6);
391
392 // Test monotonicity (larger inputs -> larger probabilities)
393 for i in 1..result.data().len() {
394 assert!(result.data()[i] > result.data()[i-1]);
395 }
396
397 // Test 2D array
398 let x = NDArray::from_matrix(vec![
399 vec![1.0, 2.0, 3.0],
400 vec![4.0, 5.0, 6.0]
401 ]);
402 let result = NabMath::softmax(&x, Some(1));
403
404 // Test each row sums to 1
405 for i in 0..2 {
406 let row_sum: f64 = result.data()[i*3..(i+1)*3].iter().sum();
407 assert!((row_sum - 1.0).abs() < 1e-6);
408 }
409 }
410
411 /// Tests softmax derivative computation
412 #[test]
413 fn test_softmax_derivative() {
414 let x = NDArray::from_vec(vec![0.1, 0.7, 0.2]);
415 let result = NabMath::softmax_derivative(&x);
416 assert_eq!(result.shape(), &[3]);
417 // Verify derivative values
418 for &val in result.data() {
419 assert!(val >= 0.0 && val <= 0.25); // Maximum value is 0.25 for softmax derivative
420 }
421 }
422
423 /// Tests Leaky ReLU computation with different alphas
424 #[test]
425 fn test_leaky_relu() {
426 let x = NDArray::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
427
428 // Test with default alpha
429 let result = NabMath::leaky_relu(&x, None);
430 assert_eq!(result.data()[3], 1.0); // Positive values unchanged
431 assert_eq!(result.data()[4], 2.0);
432 assert_eq!(result.data()[0], -0.02); // Negative values scaled by 0.01
433 assert_eq!(result.data()[2], 0.0); // Zero unchanged
434
435 // Test with custom alpha
436 let result = NabMath::leaky_relu(&x, Some(0.1));
437 assert_eq!(result.data()[3], 1.0); // Positive values unchanged
438 assert_eq!(result.data()[0], -0.2); // Negative values scaled by 0.1
439 }
440
441 /// Tests ELU computation with different alphas
442 #[test]
443 fn test_elu() {
444 let x = NDArray::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
445
446 // Test with default alpha
447 let result = NabMath::elu(&x, None);
448 assert!(result.data()[0] < -0.8); // ELU(-2) ≈ -0.86
449 assert_eq!(result.data()[3], 1.0);
450
451 // Test with custom alpha
452 let result = NabMath::elu(&x, Some(2.0));
453 assert!(result.data()[0] < -1.7); // ELU(-2) with alpha=2 ≈ -1.73
454 assert_eq!(result.data()[3], 1.0);
455 }
456
457 /// Tests ELU derivative computation
458 #[test]
459 fn test_elu_derivative() {
460 let x = NDArray::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
461 let result = NabMath::elu_derivative(&x, None);
462 assert!(result.data()[0] > 0.0 && result.data()[0] < 1.0);
463 assert_eq!(result.data()[3], 1.0);
464 }
465}