1use crate::{Array, DType, Shape};
6
7impl Array {
8 pub fn concatenate(arrays: &[Array], axis: usize) -> Array {
20 assert!(!arrays.is_empty(), "Need at least one array to concatenate");
21 assert_eq!(
22 arrays[0].dtype(),
23 DType::Float32,
24 "Only Float32 supported"
25 );
26
27 let first_shape = arrays[0].shape();
29 let ndim = first_shape.ndim();
30 assert!(axis < ndim, "Axis out of bounds");
31
32 for arr in arrays.iter().skip(1) {
33 assert_eq!(
34 arr.ndim(),
35 ndim,
36 "All arrays must have same number of dimensions"
37 );
38 for (i, (&dim1, &dim2)) in first_shape
39 .as_slice()
40 .iter()
41 .zip(arr.shape().as_slice().iter())
42 .enumerate()
43 {
44 if i != axis {
45 assert_eq!(
46 dim1, dim2,
47 "Dimensions must match except on concatenation axis"
48 );
49 }
50 }
51 }
52
53 let mut result_dims = first_shape.as_slice().to_vec();
55 result_dims[axis] =
56 arrays.iter().map(|a| a.shape().as_slice()[axis]).sum();
57 let result_shape = Shape::new(result_dims.clone());
58
59 if axis == 0 {
61 let mut data = Vec::new();
62 for arr in arrays {
63 data.extend(arr.to_vec());
64 }
65 Array::from_vec(data, result_shape)
66 } else {
67 let total_size: usize = result_dims.iter().product();
70 let mut result = vec![0.0f32; total_size];
71
72 let outer_size: usize = result_dims[..axis].iter().product();
74 let inner_size: usize = result_dims[axis + 1..].iter().product();
75
76 let mut offset_along_axis = 0;
77 for arr in arrays {
78 let arr_data = arr.to_vec();
79 let arr_shape = arr.shape().as_slice();
80 let arr_axis_size = arr_shape[axis];
81
82 for outer in 0..outer_size {
83 for ax in 0..arr_axis_size {
84 for inner in 0..inner_size {
85 let src_idx = outer * arr_axis_size * inner_size + ax * inner_size + inner;
86 let dst_ax = offset_along_axis + ax;
87 let dst_idx = outer * result_dims[axis] * inner_size + dst_ax * inner_size + inner;
88 result[dst_idx] = arr_data[src_idx];
89 }
90 }
91 }
92 offset_along_axis += arr_axis_size;
93 }
94
95 Array::from_vec(result, result_shape)
96 }
97 }
98
99 pub fn stack(arrays: &[Array], axis: usize) -> Array {
112 assert!(!arrays.is_empty(), "Need at least one array to stack");
113
114 let first_shape = arrays[0].shape();
116 for arr in arrays.iter().skip(1) {
117 assert_eq!(
118 arr.shape(),
119 first_shape,
120 "All arrays must have the same shape for stacking"
121 );
122 }
123
124 let ndim = first_shape.ndim();
125 assert!(axis <= ndim, "Axis out of bounds for stacking");
126
127 let mut result_dims = Vec::new();
129 for (i, &dim) in first_shape.as_slice().iter().enumerate() {
130 if i == axis {
131 result_dims.push(arrays.len());
132 }
133 result_dims.push(dim);
134 }
135 if axis == ndim {
136 result_dims.push(arrays.len());
137 }
138
139 if axis == 0 {
141 let mut data = Vec::new();
142 for arr in arrays {
143 data.extend(arr.to_vec());
144 }
145 let mut shape_dims = vec![arrays.len()];
146 shape_dims.extend_from_slice(first_shape.as_slice());
147 Array::from_vec(data, Shape::new(shape_dims))
148 } else {
149 panic!("stack only supports axis=0 for now");
150 }
151 }
152
153 pub fn split(array: &Array, num_sections: usize, axis: usize) -> Vec<Array> {
173 assert_eq!(array.dtype(), DType::Float32, "Only Float32 supported");
174 assert!(num_sections > 0, "Number of sections must be positive");
175
176 let shape = array.shape().as_slice();
177 assert!(axis < shape.len(), "Axis out of bounds");
178 let axis_size = shape[axis];
179 assert_eq!(
180 axis_size % num_sections,
181 0,
182 "Array size along axis must be divisible by number of sections"
183 );
184
185 let section_size = axis_size / num_sections;
186 let data = array.to_vec();
187
188 let mut result = Vec::with_capacity(num_sections);
189
190 if axis == 0 {
191 let elements_per_section = data.len() / num_sections;
193
194 for i in 0..num_sections {
195 let start = i * elements_per_section;
196 let end = start + elements_per_section;
197 let section_data = data[start..end].to_vec();
198
199 let mut section_shape = shape.to_vec();
200 section_shape[axis] = section_size;
201
202 result.push(Array::from_vec(section_data, Shape::new(section_shape)));
203 }
204 } else {
205 let outer_size: usize = shape[..axis].iter().product();
207 let inner_size: usize = shape[axis + 1..].iter().product();
208
209 for section_idx in 0..num_sections {
210 let mut section_data = Vec::with_capacity(outer_size * section_size * inner_size);
211
212 for outer in 0..outer_size {
213 for ax in 0..section_size {
214 let src_ax = section_idx * section_size + ax;
215 for inner in 0..inner_size {
216 let src_idx = outer * axis_size * inner_size + src_ax * inner_size + inner;
217 section_data.push(data[src_idx]);
218 }
219 }
220 }
221
222 let mut section_shape = shape.to_vec();
223 section_shape[axis] = section_size;
224
225 result.push(Array::from_vec(section_data, Shape::new(section_shape)));
226 }
227 }
228
229 result
230 }
231
232 pub fn where_cond(condition: &Array, x: &Array, y: &Array) -> Array {
254 assert_eq!(condition.dtype(), DType::Float32, "Only Float32 supported");
255 assert_eq!(x.dtype(), DType::Float32, "Only Float32 supported");
256 assert_eq!(y.dtype(), DType::Float32, "Only Float32 supported");
257
258 let shape1 = condition
260 .shape()
261 .broadcast_with(x.shape())
262 .expect("Condition and x shapes are not broadcast-compatible");
263 let result_shape = shape1
264 .broadcast_with(y.shape())
265 .expect("Cannot broadcast all three arrays to common shape");
266
267 let cond_data = condition.to_vec();
268 let x_data = x.to_vec();
269 let y_data = y.to_vec();
270
271 if condition.shape() == x.shape()
273 && x.shape() == y.shape()
274 && condition.shape() == &result_shape
275 {
276 let result_data: Vec<f32> = cond_data
277 .iter()
278 .zip(x_data.iter().zip(y_data.iter()))
279 .map(|(&c, (&x_val, &y_val))| if c != 0.0 { x_val } else { y_val })
280 .collect();
281 return Array::from_vec(result_data, result_shape);
282 }
283
284 let size = result_shape.size();
286 let result_data: Vec<f32> = (0..size)
287 .map(|i| {
288 let cond_idx =
289 crate::ops::binary::broadcast_index(i, &result_shape, condition.shape());
290 let x_idx = crate::ops::binary::broadcast_index(i, &result_shape, x.shape());
291 let y_idx = crate::ops::binary::broadcast_index(i, &result_shape, y.shape());
292
293 if cond_data[cond_idx] != 0.0 {
294 x_data[x_idx]
295 } else {
296 y_data[y_idx]
297 }
298 })
299 .collect();
300
301 Array::from_vec(result_data, result_shape)
302 }
303
304 pub fn select(indices: &Array, choices: &[Array]) -> Array {
326 assert_eq!(indices.dtype(), DType::Float32, "Only Float32 supported");
327 assert!(!choices.is_empty(), "Must provide at least one choice");
328
329 let choice_shape = choices[0].shape();
331 for choice in choices.iter().skip(1) {
332 assert_eq!(
333 choice.dtype(),
334 DType::Float32,
335 "Only Float32 supported for choices"
336 );
337 assert_eq!(
338 choice.shape(),
339 choice_shape,
340 "All choices must have the same shape"
341 );
342 }
343
344 assert_eq!(
346 indices.shape(),
347 choice_shape,
348 "Indices and choices must have the same shape"
349 );
350
351 let indices_data = indices.to_vec();
352 let choice_data: Vec<Vec<f32>> = choices.iter().map(|c| c.to_vec()).collect();
353
354 let result_data: Vec<f32> = indices_data
355 .iter()
356 .enumerate()
357 .map(|(i, &idx)| {
358 let idx_int = idx as usize;
359 assert!(
360 idx_int < choices.len(),
361 "Index {} out of bounds for {} choices",
362 idx_int,
363 choices.len()
364 );
365 choice_data[idx_int][i]
366 })
367 .collect();
368
369 Array::from_vec(result_data, choice_shape.clone())
370 }
371
372 pub fn clip(&self, min: f32, max: f32) -> Array {
383 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
384
385 let data = self.to_vec();
386 let result_data: Vec<f32> =
387 data.iter().map(|&x| x.clamp(min, max)).collect();
388
389 Array::from_vec(result_data, self.shape().clone())
390 }
391
392 pub fn flip(&self, axis: usize) -> Array {
403 assert!(axis < self.ndim(), "Axis out of bounds");
404 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
405
406 let shape = self.shape();
407 let dims = shape.as_slice();
408 let data = self.to_vec();
409
410 if self.ndim() == 1 {
412 let mut result: Vec<f32> = data.clone();
413 result.reverse();
414 return Array::from_vec(result, shape.clone());
415 }
416
417 if axis == 0 {
419 let slice_size = data.len() / dims[0];
420 let mut result = Vec::with_capacity(data.len());
421
422 for i in (0..dims[0]).rev() {
423 let start = i * slice_size;
424 let end = start + slice_size;
425 result.extend_from_slice(&data[start..end]);
426 }
427
428 Array::from_vec(result, shape.clone())
429 } else {
430 panic!("flip only supports axis=0 for multi-dimensional arrays");
431 }
432 }
433
434 pub fn pad(&self, pad_width: &[(usize, usize)], constant_value: f32) -> Array {
450 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
451 assert_eq!(
452 pad_width.len(),
453 self.ndim(),
454 "pad_width must match number of dimensions"
455 );
456
457 let shape = self.shape().as_slice();
458 let data = self.to_vec();
459
460 let mut out_shape = Vec::with_capacity(shape.len());
462 for (i, &dim) in shape.iter().enumerate() {
463 out_shape.push(pad_width[i].0 + dim + pad_width[i].1);
464 }
465
466 if self.ndim() == 1 {
468 let (before, after) = pad_width[0];
469 let mut result = vec![constant_value; before];
470 result.extend_from_slice(&data);
471 result.extend(vec![constant_value; after]);
472 return Array::from_vec(result, Shape::new(out_shape));
473 }
474
475 if self.ndim() == 2 {
477 let (h, w) = (shape[0], shape[1]);
478 let (h_before, _) = pad_width[0];
479 let (w_before, _) = pad_width[1];
480
481 let out_h = out_shape[0];
482 let out_w = out_shape[1];
483 let mut result = vec![constant_value; out_h * out_w];
484
485 for i in 0..h {
486 for j in 0..w {
487 let out_i = i + h_before;
488 let out_j = j + w_before;
489 result[out_i * out_w + out_j] = data[i * w + j];
490 }
491 }
492
493 return Array::from_vec(result, Shape::new(out_shape));
494 }
495
496 panic!("pad only supports 1D and 2D arrays for now");
497 }
498
499 pub fn pad_edge(&self, pad_width: &[(usize, usize)]) -> Array {
510 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
511 assert_eq!(
512 pad_width.len(),
513 self.ndim(),
514 "pad_width must match number of dimensions"
515 );
516
517 let shape = self.shape().as_slice();
518 let data = self.to_vec();
519
520 if self.ndim() == 1 {
522 let (before, after) = pad_width[0];
523 let mut result = vec![data[0]; before];
524 result.extend_from_slice(&data);
525 result.extend(vec![data[data.len() - 1]; after]);
526
527 let out_len = before + shape[0] + after;
528 return Array::from_vec(result, Shape::new(vec![out_len]));
529 }
530
531 if self.ndim() == 2 {
533 let (h, w) = (shape[0], shape[1]);
534 let (h_before, h_after) = pad_width[0];
535 let (w_before, w_after) = pad_width[1];
536
537 let out_h = h_before + h + h_after;
538 let out_w = w_before + w + w_after;
539 let mut result = vec![0.0; out_h * out_w];
540
541 for out_i in 0..out_h {
542 for out_j in 0..out_w {
543 let in_i = if out_i < h_before {
545 0
546 } else if out_i >= h_before + h {
547 h - 1
548 } else {
549 out_i - h_before
550 };
551
552 let in_j = if out_j < w_before {
553 0
554 } else if out_j >= w_before + w {
555 w - 1
556 } else {
557 out_j - w_before
558 };
559
560 result[out_i * out_w + out_j] = data[in_i * w + in_j];
561 }
562 }
563
564 return Array::from_vec(result, Shape::new(vec![out_h, out_w]));
565 }
566
567 panic!("pad_edge only supports 1D and 2D arrays for now");
568 }
569
570 pub fn pad_reflect(&self, pad_width: &[(usize, usize)]) -> Array {
581 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
582 assert_eq!(
583 pad_width.len(),
584 self.ndim(),
585 "pad_width must match number of dimensions"
586 );
587
588 let shape = self.shape().as_slice();
589 let data = self.to_vec();
590
591 if self.ndim() == 1 {
593 let len = shape[0];
594 let (before, after) = pad_width[0];
595
596 assert!(
597 before < len && after < len,
598 "Padding width must be less than array size for reflect mode"
599 );
600
601 let mut result = Vec::with_capacity(before + len + after);
602
603 for i in 0..before {
605 result.push(data[before - i]);
606 }
607
608 result.extend_from_slice(&data);
610
611 for i in 0..after {
613 result.push(data[len - 2 - i]);
614 }
615
616 let out_len = before + len + after;
617 return Array::from_vec(result, Shape::new(vec![out_len]));
618 }
619
620 if self.ndim() == 2 {
622 let (h, w) = (shape[0], shape[1]);
623 let (h_before, h_after) = pad_width[0];
624 let (w_before, w_after) = pad_width[1];
625
626 assert!(
627 h_before < h && h_after < h && w_before < w && w_after < w,
628 "Padding width must be less than array size for reflect mode"
629 );
630
631 let out_h = h_before + h + h_after;
632 let out_w = w_before + w + w_after;
633 let mut result = vec![0.0; out_h * out_w];
634
635 for out_i in 0..out_h {
636 for out_j in 0..out_w {
637 let in_i = if out_i < h_before {
639 h_before - out_i
640 } else if out_i >= h_before + h {
641 h - 2 - (out_i - h_before - h)
642 } else {
643 out_i - h_before
644 };
645
646 let in_j = if out_j < w_before {
647 w_before - out_j
648 } else if out_j >= w_before + w {
649 w - 2 - (out_j - w_before - w)
650 } else {
651 out_j - w_before
652 };
653
654 result[out_i * out_w + out_j] = data[in_i * w + in_j];
655 }
656 }
657
658 return Array::from_vec(result, Shape::new(vec![out_h, out_w]));
659 }
660
661 panic!("pad_reflect only supports 1D and 2D arrays for now");
662 }
663
664 pub fn nan_to_num(&self, nan: f32, posinf: f32, neginf: f32) -> Array {
684 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
685 let data = self.to_vec();
686 let result: Vec<f32> = data
687 .iter()
688 .map(|&x| {
689 if x.is_nan() {
690 nan
691 } else if x.is_infinite() && x > 0.0 {
692 posinf
693 } else if x.is_infinite() && x < 0.0 {
694 neginf
695 } else {
696 x
697 }
698 })
699 .collect();
700 Array::from_vec(result, self.shape().clone())
701 }
702
703 pub fn isnan(&self) -> Array {
716 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
717 let data = self.to_vec();
718 let result: Vec<f32> = data
719 .iter()
720 .map(|&x| if x.is_nan() { 1.0 } else { 0.0 })
721 .collect();
722 Array::from_vec(result, self.shape().clone())
723 }
724
725 pub fn isinf(&self) -> Array {
738 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
739 let data = self.to_vec();
740 let result: Vec<f32> = data
741 .iter()
742 .map(|&x| if x.is_infinite() { 1.0 } else { 0.0 })
743 .collect();
744 Array::from_vec(result, self.shape().clone())
745 }
746
747 pub fn isfinite(&self) -> Array {
760 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
761 let data = self.to_vec();
762 let result: Vec<f32> = data
763 .iter()
764 .map(|&x| if x.is_finite() { 1.0 } else { 0.0 })
765 .collect();
766 Array::from_vec(result, self.shape().clone())
767 }
768
769 pub fn clip_by_norm(&self, max_norm: f32) -> Array {
786 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
787 let data = self.to_vec();
788
789 let norm: f32 = data.iter().map(|&x| x * x).sum::<f32>().sqrt();
791
792 if norm <= max_norm {
793 return self.clone();
794 }
795
796 let scale = max_norm / norm;
798 let result: Vec<f32> = data.iter().map(|&x| x * scale).collect();
799 Array::from_vec(result, self.shape().clone())
800 }
801
802 pub fn ravel(&self) -> Array {
816 let size = self.size();
817 Array::from_vec(self.to_vec(), Shape::new(vec![size]))
818 }
819
820 pub fn flatten(&self) -> Array {
833 self.ravel()
834 }
835
836 pub fn atleast_1d(&self) -> Array {
849 if self.shape().ndim() == 0 {
850 Array::from_vec(self.to_vec(), Shape::new(vec![1]))
851 } else {
852 self.clone()
853 }
854 }
855
856 pub fn atleast_2d(&self) -> Array {
869 match self.shape().ndim() {
870 0 => Array::from_vec(self.to_vec(), Shape::new(vec![1, 1])),
871 1 => {
872 let n = self.shape().as_slice()[0];
873 Array::from_vec(self.to_vec(), Shape::new(vec![1, n]))
874 }
875 _ => self.clone(),
876 }
877 }
878
879 pub fn atleast_3d(&self) -> Array {
892 match self.shape().ndim() {
893 0 => Array::from_vec(self.to_vec(), Shape::new(vec![1, 1, 1])),
894 1 => {
895 let n = self.shape().as_slice()[0];
896 Array::from_vec(self.to_vec(), Shape::new(vec![1, n, 1]))
897 }
898 2 => {
899 let dims = self.shape().as_slice();
900 Array::from_vec(self.to_vec(), Shape::new(vec![dims[0], dims[1], 1]))
901 }
902 _ => self.clone(),
903 }
904 }
905
906 pub fn broadcast_to(&self, new_shape: Shape) -> Array {
920 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
921
922 let _result_shape = self
924 .shape()
925 .broadcast_with(&new_shape)
926 .unwrap_or_else(|| {
927 panic!(
928 "Cannot broadcast array of shape {:?} to shape {:?}. \
929 Broadcasting requires dimensions to be equal or one of them to be 1.",
930 self.shape().as_slice(),
931 new_shape.as_slice()
932 )
933 });
934
935 let data = self.to_vec();
936 let size = new_shape.size();
937 let mut result = Vec::with_capacity(size);
938
939 for i in 0..size {
940 let src_idx =
941 crate::ops::binary::broadcast_index(i, &new_shape, self.shape());
942 result.push(data[src_idx]);
943 }
944
945 Array::from_vec(result, new_shape)
946 }
947
948 pub fn broadcast_arrays(arrays: &[Array]) -> Vec<Array> {
974 if arrays.is_empty() {
975 return vec![];
976 }
977
978 if arrays.len() == 1 {
979 return vec![arrays[0].clone()];
980 }
981
982 let mut common_shape = arrays[0].shape().clone();
984
985 for array in &arrays[1..] {
986 common_shape = common_shape
987 .broadcast_with(array.shape())
988 .unwrap_or_else(|| {
989 panic!(
990 "Cannot broadcast arrays with shapes {:?} and {:?}. \
991 Broadcasting requires dimensions to be equal or one of them to be 1.",
992 common_shape.as_slice(),
993 array.shape().as_slice()
994 )
995 });
996 }
997
998 arrays
1000 .iter()
1001 .map(|arr| arr.broadcast_to(common_shape.clone()))
1002 .collect()
1003 }
1004
1005 pub fn take(&self, indices: &[usize]) -> Array {
1017 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1018 let data = self.to_vec();
1019
1020 let result: Vec<f32> = indices
1021 .iter()
1022 .map(|&idx| {
1023 assert!(idx < data.len(), "Index {} out of bounds", idx);
1024 data[idx]
1025 })
1026 .collect();
1027
1028 let len = result.len();
1029 Array::from_vec(result, Shape::new(vec![len]))
1030 }
1031
1032 pub fn put(&self, indices: &[usize], values: &[f32]) -> Array {
1051 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1052 assert_eq!(
1053 indices.len(),
1054 values.len(),
1055 "Number of indices must match number of values"
1056 );
1057
1058 let mut data = self.to_vec();
1059
1060 for (i, &idx) in indices.iter().enumerate() {
1061 assert!(idx < data.len(), "Index {} out of bounds", idx);
1062 data[idx] = values[i];
1063 }
1064
1065 Array::from_vec(data, self.shape().clone())
1066 }
1067
1068 pub fn scatter(&self, indices: &[usize], updates: &[f32]) -> Array {
1088 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1089 assert_eq!(
1090 indices.len(),
1091 updates.len(),
1092 "Number of indices must match number of updates"
1093 );
1094
1095 let mut data = self.to_vec();
1096
1097 for (i, &idx) in indices.iter().enumerate() {
1098 assert!(idx < data.len(), "Index {} out of bounds", idx);
1099 data[idx] = updates[i];
1100 }
1101
1102 Array::from_vec(data, self.shape().clone())
1103 }
1104
1105 pub fn scatter_add(&self, indices: &[usize], updates: &[f32]) -> Array {
1125 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1126 assert_eq!(
1127 indices.len(),
1128 updates.len(),
1129 "Number of indices must match number of updates"
1130 );
1131
1132 let mut data = self.to_vec();
1133
1134 for (i, &idx) in indices.iter().enumerate() {
1135 assert!(idx < data.len(), "Index {} out of bounds", idx);
1136 data[idx] += updates[i];
1137 }
1138
1139 Array::from_vec(data, self.shape().clone())
1140 }
1141
1142 pub fn scatter_min(&self, indices: &[usize], updates: &[f32]) -> Array {
1161 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1162 assert_eq!(
1163 indices.len(),
1164 updates.len(),
1165 "Number of indices must match number of updates"
1166 );
1167
1168 let mut data = self.to_vec();
1169
1170 for (i, &idx) in indices.iter().enumerate() {
1171 assert!(idx < data.len(), "Index {} out of bounds", idx);
1172 data[idx] = data[idx].min(updates[i]);
1173 }
1174
1175 Array::from_vec(data, self.shape().clone())
1176 }
1177
1178 pub fn scatter_max(&self, indices: &[usize], updates: &[f32]) -> Array {
1197 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1198 assert_eq!(
1199 indices.len(),
1200 updates.len(),
1201 "Number of indices must match number of updates"
1202 );
1203
1204 let mut data = self.to_vec();
1205
1206 for (i, &idx) in indices.iter().enumerate() {
1207 assert!(idx < data.len(), "Index {} out of bounds", idx);
1208 data[idx] = data[idx].max(updates[i]);
1209 }
1210
1211 Array::from_vec(data, self.shape().clone())
1212 }
1213
1214 pub fn scatter_mul(&self, indices: &[usize], updates: &[f32]) -> Array {
1233 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1234 assert_eq!(
1235 indices.len(),
1236 updates.len(),
1237 "Number of indices must match number of updates"
1238 );
1239
1240 let mut data = self.to_vec();
1241
1242 for (i, &idx) in indices.iter().enumerate() {
1243 assert!(idx < data.len(), "Index {} out of bounds", idx);
1244 data[idx] *= updates[i];
1245 }
1246
1247 Array::from_vec(data, self.shape().clone())
1248 }
1249
1250 pub fn take_along_axis(&self, indices: &Array, axis: usize) -> Array {
1276 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1277 assert_eq!(indices.dtype(), DType::Float32, "Indices must be Float32");
1278 assert!(axis < self.ndim(), "Axis {} out of bounds", axis);
1279
1280 let data = self.to_vec();
1281 let idx_data = indices.to_vec();
1282 let shape = self.shape().as_slice();
1283
1284 if indices.ndim() == 1 && self.ndim() == 2 {
1286 if axis == 1 {
1287 let rows = shape[0];
1289 let cols = shape[1];
1290 assert_eq!(
1291 indices.size(),
1292 rows,
1293 "Indices size must match number of rows"
1294 );
1295 let result: Vec<f32> = idx_data
1296 .iter()
1297 .enumerate()
1298 .map(|(row, &idx)| {
1299 let col = idx as usize;
1300 assert!(col < cols, "Column index {} out of bounds", col);
1301 data[row * cols + col]
1302 })
1303 .collect();
1304 return Array::from_vec(result, Shape::new(vec![rows]));
1305 } else {
1306 let rows = shape[0];
1308 let cols = shape[1];
1309 assert_eq!(
1310 indices.size(),
1311 cols,
1312 "Indices size must match number of columns"
1313 );
1314 let result: Vec<f32> = idx_data
1315 .iter()
1316 .enumerate()
1317 .map(|(col, &idx)| {
1318 let row = idx as usize;
1319 assert!(row < rows, "Row index {} out of bounds", row);
1320 data[row * cols + col]
1321 })
1322 .collect();
1323 return Array::from_vec(result, Shape::new(vec![cols]));
1324 }
1325 }
1326
1327 if self.ndim() == 1 {
1329 let result: Vec<f32> = idx_data
1330 .iter()
1331 .map(|&idx| {
1332 let i = idx as usize;
1333 assert!(i < data.len(), "Index {} out of bounds", i);
1334 data[i]
1335 })
1336 .collect();
1337 return Array::from_vec(result, indices.shape().clone());
1338 }
1339
1340 let idx_shape = indices.shape().as_slice();
1342 assert_eq!(
1343 self.ndim(),
1344 indices.ndim(),
1345 "For N-dimensional take_along_axis, array and indices must have same number of dimensions"
1346 );
1347 for (i, (&s, &is)) in shape.iter().zip(idx_shape.iter()).enumerate() {
1348 if i != axis {
1349 assert_eq!(
1350 s, is,
1351 "Dimension {} must match: array has {}, indices has {}",
1352 i, s, is
1353 );
1354 }
1355 }
1356
1357 let ndim = self.ndim();
1358 let out_size = indices.size();
1359 let mut result = vec![0.0f32; out_size];
1360
1361 let mut strides = vec![1usize; ndim];
1363 for i in (0..ndim - 1).rev() {
1364 strides[i] = strides[i + 1] * shape[i + 1];
1365 }
1366
1367 let mut idx_strides = vec![1usize; ndim];
1369 for i in (0..ndim - 1).rev() {
1370 idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
1371 }
1372
1373 for out_flat in 0..out_size {
1375 let mut multi_idx = vec![0usize; ndim];
1377 let mut remaining = out_flat;
1378 for i in 0..ndim {
1379 multi_idx[i] = remaining / idx_strides[i];
1380 remaining %= idx_strides[i];
1381 }
1382
1383 let idx_val = idx_data[out_flat] as usize;
1385 assert!(
1386 idx_val < shape[axis],
1387 "Index {} out of bounds for axis {} with size {}",
1388 idx_val,
1389 axis,
1390 shape[axis]
1391 );
1392
1393 let mut input_idx = multi_idx.clone();
1395 input_idx[axis] = idx_val;
1396
1397 let in_flat: usize = input_idx
1399 .iter()
1400 .zip(strides.iter())
1401 .map(|(i, s)| i * s)
1402 .sum();
1403
1404 result[out_flat] = data[in_flat];
1405 }
1406
1407 Array::from_vec(result, indices.shape().clone())
1408 }
1409
1410 pub fn nonzero(&self) -> Vec<usize> {
1423 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1424 let data = self.to_vec();
1425
1426 data.iter()
1427 .enumerate()
1428 .filter(|(_, &val)| val != 0.0)
1429 .map(|(idx, _)| idx)
1430 .collect()
1431 }
1432
1433 pub fn argwhere(&self) -> Vec<usize> {
1447 self.nonzero()
1448 }
1449
1450 pub fn compress(&self, condition: &Array) -> Array {
1464 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1465 assert_eq!(
1466 condition.dtype(),
1467 DType::Float32,
1468 "Only Float32 supported"
1469 );
1470 assert_eq!(
1471 self.size(),
1472 condition.size(),
1473 "Array and condition must have same size"
1474 );
1475
1476 let data = self.to_vec();
1477 let cond_data = condition.to_vec();
1478
1479 let result: Vec<f32> = data
1480 .iter()
1481 .zip(cond_data.iter())
1482 .filter(|(_, &c)| c != 0.0)
1483 .map(|(&val, _)| val)
1484 .collect();
1485
1486 let len = result.len();
1487 Array::from_vec(result, Shape::new(vec![len]))
1488 }
1489
1490 pub fn choose(indices: &[usize], choices: &[Array]) -> Array {
1507 assert!(!choices.is_empty(), "Must provide at least one choice");
1508 let size = choices[0].size();
1509
1510 for choice in choices.iter() {
1511 assert_eq!(
1512 choice.size(),
1513 size,
1514 "All choices must have the same size"
1515 );
1516 }
1517
1518 assert_eq!(
1519 indices.len(),
1520 size,
1521 "Indices must have same length as choices"
1522 );
1523
1524 let choice_data: Vec<Vec<f32>> =
1525 choices.iter().map(|c| c.to_vec()).collect();
1526
1527 let result: Vec<f32> = (0..size)
1528 .map(|i| {
1529 let choice_idx = indices[i];
1530 assert!(
1531 choice_idx < choices.len(),
1532 "Index {} out of bounds",
1533 choice_idx
1534 );
1535 choice_data[choice_idx][i]
1536 })
1537 .collect();
1538
1539 Array::from_vec(result, choices[0].shape().clone())
1540 }
1541
1542 pub fn extract(&self, condition: &Array) -> Array {
1556 self.compress(condition)
1557 }
1558
1559 pub fn roll(&self, shift: isize) -> Array {
1572 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1573 let data = self.to_vec();
1574 let len = data.len();
1575
1576 if len == 0 {
1577 return self.clone();
1578 }
1579
1580 let shift = ((shift % len as isize) + len as isize) as usize % len;
1582
1583 let mut result = vec![0.0; len];
1584 for i in 0..len {
1585 result[(i + shift) % len] = data[i];
1586 }
1587
1588 Array::from_vec(result, self.shape().clone())
1589 }
1590
1591 pub fn rot90(&self, k: isize) -> Array {
1604 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1605 assert_eq!(self.shape().ndim(), 2, "Only 2D arrays supported");
1606
1607 let shape = self.shape().as_slice();
1608 let (h, w) = (shape[0], shape[1]);
1609 let data = self.to_vec();
1610
1611 let k = k.rem_euclid(4);
1613
1614 match k {
1615 0 => self.clone(),
1616 1 => {
1617 let mut result = vec![0.0; h * w];
1619 for i in 0..h {
1620 for j in 0..w {
1621 let new_i = w - 1 - j;
1622 let new_j = i;
1623 result[new_i * h + new_j] = data[i * w + j];
1624 }
1625 }
1626 Array::from_vec(result, Shape::new(vec![w, h]))
1627 }
1628 2 => {
1629 let mut result = vec![0.0; h * w];
1631 for i in 0..h {
1632 for j in 0..w {
1633 let new_i = h - 1 - i;
1634 let new_j = w - 1 - j;
1635 result[new_i * w + new_j] = data[i * w + j];
1636 }
1637 }
1638 Array::from_vec(result, Shape::new(vec![h, w]))
1639 }
1640 3 => {
1641 let mut result = vec![0.0; h * w];
1643 for i in 0..h {
1644 for j in 0..w {
1645 let new_i = j;
1646 let new_j = h - 1 - i;
1647 result[new_i * h + new_j] = data[i * w + j];
1648 }
1649 }
1650 Array::from_vec(result, Shape::new(vec![w, h]))
1651 }
1652 _ => unreachable!(),
1653 }
1654 }
1655
1656 pub fn swapaxes(&self, axis1: usize, axis2: usize) -> Array {
1667 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1668 let ndim = self.shape().ndim();
1669 assert!(axis1 < ndim, "axis1 out of bounds");
1670 assert!(axis2 < ndim, "axis2 out of bounds");
1671
1672 if axis1 == axis2 {
1673 return self.clone();
1674 }
1675
1676 if ndim == 2 && ((axis1 == 0 && axis2 == 1) || (axis1 == 1 && axis2 == 0)) {
1678 return self.transpose();
1679 }
1680
1681 let old_shape = self.shape().as_slice();
1683 let mut new_shape = old_shape.to_vec();
1684 new_shape.swap(axis1, axis2);
1685
1686 let data = self.to_vec();
1687 let size = self.size();
1688 let mut result = vec![0.0; size];
1689
1690 let old_strides = self.shape().default_strides();
1692 let mut new_strides = vec![1; ndim];
1693 for i in (0..ndim - 1).rev() {
1694 new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
1695 }
1696
1697 for i in 0..size {
1699 let mut old_indices = vec![0; ndim];
1700 let mut temp = i;
1701 for j in 0..ndim {
1702 old_indices[j] = temp / old_strides[j];
1703 temp %= old_strides[j];
1704 }
1705
1706 old_indices.swap(axis1, axis2);
1708
1709 let mut new_idx = 0;
1711 for j in 0..ndim {
1712 new_idx += old_indices[j] * new_strides[j];
1713 }
1714
1715 result[new_idx] = data[i];
1716 }
1717
1718 Array::from_vec(result, Shape::new(new_shape))
1719 }
1720
1721 pub fn moveaxis(&self, source: usize, destination: usize) -> Array {
1734 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1735 let ndim = self.shape().ndim();
1736 assert!(source < ndim, "source axis out of bounds");
1737 assert!(destination < ndim, "destination axis out of bounds");
1738
1739 if source == destination {
1740 return self.clone();
1741 }
1742
1743 let old_shape = self.shape().as_slice();
1744 let mut new_shape = Vec::new();
1745
1746 for (i, &dim) in old_shape.iter().enumerate() {
1748 if i != source {
1749 new_shape.push(dim);
1750 }
1751 }
1752 new_shape.insert(destination, old_shape[source]);
1753
1754 if ndim == 2 {
1756 return self.swapaxes(source, destination);
1757 }
1758
1759 if ndim == 3 {
1761 let data = self.to_vec();
1762 let size = self.size();
1763 let mut result = vec![0.0; size];
1764
1765 let old_strides = self.shape().default_strides();
1766 let mut new_strides = vec![1; ndim];
1767 for i in (0..ndim - 1).rev() {
1768 new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
1769 }
1770
1771 for i in 0..size {
1772 let mut old_indices = vec![0; ndim];
1773 let mut temp = i;
1774 for j in 0..ndim {
1775 old_indices[j] = temp / old_strides[j];
1776 temp %= old_strides[j];
1777 }
1778
1779 let moved_val = old_indices[source];
1781 old_indices.remove(source);
1782 old_indices.insert(destination, moved_val);
1783
1784 let mut new_idx = 0;
1786 for j in 0..ndim {
1787 new_idx += old_indices[j] * new_strides[j];
1788 }
1789
1790 result[new_idx] = data[i];
1791 }
1792
1793 return Array::from_vec(result, Shape::new(new_shape));
1794 }
1795
1796 self.swapaxes(source, destination)
1798 }
1799
1800 pub fn interp(x: &Array, xp: &Array, fp: &Array) -> Array {
1821 assert_eq!(x.dtype(), DType::Float32, "Only Float32 supported");
1822 assert_eq!(xp.dtype(), DType::Float32, "Only Float32 supported");
1823 assert_eq!(fp.dtype(), DType::Float32, "Only Float32 supported");
1824 assert_eq!(
1825 xp.size(),
1826 fp.size(),
1827 "xp and fp must have the same size"
1828 );
1829
1830 let x_data = x.to_vec();
1831 let xp_data = xp.to_vec();
1832 let fp_data = fp.to_vec();
1833
1834 let result: Vec<f32> = x_data
1835 .iter()
1836 .map(|&xi| {
1837 if xi <= xp_data[0] {
1839 return fp_data[0];
1840 }
1841 if xi >= xp_data[xp_data.len() - 1] {
1842 return fp_data[fp_data.len() - 1];
1843 }
1844
1845 for i in 0..xp_data.len() - 1 {
1847 if xi >= xp_data[i] && xi <= xp_data[i + 1] {
1848 let t = (xi - xp_data[i]) / (xp_data[i + 1] - xp_data[i]);
1850 return fp_data[i] + t * (fp_data[i + 1] - fp_data[i]);
1851 }
1852 }
1853
1854 fp_data[fp_data.len() - 1]
1855 })
1856 .collect();
1857
1858 Array::from_vec(result, x.shape().clone())
1859 }
1860
1861 pub fn lerp(&self, other: &Array, weight: f32) -> Array {
1875 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1876 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1877 assert_eq!(
1878 self.shape(),
1879 other.shape(),
1880 "Arrays must have the same shape"
1881 );
1882
1883 let self_data = self.to_vec();
1884 let other_data = other.to_vec();
1885
1886 let result: Vec<f32> = self_data
1887 .iter()
1888 .zip(other_data.iter())
1889 .map(|(&a, &b)| a + weight * (b - a))
1890 .collect();
1891
1892 Array::from_vec(result, self.shape().clone())
1893 }
1894
1895 pub fn lerp_array(&self, other: &Array, weights: &Array) -> Array {
1910 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1911 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1912 assert_eq!(weights.dtype(), DType::Float32, "Only Float32 supported");
1913 assert_eq!(
1914 self.shape(),
1915 other.shape(),
1916 "Arrays must have the same shape"
1917 );
1918 assert_eq!(
1919 self.shape(),
1920 weights.shape(),
1921 "Arrays and weights must have the same shape"
1922 );
1923
1924 let self_data = self.to_vec();
1925 let other_data = other.to_vec();
1926 let weight_data = weights.to_vec();
1927
1928 let result: Vec<f32> = self_data
1929 .iter()
1930 .zip(other_data.iter())
1931 .zip(weight_data.iter())
1932 .map(|((&a, &b), &w)| a + w * (b - a))
1933 .collect();
1934
1935 Array::from_vec(result, self.shape().clone())
1936 }
1937
1938 pub fn convolve(&self, kernel: &Array) -> Array {
1956 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1957 assert_eq!(kernel.dtype(), DType::Float32, "Only Float32 supported");
1958 assert_eq!(self.ndim(), 1, "Convolve only supports 1D arrays");
1959 assert_eq!(kernel.ndim(), 1, "Kernel must be 1D");
1960
1961 let signal = self.to_vec();
1962 let mut k = kernel.to_vec();
1963 let n = signal.len();
1964 let m = k.len();
1965
1966 if m > n {
1967 return Array::zeros(Shape::new(vec![0]), DType::Float32);
1969 }
1970
1971 k.reverse();
1973
1974 let out_size = n - m + 1;
1976 let mut result = Vec::with_capacity(out_size);
1977
1978 for i in 0..out_size {
1979 let mut sum = 0.0;
1980 for j in 0..m {
1981 sum += signal[i + j] * k[j];
1982 }
1983 result.push(sum);
1984 }
1985
1986 Array::from_vec(result, Shape::new(vec![out_size]))
1987 }
1988
1989 pub fn correlate(&self, template: &Array) -> Array {
2006 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2007 assert_eq!(template.dtype(), DType::Float32, "Only Float32 supported");
2008 assert_eq!(self.ndim(), 1, "Correlate only supports 1D arrays");
2009 assert_eq!(template.ndim(), 1, "Template must be 1D");
2010
2011 let signal = self.to_vec();
2012 let t = template.to_vec();
2013 let n = signal.len();
2014 let m = t.len();
2015
2016 if m > n {
2017 return Array::zeros(Shape::new(vec![0]), DType::Float32);
2019 }
2020
2021 let out_size = n - m + 1;
2023 let mut result = Vec::with_capacity(out_size);
2024
2025 for i in 0..out_size {
2026 let mut sum = 0.0;
2027 for j in 0..m {
2028 sum += signal[i + j] * t[j];
2030 }
2031 result.push(sum);
2032 }
2033
2034 Array::from_vec(result, Shape::new(vec![out_size]))
2035 }
2036
2037 pub fn vstack(&self, other: &Array) -> Array {
2052 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2053 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2054
2055 let self_shape = self.shape().as_slice();
2056 let other_shape = other.shape().as_slice();
2057
2058 let self_2d = if self_shape.len() == 1 {
2060 self.reshape(Shape::new(vec![1, self_shape[0]]))
2061 } else {
2062 self.clone()
2063 };
2064
2065 let other_2d = if other_shape.len() == 1 {
2066 other.reshape(Shape::new(vec![1, other_shape[0]]))
2067 } else {
2068 other.clone()
2069 };
2070
2071 Array::concatenate(&[self_2d, other_2d], 0)
2073 }
2074
2075 pub fn hstack(&self, other: &Array) -> Array {
2090 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2091 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2092
2093 let self_shape = self.shape().as_slice();
2094 let other_shape = other.shape().as_slice();
2095
2096 if self_shape.len() == 1 && other_shape.len() == 1 {
2098 return Array::concatenate(&[self.clone(), other.clone()], 0);
2099 }
2100
2101 Array::concatenate(&[self.clone(), other.clone()], 1)
2103 }
2104
2105 pub fn vsplit(&self, num_sections: usize) -> Vec<Array> {
2118 let shape = self.shape().as_slice();
2119
2120 if shape.len() == 1 {
2121 return Array::split(self, num_sections, 0);
2123 }
2124
2125 Array::split(self, num_sections, 0)
2127 }
2128
2129 pub fn hsplit(&self, num_sections: usize) -> Vec<Array> {
2142 let shape = self.shape().as_slice();
2143 assert!(!shape.is_empty(), "hsplit requires at least 1D array");
2144
2145 if shape.len() == 1 {
2146 return Array::split(self, num_sections, 0);
2148 }
2149
2150 Array::split(self, num_sections, 1)
2152 }
2153
2154 pub fn append(&self, values: &Array) -> Array {
2166 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2167 assert_eq!(values.dtype(), DType::Float32, "Only Float32 supported");
2168
2169 let mut data = self.to_vec();
2170 data.extend(values.to_vec());
2171
2172 let new_size = data.len();
2173 Array::from_vec(data, Shape::new(vec![new_size]))
2174 }
2175
2176 pub fn insert(&self, index: usize, values: &Array) -> Array {
2188 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2189 assert_eq!(values.dtype(), DType::Float32, "Only Float32 supported");
2190
2191 let mut data = self.to_vec();
2192 let values_data = values.to_vec();
2193
2194 assert!(index <= data.len(), "Index out of bounds");
2195
2196 for (i, &val) in values_data.iter().enumerate() {
2198 data.insert(index + i, val);
2199 }
2200
2201 let new_size = data.len();
2202 Array::from_vec(data, Shape::new(vec![new_size]))
2203 }
2204
2205 pub fn delete(&self, indices: &[usize]) -> Array {
2216 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2217
2218 let data = self.to_vec();
2219 let mut result = Vec::new();
2220
2221 for (i, &val) in data.iter().enumerate() {
2222 if !indices.contains(&i) {
2223 result.push(val);
2224 }
2225 }
2226
2227 let new_size = result.len();
2228 Array::from_vec(result, Shape::new(vec![new_size]))
2229 }
2230
2231 pub fn trim_zeros(&self) -> Array {
2242 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2243
2244 let data = self.to_vec();
2245
2246 let start = data.iter().position(|&x| x.abs() > 1e-10).unwrap_or(data.len());
2248
2249 let end = data.iter().rposition(|&x| x.abs() > 1e-10).map(|i| i + 1).unwrap_or(0);
2251
2252 if start >= end {
2253 return Array::zeros(Shape::new(vec![0]), DType::Float32);
2254 }
2255
2256 let result = data[start..end].to_vec();
2257 let new_size = result.len();
2258 Array::from_vec(result, Shape::new(vec![new_size]))
2259 }
2260
2261 pub fn repeat_elements(&self, repeats: usize) -> Array {
2272 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2273
2274 let data = self.to_vec();
2275 let mut result = Vec::with_capacity(data.len() * repeats);
2276
2277 for &val in data.iter() {
2278 for _ in 0..repeats {
2279 result.push(val);
2280 }
2281 }
2282
2283 let new_size = result.len();
2284 Array::from_vec(result, Shape::new(vec![new_size]))
2285 }
2286
2287 pub fn resize(&self, new_size: usize) -> Array {
2298 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2299
2300 let data = self.to_vec();
2301 let mut result = Vec::with_capacity(new_size);
2302
2303 for i in 0..new_size {
2304 result.push(data[i % data.len()]);
2305 }
2306
2307 Array::from_vec(result, Shape::new(vec![new_size]))
2308 }
2309
2310 pub fn corrcoef(&self, other: &Array) -> f32 {
2322 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2323 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2324 assert_eq!(self.size(), other.size(), "Arrays must have same size");
2325
2326 let x = self.to_vec();
2327 let y = other.to_vec();
2328 let n = x.len() as f32;
2329
2330 let x_mean: f32 = x.iter().sum::<f32>() / n;
2332 let y_mean: f32 = y.iter().sum::<f32>() / n;
2333
2334 let mut cov = 0.0;
2336 let mut x_var = 0.0;
2337 let mut y_var = 0.0;
2338
2339 for (x_val, y_val) in x.iter().zip(y.iter()) {
2340 let x_diff = x_val - x_mean;
2341 let y_diff = y_val - y_mean;
2342 cov += x_diff * y_diff;
2343 x_var += x_diff * x_diff;
2344 y_var += y_diff * y_diff;
2345 }
2346
2347 if x_var.abs() < 1e-10 || y_var.abs() < 1e-10 {
2349 return 0.0;
2350 }
2351
2352 cov / (x_var * y_var).sqrt()
2353 }
2354
2355 pub fn flatnonzero(&self) -> Vec<usize> {
2366 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2367
2368 let data = self.to_vec();
2369 data.iter()
2370 .enumerate()
2371 .filter_map(|(i, &val)| if val.abs() > 1e-10 { Some(i) } else { None })
2372 .collect()
2373 }
2374
2375 pub fn tile_1d(&self, reps: usize) -> Array {
2386 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2387 let data = self.to_vec();
2388
2389 let mut result = Vec::with_capacity(data.len() * reps);
2390 for _ in 0..reps {
2391 result.extend_from_slice(&data);
2392 }
2393
2394 let new_size = result.len();
2395 Array::from_vec(result, Shape::new(vec![new_size]))
2396 }
2397
2398 pub fn column_stack(arrays: &[Array]) -> Array {
2411 assert!(!arrays.is_empty(), "Need at least one array");
2412 assert_eq!(arrays[0].dtype(), DType::Float32, "Only Float32 supported");
2413
2414 let n_rows = arrays[0].size();
2415 let n_cols = arrays.len();
2416
2417 let mut result = Vec::with_capacity(n_rows * n_cols);
2418 for row_idx in 0..n_rows {
2419 for arr in arrays {
2420 let data = arr.to_vec();
2421 result.push(data[row_idx]);
2422 }
2423 }
2424
2425 Array::from_vec(result, Shape::new(vec![n_rows, n_cols]))
2426 }
2427
2428 pub fn row_stack(arrays: &[Array]) -> Array {
2443 assert!(!arrays.is_empty(), "Need at least one array");
2444
2445 let arrays_2d: Vec<Array> = arrays.iter().map(|arr| {
2447 if arr.shape().as_slice().len() == 1 {
2448 let size = arr.size();
2449 let data = arr.to_vec();
2450 Array::from_vec(data, Shape::new(vec![1, size]))
2451 } else {
2452 arr.clone()
2453 }
2454 }).collect();
2455
2456 Array::concatenate(&arrays_2d, 0)
2457 }
2458
2459 pub fn dstack(arrays: &[Array]) -> Array {
2471 assert!(!arrays.is_empty(), "Need at least one array");
2472
2473 let arrays_3d: Vec<Array> = arrays.iter().map(|arr| {
2475 let shape = arr.shape().as_slice();
2476 match shape.len() {
2477 1 => {
2478 let size = arr.size();
2479 let data = arr.to_vec();
2480 Array::from_vec(data, Shape::new(vec![1, size, 1]))
2481 }
2482 2 => {
2483 let data = arr.to_vec();
2484 Array::from_vec(data, Shape::new(vec![shape[0], shape[1], 1]))
2485 }
2486 _ => arr.clone(),
2487 }
2488 }).collect();
2489
2490 Array::concatenate(&arrays_3d, 2)
2491 }
2492
2493 pub fn absolute(&self) -> Array {
2504 self.abs()
2505 }
2506
2507 pub fn clamp(&self, min: f32, max: f32) -> Array {
2518 self.clip(min, max)
2519 }
2520
2521 pub fn fill_diagonal(&self, value: f32) -> Array {
2532 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2533
2534 let shape = self.shape();
2535 let dims = shape.as_slice();
2536 assert_eq!(dims.len(), 2, "fill_diagonal only supports 2D arrays");
2537
2538 let (rows, cols) = (dims[0], dims[1]);
2539 let data = self.to_vec();
2540 let mut result = data.clone();
2541
2542 let min_dim = rows.min(cols);
2543 for i in 0..min_dim {
2544 result[i * cols + i] = value;
2545 }
2546
2547 Array::from_vec(result, shape.clone())
2548 }
2549
2550 pub fn polyval(&self, coeffs: &Array) -> Array {
2567 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2568 assert_eq!(coeffs.dtype(), DType::Float32, "Only Float32 supported");
2569 assert_eq!(coeffs.ndim(), 1, "Coefficients must be 1D");
2570
2571 let x_data = self.to_vec();
2572 let c_data = coeffs.to_vec();
2573
2574 let result_data: Vec<f32> = x_data
2575 .iter()
2576 .map(|&x| {
2577 let mut result = 0.0;
2579 for &coeff in &c_data {
2580 result = result * x + coeff;
2581 }
2582 result
2583 })
2584 .collect();
2585
2586 Array::from_vec(result_data, self.shape().clone())
2587 }
2588
2589 pub fn polyadd(&self, other: &Array) -> Array {
2603 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2604 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2605 assert_eq!(self.ndim(), 1, "Polynomials must be 1D");
2606 assert_eq!(other.ndim(), 1, "Polynomials must be 1D");
2607
2608 let p_data = self.to_vec();
2609 let q_data = other.to_vec();
2610
2611 let max_len = p_data.len().max(q_data.len());
2612 let mut result = vec![0.0; max_len];
2613
2614 let p_offset = max_len - p_data.len();
2616 let q_offset = max_len - q_data.len();
2617
2618 for (i, &val) in p_data.iter().enumerate() {
2619 result[p_offset + i] += val;
2620 }
2621
2622 for (i, &val) in q_data.iter().enumerate() {
2623 result[q_offset + i] += val;
2624 }
2625
2626 Array::from_vec(result, Shape::new(vec![max_len]))
2627 }
2628
2629 pub fn polymul(&self, other: &Array) -> Array {
2643 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2644 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2645 assert_eq!(self.ndim(), 1, "Polynomials must be 1D");
2646 assert_eq!(other.ndim(), 1, "Polynomials must be 1D");
2647
2648 let p = self.to_vec();
2649 let q = other.to_vec();
2650 let result_len = p.len() + q.len() - 1;
2651 let mut result = vec![0.0; result_len];
2652
2653 for (i, &pi) in p.iter().enumerate() {
2654 for (j, &qj) in q.iter().enumerate() {
2655 result[i + j] += pi * qj;
2656 }
2657 }
2658
2659 Array::from_vec(result, Shape::new(vec![result_len]))
2660 }
2661
2662 pub fn polyder(&self) -> Array {
2675 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2676 assert_eq!(self.ndim(), 1, "Polynomial must be 1D");
2677
2678 let coeffs = self.to_vec();
2679 if coeffs.len() <= 1 {
2680 return Array::from_vec(vec![0.0], Shape::new(vec![1]));
2681 }
2682
2683 let n = coeffs.len() - 1;
2684 let mut result = Vec::with_capacity(n);
2685
2686 for (i, &c) in coeffs.iter().take(n).enumerate() {
2687 let degree = (n - i) as f32;
2688 result.push(c * degree);
2689 }
2690
2691 Array::from_vec(result, Shape::new(vec![n]))
2692 }
2693
2694 pub fn polysub(&self, other: &Array) -> Array {
2707 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2708 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2709 assert_eq!(self.ndim(), 1, "Polynomials must be 1D");
2710 assert_eq!(other.ndim(), 1, "Polynomials must be 1D");
2711
2712 let p_data = self.to_vec();
2713 let q_data = other.to_vec();
2714
2715 let max_len = p_data.len().max(q_data.len());
2716 let mut result = vec![0.0; max_len];
2717
2718 let p_offset = max_len - p_data.len();
2719 let q_offset = max_len - q_data.len();
2720
2721 for (i, &val) in p_data.iter().enumerate() {
2722 result[p_offset + i] += val;
2723 }
2724
2725 for (i, &val) in q_data.iter().enumerate() {
2726 result[q_offset + i] -= val;
2727 }
2728
2729 Array::from_vec(result, Shape::new(vec![max_len]))
2730 }
2731
2732 pub fn piecewise(&self, conditions: &[Array], functions: &[Array]) -> Array {
2759 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2760 assert_eq!(conditions.len(), functions.len(), "Number of conditions must match number of functions");
2761 assert!(!conditions.is_empty(), "At least one condition required");
2762
2763 let n = self.size();
2764 for cond in conditions {
2765 assert_eq!(cond.size(), n, "Condition size must match array size");
2766 }
2767 for func in functions {
2768 assert_eq!(func.size(), n, "Function output size must match array size");
2769 }
2770
2771 let mut result = vec![0.0; n];
2772 let mut assigned = vec![false; n];
2773
2774 for (cond, func) in conditions.iter().zip(functions.iter()) {
2775 let cond_data = cond.to_vec();
2776 let func_data = func.to_vec();
2777
2778 for i in 0..n {
2779 if !assigned[i] && cond_data[i] != 0.0 {
2780 result[i] = func_data[i];
2781 assigned[i] = true;
2782 }
2783 }
2784 }
2785
2786 Array::from_vec(result, self.shape().clone())
2787 }
2788
2789 pub fn place(&self, mask: &Array, values: &[f32]) -> Array {
2804 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2805 assert_eq!(mask.dtype(), DType::Float32, "Only Float32 supported");
2806 assert_eq!(self.size(), mask.size(), "Array and mask must have same size");
2807
2808 let mut data = self.to_vec();
2809 let mask_data = mask.to_vec();
2810
2811 let mut value_idx = 0;
2812 for (i, &m) in mask_data.iter().enumerate() {
2813 if m != 0.0 && value_idx < values.len() {
2814 data[i] = values[value_idx];
2815 value_idx += 1;
2816 }
2817 }
2818
2819 Array::from_vec(data, self.shape().clone())
2820 }
2821
2822 pub fn copyto(&self, src: &Array, mask: &Array) -> Array {
2837 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2838 assert_eq!(src.dtype(), DType::Float32, "Only Float32 supported");
2839 assert_eq!(mask.dtype(), DType::Float32, "Only Float32 supported");
2840 assert_eq!(self.size(), src.size(), "Arrays must have same size");
2841 assert_eq!(self.size(), mask.size(), "Array and mask must have same size");
2842
2843 let mut data = self.to_vec();
2844 let src_data = src.to_vec();
2845 let mask_data = mask.to_vec();
2846
2847 for i in 0..data.len() {
2848 if mask_data[i] != 0.0 {
2849 data[i] = src_data[i];
2850 }
2851 }
2852
2853 Array::from_vec(data, self.shape().clone())
2854 }
2855
2856 pub fn argmax_with_value(&self) -> (usize, f32) {
2868 let data = self.to_vec();
2869 let mut max_idx = 0;
2870 let mut max_val = f32::NEG_INFINITY;
2871
2872 for (i, &x) in data.iter().enumerate() {
2873 if x > max_val {
2874 max_val = x;
2875 max_idx = i;
2876 }
2877 }
2878
2879 (max_idx, max_val)
2880 }
2881
2882 pub fn argmin_with_value(&self) -> (usize, f32) {
2894 let data = self.to_vec();
2895 let mut min_idx = 0;
2896 let mut min_val = f32::INFINITY;
2897
2898 for (i, &x) in data.iter().enumerate() {
2899 if x < min_val {
2900 min_val = x;
2901 min_idx = i;
2902 }
2903 }
2904
2905 (min_idx, min_val)
2906 }
2907
2908 pub fn permute(&self, axes: &[usize]) -> Array {
2922 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2923 let shape = self.shape().as_slice();
2924 assert_eq!(axes.len(), shape.len(), "axes must match number of dimensions");
2925
2926 let new_shape: Vec<usize> = axes.iter().map(|&a| shape[a]).collect();
2928
2929 if shape.len() == 2 && axes == [1, 0] {
2931 return self.transpose();
2932 }
2933
2934 let data = self.to_vec();
2936 let mut result = vec![0.0; data.len()];
2937
2938 let mut old_strides = vec![1usize; shape.len()];
2940 for i in (0..shape.len() - 1).rev() {
2941 old_strides[i] = old_strides[i + 1] * shape[i + 1];
2942 }
2943
2944 let mut new_strides = vec![1usize; new_shape.len()];
2946 for i in (0..new_shape.len() - 1).rev() {
2947 new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
2948 }
2949
2950 let permuted_old_strides: Vec<usize> = axes.iter().map(|&a| old_strides[a]).collect();
2952
2953 for new_idx in 0..data.len() {
2955 let mut old_idx = 0;
2956 let mut remainder = new_idx;
2957 for (d, &new_stride) in new_strides.iter().enumerate() {
2958 let coord = remainder / new_stride;
2959 remainder %= new_stride;
2960 old_idx += coord * permuted_old_strides[d];
2961 }
2962 result[new_idx] = data[old_idx];
2963 }
2964
2965 Array::from_vec(result, Shape::new(new_shape))
2966 }
2967
2968 pub fn gather(&self, indices: &Array, axis: usize) -> Array {
2983 let shape = self.shape().as_slice();
2984 assert!(axis < shape.len(), "Axis out of bounds");
2985
2986 let indices_data: Vec<usize> = indices.to_vec().iter().map(|&x| x as usize).collect();
2987 let data = self.to_vec();
2988
2989 if axis == 0 && shape.len() == 1 {
2990 let result: Vec<f32> = indices_data.iter().map(|&i| data[i]).collect();
2992 return Array::from_vec(result, Shape::new(vec![indices_data.len()]));
2993 }
2994
2995 let mut result = Vec::new();
2997 let mut new_shape = shape.to_vec();
2998 new_shape[axis] = indices_data.len();
2999
3000 let mut strides: Vec<usize> = Vec::with_capacity(shape.len());
3002 let mut stride = 1;
3003 for &dim in shape.iter().rev() {
3004 strides.push(stride);
3005 stride *= dim;
3006 }
3007 strides.reverse();
3008
3009 let total_size: usize = new_shape.iter().product();
3010 result.reserve(total_size);
3011
3012 for out_idx in 0..total_size {
3014 let mut coords = Vec::with_capacity(shape.len());
3016 let mut remainder = out_idx;
3017 for &dim in &new_shape {
3018 coords.push(remainder % dim);
3019 remainder /= dim;
3020 }
3021 coords.reverse();
3022
3023 let idx_in_indices = coords[axis];
3025 coords[axis] = indices_data[idx_in_indices];
3026
3027 let mut in_idx = 0;
3029 for (d, &coord) in coords.iter().enumerate() {
3030 in_idx += coord * strides[d];
3031 }
3032
3033 result.push(data[in_idx]);
3034 }
3035
3036 Array::from_vec(result, Shape::new(new_shape))
3037 }
3038
3039 pub fn gather_nd(&self, indices: &[(usize, usize)]) -> Array {
3051 let data = self.to_vec();
3052 let shape = self.shape().as_slice();
3053 assert_eq!(shape.len(), 2, "gather_nd only supports 2D arrays for now");
3054
3055 let cols = shape[1];
3056 let result: Vec<f32> = indices
3057 .iter()
3058 .map(|&(r, c)| data[r * cols + c])
3059 .collect();
3060
3061 Array::from_vec(result, Shape::new(vec![indices.len()]))
3062 }
3063
3064 pub fn segment_sum(&self, segment_ids: &Array, num_segments: usize) -> Array {
3076 assert_eq!(self.size(), segment_ids.size(), "Data and segment_ids must have same size");
3077
3078 let data = self.to_vec();
3079 let ids: Vec<usize> = segment_ids.to_vec().iter().map(|&x| x as usize).collect();
3080
3081 let mut result = vec![0.0; num_segments];
3082 for (val, &seg_id) in data.iter().zip(ids.iter()) {
3083 if seg_id < num_segments {
3084 result[seg_id] += val;
3085 }
3086 }
3087
3088 Array::from_vec(result, Shape::new(vec![num_segments]))
3089 }
3090
3091 pub fn segment_mean(&self, segment_ids: &Array, num_segments: usize) -> Array {
3093 assert_eq!(self.size(), segment_ids.size(), "Data and segment_ids must have same size");
3094
3095 let data = self.to_vec();
3096 let ids: Vec<usize> = segment_ids.to_vec().iter().map(|&x| x as usize).collect();
3097
3098 let mut sums = vec![0.0; num_segments];
3099 let mut counts = vec![0usize; num_segments];
3100
3101 for (val, &seg_id) in data.iter().zip(ids.iter()) {
3102 if seg_id < num_segments {
3103 sums[seg_id] += val;
3104 counts[seg_id] += 1;
3105 }
3106 }
3107
3108 let result: Vec<f32> = sums
3109 .iter()
3110 .zip(counts.iter())
3111 .map(|(&sum, &count)| if count > 0 { sum / count as f32 } else { 0.0 })
3112 .collect();
3113
3114 Array::from_vec(result, Shape::new(vec![num_segments]))
3115 }
3116
3117 pub fn segment_max(&self, segment_ids: &Array, num_segments: usize) -> Array {
3119 assert_eq!(self.size(), segment_ids.size(), "Data and segment_ids must have same size");
3120
3121 let data = self.to_vec();
3122 let ids: Vec<usize> = segment_ids.to_vec().iter().map(|&x| x as usize).collect();
3123
3124 let mut result = vec![f32::NEG_INFINITY; num_segments];
3125
3126 for (val, &seg_id) in data.iter().zip(ids.iter()) {
3127 if seg_id < num_segments && *val > result[seg_id] {
3128 result[seg_id] = *val;
3129 }
3130 }
3131
3132 Array::from_vec(result, Shape::new(vec![num_segments]))
3133 }
3134
3135 pub fn segment_min(&self, segment_ids: &Array, num_segments: usize) -> Array {
3137 assert_eq!(self.size(), segment_ids.size(), "Data and segment_ids must have same size");
3138
3139 let data = self.to_vec();
3140 let ids: Vec<usize> = segment_ids.to_vec().iter().map(|&x| x as usize).collect();
3141
3142 let mut result = vec![f32::INFINITY; num_segments];
3143
3144 for (val, &seg_id) in data.iter().zip(ids.iter()) {
3145 if seg_id < num_segments && *val < result[seg_id] {
3146 result[seg_id] = *val;
3147 }
3148 }
3149
3150 Array::from_vec(result, Shape::new(vec![num_segments]))
3151 }
3152
3153 pub fn flip_axes(&self, axes: &[usize]) -> Array {
3155 let mut result = self.clone();
3156 for &axis in axes {
3157 result = result.flip(axis);
3158 }
3159 result
3160 }
3161
3162 pub fn moveaxis_multiple(&self, sources: &[usize], destinations: &[usize]) -> Array {
3164 assert_eq!(sources.len(), destinations.len(), "sources and destinations must have same length");
3165
3166 let mut result = self.clone();
3167 for (&src, &dst) in sources.iter().zip(destinations.iter()) {
3168 result = result.moveaxis(src, dst);
3169 }
3170 result
3171 }
3172
3173 pub fn expand_dims_multiple(&self, axes: &[usize]) -> Array {
3175 let mut sorted_axes = axes.to_vec();
3176 sorted_axes.sort();
3177
3178 let mut result = self.clone();
3179 for (i, &axis) in sorted_axes.iter().enumerate() {
3180 result = result.expand_dims(axis + i);
3181 }
3182 result
3183 }
3184
3185 pub fn squeeze_all(&self) -> Array {
3187 let shape = self.shape().as_slice();
3188 let new_shape: Vec<usize> = shape.iter().cloned().filter(|&d| d != 1).collect();
3189
3190 if new_shape.is_empty() {
3191 return Array::from_vec(self.to_vec(), Shape::new(vec![1]));
3193 }
3194
3195 self.reshape(Shape::new(new_shape))
3196 }
3197
3198 pub fn unflatten(&self, dim: usize, sizes: &[usize]) -> Array {
3200 let shape = self.shape().as_slice();
3201 assert!(dim < shape.len(), "dim out of bounds");
3202 assert_eq!(
3203 sizes.iter().product::<usize>(),
3204 shape[dim],
3205 "sizes must multiply to the dimension size"
3206 );
3207
3208 let mut new_shape = Vec::with_capacity(shape.len() - 1 + sizes.len());
3209 new_shape.extend(&shape[..dim]);
3210 new_shape.extend(sizes);
3211 new_shape.extend(&shape[dim + 1..]);
3212
3213 self.reshape(Shape::new(new_shape))
3214 }
3215
3216 pub fn repeat_axis(&self, repeats: usize, axis: usize) -> Array {
3218 let shape = self.shape().as_slice();
3219 assert!(axis < shape.len(), "axis out of bounds");
3220
3221 if axis == 0 {
3222 let data = self.to_vec();
3224 let chunk_size = self.size() / shape[0];
3225 let mut result = Vec::with_capacity(self.size() * repeats);
3226
3227 for chunk in data.chunks(chunk_size) {
3228 for _ in 0..repeats {
3229 result.extend(chunk);
3230 }
3231 }
3232
3233 let mut new_shape = shape.to_vec();
3234 new_shape[axis] *= repeats;
3235
3236 Array::from_vec(result, Shape::new(new_shape))
3237 } else {
3238 let mut new_shape = shape.to_vec();
3241 new_shape[axis] *= repeats;
3242
3243 let data = self.to_vec();
3244 let mut result = Vec::with_capacity(new_shape.iter().product());
3245
3246 let inner_size: usize = shape[axis + 1..].iter().product();
3248 let outer_size: usize = shape[..axis].iter().product();
3249 let axis_size = shape[axis];
3250
3251 for outer in 0..outer_size {
3252 for ax in 0..axis_size {
3253 for _ in 0..repeats {
3254 let start = outer * axis_size * inner_size + ax * inner_size;
3255 result.extend(&data[start..start + inner_size]);
3256 }
3257 }
3258 }
3259
3260 Array::from_vec(result, Shape::new(new_shape))
3261 }
3262 }
3263
3264 pub fn tile_nd(&self, reps: &[usize]) -> Array {
3266 assert_eq!(reps.len(), self.ndim(), "reps must have same length as ndim");
3267
3268 let mut result = self.clone();
3269 for (axis, &rep) in reps.iter().enumerate() {
3270 if rep > 1 {
3271 result = result.repeat_axis(rep, axis);
3272 }
3273 }
3274 result
3275 }
3276}
3277
3278#[cfg(test)]
3279mod tests {
3280 use super::*;
3281
3282 #[test]
3283 fn test_concatenate_1d() {
3284 let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
3285 let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
3286 let c = Array::from_vec(vec![5.0, 6.0], Shape::new(vec![2]));
3287
3288 let result = Array::concatenate(&[a, b, c], 0);
3289 assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
3290 }
3291
3292 #[test]
3293 fn test_stack() {
3294 let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
3295 let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
3296
3297 let result = Array::stack(&[a, b], 0);
3298 assert_eq!(result.shape().as_slice(), &[2, 2]);
3299 assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
3300 }
3301
3302 #[test]
3303 fn test_split() {
3304 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![6]));
3306 let parts = Array::split(&a, 3, 0);
3307
3308 assert_eq!(parts.len(), 3);
3309 assert_eq!(parts[0].to_vec(), vec![1.0, 2.0]);
3310 assert_eq!(parts[1].to_vec(), vec![3.0, 4.0]);
3311 assert_eq!(parts[2].to_vec(), vec![5.0, 6.0]);
3312 }
3313
3314 #[test]
3315 fn test_split_2d() {
3316 let a = Array::from_vec(
3318 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
3319 Shape::new(vec![4, 2]),
3320 );
3321 let parts = Array::split(&a, 2, 0);
3322
3323 assert_eq!(parts.len(), 2);
3324 assert_eq!(parts[0].shape().as_slice(), &[2, 2]);
3325 assert_eq!(parts[0].to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
3326 assert_eq!(parts[1].to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
3327 }
3328
3329 #[test]
3330 fn test_where_cond() {
3331 let cond =
3332 Array::from_vec(vec![1.0, 0.0, 1.0, 0.0], Shape::new(vec![4]));
3333 let x =
3334 Array::from_vec(vec![10.0, 20.0, 30.0, 40.0], Shape::new(vec![4]));
3335 let y = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
3336
3337 let result = Array::where_cond(&cond, &x, &y);
3338 assert_eq!(result.to_vec(), vec![10.0, 2.0, 30.0, 4.0]);
3339 }
3340
3341 #[test]
3342 fn test_where_cond_broadcast_scalar_condition() {
3343 let condition = Array::from_vec(vec![1.0], Shape::new(vec![1]));
3344 let x = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3345 let y = Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3]));
3346 let result = Array::where_cond(&condition, &x, &y);
3347 assert_eq!(result.to_vec(), vec![10.0, 20.0, 30.0]);
3348 }
3349
3350 #[test]
3351 fn test_where_cond_broadcast_scalar_x() {
3352 let condition = Array::from_vec(vec![1.0, 0.0, 1.0], Shape::new(vec![3]));
3353 let x = Array::from_vec(vec![42.0], Shape::new(vec![1]));
3354 let y = Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3]));
3355 let result = Array::where_cond(&condition, &x, &y);
3356 assert_eq!(result.to_vec(), vec![42.0, 200.0, 42.0]);
3357 }
3358
3359 #[test]
3360 fn test_where_cond_broadcast_scalar_y() {
3361 let condition = Array::from_vec(vec![1.0, 0.0, 1.0], Shape::new(vec![3]));
3362 let x = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3363 let y = Array::from_vec(vec![99.0], Shape::new(vec![1]));
3364 let result = Array::where_cond(&condition, &x, &y);
3365 assert_eq!(result.to_vec(), vec![10.0, 99.0, 30.0]);
3366 }
3367
3368 #[test]
3369 fn test_where_cond_2d() {
3370 let condition = Array::from_vec(
3371 vec![1.0, 0.0, 0.0, 1.0],
3372 Shape::new(vec![2, 2])
3373 );
3374 let x = Array::from_vec(
3375 vec![1.0, 2.0, 3.0, 4.0],
3376 Shape::new(vec![2, 2])
3377 );
3378 let y = Array::from_vec(
3379 vec![10.0, 20.0, 30.0, 40.0],
3380 Shape::new(vec![2, 2])
3381 );
3382 let result = Array::where_cond(&condition, &x, &y);
3383 assert_eq!(result.to_vec(), vec![1.0, 20.0, 30.0, 4.0]);
3384 }
3385
3386 #[test]
3387 fn test_where_cond_broadcast_2d() {
3388 let condition = Array::from_vec(vec![1.0, 0.0], Shape::new(vec![2]));
3389 let x = Array::from_vec(
3390 vec![1.0, 2.0, 3.0, 4.0],
3391 Shape::new(vec![2, 2])
3392 );
3393 let y = Array::from_vec(
3394 vec![10.0, 20.0, 30.0, 40.0],
3395 Shape::new(vec![2, 2])
3396 );
3397 let result = Array::where_cond(&condition, &x, &y);
3398 assert_eq!(result.to_vec(), vec![1.0, 20.0, 3.0, 40.0]);
3399 }
3400
3401 #[test]
3402 fn test_where_cond_negative_values() {
3403 let condition = Array::from_vec(vec![-5.0, 0.0, 3.14], Shape::new(vec![3]));
3404 let x = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3405 let y = Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3]));
3406 let result = Array::where_cond(&condition, &x, &y);
3407 assert_eq!(result.to_vec(), vec![10.0, 200.0, 30.0]);
3409 }
3410
3411 #[test]
3412 fn test_select_basic() {
3413 let indices = Array::from_vec(vec![0.0, 1.0, 2.0, 1.0], Shape::new(vec![4]));
3414 let choice0 = Array::from_vec(vec![10.0, 10.0, 10.0, 10.0], Shape::new(vec![4]));
3415 let choice1 = Array::from_vec(vec![20.0, 20.0, 20.0, 20.0], Shape::new(vec![4]));
3416 let choice2 = Array::from_vec(vec![30.0, 30.0, 30.0, 30.0], Shape::new(vec![4]));
3417 let result = Array::select(&indices, &[choice0, choice1, choice2]);
3418 assert_eq!(result.to_vec(), vec![10.0, 20.0, 30.0, 20.0]);
3419 }
3420
3421 #[test]
3422 fn test_select_varying_values() {
3423 let indices = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
3424 let choice0 = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3425 let choice1 = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3426 let result = Array::select(&indices, &[choice0, choice1]);
3427 assert_eq!(result.to_vec(), vec![1.0, 20.0, 3.0]);
3430 }
3431
3432 #[test]
3433 fn test_select_2d() {
3434 let indices = Array::from_vec(vec![0.0, 1.0, 1.0, 0.0], Shape::new(vec![2, 2]));
3435 let choice0 = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
3436 let choice1 = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0], Shape::new(vec![2, 2]));
3437 let result = Array::select(&indices, &[choice0, choice1]);
3438 assert_eq!(result.to_vec(), vec![1.0, 20.0, 30.0, 4.0]);
3439 }
3440
3441 #[test]
3442 fn test_clip() {
3443 let a = Array::from_vec(
3444 vec![-5.0, 0.0, 5.0, 10.0, 15.0],
3445 Shape::new(vec![5]),
3446 );
3447 let clipped = a.clip(0.0, 10.0);
3448 assert_eq!(clipped.to_vec(), vec![0.0, 0.0, 5.0, 10.0, 10.0]);
3449 }
3450
3451 #[test]
3452 fn test_flip_1d() {
3453 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
3454 let flipped = a.flip(0);
3455 assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
3456 }
3457
3458 #[test]
3459 fn test_flip_2d() {
3460 let a = Array::from_vec(
3461 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3462 Shape::new(vec![3, 2]),
3463 );
3464 let flipped = a.flip(0);
3465 assert_eq!(flipped.to_vec(), vec![5.0, 6.0, 3.0, 4.0, 1.0, 2.0]);
3466 }
3467
3468 #[test]
3469 fn test_nan_to_num() {
3470 let a = Array::from_vec(
3471 vec![1.0, f32::NAN, f32::INFINITY, -f32::INFINITY, 5.0],
3472 Shape::new(vec![5]),
3473 );
3474 let result = a.nan_to_num(0.0, 1e10, -1e10);
3475 assert_eq!(result.to_vec()[0], 1.0);
3476 assert_eq!(result.to_vec()[1], 0.0);
3477 assert_eq!(result.to_vec()[2], 1e10);
3478 assert_eq!(result.to_vec()[3], -1e10);
3479 assert_eq!(result.to_vec()[4], 5.0);
3480 }
3481
3482 #[test]
3483 fn test_isnan() {
3484 let a = Array::from_vec(
3485 vec![1.0, f32::NAN, 3.0, f32::NAN, 5.0],
3486 Shape::new(vec![5]),
3487 );
3488 let result = a.isnan();
3489 assert_eq!(result.to_vec(), vec![0.0, 1.0, 0.0, 1.0, 0.0]);
3490 }
3491
3492 #[test]
3493 fn test_isinf() {
3494 let a = Array::from_vec(
3495 vec![1.0, f32::INFINITY, -f32::INFINITY, 3.0],
3496 Shape::new(vec![4]),
3497 );
3498 let result = a.isinf();
3499 assert_eq!(result.to_vec(), vec![0.0, 1.0, 1.0, 0.0]);
3500 }
3501
3502 #[test]
3503 fn test_isfinite() {
3504 let a = Array::from_vec(
3505 vec![1.0, f32::NAN, f32::INFINITY, 3.0],
3506 Shape::new(vec![4]),
3507 );
3508 let result = a.isfinite();
3509 assert_eq!(result.to_vec(), vec![1.0, 0.0, 0.0, 1.0]);
3510 }
3511
3512 #[test]
3513 fn test_clip_by_norm() {
3514 let a = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
3516 let clipped = a.clip_by_norm(2.0);
3517 let result = clipped.to_vec();
3518 assert!((result[0] - 1.2).abs() < 1e-5);
3521 assert!((result[1] - 1.6).abs() < 1e-5);
3522
3523 let b = Array::from_vec(vec![1.0, 1.0], Shape::new(vec![2]));
3525 let clipped2 = b.clip_by_norm(5.0);
3526 assert_eq!(clipped2.to_vec(), vec![1.0, 1.0]);
3527 }
3528
3529 #[test]
3530 fn test_ravel() {
3531 let a = Array::from_vec(
3532 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3533 Shape::new(vec![2, 3]),
3534 );
3535 let flat = a.ravel();
3536 assert_eq!(flat.shape().as_slice(), &[6]);
3537 assert_eq!(flat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
3538 }
3539
3540 #[test]
3541 fn test_flatten() {
3542 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
3543 let flat = a.flatten();
3544 assert_eq!(flat.shape().as_slice(), &[4]);
3545 assert_eq!(flat.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
3546 }
3547
3548 #[test]
3549 fn test_atleast_1d() {
3550 let a = Array::from_vec(vec![5.0], Shape::new(vec![]));
3552 let b = a.atleast_1d();
3553 assert_eq!(b.shape().as_slice(), &[1]);
3554
3555 let c = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
3557 let d = c.atleast_1d();
3558 assert_eq!(d.shape().as_slice(), &[2]);
3559 }
3560
3561 #[test]
3562 fn test_atleast_2d() {
3563 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3565 let b = a.atleast_2d();
3566 assert_eq!(b.shape().as_slice(), &[1, 3]);
3567
3568 let c = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![1, 2]));
3570 let d = c.atleast_2d();
3571 assert_eq!(d.shape().as_slice(), &[1, 2]);
3572 }
3573
3574 #[test]
3575 fn test_atleast_3d() {
3576 let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
3578 let b = a.atleast_3d();
3579 assert_eq!(b.shape().as_slice(), &[1, 2, 1]);
3580
3581 let c = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![1, 2]));
3583 let d = c.atleast_3d();
3584 assert_eq!(d.shape().as_slice(), &[1, 2, 1]);
3585 }
3586
3587 #[test]
3588 fn test_broadcast_to() {
3589 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3591 let b = a.broadcast_to(Shape::new(vec![2, 3]));
3592 assert_eq!(b.shape().as_slice(), &[2, 3]);
3593 assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
3594 }
3595
3596 #[test]
3597 fn test_take() {
3598 let a = Array::from_vec(
3599 vec![10.0, 20.0, 30.0, 40.0, 50.0],
3600 Shape::new(vec![5]),
3601 );
3602 let indices = vec![0, 2, 4];
3603 let result = a.take(&indices);
3604 assert_eq!(result.to_vec(), vec![10.0, 30.0, 50.0]);
3605 }
3606
3607 #[test]
3608 fn test_nonzero() {
3609 let a = Array::from_vec(
3610 vec![0.0, 1.0, 0.0, 3.0, 0.0, 5.0],
3611 Shape::new(vec![6]),
3612 );
3613 let indices = a.nonzero();
3614 assert_eq!(indices, vec![1, 3, 5]);
3615 }
3616
3617 #[test]
3618 fn test_argwhere() {
3619 let a = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0], Shape::new(vec![4]));
3620 let indices = a.argwhere();
3621 assert_eq!(indices, vec![1, 3]);
3622 }
3623
3624 #[test]
3625 fn test_compress() {
3626 let a = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0], Shape::new(vec![4]));
3627 let condition =
3628 Array::from_vec(vec![1.0, 0.0, 1.0, 0.0], Shape::new(vec![4]));
3629 let result = a.compress(&condition);
3630 assert_eq!(result.to_vec(), vec![10.0, 30.0]);
3631 }
3632
3633 #[test]
3634 fn test_choose() {
3635 let choices = vec![
3636 Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3])),
3637 Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3])),
3638 ];
3639 let indices = vec![0, 1, 0];
3640 let result = Array::choose(&indices, &choices);
3641 assert_eq!(result.to_vec(), vec![10.0, 200.0, 30.0]);
3642 }
3643
3644 #[test]
3645 fn test_extract() {
3646 let a = Array::from_vec(
3647 vec![1.0, 2.0, 3.0, 4.0, 5.0],
3648 Shape::new(vec![5]),
3649 );
3650 let condition = Array::from_vec(
3651 vec![1.0, 0.0, 1.0, 0.0, 1.0],
3652 Shape::new(vec![5]),
3653 );
3654 let result = a.extract(&condition);
3655 assert_eq!(result.to_vec(), vec![1.0, 3.0, 5.0]);
3656 }
3657
3658 #[test]
3659 fn test_roll() {
3660 let a = Array::from_vec(
3661 vec![1.0, 2.0, 3.0, 4.0, 5.0],
3662 Shape::new(vec![5]),
3663 );
3664 let rolled = a.roll(2);
3665 assert_eq!(rolled.to_vec(), vec![4.0, 5.0, 1.0, 2.0, 3.0]);
3666
3667 let rolled_neg = a.roll(-1);
3669 assert_eq!(rolled_neg.to_vec(), vec![2.0, 3.0, 4.0, 5.0, 1.0]);
3670 }
3671
3672 #[test]
3673 fn test_rot90() {
3674 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
3675
3676 let rot1 = a.rot90(1);
3678 assert_eq!(rot1.to_vec(), vec![2.0, 4.0, 1.0, 3.0]);
3679
3680 let rot2 = a.rot90(2);
3682 assert_eq!(rot2.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
3683
3684 let rot3 = a.rot90(3);
3686 assert_eq!(rot3.to_vec(), vec![3.0, 1.0, 4.0, 2.0]);
3687 }
3688
3689 #[test]
3690 fn test_swapaxes() {
3691 let a = Array::from_vec(
3692 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3693 Shape::new(vec![2, 3]),
3694 );
3695 let swapped = a.swapaxes(0, 1);
3696 assert_eq!(swapped.shape().as_slice(), &[3, 2]);
3697 assert_eq!(swapped.to_vec(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
3698 }
3699
3700 #[test]
3701 fn test_moveaxis() {
3702 let a = Array::from_vec(
3703 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3704 Shape::new(vec![1, 2, 3]),
3705 );
3706 let moved = a.moveaxis(2, 0);
3707 assert_eq!(moved.shape().as_slice(), &[3, 1, 2]);
3708 }
3709
3710 #[test]
3711 fn test_interp() {
3712 let xp = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3713 let fp = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3714 let x = Array::from_vec(vec![1.5, 2.5], Shape::new(vec![2]));
3715 let result = Array::interp(&x, &xp, &fp);
3716 assert_eq!(result.to_vec(), vec![15.0, 25.0]);
3717
3718 let x_edge = Array::from_vec(vec![0.5, 3.5], Shape::new(vec![2]));
3720 let result_edge = Array::interp(&x_edge, &xp, &fp);
3721 assert_eq!(result_edge.to_vec(), vec![10.0, 30.0]);
3722 }
3723
3724 #[test]
3725 fn test_lerp() {
3726 let a = Array::from_vec(vec![0.0, 10.0, 20.0], Shape::new(vec![3]));
3727 let b = Array::from_vec(vec![100.0, 110.0, 120.0], Shape::new(vec![3]));
3728 let result = a.lerp(&b, 0.5);
3729 assert_eq!(result.to_vec(), vec![50.0, 60.0, 70.0]);
3730
3731 let result_0 = a.lerp(&b, 0.0);
3733 assert_eq!(result_0.to_vec(), a.to_vec());
3734
3735 let result_1 = a.lerp(&b, 1.0);
3737 assert_eq!(result_1.to_vec(), b.to_vec());
3738 }
3739
3740 #[test]
3741 fn test_lerp_array() {
3742 let a = Array::from_vec(vec![0.0, 10.0, 20.0], Shape::new(vec![3]));
3743 let b = Array::from_vec(vec![100.0, 110.0, 120.0], Shape::new(vec![3]));
3744 let weights = Array::from_vec(vec![0.0, 0.5, 1.0], Shape::new(vec![3]));
3745 let result = a.lerp_array(&b, &weights);
3746 assert_eq!(result.to_vec(), vec![0.0, 60.0, 120.0]);
3747 }
3748
3749 #[test]
3750 fn test_put() {
3751 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3752 let result = a.put(&[0, 2, 4], &[10.0, 30.0, 50.0]);
3753 assert_eq!(result.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
3754 assert_eq!(result.shape().as_slice(), &[5]);
3755
3756 let a2d = Array::from_vec(
3758 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3759 Shape::new(vec![2, 3]),
3760 );
3761 let result2d = a2d.put(&[0, 5], &[100.0, 600.0]);
3762 assert_eq!(result2d.to_vec(), vec![100.0, 2.0, 3.0, 4.0, 5.0, 600.0]);
3763 }
3764
3765 #[test]
3766 fn test_scatter() {
3767 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3768 let result = a.scatter(&[0, 2, 4], &[10.0, 30.0, 50.0]);
3769 assert_eq!(result.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
3770 assert_eq!(result.shape().as_slice(), &[5]);
3771
3772 let a2d = Array::from_vec(
3774 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3775 Shape::new(vec![2, 3]),
3776 );
3777 let result2d = a2d.scatter(&[0, 5], &[100.0, 600.0]);
3778 assert_eq!(result2d.to_vec(), vec![100.0, 2.0, 3.0, 4.0, 5.0, 600.0]);
3779 }
3780
3781 #[test]
3782 fn test_scatter_add() {
3783 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3784 let result = a.scatter_add(&[0, 2, 4], &[10.0, 30.0, 50.0]);
3785 assert_eq!(result.to_vec(), vec![11.0, 2.0, 33.0, 4.0, 55.0]);
3786
3787 let a2 = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3789 let result2 = a2.scatter_add(&[0, 0, 1], &[5.0, 3.0, 10.0]);
3790 assert_eq!(result2.to_vec(), vec![9.0, 12.0, 3.0]); }
3792
3793 #[test]
3794 fn test_scatter_min() {
3795 let a = Array::from_vec(vec![5.0, 10.0, 15.0, 20.0, 25.0], Shape::new(vec![5]));
3796 let result = a.scatter_min(&[1, 2, 3], &[8.0, 20.0, 15.0]);
3797 assert_eq!(result.to_vec(), vec![5.0, 8.0, 15.0, 15.0, 25.0]);
3798
3799 let a2 = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3801 let result2 = a2.scatter_min(&[0, 1, 2], &[5.0, 10.0, 15.0]);
3802 assert_eq!(result2.to_vec(), vec![1.0, 2.0, 3.0]);
3803 }
3804
3805 #[test]
3806 fn test_scatter_max() {
3807 let a = Array::from_vec(vec![5.0, 10.0, 15.0, 20.0, 25.0], Shape::new(vec![5]));
3808 let result = a.scatter_max(&[1, 2, 3], &[12.0, 10.0, 25.0]);
3809 assert_eq!(result.to_vec(), vec![5.0, 12.0, 15.0, 25.0, 25.0]);
3810
3811 let a2 = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3813 let result2 = a2.scatter_max(&[0, 1, 2], &[5.0, 10.0, 15.0]);
3814 assert_eq!(result2.to_vec(), vec![10.0, 20.0, 30.0]);
3815 }
3816
3817 #[test]
3818 fn test_scatter_mul() {
3819 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3820 let result = a.scatter_mul(&[1, 2, 3], &[2.0, 3.0, 0.5]);
3821 assert_eq!(result.to_vec(), vec![1.0, 4.0, 9.0, 2.0, 5.0]);
3822
3823 let a2 = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3825 let result2 = a2.scatter_mul(&[0, 0, 1], &[2.0, 3.0, 5.0]);
3826 assert_eq!(result2.to_vec(), vec![6.0, 10.0, 3.0]); }
3828
3829 #[test]
3830 fn test_scatter_duplicate_indices() {
3831 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3833 let result = a.scatter(&[0, 0], &[10.0, 20.0]);
3834 assert_eq!(result.to_vec(), vec![20.0, 2.0, 3.0]); let result2 = a.scatter_add(&[0, 0], &[10.0, 20.0]);
3838 assert_eq!(result2.to_vec(), vec![31.0, 2.0, 3.0]); }
3840
3841 #[test]
3842 fn test_take_along_axis_1d() {
3843 let a = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], Shape::new(vec![5]));
3844 let indices = Array::from_vec(vec![0.0, 2.0, 4.0], Shape::new(vec![3]));
3845 let result = a.take_along_axis(&indices, 0);
3846 assert_eq!(result.to_vec(), vec![10.0, 30.0, 50.0]);
3847 }
3848
3849 #[test]
3850 fn test_take_along_axis_2d_axis1() {
3851 let a = Array::from_vec(
3852 vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
3853 Shape::new(vec![2, 3]),
3854 );
3855 let indices = Array::from_vec(vec![0.0, 2.0], Shape::new(vec![2]));
3857 let result = a.take_along_axis(&indices, 1);
3858 assert_eq!(result.to_vec(), vec![10.0, 60.0]);
3859 }
3860
3861 #[test]
3862 fn test_take_along_axis_2d_axis0() {
3863 let a = Array::from_vec(
3864 vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
3865 Shape::new(vec![2, 3]),
3866 );
3867 let indices = Array::from_vec(vec![1.0, 0.0, 1.0], Shape::new(vec![3]));
3869 let result = a.take_along_axis(&indices, 0);
3870 assert_eq!(result.to_vec(), vec![40.0, 20.0, 60.0]);
3871 }
3872
3873 #[test]
3874 fn test_take_along_axis_3d() {
3875 let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
3877 let a = Array::from_vec(data, Shape::new(vec![2, 3, 4]));
3878
3879 let indices = Array::from_vec(
3883 vec![
3884 0.0, 3.0, 1.0, 2.0, 0.0, 1.0, 3.0, 0.0, 2.0, 1.0, 1.0, 3.0, ],
3891 Shape::new(vec![2, 3, 2]),
3892 );
3893
3894 let result = a.take_along_axis(&indices, 2);
3895 assert_eq!(result.shape().as_slice(), &[2, 3, 2]);
3896
3897 assert_eq!(
3908 result.to_vec(),
3909 vec![0.0, 3.0, 5.0, 6.0, 8.0, 9.0, 15.0, 12.0, 18.0, 17.0, 21.0, 23.0]
3910 );
3911 }
3912
3913 #[test]
3914 fn test_take_along_axis_3d_middle_axis() {
3915 let data: Vec<f32> = (0..12).map(|x| x as f32).collect();
3917 let a = Array::from_vec(data, Shape::new(vec![2, 3, 2]));
3918
3919 let indices = Array::from_vec(
3922 vec![
3923 0.0, 2.0, 1.0, 0.0, 2.0, 1.0, 0.0, 2.0, ],
3928 Shape::new(vec![2, 2, 2]),
3929 );
3930
3931 let result = a.take_along_axis(&indices, 1);
3932 assert_eq!(result.shape().as_slice(), &[2, 2, 2]);
3933
3934 assert_eq!(
3947 result.to_vec(),
3948 vec![0.0, 5.0, 2.0, 1.0, 10.0, 9.0, 6.0, 11.0]
3949 );
3950 }
3951
3952 #[test]
3953 fn test_broadcast_arrays_compatible() {
3954 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3956 let b = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3957 let c = Array::from_vec(vec![100.0], Shape::new(vec![1]));
3958
3959 let broadcasted = Array::broadcast_arrays(&[a, b, c]);
3960
3961 assert_eq!(broadcasted.len(), 3);
3962 assert_eq!(broadcasted[0].shape().as_slice(), &[3]);
3963 assert_eq!(broadcasted[1].shape().as_slice(), &[3]);
3964 assert_eq!(broadcasted[2].shape().as_slice(), &[3]);
3965
3966 assert_eq!(broadcasted[0].to_vec(), vec![1.0, 2.0, 3.0]);
3968 assert_eq!(broadcasted[1].to_vec(), vec![10.0, 20.0, 30.0]);
3969 assert_eq!(broadcasted[2].to_vec(), vec![100.0, 100.0, 100.0]);
3970 }
3971
3972 #[test]
3973 fn test_broadcast_arrays_2d() {
3974 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
3976 let b = Array::from_vec(vec![10.0, 20.0], Shape::new(vec![2, 1]));
3977
3978 let broadcasted = Array::broadcast_arrays(&[a, b]);
3979
3980 assert_eq!(broadcasted.len(), 2);
3981 assert_eq!(broadcasted[0].shape().as_slice(), &[2, 3]);
3982 assert_eq!(broadcasted[1].shape().as_slice(), &[2, 3]);
3983
3984 assert_eq!(
3986 broadcasted[0].to_vec(),
3987 vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
3988 );
3989 assert_eq!(
3990 broadcasted[1].to_vec(),
3991 vec![10.0, 10.0, 10.0, 20.0, 20.0, 20.0]
3992 );
3993 }
3994
3995 #[test]
3996 fn test_broadcast_arrays_empty() {
3997 let broadcasted = Array::broadcast_arrays(&[]);
3999 assert_eq!(broadcasted.len(), 0);
4000 }
4001
4002 #[test]
4003 fn test_broadcast_arrays_single() {
4004 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
4006 let broadcasted = Array::broadcast_arrays(&[a.clone()]);
4007
4008 assert_eq!(broadcasted.len(), 1);
4009 assert_eq!(broadcasted[0].shape().as_slice(), &[3]);
4010 assert_eq!(broadcasted[0].to_vec(), vec![1.0, 2.0, 3.0]);
4011 }
4012
4013 #[test]
4014 #[should_panic(expected = "Cannot broadcast arrays with shapes")]
4015 fn test_broadcast_arrays_incompatible() {
4016 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
4018 let b = Array::from_vec(vec![10.0, 20.0], Shape::new(vec![2]));
4019
4020 Array::broadcast_arrays(&[a, b]);
4021 }
4022
4023 #[test]
4024 #[should_panic(expected = "Cannot broadcast array of shape")]
4025 fn test_broadcast_to_error_message() {
4026 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
4028 a.broadcast_to(Shape::new(vec![2]));
4029 }
4030
4031 #[test]
4032 fn test_convolve() {
4033 let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
4034 let kernel = Array::from_vec(vec![1.0, 0.0, -1.0], Shape::new(vec![3]));
4035 let conv = signal.convolve(&kernel);
4036 assert_eq!(conv.to_vec(), vec![2.0, 2.0, 2.0]);
4040 }
4041
4042 #[test]
4043 fn test_convolve_averaging() {
4044 let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
4045 let kernel = Array::from_vec(vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], Shape::new(vec![3]));
4046 let conv = signal.convolve(&kernel);
4047 assert_eq!(conv.to_vec(), vec![2.0, 3.0, 4.0]);
4049 }
4050
4051 #[test]
4052 fn test_convolve_identity() {
4053 let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
4054 let kernel = Array::from_vec(vec![1.0], Shape::new(vec![1]));
4055 let conv = signal.convolve(&kernel);
4056 assert_eq!(conv.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
4058 }
4059
4060 #[test]
4061 fn test_correlate() {
4062 let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
4063 let template = Array::from_vec(vec![1.0, 2.0, 1.0], Shape::new(vec![3]));
4064 let corr = signal.correlate(&template);
4065 assert_eq!(corr.to_vec(), vec![8.0, 12.0, 16.0]);
4067 }
4068
4069 #[test]
4070 fn test_correlate_pattern_detection() {
4071 let signal = Array::from_vec(vec![0.0, 0.0, 1.0, 2.0, 1.0, 0.0, 0.0], Shape::new(vec![7]));
4073 let pattern = Array::from_vec(vec![1.0, 2.0, 1.0], Shape::new(vec![3]));
4074 let corr = signal.correlate(&pattern);
4075 let max_idx = corr
4077 .to_vec()
4078 .iter()
4079 .enumerate()
4080 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
4081 .unwrap()
4082 .0;
4083 assert_eq!(max_idx, 2); }
4085
4086 #[test]
4087 fn test_convolve_correlate_difference() {
4088 let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
4090 let kernel = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
4091
4092 let conv = signal.convolve(&kernel);
4093 let corr = signal.correlate(&kernel);
4094
4095 assert_ne!(conv.to_vec(), corr.to_vec());
4097 }
4098}