1pub mod extended;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13pub fn reshape<T: Element, D: Dimension>(
27 a: &Array<T, D>,
28 new_shape: &[usize],
29) -> FerrayResult<Array<T, IxDyn>> {
30 let old_size = a.size();
31 let new_size: usize = new_shape.iter().product();
32 if old_size != new_size {
33 return Err(FerrayError::shape_mismatch(format!(
34 "cannot reshape array of size {} into shape {:?} (size {})",
35 old_size, new_shape, new_size,
36 )));
37 }
38 let data: Vec<T> = a.iter().cloned().collect();
39 Array::from_vec(IxDyn::new(new_shape), data)
40}
41
42pub fn ravel<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
46 let data: Vec<T> = a.iter().cloned().collect();
47 let n = data.len();
48 Array::from_vec(Ix1::new([n]), data)
49}
50
51pub fn flatten<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
55 ravel(a)
56}
57
58pub fn squeeze<T: Element, D: Dimension>(
69 a: &Array<T, D>,
70 axis: Option<usize>,
71) -> FerrayResult<Array<T, IxDyn>> {
72 let shape = a.shape();
73 match axis {
74 Some(ax) => {
75 if ax >= shape.len() {
76 return Err(FerrayError::axis_out_of_bounds(ax, shape.len()));
77 }
78 if shape[ax] != 1 {
79 return Err(FerrayError::invalid_value(format!(
80 "cannot select axis {} with size {} for squeeze (must be 1)",
81 ax, shape[ax],
82 )));
83 }
84 let new_shape: Vec<usize> = shape
85 .iter()
86 .enumerate()
87 .filter(|&(i, _)| i != ax)
88 .map(|(_, &s)| s)
89 .collect();
90 let data: Vec<T> = a.iter().cloned().collect();
91 Array::from_vec(IxDyn::new(&new_shape), data)
92 }
93 None => {
94 let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
95 let new_shape = if new_shape.is_empty() && !shape.is_empty() {
98 vec![1]
99 } else if new_shape.is_empty() {
100 vec![]
101 } else {
102 new_shape
103 };
104 let data: Vec<T> = a.iter().cloned().collect();
105 Array::from_vec(IxDyn::new(&new_shape), data)
106 }
107 }
108}
109
110pub fn expand_dims<T: Element, D: Dimension>(
117 a: &Array<T, D>,
118 axis: usize,
119) -> FerrayResult<Array<T, IxDyn>> {
120 let ndim = a.ndim();
121 if axis > ndim {
122 return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
123 }
124 let mut new_shape: Vec<usize> = a.shape().to_vec();
125 new_shape.insert(axis, 1);
126 let data: Vec<T> = a.iter().cloned().collect();
127 Array::from_vec(IxDyn::new(&new_shape), data)
128}
129
130pub fn broadcast_to<T: Element, D: Dimension>(
139 a: &Array<T, D>,
140 new_shape: &[usize],
141) -> FerrayResult<Array<T, IxDyn>> {
142 let src_shape = a.shape();
143 let src_ndim = src_shape.len();
144 let dst_ndim = new_shape.len();
145
146 if dst_ndim < src_ndim {
147 return Err(FerrayError::BroadcastFailure {
148 shape_a: src_shape.to_vec(),
149 shape_b: new_shape.to_vec(),
150 });
151 }
152
153 let pad = dst_ndim - src_ndim;
155 for i in 0..src_ndim {
156 let s = src_shape[i];
157 let d = new_shape[pad + i];
158 if s != d && s != 1 {
159 return Err(FerrayError::BroadcastFailure {
160 shape_a: src_shape.to_vec(),
161 shape_b: new_shape.to_vec(),
162 });
163 }
164 }
165
166 let total: usize = new_shape.iter().product();
168 let mut data = Vec::with_capacity(total);
169 let src_data: Vec<T> = a.iter().cloned().collect();
170
171 let mut src_strides = vec![1usize; src_ndim];
173 for i in (0..src_ndim.saturating_sub(1)).rev() {
174 src_strides[i] = src_strides[i + 1] * src_shape[i + 1];
175 }
176
177 let mut out_strides = vec![1usize; dst_ndim];
179 for i in (0..dst_ndim.saturating_sub(1)).rev() {
180 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
181 }
182
183 for flat in 0..total {
184 let mut rem = flat;
185 let mut s_idx = 0usize;
186 #[allow(clippy::needless_range_loop)]
187 for i in 0..dst_ndim {
188 let idx = rem / out_strides[i];
189 rem %= out_strides[i];
190 if i >= pad {
191 let src_i = i - pad;
192 let src_idx = if src_shape[src_i] == 1 { 0 } else { idx };
193 s_idx += src_idx * src_strides[src_i];
194 }
195 }
196 data.push(src_data[s_idx].clone());
197 }
198
199 Array::from_vec(IxDyn::new(new_shape), data)
200}
201
202pub fn concatenate<T: Element>(
215 arrays: &[Array<T, IxDyn>],
216 axis: usize,
217) -> FerrayResult<Array<T, IxDyn>> {
218 if arrays.is_empty() {
219 return Err(FerrayError::invalid_value(
220 "concatenate: need at least one array",
221 ));
222 }
223 let ndim = arrays[0].ndim();
224 if axis >= ndim {
225 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
226 }
227 let base_shape = arrays[0].shape();
228
229 let mut total_along_axis = 0usize;
231 for arr in arrays {
232 if arr.ndim() != ndim {
233 return Err(FerrayError::shape_mismatch(format!(
234 "all arrays must have same ndim; got {} and {}",
235 ndim,
236 arr.ndim(),
237 )));
238 }
239 for (i, (&s, &base)) in arr.shape().iter().zip(base_shape.iter()).enumerate() {
240 if i != axis && s != base {
241 return Err(FerrayError::shape_mismatch(format!(
242 "shape mismatch on axis {}: {} vs {}",
243 i, s, base,
244 )));
245 }
246 }
247 total_along_axis += arr.shape()[axis];
248 }
249
250 let mut new_shape = base_shape.to_vec();
252 new_shape[axis] = total_along_axis;
253 let total: usize = new_shape.iter().product();
254 let mut data = Vec::with_capacity(total);
255
256 let src_vecs: Vec<Vec<T>> = arrays.iter().map(|a| a.iter().cloned().collect()).collect();
258
259 let mut out_strides = vec![1usize; ndim];
261 for i in (0..ndim - 1).rev() {
262 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
263 }
264
265 for flat_idx in 0..total {
267 let mut rem = flat_idx;
269 let mut nd_idx = vec![0usize; ndim];
270 for i in 0..ndim {
271 nd_idx[i] = rem / out_strides[i];
272 rem %= out_strides[i];
273 }
274
275 let concat_idx = nd_idx[axis];
277 let mut offset = 0;
278 let mut src_arr_idx = 0;
279 for (k, arr) in arrays.iter().enumerate() {
280 let len_along = arr.shape()[axis];
281 if concat_idx < offset + len_along {
282 src_arr_idx = k;
283 break;
284 }
285 offset += len_along;
286 }
287 let local_concat_idx = concat_idx - offset;
288
289 let src_shape = arrays[src_arr_idx].shape();
291 let mut src_flat = 0usize;
292 let mut src_mul = 1usize;
293 for i in (0..ndim).rev() {
294 let idx = if i == axis {
295 local_concat_idx
296 } else {
297 nd_idx[i]
298 };
299 src_flat += idx * src_mul;
300 src_mul *= src_shape[i];
301 }
302
303 let elem = src_vecs[src_arr_idx].get(src_flat).ok_or_else(|| {
304 FerrayError::invalid_value(format!(
305 "concatenate: internal index {} out of range for source array of length {}",
306 src_flat,
307 src_vecs[src_arr_idx].len(),
308 ))
309 })?;
310 data.push(elem.clone());
311 }
312
313 Array::from_vec(IxDyn::new(&new_shape), data)
314}
315
316pub fn stack<T: Element>(arrays: &[Array<T, IxDyn>], axis: usize) -> FerrayResult<Array<T, IxDyn>> {
328 if arrays.is_empty() {
329 return Err(FerrayError::invalid_value("stack: need at least one array"));
330 }
331 let base_shape = arrays[0].shape();
332 let ndim = base_shape.len();
333
334 if axis > ndim {
335 return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
336 }
337
338 for arr in &arrays[1..] {
339 if arr.shape() != base_shape {
340 return Err(FerrayError::shape_mismatch(format!(
341 "all input arrays must have the same shape; got {:?} and {:?}",
342 base_shape,
343 arr.shape(),
344 )));
345 }
346 }
347
348 let mut expanded = Vec::with_capacity(arrays.len());
350 for arr in arrays {
351 expanded.push(expand_dims(arr, axis)?);
352 }
353 concatenate(&expanded, axis)
354}
355
356pub fn vstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
361 if arrays.is_empty() {
362 return Err(FerrayError::invalid_value(
363 "vstack: need at least one array",
364 ));
365 }
366 let ndim = arrays[0].ndim();
368 if ndim == 1 {
369 let mut reshaped = Vec::with_capacity(arrays.len());
370 for arr in arrays {
371 let n = arr.shape()[0];
372 reshaped.push(reshape(arr, &[1, n])?);
373 }
374 concatenate(&reshaped, 0)
375 } else {
376 concatenate(arrays, 0)
377 }
378}
379
380pub fn hstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
385 if arrays.is_empty() {
386 return Err(FerrayError::invalid_value(
387 "hstack: need at least one array",
388 ));
389 }
390 let ndim = arrays[0].ndim();
391 if ndim == 1 {
392 concatenate(arrays, 0)
393 } else {
394 concatenate(arrays, 1)
395 }
396}
397
398pub fn dstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
406 if arrays.is_empty() {
407 return Err(FerrayError::invalid_value(
408 "dstack: need at least one array",
409 ));
410 }
411 let mut expanded = Vec::with_capacity(arrays.len());
412 for arr in arrays {
413 let shape = arr.shape();
414 match shape.len() {
415 1 => {
416 let n = shape[0];
417 expanded.push(reshape(arr, &[1, n, 1])?);
418 }
419 2 => {
420 let (m, n) = (shape[0], shape[1]);
421 expanded.push(reshape(arr, &[m, n, 1])?);
422 }
423 _ => {
424 let data: Vec<T> = arr.iter().cloned().collect();
426 expanded.push(Array::from_vec(IxDyn::new(shape), data)?);
427 }
428 }
429 }
430 concatenate(&expanded, 2)
431}
432
433pub fn block<T: Element>(blocks: &[Vec<Array<T, IxDyn>>]) -> FerrayResult<Array<T, IxDyn>> {
443 if blocks.is_empty() {
444 return Err(FerrayError::invalid_value("block: empty input"));
445 }
446 let mut rows = Vec::with_capacity(blocks.len());
447 for row in blocks {
448 if row.is_empty() {
449 return Err(FerrayError::invalid_value("block: empty row"));
450 }
451 let row_arr = if row.len() == 1 {
453 let data: Vec<T> = row[0].iter().cloned().collect();
454 Array::from_vec(IxDyn::new(row[0].shape()), data)?
455 } else {
456 hstack(row)?
457 };
458 rows.push(row_arr);
459 }
460 if rows.len() == 1 {
461 Ok(rows.pop().unwrap_or_else(|| unreachable!()))
463 } else {
464 vstack(&rows)
465 }
466}
467
468pub fn split<T: Element>(
477 a: &Array<T, IxDyn>,
478 n_sections: usize,
479 axis: usize,
480) -> FerrayResult<Vec<Array<T, IxDyn>>> {
481 let shape = a.shape();
482 if axis >= shape.len() {
483 return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
484 }
485 let axis_len = shape[axis];
486 if n_sections == 0 {
487 return Err(FerrayError::invalid_value("split: n_sections must be > 0"));
488 }
489 if axis_len % n_sections != 0 {
490 return Err(FerrayError::invalid_value(format!(
491 "array of size {} along axis {} cannot be evenly split into {} sections",
492 axis_len, axis, n_sections,
493 )));
494 }
495 let chunk_size = axis_len / n_sections;
496 let indices: Vec<usize> = (1..n_sections).map(|i| i * chunk_size).collect();
497 array_split(a, &indices, axis)
498}
499
500pub fn array_split<T: Element>(
509 a: &Array<T, IxDyn>,
510 indices: &[usize],
511 axis: usize,
512) -> FerrayResult<Vec<Array<T, IxDyn>>> {
513 let shape = a.shape();
514 let ndim = shape.len();
515 if axis >= ndim {
516 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
517 }
518 let axis_len = shape[axis];
519 let src_data: Vec<T> = a.iter().cloned().collect();
520
521 let mut splits = Vec::with_capacity(indices.len() + 2);
523 splits.push(0);
524 for &idx in indices {
525 splits.push(idx.min(axis_len));
526 }
527 splits.push(axis_len);
528
529 let mut src_strides = vec![1usize; ndim];
531 for i in (0..ndim - 1).rev() {
532 src_strides[i] = src_strides[i + 1] * shape[i + 1];
533 }
534
535 let mut result = Vec::with_capacity(splits.len() - 1);
536 for w in splits.windows(2) {
537 let start = w[0];
538 let end = w[1];
539 let chunk_len = end - start;
540
541 let mut sub_shape = shape.to_vec();
542 sub_shape[axis] = chunk_len;
543 let sub_total: usize = sub_shape.iter().product();
544
545 let mut sub_strides = vec![1usize; ndim];
547 for i in (0..ndim - 1).rev() {
548 sub_strides[i] = sub_strides[i + 1] * sub_shape[i + 1];
549 }
550
551 let mut sub_data = Vec::with_capacity(sub_total);
552 for flat in 0..sub_total {
553 let mut rem = flat;
555 let mut src_flat = 0usize;
556 for i in 0..ndim {
557 let idx = rem / sub_strides[i];
558 rem %= sub_strides[i];
559 let src_idx = if i == axis { idx + start } else { idx };
560 src_flat += src_idx * src_strides[i];
561 }
562 sub_data.push(src_data[src_flat].clone());
563 }
564 result.push(Array::from_vec(IxDyn::new(&sub_shape), sub_data)?);
565 }
566
567 Ok(result)
568}
569
570pub fn vsplit<T: Element>(
574 a: &Array<T, IxDyn>,
575 n_sections: usize,
576) -> FerrayResult<Vec<Array<T, IxDyn>>> {
577 split(a, n_sections, 0)
578}
579
580pub fn hsplit<T: Element>(
584 a: &Array<T, IxDyn>,
585 n_sections: usize,
586) -> FerrayResult<Vec<Array<T, IxDyn>>> {
587 split(a, n_sections, 1)
588}
589
590pub fn dsplit<T: Element>(
594 a: &Array<T, IxDyn>,
595 n_sections: usize,
596) -> FerrayResult<Vec<Array<T, IxDyn>>> {
597 split(a, n_sections, 2)
598}
599
600pub fn transpose<T: Element, D: Dimension>(
615 a: &Array<T, D>,
616 axes: Option<&[usize]>,
617) -> FerrayResult<Array<T, IxDyn>> {
618 let shape = a.shape();
619 let ndim = shape.len();
620 let perm: Vec<usize> = match axes {
621 Some(ax) => {
622 if ax.len() != ndim {
623 return Err(FerrayError::invalid_value(format!(
624 "axes must have length {} but got {}",
625 ndim,
626 ax.len(),
627 )));
628 }
629 let mut seen = vec![false; ndim];
631 for &a in ax {
632 if a >= ndim {
633 return Err(FerrayError::axis_out_of_bounds(a, ndim));
634 }
635 if seen[a] {
636 return Err(FerrayError::invalid_value(format!(
637 "duplicate axis {} in transpose",
638 a,
639 )));
640 }
641 seen[a] = true;
642 }
643 ax.to_vec()
644 }
645 None => (0..ndim).rev().collect(),
646 };
647
648 let new_shape: Vec<usize> = perm.iter().map(|&ax| shape[ax]).collect();
649 let total: usize = new_shape.iter().product();
650 let src_data: Vec<T> = a.iter().cloned().collect();
651
652 let mut src_strides = vec![1usize; ndim];
654 for i in (0..ndim.saturating_sub(1)).rev() {
655 src_strides[i] = src_strides[i + 1] * shape[i + 1];
656 }
657
658 let mut out_strides = vec![1usize; ndim];
660 for i in (0..ndim.saturating_sub(1)).rev() {
661 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
662 }
663
664 let mut data = Vec::with_capacity(total);
665 for flat_out in 0..total {
666 let mut rem = flat_out;
668 let mut src_flat = 0usize;
669 #[allow(clippy::needless_range_loop)]
670 for i in 0..ndim {
671 let idx = rem / out_strides[i];
672 rem %= out_strides[i];
673 src_flat += idx * src_strides[perm[i]];
675 }
676 data.push(src_data[src_flat].clone());
677 }
678
679 Array::from_vec(IxDyn::new(&new_shape), data)
680}
681
682pub fn swapaxes<T: Element, D: Dimension>(
689 a: &Array<T, D>,
690 axis1: usize,
691 axis2: usize,
692) -> FerrayResult<Array<T, IxDyn>> {
693 let ndim = a.ndim();
694 if axis1 >= ndim {
695 return Err(FerrayError::axis_out_of_bounds(axis1, ndim));
696 }
697 if axis2 >= ndim {
698 return Err(FerrayError::axis_out_of_bounds(axis2, ndim));
699 }
700 let mut perm: Vec<usize> = (0..ndim).collect();
701 perm.swap(axis1, axis2);
702 transpose(a, Some(&perm))
703}
704
705pub fn moveaxis<T: Element, D: Dimension>(
712 a: &Array<T, D>,
713 source: usize,
714 destination: usize,
715) -> FerrayResult<Array<T, IxDyn>> {
716 let ndim = a.ndim();
717 if source >= ndim {
718 return Err(FerrayError::axis_out_of_bounds(source, ndim));
719 }
720 if destination >= ndim {
721 return Err(FerrayError::axis_out_of_bounds(destination, ndim));
722 }
723 let mut order: Vec<usize> = (0..ndim).filter(|&x| x != source).collect();
725 order.insert(destination, source);
726 transpose(a, Some(&order))
727}
728
729pub fn rollaxis<T: Element, D: Dimension>(
736 a: &Array<T, D>,
737 axis: usize,
738 start: usize,
739) -> FerrayResult<Array<T, IxDyn>> {
740 let ndim = a.ndim();
741 if axis >= ndim {
742 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
743 }
744 if start > ndim {
745 return Err(FerrayError::axis_out_of_bounds(start, ndim + 1));
746 }
747 let dst = if start > axis { start - 1 } else { start };
748 if axis == dst {
749 let data: Vec<T> = a.iter().cloned().collect();
751 return Array::from_vec(IxDyn::new(a.shape()), data);
752 }
753 moveaxis(a, axis, dst)
754}
755
756pub fn flip<T: Element, D: Dimension>(
763 a: &Array<T, D>,
764 axis: usize,
765) -> FerrayResult<Array<T, IxDyn>> {
766 let shape = a.shape();
767 let ndim = shape.len();
768 if axis >= ndim {
769 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
770 }
771 let src_data: Vec<T> = a.iter().cloned().collect();
772 let total = src_data.len();
773
774 let mut strides = vec![1usize; ndim];
776 for i in (0..ndim.saturating_sub(1)).rev() {
777 strides[i] = strides[i + 1] * shape[i + 1];
778 }
779
780 let mut data = Vec::with_capacity(total);
781 for flat in 0..total {
782 let mut rem = flat;
783 let mut src_flat = 0usize;
784 for i in 0..ndim {
785 let idx = rem / strides[i];
786 rem %= strides[i];
787 let src_idx = if i == axis { shape[i] - 1 - idx } else { idx };
788 src_flat += src_idx * strides[i];
789 }
790 data.push(src_data[src_flat].clone());
791 }
792 Array::from_vec(IxDyn::new(shape), data)
793}
794
795pub fn fliplr<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
802 if a.ndim() < 2 {
803 return Err(FerrayError::invalid_value(
804 "fliplr: array must be at least 2-D",
805 ));
806 }
807 flip(a, 1)
808}
809
810pub fn flipud<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
817 if a.ndim() < 1 {
818 return Err(FerrayError::invalid_value(
819 "flipud: array must be at least 1-D",
820 ));
821 }
822 flip(a, 0)
823}
824
825pub fn rot90<T: Element, D: Dimension>(a: &Array<T, D>, k: i32) -> FerrayResult<Array<T, IxDyn>> {
834 if a.ndim() < 2 {
835 return Err(FerrayError::invalid_value(
836 "rot90: array must be at least 2-D",
837 ));
838 }
839 let k = k.rem_euclid(4);
841 let shape = a.shape();
842 let data: Vec<T> = a.iter().cloned().collect();
843
844 let as_dyn = Array::from_vec(IxDyn::new(shape), data)?;
846
847 match k {
848 0 => Ok(as_dyn),
849 1 => {
850 let flipped = flip(&as_dyn, 1)?;
852 swapaxes(&flipped, 0, 1)
853 }
854 2 => {
855 let f1 = flip(&as_dyn, 0)?;
857 flip(&f1, 1)
858 }
859 3 => {
860 let transposed = swapaxes(&as_dyn, 0, 1)?;
862 flip(&transposed, 1)
863 }
864 _ => unreachable!(),
865 }
866}
867
868pub fn roll<T: Element, D: Dimension>(
878 a: &Array<T, D>,
879 shift: isize,
880 axis: Option<usize>,
881) -> FerrayResult<Array<T, IxDyn>> {
882 match axis {
883 None => {
884 let data: Vec<T> = a.iter().cloned().collect();
886 let n = data.len();
887 if n == 0 {
888 return Array::from_vec(IxDyn::new(a.shape()), data);
889 }
890 let shift = ((shift % n as isize) + n as isize) as usize % n;
891 let mut rolled = Vec::with_capacity(n);
892 for i in 0..n {
893 rolled.push(data[(n + i - shift) % n].clone());
894 }
895 Array::from_vec(IxDyn::new(a.shape()), rolled)
896 }
897 Some(ax) => {
898 let shape = a.shape();
899 let ndim = shape.len();
900 if ax >= ndim {
901 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
902 }
903 let axis_len = shape[ax];
904 if axis_len == 0 {
905 let data: Vec<T> = a.iter().cloned().collect();
906 return Array::from_vec(IxDyn::new(shape), data);
907 }
908 let shift = ((shift % axis_len as isize) + axis_len as isize) as usize % axis_len;
909 let src_data: Vec<T> = a.iter().cloned().collect();
910 let total = src_data.len();
911
912 let mut strides = vec![1usize; ndim];
914 for i in (0..ndim.saturating_sub(1)).rev() {
915 strides[i] = strides[i + 1] * shape[i + 1];
916 }
917
918 let mut data = Vec::with_capacity(total);
919 for flat in 0..total {
920 let mut rem = flat;
921 let mut src_flat = 0usize;
922 #[allow(clippy::needless_range_loop)]
923 for i in 0..ndim {
924 let idx = rem / strides[i];
925 rem %= strides[i];
926 let src_idx = if i == ax {
927 (axis_len + idx - shift) % axis_len
928 } else {
929 idx
930 };
931 src_flat += src_idx * strides[i];
932 }
933 data.push(src_data[src_flat].clone());
934 }
935 Array::from_vec(IxDyn::new(shape), data)
936 }
937 }
938}
939
940#[cfg(test)]
945mod tests {
946 use super::*;
947
948 fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
949 Array::from_vec(IxDyn::new(shape), data).unwrap()
950 }
951
952 #[test]
955 fn test_reshape() {
956 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
957 let b = reshape(&a, &[3, 2]).unwrap();
958 assert_eq!(b.shape(), &[3, 2]);
959 let data: Vec<f64> = b.iter().copied().collect();
960 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
961 }
962
963 #[test]
964 fn test_reshape_size_mismatch() {
965 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
966 assert!(reshape(&a, &[2, 4]).is_err());
967 }
968
969 #[test]
970 fn test_ravel() {
971 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
972 let b = ravel(&a).unwrap();
973 assert_eq!(b.shape(), &[6]);
974 assert_eq!(b.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
975 }
976
977 #[test]
978 fn test_flatten() {
979 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
980 let b = flatten(&a).unwrap();
981 assert_eq!(b.shape(), &[6]);
982 }
983
984 #[test]
985 fn test_squeeze() {
986 let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
987 let b = squeeze(&a, None).unwrap();
988 assert_eq!(b.shape(), &[3]);
989 }
990
991 #[test]
992 fn test_squeeze_specific_axis() {
993 let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
994 let b = squeeze(&a, Some(0)).unwrap();
995 assert_eq!(b.shape(), &[3, 1]);
996 }
997
998 #[test]
999 fn test_squeeze_not_size_1() {
1000 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1001 assert!(squeeze(&a, Some(0)).is_err());
1002 }
1003
1004 #[test]
1005 fn test_expand_dims() {
1006 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1007 let b = expand_dims(&a, 0).unwrap();
1008 assert_eq!(b.shape(), &[1, 3]);
1009 let c = expand_dims(&a, 1).unwrap();
1010 assert_eq!(c.shape(), &[3, 1]);
1011 }
1012
1013 #[test]
1014 fn test_expand_dims_oob() {
1015 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1016 assert!(expand_dims(&a, 3).is_err());
1017 }
1018
1019 #[test]
1020 fn test_broadcast_to() {
1021 let a = dyn_arr(&[1, 3], vec![1.0, 2.0, 3.0]);
1022 let b = broadcast_to(&a, &[3, 3]).unwrap();
1023 assert_eq!(b.shape(), &[3, 3]);
1024 let data: Vec<f64> = b.iter().copied().collect();
1025 assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
1026 }
1027
1028 #[test]
1029 fn test_broadcast_to_1d_to_2d() {
1030 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1031 let b = broadcast_to(&a, &[2, 3]).unwrap();
1032 assert_eq!(b.shape(), &[2, 3]);
1033 }
1034
1035 #[test]
1036 fn test_broadcast_to_incompatible() {
1037 let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1038 assert!(broadcast_to(&a, &[3]).is_err());
1039 }
1040
1041 #[test]
1044 fn test_concatenate() {
1045 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1046 let b = dyn_arr(&[2, 3], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1047 let c = concatenate(&[a, b], 0).unwrap();
1048 assert_eq!(c.shape(), &[4, 3]);
1049 }
1050
1051 #[test]
1052 fn test_concatenate_axis1() {
1053 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1054 let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1055 let c = concatenate(&[a, b], 1).unwrap();
1056 assert_eq!(c.shape(), &[2, 5]);
1057 }
1058
1059 #[test]
1060 fn test_concatenate_shape_mismatch() {
1061 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1062 let b = dyn_arr(&[3, 3], vec![1.0; 9]);
1063 let c = concatenate(&[a, b], 0).unwrap();
1067 assert_eq!(c.shape(), &[5, 3]);
1068 }
1069
1070 #[test]
1071 fn test_concatenate_empty() {
1072 let v: Vec<Array<f64, IxDyn>> = vec![];
1073 assert!(concatenate(&v, 0).is_err());
1074 }
1075
1076 #[test]
1077 fn test_stack() {
1078 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1079 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1080 let c = stack(&[a, b], 0).unwrap();
1081 assert_eq!(c.shape(), &[2, 3]);
1082 let data: Vec<f64> = c.iter().copied().collect();
1083 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1084 }
1085
1086 #[test]
1087 fn test_stack_axis1() {
1088 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1089 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1090 let c = stack(&[a, b], 1).unwrap();
1091 assert_eq!(c.shape(), &[3, 2]);
1092 let data: Vec<f64> = c.iter().copied().collect();
1093 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1094 }
1095
1096 #[test]
1097 fn test_vstack() {
1098 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1099 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1100 let c = vstack(&[a, b]).unwrap();
1101 assert_eq!(c.shape(), &[2, 3]);
1102 }
1103
1104 #[test]
1105 fn test_hstack() {
1106 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1107 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1108 let c = hstack(&[a, b]).unwrap();
1109 assert_eq!(c.shape(), &[6]);
1110 }
1111
1112 #[test]
1113 fn test_hstack_2d() {
1114 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1115 let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1116 let c = hstack(&[a, b]).unwrap();
1117 assert_eq!(c.shape(), &[2, 5]);
1118 }
1119
1120 #[test]
1121 fn test_dstack() {
1122 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1123 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1124 let c = dstack(&[a, b]).unwrap();
1125 assert_eq!(c.shape(), &[2, 2, 2]);
1126 }
1127
1128 #[test]
1129 fn test_block() {
1130 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1131 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1132 let c = dyn_arr(&[2, 2], vec![9.0, 10.0, 11.0, 12.0]);
1133 let d = dyn_arr(&[2, 2], vec![13.0, 14.0, 15.0, 16.0]);
1134 let result = block(&[vec![a, b], vec![c, d]]).unwrap();
1135 assert_eq!(result.shape(), &[4, 4]);
1136 }
1137
1138 #[test]
1139 fn test_split() {
1140 let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1141 let parts = split(&a, 3, 0).unwrap();
1142 assert_eq!(parts.len(), 3);
1143 assert_eq!(parts[0].shape(), &[2]);
1144 assert_eq!(parts[1].shape(), &[2]);
1145 assert_eq!(parts[2].shape(), &[2]);
1146 }
1147
1148 #[test]
1149 fn test_split_uneven() {
1150 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1151 assert!(split(&a, 3, 0).is_err()); }
1153
1154 #[test]
1155 fn test_array_split() {
1156 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1157 let parts = array_split(&a, &[2, 4], 0).unwrap();
1158 assert_eq!(parts.len(), 3);
1159 assert_eq!(parts[0].shape(), &[2]); assert_eq!(parts[1].shape(), &[2]); assert_eq!(parts[2].shape(), &[1]); }
1163
1164 #[test]
1165 fn test_vsplit() {
1166 let a = dyn_arr(&[4, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1167 let parts = vsplit(&a, 2).unwrap();
1168 assert_eq!(parts.len(), 2);
1169 assert_eq!(parts[0].shape(), &[2, 2]);
1170 }
1171
1172 #[test]
1173 fn test_hsplit() {
1174 let a = dyn_arr(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1175 let parts = hsplit(&a, 2).unwrap();
1176 assert_eq!(parts.len(), 2);
1177 assert_eq!(parts[0].shape(), &[2, 2]);
1178 }
1179
1180 #[test]
1183 fn test_transpose_2d() {
1184 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1185 let b = transpose(&a, None).unwrap();
1186 assert_eq!(b.shape(), &[3, 2]);
1187 let data: Vec<f64> = b.iter().copied().collect();
1188 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1189 }
1190
1191 #[test]
1192 fn test_transpose_explicit() {
1193 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1194 let b = transpose(&a, Some(&[1, 0])).unwrap();
1195 assert_eq!(b.shape(), &[3, 2]);
1196 }
1197
1198 #[test]
1199 fn test_transpose_bad_axes() {
1200 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1201 assert!(transpose(&a, Some(&[0])).is_err()); }
1203
1204 #[test]
1205 fn test_swapaxes() {
1206 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1207 let b = swapaxes(&a, 0, 2).unwrap();
1208 assert_eq!(b.shape(), &[4, 3, 2]);
1209 }
1210
1211 #[test]
1212 fn test_moveaxis() {
1213 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1214 let b = moveaxis(&a, 0, 2).unwrap();
1215 assert_eq!(b.shape(), &[3, 4, 2]);
1216 }
1217
1218 #[test]
1219 fn test_rollaxis() {
1220 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1221 let b = rollaxis(&a, 2, 0).unwrap();
1222 assert_eq!(b.shape(), &[4, 2, 3]);
1223 }
1224
1225 #[test]
1226 fn test_flip() {
1227 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1228 let b = flip(&a, 0).unwrap();
1229 let data: Vec<f64> = b.iter().copied().collect();
1230 assert_eq!(data, vec![3.0, 2.0, 1.0]);
1231 }
1232
1233 #[test]
1234 fn test_flip_2d() {
1235 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1236 let b = flip(&a, 0).unwrap();
1237 let data: Vec<f64> = b.iter().copied().collect();
1238 assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1239
1240 let c = flip(&a, 1).unwrap();
1241 let data2: Vec<f64> = c.iter().copied().collect();
1242 assert_eq!(data2, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1243 }
1244
1245 #[test]
1246 fn test_fliplr() {
1247 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1248 let b = fliplr(&a).unwrap();
1249 let data: Vec<f64> = b.iter().copied().collect();
1250 assert_eq!(data, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1251 }
1252
1253 #[test]
1254 fn test_flipud() {
1255 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1256 let b = flipud(&a).unwrap();
1257 let data: Vec<f64> = b.iter().copied().collect();
1258 assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1259 }
1260
1261 #[test]
1262 fn test_fliplr_1d_err() {
1263 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1264 assert!(fliplr(&a).is_err());
1265 }
1266
1267 #[test]
1268 fn test_rot90_once() {
1269 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1271 let b = rot90(&a, 1).unwrap();
1272 assert_eq!(b.shape(), &[2, 2]);
1273 let data: Vec<f64> = b.iter().copied().collect();
1274 assert_eq!(data, vec![2.0, 4.0, 1.0, 3.0]);
1275 }
1276
1277 #[test]
1278 fn test_rot90_twice() {
1279 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1280 let b = rot90(&a, 2).unwrap();
1281 let data: Vec<f64> = b.iter().copied().collect();
1282 assert_eq!(data, vec![4.0, 3.0, 2.0, 1.0]);
1283 }
1284
1285 #[test]
1286 fn test_rot90_four_is_identity() {
1287 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1288 let b = rot90(&a, 4).unwrap();
1289 let data_a: Vec<f64> = a.iter().copied().collect();
1290 let data_b: Vec<f64> = b.iter().copied().collect();
1291 assert_eq!(data_a, data_b);
1292 assert_eq!(a.shape(), b.shape());
1293 }
1294
1295 #[test]
1296 fn test_roll_flat() {
1297 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1298 let b = roll(&a, 2, None).unwrap();
1299 let data: Vec<f64> = b.iter().copied().collect();
1300 assert_eq!(data, vec![4.0, 5.0, 1.0, 2.0, 3.0]);
1301 }
1302
1303 #[test]
1304 fn test_roll_negative() {
1305 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1306 let b = roll(&a, -2, None).unwrap();
1307 let data: Vec<f64> = b.iter().copied().collect();
1308 assert_eq!(data, vec![3.0, 4.0, 5.0, 1.0, 2.0]);
1309 }
1310
1311 #[test]
1312 fn test_roll_axis() {
1313 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1314 let b = roll(&a, 1, Some(1)).unwrap();
1315 let data: Vec<f64> = b.iter().copied().collect();
1316 assert_eq!(data, vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1317 }
1318}