1use rand::Rng;
2use rand_distr::{StandardNormal, Uniform, Distribution};
3use std::ops::{Add, Sub, Mul, Div};
4use rayon::prelude::*;
5#[derive(Debug, Clone)]
6pub struct NDArray {
7 pub data: Vec<f64>,
8 pub shape: Vec<usize>,
9}
10
11impl NDArray {
12 pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Self {
13 let total_size: usize = shape.iter().product();
14 assert_eq!(data.len(), total_size, "Data length must match shape dimensions");
15 NDArray { data, shape }
16 }
17
18 pub fn from_vec(data: Vec<f64>) -> Self {
19 let len = data.len();
20 Self::new(data, vec![len])
21 }
22
23 #[allow(dead_code)]
24 pub fn from_matrix(data: Vec<Vec<f64>>) -> Self {
25 let rows = data.len();
26 let cols = data.get(0).map_or(0, |row| row.len());
27 let flat_data: Vec<f64> = data.into_iter().flatten().collect();
28 Self::new(flat_data, vec![rows, cols])
29 }
30
31 pub fn shape(&self) -> &[usize] {
32 &self.shape
33 }
34
35 pub fn ndim(&self) -> usize {
36 self.shape.len()
37 }
38
39 pub fn data(&self) -> &[f64] {
41 &self.data
42 }
43
44 #[allow(dead_code)]
55 pub fn rand_2d(rows: usize, cols: usize) -> Self {
56 let mut rng = rand::thread_rng();
57 let data: Vec<f64> = (0..rows * cols).map(|_| rng.gen()).collect();
58 Self::new(data, vec![rows, cols])
59 }
60
61
62 #[allow(dead_code)]
72 pub fn randn(size: usize) -> Self {
73 let mut rng = rand::thread_rng();
74 let data: Vec<f64> = (0..size).map(|_| rng.sample(StandardNormal)).collect();
75 Self::from_vec(data)
76 }
77
78 #[allow(dead_code)]
89 pub fn randn_2d(rows: usize, cols: usize) -> Self {
90 let mut rng = rand::thread_rng();
91 let data: Vec<f64> = (0..rows * cols).map(|_| rng.sample(StandardNormal)).collect();
92 Self::new(data, vec![rows, cols])
93 }
94
95 #[allow(dead_code)]
107 pub fn randint(low: i32, high: i32, size: usize) -> Self {
108 let mut rng = rand::thread_rng();
109 let data: Vec<f64> = (0..size).map(|_| rng.gen_range(low..high) as f64).collect();
110 Self::from_vec(data)
111 }
112
113 #[allow(dead_code)]
126 pub fn randint_2d(low: i32, high: i32, rows: usize, cols: usize) -> Self {
127 let mut rng = rand::thread_rng();
128 let data: Vec<f64> = (0..rows * cols).map(|_| rng.gen_range(low..high) as f64).collect();
129 Self::new(data, vec![rows, cols])
130 }
131
132 pub fn reshape(&self, new_shape: &[usize]) -> Result<Self, &'static str> {
142 let total_elements = self.data.len();
143 let new_total: usize = new_shape.iter().copied().product();
144
145 if total_elements != new_total {
146 return Err("New shape must have same total size as original");
147 }
148
149 Ok(NDArray {
150 data: self.data.clone(),
151 shape: new_shape.to_vec()
152 })
153 }
154
155 #[allow(dead_code)]
161 pub fn max(&self) -> f64 {
162 *self.data.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap()
163 }
164
165 pub fn argmax(&self, axis: Option<usize>) -> Vec<usize> {
179 match axis {
180 None => {
181 vec![self.data.iter()
183 .enumerate()
184 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
185 .map(|(i, _)| i)
186 .unwrap()]
187 },
188 Some(ax) => {
189 if ax >= self.shape.len() {
190 panic!("Axis {} out of bounds for shape {:?}", ax, self.shape);
191 }
192 match ax {
194 0 => {
195 let cols = self.shape[1];
196 let mut indices = Vec::with_capacity(cols);
197 for j in 0..cols {
198 let mut max_idx = 0;
199 let mut max_val = self.data[j];
200 for i in 1..self.shape[0] {
201 let val = self.data[i * cols + j];
202 if val > max_val {
203 max_val = val;
204 max_idx = i;
205 }
206 }
207 indices.push(max_idx);
208 }
209 indices
210 },
211 1 => {
212 let cols = self.shape[1];
213 let mut indices = Vec::with_capacity(self.shape[0]);
214 for i in 0..self.shape[0] {
215 let row_start = i * cols;
216 let mut max_idx = 0;
217 let mut max_val = self.data[row_start];
218 for j in 1..cols {
219 let val = self.data[row_start + j];
220 if val > max_val {
221 max_val = val;
222 max_idx = j;
223 }
224 }
225 indices.push(max_idx);
226 }
227 indices
228 },
229 _ => panic!("Unsupported axis {}", ax)
230 }
231 }
232 }
233 }
234
235
236 #[allow(dead_code)]
242 pub fn min(&self) -> f64 {
243 *self.data.iter().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap()
244 }
245
246 #[allow(dead_code)]
257 pub fn from_vec_reshape(data: Vec<f64>, shape: Vec<usize>) -> Self {
258 let total_size: usize = shape.iter().product();
259 assert_eq!(data.len(), total_size, "Data length must match shape dimensions");
260 NDArray { data, shape }
261 }
262
263 #[allow(dead_code)]
273 pub fn extract_sample(&self, sample_index: usize) -> Self {
274 assert!(self.ndim() >= 2, "Array must have at least 2 dimensions");
275 assert!(sample_index < self.shape[0], "Sample index out of bounds");
276
277 let sample_size: usize = self.shape.iter().skip(1).product();
278 let start_index = sample_index * sample_size;
279 let end_index = start_index + sample_size;
280
281 let new_shape: Vec<usize> = self.shape.iter().skip(1).cloned().collect();
283
284 NDArray::new(
285 self.data[start_index..end_index].to_vec(),
286 new_shape
287 )
288 }
289
290 #[allow(dead_code)]
296 pub fn pretty_print(&self, precision: usize) {
297 let indent_str = " ".repeat(precision);
298
299 let format_value = |x: f64| -> String {
300 if x == 0.0 {
301 format!("{:.1}", x)
302 } else {
303 format!("{:.*}", precision, x)
304 }
305 };
306
307 match self.ndim() {
308 1 => println!("{}[{}]", indent_str, self.data.iter()
309 .map(|&x| format_value(x))
310 .collect::<Vec<_>>()
311 .join(" ")),
312
313 2 => {
314 println!("{}[", indent_str);
315 for i in 0..self.shape[0] {
316 print!("{} [", indent_str);
317 for j in 0..self.shape[1] {
318 print!("{}", format_value(self.get_2d(i, j)));
319 if j < self.shape[1] - 1 {
320 print!(" ");
321 }
322 }
323 println!("]");
324 }
325 println!("{}]", indent_str);
326 },
327
328 _ => {
329 println!("{}[", indent_str);
330 for i in 0..self.shape[0] {
331 let slice = self.extract_sample(i);
332 slice.pretty_print(precision + 2);
333 }
334 println!("{}]", indent_str);
335 }
336 }
337 }
338
339
340 #[allow(dead_code)]
350 pub fn get(&self, index: usize) -> f64 {
351 self.data[index]
352 }
353
354 #[allow(dead_code)]
366 pub fn arange(start: f64, stop: f64, step: f64) -> Self {
367 let mut data = Vec::new();
368 let mut current = start;
369 while current < stop {
370 data.push(current);
371 current += step;
372 }
373 Self::from_vec(data)
374 }
375
376 #[allow(dead_code)]
386 pub fn zeros(shape: Vec<usize>) -> Self {
387 let total_size: usize = shape.iter().product();
388 NDArray {
389 data: vec![0.0; total_size],
390 shape,
391 }
392 }
393
394
395 #[allow(dead_code)]
406 pub fn zeros_2d(rows: usize, cols: usize) -> Self {
407 Self::new(vec![0.0; rows * cols], vec![rows, cols])
408 }
409
410 #[allow(dead_code)]
420 pub fn ones(size: usize) -> Self {
421 Self::from_vec(vec![1.0; size])
422 }
423
424 #[allow(dead_code)]
435 pub fn ones_2d(rows: usize, cols: usize) -> Self {
436 Self::new(vec![1.0; rows * cols], vec![rows, cols])
437 }
438
439 #[allow(dead_code)]
452 pub fn linspace(start: f64, end: f64, num: usize, precision: usize) -> Self {
453 assert!(num > 1, "Number of samples must be greater than 1");
454 let step = (end - start) / (num - 1) as f64;
455 let mut data = Vec::with_capacity(num);
456 let factor = 10f64.powi(precision as i32);
457 for i in 0..num {
458 let value = start + step * i as f64;
459 let rounded_value = (value * factor).round() / factor;
460 data.push(rounded_value);
461 }
462 Self::from_vec(data)
463 }
464
465 #[allow(dead_code)]
475 pub fn eye(n: usize) -> Self {
476 let mut data = vec![0.0; n * n];
477 for i in 0..n {
478 data[i * n + i] = 1.0;
479 }
480 Self::new(data, vec![n, n])
481 }
482
483 #[allow(dead_code)]
493 pub fn rand(size: usize) -> Self {
494 let mut rng = rand::thread_rng();
495 let data: Vec<f64> = (0..size).map(|_| rng.gen()).collect();
496 Self::from_vec(data)
497 }
498
499
500 #[allow(dead_code)]
513 pub fn sub_matrix(&self, row_start: usize, row_end: usize, col_start: usize, col_end: usize) -> Self {
514 assert_eq!(self.ndim(), 2, "sub_matrix is only applicable to 2D arrays");
515 let cols = self.shape[1];
516 let mut data = Vec::new();
517 for row in row_start..row_end {
518 for col in col_start..col_end {
519 data.push(self.data[row * cols + col]);
520 }
521 }
522 Self::new(data, vec![row_end - row_start, col_end - col_start])
523 }
524
525 #[allow(dead_code)]
532 pub fn set(&mut self, index: usize, value: f64) {
533 self.data[index] = value;
534 }
535
536 #[allow(dead_code)]
544 pub fn set_range(&mut self, start: usize, end: usize, value: f64) {
545 for i in start..end {
546 self.data[i] = value;
547 }
548 }
549
550 #[allow(dead_code)]
556 pub fn copy(&self) -> Self {
557 Self::new(self.data.clone(), self.shape.clone())
558 }
559
560 #[allow(dead_code)]
571 pub fn view(&self, start: usize, end: usize) -> &[f64] {
572 &self.data[start..end]
573 }
574
575 #[allow(dead_code)]
586 pub fn view_mut(&mut self, start: usize, end: usize) -> &mut [f64] {
587 &mut self.data[start..end]
588 }
589
590
591 #[allow(dead_code)]
602 pub fn get_2d(&self, row: usize, col: usize) -> f64 {
603 assert_eq!(self.ndim(), 2, "get_2d is only applicable to 2D arrays");
604 let cols = self.shape[1];
605 self.data[row * cols + col]
606 }
607
608 #[allow(dead_code)]
616 pub fn set_2d(&mut self, row: usize, col: usize, value: f64) {
617 assert_eq!(self.ndim(), 2, "set_2d is only applicable to 2D arrays");
618 let cols = self.shape[1];
619 self.data[row * cols + col] = value;
620 }
621
622 #[allow(dead_code)]
632 pub fn new_axis(&self, axis: usize) -> Self {
633 let mut new_shape = self.shape.clone();
634 new_shape.insert(axis, 1);
635 Self::new(self.data.clone(), new_shape)
636 }
637
638 #[allow(dead_code)]
648 pub fn expand_dims(&self, axis: usize) -> Self {
649 self.new_axis(axis)
650 }
651
652 #[allow(dead_code)]
662 pub fn greater_than(&self, threshold: f64) -> Vec<bool> {
663 self.data.iter().map(|&x| x > threshold).collect()
664 }
665
666 #[allow(dead_code)]
676 pub fn filter(&self, condition: impl Fn(&f64) -> bool) -> Self {
677 let data: Vec<f64> = self.data.iter().cloned().filter(condition).collect();
678 Self::from_vec(data)
679 }
680
681
682 #[allow(dead_code)]
688 pub fn dtype(&self) -> &'static str {
689 "f64" }
691
692 #[allow(dead_code)]
698 pub fn size(&self) -> usize {
699 self.data.len()
700 }
701
702 #[allow(dead_code)]
708 pub fn argmin(&self) -> usize {
709 self.data.iter().enumerate().min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()).map(|(i, _)| i).unwrap()
710 }
711
712 #[allow(dead_code)]
723 pub fn slice(&self, start: usize, end: usize) -> Self {
724 let mut new_shape = self.shape.clone();
729 new_shape[0] = end - start;
730
731 if self.ndim() == 2 {
732 let cols = self.shape[1];
733 let start_idx = start * cols;
734 let end_idx = end * cols;
735 let sliced_data = self.data[start_idx..end_idx].to_vec();
738 NDArray::new(sliced_data, new_shape)
741 } else {
742 NDArray::new(self.data[start..end].to_vec(), new_shape)
744 }
745 }
746
747 pub fn one_hot_encode(labels: &NDArray) -> Self {
761 for &value in labels.data() {
763 if value.fract() != 0.0 {
765 panic!("All values must be integers for one-hot encoding");
766 }
767 }
768
769 let labels_int: Vec<i32> = labels.data()
771 .iter()
772 .map(|&x| x as i32)
773 .collect();
774
775 let min_label = labels_int.iter().min().unwrap();
777 let max_label = labels_int.iter().max().unwrap();
778 let num_classes = (max_label - min_label + 1) as usize;
779
780 let mut data = vec![0.0; labels_int.len() * num_classes];
781
782 for (i, &label) in labels_int.iter().enumerate() {
784 let shifted_label = (label - min_label) as usize;
785 data[i * num_classes + shifted_label] = 1.0;
786 }
787
788 NDArray::new(data, vec![labels_int.len(), num_classes])
789 }
790
791 pub fn transpose(&self) -> Result<Self, &'static str> {
797 if self.shape.len() != 2 {
798 return Err("transpose currently only supports 2D arrays");
799 }
800
801 let (rows, cols) = (self.shape[0], self.shape[1]);
802 let mut new_data = vec![0.0; rows * cols];
803
804 for i in 0..rows {
805 for j in 0..cols {
806 new_data[j * rows + i] = self.data[i * cols + j];
807 }
808 }
809
810 Ok(NDArray {
811 data: new_data,
812 shape: vec![cols, rows]
813 })
814 }
815
816 pub fn dot(&self, other: &NDArray) -> Self {
826 assert_eq!(self.ndim(), 2, "Dot product is only defined for 2D arrays");
827 assert_eq!(other.ndim(), 2, "Dot product is only defined for 2D arrays");
828 assert_eq!(self.shape[1], other.shape[0], "Inner dimensions must match for dot product");
829
830 let rows = self.shape[0];
831 let cols = other.shape[1];
832 let mut result_data = vec![0.0; rows * cols];
833
834 for i in 0..rows {
835 for j in 0..cols {
836 let mut sum = 0.0;
837 for k in 0..self.shape[1] {
838 sum += self.data[i * self.shape[1] + k] * other.data[k * other.shape[1] + j];
839 }
840 result_data[i * cols + j] = sum;
841 }
842 }
843
844 NDArray::new(result_data, vec![rows, cols])
845 }
846
847 pub fn multiply(&self, other: &NDArray) -> Self {
863 assert_eq!(self.shape, other.shape, "Shapes must match for element-wise multiplication");
864
865 let data = if self.data.len() > 1000 {
866 self.data.par_iter()
867 .zip(other.data.par_iter())
868 .map(|(&a, &b)| a * b)
869 .collect()
870 } else {
871 self.data.iter()
872 .zip(other.data.iter())
873 .map(|(&a, &b)| a * b)
874 .collect()
875 };
876
877 NDArray::new(data, self.shape.clone())
878 }
879
880
881
882
883
884
885 pub fn scalar_sub(&self, scalar: f64) -> Self {
895 let data: Vec<f64> = self.data.iter().map(|&x| x - scalar).collect();
896 NDArray::new(data, self.shape.clone())
897 }
898
899 pub fn multiply_scalar(&self, scalar: f64) -> Self {
909 let data: Vec<f64> = self.data.iter().map(|&x| x * scalar).collect();
910 NDArray::new(data, self.shape.clone())
911 }
912
913 pub fn clip(&self, min: f64, max: f64) -> Self {
924 let data: Vec<f64> = self.data.iter().map(|&x| x.clamp(min, max)).collect();
925 NDArray::new(data, self.shape.clone())
926 }
927
928 pub fn divide(&self, other: &NDArray) -> Self {
938 assert_eq!(self.shape, other.shape, "Shapes must match for element-wise division");
939
940 let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(a, b)| a / b).collect();
941 NDArray::new(data, self.shape.clone())
942 }
943
944 pub fn divide_scalar(&self, scalar: f64) -> Self {
954 let data: Vec<f64> = self.data.iter().map(|&x| x / scalar).collect();
955 NDArray::new(data, self.shape.clone())
956 }
957
958 pub fn sum_axis(&self, axis: usize) -> Self {
968 if axis >= self.shape.len() {
969 panic!("Axis {} out of bounds for shape {:?}", axis, self.shape);
970 }
971
972 match axis {
973 0 => {
974 let cols = self.shape[1];
975 let mut result = vec![0.0; cols];
976
977 for j in 0..cols {
978 for i in 0..self.shape[0] {
979 result[j] += self.data[i * cols + j];
980 }
981 }
982
983 NDArray::new(result, vec![1, cols])
984 },
985 1 => {
986 let cols = self.shape[1];
987 let mut result = vec![0.0; self.shape[0]];
988
989 for i in 0..self.shape[0] {
990 for j in 0..cols {
991 result[i] += self.data[i * cols + j];
992 }
993 }
994
995 NDArray::new(result, vec![self.shape[0], 1])
996 },
997 _ => panic!("Unsupported axis {}", axis)
998 }
999 }
1000
1001 pub fn subtract(&self, other: &NDArray) -> Self {
1011 assert_eq!(self.shape, other.shape, "Shapes must match for element-wise subtraction");
1012
1013 let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(a, b)| a - b).collect();
1014 NDArray::new(data, self.shape.clone())
1015 }
1016
1017 pub fn add_scalar(&self, scalar: f64) -> Self {
1027 let data: Vec<f64> = self.data.iter().map(|&x| x + scalar).collect();
1028 NDArray::new(data, self.shape.clone())
1029 }
1030
1031 pub fn log(&self) -> Self {
1037 let data: Vec<f64> = self.data.iter().map(|&x| x.ln()).collect();
1038 NDArray::new(data, self.shape.clone())
1039 }
1040
1041 pub fn sum(&self) -> f64 {
1047 self.data.iter().sum()
1048 }
1049
1050 pub fn pad_to_size(&self, target_size: usize) -> Self {
1051 if self.shape[0] >= target_size {
1052 return self.clone();
1053 }
1054
1055 let mut new_shape = self.shape.clone();
1056 new_shape[0] = target_size;
1057 let total_size: usize = new_shape.iter().product();
1058
1059 let mut new_data = vec![0.0; total_size];
1061
1062 let row_size = self.shape.iter().skip(1).product::<usize>();
1064 let existing_data_size = self.shape[0] * row_size;
1065 new_data[..existing_data_size].copy_from_slice(&self.data);
1066
1067 NDArray::new(new_data, new_shape)
1068 }
1069
1070 pub fn layer_normalize(&self) -> Self {
1072 let (rows, cols) = (self.shape[0], self.shape[1]);
1073 let mut result = vec![0.0; self.data.len()];
1074
1075 for i in 0..rows {
1076 let start = i * cols;
1077 let end = start + cols;
1078 let row = &self.data[start..end];
1079
1080 let mean: f64 = row.iter().sum::<f64>() / cols as f64;
1082 let var: f64 = row.iter()
1083 .map(|&x| (x - mean).powi(2))
1084 .sum::<f64>() / cols as f64;
1085 let std = (var + 1e-5).sqrt();
1086
1087 for j in 0..cols {
1089 result[start + j] = (row[j] - mean) / std;
1090 }
1091 }
1092
1093 NDArray::new(result, self.shape.clone())
1094 }
1095
1096 pub fn batch_normalize(&self) -> Self {
1098 let (batch_size, features) = (self.shape[0], self.shape[1]);
1099 let mut result = vec![0.0; self.data.len()];
1100
1101 for j in 0..features {
1103 let mut mean = 0.0;
1105 let mut var = 0.0;
1106
1107 for i in 0..batch_size {
1109 mean += self.data[i * features + j];
1110 }
1111 mean /= batch_size as f64;
1112
1113 for i in 0..batch_size {
1115 var += (self.data[i * features + j] - mean).powi(2);
1116 }
1117 var /= batch_size as f64;
1118
1119 let std = (var + 1e-5).sqrt();
1121 for i in 0..batch_size {
1122 result[i * features + j] = (self.data[i * features + j] - mean) / std;
1123 }
1124 }
1125
1126 NDArray::new(result, self.shape.clone())
1127 }
1128
1129
1130 pub fn add(self, other: &NDArray) -> Self {
1131 if self.shape.len() == other.shape.len() &&
1137 other.shape[0] == 1 &&
1138 self.shape[1] == other.shape[1] {
1139
1140 let mut result_data = Vec::with_capacity(self.data.len());
1142 let cols = other.shape[1];
1143
1144 for i in 0..self.shape[0] {
1146 for j in 0..cols {
1147 result_data.push(self.data[i * cols + j] + other.data[j]);
1148 }
1149 }
1150
1151 let result = NDArray::new(result_data, self.shape.clone());
1152 return result;
1154 }
1155
1156 if self.shape != other.shape {
1158 panic!("Shapes must match for element-wise addition\n left: {:?}\n right: {:?}",
1159 self.shape, other.shape);
1160 }
1161
1162 let data: Vec<f64> = self.data.iter()
1163 .zip(other.data.iter())
1164 .map(|(a, b)| a + b)
1165 .collect();
1166
1167 NDArray::new(data, self.shape.clone())
1168 }
1169
1170 pub fn mean(&self) -> f64 {
1172 self.sum() / self.data.len() as f64
1173 }
1174
1175 pub fn std(&self) -> f64 {
1177 let mean = self.mean();
1178 let variance = self.data.iter()
1179 .map(|&x| (x - mean).powi(2))
1180 .sum::<f64>() / self.data.len() as f64;
1181 variance.sqrt()
1182 }
1183
1184 pub fn min_axis(&self, axis: usize) -> Result<Self, &'static str> {
1186 if axis >= self.shape.len() {
1187 return Err("Axis out of bounds");
1188 }
1189
1190 match axis {
1191 0 => {
1192 if self.shape.len() != 2 {
1193 return Err("min_axis(0) requires 2D array");
1194 }
1195 let cols = self.shape[1];
1196 let mut result = vec![f64::INFINITY; cols];
1197
1198 for j in 0..cols {
1199 for i in 0..self.shape[0] {
1200 result[j] = result[j].min(self.data[i * cols + j]);
1201 }
1202 }
1203
1204 Ok(NDArray::new(result, vec![1, cols]))
1205 },
1206 1 => {
1207 if self.shape.len() != 2 {
1208 return Err("min_axis(1) requires 2D array");
1209 }
1210 let cols = self.shape[1];
1211 let mut result = vec![f64::INFINITY; self.shape[0]];
1212
1213 for i in 0..self.shape[0] {
1214 for j in 0..cols {
1215 result[i] = result[i].min(self.data[i * cols + j]);
1216 }
1217 }
1218
1219 Ok(NDArray::new(result, vec![self.shape[0], 1]))
1220 },
1221 _ => Err("Unsupported axis")
1222 }
1223 }
1224
1225 pub fn concatenate(&self, other: &Self, axis: usize) -> Result<Self, &'static str> {
1227 if axis >= self.shape.len() {
1228 return Err("Axis out of bounds");
1229 }
1230
1231 if self.shape.len() != other.shape.len() {
1232 return Err("Arrays must have same number of dimensions");
1233 }
1234
1235 for (i, (&s1, &s2)) in self.shape.iter().zip(other.shape.iter()).enumerate() {
1237 if i != axis && s1 != s2 {
1238 return Err("All dimensions except concatenation axis must match");
1239 }
1240 }
1241
1242 let mut new_shape = self.shape.clone();
1243 new_shape[axis] += other.shape[axis];
1244
1245 let mut new_data = Vec::with_capacity(self.data.len() + other.data.len());
1246
1247 match axis {
1248 0 => {
1249 new_data.extend_from_slice(&self.data);
1250 new_data.extend_from_slice(&other.data);
1251 },
1252 1 => {
1253 let rows = self.shape[0];
1254 let cols1 = self.shape[1];
1255 let cols2 = other.shape[1];
1256
1257 for i in 0..rows {
1258 new_data.extend_from_slice(&self.data[i * cols1..(i + 1) * cols1]);
1259 new_data.extend_from_slice(&other.data[i * cols2..(i + 1) * cols2]);
1260 }
1261 },
1262 _ => return Err("Unsupported axis")
1263 }
1264
1265 Ok(NDArray::new(new_data, new_shape))
1266 }
1267
1268 pub fn map<F>(&self, f: F) -> Self
1269 where F: Fn(f64) -> f64
1270 {
1271 let new_data: Vec<f64> = self.data.iter().map(|&x| f(x)).collect();
1272 NDArray::new(new_data, self.shape.clone())
1273 }
1274
1275 pub fn abs(&self) -> Self {
1281 self.map(|x| x.abs())
1282 }
1283
1284 pub fn power(&self, n: f64) -> Self {
1290 self.map(|x| x.powf(n))
1291 }
1292
1293 pub fn cumsum(&self) -> Self {
1299 let mut result = Vec::with_capacity(self.data.len());
1300 let mut sum = 0.0;
1301 for &x in &self.data {
1302 sum += x;
1303 result.push(sum);
1304 }
1305 NDArray::new(result, self.shape.clone())
1306 }
1307
1308 pub fn round(&self, decimals: i32) -> Self {
1318 let factor = 10.0_f64.powi(decimals);
1319 self.map(|x| (x * factor).round() / factor)
1320 }
1321
1322 pub fn argsort(&self) -> Vec<usize> {
1328 let mut indices: Vec<usize> = (0..self.data.len()).collect();
1329 indices.sort_by(|&i, &j| self.data[i].partial_cmp(&self.data[j]).unwrap());
1330 indices
1331 }
1332
1333 pub fn unique(&self) -> Self {
1339 let mut unique_vals = self.data.clone();
1340 unique_vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
1341 unique_vals.dedup();
1342 NDArray::new(unique_vals.to_vec(), vec![unique_vals.len()])
1343 }
1344
1345 pub fn where_cond<F>(&self, condition: F, x: f64, y: f64) -> Self
1357 where F: Fn(f64) -> bool
1358 {
1359 self.map(|val| if condition(val) { x } else { y })
1360 }
1361
1362 pub fn median(&self) -> f64 {
1368 let mut sorted = self.data.clone();
1369 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1370 let mid = sorted.len() / 2;
1371 if sorted.len() % 2 == 0 {
1372 (sorted[mid - 1] + sorted[mid]) / 2.0
1373 } else {
1374 sorted[mid]
1375 }
1376 }
1377
1378 pub fn max_axis(&self, axis: usize) -> Self {
1388 if axis >= self.shape.len() {
1389 panic!("Axis {} out of bounds for shape {:?}", axis, self.shape);
1390 }
1391
1392 if self.shape.len() == 1 {
1394 return NDArray::new(vec![self.data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)], vec![1]);
1395 }
1396
1397 match axis {
1399 0 => {
1400 let cols = self.shape[1];
1401 let mut result = vec![f64::NEG_INFINITY; cols];
1402
1403 for j in 0..cols {
1404 for i in 0..self.shape[0] {
1405 result[j] = result[j].max(self.data[i * cols + j]);
1406 }
1407 }
1408
1409 NDArray::new(result, vec![1, cols])
1410 },
1411 1 => {
1412 let cols = self.shape[1];
1413 let mut result = vec![f64::NEG_INFINITY; self.shape[0]];
1414
1415 for i in 0..self.shape[0] {
1416 for j in 0..cols {
1417 result[i] = result[i].max(self.data[i * cols + j]);
1418 }
1419 }
1420
1421 NDArray::new(result, vec![self.shape[0], 1])
1422 },
1423 _ => panic!("Unsupported axis {}", axis)
1424 }
1425 }
1426
1427 pub fn display(&self) -> String {
1429 format!("NDArray(shape={:?}, data={:?})", self.shape, self.data)
1430 }
1431
1432 pub fn rand_uniform(shape: &[usize]) -> Self {
1446 let size: usize = shape.iter().product();
1447 let uniform = Uniform::new(0.0, 1.0);
1448 let mut rng = rand::thread_rng();
1449
1450 let data: Vec<f64> = (0..size)
1451 .map(|_| uniform.sample(&mut rng))
1452 .collect();
1453
1454 Self::new(data, shape.to_vec())
1455 }
1456
1457 pub fn mean_axis(&self, axis: usize) -> Self {
1462 let sum = self.sum_axis(axis);
1463 let n = if axis == 0 { self.shape[0] } else { self.shape[1] } as f64;
1464 sum.multiply_scalar(1.0 / n)
1465 }
1466
1467 pub fn var_axis(&self, axis: usize) -> Self {
1472 let mean = self.mean_axis(axis);
1473
1474 let broadcasted_mean = if axis == 0 {
1476 let mut result = Vec::with_capacity(self.data.len());
1477 let cols = self.shape[1];
1478
1479 for _ in 0..self.shape[0] {
1481 for j in 0..cols {
1482 result.push(mean.data[j]);
1483 }
1484 }
1485
1486 NDArray::new(result, self.shape.clone())
1487 } else {
1488 mean
1489 };
1490
1491 let centered = self.subtract(&broadcasted_mean);
1492 let squared = centered.multiply(¢ered);
1493 let n = if axis == 0 { self.shape[0] } else { self.shape[1] } as f64;
1494 squared.sum_axis(axis).multiply_scalar(1.0 / n)
1495 }
1496
1497 pub fn to_categorical(&self, num_classes: Option<usize>) -> Self {
1520 assert_eq!(self.ndim(), 1, "Input must be a 1D array");
1522
1523 let min_label = self.data().iter()
1525 .fold(f64::INFINITY, |a, &b| a.min(b)) as i32;
1526 let max_label = self.data().iter()
1527 .fold(f64::NEG_INFINITY, |a, &b| a.max(b)) as i32;
1528
1529 let n_classes = num_classes.unwrap_or_else(||
1531 (max_label - min_label + 1) as usize
1532 );
1533
1534 let n_samples = self.shape()[0];
1535 let mut categorical = vec![0.0; n_samples * n_classes];
1536
1537 for (sample_idx, &label) in self.data().iter().enumerate() {
1539 let shifted_label = (label as i32 - min_label) as usize;
1541 assert!(shifted_label < n_classes,
1542 "Label {} is out of range for {} classes", label, n_classes);
1543
1544 let row_offset = sample_idx * n_classes;
1545 categorical[row_offset + shifted_label] = 1.0;
1546 }
1547
1548 NDArray::new(categorical, vec![n_samples, n_classes])
1549 }
1550}
1551
1552impl Add for NDArray {
1553 type Output = Self;
1554
1555 fn add(self, other: Self) -> Self::Output {
1556 assert_eq!(self.shape, other.shape, "Shapes must match for element-wise addition");
1557 let data = self.data.iter().zip(other.data.iter()).map(|(a, b)| a + b).collect();
1558 NDArray::new(data, self.shape.clone())
1559 }
1560}
1561
1562impl Add<&NDArray> for NDArray {
1563 type Output = Self;
1564
1565 fn add(self, other: &NDArray) -> Self::Output {
1566 if self.shape.len() == other.shape.len() &&
1572 other.shape[0] == 1 &&
1573 self.shape[1] == other.shape[1] {
1574
1575 let mut result_data = Vec::with_capacity(self.data.len());
1577 let cols = other.shape[1];
1578
1579 for i in 0..self.shape[0] {
1581 for j in 0..cols {
1582 result_data.push(self.data[i * cols + j] + other.data[j]);
1583 }
1584 }
1585
1586 let result = NDArray::new(result_data, self.shape.clone());
1587 return result;
1589 }
1590
1591 if self.shape != other.shape {
1593 panic!("Shapes must match for element-wise addition\n left: {:?}\n right: {:?}",
1594 self.shape, other.shape);
1595 }
1596
1597 let data: Vec<f64> = self.data.iter()
1598 .zip(other.data.iter())
1599 .map(|(a, b)| a + b)
1600 .collect();
1601
1602 NDArray::new(data, self.shape.clone())
1603 }
1604}
1605
1606impl Sub for NDArray {
1607 type Output = Self;
1608
1609 fn sub(self, other: Self) -> Self::Output {
1610 assert_eq!(self.shape, other.shape, "Shapes must match for element-wise subtraction");
1611 let data = self.data.iter().zip(other.data.iter()).map(|(a, b)| a - b).collect();
1612 NDArray::new(data, self.shape.clone())
1613 }
1614}
1615
1616impl Mul<f64> for NDArray {
1617 type Output = Self;
1618
1619 fn mul(self, scalar: f64) -> Self::Output {
1620 let data = self.data.iter().map(|a| a * scalar).collect();
1621 NDArray::new(data, self.shape.clone())
1622 }
1623}
1624
1625
1626impl Add<f64> for NDArray {
1627 type Output = Self;
1628
1629 fn add(self, scalar: f64) -> Self::Output {
1630 self.add_scalar(scalar)
1631 }
1632}
1633
1634impl Mul<&NDArray> for f64 {
1635 type Output = NDArray;
1636
1637 fn mul(self, rhs: &NDArray) -> NDArray {
1638 rhs.multiply_scalar(self)
1639 }
1640}
1641
1642impl std::fmt::Display for NDArray {
1644 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1645 write!(f, "{}", self.display())
1646 }
1647}
1648
1649impl Sub<&NDArray> for NDArray {
1650 type Output = Self;
1651
1652 fn sub(self, other: &NDArray) -> Self::Output {
1653 if self.shape != other.shape {
1654 panic!("Shapes must match for element-wise subtraction");
1655 }
1656 let data: Vec<f64> = self.data.iter()
1657 .zip(other.data.iter())
1658 .map(|(a, b)| a - b)
1659 .collect();
1660 NDArray::new(data, self.shape.clone())
1661 }
1662}
1663
1664impl<'a, 'b> Sub<&'b NDArray> for &'a NDArray {
1679 type Output = NDArray;
1680
1681 fn sub(self, other: &'b NDArray) -> NDArray {
1682 if self.shape != other.shape {
1683 panic!("Shapes must match for element-wise subtraction");
1684 }
1685 let data: Vec<f64> = self.data.iter()
1686 .zip(other.data.iter())
1687 .map(|(a, b)| a - b)
1688 .collect();
1689 NDArray::new(data, self.shape.clone())
1690 }
1691}
1692
1693impl<'a, 'b> Add<&'b NDArray> for &'a NDArray {
1708 type Output = NDArray;
1709
1710 fn add(self, other: &'b NDArray) -> NDArray {
1711 if self.shape != other.shape {
1712 panic!("Shapes must match for element-wise addition");
1713 }
1714 let data: Vec<f64> = self.data.iter()
1715 .zip(other.data.iter())
1716 .map(|(a, b)| a + b)
1717 .collect();
1718 NDArray::new(data, self.shape.clone())
1719 }
1720}
1721
1722impl<'a, 'b> Mul<&'b NDArray> for &'a NDArray {
1737 type Output = NDArray;
1738
1739 fn mul(self, other: &'b NDArray) -> NDArray {
1740 if self.shape != other.shape {
1741 panic!("Shapes must match for element-wise multiplication");
1742 }
1743 let data: Vec<f64> = self.data.iter()
1744 .zip(other.data.iter())
1745 .map(|(a, b)| a * b)
1746 .collect();
1747 NDArray::new(data, self.shape.clone())
1748 }
1749}
1750
1751impl<'a, 'b> Div<&'b NDArray> for &'a NDArray {
1766 type Output = NDArray;
1767
1768 fn div(self, other: &'b NDArray) -> NDArray {
1769 if self.shape != other.shape {
1770 panic!("Shapes must match for element-wise division");
1771 }
1772 let data: Vec<f64> = self.data.iter()
1773 .zip(other.data.iter())
1774 .map(|(a, b)| a / b)
1775 .collect();
1776 NDArray::new(data, self.shape.clone())
1777 }
1778}
1779
1780#[cfg(test)]
1781mod tests {
1782 use super::*;
1783
1784
1785 #[test]
1787 fn test_new_ndarray() {
1788 let data = vec![1.0, 2.0, 3.0, 4.0];
1789 let shape = vec![2, 2];
1790 let array = NDArray::new(data.clone(), shape.clone());
1791 assert_eq!(array.data(), &data);
1792 assert_eq!(array.shape(), &shape);
1793 }
1794
1795 #[test]
1797 fn test_from_vec() {
1798 let data = vec![1.0, 2.0, 3.0];
1799 let array = NDArray::from_vec(data.clone());
1800 assert_eq!(array.data(), &data);
1801 assert_eq!(array.shape(), &[3]);
1802 }
1803
1804 #[test]
1806 fn test_arange() {
1807 let array = NDArray::arange(0.0, 5.0, 1.0);
1808 assert_eq!(array.data(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
1809 }
1810
1811 #[test]
1813 fn test_element_wise_addition() {
1814 let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1815 let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
1816 let sum = arr1.clone() + arr2;
1817 assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
1818 }
1819
1820 #[test]
1822 fn test_scalar_multiplication() {
1823 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1824 let scaled = arr.clone() * 2.0;
1825 assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
1826 }
1827
1828 #[test]
1830 fn test_reshape() {
1831 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1832 let reshaped = arr.reshape(&[2, 3])
1833 .expect("Failed to reshape array to valid dimensions");
1834 assert_eq!(reshaped.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1835 }
1836
1837 #[test]
1839 fn test_element_wise_subtraction() {
1840 let arr1 = NDArray::from_vec(vec![5.0, 7.0, 9.0]);
1841 let arr2 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1842 let diff = arr1 - arr2;
1843 assert_eq!(diff.data(), &[4.0, 5.0, 6.0]);
1844 }
1845
1846 #[test]
1848 fn test_scalar_addition() {
1849 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1850 let result = arr + 1.0;
1851 assert_eq!(result.data(), &[2.0, 3.0, 4.0]);
1852 }
1853
1854 #[test]
1856 #[allow(non_snake_case)]
1857 fn test_combined_operations() {
1858 let X = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1859 let theta_1 = 2.0;
1860 let theta_0 = 1.0;
1861 let predictions = X.clone() * theta_1 + theta_0;
1862 assert_eq!(predictions.data(), &[3.0, 5.0, 7.0]);
1863 }
1864
1865 #[test]
1867 fn test_one_hot_encode() {
1868 let labels = NDArray::from_vec(vec![0.0, 1.0, 2.0, 1.0, 0.0]);
1869 let one_hot = NDArray::one_hot_encode(&labels);
1870
1871 let expected = vec![
1872 1.0, 0.0, 0.0,
1873 0.0, 1.0, 0.0,
1874 0.0, 0.0, 1.0,
1875 0.0, 1.0, 0.0,
1876 1.0, 0.0, 0.0
1877 ];
1878
1879 assert_eq!(one_hot.shape(), &[5, 3]);
1880 assert_eq!(one_hot.data(), &expected);
1881 }
1882
1883 #[test]
1885 fn test_one_hot_encode_negative() {
1886 let labels = NDArray::from_vec(vec![-1.0, 0.0, 1.0, 0.0, -1.0]);
1887 let one_hot = NDArray::one_hot_encode(&labels);
1888
1889 let expected = vec![
1890 1.0, 0.0, 0.0,
1891 0.0, 1.0, 0.0,
1892 0.0, 0.0, 1.0,
1893 0.0, 1.0, 0.0,
1894 1.0, 0.0, 0.0
1895 ];
1896
1897 assert_eq!(one_hot.shape(), &[5, 3]);
1898 assert_eq!(one_hot.data(), &expected);
1899 }
1900
1901 #[test]
1903 #[should_panic(expected = "All values must be integers for one-hot encoding")]
1904 fn test_one_hot_encode_non_integer() {
1905 let labels = NDArray::from_vec(vec![0.0, 1.5, 2.0]);
1906 NDArray::one_hot_encode(&labels);
1907 }
1908
1909 #[test]
1911 fn test_transpose() {
1912 let arr = NDArray::from_matrix(vec![
1913 vec![1.0, 2.0, 3.0],
1914 vec![4.0, 5.0, 6.0]
1915 ]);
1916 let transposed = arr.transpose()
1917 .expect("Failed to transpose valid 2D array");
1918 assert_eq!(transposed.shape(), &[3, 2]);
1919 assert_eq!(transposed.data(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1920 }
1921
1922 #[test]
1924 fn test_dot() {
1925 let arr1 = NDArray::from_matrix(vec![
1926 vec![1.0, 2.0, 3.0],
1927 vec![4.0, 5.0, 6.0],
1928 ]);
1929 let arr2 = NDArray::from_matrix(vec![
1930 vec![7.0, 8.0],
1931 vec![9.0, 10.0],
1932 vec![11.0, 12.0],
1933 ]);
1934 let dot = arr1.dot(&arr2);
1935 assert_eq!(dot.data(), &[58.0, 64.0, 139.0, 154.0]); }
1937
1938 #[test]
1940 fn test_multiply() {
1941 let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1942 let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
1943 let multiply = arr1.multiply(&arr2);
1944 assert_eq!(multiply.data(), &[4.0, 10.0, 18.0]);
1945 }
1946
1947 #[test]
1948 fn test_multiply_large() {
1949 let arr1 = NDArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1951 let arr2 = NDArray::new(vec![2.0, 3.0, 4.0, 5.0], vec![2, 2]);
1952 let result = arr1.multiply(&arr2);
1953 assert_eq!(result.data(), &[2.0, 6.0, 12.0, 20.0]);
1954
1955 let large_arr1 = NDArray::new(vec![1.0; 2000], vec![1000, 2]);
1957 let large_arr2 = NDArray::new(vec![2.0; 2000], vec![1000, 2]);
1958 let large_result = large_arr1.multiply(&large_arr2);
1959 assert_eq!(large_result.data(), &vec![2.0; 2000]);
1960
1961 let arr3 = NDArray::new(vec![1.0, 2.0], vec![2, 1]);
1963 let arr4 = NDArray::new(vec![1.0, 2.0, 3.0], vec![3, 1]);
1964 let result = std::panic::catch_unwind(|| arr3.multiply(&arr4));
1965 assert!(result.is_err());
1966 }
1967
1968 #[test]
1970 fn test_scalar_sub() {
1971 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1972 let result = arr.scalar_sub(1.0);
1973 assert_eq!(result.data(), &[0.0, 1.0, 2.0]);
1974 }
1975
1976 #[test]
1978 fn test_multiply_scalar() {
1979 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1980 let result = arr.multiply_scalar(2.0);
1981 assert_eq!(result.data(), &[2.0, 4.0, 6.0]);
1982 }
1983
1984 #[test]
1986 fn test_map() {
1987 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1988 let result = arr.map(|x| x * 2.0);
1989 assert_eq!(result.data(), &[2.0, 4.0, 6.0]);
1990 }
1991
1992 #[test]
1994 fn test_clip() {
1995 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
1996 let result = arr.clip(1.0, 2.0);
1997 assert_eq!(result.data(), &[1.0, 2.0, 2.0]);
1998 }
1999
2000 #[test]
2002 fn test_divide() {
2003 let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2004 let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2005 let divide = arr1.divide(&arr2);
2006 assert_eq!(divide.data(), &[0.25, 0.4, 0.5]);
2007 }
2008
2009 #[test]
2011 fn test_divide_scalar() {
2012 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2013 let result = arr.divide_scalar(2.0);
2014 assert_eq!(result.data(), &[0.5, 1.0, 1.5]);
2015 }
2016
2017 #[test]
2019 fn test_sum_axis() {
2020 let arr = NDArray::from_matrix(vec![
2021 vec![1.0, 2.0, 3.0],
2022 vec![4.0, 5.0, 6.0],
2023 ]);
2024 let result = arr.sum_axis(0);
2025 assert_eq!(result.data(), &[5.0, 7.0, 9.0]); assert_eq!(result.shape(), &[1, 3]); let result = arr.sum_axis(1);
2029 assert_eq!(result.data(), &[6.0, 15.0]); assert_eq!(result.shape(), &[2, 1]); }
2032
2033 #[test]
2035 fn test_subtract() {
2036 let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2037 let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2038 let subtract = arr1.subtract(&arr2);
2039 assert_eq!(subtract.data(), &[-3.0, -3.0, -3.0]);
2040 }
2041
2042 #[test]
2044 fn test_add_scalar() {
2045 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2046 let result = arr.add_scalar(1.0);
2047 assert_eq!(result.data(), &[2.0, 3.0, 4.0]);
2048 }
2049
2050 #[test]
2052 fn test_zeros() {
2053 let shape = vec![2, 3];
2054 let zeros = NDArray::zeros(shape);
2055 assert_eq!(zeros.shape(), &[2, 3]);
2056 assert_eq!(zeros.data(), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
2057 }
2058
2059 #[test]
2061 fn test_log() {
2062 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2063 let result = arr.log();
2064 assert_eq!(result.data(), &[0.0, 0.6931471805599453, 1.0986122886681098]);
2065 }
2066
2067 #[test]
2069 fn test_sum() {
2070 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2071 let result = arr.sum();
2072 assert_eq!(result, 6.0);
2073 }
2074
2075 #[test]
2077 fn test_mean() {
2078 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
2079 assert_eq!(arr.mean(), 2.5);
2080 }
2081
2082 #[test]
2084 fn test_std() {
2085 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
2086 assert!((arr.std() - 1.118034).abs() < 1e-6);
2087 }
2088
2089 #[test]
2091 fn test_min_axis() {
2092 let arr = NDArray::from_matrix(vec![
2093 vec![1.0, 2.0, 3.0],
2094 vec![4.0, 0.5, 6.0],
2095 ]);
2096
2097 let min_axis_0 = arr.min_axis(0).unwrap();
2098 assert_eq!(min_axis_0.data(), &[1.0, 0.5, 3.0]);
2099
2100 let min_axis_1 = arr.min_axis(1).unwrap();
2101 assert_eq!(min_axis_1.data(), &[1.0, 0.5]);
2102 }
2103
2104 #[test]
2106 fn test_concatenate() {
2107 let arr1 = NDArray::from_matrix(vec![
2108 vec![1.0, 2.0],
2109 vec![3.0, 4.0],
2110 ]);
2111 let arr2 = NDArray::from_matrix(vec![
2112 vec![5.0, 6.0],
2113 vec![7.0, 8.0],
2114 ]);
2115
2116 let concat_0 = arr1.concatenate(&arr2, 0).unwrap();
2117 assert_eq!(concat_0.shape(), &[4, 2]);
2118 assert_eq!(concat_0.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2119
2120 let concat_1 = arr1.concatenate(&arr2, 1).unwrap();
2121 assert_eq!(concat_1.shape(), &[2, 4]);
2122 assert_eq!(concat_1.data(), &[1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]);
2123 }
2124
2125 #[test]
2127 fn test_broadcast_addition() {
2128 let arr1 = NDArray::from_matrix(vec![
2129 vec![1.0, 2.0, 3.0],
2130 vec![4.0, 5.0, 6.0]
2131 ]);
2132 let arr2 = NDArray::from_matrix(vec![
2133 vec![1.0, 2.0, 3.0]
2134 ]);
2135 let result = arr1 + &arr2;
2136 assert_eq!(result.shape(), &[2, 3]);
2137 assert_eq!(result.data(), &[2.0, 4.0, 6.0, 5.0, 7.0, 9.0]);
2138 }
2139
2140 #[test]
2142 fn test_max() {
2143 let arr = NDArray::from_vec(vec![1.0, 5.0, 3.0, 2.0]);
2144 assert_eq!(arr.max(), 5.0);
2145 }
2146
2147 #[test]
2149 fn test_sum_with_broadcasting() {
2150 let arr = NDArray::from_matrix(vec![
2151 vec![1.0, 2.0, 3.0],
2152 vec![4.0, 5.0, 6.0]
2153 ]);
2154 let sum_cols = arr.sum_axis(0);
2155 assert_eq!(sum_cols.shape(), &[1, 3]);
2156 assert_eq!(sum_cols.data(), &[5.0, 7.0, 9.0]);
2157
2158 let result = arr + &sum_cols;
2160 assert_eq!(result.data(), &[6.0, 9.0, 12.0, 9.0, 12.0, 15.0]);
2161 }
2162
2163 #[test]
2165 fn test_scalar_operations() {
2166 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2167
2168 let result1 = arr.clone() * 2.0;
2170 let result2 = 2.0 * &arr;
2171 assert_eq!(result1.data(), result2.data());
2172
2173 let result3 = arr + 1.0;
2175 assert_eq!(result3.data(), &[2.0, 3.0, 4.0]);
2176 }
2177
2178 #[test]
2180 #[should_panic(expected = "Shapes must match for element-wise addition")]
2181 fn test_invalid_addition() {
2182 let arr1 = NDArray::from_matrix(vec![vec![1.0, 2.0]]);
2183 let arr2 = NDArray::from_matrix(vec![vec![1.0, 2.0, 3.0]]);
2184 let _result = arr1 + arr2;
2185 }
2186
2187 #[test]
2189 fn test_chained_operations() {
2190 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2191 let result = (arr * 2.0 + 1.0).multiply_scalar(3.0);
2192 assert_eq!(result.data(), &[9.0, 15.0, 21.0]);
2193 }
2194
2195 #[test]
2197 fn test_abs() {
2198 let arr = NDArray::from_vec(vec![-1.0, 2.0, -3.0]);
2199 let result = arr.abs();
2200 assert_eq!(result.data(), &[1.0, 2.0, 3.0]);
2201 }
2202
2203 #[test]
2205 fn test_power() {
2206 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2207 let result = arr.power(2.0);
2208 assert_eq!(result.data(), &[1.0, 4.0, 9.0]);
2209 }
2210
2211 #[test]
2213 fn test_cumsum() {
2214 let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
2215 let result = arr.cumsum();
2216 assert_eq!(result.data(), &[1.0, 3.0, 6.0, 10.0]);
2217 }
2218
2219 #[test]
2221 fn test_round() {
2222 let arr = NDArray::from_vec(vec![1.234, 2.345, 3.456]);
2223 let result = arr.round(2);
2224 assert_eq!(result.data(), &[1.23, 2.35, 3.46]);
2225 }
2226
2227 #[test]
2229 fn test_argsort() {
2230 let arr = NDArray::from_vec(vec![3.0, 1.0, 2.0]);
2231 let indices = arr.argsort();
2232 assert_eq!(indices, vec![1, 2, 0]);
2233 }
2234
2235 #[test]
2237 fn test_argmax() {
2238 let arr = NDArray::from_matrix(vec![
2239 vec![1.0, 3.0, 2.0],
2240 vec![4.0, 2.0, 6.0]
2241 ]);
2242
2243 assert_eq!(arr.argmax(None), vec![5]); assert_eq!(arr.argmax(Some(0)), vec![1, 0, 1]); assert_eq!(arr.argmax(Some(1)), vec![1, 2]); }
2252
2253 #[test]
2255 fn test_unique() {
2256 let arr = NDArray::from_vec(vec![3.0, 1.0, 2.0, 1.0, 3.0]);
2257 let unique = arr.unique();
2258 assert_eq!(unique.data(), &[1.0, 2.0, 3.0]);
2259 }
2260
2261 #[test]
2263 fn test_where_cond() {
2264 let arr = NDArray::from_vec(vec![-1.0, 2.0, -3.0, 4.0]);
2265 let result = arr.where_cond(|x| x > 0.0, 1.0, -1.0);
2266 assert_eq!(result.data(), &[-1.0, 1.0, -1.0, 1.0]);
2267 }
2268
2269 #[test]
2271 fn test_median() {
2272 let arr1 = NDArray::from_vec(vec![1.0, 3.0, 2.0]);
2274 assert_eq!(arr1.median(), 2.0);
2275
2276 let arr2 = NDArray::from_vec(vec![1.0, 3.0, 2.0, 4.0]);
2278 assert_eq!(arr2.median(), 2.5);
2279 }
2280
2281 #[test]
2283 fn test_max_axis() {
2284 let arr = NDArray::from_matrix(vec![
2285 vec![1.0, 2.0, 3.0],
2286 vec![4.0, 0.5, 6.0],
2287 ]);
2288
2289 let max_axis_0 = arr.max_axis(0);
2290 assert_eq!(max_axis_0.shape(), &[1, 3]);
2291 assert_eq!(max_axis_0.data(), &[4.0, 2.0, 6.0]); let max_axis_1 = arr.max_axis(1);
2294 assert_eq!(max_axis_1.shape(), &[2, 1]);
2295 assert_eq!(max_axis_1.data(), &[3.0, 6.0]); }
2297
2298 #[test]
2300 fn test_element_wise_subtraction_ref() {
2301 let arr1 = NDArray::from_vec(vec![5.0, 7.0, 9.0]);
2302 let arr2 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2303 let diff = &arr1 - &arr2;
2304 assert_eq!(diff.data(), &[4.0, 5.0, 6.0]);
2305 }
2306
2307 #[test]
2309 fn test_element_wise_addition_ref() {
2310 let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2311 let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2312 let sum = &arr1 + &arr2;
2313 assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
2314 }
2315
2316 #[test]
2318 fn test_element_wise_multiplication_ref() {
2319 let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2320 let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2321 let product = &arr1 * &arr2;
2322 assert_eq!(product.data(), &[4.0, 10.0, 18.0]);
2323 }
2324
2325 #[test]
2327 fn test_element_wise_division_ref() {
2328 let arr1 = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
2329 let arr2 = NDArray::from_vec(vec![4.0, 5.0, 6.0]);
2330 let quotient = &arr1 / &arr2;
2331 assert_eq!(quotient.data(), &[0.25, 0.4, 0.5]);
2332 }
2333
2334 #[test]
2336 fn test_rand_uniform() {
2337 let shape = [2, 3];
2339 let arr = NDArray::rand_uniform(&shape);
2340 assert_eq!(arr.shape(), &[2, 3]);
2341
2342 for &val in arr.data() {
2344 assert!(val >= 0.0 && val <= 1.0);
2345 }
2346
2347 let arr2 = NDArray::rand_uniform(&shape);
2349 assert_ne!(arr.data(), arr2.data(), "Random arrays should be different");
2350
2351 let large_arr = NDArray::rand_uniform(&[1000]);
2353 let mean = large_arr.mean();
2354 let std = large_arr.std();
2355
2356 assert!((mean - 0.5).abs() < 0.1, "Mean should be approximately 0.5");
2360 assert!((std - 0.289).abs() < 0.1, "Std should be approximately 0.289");
2361 }
2362
2363 #[test]
2364 fn test_statistical_functions() {
2365 let arr = NDArray::from_matrix(vec![
2366 vec![1.0, 2.0, 3.0],
2367 vec![4.0, 5.0, 6.0],
2368 ]);
2369
2370 let mean_cols = arr.mean_axis(0);
2372 assert_eq!(mean_cols.shape(), &[1, 3]);
2373 assert_eq!(mean_cols.data(), &[2.5, 3.5, 4.5]);
2374
2375 let var_cols = arr.var_axis(0);
2377 assert_eq!(var_cols.shape(), &[1, 3]);
2378 assert_eq!(var_cols.data(), &[2.25, 2.25, 2.25]);
2379
2380 let sqrt = arr.sqrt(); assert_eq!(sqrt.data(), &[1.0, 2.0_f64.sqrt(), 3.0_f64.sqrt(),
2383 2.0, 5.0_f64.sqrt(), 6.0_f64.sqrt()]);
2384
2385 let added = arr.add_scalar(1.0); assert_eq!(added.data(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
2388 }
2389
2390 #[test]
2391 fn test_to_categorical() {
2392 let labels = NDArray::from_vec(vec![0.0, 1.0, 2.0]);
2394 let categorical = labels.to_categorical(None);
2395 assert_eq!(categorical.shape(), &[3, 3]);
2396 assert_eq!(categorical.data(), &[
2397 1.0, 0.0, 0.0,
2398 0.0, 1.0, 0.0,
2399 0.0, 0.0, 1.0
2400 ]);
2401
2402 let labels = NDArray::from_vec(vec![0.0, 1.0]);
2404 let categorical = labels.to_categorical(Some(3));
2405 assert_eq!(categorical.shape(), &[2, 3]);
2406 assert_eq!(categorical.data(), &[
2407 1.0, 0.0, 0.0,
2408 0.0, 1.0, 0.0
2409 ]);
2410
2411 let labels = NDArray::from_vec(vec![-1.0, 0.0, 1.0]);
2413 let categorical = labels.to_categorical(Some(3));
2414 assert_eq!(categorical.shape(), &[3, 3]);
2415 assert_eq!(categorical.data(), &[
2416 1.0, 0.0, 0.0,
2417 0.0, 1.0, 0.0,
2418 0.0, 0.0, 1.0
2419 ]);
2420 }
2421}