nabla_ml/
nab_array.rs

1use rand::Rng;
2use rand_distr::{StandardNormal, Uniform, Distribution};
3use std::ops::{Add, Sub, Mul, Div};
4use rayon::prelude::*;
5#[derive(Debug, Clone)]
6pub struct NDArray {
7    pub data: Vec<f64>,
8    pub shape: Vec<usize>,
9}
10
11impl NDArray {
12    pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Self {
13        let total_size: usize = shape.iter().product();
14        assert_eq!(data.len(), total_size, "Data length must match shape dimensions");
15        NDArray { data, shape }
16    }
17
18    pub fn from_vec(data: Vec<f64>) -> Self {
19        let len = data.len();
20        Self::new(data, vec![len])
21    }
22
23    #[allow(dead_code)]
24    pub fn from_matrix(data: Vec<Vec<f64>>) -> Self {
25        let rows = data.len();
26        let cols = data.get(0).map_or(0, |row| row.len());
27        let flat_data: Vec<f64> = data.into_iter().flatten().collect();
28        Self::new(flat_data, vec![rows, cols])
29    }
30
31    pub fn shape(&self) -> &[usize] {
32        &self.shape
33    }
34
35    pub fn ndim(&self) -> usize {
36        self.shape.len()
37    }
38
39    /// Returns a reference to the data of the array
40    pub fn data(&self) -> &[f64] {
41        &self.data
42    }
43
44    /// Creates a 2D array (matrix) of random numbers between 0 and 1
45    ///
46    /// # Arguments
47    ///
48    /// * `rows` - The number of rows in the matrix.
49    /// * `cols` - The number of columns in the matrix.
50    ///
51    /// # Returns
52    ///
53    /// A 2D NDArray filled with random numbers.
54    #[allow(dead_code)]
55    pub fn rand_2d(rows: usize, cols: usize) -> Self {
56        let mut rng = rand::thread_rng();
57        let data: Vec<f64> = (0..rows * cols).map(|_| rng.gen()).collect();
58        Self::new(data, vec![rows, cols])
59    }
60
61
62    /// Creates a 1D array of random numbers following a normal distribution
63    ///
64    /// # Arguments
65    ///
66    /// * `size` - The number of elements in the array.
67    ///
68    /// # Returns
69    ///
70    /// A 1D NDArray filled with random numbers from a normal distribution.
71    #[allow(dead_code)]
72    pub fn randn(size: usize) -> Self {
73        let mut rng = rand::thread_rng();
74        let data: Vec<f64> = (0..size).map(|_| rng.sample(StandardNormal)).collect();
75        Self::from_vec(data)
76    }
77
78    /// Creates a 2D array (matrix) of random numbers following a normal distribution
79    ///
80    /// # Arguments
81    ///
82    /// * `rows` - The number of rows in the matrix.
83    /// * `cols` - The number of columns in the matrix.
84    ///
85    /// # Returns
86    ///
87    /// A 2D NDArray filled with random numbers from a normal distribution.
88    #[allow(dead_code)]
89    pub fn randn_2d(rows: usize, cols: usize) -> Self {
90        let mut rng = rand::thread_rng();
91        let data: Vec<f64> = (0..rows * cols).map(|_| rng.sample(StandardNormal)).collect();
92        Self::new(data, vec![rows, cols])
93    }
94
95    /// Creates a 1D array of random integers between `low` and `high`
96    ///
97    /// # Arguments
98    ///
99    /// * `low` - The lower bound (inclusive).
100    /// * `high` - The upper bound (exclusive).
101    /// * `size` - The number of elements in the array.
102    ///
103    /// # Returns
104    ///
105    /// A 1D NDArray filled with random integers.
106    #[allow(dead_code)]
107    pub fn randint(low: i32, high: i32, size: usize) -> Self {
108        let mut rng = rand::thread_rng();
109        let data: Vec<f64> = (0..size).map(|_| rng.gen_range(low..high) as f64).collect();
110        Self::from_vec(data)
111    }
112
113    /// Creates a 2D array (matrix) of random integers between `low` and `high`
114    ///
115    /// # Arguments
116    ///
117    /// * `low` - The lower bound (inclusive).
118    /// * `high` - The upper bound (exclusive).
119    /// * `rows` - The number of rows in the matrix.
120    /// * `cols` - The number of columns in the matrix.
121    ///
122    /// # Returns
123    ///
124    /// A 2D NDArray filled with random integers.
125    #[allow(dead_code)]
126    pub fn randint_2d(low: i32, high: i32, rows: usize, cols: usize) -> Self {
127        let mut rng = rand::thread_rng();
128        let data: Vec<f64> = (0..rows * cols).map(|_| rng.gen_range(low..high) as f64).collect();
129        Self::new(data, vec![rows, cols])
130    }
131
132    /// Reshapes the array to the specified shape, allowing one dimension to be inferred
133    ///
134    /// # Arguments
135    ///
136    /// * `new_shape` - A vector representing the new shape, with at most one dimension as `-1`.
137    ///
138    /// # Returns
139    ///
140    /// A new NDArray with the specified shape.
141    pub fn reshape(&self, new_shape: &[usize]) -> Result<Self, &'static str> {
142        let total_elements = self.data.len();
143        let new_total: usize = new_shape.iter().copied().product();
144        
145        if total_elements != new_total {
146            return Err("New shape must have same total size as original");
147        }
148        
149        Ok(NDArray {
150            data: self.data.clone(),
151            shape: new_shape.to_vec()
152        })
153    }
154
155     /// Returns the maximum value in the array
156    ///
157    /// # Returns
158    ///
159    /// The maximum value as an f64.
160    #[allow(dead_code)]
161    pub fn max(&self) -> f64 {
162        *self.data.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap()
163    }
164
165    /// Returns the index of the maximum value in the array
166    ///
167    /// # Returns
168    ///
169    /// The index of the maximum value.
170    // #[allow(dead_code)]
171    // pub fn argmax(&self) -> usize {
172    //     self.data.iter().enumerate().max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()).map(|(i, _)| i).unwrap()
173    // }
174
175    /// Returns the indices of maximum values
176    /// For 1D arrays: returns a single index
177    /// For 2D arrays: returns indices along the specified axis
178    pub fn argmax(&self, axis: Option<usize>) -> Vec<usize> {
179        match axis {
180            None => {
181                // Global argmax
182                vec![self.data.iter()
183                    .enumerate()
184                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
185                    .map(|(i, _)| i)
186                    .unwrap()]
187            },
188            Some(ax) => {
189                if ax >= self.shape.len() {
190                    panic!("Axis {} out of bounds for shape {:?}", ax, self.shape);
191                }
192                // Axis-wise argmax
193                match ax {
194                    0 => {
195                        let cols = self.shape[1];
196                        let mut indices = Vec::with_capacity(cols);
197                        for j in 0..cols {
198                            let mut max_idx = 0;
199                            let mut max_val = self.data[j];
200                            for i in 1..self.shape[0] {
201                                let val = self.data[i * cols + j];
202                                if val > max_val {
203                                    max_val = val;
204                                    max_idx = i;
205                                }
206                            }
207                            indices.push(max_idx);
208                        }
209                        indices
210                    },
211                    1 => {
212                        let cols = self.shape[1];
213                        let mut indices = Vec::with_capacity(self.shape[0]);
214                        for i in 0..self.shape[0] {
215                            let row_start = i * cols;
216                            let mut max_idx = 0;
217                            let mut max_val = self.data[row_start];
218                            for j in 1..cols {
219                                let val = self.data[row_start + j];
220                                if val > max_val {
221                                    max_val = val;
222                                    max_idx = j;
223                                }
224                            }
225                            indices.push(max_idx);
226                        }
227                        indices
228                    },
229                    _ => panic!("Unsupported axis {}", ax)
230                }
231            }
232        }
233    }
234
235
236    /// Returns the minimum value in the array
237    ///
238    /// # Returns
239    ///
240    /// The minimum value as an f64.
241    #[allow(dead_code)]
242    pub fn min(&self) -> f64 {
243        *self.data.iter().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap()
244    }
245
246    /// Creates an NDArray from a flat vector and a specified shape
247    ///
248    /// # Arguments
249    ///
250    /// * `data` - A vector of f64 values representing the array's data.
251    /// * `shape` - A vector of usize values representing the dimensions of the array.
252    ///
253    /// # Returns
254    ///
255    /// A new NDArray instance.
256    #[allow(dead_code)]
257    pub fn from_vec_reshape(data: Vec<f64>, shape: Vec<usize>) -> Self {
258        let total_size: usize = shape.iter().product();
259        assert_eq!(data.len(), total_size, "Data length must match shape dimensions");
260        NDArray { data, shape }
261    }
262
263    /// Extracts a single sample from a batch of N-dimensional arrays
264    ///
265    /// # Arguments
266    ///
267    /// * `sample_index` - The index of the sample to extract
268    ///
269    /// # Returns
270    ///
271    /// A new NDArray containing just the specified sample with N-1 dimensions
272    #[allow(dead_code)]
273    pub fn extract_sample(&self, sample_index: usize) -> Self {
274        assert!(self.ndim() >= 2, "Array must have at least 2 dimensions");
275        assert!(sample_index < self.shape[0], "Sample index out of bounds");
276
277        let sample_size: usize = self.shape.iter().skip(1).product();
278        let start_index = sample_index * sample_size;
279        let end_index = start_index + sample_size;
280        
281        // Create new shape without the first dimension
282        let new_shape: Vec<usize> = self.shape.iter().skip(1).cloned().collect();
283        
284        NDArray::new(
285            self.data[start_index..end_index].to_vec(),
286            new_shape
287        )
288    }
289
290    /// Pretty prints an N-dimensional array
291    ///
292    /// # Arguments
293    ///
294    /// * `precision` - The number of decimal places to round each value to.
295    #[allow(dead_code)]
296    pub fn pretty_print(&self, precision: usize) {
297        let indent_str = " ".repeat(precision);
298        
299        let format_value = |x: f64| -> String {
300            if x == 0.0 {
301                format!("{:.1}", x)
302            } else {
303                format!("{:.*}", precision, x)
304            }
305        };
306        
307        match self.ndim() {
308            1 => println!("{}[{}]", indent_str, self.data.iter()
309                .map(|&x| format_value(x))
310                .collect::<Vec<_>>()
311                .join(" ")),
312                
313            2 => {
314                println!("{}[", indent_str);
315                for i in 0..self.shape[0] {
316                    print!("{}  [", indent_str);
317                    for j in 0..self.shape[1] {
318                        print!("{}", format_value(self.get_2d(i, j)));
319                        if j < self.shape[1] - 1 {
320                            print!(" ");
321                        }
322                    }
323                    println!("]");
324                }
325                println!("{}]", indent_str);
326            },
327            
328            _ => {
329                println!("{}[", indent_str);
330                for i in 0..self.shape[0] {
331                    let slice = self.extract_sample(i);
332                    slice.pretty_print(precision + 2);
333                }
334                println!("{}]", indent_str);
335            }
336        }
337    }
338
339
340    /// Returns a specific element from the array
341    ///
342    /// # Arguments
343    ///
344    /// * `index` - The index of the element to retrieve.
345    ///
346    /// # Returns
347    ///
348    /// The element at the specified index.
349    #[allow(dead_code)]
350    pub fn get(&self, index: usize) -> f64 {
351        self.data[index]
352    }
353
354    /// Creates a 1D array with a range of numbers
355    ///
356    /// # Arguments
357    ///
358    /// * `start` - The starting value of the range (inclusive).
359    /// * `stop` - The stopping value of the range (exclusive).
360    /// * `step` - The step size between each value in the range.
361    ///
362    /// # Returns
363    ///
364    /// A 1D NDArray containing the range of numbers.
365    #[allow(dead_code)]
366    pub fn arange(start: f64, stop: f64, step: f64) -> Self {
367        let mut data = Vec::new();
368        let mut current = start;
369        while current < stop {
370            data.push(current);
371            current += step;
372        }
373        Self::from_vec(data)
374    }
375
376        /// Creates a 1D array filled with zeros
377    ///
378    /// # Arguments
379    ///
380    /// * `size` - The number of elements in the array.
381    ///
382    /// # Returns
383    ///
384    /// A 1D NDArray filled with zeros.
385    #[allow(dead_code)]
386    pub fn zeros(shape: Vec<usize>) -> Self {
387        let total_size: usize = shape.iter().product();
388        NDArray {
389            data: vec![0.0; total_size],
390            shape,
391        }
392    }
393
394
395    /// Creates a 2D array (matrix) filled with zeros
396    ///
397    /// # Arguments
398    ///
399    /// * `rows` - The number of rows in the matrix.
400    /// * `cols` - The number of columns in the matrix.
401    ///
402    /// # Returns
403    ///
404    /// A 2D NDArray filled with zeros.
405    #[allow(dead_code)]
406    pub fn zeros_2d(rows: usize, cols: usize) -> Self {
407        Self::new(vec![0.0; rows * cols], vec![rows, cols])
408    }
409
410    /// Creates a 1D array filled with ones
411    ///
412    /// # Arguments
413    ///
414    /// * `size` - The number of elements in the array.
415    ///
416    /// # Returns
417    ///
418    /// A 1D NDArray filled with ones.
419    #[allow(dead_code)]
420    pub fn ones(size: usize) -> Self {
421        Self::from_vec(vec![1.0; size])
422    }
423
424        /// Creates a 2D array (matrix) filled with ones
425    ///
426    /// # Arguments
427    ///
428    /// * `rows` - The number of rows in the matrix.
429    /// * `cols` - The number of columns in the matrix.
430    ///
431    /// # Returns
432    ///
433    /// A 2D NDArray filled with ones.
434    #[allow(dead_code)]
435    pub fn ones_2d(rows: usize, cols: usize) -> Self {
436        Self::new(vec![1.0; rows * cols], vec![rows, cols])
437    }
438
439    /// Creates a 1D array with evenly spaced numbers over a specified interval
440    ///
441    /// # Arguments
442    ///
443    /// * `start` - The starting value of the interval.
444    /// * `end` - The ending value of the interval.
445    /// * `num` - The number of evenly spaced samples to generate.
446    /// * `precision` - The number of decimal places to round each value to.
447    ///
448    /// # Returns
449    ///
450    /// A 1D NDArray containing the evenly spaced numbers.
451    #[allow(dead_code)]
452    pub fn linspace(start: f64, end: f64, num: usize, precision: usize) -> Self {
453        assert!(num > 1, "Number of samples must be greater than 1");
454        let step = (end - start) / (num - 1) as f64;
455        let mut data = Vec::with_capacity(num);
456        let factor = 10f64.powi(precision as i32);
457        for i in 0..num {
458            let value = start + step * i as f64;
459            let rounded_value = (value * factor).round() / factor;
460            data.push(rounded_value);
461        }
462        Self::from_vec(data)
463    }
464
465    /// Creates an identity matrix of size `n x n`
466    ///
467    /// # Arguments
468    ///
469    /// * `n` - The size of the identity matrix.
470    ///
471    /// # Returns
472    ///
473    /// An `n x n` identity matrix as an NDArray.
474    #[allow(dead_code)]
475    pub fn eye(n: usize) -> Self {
476        let mut data = vec![0.0; n * n];
477        for i in 0..n {
478            data[i * n + i] = 1.0;
479        }
480        Self::new(data, vec![n, n])
481    }
482
483    /// Creates a 1D array of random numbers between 0 and 1
484    ///
485    /// # Arguments
486    ///
487    /// * `size` - The number of elements in the array.
488    ///
489    /// # Returns
490    ///
491    /// A 1D NDArray filled with random numbers.
492    #[allow(dead_code)]
493    pub fn rand(size: usize) -> Self {
494        let mut rng = rand::thread_rng();
495        let data: Vec<f64> = (0..size).map(|_| rng.gen()).collect();
496        Self::from_vec(data)
497    }
498
499
500    /// Returns a sub-matrix from a 2D array
501    ///
502    /// # Arguments
503    ///
504    /// * `row_start` - The starting row index of the sub-matrix.
505    /// * `row_end` - The ending row index of the sub-matrix (exclusive).
506    /// * `col_start` - The starting column index of the sub-matrix.
507    /// * `col_end` - The ending column index of the sub-matrix (exclusive).
508    ///
509    /// # Returns
510    ///
511    /// A new NDArray representing the specified sub-matrix.
512    #[allow(dead_code)]
513    pub fn sub_matrix(&self, row_start: usize, row_end: usize, col_start: usize, col_end: usize) -> Self {
514        assert_eq!(self.ndim(), 2, "sub_matrix is only applicable to 2D arrays");
515        let cols = self.shape[1];
516        let mut data = Vec::new();
517        for row in row_start..row_end {
518            for col in col_start..col_end {
519                data.push(self.data[row * cols + col]);
520            }
521        }
522        Self::new(data, vec![row_end - row_start, col_end - col_start])
523    }
524
525    /// Sets a specific element in the array
526    ///
527    /// # Arguments
528    ///
529    /// * `index` - The index of the element to set.
530    /// * `value` - The value to set the element to.
531    #[allow(dead_code)]
532    pub fn set(&mut self, index: usize, value: f64) {
533        self.data[index] = value;
534    }
535
536    /// Sets a range of elements in the array to a specific value
537    ///
538    /// # Arguments
539    ///
540    /// * `start` - The starting index of the range.
541    /// * `end` - The ending index of the range (exclusive).
542    /// * `value` - The value to set the elements to.
543    #[allow(dead_code)]
544    pub fn set_range(&mut self, start: usize, end: usize, value: f64) {
545        for i in start..end {
546            self.data[i] = value;
547        }
548    }
549
550     /// Returns a copy of the array
551    ///
552    /// # Returns
553    ///
554    /// A new NDArray that is a copy of the original.
555    #[allow(dead_code)]
556    pub fn copy(&self) -> Self {
557        Self::new(self.data.clone(), self.shape.clone())
558    }
559
560    /// Returns a view (slice) of the array from start to end (exclusive)
561    ///
562    /// # Arguments
563    ///
564    /// * `start` - The starting index of the view.
565    /// * `end` - The ending index of the view (exclusive).
566    ///
567    /// # Returns
568    ///
569    /// A slice of f64 values representing the specified view.
570    #[allow(dead_code)]
571    pub fn view(&self, start: usize, end: usize) -> &[f64] {
572        &self.data[start..end]
573    }
574
575        /// Returns a mutable view (slice) of the array from start to end (exclusive)
576    ///
577    /// # Arguments
578    ///
579    /// * `start` - The starting index of the view.
580    /// * `end` - The ending index of the view (exclusive).
581    ///
582    /// # Returns
583    ///
584    /// A mutable slice of f64 values representing the specified view.
585    #[allow(dead_code)]
586    pub fn view_mut(&mut self, start: usize, end: usize) -> &mut [f64] {
587        &mut self.data[start..end]
588    }
589
590
591    /// Returns a specific element from a 2D array
592    ///
593    /// # Arguments
594    ///
595    /// * `row` - The row index of the element.
596    /// * `col` - The column index of the element.
597    ///
598    /// # Returns
599    ///
600    /// The element at the specified row and column.
601    #[allow(dead_code)]
602    pub fn get_2d(&self, row: usize, col: usize) -> f64 {
603        assert_eq!(self.ndim(), 2, "get_2d is only applicable to 2D arrays");
604        let cols = self.shape[1];
605        self.data[row * cols + col]
606    }
607
608    /// Sets a specific element in a 2D array
609    ///
610    /// # Arguments
611    ///
612    /// * `row` - The row index of the element.
613    /// * `col` - The column index of the element.
614    /// * `value` - The value to set the element to.
615    #[allow(dead_code)]
616    pub fn set_2d(&mut self, row: usize, col: usize, value: f64) {
617        assert_eq!(self.ndim(), 2, "set_2d is only applicable to 2D arrays");
618        let cols = self.shape[1];
619        self.data[row * cols + col] = value;
620    }
621
622    /// Adds a new axis to the array at the specified position
623    ///
624    /// # Arguments
625    ///
626    /// * `axis` - The position at which to add the new axis.
627    ///
628    /// # Returns
629    ///
630    /// A new NDArray with an additional axis.
631    #[allow(dead_code)]
632    pub fn new_axis(&self, axis: usize) -> Self {
633        let mut new_shape = self.shape.clone();
634        new_shape.insert(axis, 1);
635        Self::new(self.data.clone(), new_shape)
636    }
637
638    /// Expands the dimensions of the array by adding a new axis at the specified index
639    ///
640    /// # Arguments
641    ///
642    /// * `axis` - The index at which to add the new axis.
643    ///
644    /// # Returns
645    ///
646    /// A new NDArray with expanded dimensions.
647    #[allow(dead_code)]
648    pub fn expand_dims(&self, axis: usize) -> Self {
649        self.new_axis(axis)
650    }
651
652    /// Returns a boolean array indicating whether each element satisfies the condition
653    ///
654    /// # Arguments
655    ///
656    /// * `threshold` - The threshold value to compare each element against.
657    ///
658    /// # Returns
659    ///
660    /// A vector of boolean values indicating whether each element is greater than the threshold.
661    #[allow(dead_code)]
662    pub fn greater_than(&self, threshold: f64) -> Vec<bool> {
663        self.data.iter().map(|&x| x > threshold).collect()
664    }
665
666    /// Returns a new array containing only the elements that satisfy the condition
667    ///
668    /// # Arguments
669    ///
670    /// * `condition` - A closure that takes an f64 and returns a boolean.
671    ///
672    /// # Returns
673    ///
674    /// A new NDArray containing only the elements that satisfy the condition.
675    #[allow(dead_code)]
676    pub fn filter(&self, condition: impl Fn(&f64) -> bool) -> Self {
677        let data: Vec<f64> = self.data.iter().cloned().filter(condition).collect();
678        Self::from_vec(data)
679    }
680
681
682    /// Returns the data type of the elements in the array
683    ///
684    /// # Returns
685    ///
686    /// A string representing the data type of the elements.
687    #[allow(dead_code)]
688    pub fn dtype(&self) -> &'static str {
689        "f64" // Since we're using f64 for all elements
690    }
691
692    /// Returns the total number of elements in the array
693    ///
694    /// # Returns
695    ///
696    /// The total number of elements in the array.
697    #[allow(dead_code)]
698    pub fn size(&self) -> usize {
699        self.data.len()
700    }
701
702    /// Returns the index of the minimum value in the array
703    ///
704    /// # Returns
705    ///
706    /// The index of the minimum value.
707    #[allow(dead_code)]
708    pub fn argmin(&self) -> usize {
709        self.data.iter().enumerate().min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()).map(|(i, _)| i).unwrap()
710    }
711
712    /// Returns a slice of the array from start to end (exclusive)
713    ///
714    /// # Arguments
715    ///
716    /// * `start` - The starting index of the slice.
717    /// * `end` - The ending index of the slice (exclusive).
718    ///
719    /// # Returns
720    ///
721    /// A new NDArray containing the specified slice.
722    #[allow(dead_code)]
723    pub fn slice(&self, start: usize, end: usize) -> Self {
724        // println!("Slicing array:");
725        // println!("  Original shape: {:?}", self.shape);
726        // println!("  Start: {}, End: {}", start, end);
727
728        let mut new_shape = self.shape.clone();
729        new_shape[0] = end - start;
730
731        if self.ndim() == 2 {
732            let cols = self.shape[1];
733            let start_idx = start * cols;
734            let end_idx = end * cols;
735            // println!("  2D array: keeping columns, new shape will be: {:?}", new_shape);
736            
737            let sliced_data = self.data[start_idx..end_idx].to_vec();
738            // println!("  Sliced data length: {}", sliced_data.len());
739            
740            NDArray::new(sliced_data, new_shape)
741        } else {
742            // println!("  1D array: simple slice");
743            NDArray::new(self.data[start..end].to_vec(), new_shape)
744        }
745    }
746
747    /// Converts an NDArray of labels into a one-hot encoded NDArray
748    ///
749    /// # Arguments
750    ///
751    /// * `labels` - An NDArray containing numerical labels
752    ///
753    /// # Returns
754    ///
755    /// A new NDArray with one-hot encoded labels where each row corresponds to one label
756    ///
757    /// # Panics
758    ///
759    /// Panics if the input contains non-integer values
760    pub fn one_hot_encode(labels: &NDArray) -> Self {
761        // Verify that all values are integers
762        for &value in labels.data() {
763            // Check if the value is effectively an integer
764            if value.fract() != 0.0 {
765                panic!("All values must be integers for one-hot encoding");
766            }
767        }
768
769        // Convert values to integers and find unique classes
770        let labels_int: Vec<i32> = labels.data()
771            .iter()
772            .map(|&x| x as i32)
773            .collect();
774
775        // Find min and max to determine the range of classes
776        let min_label = labels_int.iter().min().unwrap();
777        let max_label = labels_int.iter().max().unwrap();
778        let num_classes = (max_label - min_label + 1) as usize;
779        
780        let mut data = vec![0.0; labels_int.len() * num_classes];
781        
782        // Shift indices by min_label to handle negative values
783        for (i, &label) in labels_int.iter().enumerate() {
784            let shifted_label = (label - min_label) as usize;
785            data[i * num_classes + shifted_label] = 1.0;
786        }
787        
788        NDArray::new(data, vec![labels_int.len(), num_classes])
789    }
790
791    /// Transposes a 2D array (matrix)
792    ///
793    /// # Returns
794    ///
795    /// A new NDArray with transposed dimensions.
796    pub fn transpose(&self) -> Result<Self, &'static str> {
797        if self.shape.len() != 2 {
798            return Err("transpose currently only supports 2D arrays");
799        }
800        
801        let (rows, cols) = (self.shape[0], self.shape[1]);
802        let mut new_data = vec![0.0; rows * cols];
803        
804        for i in 0..rows {
805            for j in 0..cols {
806                new_data[j * rows + i] = self.data[i * cols + j];
807            }
808        }
809        
810        Ok(NDArray {
811            data: new_data,
812            shape: vec![cols, rows]
813        })
814    }
815
816    /// Performs matrix multiplication (dot product) between two 2D arrays
817    ///
818    /// # Arguments
819    ///
820    /// * `other` - The other NDArray to multiply with.
821    ///
822    /// # Returns
823    ///
824    /// A new NDArray resulting from the matrix multiplication.
825    pub fn dot(&self, other: &NDArray) -> Self {
826        assert_eq!(self.ndim(), 2, "Dot product is only defined for 2D arrays");
827        assert_eq!(other.ndim(), 2, "Dot product is only defined for 2D arrays");
828        assert_eq!(self.shape[1], other.shape[0], "Inner dimensions must match for dot product");
829
830        let rows = self.shape[0];
831        let cols = other.shape[1];
832        let mut result_data = vec![0.0; rows * cols];
833
834        for i in 0..rows {
835            for j in 0..cols {
836                let mut sum = 0.0;
837                for k in 0..self.shape[1] {
838                    sum += self.data[i * self.shape[1] + k] * other.data[k * other.shape[1] + j];
839                }
840                result_data[i * cols + j] = sum;
841            }
842        }
843
844        NDArray::new(result_data, vec![rows, cols])
845    }
846
847    /// Performs element-wise multiplication between two arrays
848    ///
849    /// # Arguments
850    ///
851    /// * `other` - The other NDArray to multiply with.
852    ///
853    /// # Returns
854    ///
855    /// A new NDArray resulting from the element-wise multiplication.
856    // pub fn multiply(&self, other: &NDArray) -> Self {
857    //     assert_eq!(self.shape, other.shape, "Shapes must match for element-wise multiplication");
858
859    //     let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(a, b)| a * b).collect();
860    //     NDArray::new(data, self.shape.clone())
861    // }
862    pub fn multiply(&self, other: &NDArray) -> Self {
863        assert_eq!(self.shape, other.shape, "Shapes must match for element-wise multiplication");
864        
865        let data = if self.data.len() > 1000 {
866            self.data.par_iter()
867                .zip(other.data.par_iter())
868                .map(|(&a, &b)| a * b)
869                .collect()
870        } else {
871            self.data.iter()
872                .zip(other.data.iter())
873                .map(|(&a, &b)| a * b)
874                .collect()    
875        };
876        
877        NDArray::new(data, self.shape.clone())
878    }
879
880   
881
882
883
884
885    /// Subtracts a scalar from each element in the array
886    ///
887    /// # Arguments
888    ///
889    /// * `scalar` - The scalar value to subtract.
890    ///
891    /// # Returns
892    ///
893    /// A new NDArray with the scalar subtracted from each element.
894    pub fn scalar_sub(&self, scalar: f64) -> Self {
895        let data: Vec<f64> = self.data.iter().map(|&x| x - scalar).collect();
896        NDArray::new(data, self.shape.clone())
897    }
898
899    /// Multiplies each element in the array by a scalar
900    ///
901    /// # Arguments
902    ///
903    /// * `scalar` - The scalar value to multiply.
904    ///
905    /// # Returns
906    ///
907    /// A new NDArray with each element multiplied by the scalar.
908    pub fn multiply_scalar(&self, scalar: f64) -> Self {
909        let data: Vec<f64> = self.data.iter().map(|&x| x * scalar).collect();
910        NDArray::new(data, self.shape.clone())
911    }
912
913    /// Clips the values in the array to a specified range
914    ///
915    /// # Arguments
916    ///
917    /// * `min` - The minimum value to clip to.
918    /// * `max` - The maximum value to clip to.
919    ///
920    /// # Returns
921    ///
922    /// A new NDArray with values clipped to the specified range.
923    pub fn clip(&self, min: f64, max: f64) -> Self {
924        let data: Vec<f64> = self.data.iter().map(|&x| x.clamp(min, max)).collect();
925        NDArray::new(data, self.shape.clone())
926    }
927
928    /// Performs element-wise division between two arrays
929    ///
930    /// # Arguments
931    ///
932    /// * `other` - The other NDArray to divide by.
933    ///
934    /// # Returns
935    ///
936    /// A new NDArray resulting from the element-wise division.
937    pub fn divide(&self, other: &NDArray) -> Self {
938        assert_eq!(self.shape, other.shape, "Shapes must match for element-wise division");
939
940        let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(a, b)| a / b).collect();
941        NDArray::new(data, self.shape.clone())
942    }
943
944    /// Divides each element in the array by a scalar
945    ///
946    /// # Arguments
947    ///
948    /// * `scalar` - The scalar value to divide by.
949    ///
950    /// # Returns
951    ///
952    /// A new NDArray with each element divided by the scalar.
953    pub fn divide_scalar(&self, scalar: f64) -> Self {
954        let data: Vec<f64> = self.data.iter().map(|&x| x / scalar).collect();
955        NDArray::new(data, self.shape.clone())
956    }
957
958    /// Sums the elements of the array along a specified axis
959    ///
960    /// # Arguments
961    ///
962    /// * `axis` - The axis along which to sum the elements.
963    ///
964    /// # Returns
965    ///
966    /// A new NDArray with the summed elements along the specified axis.
967    pub fn sum_axis(&self, axis: usize) -> Self {
968        if axis >= self.shape.len() {
969            panic!("Axis {} out of bounds for shape {:?}", axis, self.shape);
970        }
971
972        match axis {
973            0 => {
974                let cols = self.shape[1];
975                let mut result = vec![0.0; cols];
976                
977                for j in 0..cols {
978                    for i in 0..self.shape[0] {
979                        result[j] += self.data[i * cols + j];
980                    }
981                }
982                
983                NDArray::new(result, vec![1, cols])
984            },
985            1 => {
986                let cols = self.shape[1];
987                let mut result = vec![0.0; self.shape[0]];
988                
989                for i in 0..self.shape[0] {
990                    for j in 0..cols {
991                        result[i] += self.data[i * cols + j];
992                    }
993                }
994                
995                NDArray::new(result, vec![self.shape[0], 1])
996            },
997            _ => panic!("Unsupported axis {}", axis)
998        }
999    }
1000
1001    /// Performs element-wise subtraction between two arrays
1002    ///
1003    /// # Arguments
1004    ///
1005    /// * `other` - The other NDArray to subtract.
1006    ///
1007    /// # Returns
1008    ///
1009    /// A new NDArray resulting from the element-wise subtraction.
1010    pub fn subtract(&self, other: &NDArray) -> Self {
1011        assert_eq!(self.shape, other.shape, "Shapes must match for element-wise subtraction");
1012
1013        let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(a, b)| a - b).collect();
1014        NDArray::new(data, self.shape.clone())
1015    }
1016
1017    /// Adds a scalar to each element in the array
1018    ///
1019    /// # Arguments
1020    ///
1021    /// * `scalar` - The scalar value to add.
1022    ///
1023    /// # Returns
1024    ///
1025    /// A new NDArray with the scalar added to each element.
1026    pub fn add_scalar(&self, scalar: f64) -> Self {
1027        let data: Vec<f64> = self.data.iter().map(|&x| x + scalar).collect();
1028        NDArray::new(data, self.shape.clone())
1029    }
1030
1031    /// Calculates the natural logarithm of each element in the array
1032    ///
1033    /// # Returns
1034    ///
1035    /// A new NDArray with the natural logarithm of each element.
1036    pub fn log(&self) -> Self {
1037        let data: Vec<f64> = self.data.iter().map(|&x| x.ln()).collect();
1038        NDArray::new(data, self.shape.clone())
1039    }
1040
1041    /// Sums all elements in the array
1042    ///
1043    /// # Returns
1044    ///
1045    /// The sum of all elements as an f64.
1046    pub fn sum(&self) -> f64 {
1047        self.data.iter().sum()
1048    }
1049
1050    pub fn pad_to_size(&self, target_size: usize) -> Self {
1051        if self.shape[0] >= target_size {
1052            return self.clone();
1053        }
1054
1055        let mut new_shape = self.shape.clone();
1056        new_shape[0] = target_size;
1057        let total_size: usize = new_shape.iter().product();
1058        
1059        // Create new data vector with zeros
1060        let mut new_data = vec![0.0; total_size];
1061        
1062        // Copy existing data
1063        let row_size = self.shape.iter().skip(1).product::<usize>();
1064        let existing_data_size = self.shape[0] * row_size;
1065        new_data[..existing_data_size].copy_from_slice(&self.data);
1066        
1067        NDArray::new(new_data, new_shape)
1068    }
1069
1070    /// Add layer normalization
1071    pub fn layer_normalize(&self) -> Self {
1072        let (rows, cols) = (self.shape[0], self.shape[1]);
1073        let mut result = vec![0.0; self.data.len()];
1074        
1075        for i in 0..rows {
1076            let start = i * cols;
1077            let end = start + cols;
1078            let row = &self.data[start..end];
1079            
1080            // Calculate mean and variance
1081            let mean: f64 = row.iter().sum::<f64>() / cols as f64;
1082            let var: f64 = row.iter()
1083                .map(|&x| (x - mean).powi(2))
1084                .sum::<f64>() / cols as f64;
1085            let std = (var + 1e-5).sqrt();
1086            
1087            // Normalize
1088            for j in 0..cols {
1089                result[start + j] = (row[j] - mean) / std;
1090            }
1091        }
1092        
1093        NDArray::new(result, self.shape.clone())
1094    }
1095
1096    /// Add batch normalization
1097    pub fn batch_normalize(&self) -> Self {
1098        let (batch_size, features) = (self.shape[0], self.shape[1]);
1099        let mut result = vec![0.0; self.data.len()];
1100        
1101        // For each feature
1102        for j in 0..features {
1103            // Calculate mean and variance across the batch
1104            let mut mean = 0.0;
1105            let mut var = 0.0;
1106            
1107            // Calculate mean
1108            for i in 0..batch_size {
1109                mean += self.data[i * features + j];
1110            }
1111            mean /= batch_size as f64;
1112            
1113            // Calculate variance
1114            for i in 0..batch_size {
1115                var += (self.data[i * features + j] - mean).powi(2);
1116            }
1117            var /= batch_size as f64;
1118            
1119            // Normalize
1120            let std = (var + 1e-5).sqrt();
1121            for i in 0..batch_size {
1122                result[i * features + j] = (self.data[i * features + j] - mean) / std;
1123            }
1124        }
1125        
1126        NDArray::new(result, self.shape.clone())
1127    }
1128
1129
1130    pub fn add(self, other: &NDArray) -> Self {
1131        // println!("Adding arrays:");
1132        // println!("  Left shape: {:?}", self.shape);
1133        // println!("  Right shape: {:?}", other.shape);
1134
1135        // Handle broadcasting for shapes like [N, M] + [1, M]
1136        if self.shape.len() == other.shape.len() && 
1137           other.shape[0] == 1 && 
1138           self.shape[1] == other.shape[1] {
1139            
1140            // println!("  Performing broadcasting addition");
1141            let mut result_data = Vec::with_capacity(self.data.len());
1142            let cols = other.shape[1];
1143            
1144            // Add the broadcasted row to each row of self
1145            for i in 0..self.shape[0] {
1146                for j in 0..cols {
1147                    result_data.push(self.data[i * cols + j] + other.data[j]);
1148                }
1149            }
1150            
1151            let result = NDArray::new(result_data, self.shape.clone());
1152            // println!("  Result shape: {:?}", result.shape);
1153            return result;
1154        }
1155        
1156        // Regular element-wise addition for matching shapes
1157        if self.shape != other.shape {
1158            panic!("Shapes must match for element-wise addition\n  left: {:?}\n right: {:?}", 
1159                   self.shape, other.shape);
1160        }
1161
1162        let data: Vec<f64> = self.data.iter()
1163            .zip(other.data.iter())
1164            .map(|(a, b)| a + b)
1165            .collect();
1166
1167        NDArray::new(data, self.shape.clone())
1168    }
1169
1170    /// Returns the mean of the array
1171    pub fn mean(&self) -> f64 {
1172        self.sum() / self.data.len() as f64
1173    }
1174
1175    /// Returns the standard deviation of the array
1176    pub fn std(&self) -> f64 {
1177        let mean = self.mean();
1178        let variance = self.data.iter()
1179            .map(|&x| (x - mean).powi(2))
1180            .sum::<f64>() / self.data.len() as f64;
1181        variance.sqrt()
1182    }
1183
1184    /// Returns the minimum value along the specified axis
1185    pub fn min_axis(&self, axis: usize) -> Result<Self, &'static str> {
1186        if axis >= self.shape.len() {
1187            return Err("Axis out of bounds");
1188        }
1189
1190        match axis {
1191            0 => {
1192                if self.shape.len() != 2 {
1193                    return Err("min_axis(0) requires 2D array");
1194                }
1195                let cols = self.shape[1];
1196                let mut result = vec![f64::INFINITY; cols];
1197                
1198                for j in 0..cols {
1199                    for i in 0..self.shape[0] {
1200                        result[j] = result[j].min(self.data[i * cols + j]);
1201                    }
1202                }
1203                
1204                Ok(NDArray::new(result, vec![1, cols]))
1205            },
1206            1 => {
1207                if self.shape.len() != 2 {
1208                    return Err("min_axis(1) requires 2D array");
1209                }
1210                let cols = self.shape[1];
1211                let mut result = vec![f64::INFINITY; self.shape[0]];
1212                
1213                for i in 0..self.shape[0] {
1214                    for j in 0..cols {
1215                        result[i] = result[i].min(self.data[i * cols + j]);
1216                    }
1217                }
1218                
1219                Ok(NDArray::new(result, vec![self.shape[0], 1]))
1220            },
1221            _ => Err("Unsupported axis")
1222        }
1223    }
1224
1225    /// Concatenates two arrays along the specified axis
1226    pub fn concatenate(&self, other: &Self, axis: usize) -> Result<Self, &'static str> {
1227        if axis >= self.shape.len() {
1228            return Err("Axis out of bounds");
1229        }
1230        
1231        if self.shape.len() != other.shape.len() {
1232            return Err("Arrays must have same number of dimensions");
1233        }
1234        
1235        // Check that all dimensions except axis match
1236        for (i, (&s1, &s2)) in self.shape.iter().zip(other.shape.iter()).enumerate() {
1237            if i != axis && s1 != s2 {
1238                return Err("All dimensions except concatenation axis must match");
1239            }
1240        }
1241        
1242        let mut new_shape = self.shape.clone();
1243        new_shape[axis] += other.shape[axis];
1244        
1245        let mut new_data = Vec::with_capacity(self.data.len() + other.data.len());
1246        
1247        match axis {
1248            0 => {
1249                new_data.extend_from_slice(&self.data);
1250                new_data.extend_from_slice(&other.data);
1251            },
1252            1 => {
1253                let rows = self.shape[0];
1254                let cols1 = self.shape[1];
1255                let cols2 = other.shape[1];
1256                
1257                for i in 0..rows {
1258                    new_data.extend_from_slice(&self.data[i * cols1..(i + 1) * cols1]);
1259                    new_data.extend_from_slice(&other.data[i * cols2..(i + 1) * cols2]);
1260                }
1261            },
1262            _ => return Err("Unsupported axis")
1263        }
1264        
1265        Ok(NDArray::new(new_data, new_shape))
1266    }
1267
1268    pub fn map<F>(&self, f: F) -> Self 
1269    where F: Fn(f64) -> f64 
1270    {
1271        let new_data: Vec<f64> = self.data.iter().map(|&x| f(x)).collect();
1272        NDArray::new(new_data, self.shape.clone())
1273    }
1274
1275    /// Returns the absolute values of array elements
1276    ///
1277    /// # Returns
1278    ///
1279    /// A new NDArray with absolute values
1280    pub fn abs(&self) -> Self {
1281        self.map(|x| x.abs())
1282    }
1283
1284    /// Returns the exponential power of array elements
1285    ///
1286    /// # Returns
1287    ///
1288    /// A new NDArray with exponential values
1289    pub fn power(&self, n: f64) -> Self {
1290        self.map(|x| x.powf(n))
1291    }
1292
1293    /// Returns the cumulative sum of array elements
1294    ///
1295    /// # Returns
1296    ///
1297    /// A new NDArray with cumulative sums
1298    pub fn cumsum(&self) -> Self {
1299        let mut result = Vec::with_capacity(self.data.len());
1300        let mut sum = 0.0;
1301        for &x in &self.data {
1302            sum += x;
1303            result.push(sum);
1304        }
1305        NDArray::new(result, self.shape.clone())
1306    }
1307
1308    /// Returns array with elements rounded to specified decimals
1309    ///
1310    /// # Arguments
1311    ///
1312    /// * `decimals` - Number of decimal places to round to
1313    ///
1314    /// # Returns
1315    ///
1316    /// A new NDArray with rounded values
1317    pub fn round(&self, decimals: i32) -> Self {
1318        let factor = 10.0_f64.powi(decimals);
1319        self.map(|x| (x * factor).round() / factor)
1320    }
1321
1322    /// Returns indices that would sort the array
1323    ///
1324    /// # Returns
1325    ///
1326    /// A vector of indices that would sort the array
1327    pub fn argsort(&self) -> Vec<usize> {
1328        let mut indices: Vec<usize> = (0..self.data.len()).collect();
1329        indices.sort_by(|&i, &j| self.data[i].partial_cmp(&self.data[j]).unwrap());
1330        indices
1331    }
1332
1333    /// Returns unique elements of the array
1334    ///
1335    /// # Returns
1336    ///
1337    /// A new NDArray containing unique elements in sorted order
1338    pub fn unique(&self) -> Self {
1339        let mut unique_vals = self.data.clone();
1340        unique_vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
1341        unique_vals.dedup();
1342        NDArray::new(unique_vals.to_vec(), vec![unique_vals.len()])
1343    }
1344
1345    /// Applies a condition element-wise and returns a new array
1346    ///
1347    /// # Arguments
1348    ///
1349    /// * `condition` - Function that returns true/false for each element
1350    /// * `x` - Value to use where condition is true
1351    /// * `y` - Value to use where condition is false
1352    ///
1353    /// # Returns
1354    ///
1355    /// A new NDArray with values chosen based on condition
1356    pub fn where_cond<F>(&self, condition: F, x: f64, y: f64) -> Self 
1357    where F: Fn(f64) -> bool 
1358    {
1359        self.map(|val| if condition(val) { x } else { y })
1360    }
1361
1362    /// Returns the median value of the array
1363    ///
1364    /// # Returns
1365    ///
1366    /// The median value as f64
1367    pub fn median(&self) -> f64 {
1368        let mut sorted = self.data.clone();
1369        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1370        let mid = sorted.len() / 2;
1371        if sorted.len() % 2 == 0 {
1372            (sorted[mid - 1] + sorted[mid]) / 2.0
1373        } else {
1374            sorted[mid]
1375        }
1376    }
1377
1378    /// Returns the maximum values along the specified axis
1379    ///
1380    /// # Arguments
1381    ///
1382    /// * `axis` - Axis along which to find maximum values
1383    ///
1384    /// # Returns
1385    ///
1386    /// NDArray containing maximum values along specified axis
1387    pub fn max_axis(&self, axis: usize) -> Self {
1388        if axis >= self.shape.len() {
1389            panic!("Axis {} out of bounds for shape {:?}", axis, self.shape);
1390        }
1391
1392        // Handle 1D array case
1393        if self.shape.len() == 1 {
1394            return NDArray::new(vec![self.data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)], vec![1]);
1395        }
1396
1397        // Handle 2D array case
1398        match axis {
1399            0 => {
1400                let cols = self.shape[1];
1401                let mut result = vec![f64::NEG_INFINITY; cols];
1402                
1403                for j in 0..cols {
1404                    for i in 0..self.shape[0] {
1405                        result[j] = result[j].max(self.data[i * cols + j]);
1406                    }
1407                }
1408                
1409                NDArray::new(result, vec![1, cols])
1410            },
1411            1 => {
1412                let cols = self.shape[1];
1413                let mut result = vec![f64::NEG_INFINITY; self.shape[0]];
1414                
1415                for i in 0..self.shape[0] {
1416                    for j in 0..cols {
1417                        result[i] = result[i].max(self.data[i * cols + j]);
1418                    }
1419                }
1420                
1421                NDArray::new(result, vec![self.shape[0], 1])
1422            },
1423            _ => panic!("Unsupported axis {}", axis)
1424        }
1425    }
1426
1427    /// Returns a string representation of the array
1428    pub fn display(&self) -> String {
1429        format!("NDArray(shape={:?}, data={:?})", self.shape, self.data)
1430    }
1431
1432    /// Creates a new NDArray with random uniform values between 0 and 1
1433    /// 
1434    /// # Arguments
1435    /// 
1436    /// * `shape` - Shape of the array
1437    /// 
1438    /// # Example
1439    /// ```
1440    /// use nabla_ml::nab_array::NDArray;
1441    /// 
1442    /// let arr = NDArray::rand_uniform(&[2, 3]);
1443    /// assert_eq!(arr.shape(), vec![2, 3]);
1444    /// ```
1445    pub fn rand_uniform(shape: &[usize]) -> Self {
1446        let size: usize = shape.iter().product();
1447        let uniform = Uniform::new(0.0, 1.0);
1448        let mut rng = rand::thread_rng();
1449        
1450        let data: Vec<f64> = (0..size)
1451            .map(|_| uniform.sample(&mut rng))
1452            .collect();
1453
1454        Self::new(data, shape.to_vec())
1455    }
1456
1457    /// Calculates the mean along the specified axis
1458    /// 
1459    /// # Arguments
1460    /// * `axis` - Axis along which to calculate mean (0 for columns, 1 for rows)
1461    pub fn mean_axis(&self, axis: usize) -> Self {
1462        let sum = self.sum_axis(axis);
1463        let n = if axis == 0 { self.shape[0] } else { self.shape[1] } as f64;
1464        sum.multiply_scalar(1.0 / n)
1465    }
1466
1467    /// Calculates the variance along the specified axis
1468    /// 
1469    /// # Arguments
1470    /// * `axis` - Axis along which to calculate variance (0 for columns, 1 for rows)
1471    pub fn var_axis(&self, axis: usize) -> Self {
1472        let mean = self.mean_axis(axis);
1473        
1474        // For axis 0, we need to broadcast the mean to match original shape
1475        let broadcasted_mean = if axis == 0 {
1476            let mut result = Vec::with_capacity(self.data.len());
1477            let cols = self.shape[1];
1478            
1479            // Repeat mean values for each row
1480            for _ in 0..self.shape[0] {
1481                for j in 0..cols {
1482                    result.push(mean.data[j]);
1483                }
1484            }
1485            
1486            NDArray::new(result, self.shape.clone())
1487        } else {
1488            mean
1489        };
1490
1491        let centered = self.subtract(&broadcasted_mean);
1492        let squared = centered.multiply(&centered);
1493        let n = if axis == 0 { self.shape[0] } else { self.shape[1] } as f64;
1494        squared.sum_axis(axis).multiply_scalar(1.0 / n)
1495    }
1496
1497    /// Converts a class vector (integers) to binary class matrix (one-hot encoding)
1498    /// 
1499    /// # Arguments
1500    /// 
1501    /// * `num_classes` - Optional number of classes. If None, it will be inferred from the data
1502    /// 
1503    /// # Returns
1504    /// 
1505    /// A 2D NDArray where each row is a one-hot encoded vector
1506    /// 
1507    /// # Example
1508    /// 
1509    /// ```
1510    /// use nabla_ml::nab_array::NDArray;
1511    /// 
1512    /// let labels = NDArray::from_vec(vec![0.0, 1.0, 2.0]);
1513    /// let categorical = labels.to_categorical(None);
1514    /// assert_eq!(categorical.shape(), &[3, 3]);
1515    /// assert_eq!(categorical.data(), &[1.0, 0.0, 0.0,
1516    ///                                 0.0, 1.0, 0.0,
1517    ///                                 0.0, 0.0, 1.0]);
1518    /// ```
1519    pub fn to_categorical(&self, num_classes: Option<usize>) -> Self {
1520        // Ensure input is 1D
1521        assert_eq!(self.ndim(), 1, "Input must be a 1D array");
1522        
1523        // Find min and max labels to handle negative values
1524        let min_label = self.data().iter()
1525            .fold(f64::INFINITY, |a, &b| a.min(b)) as i32;
1526        let max_label = self.data().iter()
1527            .fold(f64::NEG_INFINITY, |a, &b| a.max(b)) as i32;
1528            
1529        // Determine number of classes
1530        let n_classes = num_classes.unwrap_or_else(|| 
1531            (max_label - min_label + 1) as usize
1532        );
1533        
1534        let n_samples = self.shape()[0];
1535        let mut categorical = vec![0.0; n_samples * n_classes];
1536        
1537        // Fill the categorical array
1538        for (sample_idx, &label) in self.data().iter().enumerate() {
1539            // Shift label to be non-negative
1540            let shifted_label = (label as i32 - min_label) as usize;
1541            assert!(shifted_label < n_classes, 
1542                "Label {} is out of range for {} classes", label, n_classes);
1543            
1544            let row_offset = sample_idx * n_classes;
1545            categorical[row_offset + shifted_label] = 1.0;
1546        }
1547        
1548        NDArray::new(categorical, vec![n_samples, n_classes])
1549    }
1550}
1551
1552impl Add for NDArray {
1553    type Output = Self;
1554
1555    fn add(self, other: Self) -> Self::Output {
1556        assert_eq!(self.shape, other.shape, "Shapes must match for element-wise addition");
1557        let data = self.data.iter().zip(other.data.iter()).map(|(a, b)| a + b).collect();
1558        NDArray::new(data, self.shape.clone())
1559    }
1560}
1561
1562impl Add<&NDArray> for NDArray {
1563    type Output = Self;
1564
1565    fn add(self, other: &NDArray) -> Self::Output {
1566        // println!("Adding arrays:");
1567        // println!("  Left shape: {:?}", self.shape);
1568        // println!("  Right shape: {:?}", other.shape);
1569
1570        // Handle broadcasting for shapes like [N, M] + [1, M]
1571        if self.shape.len() == other.shape.len() && 
1572           other.shape[0] == 1 && 
1573           self.shape[1] == other.shape[1] {
1574            
1575            // println!("  Performing broadcasting addition");
1576            let mut result_data = Vec::with_capacity(self.data.len());
1577            let cols = other.shape[1];
1578            
1579            // Add the broadcasted row to each row of self
1580            for i in 0..self.shape[0] {
1581                for j in 0..cols {
1582                    result_data.push(self.data[i * cols + j] + other.data[j]);
1583                }
1584            }
1585            
1586            let result = NDArray::new(result_data, self.shape.clone());
1587            // println!("  Result shape: {:?}", result.shape);
1588            return result;
1589        }
1590        
1591        // Regular element-wise addition for matching shapes
1592        if self.shape != other.shape {
1593            panic!("Shapes must match for element-wise addition\n  left: {:?}\n right: {:?}", 
1594                   self.shape, other.shape);
1595        }
1596
1597        let data: Vec<f64> = self.data.iter()
1598            .zip(other.data.iter())
1599            .map(|(a, b)| a + b)
1600            .collect();
1601
1602        NDArray::new(data, self.shape.clone())
1603    }
1604}
1605
1606impl Sub for NDArray {
1607    type Output = Self;
1608
1609    fn sub(self, other: Self) -> Self::Output {
1610        assert_eq!(self.shape, other.shape, "Shapes must match for element-wise subtraction");
1611        let data = self.data.iter().zip(other.data.iter()).map(|(a, b)| a - b).collect();
1612        NDArray::new(data, self.shape.clone())
1613    }
1614}
1615
1616impl Mul<f64> for NDArray {
1617    type Output = Self;
1618
1619    fn mul(self, scalar: f64) -> Self::Output {
1620        let data = self.data.iter().map(|a| a * scalar).collect();
1621        NDArray::new(data, self.shape.clone())
1622    }
1623}
1624
1625
1626impl Add<f64> for NDArray {
1627    type Output = Self;
1628
1629    fn add(self, scalar: f64) -> Self::Output {
1630        self.add_scalar(scalar)
1631    }
1632}
1633
1634impl Mul<&NDArray> for f64 {
1635    type Output = NDArray;
1636
1637    fn mul(self, rhs: &NDArray) -> NDArray {
1638        rhs.multiply_scalar(self)
1639    }
1640}
1641
1642// Add std::fmt::Display implementation for convenient printing
1643impl std::fmt::Display for NDArray {
1644    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1645        write!(f, "{}", self.display())
1646    }
1647}
1648
1649impl Sub<&NDArray> for NDArray {
1650    type Output = Self;
1651
1652    fn sub(self, other: &NDArray) -> Self::Output {
1653        if self.shape != other.shape {
1654            panic!("Shapes must match for element-wise subtraction");
1655        }
1656        let data: Vec<f64> = self.data.iter()
1657            .zip(other.data.iter())
1658            .map(|(a, b)| a - b)
1659            .collect();
1660        NDArray::new(data, self.shape.clone())
1661    }
1662}
1663
1664/// Implements element-wise subtraction between two NDArray references
1665///
1666/// # Arguments
1667///
1668/// * `self` - The first NDArray reference
1669/// * `other` - The second NDArray reference to subtract from the first
1670///
1671/// # Returns
1672///
1673/// A new NDArray containing the element-wise difference
1674///
1675/// # Panics
1676///
1677/// Panics if the shapes of the two arrays don't match
1678impl<'a, 'b> Sub<&'b NDArray> for &'a NDArray {
1679    type Output = NDArray;
1680
1681    fn sub(self, other: &'b NDArray) -> NDArray {
1682        if self.shape != other.shape {
1683            panic!("Shapes must match for element-wise subtraction");
1684        }
1685        let data: Vec<f64> = self.data.iter()
1686            .zip(other.data.iter())
1687            .map(|(a, b)| a - b)
1688            .collect();
1689        NDArray::new(data, self.shape.clone())
1690    }
1691}
1692
1693/// Implements element-wise addition between two NDArray references
1694///
1695/// # Arguments
1696///
1697/// * `self` - The first NDArray reference
1698/// * `other` - The second NDArray reference to add to the first
1699///
1700/// # Returns
1701///
1702/// A new NDArray containing the element-wise sum
1703///
1704/// # Panics
1705///
1706/// Panics if the shapes of the two arrays don't match
1707impl<'a, 'b> Add<&'b NDArray> for &'a NDArray {
1708    type Output = NDArray;
1709
1710    fn add(self, other: &'b NDArray) -> NDArray {
1711        if self.shape != other.shape {
1712            panic!("Shapes must match for element-wise addition");
1713        }
1714        let data: Vec<f64> = self.data.iter()
1715            .zip(other.data.iter())
1716            .map(|(a, b)| a + b)
1717            .collect();
1718        NDArray::new(data, self.shape.clone())
1719    }
1720}
1721
1722/// Implements element-wise multiplication between two NDArray references
1723///
1724/// # Arguments
1725///
1726/// * `self` - The first NDArray reference
1727/// * `other` - The second NDArray reference to multiply with the first
1728///
1729/// # Returns
1730///
1731/// A new NDArray containing the element-wise product
1732///
1733/// # Panics
1734///
1735/// Panics if the shapes of the two arrays don't match
1736impl<'a, 'b> Mul<&'b NDArray> for &'a NDArray {
1737    type Output = NDArray;
1738
1739    fn mul(self, other: &'b NDArray) -> NDArray {
1740        if self.shape != other.shape {
1741            panic!("Shapes must match for element-wise multiplication");
1742        }
1743        let data: Vec<f64> = self.data.iter()
1744            .zip(other.data.iter())
1745            .map(|(a, b)| a * b)
1746            .collect();
1747        NDArray::new(data, self.shape.clone())
1748    }
1749}
1750
1751/// Implements element-wise division between two NDArray references
1752///
1753/// # Arguments
1754///
1755/// * `self` - The first NDArray reference (numerator)
1756/// * `other` - The second NDArray reference (denominator)
1757///
1758/// # Returns
1759///
1760/// A new NDArray containing the element-wise quotient
1761///
1762/// # Panics
1763///
1764/// Panics if the shapes of the two arrays don't match
1765impl<'a, 'b> Div<&'b NDArray> for &'a NDArray {
1766    type Output = NDArray;
1767
1768    fn div(self, other: &'b NDArray) -> NDArray {
1769        if self.shape != other.shape {
1770            panic!("Shapes must match for element-wise division");
1771        }
1772        let data: Vec<f64> = self.data.iter()
1773            .zip(other.data.iter())
1774            .map(|(a, b)| a / b)
1775            .collect();
1776        NDArray::new(data, self.shape.clone())
1777    }
1778}
1779
1780#[cfg(test)]
1781mod tests {
1782    use super::*;
1783   
1784
1785    /// Tests basic NDArray creation with explicit data and shape
1786    #[test]
1787    fn test_new_ndarray() {
1788        let data = vec![1.0, 2.0, 3.0, 4.0];
1789        let shape = vec![2, 2];
1790        let array = NDArray::new(data.clone(), shape.clone());
1791        assert_eq!(array.data(), &data);
1792        assert_eq!(array.shape(), &shape);
1793    }
1794
1795    /// Tests creation of 1D array from vector
1796    #[test]
1797    fn test_from_vec() {
1798        let data = vec![1.0, 2.0, 3.0];
1799        let array = NDArray::from_vec(data.clone());
1800        assert_eq!(array.data(), &data);
1801        assert_eq!(array.shape(), &[3]);
1802    }
1803
1804    /// Tests array creation with evenly spaced values
1805    #[test]
1806    fn test_arange() {
1807        let array = NDArray::arange(0.0, 5.0, 1.0);
1808        assert_eq!(array.data(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
1809    }
1810
1811    /// Tests element-wise addition between two arrays
1812    #[test]
1813    fn test_element_wise_addition() {
1814        let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1815        let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
1816        let sum = arr1.clone() + arr2;
1817        assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
1818    }
1819
1820    /// Tests multiplication of array by scalar value
1821    #[test]
1822    fn test_scalar_multiplication() {
1823        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1824        let scaled = arr.clone() * 2.0;
1825        assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
1826    }
1827
1828    /// Tests reshaping array to new dimensions while preserving data
1829    #[test]
1830    fn test_reshape() {
1831        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1832        let reshaped = arr.reshape(&[2, 3])
1833            .expect("Failed to reshape array to valid dimensions");
1834        assert_eq!(reshaped.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1835    }
1836
1837    /// Tests element-wise subtraction between arrays
1838    #[test]
1839    fn test_element_wise_subtraction() {
1840        let arr1 = NDArray::from_vec(vec![5.0, 7.0, 9.0]);
1841        let arr2 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1842        let diff = arr1 - arr2;
1843        assert_eq!(diff.data(), &[4.0, 5.0, 6.0]);
1844    }
1845
1846    /// Tests addition of scalar to array
1847    #[test]
1848    fn test_scalar_addition() {
1849        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1850        let result = arr + 1.0;
1851        assert_eq!(result.data(), &[2.0, 3.0, 4.0]);
1852    }
1853
1854    /// Tests combination of multiple operations in sequence
1855    #[test]
1856    #[allow(non_snake_case)]
1857    fn test_combined_operations() {
1858        let X = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1859        let theta_1 = 2.0;
1860        let theta_0 = 1.0;
1861        let predictions = X.clone() * theta_1 + theta_0;
1862        assert_eq!(predictions.data(), &[3.0, 5.0, 7.0]);
1863    }
1864
1865    /// Tests one-hot encoding of label vectors
1866    #[test]
1867    fn test_one_hot_encode() {
1868        let labels = NDArray::from_vec(vec![0.0, 1.0, 2.0, 1.0, 0.0]);
1869        let one_hot = NDArray::one_hot_encode(&labels);
1870        
1871        let expected = vec![
1872            1.0, 0.0, 0.0,
1873            0.0, 1.0, 0.0,
1874            0.0, 0.0, 1.0,
1875            0.0, 1.0, 0.0,
1876            1.0, 0.0, 0.0
1877        ];
1878        
1879        assert_eq!(one_hot.shape(), &[5, 3]);
1880        assert_eq!(one_hot.data(), &expected);
1881    }
1882
1883    /// Tests one-hot encoding with negative label values
1884    #[test]
1885    fn test_one_hot_encode_negative() {
1886        let labels = NDArray::from_vec(vec![-1.0, 0.0, 1.0, 0.0, -1.0]);
1887        let one_hot = NDArray::one_hot_encode(&labels);
1888        
1889        let expected = vec![
1890            1.0, 0.0, 0.0,
1891            0.0, 1.0, 0.0,
1892            0.0, 0.0, 1.0,
1893            0.0, 1.0, 0.0,
1894            1.0, 0.0, 0.0
1895        ];
1896        
1897        assert_eq!(one_hot.shape(), &[5, 3]);
1898        assert_eq!(one_hot.data(), &expected);
1899    }
1900
1901    /// Tests that one-hot encoding fails with non-integer values
1902    #[test]
1903    #[should_panic(expected = "All values must be integers for one-hot encoding")]
1904    fn test_one_hot_encode_non_integer() {
1905        let labels = NDArray::from_vec(vec![0.0, 1.5, 2.0]);
1906        NDArray::one_hot_encode(&labels);
1907    }
1908
1909    /// Tests matrix transposition operation
1910    #[test]
1911    fn test_transpose() {
1912        let arr = NDArray::from_matrix(vec![
1913            vec![1.0, 2.0, 3.0],
1914            vec![4.0, 5.0, 6.0]
1915        ]);
1916        let transposed = arr.transpose()
1917            .expect("Failed to transpose valid 2D array");
1918        assert_eq!(transposed.shape(), &[3, 2]);
1919        assert_eq!(transposed.data(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1920    }
1921
1922    /// Tests matrix multiplication (dot product)
1923    #[test]
1924    fn test_dot() {
1925        let arr1 = NDArray::from_matrix(vec![
1926            vec![1.0, 2.0, 3.0],
1927            vec![4.0, 5.0, 6.0],
1928        ]);
1929        let arr2 = NDArray::from_matrix(vec![
1930            vec![7.0, 8.0],
1931            vec![9.0, 10.0],
1932            vec![11.0, 12.0],
1933        ]);
1934        let dot = arr1.dot(&arr2);
1935        assert_eq!(dot.data(), &[58.0, 64.0, 139.0, 154.0]); // Adjust expected values based on the dot product calculation
1936    }
1937
1938    /// Tests element-wise multiplication between arrays
1939    #[test]
1940    fn test_multiply() {
1941        let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1942        let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
1943        let multiply = arr1.multiply(&arr2);
1944        assert_eq!(multiply.data(), &[4.0, 10.0, 18.0]);
1945    }
1946
1947    #[test]
1948    fn test_multiply_large() {
1949        // Test small array (< 1000 elements)
1950        let arr1 = NDArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1951        let arr2 = NDArray::new(vec![2.0, 3.0, 4.0, 5.0], vec![2, 2]);
1952        let result = arr1.multiply(&arr2);
1953        assert_eq!(result.data(), &[2.0, 6.0, 12.0, 20.0]);
1954
1955        // Test large array (> 1000 elements) 
1956        let large_arr1 = NDArray::new(vec![1.0; 2000], vec![1000, 2]);
1957        let large_arr2 = NDArray::new(vec![2.0; 2000], vec![1000, 2]);
1958        let large_result = large_arr1.multiply(&large_arr2);
1959        assert_eq!(large_result.data(), &vec![2.0; 2000]);
1960
1961        // Test mismatched shapes
1962        let arr3 = NDArray::new(vec![1.0, 2.0], vec![2, 1]);
1963        let arr4 = NDArray::new(vec![1.0, 2.0, 3.0], vec![3, 1]);
1964        let result = std::panic::catch_unwind(|| arr3.multiply(&arr4));
1965        assert!(result.is_err());
1966    }
1967
1968    /// Tests subtraction of scalar from array
1969    #[test]
1970    fn test_scalar_sub() {
1971        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1972        let result = arr.scalar_sub(1.0);
1973        assert_eq!(result.data(), &[0.0, 1.0, 2.0]);
1974    }
1975
1976    /// Tests multiplication by scalar value
1977    #[test]
1978    fn test_multiply_scalar() {
1979        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1980        let result = arr.multiply_scalar(2.0);
1981        assert_eq!(result.data(), &[2.0, 4.0, 6.0]);
1982    }
1983
1984    /// Tests mapping function across array elements
1985    #[test]
1986    fn test_map() {
1987        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1988        let result = arr.map(|x| x * 2.0);
1989        assert_eq!(result.data(), &[2.0, 4.0, 6.0]);
1990    }
1991
1992    /// Tests clipping array values to specified range
1993    #[test]
1994    fn test_clip() {
1995        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1996        let result = arr.clip(1.0, 2.0);
1997        assert_eq!(result.data(), &[1.0, 2.0, 2.0]);
1998    }
1999
2000    /// Tests element-wise division between arrays
2001    #[test]
2002    fn test_divide() {
2003        let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2004        let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2005        let divide = arr1.divide(&arr2);
2006        assert_eq!(divide.data(), &[0.25, 0.4, 0.5]);
2007    }
2008
2009    /// Tests division by scalar value
2010    #[test]
2011    fn test_divide_scalar() {
2012        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2013        let result = arr.divide_scalar(2.0);
2014        assert_eq!(result.data(), &[0.5, 1.0, 1.5]);
2015    }
2016
2017    /// Tests sum operation along specified axis
2018    #[test]
2019    fn test_sum_axis() {
2020        let arr = NDArray::from_matrix(vec![
2021            vec![1.0, 2.0, 3.0],
2022            vec![4.0, 5.0, 6.0],
2023        ]);
2024        let result = arr.sum_axis(0);
2025        assert_eq!(result.data(), &[5.0, 7.0, 9.0]); // Sum along columns
2026        assert_eq!(result.shape(), &[1, 3]); // Shape should be [1, 3]
2027
2028        let result = arr.sum_axis(1);
2029        assert_eq!(result.data(), &[6.0, 15.0]); // Sum along rows
2030        assert_eq!(result.shape(), &[2, 1]); // Shape should be [2, 1]
2031    }
2032
2033    /// Tests element-wise subtraction between arrays
2034    #[test]
2035    fn test_subtract() {
2036        let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2037        let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2038        let subtract = arr1.subtract(&arr2);
2039        assert_eq!(subtract.data(), &[-3.0, -3.0, -3.0]);
2040    }
2041
2042    /// Tests addition of scalar to array
2043    #[test]
2044    fn test_add_scalar() {
2045        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2046        let result = arr.add_scalar(1.0);
2047        assert_eq!(result.data(), &[2.0, 3.0, 4.0]);
2048    }
2049
2050    /// Tests creation of zero-filled array
2051    #[test]
2052    fn test_zeros() {
2053        let shape = vec![2, 3];
2054        let zeros = NDArray::zeros(shape);
2055        assert_eq!(zeros.shape(), &[2, 3]);
2056        assert_eq!(zeros.data(), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
2057    }
2058
2059    /// Tests natural logarithm of array elements
2060    #[test]
2061    fn test_log() {
2062        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2063        let result = arr.log();
2064        assert_eq!(result.data(), &[0.0, 0.6931471805599453, 1.0986122886681098]);
2065    }
2066
2067    /// Tests sum of all array elements
2068    #[test]
2069    fn test_sum() {
2070        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2071        let result = arr.sum();
2072        assert_eq!(result, 6.0);
2073    }
2074
2075    /// Tests calculation of array mean
2076    #[test]
2077    fn test_mean() {
2078        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
2079        assert_eq!(arr.mean(), 2.5);
2080    }
2081
2082    /// Tests calculation of array standard deviation
2083    #[test]
2084    fn test_std() {
2085        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
2086        assert!((arr.std() - 1.118034).abs() < 1e-6);
2087    }
2088
2089    /// Tests finding minimum values along specified axis
2090    #[test]
2091    fn test_min_axis() {
2092        let arr = NDArray::from_matrix(vec![
2093            vec![1.0, 2.0, 3.0],
2094            vec![4.0, 0.5, 6.0],
2095        ]);
2096        
2097        let min_axis_0 = arr.min_axis(0).unwrap();
2098        assert_eq!(min_axis_0.data(), &[1.0, 0.5, 3.0]);
2099        
2100        let min_axis_1 = arr.min_axis(1).unwrap();
2101        assert_eq!(min_axis_1.data(), &[1.0, 0.5]);
2102    }
2103
2104    /// Tests array concatenation along specified axis
2105    #[test]
2106    fn test_concatenate() {
2107        let arr1 = NDArray::from_matrix(vec![
2108            vec![1.0, 2.0],
2109            vec![3.0, 4.0],
2110        ]);
2111        let arr2 = NDArray::from_matrix(vec![
2112            vec![5.0, 6.0],
2113            vec![7.0, 8.0],
2114        ]);
2115        
2116        let concat_0 = arr1.concatenate(&arr2, 0).unwrap();
2117        assert_eq!(concat_0.shape(), &[4, 2]);
2118        assert_eq!(concat_0.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2119        
2120        let concat_1 = arr1.concatenate(&arr2, 1).unwrap();
2121        assert_eq!(concat_1.shape(), &[2, 4]);
2122        assert_eq!(concat_1.data(), &[1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]);
2123    }
2124
2125    /// Tests broadcasting addition between arrays of different shapes
2126    #[test]
2127    fn test_broadcast_addition() {
2128        let arr1 = NDArray::from_matrix(vec![
2129            vec![1.0, 2.0, 3.0],
2130            vec![4.0, 5.0, 6.0]
2131        ]);
2132        let arr2 = NDArray::from_matrix(vec![
2133            vec![1.0, 2.0, 3.0]
2134        ]);
2135        let result = arr1 + &arr2;
2136        assert_eq!(result.shape(), &[2, 3]);
2137        assert_eq!(result.data(), &[2.0, 4.0, 6.0, 5.0, 7.0, 9.0]);
2138    }
2139
2140    /// Tests finding maximum value in array
2141    #[test]
2142    fn test_max() {
2143        let arr = NDArray::from_vec(vec![1.0, 5.0, 3.0, 2.0]);
2144        assert_eq!(arr.max(), 5.0);
2145    }
2146
2147    /// Tests sum operation with broadcasting
2148    #[test]
2149    fn test_sum_with_broadcasting() {
2150        let arr = NDArray::from_matrix(vec![
2151            vec![1.0, 2.0, 3.0],
2152            vec![4.0, 5.0, 6.0]
2153        ]);
2154        let sum_cols = arr.sum_axis(0);
2155        assert_eq!(sum_cols.shape(), &[1, 3]);
2156        assert_eq!(sum_cols.data(), &[5.0, 7.0, 9.0]);
2157
2158        // Test broadcasting the sum back
2159        let result = arr + &sum_cols;
2160        assert_eq!(result.data(), &[6.0, 9.0, 12.0, 9.0, 12.0, 15.0]);
2161    }
2162
2163    /// Tests various scalar operations (multiplication and addition)
2164    #[test]
2165    fn test_scalar_operations() {
2166        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2167        
2168        // Test scalar multiplication from both sides
2169        let result1 = arr.clone() * 2.0;
2170        let result2 = 2.0 * &arr;
2171        assert_eq!(result1.data(), result2.data());
2172        
2173        // Test scalar addition
2174        let result3 = arr + 1.0;
2175        assert_eq!(result3.data(), &[2.0, 3.0, 4.0]);
2176    }
2177
2178    /// Tests error handling for invalid array addition
2179    #[test]
2180    #[should_panic(expected = "Shapes must match for element-wise addition")]
2181    fn test_invalid_addition() {
2182        let arr1 = NDArray::from_matrix(vec![vec![1.0, 2.0]]);
2183        let arr2 = NDArray::from_matrix(vec![vec![1.0, 2.0, 3.0]]);
2184        let _result = arr1 + arr2;
2185    }
2186
2187    /// Tests chaining multiple operations together
2188    #[test]
2189    fn test_chained_operations() {
2190        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2191        let result = (arr * 2.0 + 1.0).multiply_scalar(3.0);
2192        assert_eq!(result.data(), &[9.0, 15.0, 21.0]);
2193    }
2194
2195    /// Tests absolute value calculation
2196    #[test]
2197    fn test_abs() {
2198        let arr = NDArray::from_vec(vec![-1.0, 2.0, -3.0]);
2199        let result = arr.abs();
2200        assert_eq!(result.data(), &[1.0, 2.0, 3.0]);
2201    }
2202
2203    /// Tests exponential power calculation
2204    #[test]
2205    fn test_power() {
2206        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2207        let result = arr.power(2.0);
2208        assert_eq!(result.data(), &[1.0, 4.0, 9.0]);
2209    }
2210
2211    /// Tests cumulative sum calculation
2212    #[test]
2213    fn test_cumsum() {
2214        let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
2215        let result = arr.cumsum();
2216        assert_eq!(result.data(), &[1.0, 3.0, 6.0, 10.0]);
2217    }
2218
2219    /// Tests rounding to specified decimals
2220    #[test]
2221    fn test_round() {
2222        let arr = NDArray::from_vec(vec![1.234, 2.345, 3.456]);
2223        let result = arr.round(2);
2224        assert_eq!(result.data(), &[1.23, 2.35, 3.46]);
2225    }
2226
2227    /// Tests getting indices that would sort the array
2228    #[test]
2229    fn test_argsort() {
2230        let arr = NDArray::from_vec(vec![3.0, 1.0, 2.0]);
2231        let indices = arr.argsort();
2232        assert_eq!(indices, vec![1, 2, 0]);
2233    }
2234
2235    /// Tests finding argmax along different axes
2236    #[test]
2237    fn test_argmax() {
2238        let arr = NDArray::from_matrix(vec![
2239            vec![1.0, 3.0, 2.0],
2240            vec![4.0, 2.0, 6.0]
2241        ]);
2242        
2243        // Test global argmax
2244        assert_eq!(arr.argmax(None), vec![5]); // 6.0 is at index 5
2245        
2246        // Test argmax along axis 0
2247        assert_eq!(arr.argmax(Some(0)), vec![1, 0, 1]); // Max along columns
2248        
2249        // Test argmax along axis 1
2250        assert_eq!(arr.argmax(Some(1)), vec![1, 2]); // Max along rows
2251    }
2252
2253    /// Tests finding unique values in array
2254    #[test]
2255    fn test_unique() {
2256        let arr = NDArray::from_vec(vec![3.0, 1.0, 2.0, 1.0, 3.0]);
2257        let unique = arr.unique();
2258        assert_eq!(unique.data(), &[1.0, 2.0, 3.0]);
2259    }
2260
2261    /// Tests conditional value selection
2262    #[test]
2263    fn test_where_cond() {
2264        let arr = NDArray::from_vec(vec![-1.0, 2.0, -3.0, 4.0]);
2265        let result = arr.where_cond(|x| x > 0.0, 1.0, -1.0);
2266        assert_eq!(result.data(), &[-1.0, 1.0, -1.0, 1.0]);
2267    }
2268
2269    /// Tests median calculation
2270    #[test]
2271    fn test_median() {
2272        // Test odd number of elements
2273        let arr1 = NDArray::from_vec(vec![1.0, 3.0, 2.0]);
2274        assert_eq!(arr1.median(), 2.0);
2275        
2276        // Test even number of elements
2277        let arr2 = NDArray::from_vec(vec![1.0, 3.0, 2.0, 4.0]);
2278        assert_eq!(arr2.median(), 2.5);
2279    }
2280
2281    /// Tests finding maximum values along specified axis
2282    #[test]
2283    fn test_max_axis() {
2284        let arr = NDArray::from_matrix(vec![
2285            vec![1.0, 2.0, 3.0],
2286            vec![4.0, 0.5, 6.0],
2287        ]);
2288        
2289        let max_axis_0 = arr.max_axis(0);
2290        assert_eq!(max_axis_0.shape(), &[1, 3]);
2291        assert_eq!(max_axis_0.data(), &[4.0, 2.0, 6.0]); // Max along columns
2292        
2293        let max_axis_1 = arr.max_axis(1);
2294        assert_eq!(max_axis_1.shape(), &[2, 1]);
2295        assert_eq!(max_axis_1.data(), &[3.0, 6.0]); // Max along rows
2296    }
2297
2298    /// Tests element-wise subtraction between NDArray references
2299    #[test]
2300    fn test_element_wise_subtraction_ref() {
2301        let arr1 = NDArray::from_vec(vec![5.0, 7.0, 9.0]);
2302        let arr2 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2303        let diff = &arr1 - &arr2;
2304        assert_eq!(diff.data(), &[4.0, 5.0, 6.0]);
2305    }
2306
2307    /// Tests element-wise addition between NDArray references
2308    #[test]
2309    fn test_element_wise_addition_ref() {
2310        let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2311        let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2312        let sum = &arr1 + &arr2;
2313        assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
2314    }
2315
2316    /// Tests element-wise multiplication between NDArray references
2317    #[test]
2318    fn test_element_wise_multiplication_ref() {
2319        let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2320        let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2321        let product = &arr1 * &arr2;
2322        assert_eq!(product.data(), &[4.0, 10.0, 18.0]);
2323    }
2324
2325    /// Tests element-wise division between NDArray references
2326    #[test]
2327    fn test_element_wise_division_ref() {
2328        let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2329        let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2330        let quotient = &arr1 / &arr2;
2331        assert_eq!(quotient.data(), &[0.25, 0.4, 0.5]);
2332    }
2333
2334    /// Tests random uniform distribution generation
2335    #[test]
2336    fn test_rand_uniform() {
2337        // Test shape
2338        let shape = [2, 3];
2339        let arr = NDArray::rand_uniform(&shape);
2340        assert_eq!(arr.shape(), &[2, 3]);
2341
2342        // Test range (should be between 0 and 1)
2343        for &val in arr.data() {
2344            assert!(val >= 0.0 && val <= 1.0);
2345        }
2346
2347        // Test randomness (generate multiple arrays and verify they're different)
2348        let arr2 = NDArray::rand_uniform(&shape);
2349        assert_ne!(arr.data(), arr2.data(), "Random arrays should be different");
2350
2351        // Test distribution (roughly uniform)
2352        let large_arr = NDArray::rand_uniform(&[1000]);
2353        let mean = large_arr.mean();
2354        let std = large_arr.std();
2355        
2356        // For uniform distribution between 0 and 1:
2357        // Expected mean = 0.5
2358        // Expected std = 1/sqrt(12) ≈ 0.289
2359        assert!((mean - 0.5).abs() < 0.1, "Mean should be approximately 0.5");
2360        assert!((std - 0.289).abs() < 0.1, "Std should be approximately 0.289");
2361    }
2362
2363    #[test]
2364    fn test_statistical_functions() {
2365        let arr = NDArray::from_matrix(vec![
2366            vec![1.0, 2.0, 3.0],
2367            vec![4.0, 5.0, 6.0],
2368        ]);
2369
2370        // Test mean_axis
2371        let mean_cols = arr.mean_axis(0);
2372        assert_eq!(mean_cols.shape(), &[1, 3]);
2373        assert_eq!(mean_cols.data(), &[2.5, 3.5, 4.5]);
2374
2375        // Test var_axis
2376        let var_cols = arr.var_axis(0);
2377        assert_eq!(var_cols.shape(), &[1, 3]);
2378        assert_eq!(var_cols.data(), &[2.25, 2.25, 2.25]);
2379
2380        // Test sqrt (using NabMath trait)
2381        let sqrt = arr.sqrt();  // This now uses the implementation from nab_math.rs
2382        assert_eq!(sqrt.data(), &[1.0, 2.0_f64.sqrt(), 3.0_f64.sqrt(), 
2383                                 2.0, 5.0_f64.sqrt(), 6.0_f64.sqrt()]);
2384
2385        // Test add_scalar (using NabMath trait)
2386        let added = arr.add_scalar(1.0);  // This now uses the implementation from nab_math.rs
2387        assert_eq!(added.data(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
2388    }
2389
2390    #[test]
2391    fn test_to_categorical() {
2392        // Test basic functionality
2393        let labels = NDArray::from_vec(vec![0.0, 1.0, 2.0]);
2394        let categorical = labels.to_categorical(None);
2395        assert_eq!(categorical.shape(), &[3, 3]);
2396        assert_eq!(categorical.data(), &[
2397            1.0, 0.0, 0.0,
2398            0.0, 1.0, 0.0,
2399            0.0, 0.0, 1.0
2400        ]);
2401
2402        // Test with explicit num_classes
2403        let labels = NDArray::from_vec(vec![0.0, 1.0]);
2404        let categorical = labels.to_categorical(Some(3));
2405        assert_eq!(categorical.shape(), &[2, 3]);
2406        assert_eq!(categorical.data(), &[
2407            1.0, 0.0, 0.0,
2408            0.0, 1.0, 0.0
2409        ]);
2410
2411        // Test with negative labels
2412        let labels = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
2413        let categorical = labels.to_categorical(Some(3));
2414        assert_eq!(categorical.shape(), &[3, 3]);
2415        assert_eq!(categorical.data(), &[
2416            1.0, 0.0, 0.0,
2417            0.0, 1.0, 0.0,
2418            0.0, 0.0, 1.0
2419        ]);
2420    }
2421}