1pub mod extended;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13pub fn reshape<T: Element, D: Dimension>(
27 a: &Array<T, D>,
28 new_shape: &[usize],
29) -> FerrayResult<Array<T, IxDyn>> {
30 let old_size = a.size();
31 let new_size: usize = new_shape.iter().product();
32 if old_size != new_size {
33 return Err(FerrayError::shape_mismatch(format!(
34 "cannot reshape array of size {} into shape {:?} (size {})",
35 old_size, new_shape, new_size,
36 )));
37 }
38 let data: Vec<T> = a.iter().cloned().collect();
39 Array::from_vec(IxDyn::new(new_shape), data)
40}
41
42pub fn ravel<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
46 let data: Vec<T> = a.iter().cloned().collect();
47 let n = data.len();
48 Array::from_vec(Ix1::new([n]), data)
49}
50
51pub fn flatten<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
55 ravel(a)
56}
57
58pub fn squeeze<T: Element, D: Dimension>(
69 a: &Array<T, D>,
70 axis: Option<usize>,
71) -> FerrayResult<Array<T, IxDyn>> {
72 let shape = a.shape();
73 match axis {
74 Some(ax) => {
75 if ax >= shape.len() {
76 return Err(FerrayError::axis_out_of_bounds(ax, shape.len()));
77 }
78 if shape[ax] != 1 {
79 return Err(FerrayError::invalid_value(format!(
80 "cannot select axis {} with size {} for squeeze (must be 1)",
81 ax, shape[ax],
82 )));
83 }
84 let new_shape: Vec<usize> = shape
85 .iter()
86 .enumerate()
87 .filter(|&(i, _)| i != ax)
88 .map(|(_, &s)| s)
89 .collect();
90 let data: Vec<T> = a.iter().cloned().collect();
91 Array::from_vec(IxDyn::new(&new_shape), data)
92 }
93 None => {
94 let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
95 let new_shape = if new_shape.is_empty() && !shape.is_empty() {
98 vec![1]
99 } else if new_shape.is_empty() {
100 vec![]
101 } else {
102 new_shape
103 };
104 let data: Vec<T> = a.iter().cloned().collect();
105 Array::from_vec(IxDyn::new(&new_shape), data)
106 }
107 }
108}
109
110pub fn expand_dims<T: Element, D: Dimension>(
117 a: &Array<T, D>,
118 axis: usize,
119) -> FerrayResult<Array<T, IxDyn>> {
120 let ndim = a.ndim();
121 if axis > ndim {
122 return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
123 }
124 let mut new_shape: Vec<usize> = a.shape().to_vec();
125 new_shape.insert(axis, 1);
126 let data: Vec<T> = a.iter().cloned().collect();
127 Array::from_vec(IxDyn::new(&new_shape), data)
128}
129
130pub fn broadcast_to<T: Element, D: Dimension>(
139 a: &Array<T, D>,
140 new_shape: &[usize],
141) -> FerrayResult<Array<T, IxDyn>> {
142 let src_shape = a.shape();
143 let src_ndim = src_shape.len();
144 let dst_ndim = new_shape.len();
145
146 if dst_ndim < src_ndim {
147 return Err(FerrayError::BroadcastFailure {
148 shape_a: src_shape.to_vec(),
149 shape_b: new_shape.to_vec(),
150 });
151 }
152
153 let pad = dst_ndim - src_ndim;
155 for i in 0..src_ndim {
156 let s = src_shape[i];
157 let d = new_shape[pad + i];
158 if s != d && s != 1 {
159 return Err(FerrayError::BroadcastFailure {
160 shape_a: src_shape.to_vec(),
161 shape_b: new_shape.to_vec(),
162 });
163 }
164 }
165
166 let total: usize = new_shape.iter().product();
168 let mut data = Vec::with_capacity(total);
169 let src_data: Vec<T> = a.iter().cloned().collect();
170
171 let mut src_strides = vec![1usize; src_ndim];
173 for i in (0..src_ndim.saturating_sub(1)).rev() {
174 src_strides[i] = src_strides[i + 1] * src_shape[i + 1];
175 }
176
177 let mut out_strides = vec![1usize; dst_ndim];
179 for i in (0..dst_ndim.saturating_sub(1)).rev() {
180 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
181 }
182
183 for flat in 0..total {
184 let mut rem = flat;
185 let mut s_idx = 0usize;
186 #[allow(clippy::needless_range_loop)]
187 for i in 0..dst_ndim {
188 let idx = rem / out_strides[i];
189 rem %= out_strides[i];
190 if i >= pad {
191 let src_i = i - pad;
192 let src_idx = if src_shape[src_i] == 1 { 0 } else { idx };
193 s_idx += src_idx * src_strides[src_i];
194 }
195 }
196 data.push(src_data[s_idx].clone());
197 }
198
199 Array::from_vec(IxDyn::new(new_shape), data)
200}
201
202pub fn concatenate<T: Element>(
215 arrays: &[Array<T, IxDyn>],
216 axis: usize,
217) -> FerrayResult<Array<T, IxDyn>> {
218 if arrays.is_empty() {
219 return Err(FerrayError::invalid_value(
220 "concatenate: need at least one array",
221 ));
222 }
223 let ndim = arrays[0].ndim();
224 if axis >= ndim {
225 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
226 }
227 let base_shape = arrays[0].shape();
228
229 let mut total_along_axis = 0usize;
231 for arr in arrays {
232 if arr.ndim() != ndim {
233 return Err(FerrayError::shape_mismatch(format!(
234 "all arrays must have same ndim; got {} and {}",
235 ndim,
236 arr.ndim(),
237 )));
238 }
239 for (i, (&s, &base)) in arr.shape().iter().zip(base_shape.iter()).enumerate() {
240 if i != axis && s != base {
241 return Err(FerrayError::shape_mismatch(format!(
242 "shape mismatch on axis {}: {} vs {}",
243 i, s, base,
244 )));
245 }
246 }
247 total_along_axis += arr.shape()[axis];
248 }
249
250 let mut new_shape = base_shape.to_vec();
252 new_shape[axis] = total_along_axis;
253 let total: usize = new_shape.iter().product();
254 let mut data = Vec::with_capacity(total);
255
256 let mut out_strides = vec![1usize; ndim];
258 for i in (0..ndim - 1).rev() {
259 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
260 }
261
262 for flat_idx in 0..total {
264 let mut rem = flat_idx;
266 let mut nd_idx = vec![0usize; ndim];
267 for i in 0..ndim {
268 nd_idx[i] = rem / out_strides[i];
269 rem %= out_strides[i];
270 }
271
272 let concat_idx = nd_idx[axis];
274 let mut offset = 0;
275 let mut src_arr_idx = 0;
276 for (k, arr) in arrays.iter().enumerate() {
277 let len_along = arr.shape()[axis];
278 if concat_idx < offset + len_along {
279 src_arr_idx = k;
280 break;
281 }
282 offset += len_along;
283 }
284 let local_concat_idx = concat_idx - offset;
285
286 let src = &arrays[src_arr_idx];
288 let src_shape = src.shape();
289 let mut src_flat = 0usize;
290 let mut src_mul = 1usize;
291 for i in (0..ndim).rev() {
292 let idx = if i == axis {
293 local_concat_idx
294 } else {
295 nd_idx[i]
296 };
297 src_flat += idx * src_mul;
298 src_mul *= src_shape[i];
299 }
300
301 let src_data: &T = src.iter().nth(src_flat).unwrap();
302 data.push(src_data.clone());
303 }
304
305 Array::from_vec(IxDyn::new(&new_shape), data)
306}
307
308pub fn stack<T: Element>(arrays: &[Array<T, IxDyn>], axis: usize) -> FerrayResult<Array<T, IxDyn>> {
320 if arrays.is_empty() {
321 return Err(FerrayError::invalid_value("stack: need at least one array"));
322 }
323 let base_shape = arrays[0].shape();
324 let ndim = base_shape.len();
325
326 if axis > ndim {
327 return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
328 }
329
330 for arr in &arrays[1..] {
331 if arr.shape() != base_shape {
332 return Err(FerrayError::shape_mismatch(format!(
333 "all input arrays must have the same shape; got {:?} and {:?}",
334 base_shape,
335 arr.shape(),
336 )));
337 }
338 }
339
340 let mut expanded = Vec::with_capacity(arrays.len());
342 for arr in arrays {
343 expanded.push(expand_dims(arr, axis)?);
344 }
345 concatenate(&expanded, axis)
346}
347
348pub fn vstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
353 if arrays.is_empty() {
354 return Err(FerrayError::invalid_value(
355 "vstack: need at least one array",
356 ));
357 }
358 let ndim = arrays[0].ndim();
360 if ndim == 1 {
361 let mut reshaped = Vec::with_capacity(arrays.len());
362 for arr in arrays {
363 let n = arr.shape()[0];
364 reshaped.push(reshape(arr, &[1, n])?);
365 }
366 concatenate(&reshaped, 0)
367 } else {
368 concatenate(arrays, 0)
369 }
370}
371
372pub fn hstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
377 if arrays.is_empty() {
378 return Err(FerrayError::invalid_value(
379 "hstack: need at least one array",
380 ));
381 }
382 let ndim = arrays[0].ndim();
383 if ndim == 1 {
384 concatenate(arrays, 0)
385 } else {
386 concatenate(arrays, 1)
387 }
388}
389
390pub fn dstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
398 if arrays.is_empty() {
399 return Err(FerrayError::invalid_value(
400 "dstack: need at least one array",
401 ));
402 }
403 let mut expanded = Vec::with_capacity(arrays.len());
404 for arr in arrays {
405 let shape = arr.shape();
406 match shape.len() {
407 1 => {
408 let n = shape[0];
409 expanded.push(reshape(arr, &[1, n, 1])?);
410 }
411 2 => {
412 let (m, n) = (shape[0], shape[1]);
413 expanded.push(reshape(arr, &[m, n, 1])?);
414 }
415 _ => {
416 let data: Vec<T> = arr.iter().cloned().collect();
418 expanded.push(Array::from_vec(IxDyn::new(shape), data)?);
419 }
420 }
421 }
422 concatenate(&expanded, 2)
423}
424
425pub fn block<T: Element>(blocks: &[Vec<Array<T, IxDyn>>]) -> FerrayResult<Array<T, IxDyn>> {
435 if blocks.is_empty() {
436 return Err(FerrayError::invalid_value("block: empty input"));
437 }
438 let mut rows = Vec::with_capacity(blocks.len());
439 for row in blocks {
440 if row.is_empty() {
441 return Err(FerrayError::invalid_value("block: empty row"));
442 }
443 let row_arr = if row.len() == 1 {
445 let data: Vec<T> = row[0].iter().cloned().collect();
446 Array::from_vec(IxDyn::new(row[0].shape()), data)?
447 } else {
448 hstack(row)?
449 };
450 rows.push(row_arr);
451 }
452 if rows.len() == 1 {
453 Ok(rows.into_iter().next().unwrap())
454 } else {
455 vstack(&rows)
456 }
457}
458
459pub fn split<T: Element>(
468 a: &Array<T, IxDyn>,
469 n_sections: usize,
470 axis: usize,
471) -> FerrayResult<Vec<Array<T, IxDyn>>> {
472 let shape = a.shape();
473 if axis >= shape.len() {
474 return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
475 }
476 let axis_len = shape[axis];
477 if n_sections == 0 {
478 return Err(FerrayError::invalid_value("split: n_sections must be > 0"));
479 }
480 if axis_len % n_sections != 0 {
481 return Err(FerrayError::invalid_value(format!(
482 "array of size {} along axis {} cannot be evenly split into {} sections",
483 axis_len, axis, n_sections,
484 )));
485 }
486 let chunk_size = axis_len / n_sections;
487 let indices: Vec<usize> = (1..n_sections).map(|i| i * chunk_size).collect();
488 array_split(a, &indices, axis)
489}
490
491pub fn array_split<T: Element>(
500 a: &Array<T, IxDyn>,
501 indices: &[usize],
502 axis: usize,
503) -> FerrayResult<Vec<Array<T, IxDyn>>> {
504 let shape = a.shape();
505 let ndim = shape.len();
506 if axis >= ndim {
507 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
508 }
509 let axis_len = shape[axis];
510 let src_data: Vec<T> = a.iter().cloned().collect();
511
512 let mut splits = Vec::with_capacity(indices.len() + 2);
514 splits.push(0);
515 for &idx in indices {
516 splits.push(idx.min(axis_len));
517 }
518 splits.push(axis_len);
519
520 let mut src_strides = vec![1usize; ndim];
522 for i in (0..ndim - 1).rev() {
523 src_strides[i] = src_strides[i + 1] * shape[i + 1];
524 }
525
526 let mut result = Vec::with_capacity(splits.len() - 1);
527 for w in splits.windows(2) {
528 let start = w[0];
529 let end = w[1];
530 let chunk_len = end - start;
531
532 let mut sub_shape = shape.to_vec();
533 sub_shape[axis] = chunk_len;
534 let sub_total: usize = sub_shape.iter().product();
535
536 let mut sub_strides = vec![1usize; ndim];
538 for i in (0..ndim - 1).rev() {
539 sub_strides[i] = sub_strides[i + 1] * sub_shape[i + 1];
540 }
541
542 let mut sub_data = Vec::with_capacity(sub_total);
543 for flat in 0..sub_total {
544 let mut rem = flat;
546 let mut src_flat = 0usize;
547 for i in 0..ndim {
548 let idx = rem / sub_strides[i];
549 rem %= sub_strides[i];
550 let src_idx = if i == axis { idx + start } else { idx };
551 src_flat += src_idx * src_strides[i];
552 }
553 sub_data.push(src_data[src_flat].clone());
554 }
555 result.push(Array::from_vec(IxDyn::new(&sub_shape), sub_data)?);
556 }
557
558 Ok(result)
559}
560
561pub fn vsplit<T: Element>(
565 a: &Array<T, IxDyn>,
566 n_sections: usize,
567) -> FerrayResult<Vec<Array<T, IxDyn>>> {
568 split(a, n_sections, 0)
569}
570
571pub fn hsplit<T: Element>(
575 a: &Array<T, IxDyn>,
576 n_sections: usize,
577) -> FerrayResult<Vec<Array<T, IxDyn>>> {
578 split(a, n_sections, 1)
579}
580
581pub fn dsplit<T: Element>(
585 a: &Array<T, IxDyn>,
586 n_sections: usize,
587) -> FerrayResult<Vec<Array<T, IxDyn>>> {
588 split(a, n_sections, 2)
589}
590
591pub fn transpose<T: Element, D: Dimension>(
606 a: &Array<T, D>,
607 axes: Option<&[usize]>,
608) -> FerrayResult<Array<T, IxDyn>> {
609 let shape = a.shape();
610 let ndim = shape.len();
611 let perm: Vec<usize> = match axes {
612 Some(ax) => {
613 if ax.len() != ndim {
614 return Err(FerrayError::invalid_value(format!(
615 "axes must have length {} but got {}",
616 ndim,
617 ax.len(),
618 )));
619 }
620 let mut seen = vec![false; ndim];
622 for &a in ax {
623 if a >= ndim {
624 return Err(FerrayError::axis_out_of_bounds(a, ndim));
625 }
626 if seen[a] {
627 return Err(FerrayError::invalid_value(format!(
628 "duplicate axis {} in transpose",
629 a,
630 )));
631 }
632 seen[a] = true;
633 }
634 ax.to_vec()
635 }
636 None => (0..ndim).rev().collect(),
637 };
638
639 let new_shape: Vec<usize> = perm.iter().map(|&ax| shape[ax]).collect();
640 let total: usize = new_shape.iter().product();
641 let src_data: Vec<T> = a.iter().cloned().collect();
642
643 let mut src_strides = vec![1usize; ndim];
645 for i in (0..ndim.saturating_sub(1)).rev() {
646 src_strides[i] = src_strides[i + 1] * shape[i + 1];
647 }
648
649 let mut out_strides = vec![1usize; ndim];
651 for i in (0..ndim.saturating_sub(1)).rev() {
652 out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
653 }
654
655 let mut data = Vec::with_capacity(total);
656 for flat_out in 0..total {
657 let mut rem = flat_out;
659 let mut src_flat = 0usize;
660 #[allow(clippy::needless_range_loop)]
661 for i in 0..ndim {
662 let idx = rem / out_strides[i];
663 rem %= out_strides[i];
664 src_flat += idx * src_strides[perm[i]];
666 }
667 data.push(src_data[src_flat].clone());
668 }
669
670 Array::from_vec(IxDyn::new(&new_shape), data)
671}
672
673pub fn swapaxes<T: Element, D: Dimension>(
680 a: &Array<T, D>,
681 axis1: usize,
682 axis2: usize,
683) -> FerrayResult<Array<T, IxDyn>> {
684 let ndim = a.ndim();
685 if axis1 >= ndim {
686 return Err(FerrayError::axis_out_of_bounds(axis1, ndim));
687 }
688 if axis2 >= ndim {
689 return Err(FerrayError::axis_out_of_bounds(axis2, ndim));
690 }
691 let mut perm: Vec<usize> = (0..ndim).collect();
692 perm.swap(axis1, axis2);
693 transpose(a, Some(&perm))
694}
695
696pub fn moveaxis<T: Element, D: Dimension>(
703 a: &Array<T, D>,
704 source: usize,
705 destination: usize,
706) -> FerrayResult<Array<T, IxDyn>> {
707 let ndim = a.ndim();
708 if source >= ndim {
709 return Err(FerrayError::axis_out_of_bounds(source, ndim));
710 }
711 if destination >= ndim {
712 return Err(FerrayError::axis_out_of_bounds(destination, ndim));
713 }
714 let mut order: Vec<usize> = (0..ndim).filter(|&x| x != source).collect();
716 order.insert(destination, source);
717 transpose(a, Some(&order))
718}
719
720pub fn rollaxis<T: Element, D: Dimension>(
727 a: &Array<T, D>,
728 axis: usize,
729 start: usize,
730) -> FerrayResult<Array<T, IxDyn>> {
731 let ndim = a.ndim();
732 if axis >= ndim {
733 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
734 }
735 if start > ndim {
736 return Err(FerrayError::axis_out_of_bounds(start, ndim + 1));
737 }
738 let dst = if start > axis { start - 1 } else { start };
739 if axis == dst {
740 let data: Vec<T> = a.iter().cloned().collect();
742 return Array::from_vec(IxDyn::new(a.shape()), data);
743 }
744 moveaxis(a, axis, dst)
745}
746
747pub fn flip<T: Element, D: Dimension>(
754 a: &Array<T, D>,
755 axis: usize,
756) -> FerrayResult<Array<T, IxDyn>> {
757 let shape = a.shape();
758 let ndim = shape.len();
759 if axis >= ndim {
760 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
761 }
762 let src_data: Vec<T> = a.iter().cloned().collect();
763 let total = src_data.len();
764
765 let mut strides = vec![1usize; ndim];
767 for i in (0..ndim.saturating_sub(1)).rev() {
768 strides[i] = strides[i + 1] * shape[i + 1];
769 }
770
771 let mut data = Vec::with_capacity(total);
772 for flat in 0..total {
773 let mut rem = flat;
774 let mut src_flat = 0usize;
775 for i in 0..ndim {
776 let idx = rem / strides[i];
777 rem %= strides[i];
778 let src_idx = if i == axis { shape[i] - 1 - idx } else { idx };
779 src_flat += src_idx * strides[i];
780 }
781 data.push(src_data[src_flat].clone());
782 }
783 Array::from_vec(IxDyn::new(shape), data)
784}
785
786pub fn fliplr<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
793 if a.ndim() < 2 {
794 return Err(FerrayError::invalid_value(
795 "fliplr: array must be at least 2-D",
796 ));
797 }
798 flip(a, 1)
799}
800
801pub fn flipud<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
808 if a.ndim() < 1 {
809 return Err(FerrayError::invalid_value(
810 "flipud: array must be at least 1-D",
811 ));
812 }
813 flip(a, 0)
814}
815
816pub fn rot90<T: Element, D: Dimension>(a: &Array<T, D>, k: i32) -> FerrayResult<Array<T, IxDyn>> {
825 if a.ndim() < 2 {
826 return Err(FerrayError::invalid_value(
827 "rot90: array must be at least 2-D",
828 ));
829 }
830 let k = k.rem_euclid(4);
832 let shape = a.shape();
833 let data: Vec<T> = a.iter().cloned().collect();
834
835 let as_dyn = Array::from_vec(IxDyn::new(shape), data)?;
837
838 match k {
839 0 => Ok(as_dyn),
840 1 => {
841 let flipped = flip(&as_dyn, 1)?;
843 swapaxes(&flipped, 0, 1)
844 }
845 2 => {
846 let f1 = flip(&as_dyn, 0)?;
848 flip(&f1, 1)
849 }
850 3 => {
851 let transposed = swapaxes(&as_dyn, 0, 1)?;
853 flip(&transposed, 1)
854 }
855 _ => unreachable!(),
856 }
857}
858
859pub fn roll<T: Element, D: Dimension>(
869 a: &Array<T, D>,
870 shift: isize,
871 axis: Option<usize>,
872) -> FerrayResult<Array<T, IxDyn>> {
873 match axis {
874 None => {
875 let data: Vec<T> = a.iter().cloned().collect();
877 let n = data.len();
878 if n == 0 {
879 return Array::from_vec(IxDyn::new(a.shape()), data);
880 }
881 let shift = ((shift % n as isize) + n as isize) as usize % n;
882 let mut rolled = Vec::with_capacity(n);
883 for i in 0..n {
884 rolled.push(data[(n + i - shift) % n].clone());
885 }
886 Array::from_vec(IxDyn::new(a.shape()), rolled)
887 }
888 Some(ax) => {
889 let shape = a.shape();
890 let ndim = shape.len();
891 if ax >= ndim {
892 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
893 }
894 let axis_len = shape[ax];
895 if axis_len == 0 {
896 let data: Vec<T> = a.iter().cloned().collect();
897 return Array::from_vec(IxDyn::new(shape), data);
898 }
899 let shift = ((shift % axis_len as isize) + axis_len as isize) as usize % axis_len;
900 let src_data: Vec<T> = a.iter().cloned().collect();
901 let total = src_data.len();
902
903 let mut strides = vec![1usize; ndim];
905 for i in (0..ndim.saturating_sub(1)).rev() {
906 strides[i] = strides[i + 1] * shape[i + 1];
907 }
908
909 let mut data = Vec::with_capacity(total);
910 for flat in 0..total {
911 let mut rem = flat;
912 let mut src_flat = 0usize;
913 #[allow(clippy::needless_range_loop)]
914 for i in 0..ndim {
915 let idx = rem / strides[i];
916 rem %= strides[i];
917 let src_idx = if i == ax {
918 (axis_len + idx - shift) % axis_len
919 } else {
920 idx
921 };
922 src_flat += src_idx * strides[i];
923 }
924 data.push(src_data[src_flat].clone());
925 }
926 Array::from_vec(IxDyn::new(shape), data)
927 }
928 }
929}
930
931#[cfg(test)]
936mod tests {
937 use super::*;
938
939 fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
940 Array::from_vec(IxDyn::new(shape), data).unwrap()
941 }
942
943 #[test]
946 fn test_reshape() {
947 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
948 let b = reshape(&a, &[3, 2]).unwrap();
949 assert_eq!(b.shape(), &[3, 2]);
950 let data: Vec<f64> = b.iter().copied().collect();
951 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
952 }
953
954 #[test]
955 fn test_reshape_size_mismatch() {
956 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
957 assert!(reshape(&a, &[2, 4]).is_err());
958 }
959
960 #[test]
961 fn test_ravel() {
962 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
963 let b = ravel(&a).unwrap();
964 assert_eq!(b.shape(), &[6]);
965 assert_eq!(b.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
966 }
967
968 #[test]
969 fn test_flatten() {
970 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
971 let b = flatten(&a).unwrap();
972 assert_eq!(b.shape(), &[6]);
973 }
974
975 #[test]
976 fn test_squeeze() {
977 let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
978 let b = squeeze(&a, None).unwrap();
979 assert_eq!(b.shape(), &[3]);
980 }
981
982 #[test]
983 fn test_squeeze_specific_axis() {
984 let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
985 let b = squeeze(&a, Some(0)).unwrap();
986 assert_eq!(b.shape(), &[3, 1]);
987 }
988
989 #[test]
990 fn test_squeeze_not_size_1() {
991 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
992 assert!(squeeze(&a, Some(0)).is_err());
993 }
994
995 #[test]
996 fn test_expand_dims() {
997 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
998 let b = expand_dims(&a, 0).unwrap();
999 assert_eq!(b.shape(), &[1, 3]);
1000 let c = expand_dims(&a, 1).unwrap();
1001 assert_eq!(c.shape(), &[3, 1]);
1002 }
1003
1004 #[test]
1005 fn test_expand_dims_oob() {
1006 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1007 assert!(expand_dims(&a, 3).is_err());
1008 }
1009
1010 #[test]
1011 fn test_broadcast_to() {
1012 let a = dyn_arr(&[1, 3], vec![1.0, 2.0, 3.0]);
1013 let b = broadcast_to(&a, &[3, 3]).unwrap();
1014 assert_eq!(b.shape(), &[3, 3]);
1015 let data: Vec<f64> = b.iter().copied().collect();
1016 assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
1017 }
1018
1019 #[test]
1020 fn test_broadcast_to_1d_to_2d() {
1021 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1022 let b = broadcast_to(&a, &[2, 3]).unwrap();
1023 assert_eq!(b.shape(), &[2, 3]);
1024 }
1025
1026 #[test]
1027 fn test_broadcast_to_incompatible() {
1028 let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
1029 assert!(broadcast_to(&a, &[3]).is_err());
1030 }
1031
1032 #[test]
1035 fn test_concatenate() {
1036 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1037 let b = dyn_arr(&[2, 3], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1038 let c = concatenate(&[a, b], 0).unwrap();
1039 assert_eq!(c.shape(), &[4, 3]);
1040 }
1041
1042 #[test]
1043 fn test_concatenate_axis1() {
1044 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1045 let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1046 let c = concatenate(&[a, b], 1).unwrap();
1047 assert_eq!(c.shape(), &[2, 5]);
1048 }
1049
1050 #[test]
1051 fn test_concatenate_shape_mismatch() {
1052 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1053 let b = dyn_arr(&[3, 3], vec![1.0; 9]);
1054 let c = concatenate(&[a, b], 0).unwrap();
1058 assert_eq!(c.shape(), &[5, 3]);
1059 }
1060
1061 #[test]
1062 fn test_concatenate_empty() {
1063 let v: Vec<Array<f64, IxDyn>> = vec![];
1064 assert!(concatenate(&v, 0).is_err());
1065 }
1066
1067 #[test]
1068 fn test_stack() {
1069 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1070 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1071 let c = stack(&[a, b], 0).unwrap();
1072 assert_eq!(c.shape(), &[2, 3]);
1073 let data: Vec<f64> = c.iter().copied().collect();
1074 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1075 }
1076
1077 #[test]
1078 fn test_stack_axis1() {
1079 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1080 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1081 let c = stack(&[a, b], 1).unwrap();
1082 assert_eq!(c.shape(), &[3, 2]);
1083 let data: Vec<f64> = c.iter().copied().collect();
1084 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1085 }
1086
1087 #[test]
1088 fn test_vstack() {
1089 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1090 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1091 let c = vstack(&[a, b]).unwrap();
1092 assert_eq!(c.shape(), &[2, 3]);
1093 }
1094
1095 #[test]
1096 fn test_hstack() {
1097 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1098 let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
1099 let c = hstack(&[a, b]).unwrap();
1100 assert_eq!(c.shape(), &[6]);
1101 }
1102
1103 #[test]
1104 fn test_hstack_2d() {
1105 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1106 let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
1107 let c = hstack(&[a, b]).unwrap();
1108 assert_eq!(c.shape(), &[2, 5]);
1109 }
1110
1111 #[test]
1112 fn test_dstack() {
1113 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1114 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1115 let c = dstack(&[a, b]).unwrap();
1116 assert_eq!(c.shape(), &[2, 2, 2]);
1117 }
1118
1119 #[test]
1120 fn test_block() {
1121 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1122 let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1123 let c = dyn_arr(&[2, 2], vec![9.0, 10.0, 11.0, 12.0]);
1124 let d = dyn_arr(&[2, 2], vec![13.0, 14.0, 15.0, 16.0]);
1125 let result = block(&[vec![a, b], vec![c, d]]).unwrap();
1126 assert_eq!(result.shape(), &[4, 4]);
1127 }
1128
1129 #[test]
1130 fn test_split() {
1131 let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1132 let parts = split(&a, 3, 0).unwrap();
1133 assert_eq!(parts.len(), 3);
1134 assert_eq!(parts[0].shape(), &[2]);
1135 assert_eq!(parts[1].shape(), &[2]);
1136 assert_eq!(parts[2].shape(), &[2]);
1137 }
1138
1139 #[test]
1140 fn test_split_uneven() {
1141 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1142 assert!(split(&a, 3, 0).is_err()); }
1144
1145 #[test]
1146 fn test_array_split() {
1147 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1148 let parts = array_split(&a, &[2, 4], 0).unwrap();
1149 assert_eq!(parts.len(), 3);
1150 assert_eq!(parts[0].shape(), &[2]); assert_eq!(parts[1].shape(), &[2]); assert_eq!(parts[2].shape(), &[1]); }
1154
1155 #[test]
1156 fn test_vsplit() {
1157 let a = dyn_arr(&[4, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1158 let parts = vsplit(&a, 2).unwrap();
1159 assert_eq!(parts.len(), 2);
1160 assert_eq!(parts[0].shape(), &[2, 2]);
1161 }
1162
1163 #[test]
1164 fn test_hsplit() {
1165 let a = dyn_arr(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1166 let parts = hsplit(&a, 2).unwrap();
1167 assert_eq!(parts.len(), 2);
1168 assert_eq!(parts[0].shape(), &[2, 2]);
1169 }
1170
1171 #[test]
1174 fn test_transpose_2d() {
1175 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1176 let b = transpose(&a, None).unwrap();
1177 assert_eq!(b.shape(), &[3, 2]);
1178 let data: Vec<f64> = b.iter().copied().collect();
1179 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1180 }
1181
1182 #[test]
1183 fn test_transpose_explicit() {
1184 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1185 let b = transpose(&a, Some(&[1, 0])).unwrap();
1186 assert_eq!(b.shape(), &[3, 2]);
1187 }
1188
1189 #[test]
1190 fn test_transpose_bad_axes() {
1191 let a = dyn_arr(&[2, 3], vec![1.0; 6]);
1192 assert!(transpose(&a, Some(&[0])).is_err()); }
1194
1195 #[test]
1196 fn test_swapaxes() {
1197 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1198 let b = swapaxes(&a, 0, 2).unwrap();
1199 assert_eq!(b.shape(), &[4, 3, 2]);
1200 }
1201
1202 #[test]
1203 fn test_moveaxis() {
1204 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1205 let b = moveaxis(&a, 0, 2).unwrap();
1206 assert_eq!(b.shape(), &[3, 4, 2]);
1207 }
1208
1209 #[test]
1210 fn test_rollaxis() {
1211 let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
1212 let b = rollaxis(&a, 2, 0).unwrap();
1213 assert_eq!(b.shape(), &[4, 2, 3]);
1214 }
1215
1216 #[test]
1217 fn test_flip() {
1218 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1219 let b = flip(&a, 0).unwrap();
1220 let data: Vec<f64> = b.iter().copied().collect();
1221 assert_eq!(data, vec![3.0, 2.0, 1.0]);
1222 }
1223
1224 #[test]
1225 fn test_flip_2d() {
1226 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1227 let b = flip(&a, 0).unwrap();
1228 let data: Vec<f64> = b.iter().copied().collect();
1229 assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1230
1231 let c = flip(&a, 1).unwrap();
1232 let data2: Vec<f64> = c.iter().copied().collect();
1233 assert_eq!(data2, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1234 }
1235
1236 #[test]
1237 fn test_fliplr() {
1238 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1239 let b = fliplr(&a).unwrap();
1240 let data: Vec<f64> = b.iter().copied().collect();
1241 assert_eq!(data, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
1242 }
1243
1244 #[test]
1245 fn test_flipud() {
1246 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1247 let b = flipud(&a).unwrap();
1248 let data: Vec<f64> = b.iter().copied().collect();
1249 assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
1250 }
1251
1252 #[test]
1253 fn test_fliplr_1d_err() {
1254 let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
1255 assert!(fliplr(&a).is_err());
1256 }
1257
1258 #[test]
1259 fn test_rot90_once() {
1260 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1262 let b = rot90(&a, 1).unwrap();
1263 assert_eq!(b.shape(), &[2, 2]);
1264 let data: Vec<f64> = b.iter().copied().collect();
1265 assert_eq!(data, vec![2.0, 4.0, 1.0, 3.0]);
1266 }
1267
1268 #[test]
1269 fn test_rot90_twice() {
1270 let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1271 let b = rot90(&a, 2).unwrap();
1272 let data: Vec<f64> = b.iter().copied().collect();
1273 assert_eq!(data, vec![4.0, 3.0, 2.0, 1.0]);
1274 }
1275
1276 #[test]
1277 fn test_rot90_four_is_identity() {
1278 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1279 let b = rot90(&a, 4).unwrap();
1280 let data_a: Vec<f64> = a.iter().copied().collect();
1281 let data_b: Vec<f64> = b.iter().copied().collect();
1282 assert_eq!(data_a, data_b);
1283 assert_eq!(a.shape(), b.shape());
1284 }
1285
1286 #[test]
1287 fn test_roll_flat() {
1288 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1289 let b = roll(&a, 2, None).unwrap();
1290 let data: Vec<f64> = b.iter().copied().collect();
1291 assert_eq!(data, vec![4.0, 5.0, 1.0, 2.0, 3.0]);
1292 }
1293
1294 #[test]
1295 fn test_roll_negative() {
1296 let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1297 let b = roll(&a, -2, None).unwrap();
1298 let data: Vec<f64> = b.iter().copied().collect();
1299 assert_eq!(data, vec![3.0, 4.0, 5.0, 1.0, 2.0]);
1300 }
1301
1302 #[test]
1303 fn test_roll_axis() {
1304 let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1305 let b = roll(&a, 1, Some(1)).unwrap();
1306 let data: Vec<f64> = b.iter().copied().collect();
1307 assert_eq!(data, vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1308 }
1309}