1pub mod extended;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13pub fn reshape<T: Element, D: Dimension>(
31 a: &Array<T, D>,
32 new_shape: &[usize],
33) -> FerrayResult<Array<T, IxDyn>> {
34 let old_size = a.size();
35 let new_size: usize = new_shape.iter().product();
36 if old_size != new_size {
37 return Err(FerrayError::shape_mismatch(format!(
38 "cannot reshape array of size {old_size} into shape {new_shape:?} (size {new_size})",
39 )));
40 }
41 let view = a.inner.view().into_dyn();
42 let reshaped = view
43 .to_shape(ndarray::IxDyn(new_shape))
44 .map_err(|e| FerrayError::shape_mismatch(e.to_string()))?;
45 Ok(Array::from_ndarray(
49 reshaped.as_standard_layout().into_owned(),
50 ))
51}
52
53pub fn ravel<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
58 let n = a.size();
59 let view = a.inner.view().into_dyn();
60 let reshaped = view
61 .to_shape(ndarray::IxDyn(&[n]))
62 .expect("1-D reshape always succeeds for a size-preserving target");
63 let standard = reshaped.as_standard_layout().into_owned();
64 let one_d = standard
65 .into_dimensionality::<ndarray::Ix1>()
66 .expect("reshape result has ndim == 1 by construction");
67 Ok(Array::from_ndarray(one_d))
68}
69
70pub fn flatten<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
74 ravel(a)
75}
76
77pub fn squeeze<T: Element, D: Dimension>(
88 a: &Array<T, D>,
89 axis: Option<usize>,
90) -> FerrayResult<Array<T, IxDyn>> {
91 let shape = a.shape();
92 if let Some(ax) = axis {
93 if ax >= shape.len() {
94 return Err(FerrayError::axis_out_of_bounds(ax, shape.len()));
95 }
96 if shape[ax] != 1 {
97 return Err(FerrayError::invalid_value(format!(
98 "cannot select axis {} with size {} for squeeze (must be 1)",
99 ax, shape[ax],
100 )));
101 }
102 let new_shape: Vec<usize> = shape
103 .iter()
104 .enumerate()
105 .filter(|&(i, _)| i != ax)
106 .map(|(_, &s)| s)
107 .collect();
108 let data: Vec<T> = a.iter().cloned().collect();
109 Array::from_vec(IxDyn::new(&new_shape), data)
110 } else {
111 let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
112 let new_shape = if new_shape.is_empty() && !shape.is_empty() {
115 vec![1]
116 } else if new_shape.is_empty() {
117 vec![]
118 } else {
119 new_shape
120 };
121 let data: Vec<T> = a.iter().cloned().collect();
122 Array::from_vec(IxDyn::new(&new_shape), data)
123 }
124}
125
126pub fn expand_dims<T: Element, D: Dimension>(
133 a: &Array<T, D>,
134 axis: usize,
135) -> FerrayResult<Array<T, IxDyn>> {
136 let ndim = a.ndim();
137 if axis > ndim {
138 return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
139 }
140 let mut new_shape: Vec<usize> = a.shape().to_vec();
141 new_shape.insert(axis, 1);
142 let data: Vec<T> = a.iter().cloned().collect();
143 Array::from_vec(IxDyn::new(&new_shape), data)
144}
145
146pub fn broadcast_to<T: Element, D: Dimension>(
159 a: &Array<T, D>,
160 new_shape: &[usize],
161) -> FerrayResult<Array<T, IxDyn>> {
162 let src_shape = a.shape();
163 let dyn_view = a.inner.view().into_dyn();
164 let broadcast_view = dyn_view
165 .broadcast(ndarray::IxDyn(new_shape))
166 .ok_or_else(|| FerrayError::BroadcastFailure {
167 shape_a: src_shape.to_vec(),
168 shape_b: new_shape.to_vec(),
169 })?;
170 Ok(Array::from_ndarray(
174 broadcast_view.as_standard_layout().into_owned(),
175 ))
176}
177
178pub fn concatenate<T: Element>(
191 arrays: &[Array<T, IxDyn>],
192 axis: usize,
193) -> FerrayResult<Array<T, IxDyn>> {
194 if arrays.is_empty() {
195 return Err(FerrayError::invalid_value(
196 "concatenate: need at least one array",
197 ));
198 }
199 let ndim = arrays[0].ndim();
200 if axis >= ndim {
201 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
202 }
203 let base_shape = arrays[0].shape();
204
205 let mut total_along_axis = 0usize;
207 for arr in arrays {
208 if arr.ndim() != ndim {
209 return Err(FerrayError::shape_mismatch(format!(
210 "all arrays must have same ndim; got {} and {}",
211 ndim,
212 arr.ndim(),
213 )));
214 }
215 for (i, (&s, &base)) in arr.shape().iter().zip(base_shape.iter()).enumerate() {
216 if i != axis && s != base {
217 return Err(FerrayError::shape_mismatch(format!(
218 "shape mismatch on axis {i}: {s} vs {base}",
219 )));
220 }
221 }
222 total_along_axis += arr.shape()[axis];
223 }
224
225 let mut new_shape = base_shape.to_vec();
227 new_shape[axis] = total_along_axis;
228 let total: usize = new_shape.iter().product();
229 let mut data = Vec::with_capacity(total);
230
231 let src_vecs: Vec<Vec<T>> = arrays.iter().map(|a| a.iter().cloned().collect()).collect();
233
234 let mut out_strides = vec![1usize; ndim];
236 for i in (0..ndim - 1).rev() {
237 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
238 }
239
240 for flat_idx in 0..total {
242 let mut rem = flat_idx;
244 let mut nd_idx = vec![0usize; ndim];
245 for i in 0..ndim {
246 nd_idx[i] = rem / out_strides[i];
247 rem %= out_strides[i];
248 }
249
250 let concat_idx = nd_idx[axis];
252 let mut offset = 0;
253 let mut src_arr_idx = 0;
254 for (k, arr) in arrays.iter().enumerate() {
255 let len_along = arr.shape()[axis];
256 if concat_idx < offset + len_along {
257 src_arr_idx = k;
258 break;
259 }
260 offset += len_along;
261 }
262 let local_concat_idx = concat_idx - offset;
263
264 let src_shape = arrays[src_arr_idx].shape();
266 let mut src_flat = 0usize;
267 let mut src_mul = 1usize;
268 for i in (0..ndim).rev() {
269 let idx = if i == axis {
270 local_concat_idx
271 } else {
272 nd_idx[i]
273 };
274 src_flat += idx * src_mul;
275 src_mul *= src_shape[i];
276 }
277
278 let elem = src_vecs[src_arr_idx].get(src_flat).ok_or_else(|| {
279 FerrayError::invalid_value(format!(
280 "concatenate: internal index {} out of range for source array of length {}",
281 src_flat,
282 src_vecs[src_arr_idx].len(),
283 ))
284 })?;
285 data.push(elem.clone());
286 }
287
288 Array::from_vec(IxDyn::new(&new_shape), data)
289}
290
291pub fn stack<T: Element>(arrays: &[Array<T, IxDyn>], axis: usize) -> FerrayResult<Array<T, IxDyn>> {
303 if arrays.is_empty() {
304 return Err(FerrayError::invalid_value("stack: need at least one array"));
305 }
306 let base_shape = arrays[0].shape();
307 let ndim = base_shape.len();
308
309 if axis > ndim {
310 return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
311 }
312
313 for arr in &arrays[1..] {
314 if arr.shape() != base_shape {
315 return Err(FerrayError::shape_mismatch(format!(
316 "all input arrays must have the same shape; got {:?} and {:?}",
317 base_shape,
318 arr.shape(),
319 )));
320 }
321 }
322
323 let mut expanded = Vec::with_capacity(arrays.len());
325 for arr in arrays {
326 expanded.push(expand_dims(arr, axis)?);
327 }
328 concatenate(&expanded, axis)
329}
330
331pub fn vstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
336 if arrays.is_empty() {
337 return Err(FerrayError::invalid_value(
338 "vstack: need at least one array",
339 ));
340 }
341 let ndim = arrays[0].ndim();
343 if ndim == 1 {
344 let mut reshaped = Vec::with_capacity(arrays.len());
345 for arr in arrays {
346 let n = arr.shape()[0];
347 reshaped.push(reshape(arr, &[1, n])?);
348 }
349 concatenate(&reshaped, 0)
350 } else {
351 concatenate(arrays, 0)
352 }
353}
354
355pub fn hstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
360 if arrays.is_empty() {
361 return Err(FerrayError::invalid_value(
362 "hstack: need at least one array",
363 ));
364 }
365 let ndim = arrays[0].ndim();
366 if ndim == 1 {
367 concatenate(arrays, 0)
368 } else {
369 concatenate(arrays, 1)
370 }
371}
372
373pub fn dstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
381 if arrays.is_empty() {
382 return Err(FerrayError::invalid_value(
383 "dstack: need at least one array",
384 ));
385 }
386 let mut expanded = Vec::with_capacity(arrays.len());
387 for arr in arrays {
388 let shape = arr.shape();
389 match shape.len() {
390 1 => {
391 let n = shape[0];
392 expanded.push(reshape(arr, &[1, n, 1])?);
393 }
394 2 => {
395 let (m, n) = (shape[0], shape[1]);
396 expanded.push(reshape(arr, &[m, n, 1])?);
397 }
398 _ => {
399 let data: Vec<T> = arr.iter().cloned().collect();
401 expanded.push(Array::from_vec(IxDyn::new(shape), data)?);
402 }
403 }
404 }
405 concatenate(&expanded, 2)
406}
407
408pub fn column_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
419 if arrays.is_empty() {
420 return Err(FerrayError::invalid_value(
421 "column_stack: need at least one array",
422 ));
423 }
424 let first_ndim = arrays[0].ndim();
425 if first_ndim == 1 {
426 let n = arrays[0].shape()[0];
428 let mut reshaped = Vec::with_capacity(arrays.len());
429 for arr in arrays {
430 if arr.ndim() != 1 {
431 return Err(FerrayError::shape_mismatch(
432 "column_stack: all inputs must have the same ndim",
433 ));
434 }
435 if arr.shape()[0] != n {
436 return Err(FerrayError::shape_mismatch(format!(
437 "column_stack: 1-D inputs must have the same length; got {} and {}",
438 n,
439 arr.shape()[0],
440 )));
441 }
442 reshaped.push(reshape(arr, &[n, 1])?);
443 }
444 concatenate(&reshaped, 1)
445 } else {
446 hstack(arrays)
448 }
449}
450
451pub fn row_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
456 vstack(arrays)
457}
458
459pub fn block<T: Element>(blocks: &[Vec<Array<T, IxDyn>>]) -> FerrayResult<Array<T, IxDyn>> {
469 if blocks.is_empty() {
470 return Err(FerrayError::invalid_value("block: empty input"));
471 }
472 let mut rows = Vec::with_capacity(blocks.len());
473 for row in blocks {
474 if row.is_empty() {
475 return Err(FerrayError::invalid_value("block: empty row"));
476 }
477 let row_arr = if row.len() == 1 {
479 let data: Vec<T> = row[0].iter().cloned().collect();
480 Array::from_vec(IxDyn::new(row[0].shape()), data)?
481 } else {
482 hstack(row)?
483 };
484 rows.push(row_arr);
485 }
486 if rows.len() == 1 {
487 Ok(rows.pop().unwrap_or_else(|| unreachable!()))
489 } else {
490 vstack(&rows)
491 }
492}
493
494pub fn split<T: Element>(
503 a: &Array<T, IxDyn>,
504 n_sections: usize,
505 axis: usize,
506) -> FerrayResult<Vec<Array<T, IxDyn>>> {
507 let shape = a.shape();
508 if axis >= shape.len() {
509 return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
510 }
511 let axis_len = shape[axis];
512 if n_sections == 0 {
513 return Err(FerrayError::invalid_value("split: n_sections must be > 0"));
514 }
515 if axis_len % n_sections != 0 {
516 return Err(FerrayError::invalid_value(format!(
517 "array of size {axis_len} along axis {axis} cannot be evenly split into {n_sections} sections",
518 )));
519 }
520 let chunk_size = axis_len / n_sections;
521 let indices: Vec<usize> = (1..n_sections).map(|i| i * chunk_size).collect();
522 array_split(a, &indices, axis)
523}
524
525pub fn array_split<T: Element>(
534 a: &Array<T, IxDyn>,
535 indices: &[usize],
536 axis: usize,
537) -> FerrayResult<Vec<Array<T, IxDyn>>> {
538 let shape = a.shape();
539 let ndim = shape.len();
540 if axis >= ndim {
541 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
542 }
543 let axis_len = shape[axis];
544 let src_data: Vec<T> = a.iter().cloned().collect();
545
546 let mut splits = Vec::with_capacity(indices.len() + 2);
548 splits.push(0);
549 for &idx in indices {
550 splits.push(idx.min(axis_len));
551 }
552 splits.push(axis_len);
553
554 let mut src_strides = vec![1usize; ndim];
556 for i in (0..ndim - 1).rev() {
557 src_strides[i] = src_strides[i + 1] * shape[i + 1];
558 }
559
560 let mut result = Vec::with_capacity(splits.len() - 1);
561 for w in splits.windows(2) {
562 let start = w[0];
563 let end = w[1];
564 let chunk_len = end - start;
565
566 let mut sub_shape = shape.to_vec();
567 sub_shape[axis] = chunk_len;
568 let sub_total: usize = sub_shape.iter().product();
569
570 let mut sub_strides = vec![1usize; ndim];
572 for i in (0..ndim - 1).rev() {
573 sub_strides[i] = sub_strides[i + 1] * sub_shape[i + 1];
574 }
575
576 let mut sub_data = Vec::with_capacity(sub_total);
577 for flat in 0..sub_total {
578 let mut rem = flat;
580 let mut src_flat = 0usize;
581 for i in 0..ndim {
582 let idx = rem / sub_strides[i];
583 rem %= sub_strides[i];
584 let src_idx = if i == axis { idx + start } else { idx };
585 src_flat += src_idx * src_strides[i];
586 }
587 sub_data.push(src_data[src_flat].clone());
588 }
589 result.push(Array::from_vec(IxDyn::new(&sub_shape), sub_data)?);
590 }
591
592 Ok(result)
593}
594
595pub fn array_split_n<T: Element>(
605 a: &Array<T, IxDyn>,
606 n: usize,
607 axis: usize,
608) -> FerrayResult<Vec<Array<T, IxDyn>>> {
609 if n == 0 {
610 return Err(FerrayError::invalid_value("array_split_n: n must be > 0"));
611 }
612 let shape = a.shape();
613 if axis >= shape.len() {
614 return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
615 }
616 let axis_len = shape[axis];
617
618 let base = axis_len / n;
622 let extra = axis_len % n;
623 let mut indices = Vec::with_capacity(n.saturating_sub(1));
624 let mut cum = 0usize;
625 for i in 0..n - 1 {
626 cum += if i < extra { base + 1 } else { base };
627 indices.push(cum);
628 }
629 array_split(a, &indices, axis)
630}
631
632pub fn vsplit<T: Element>(
636 a: &Array<T, IxDyn>,
637 n_sections: usize,
638) -> FerrayResult<Vec<Array<T, IxDyn>>> {
639 split(a, n_sections, 0)
640}
641
642pub fn hsplit<T: Element>(
646 a: &Array<T, IxDyn>,
647 n_sections: usize,
648) -> FerrayResult<Vec<Array<T, IxDyn>>> {
649 split(a, n_sections, 1)
650}
651
652pub fn dsplit<T: Element>(
656 a: &Array<T, IxDyn>,
657 n_sections: usize,
658) -> FerrayResult<Vec<Array<T, IxDyn>>> {
659 split(a, n_sections, 2)
660}
661
662pub fn transpose<T: Element, D: Dimension>(
677 a: &Array<T, D>,
678 axes: Option<&[usize]>,
679) -> FerrayResult<Array<T, IxDyn>> {
680 let ndim = a.ndim();
681 let perm: Vec<usize> = match axes {
682 Some(ax) => {
683 if ax.len() != ndim {
684 return Err(FerrayError::invalid_value(format!(
685 "axes must have length {} but got {}",
686 ndim,
687 ax.len(),
688 )));
689 }
690 let mut seen = vec![false; ndim];
692 for &a in ax {
693 if a >= ndim {
694 return Err(FerrayError::axis_out_of_bounds(a, ndim));
695 }
696 if seen[a] {
697 return Err(FerrayError::invalid_value(format!(
698 "duplicate axis {a} in transpose",
699 )));
700 }
701 seen[a] = true;
702 }
703 ax.to_vec()
704 }
705 None => (0..ndim).rev().collect(),
706 };
707
708 let permuted = a
717 .inner
718 .view()
719 .into_dyn()
720 .permuted_axes(ndarray::IxDyn(&perm));
721 Ok(Array::from_ndarray(
722 permuted.as_standard_layout().into_owned(),
723 ))
724}
725
726pub fn swapaxes<T: Element, D: Dimension>(
733 a: &Array<T, D>,
734 axis1: usize,
735 axis2: usize,
736) -> FerrayResult<Array<T, IxDyn>> {
737 let ndim = a.ndim();
738 if axis1 >= ndim {
739 return Err(FerrayError::axis_out_of_bounds(axis1, ndim));
740 }
741 if axis2 >= ndim {
742 return Err(FerrayError::axis_out_of_bounds(axis2, ndim));
743 }
744 let mut perm: Vec<usize> = (0..ndim).collect();
745 perm.swap(axis1, axis2);
746 transpose(a, Some(&perm))
747}
748
749pub fn moveaxis<T: Element, D: Dimension>(
756 a: &Array<T, D>,
757 source: usize,
758 destination: usize,
759) -> FerrayResult<Array<T, IxDyn>> {
760 let ndim = a.ndim();
761 if source >= ndim {
762 return Err(FerrayError::axis_out_of_bounds(source, ndim));
763 }
764 if destination >= ndim {
765 return Err(FerrayError::axis_out_of_bounds(destination, ndim));
766 }
767 let mut order: Vec<usize> = (0..ndim).filter(|&x| x != source).collect();
769 order.insert(destination, source);
770 transpose(a, Some(&order))
771}
772
773pub fn rollaxis<T: Element, D: Dimension>(
780 a: &Array<T, D>,
781 axis: usize,
782 start: usize,
783) -> FerrayResult<Array<T, IxDyn>> {
784 let ndim = a.ndim();
785 if axis >= ndim {
786 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
787 }
788 if start > ndim {
789 return Err(FerrayError::axis_out_of_bounds(start, ndim + 1));
790 }
791 let dst = if start > axis { start - 1 } else { start };
792 if axis == dst {
793 let data: Vec<T> = a.iter().cloned().collect();
795 return Array::from_vec(IxDyn::new(a.shape()), data);
796 }
797 moveaxis(a, axis, dst)
798}
799
800pub fn flip<T: Element, D: Dimension>(
807 a: &Array<T, D>,
808 axis: usize,
809) -> FerrayResult<Array<T, IxDyn>> {
810 let shape = a.shape();
811 let ndim = shape.len();
812 if axis >= ndim {
813 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
814 }
815 let src_data: Vec<T> = a.iter().cloned().collect();
816 let total = src_data.len();
817
818 let mut strides = vec![1usize; ndim];
820 for i in (0..ndim.saturating_sub(1)).rev() {
821 strides[i] = strides[i + 1] * shape[i + 1];
822 }
823
824 let mut data = Vec::with_capacity(total);
825 for flat in 0..total {
826 let mut rem = flat;
827 let mut src_flat = 0usize;
828 for i in 0..ndim {
829 let idx = rem / strides[i];
830 rem %= strides[i];
831 let src_idx = if i == axis { shape[i] - 1 - idx } else { idx };
832 src_flat += src_idx * strides[i];
833 }
834 data.push(src_data[src_flat].clone());
835 }
836 Array::from_vec(IxDyn::new(shape), data)
837}
838
839pub fn fliplr<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
846 if a.ndim() < 2 {
847 return Err(FerrayError::invalid_value(
848 "fliplr: array must be at least 2-D",
849 ));
850 }
851 flip(a, 1)
852}
853
854pub fn flipud<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
861 if a.ndim() < 1 {
862 return Err(FerrayError::invalid_value(
863 "flipud: array must be at least 1-D",
864 ));
865 }
866 flip(a, 0)
867}
868
869pub fn rot90<T: Element, D: Dimension>(a: &Array<T, D>, k: i32) -> FerrayResult<Array<T, IxDyn>> {
878 if a.ndim() < 2 {
879 return Err(FerrayError::invalid_value(
880 "rot90: array must be at least 2-D",
881 ));
882 }
883 let k = k.rem_euclid(4);
885 let shape = a.shape();
886 let data: Vec<T> = a.iter().cloned().collect();
887
888 let as_dyn = Array::from_vec(IxDyn::new(shape), data)?;
890
891 match k {
892 0 => Ok(as_dyn),
893 1 => {
894 let flipped = flip(&as_dyn, 1)?;
896 swapaxes(&flipped, 0, 1)
897 }
898 2 => {
899 let f1 = flip(&as_dyn, 0)?;
901 flip(&f1, 1)
902 }
903 3 => {
904 let transposed = swapaxes(&as_dyn, 0, 1)?;
906 flip(&transposed, 1)
907 }
908 _ => unreachable!(),
909 }
910}
911
912pub fn roll<T: Element, D: Dimension>(
922 a: &Array<T, D>,
923 shift: isize,
924 axis: Option<usize>,
925) -> FerrayResult<Array<T, IxDyn>> {
926 match axis {
927 None => {
928 let data: Vec<T> = a.iter().cloned().collect();
930 let n = data.len();
931 if n == 0 {
932 return Array::from_vec(IxDyn::new(a.shape()), data);
933 }
934 let shift = ((shift % n as isize) + n as isize) as usize % n;
935 let mut rolled = Vec::with_capacity(n);
936 for i in 0..n {
937 rolled.push(data[(n + i - shift) % n].clone());
938 }
939 Array::from_vec(IxDyn::new(a.shape()), rolled)
940 }
941 Some(ax) => {
942 let shape = a.shape();
943 let ndim = shape.len();
944 if ax >= ndim {
945 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
946 }
947 let axis_len = shape[ax];
948 if axis_len == 0 {
949 let data: Vec<T> = a.iter().cloned().collect();
950 return Array::from_vec(IxDyn::new(shape), data);
951 }
952 let shift = ((shift % axis_len as isize) + axis_len as isize) as usize % axis_len;
953 let src_data: Vec<T> = a.iter().cloned().collect();
954 let total = src_data.len();
955
956 let mut strides = vec![1usize; ndim];
958 for i in (0..ndim.saturating_sub(1)).rev() {
959 strides[i] = strides[i + 1] * shape[i + 1];
960 }
961
962 let mut data = Vec::with_capacity(total);
963 for flat in 0..total {
964 let mut rem = flat;
965 let mut src_flat = 0usize;
966 #[allow(clippy::needless_range_loop)]
967 for i in 0..ndim {
968 let idx = rem / strides[i];
969 rem %= strides[i];
970 let src_idx = if i == ax {
971 (axis_len + idx - shift) % axis_len
972 } else {
973 idx
974 };
975 src_flat += src_idx * strides[i];
976 }
977 data.push(src_data[src_flat].clone());
978 }
979 Array::from_vec(IxDyn::new(shape), data)
980 }
981 }
982}
983
984#[cfg(test)]
989mod tests {
990 use super::*;
991
992 fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
993 Array::from_vec(IxDyn::new(shape), data).unwrap()
994 }
995
996 #[test]
999 fn test_reshape() {
1000 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1001 let b = reshape(&a, &[3, 2]).unwrap();
1002 assert_eq!(b.shape(), &[3, 2]);
1003 let data: Vec<f64> = b.iter().copied().collect();
1004 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1005 }
1006
1007 #[test]
1008 fn test_reshape_size_mismatch() {
1009 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1010 assert!(reshape(&a, &[2, 4]).is_err());
1011 }
1012
1013 #[test]
1014 fn test_ravel() {
1015 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1016 let b = ravel(&a).unwrap();
1017 assert_eq!(b.shape(), &[6]);
1018 assert_eq!(b.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1019 }
1020
1021 #[test]
1022 fn test_flatten() {
1023 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1024 let b = flatten(&a).unwrap();
1025 assert_eq!(b.shape(), &[6]);
1026 }
1027
1028 #[test]
1029 fn test_squeeze() {
1030 let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
1031 let b = squeeze(&a, None).unwrap();
1032 assert_eq!(b.shape(), &[3]);
1033 }
1034
1035 #[test]
1036 fn test_squeeze_specific_axis() {
1037 let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
1038 let b = squeeze(&a, Some(0)).unwrap();
1039 assert_eq!(b.shape(), &[3, 1]);
1040 }
1041
1042 #[test]
1043 fn test_squeeze_not_size_1() {
1044 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1045 assert!(squeeze(&a, Some(0)).is_err());
1046 }
1047
1048 #[test]
1049 fn test_expand_dims() {
1050 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1051 let b = expand_dims(&a, 0).unwrap();
1052 assert_eq!(b.shape(), &[1, 3]);
1053 let c = expand_dims(&a, 1).unwrap();
1054 assert_eq!(c.shape(), &[3, 1]);
1055 }
1056
1057 #[test]
1058 fn test_expand_dims_oob() {
1059 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1060 assert!(expand_dims(&a, 3).is_err());
1061 }
1062
1063 #[test]
1064 fn test_broadcast_to() {
1065 let a = dyn_arr(&[1, 3], vec![1.0, 2.0, 3.0]);
1066 let b = broadcast_to(&a, &[3, 3]).unwrap();
1067 assert_eq!(b.shape(), &[3, 3]);
1068 let data: Vec<f64> = b.iter().copied().collect();
1069 assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
1070 }
1071
1072 #[test]
1073 fn test_broadcast_to_1d_to_2d() {
1074 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1075 let b = broadcast_to(&a, &[2, 3]).unwrap();
1076 assert_eq!(b.shape(), &[2, 3]);
1077 }
1078
1079 #[test]
1080 fn test_broadcast_to_incompatible() {
1081 let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1082 assert!(broadcast_to(&a, &[3]).is_err());
1083 }
1084
1085 #[test]
1086 fn test_broadcast_to_from_non_contiguous_source() {
1087 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1092 let t = transpose(&a, None).unwrap();
1094 let b = broadcast_to(&t, &[2, 3, 2]).unwrap();
1095 assert_eq!(b.shape(), &[2, 3, 2]);
1096 let data: Vec<f64> = b.iter().copied().collect();
1098 assert_eq!(&data[..6], &data[6..12]);
1099 }
1100
1101 #[test]
1104 fn test_concatenate() {
1105 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1106 let b = dyn_arr(&[2, 3], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1107 let c = concatenate(&[a, b], 0).unwrap();
1108 assert_eq!(c.shape(), &[4, 3]);
1109 }
1110
1111 #[test]
1112 fn test_concatenate_axis1() {
1113 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1114 let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1115 let c = concatenate(&[a, b], 1).unwrap();
1116 assert_eq!(c.shape(), &[2, 5]);
1117 }
1118
1119 #[test]
1120 fn test_concatenate_shape_mismatch() {
1121 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1122 let b = dyn_arr(&[3, 3], vec![1.0; 9]);
1123 let c = concatenate(&[a, b], 0).unwrap();
1127 assert_eq!(c.shape(), &[5, 3]);
1128 }
1129
1130 #[test]
1131 fn test_concatenate_empty() {
1132 let v: Vec<Array<f64, IxDyn>> = vec![];
1133 assert!(concatenate(&v, 0).is_err());
1134 }
1135
1136 #[test]
1137 fn test_stack() {
1138 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1139 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1140 let c = stack(&[a, b], 0).unwrap();
1141 assert_eq!(c.shape(), &[2, 3]);
1142 let data: Vec<f64> = c.iter().copied().collect();
1143 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1144 }
1145
1146 #[test]
1147 fn test_stack_axis1() {
1148 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1149 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1150 let c = stack(&[a, b], 1).unwrap();
1151 assert_eq!(c.shape(), &[3, 2]);
1152 let data: Vec<f64> = c.iter().copied().collect();
1153 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1154 }
1155
1156 #[test]
1157 fn test_vstack() {
1158 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1159 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1160 let c = vstack(&[a, b]).unwrap();
1161 assert_eq!(c.shape(), &[2, 3]);
1162 }
1163
1164 #[test]
1165 fn test_hstack() {
1166 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1167 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1168 let c = hstack(&[a, b]).unwrap();
1169 assert_eq!(c.shape(), &[6]);
1170 }
1171
1172 #[test]
1173 fn test_hstack_2d() {
1174 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1175 let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1176 let c = hstack(&[a, b]).unwrap();
1177 assert_eq!(c.shape(), &[2, 5]);
1178 }
1179
1180 #[test]
1181 fn test_dstack() {
1182 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1183 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1184 let c = dstack(&[a, b]).unwrap();
1185 assert_eq!(c.shape(), &[2, 2, 2]);
1186 }
1187
1188 #[test]
1189 fn test_block() {
1190 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1191 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1192 let c = dyn_arr(&[2, 2], vec![9.0, 10.0, 11.0, 12.0]);
1193 let d = dyn_arr(&[2, 2], vec![13.0, 14.0, 15.0, 16.0]);
1194 let result = block(&[vec![a, b], vec![c, d]]).unwrap();
1195 assert_eq!(result.shape(), &[4, 4]);
1196 }
1197
1198 #[test]
1199 fn test_split() {
1200 let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1201 let parts = split(&a, 3, 0).unwrap();
1202 assert_eq!(parts.len(), 3);
1203 assert_eq!(parts[0].shape(), &[2]);
1204 assert_eq!(parts[1].shape(), &[2]);
1205 assert_eq!(parts[2].shape(), &[2]);
1206 }
1207
1208 #[test]
1209 fn test_split_uneven() {
1210 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1211 assert!(split(&a, 3, 0).is_err()); }
1213
1214 #[test]
1215 fn test_array_split() {
1216 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1217 let parts = array_split(&a, &[2, 4], 0).unwrap();
1218 assert_eq!(parts.len(), 3);
1219 assert_eq!(parts[0].shape(), &[2]); assert_eq!(parts[1].shape(), &[2]); assert_eq!(parts[2].shape(), &[1]); }
1223
1224 #[test]
1225 fn test_vsplit() {
1226 let a = dyn_arr(&[4, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1227 let parts = vsplit(&a, 2).unwrap();
1228 assert_eq!(parts.len(), 2);
1229 assert_eq!(parts[0].shape(), &[2, 2]);
1230 }
1231
1232 #[test]
1233 fn test_hsplit() {
1234 let a = dyn_arr(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1235 let parts = hsplit(&a, 2).unwrap();
1236 assert_eq!(parts.len(), 2);
1237 assert_eq!(parts[0].shape(), &[2, 2]);
1238 }
1239
1240 #[test]
1243 fn test_transpose_2d() {
1244 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1245 let b = transpose(&a, None).unwrap();
1246 assert_eq!(b.shape(), &[3, 2]);
1247 let data: Vec<f64> = b.iter().copied().collect();
1248 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1249 }
1250
1251 #[test]
1252 fn test_transpose_explicit() {
1253 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1254 let b = transpose(&a, Some(&[1, 0])).unwrap();
1255 assert_eq!(b.shape(), &[3, 2]);
1256 }
1257
1258 #[test]
1259 fn test_transpose_bad_axes() {
1260 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1261 assert!(transpose(&a, Some(&[0])).is_err()); }
1263
1264 #[test]
1265 fn test_swapaxes() {
1266 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1267 let b = swapaxes(&a, 0, 2).unwrap();
1268 assert_eq!(b.shape(), &[4, 3, 2]);
1269 }
1270
1271 #[test]
1272 fn test_moveaxis() {
1273 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1274 let b = moveaxis(&a, 0, 2).unwrap();
1275 assert_eq!(b.shape(), &[3, 4, 2]);
1276 }
1277
1278 #[test]
1279 fn test_rollaxis() {
1280 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1281 let b = rollaxis(&a, 2, 0).unwrap();
1282 assert_eq!(b.shape(), &[4, 2, 3]);
1283 }
1284
1285 #[test]
1286 fn test_flip() {
1287 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1288 let b = flip(&a, 0).unwrap();
1289 let data: Vec<f64> = b.iter().copied().collect();
1290 assert_eq!(data, vec![3.0, 2.0, 1.0]);
1291 }
1292
1293 #[test]
1294 fn test_flip_2d() {
1295 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1296 let b = flip(&a, 0).unwrap();
1297 let data: Vec<f64> = b.iter().copied().collect();
1298 assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1299
1300 let c = flip(&a, 1).unwrap();
1301 let data2: Vec<f64> = c.iter().copied().collect();
1302 assert_eq!(data2, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1303 }
1304
1305 #[test]
1306 fn test_fliplr() {
1307 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1308 let b = fliplr(&a).unwrap();
1309 let data: Vec<f64> = b.iter().copied().collect();
1310 assert_eq!(data, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1311 }
1312
1313 #[test]
1314 fn test_flipud() {
1315 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1316 let b = flipud(&a).unwrap();
1317 let data: Vec<f64> = b.iter().copied().collect();
1318 assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1319 }
1320
1321 #[test]
1322 fn test_fliplr_1d_err() {
1323 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1324 assert!(fliplr(&a).is_err());
1325 }
1326
1327 #[test]
1328 fn test_rot90_once() {
1329 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1331 let b = rot90(&a, 1).unwrap();
1332 assert_eq!(b.shape(), &[2, 2]);
1333 let data: Vec<f64> = b.iter().copied().collect();
1334 assert_eq!(data, vec![2.0, 4.0, 1.0, 3.0]);
1335 }
1336
1337 #[test]
1338 fn test_rot90_twice() {
1339 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1340 let b = rot90(&a, 2).unwrap();
1341 let data: Vec<f64> = b.iter().copied().collect();
1342 assert_eq!(data, vec![4.0, 3.0, 2.0, 1.0]);
1343 }
1344
1345 #[test]
1346 fn test_rot90_four_is_identity() {
1347 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1348 let b = rot90(&a, 4).unwrap();
1349 let data_a: Vec<f64> = a.iter().copied().collect();
1350 let data_b: Vec<f64> = b.iter().copied().collect();
1351 assert_eq!(data_a, data_b);
1352 assert_eq!(a.shape(), b.shape());
1353 }
1354
1355 #[test]
1356 fn test_roll_flat() {
1357 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1358 let b = roll(&a, 2, None).unwrap();
1359 let data: Vec<f64> = b.iter().copied().collect();
1360 assert_eq!(data, vec![4.0, 5.0, 1.0, 2.0, 3.0]);
1361 }
1362
1363 #[test]
1364 fn test_roll_negative() {
1365 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1366 let b = roll(&a, -2, None).unwrap();
1367 let data: Vec<f64> = b.iter().copied().collect();
1368 assert_eq!(data, vec![3.0, 4.0, 5.0, 1.0, 2.0]);
1369 }
1370
1371 #[test]
1372 fn test_roll_axis() {
1373 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1374 let b = roll(&a, 1, Some(1)).unwrap();
1375 let data: Vec<f64> = b.iter().copied().collect();
1376 assert_eq!(data, vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1377 }
1378
1379 #[test]
1384 fn test_column_stack_1d() {
1385 let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1387 let b = dyn_arr(&[4], vec![10.0, 20.0, 30.0, 40.0]);
1388 let c = dyn_arr(&[4], vec![100.0, 200.0, 300.0, 400.0]);
1389 let result = column_stack(&[a, b, c]).unwrap();
1390 assert_eq!(result.shape(), &[4, 3]);
1391 assert_eq!(
1392 result.iter().copied().collect::<Vec<_>>(),
1393 vec![
1394 1.0, 10.0, 100.0, 2.0, 20.0, 200.0, 3.0, 30.0, 300.0, 4.0, 40.0, 400.0, ]
1399 );
1400 }
1401
1402 #[test]
1403 fn test_column_stack_2d_same_as_hstack() {
1404 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1405 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1406 let result = column_stack(&[a, b]).unwrap();
1407 assert_eq!(result.shape(), &[2, 4]);
1408 assert_eq!(
1409 result.iter().copied().collect::<Vec<_>>(),
1410 vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
1411 );
1412 }
1413
1414 #[test]
1415 fn test_column_stack_length_mismatch() {
1416 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1417 let b = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1418 assert!(column_stack(&[a, b]).is_err());
1419 }
1420
1421 #[test]
1422 fn test_row_stack_is_vstack() {
1423 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1424 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1425 let row = row_stack(&[a.clone(), b.clone()]).unwrap();
1426 let v = vstack(&[a, b]).unwrap();
1427 assert_eq!(row.shape(), v.shape());
1428 assert_eq!(
1429 row.iter().copied().collect::<Vec<_>>(),
1430 v.iter().copied().collect::<Vec<_>>()
1431 );
1432 }
1433
1434 #[test]
1435 fn test_array_split_n_uneven() {
1436 let a = dyn_arr(&[7], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
1438 let parts = array_split_n(&a, 3, 0).unwrap();
1439 assert_eq!(parts.len(), 3);
1440 assert_eq!(
1441 parts[0].iter().copied().collect::<Vec<_>>(),
1442 vec![1.0, 2.0, 3.0]
1443 );
1444 assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0]);
1445 assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![6.0, 7.0]);
1446 }
1447
1448 #[test]
1449 fn test_array_split_n_even() {
1450 let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1451 let parts = array_split_n(&a, 3, 0).unwrap();
1452 assert_eq!(parts.len(), 3);
1453 for (i, expected) in [vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]
1454 .iter()
1455 .enumerate()
1456 {
1457 assert_eq!(&parts[i].iter().copied().collect::<Vec<_>>(), expected);
1458 }
1459 }
1460
1461 #[test]
1462 fn test_array_split_n_more_sections_than_elements() {
1463 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1466 let parts = array_split_n(&a, 5, 0).unwrap();
1467 assert_eq!(parts.len(), 5);
1468 assert_eq!(parts[0].iter().copied().collect::<Vec<_>>(), vec![1.0]);
1469 assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![2.0]);
1470 assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![3.0]);
1471 assert_eq!(
1472 parts[3].iter().copied().collect::<Vec<_>>(),
1473 Vec::<f64>::new()
1474 );
1475 assert_eq!(
1476 parts[4].iter().copied().collect::<Vec<_>>(),
1477 Vec::<f64>::new()
1478 );
1479 }
1480
1481 #[test]
1482 fn test_to_dyn_from_typed() {
1483 use crate::Array;
1484 use crate::dimension::Ix2;
1485 let typed =
1486 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1487 .unwrap();
1488 let dy = typed.to_dyn();
1489 assert_eq!(dy.shape(), &[2, 3]);
1490 assert_eq!(
1491 dy.iter().copied().collect::<Vec<_>>(),
1492 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
1493 );
1494 }
1495
1496 #[test]
1497 fn test_concatenate_typed_via_to_dyn() {
1498 use crate::Array;
1501 use crate::dimension::Ix2;
1502 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1503 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
1504 let result = concatenate(&[a.to_dyn(), b.to_dyn()], 0).unwrap();
1505 assert_eq!(result.shape(), &[4, 2]);
1506 assert_eq!(
1507 result.iter().copied().collect::<Vec<_>>(),
1508 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
1509 );
1510 }
1511}