1pub mod extended;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13pub fn reshape<T: Element, D: Dimension>(
31 a: &Array<T, D>,
32 new_shape: &[usize],
33) -> FerrayResult<Array<T, IxDyn>> {
34 let old_size = a.size();
35 let new_size: usize = new_shape.iter().product();
36 if old_size != new_size {
37 return Err(FerrayError::shape_mismatch(format!(
38 "cannot reshape array of size {} into shape {:?} (size {})",
39 old_size, new_shape, new_size,
40 )));
41 }
42 let view = a.inner.view().into_dyn();
43 let reshaped = view
44 .to_shape(ndarray::IxDyn(new_shape))
45 .map_err(|e| FerrayError::shape_mismatch(e.to_string()))?;
46 Ok(Array::from_ndarray(
50 reshaped.as_standard_layout().into_owned(),
51 ))
52}
53
54pub fn ravel<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
59 let n = a.size();
60 let view = a.inner.view().into_dyn();
61 let reshaped = view
62 .to_shape(ndarray::IxDyn(&[n]))
63 .expect("1-D reshape always succeeds for a size-preserving target");
64 let standard = reshaped.as_standard_layout().into_owned();
65 let one_d = standard
66 .into_dimensionality::<ndarray::Ix1>()
67 .expect("reshape result has ndim == 1 by construction");
68 Ok(Array::from_ndarray(one_d))
69}
70
71pub fn flatten<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
75 ravel(a)
76}
77
78pub fn squeeze<T: Element, D: Dimension>(
89 a: &Array<T, D>,
90 axis: Option<usize>,
91) -> FerrayResult<Array<T, IxDyn>> {
92 let shape = a.shape();
93 match axis {
94 Some(ax) => {
95 if ax >= shape.len() {
96 return Err(FerrayError::axis_out_of_bounds(ax, shape.len()));
97 }
98 if shape[ax] != 1 {
99 return Err(FerrayError::invalid_value(format!(
100 "cannot select axis {} with size {} for squeeze (must be 1)",
101 ax, shape[ax],
102 )));
103 }
104 let new_shape: Vec<usize> = shape
105 .iter()
106 .enumerate()
107 .filter(|&(i, _)| i != ax)
108 .map(|(_, &s)| s)
109 .collect();
110 let data: Vec<T> = a.iter().cloned().collect();
111 Array::from_vec(IxDyn::new(&new_shape), data)
112 }
113 None => {
114 let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
115 let new_shape = if new_shape.is_empty() && !shape.is_empty() {
118 vec![1]
119 } else if new_shape.is_empty() {
120 vec![]
121 } else {
122 new_shape
123 };
124 let data: Vec<T> = a.iter().cloned().collect();
125 Array::from_vec(IxDyn::new(&new_shape), data)
126 }
127 }
128}
129
130pub fn expand_dims<T: Element, D: Dimension>(
137 a: &Array<T, D>,
138 axis: usize,
139) -> FerrayResult<Array<T, IxDyn>> {
140 let ndim = a.ndim();
141 if axis > ndim {
142 return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
143 }
144 let mut new_shape: Vec<usize> = a.shape().to_vec();
145 new_shape.insert(axis, 1);
146 let data: Vec<T> = a.iter().cloned().collect();
147 Array::from_vec(IxDyn::new(&new_shape), data)
148}
149
150pub fn broadcast_to<T: Element, D: Dimension>(
163 a: &Array<T, D>,
164 new_shape: &[usize],
165) -> FerrayResult<Array<T, IxDyn>> {
166 let src_shape = a.shape();
167 let dyn_view = a.inner.view().into_dyn();
168 let broadcast_view = dyn_view
169 .broadcast(ndarray::IxDyn(new_shape))
170 .ok_or_else(|| FerrayError::BroadcastFailure {
171 shape_a: src_shape.to_vec(),
172 shape_b: new_shape.to_vec(),
173 })?;
174 Ok(Array::from_ndarray(
178 broadcast_view.as_standard_layout().into_owned(),
179 ))
180}
181
182pub fn concatenate<T: Element>(
195 arrays: &[Array<T, IxDyn>],
196 axis: usize,
197) -> FerrayResult<Array<T, IxDyn>> {
198 if arrays.is_empty() {
199 return Err(FerrayError::invalid_value(
200 "concatenate: need at least one array",
201 ));
202 }
203 let ndim = arrays[0].ndim();
204 if axis >= ndim {
205 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
206 }
207 let base_shape = arrays[0].shape();
208
209 let mut total_along_axis = 0usize;
211 for arr in arrays {
212 if arr.ndim() != ndim {
213 return Err(FerrayError::shape_mismatch(format!(
214 "all arrays must have same ndim; got {} and {}",
215 ndim,
216 arr.ndim(),
217 )));
218 }
219 for (i, (&s, &base)) in arr.shape().iter().zip(base_shape.iter()).enumerate() {
220 if i != axis && s != base {
221 return Err(FerrayError::shape_mismatch(format!(
222 "shape mismatch on axis {}: {} vs {}",
223 i, s, base,
224 )));
225 }
226 }
227 total_along_axis += arr.shape()[axis];
228 }
229
230 let mut new_shape = base_shape.to_vec();
232 new_shape[axis] = total_along_axis;
233 let total: usize = new_shape.iter().product();
234 let mut data = Vec::with_capacity(total);
235
236 let src_vecs: Vec<Vec<T>> = arrays.iter().map(|a| a.iter().cloned().collect()).collect();
238
239 let mut out_strides = vec![1usize; ndim];
241 for i in (0..ndim - 1).rev() {
242 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
243 }
244
245 for flat_idx in 0..total {
247 let mut rem = flat_idx;
249 let mut nd_idx = vec![0usize; ndim];
250 for i in 0..ndim {
251 nd_idx[i] = rem / out_strides[i];
252 rem %= out_strides[i];
253 }
254
255 let concat_idx = nd_idx[axis];
257 let mut offset = 0;
258 let mut src_arr_idx = 0;
259 for (k, arr) in arrays.iter().enumerate() {
260 let len_along = arr.shape()[axis];
261 if concat_idx < offset + len_along {
262 src_arr_idx = k;
263 break;
264 }
265 offset += len_along;
266 }
267 let local_concat_idx = concat_idx - offset;
268
269 let src_shape = arrays[src_arr_idx].shape();
271 let mut src_flat = 0usize;
272 let mut src_mul = 1usize;
273 for i in (0..ndim).rev() {
274 let idx = if i == axis {
275 local_concat_idx
276 } else {
277 nd_idx[i]
278 };
279 src_flat += idx * src_mul;
280 src_mul *= src_shape[i];
281 }
282
283 let elem = src_vecs[src_arr_idx].get(src_flat).ok_or_else(|| {
284 FerrayError::invalid_value(format!(
285 "concatenate: internal index {} out of range for source array of length {}",
286 src_flat,
287 src_vecs[src_arr_idx].len(),
288 ))
289 })?;
290 data.push(elem.clone());
291 }
292
293 Array::from_vec(IxDyn::new(&new_shape), data)
294}
295
296pub fn stack<T: Element>(arrays: &[Array<T, IxDyn>], axis: usize) -> FerrayResult<Array<T, IxDyn>> {
308 if arrays.is_empty() {
309 return Err(FerrayError::invalid_value("stack: need at least one array"));
310 }
311 let base_shape = arrays[0].shape();
312 let ndim = base_shape.len();
313
314 if axis > ndim {
315 return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
316 }
317
318 for arr in &arrays[1..] {
319 if arr.shape() != base_shape {
320 return Err(FerrayError::shape_mismatch(format!(
321 "all input arrays must have the same shape; got {:?} and {:?}",
322 base_shape,
323 arr.shape(),
324 )));
325 }
326 }
327
328 let mut expanded = Vec::with_capacity(arrays.len());
330 for arr in arrays {
331 expanded.push(expand_dims(arr, axis)?);
332 }
333 concatenate(&expanded, axis)
334}
335
336pub fn vstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
341 if arrays.is_empty() {
342 return Err(FerrayError::invalid_value(
343 "vstack: need at least one array",
344 ));
345 }
346 let ndim = arrays[0].ndim();
348 if ndim == 1 {
349 let mut reshaped = Vec::with_capacity(arrays.len());
350 for arr in arrays {
351 let n = arr.shape()[0];
352 reshaped.push(reshape(arr, &[1, n])?);
353 }
354 concatenate(&reshaped, 0)
355 } else {
356 concatenate(arrays, 0)
357 }
358}
359
360pub fn hstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
365 if arrays.is_empty() {
366 return Err(FerrayError::invalid_value(
367 "hstack: need at least one array",
368 ));
369 }
370 let ndim = arrays[0].ndim();
371 if ndim == 1 {
372 concatenate(arrays, 0)
373 } else {
374 concatenate(arrays, 1)
375 }
376}
377
378pub fn dstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
386 if arrays.is_empty() {
387 return Err(FerrayError::invalid_value(
388 "dstack: need at least one array",
389 ));
390 }
391 let mut expanded = Vec::with_capacity(arrays.len());
392 for arr in arrays {
393 let shape = arr.shape();
394 match shape.len() {
395 1 => {
396 let n = shape[0];
397 expanded.push(reshape(arr, &[1, n, 1])?);
398 }
399 2 => {
400 let (m, n) = (shape[0], shape[1]);
401 expanded.push(reshape(arr, &[m, n, 1])?);
402 }
403 _ => {
404 let data: Vec<T> = arr.iter().cloned().collect();
406 expanded.push(Array::from_vec(IxDyn::new(shape), data)?);
407 }
408 }
409 }
410 concatenate(&expanded, 2)
411}
412
413pub fn column_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
424 if arrays.is_empty() {
425 return Err(FerrayError::invalid_value(
426 "column_stack: need at least one array",
427 ));
428 }
429 let first_ndim = arrays[0].ndim();
430 if first_ndim == 1 {
431 let n = arrays[0].shape()[0];
433 let mut reshaped = Vec::with_capacity(arrays.len());
434 for arr in arrays {
435 if arr.ndim() != 1 {
436 return Err(FerrayError::shape_mismatch(
437 "column_stack: all inputs must have the same ndim",
438 ));
439 }
440 if arr.shape()[0] != n {
441 return Err(FerrayError::shape_mismatch(format!(
442 "column_stack: 1-D inputs must have the same length; got {} and {}",
443 n,
444 arr.shape()[0],
445 )));
446 }
447 reshaped.push(reshape(arr, &[n, 1])?);
448 }
449 concatenate(&reshaped, 1)
450 } else {
451 hstack(arrays)
453 }
454}
455
456pub fn row_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
461 vstack(arrays)
462}
463
464pub fn block<T: Element>(blocks: &[Vec<Array<T, IxDyn>>]) -> FerrayResult<Array<T, IxDyn>> {
474 if blocks.is_empty() {
475 return Err(FerrayError::invalid_value("block: empty input"));
476 }
477 let mut rows = Vec::with_capacity(blocks.len());
478 for row in blocks {
479 if row.is_empty() {
480 return Err(FerrayError::invalid_value("block: empty row"));
481 }
482 let row_arr = if row.len() == 1 {
484 let data: Vec<T> = row[0].iter().cloned().collect();
485 Array::from_vec(IxDyn::new(row[0].shape()), data)?
486 } else {
487 hstack(row)?
488 };
489 rows.push(row_arr);
490 }
491 if rows.len() == 1 {
492 Ok(rows.pop().unwrap_or_else(|| unreachable!()))
494 } else {
495 vstack(&rows)
496 }
497}
498
499pub fn split<T: Element>(
508 a: &Array<T, IxDyn>,
509 n_sections: usize,
510 axis: usize,
511) -> FerrayResult<Vec<Array<T, IxDyn>>> {
512 let shape = a.shape();
513 if axis >= shape.len() {
514 return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
515 }
516 let axis_len = shape[axis];
517 if n_sections == 0 {
518 return Err(FerrayError::invalid_value("split: n_sections must be > 0"));
519 }
520 if axis_len % n_sections != 0 {
521 return Err(FerrayError::invalid_value(format!(
522 "array of size {} along axis {} cannot be evenly split into {} sections",
523 axis_len, axis, n_sections,
524 )));
525 }
526 let chunk_size = axis_len / n_sections;
527 let indices: Vec<usize> = (1..n_sections).map(|i| i * chunk_size).collect();
528 array_split(a, &indices, axis)
529}
530
531pub fn array_split<T: Element>(
540 a: &Array<T, IxDyn>,
541 indices: &[usize],
542 axis: usize,
543) -> FerrayResult<Vec<Array<T, IxDyn>>> {
544 let shape = a.shape();
545 let ndim = shape.len();
546 if axis >= ndim {
547 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
548 }
549 let axis_len = shape[axis];
550 let src_data: Vec<T> = a.iter().cloned().collect();
551
552 let mut splits = Vec::with_capacity(indices.len() + 2);
554 splits.push(0);
555 for &idx in indices {
556 splits.push(idx.min(axis_len));
557 }
558 splits.push(axis_len);
559
560 let mut src_strides = vec![1usize; ndim];
562 for i in (0..ndim - 1).rev() {
563 src_strides[i] = src_strides[i + 1] * shape[i + 1];
564 }
565
566 let mut result = Vec::with_capacity(splits.len() - 1);
567 for w in splits.windows(2) {
568 let start = w[0];
569 let end = w[1];
570 let chunk_len = end - start;
571
572 let mut sub_shape = shape.to_vec();
573 sub_shape[axis] = chunk_len;
574 let sub_total: usize = sub_shape.iter().product();
575
576 let mut sub_strides = vec![1usize; ndim];
578 for i in (0..ndim - 1).rev() {
579 sub_strides[i] = sub_strides[i + 1] * sub_shape[i + 1];
580 }
581
582 let mut sub_data = Vec::with_capacity(sub_total);
583 for flat in 0..sub_total {
584 let mut rem = flat;
586 let mut src_flat = 0usize;
587 for i in 0..ndim {
588 let idx = rem / sub_strides[i];
589 rem %= sub_strides[i];
590 let src_idx = if i == axis { idx + start } else { idx };
591 src_flat += src_idx * src_strides[i];
592 }
593 sub_data.push(src_data[src_flat].clone());
594 }
595 result.push(Array::from_vec(IxDyn::new(&sub_shape), sub_data)?);
596 }
597
598 Ok(result)
599}
600
601pub fn array_split_n<T: Element>(
611 a: &Array<T, IxDyn>,
612 n: usize,
613 axis: usize,
614) -> FerrayResult<Vec<Array<T, IxDyn>>> {
615 if n == 0 {
616 return Err(FerrayError::invalid_value("array_split_n: n must be > 0"));
617 }
618 let shape = a.shape();
619 if axis >= shape.len() {
620 return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
621 }
622 let axis_len = shape[axis];
623
624 let base = axis_len / n;
628 let extra = axis_len % n;
629 let mut indices = Vec::with_capacity(n.saturating_sub(1));
630 let mut cum = 0usize;
631 for i in 0..n - 1 {
632 cum += if i < extra { base + 1 } else { base };
633 indices.push(cum);
634 }
635 array_split(a, &indices, axis)
636}
637
638pub fn vsplit<T: Element>(
642 a: &Array<T, IxDyn>,
643 n_sections: usize,
644) -> FerrayResult<Vec<Array<T, IxDyn>>> {
645 split(a, n_sections, 0)
646}
647
648pub fn hsplit<T: Element>(
652 a: &Array<T, IxDyn>,
653 n_sections: usize,
654) -> FerrayResult<Vec<Array<T, IxDyn>>> {
655 split(a, n_sections, 1)
656}
657
658pub fn dsplit<T: Element>(
662 a: &Array<T, IxDyn>,
663 n_sections: usize,
664) -> FerrayResult<Vec<Array<T, IxDyn>>> {
665 split(a, n_sections, 2)
666}
667
668pub fn transpose<T: Element, D: Dimension>(
683 a: &Array<T, D>,
684 axes: Option<&[usize]>,
685) -> FerrayResult<Array<T, IxDyn>> {
686 let ndim = a.ndim();
687 let perm: Vec<usize> = match axes {
688 Some(ax) => {
689 if ax.len() != ndim {
690 return Err(FerrayError::invalid_value(format!(
691 "axes must have length {} but got {}",
692 ndim,
693 ax.len(),
694 )));
695 }
696 let mut seen = vec![false; ndim];
698 for &a in ax {
699 if a >= ndim {
700 return Err(FerrayError::axis_out_of_bounds(a, ndim));
701 }
702 if seen[a] {
703 return Err(FerrayError::invalid_value(format!(
704 "duplicate axis {} in transpose",
705 a,
706 )));
707 }
708 seen[a] = true;
709 }
710 ax.to_vec()
711 }
712 None => (0..ndim).rev().collect(),
713 };
714
715 let permuted = a
724 .inner
725 .view()
726 .into_dyn()
727 .permuted_axes(ndarray::IxDyn(&perm));
728 Ok(Array::from_ndarray(
729 permuted.as_standard_layout().into_owned(),
730 ))
731}
732
733pub fn swapaxes<T: Element, D: Dimension>(
740 a: &Array<T, D>,
741 axis1: usize,
742 axis2: usize,
743) -> FerrayResult<Array<T, IxDyn>> {
744 let ndim = a.ndim();
745 if axis1 >= ndim {
746 return Err(FerrayError::axis_out_of_bounds(axis1, ndim));
747 }
748 if axis2 >= ndim {
749 return Err(FerrayError::axis_out_of_bounds(axis2, ndim));
750 }
751 let mut perm: Vec<usize> = (0..ndim).collect();
752 perm.swap(axis1, axis2);
753 transpose(a, Some(&perm))
754}
755
756pub fn moveaxis<T: Element, D: Dimension>(
763 a: &Array<T, D>,
764 source: usize,
765 destination: usize,
766) -> FerrayResult<Array<T, IxDyn>> {
767 let ndim = a.ndim();
768 if source >= ndim {
769 return Err(FerrayError::axis_out_of_bounds(source, ndim));
770 }
771 if destination >= ndim {
772 return Err(FerrayError::axis_out_of_bounds(destination, ndim));
773 }
774 let mut order: Vec<usize> = (0..ndim).filter(|&x| x != source).collect();
776 order.insert(destination, source);
777 transpose(a, Some(&order))
778}
779
780pub fn rollaxis<T: Element, D: Dimension>(
787 a: &Array<T, D>,
788 axis: usize,
789 start: usize,
790) -> FerrayResult<Array<T, IxDyn>> {
791 let ndim = a.ndim();
792 if axis >= ndim {
793 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
794 }
795 if start > ndim {
796 return Err(FerrayError::axis_out_of_bounds(start, ndim + 1));
797 }
798 let dst = if start > axis { start - 1 } else { start };
799 if axis == dst {
800 let data: Vec<T> = a.iter().cloned().collect();
802 return Array::from_vec(IxDyn::new(a.shape()), data);
803 }
804 moveaxis(a, axis, dst)
805}
806
807pub fn flip<T: Element, D: Dimension>(
814 a: &Array<T, D>,
815 axis: usize,
816) -> FerrayResult<Array<T, IxDyn>> {
817 let shape = a.shape();
818 let ndim = shape.len();
819 if axis >= ndim {
820 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
821 }
822 let src_data: Vec<T> = a.iter().cloned().collect();
823 let total = src_data.len();
824
825 let mut strides = vec![1usize; ndim];
827 for i in (0..ndim.saturating_sub(1)).rev() {
828 strides[i] = strides[i + 1] * shape[i + 1];
829 }
830
831 let mut data = Vec::with_capacity(total);
832 for flat in 0..total {
833 let mut rem = flat;
834 let mut src_flat = 0usize;
835 for i in 0..ndim {
836 let idx = rem / strides[i];
837 rem %= strides[i];
838 let src_idx = if i == axis { shape[i] - 1 - idx } else { idx };
839 src_flat += src_idx * strides[i];
840 }
841 data.push(src_data[src_flat].clone());
842 }
843 Array::from_vec(IxDyn::new(shape), data)
844}
845
846pub fn fliplr<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
853 if a.ndim() < 2 {
854 return Err(FerrayError::invalid_value(
855 "fliplr: array must be at least 2-D",
856 ));
857 }
858 flip(a, 1)
859}
860
861pub fn flipud<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
868 if a.ndim() < 1 {
869 return Err(FerrayError::invalid_value(
870 "flipud: array must be at least 1-D",
871 ));
872 }
873 flip(a, 0)
874}
875
876pub fn rot90<T: Element, D: Dimension>(a: &Array<T, D>, k: i32) -> FerrayResult<Array<T, IxDyn>> {
885 if a.ndim() < 2 {
886 return Err(FerrayError::invalid_value(
887 "rot90: array must be at least 2-D",
888 ));
889 }
890 let k = k.rem_euclid(4);
892 let shape = a.shape();
893 let data: Vec<T> = a.iter().cloned().collect();
894
895 let as_dyn = Array::from_vec(IxDyn::new(shape), data)?;
897
898 match k {
899 0 => Ok(as_dyn),
900 1 => {
901 let flipped = flip(&as_dyn, 1)?;
903 swapaxes(&flipped, 0, 1)
904 }
905 2 => {
906 let f1 = flip(&as_dyn, 0)?;
908 flip(&f1, 1)
909 }
910 3 => {
911 let transposed = swapaxes(&as_dyn, 0, 1)?;
913 flip(&transposed, 1)
914 }
915 _ => unreachable!(),
916 }
917}
918
919pub fn roll<T: Element, D: Dimension>(
929 a: &Array<T, D>,
930 shift: isize,
931 axis: Option<usize>,
932) -> FerrayResult<Array<T, IxDyn>> {
933 match axis {
934 None => {
935 let data: Vec<T> = a.iter().cloned().collect();
937 let n = data.len();
938 if n == 0 {
939 return Array::from_vec(IxDyn::new(a.shape()), data);
940 }
941 let shift = ((shift % n as isize) + n as isize) as usize % n;
942 let mut rolled = Vec::with_capacity(n);
943 for i in 0..n {
944 rolled.push(data[(n + i - shift) % n].clone());
945 }
946 Array::from_vec(IxDyn::new(a.shape()), rolled)
947 }
948 Some(ax) => {
949 let shape = a.shape();
950 let ndim = shape.len();
951 if ax >= ndim {
952 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
953 }
954 let axis_len = shape[ax];
955 if axis_len == 0 {
956 let data: Vec<T> = a.iter().cloned().collect();
957 return Array::from_vec(IxDyn::new(shape), data);
958 }
959 let shift = ((shift % axis_len as isize) + axis_len as isize) as usize % axis_len;
960 let src_data: Vec<T> = a.iter().cloned().collect();
961 let total = src_data.len();
962
963 let mut strides = vec![1usize; ndim];
965 for i in (0..ndim.saturating_sub(1)).rev() {
966 strides[i] = strides[i + 1] * shape[i + 1];
967 }
968
969 let mut data = Vec::with_capacity(total);
970 for flat in 0..total {
971 let mut rem = flat;
972 let mut src_flat = 0usize;
973 #[allow(clippy::needless_range_loop)]
974 for i in 0..ndim {
975 let idx = rem / strides[i];
976 rem %= strides[i];
977 let src_idx = if i == ax {
978 (axis_len + idx - shift) % axis_len
979 } else {
980 idx
981 };
982 src_flat += src_idx * strides[i];
983 }
984 data.push(src_data[src_flat].clone());
985 }
986 Array::from_vec(IxDyn::new(shape), data)
987 }
988 }
989}
990
991#[cfg(test)]
996mod tests {
997 use super::*;
998
999 fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
1000 Array::from_vec(IxDyn::new(shape), data).unwrap()
1001 }
1002
1003 #[test]
1006 fn test_reshape() {
1007 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1008 let b = reshape(&a, &[3, 2]).unwrap();
1009 assert_eq!(b.shape(), &[3, 2]);
1010 let data: Vec<f64> = b.iter().copied().collect();
1011 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1012 }
1013
1014 #[test]
1015 fn test_reshape_size_mismatch() {
1016 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1017 assert!(reshape(&a, &[2, 4]).is_err());
1018 }
1019
1020 #[test]
1021 fn test_ravel() {
1022 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1023 let b = ravel(&a).unwrap();
1024 assert_eq!(b.shape(), &[6]);
1025 assert_eq!(b.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1026 }
1027
1028 #[test]
1029 fn test_flatten() {
1030 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1031 let b = flatten(&a).unwrap();
1032 assert_eq!(b.shape(), &[6]);
1033 }
1034
1035 #[test]
1036 fn test_squeeze() {
1037 let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
1038 let b = squeeze(&a, None).unwrap();
1039 assert_eq!(b.shape(), &[3]);
1040 }
1041
1042 #[test]
1043 fn test_squeeze_specific_axis() {
1044 let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
1045 let b = squeeze(&a, Some(0)).unwrap();
1046 assert_eq!(b.shape(), &[3, 1]);
1047 }
1048
1049 #[test]
1050 fn test_squeeze_not_size_1() {
1051 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1052 assert!(squeeze(&a, Some(0)).is_err());
1053 }
1054
1055 #[test]
1056 fn test_expand_dims() {
1057 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1058 let b = expand_dims(&a, 0).unwrap();
1059 assert_eq!(b.shape(), &[1, 3]);
1060 let c = expand_dims(&a, 1).unwrap();
1061 assert_eq!(c.shape(), &[3, 1]);
1062 }
1063
1064 #[test]
1065 fn test_expand_dims_oob() {
1066 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1067 assert!(expand_dims(&a, 3).is_err());
1068 }
1069
1070 #[test]
1071 fn test_broadcast_to() {
1072 let a = dyn_arr(&[1, 3], vec![1.0, 2.0, 3.0]);
1073 let b = broadcast_to(&a, &[3, 3]).unwrap();
1074 assert_eq!(b.shape(), &[3, 3]);
1075 let data: Vec<f64> = b.iter().copied().collect();
1076 assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
1077 }
1078
1079 #[test]
1080 fn test_broadcast_to_1d_to_2d() {
1081 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1082 let b = broadcast_to(&a, &[2, 3]).unwrap();
1083 assert_eq!(b.shape(), &[2, 3]);
1084 }
1085
1086 #[test]
1087 fn test_broadcast_to_incompatible() {
1088 let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1089 assert!(broadcast_to(&a, &[3]).is_err());
1090 }
1091
1092 #[test]
1093 fn test_broadcast_to_from_non_contiguous_source() {
1094 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1099 let t = transpose(&a, None).unwrap();
1101 let b = broadcast_to(&t, &[2, 3, 2]).unwrap();
1102 assert_eq!(b.shape(), &[2, 3, 2]);
1103 let data: Vec<f64> = b.iter().copied().collect();
1105 assert_eq!(&data[..6], &data[6..12]);
1106 }
1107
1108 #[test]
1111 fn test_concatenate() {
1112 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1113 let b = dyn_arr(&[2, 3], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1114 let c = concatenate(&[a, b], 0).unwrap();
1115 assert_eq!(c.shape(), &[4, 3]);
1116 }
1117
1118 #[test]
1119 fn test_concatenate_axis1() {
1120 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1121 let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1122 let c = concatenate(&[a, b], 1).unwrap();
1123 assert_eq!(c.shape(), &[2, 5]);
1124 }
1125
1126 #[test]
1127 fn test_concatenate_shape_mismatch() {
1128 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1129 let b = dyn_arr(&[3, 3], vec![1.0; 9]);
1130 let c = concatenate(&[a, b], 0).unwrap();
1134 assert_eq!(c.shape(), &[5, 3]);
1135 }
1136
1137 #[test]
1138 fn test_concatenate_empty() {
1139 let v: Vec<Array<f64, IxDyn>> = vec![];
1140 assert!(concatenate(&v, 0).is_err());
1141 }
1142
1143 #[test]
1144 fn test_stack() {
1145 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1146 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1147 let c = stack(&[a, b], 0).unwrap();
1148 assert_eq!(c.shape(), &[2, 3]);
1149 let data: Vec<f64> = c.iter().copied().collect();
1150 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1151 }
1152
1153 #[test]
1154 fn test_stack_axis1() {
1155 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1156 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1157 let c = stack(&[a, b], 1).unwrap();
1158 assert_eq!(c.shape(), &[3, 2]);
1159 let data: Vec<f64> = c.iter().copied().collect();
1160 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1161 }
1162
1163 #[test]
1164 fn test_vstack() {
1165 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1166 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1167 let c = vstack(&[a, b]).unwrap();
1168 assert_eq!(c.shape(), &[2, 3]);
1169 }
1170
1171 #[test]
1172 fn test_hstack() {
1173 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1174 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1175 let c = hstack(&[a, b]).unwrap();
1176 assert_eq!(c.shape(), &[6]);
1177 }
1178
1179 #[test]
1180 fn test_hstack_2d() {
1181 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1182 let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1183 let c = hstack(&[a, b]).unwrap();
1184 assert_eq!(c.shape(), &[2, 5]);
1185 }
1186
1187 #[test]
1188 fn test_dstack() {
1189 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1190 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1191 let c = dstack(&[a, b]).unwrap();
1192 assert_eq!(c.shape(), &[2, 2, 2]);
1193 }
1194
1195 #[test]
1196 fn test_block() {
1197 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1198 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1199 let c = dyn_arr(&[2, 2], vec![9.0, 10.0, 11.0, 12.0]);
1200 let d = dyn_arr(&[2, 2], vec![13.0, 14.0, 15.0, 16.0]);
1201 let result = block(&[vec![a, b], vec![c, d]]).unwrap();
1202 assert_eq!(result.shape(), &[4, 4]);
1203 }
1204
1205 #[test]
1206 fn test_split() {
1207 let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1208 let parts = split(&a, 3, 0).unwrap();
1209 assert_eq!(parts.len(), 3);
1210 assert_eq!(parts[0].shape(), &[2]);
1211 assert_eq!(parts[1].shape(), &[2]);
1212 assert_eq!(parts[2].shape(), &[2]);
1213 }
1214
1215 #[test]
1216 fn test_split_uneven() {
1217 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1218 assert!(split(&a, 3, 0).is_err()); }
1220
1221 #[test]
1222 fn test_array_split() {
1223 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1224 let parts = array_split(&a, &[2, 4], 0).unwrap();
1225 assert_eq!(parts.len(), 3);
1226 assert_eq!(parts[0].shape(), &[2]); assert_eq!(parts[1].shape(), &[2]); assert_eq!(parts[2].shape(), &[1]); }
1230
1231 #[test]
1232 fn test_vsplit() {
1233 let a = dyn_arr(&[4, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1234 let parts = vsplit(&a, 2).unwrap();
1235 assert_eq!(parts.len(), 2);
1236 assert_eq!(parts[0].shape(), &[2, 2]);
1237 }
1238
1239 #[test]
1240 fn test_hsplit() {
1241 let a = dyn_arr(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1242 let parts = hsplit(&a, 2).unwrap();
1243 assert_eq!(parts.len(), 2);
1244 assert_eq!(parts[0].shape(), &[2, 2]);
1245 }
1246
1247 #[test]
1250 fn test_transpose_2d() {
1251 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1252 let b = transpose(&a, None).unwrap();
1253 assert_eq!(b.shape(), &[3, 2]);
1254 let data: Vec<f64> = b.iter().copied().collect();
1255 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1256 }
1257
1258 #[test]
1259 fn test_transpose_explicit() {
1260 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1261 let b = transpose(&a, Some(&[1, 0])).unwrap();
1262 assert_eq!(b.shape(), &[3, 2]);
1263 }
1264
1265 #[test]
1266 fn test_transpose_bad_axes() {
1267 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1268 assert!(transpose(&a, Some(&[0])).is_err()); }
1270
1271 #[test]
1272 fn test_swapaxes() {
1273 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1274 let b = swapaxes(&a, 0, 2).unwrap();
1275 assert_eq!(b.shape(), &[4, 3, 2]);
1276 }
1277
1278 #[test]
1279 fn test_moveaxis() {
1280 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1281 let b = moveaxis(&a, 0, 2).unwrap();
1282 assert_eq!(b.shape(), &[3, 4, 2]);
1283 }
1284
1285 #[test]
1286 fn test_rollaxis() {
1287 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1288 let b = rollaxis(&a, 2, 0).unwrap();
1289 assert_eq!(b.shape(), &[4, 2, 3]);
1290 }
1291
1292 #[test]
1293 fn test_flip() {
1294 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1295 let b = flip(&a, 0).unwrap();
1296 let data: Vec<f64> = b.iter().copied().collect();
1297 assert_eq!(data, vec![3.0, 2.0, 1.0]);
1298 }
1299
1300 #[test]
1301 fn test_flip_2d() {
1302 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1303 let b = flip(&a, 0).unwrap();
1304 let data: Vec<f64> = b.iter().copied().collect();
1305 assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1306
1307 let c = flip(&a, 1).unwrap();
1308 let data2: Vec<f64> = c.iter().copied().collect();
1309 assert_eq!(data2, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1310 }
1311
1312 #[test]
1313 fn test_fliplr() {
1314 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1315 let b = fliplr(&a).unwrap();
1316 let data: Vec<f64> = b.iter().copied().collect();
1317 assert_eq!(data, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1318 }
1319
1320 #[test]
1321 fn test_flipud() {
1322 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1323 let b = flipud(&a).unwrap();
1324 let data: Vec<f64> = b.iter().copied().collect();
1325 assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1326 }
1327
1328 #[test]
1329 fn test_fliplr_1d_err() {
1330 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1331 assert!(fliplr(&a).is_err());
1332 }
1333
1334 #[test]
1335 fn test_rot90_once() {
1336 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1338 let b = rot90(&a, 1).unwrap();
1339 assert_eq!(b.shape(), &[2, 2]);
1340 let data: Vec<f64> = b.iter().copied().collect();
1341 assert_eq!(data, vec![2.0, 4.0, 1.0, 3.0]);
1342 }
1343
1344 #[test]
1345 fn test_rot90_twice() {
1346 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1347 let b = rot90(&a, 2).unwrap();
1348 let data: Vec<f64> = b.iter().copied().collect();
1349 assert_eq!(data, vec![4.0, 3.0, 2.0, 1.0]);
1350 }
1351
1352 #[test]
1353 fn test_rot90_four_is_identity() {
1354 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1355 let b = rot90(&a, 4).unwrap();
1356 let data_a: Vec<f64> = a.iter().copied().collect();
1357 let data_b: Vec<f64> = b.iter().copied().collect();
1358 assert_eq!(data_a, data_b);
1359 assert_eq!(a.shape(), b.shape());
1360 }
1361
1362 #[test]
1363 fn test_roll_flat() {
1364 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1365 let b = roll(&a, 2, None).unwrap();
1366 let data: Vec<f64> = b.iter().copied().collect();
1367 assert_eq!(data, vec![4.0, 5.0, 1.0, 2.0, 3.0]);
1368 }
1369
1370 #[test]
1371 fn test_roll_negative() {
1372 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1373 let b = roll(&a, -2, None).unwrap();
1374 let data: Vec<f64> = b.iter().copied().collect();
1375 assert_eq!(data, vec![3.0, 4.0, 5.0, 1.0, 2.0]);
1376 }
1377
1378 #[test]
1379 fn test_roll_axis() {
1380 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1381 let b = roll(&a, 1, Some(1)).unwrap();
1382 let data: Vec<f64> = b.iter().copied().collect();
1383 assert_eq!(data, vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1384 }
1385
1386 #[test]
1391 fn test_column_stack_1d() {
1392 let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1394 let b = dyn_arr(&[4], vec![10.0, 20.0, 30.0, 40.0]);
1395 let c = dyn_arr(&[4], vec![100.0, 200.0, 300.0, 400.0]);
1396 let result = column_stack(&[a, b, c]).unwrap();
1397 assert_eq!(result.shape(), &[4, 3]);
1398 assert_eq!(
1399 result.iter().copied().collect::<Vec<_>>(),
1400 vec![
1401 1.0, 10.0, 100.0, 2.0, 20.0, 200.0, 3.0, 30.0, 300.0, 4.0, 40.0, 400.0, ]
1406 );
1407 }
1408
1409 #[test]
1410 fn test_column_stack_2d_same_as_hstack() {
1411 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1412 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1413 let result = column_stack(&[a, b]).unwrap();
1414 assert_eq!(result.shape(), &[2, 4]);
1415 assert_eq!(
1416 result.iter().copied().collect::<Vec<_>>(),
1417 vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
1418 );
1419 }
1420
1421 #[test]
1422 fn test_column_stack_length_mismatch() {
1423 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1424 let b = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1425 assert!(column_stack(&[a, b]).is_err());
1426 }
1427
1428 #[test]
1429 fn test_row_stack_is_vstack() {
1430 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1431 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1432 let row = row_stack(&[a.clone(), b.clone()]).unwrap();
1433 let v = vstack(&[a, b]).unwrap();
1434 assert_eq!(row.shape(), v.shape());
1435 assert_eq!(
1436 row.iter().copied().collect::<Vec<_>>(),
1437 v.iter().copied().collect::<Vec<_>>()
1438 );
1439 }
1440
1441 #[test]
1442 fn test_array_split_n_uneven() {
1443 let a = dyn_arr(&[7], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
1445 let parts = array_split_n(&a, 3, 0).unwrap();
1446 assert_eq!(parts.len(), 3);
1447 assert_eq!(
1448 parts[0].iter().copied().collect::<Vec<_>>(),
1449 vec![1.0, 2.0, 3.0]
1450 );
1451 assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0]);
1452 assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![6.0, 7.0]);
1453 }
1454
1455 #[test]
1456 fn test_array_split_n_even() {
1457 let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1458 let parts = array_split_n(&a, 3, 0).unwrap();
1459 assert_eq!(parts.len(), 3);
1460 for (i, expected) in [vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]
1461 .iter()
1462 .enumerate()
1463 {
1464 assert_eq!(&parts[i].iter().copied().collect::<Vec<_>>(), expected);
1465 }
1466 }
1467
1468 #[test]
1469 fn test_array_split_n_more_sections_than_elements() {
1470 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1473 let parts = array_split_n(&a, 5, 0).unwrap();
1474 assert_eq!(parts.len(), 5);
1475 assert_eq!(parts[0].iter().copied().collect::<Vec<_>>(), vec![1.0]);
1476 assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![2.0]);
1477 assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![3.0]);
1478 assert_eq!(
1479 parts[3].iter().copied().collect::<Vec<_>>(),
1480 Vec::<f64>::new()
1481 );
1482 assert_eq!(
1483 parts[4].iter().copied().collect::<Vec<_>>(),
1484 Vec::<f64>::new()
1485 );
1486 }
1487
1488 #[test]
1489 fn test_to_dyn_from_typed() {
1490 use crate::Array;
1491 use crate::dimension::Ix2;
1492 let typed =
1493 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1494 .unwrap();
1495 let dy = typed.to_dyn();
1496 assert_eq!(dy.shape(), &[2, 3]);
1497 assert_eq!(
1498 dy.iter().copied().collect::<Vec<_>>(),
1499 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
1500 );
1501 }
1502
1503 #[test]
1504 fn test_concatenate_typed_via_to_dyn() {
1505 use crate::Array;
1508 use crate::dimension::Ix2;
1509 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1510 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
1511 let result = concatenate(&[a.to_dyn(), b.to_dyn()], 0).unwrap();
1512 assert_eq!(result.shape(), &[4, 2]);
1513 assert_eq!(
1514 result.iter().copied().collect::<Vec<_>>(),
1515 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
1516 );
1517 }
1518}