1use super::normalize_index;
13use crate::array::owned::Array;
14use crate::dimension::{Axis, Dimension, Ix2, IxDyn};
15use crate::dtype::Element;
16use crate::error::{FerrayError, FerrayResult};
17
18pub fn take<T: Element, D: Dimension>(
30 a: &Array<T, D>,
31 indices: &[isize],
32 axis: Axis,
33) -> FerrayResult<Array<T, IxDyn>> {
34 a.index_select(axis, indices)
35}
36
37pub fn take_along_axis<T: Element, D: Dimension>(
47 a: &Array<T, D>,
48 indices: &[isize],
49 axis: Axis,
50) -> FerrayResult<Array<T, IxDyn>> {
51 a.index_select(axis, indices)
52}
53
54impl<T: Element, D: Dimension> Array<T, D> {
59 pub fn put(&mut self, indices: &[isize], values: &[T]) -> FerrayResult<()> {
69 if values.is_empty() {
70 return Err(FerrayError::invalid_value("values must not be empty"));
71 }
72 let size = self.size();
73 let normalized: Vec<usize> = indices
74 .iter()
75 .map(|&idx| normalize_index(idx, size, 0))
76 .collect::<FerrayResult<Vec<_>>>()?;
77
78 let mut flat: Vec<&mut T> = self.inner.iter_mut().collect();
79
80 for (i, &idx) in normalized.iter().enumerate() {
81 let val_idx = i % values.len();
82 *flat[idx] = values[val_idx].clone();
83 }
84 Ok(())
85 }
86
87 pub fn put_along_axis(
96 &mut self,
97 indices: &[isize],
98 values: &Array<T, IxDyn>,
99 axis: Axis,
100 ) -> FerrayResult<()>
101 where
102 D::NdarrayDim: ndarray::RemoveAxis,
103 {
104 let ndim = self.ndim();
105 let ax = axis.index();
106 if ax >= ndim {
107 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
108 }
109 let axis_size = self.shape()[ax];
110
111 let normalized: Vec<usize> = indices
112 .iter()
113 .map(|&idx| normalize_index(idx, axis_size, ax))
114 .collect::<FerrayResult<Vec<_>>>()?;
115
116 let nd_axis = ndarray::Axis(ax);
117 let mut val_iter = values.inner.iter();
118
119 for &idx in &normalized {
120 let mut sub = self.inner.index_axis_mut(nd_axis, idx);
121 for elem in &mut sub {
122 if let Some(v) = val_iter.next() {
123 *elem = v.clone();
124 }
125 }
126 }
127 Ok(())
128 }
129
130 pub fn fill_diagonal(&mut self, val: T) {
137 let shape = self.shape().to_vec();
138 if shape.is_empty() {
139 return;
140 }
141 let min_dim = *shape.iter().min().unwrap_or(&0);
142 let ndim = shape.len();
143
144 for i in 0..min_dim {
145 let idx: Vec<usize> = vec![i; ndim];
146 let nd_idx = ndarray::IxDyn(&idx);
147 let mut dyn_view = self.inner.view_mut().into_dyn();
148 dyn_view[nd_idx] = val.clone();
149 }
150 }
151}
152
153pub fn choose<T: Element, D: Dimension>(
167 index_arr: &Array<u64, D>,
168 choices: &[Array<T, D>],
169) -> FerrayResult<Array<T, IxDyn>> {
170 if choices.is_empty() {
171 return Err(FerrayError::invalid_value("choices must not be empty"));
172 }
173
174 let shape = index_arr.shape();
175 for (i, c) in choices.iter().enumerate() {
176 if c.shape() != shape {
177 return Err(FerrayError::shape_mismatch(format!(
178 "choice[{}] shape {:?} does not match index array shape {:?}",
179 i,
180 c.shape(),
181 shape
182 )));
183 }
184 }
185
186 let n_choices = choices.len();
187 let choice_iters: Vec<Vec<T>> = choices
188 .iter()
189 .map(|c| c.inner.iter().cloned().collect())
190 .collect();
191
192 let mut data = Vec::with_capacity(index_arr.size());
193 for (pos, idx_val) in index_arr.inner.iter().enumerate() {
194 let idx = *idx_val as usize;
195 if idx >= n_choices {
196 return Err(FerrayError::index_out_of_bounds(idx as isize, 0, n_choices));
197 }
198 data.push(choice_iters[idx][pos].clone());
199 }
200
201 let dyn_shape = IxDyn::new(shape);
202 Array::from_vec(dyn_shape, data)
203}
204
205pub fn compress<T: Element, D: Dimension>(
217 condition: &[bool],
218 a: &Array<T, D>,
219 axis: Axis,
220) -> FerrayResult<Array<T, IxDyn>> {
221 let ndim = a.ndim();
222 let ax = axis.index();
223 if ax >= ndim {
224 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
225 }
226 let axis_size = a.shape()[ax];
227 if condition.len() > axis_size {
228 return Err(FerrayError::shape_mismatch(format!(
229 "condition length {} exceeds axis size {}",
230 condition.len(),
231 axis_size
232 )));
233 }
234
235 let indices: Vec<isize> = condition
236 .iter()
237 .enumerate()
238 .filter_map(|(i, &c)| if c { Some(i as isize) } else { None })
239 .collect();
240
241 a.index_select(axis, &indices)
242}
243
244pub fn select<T: Element, D: Dimension>(
257 condlist: &[Array<bool, D>],
258 choicelist: &[Array<T, D>],
259 default: T,
260) -> FerrayResult<Array<T, IxDyn>> {
261 if condlist.len() != choicelist.len() {
262 return Err(FerrayError::invalid_value(format!(
263 "condlist length {} != choicelist length {}",
264 condlist.len(),
265 choicelist.len()
266 )));
267 }
268 if condlist.is_empty() {
269 return Err(FerrayError::invalid_value(
270 "condlist and choicelist must not be empty",
271 ));
272 }
273
274 let shape = condlist[0].shape();
275 for (i, (c, ch)) in condlist.iter().zip(choicelist.iter()).enumerate() {
276 if c.shape() != shape || ch.shape() != shape {
277 return Err(FerrayError::shape_mismatch(format!(
278 "condlist[{i}]/choicelist[{i}] shape mismatch with reference shape {shape:?}"
279 )));
280 }
281 }
282
283 let size = condlist[0].size();
284 let mut data = vec![default; size];
285
286 for (cond, choice) in condlist.iter().zip(choicelist.iter()).rev() {
288 for (i, (&c, v)) in cond.inner.iter().zip(choice.inner.iter()).enumerate() {
289 if c {
290 data[i] = v.clone();
291 }
292 }
293 }
294
295 let dyn_shape = IxDyn::new(shape);
296 Array::from_vec(dyn_shape, data)
297}
298
299pub fn indices(dimensions: &[usize]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
311 let ndim = dimensions.len();
312 let total: usize = dimensions.iter().product();
313
314 let mut result = Vec::with_capacity(ndim);
315
316 for ax in 0..ndim {
317 let mut data = Vec::with_capacity(total);
318 for flat_idx in 0..total {
319 let mut rem = flat_idx;
320 let mut idx_for_ax = 0;
321 for (d, &dim_size) in dimensions.iter().enumerate().rev() {
322 let coord = rem % dim_size;
323 rem /= dim_size;
324 if d == ax {
325 idx_for_ax = coord;
326 }
327 }
328 data.push(idx_for_ax as u64);
329 }
330 let dim = IxDyn::new(dimensions);
331 result.push(Array::from_vec(dim, data)?);
332 }
333
334 Ok(result)
335}
336
337pub fn ix_(sequences: &[&[u64]]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
349 let ndim = sequences.len();
350 let mut result = Vec::with_capacity(ndim);
351
352 for (i, seq) in sequences.iter().enumerate() {
353 let mut shape = vec![1usize; ndim];
354 shape[i] = seq.len();
355
356 let data = seq.to_vec();
357 let dim = IxDyn::new(&shape);
358 result.push(Array::from_vec(dim, data)?);
359 }
360
361 Ok(result)
362}
363
364#[must_use]
373pub fn diag_indices(n: usize, ndim: usize) -> Vec<Vec<usize>> {
374 let data: Vec<usize> = (0..n).collect();
375 vec![data; ndim]
376}
377
378pub fn diag_indices_from<T: Element, D: Dimension>(
386 a: &Array<T, D>,
387) -> FerrayResult<Vec<Vec<usize>>> {
388 let ndim = a.ndim();
389 if ndim < 2 {
390 return Err(FerrayError::invalid_value(
391 "diag_indices_from requires at least 2 dimensions",
392 ));
393 }
394 let shape = a.shape();
395 let n = shape[0];
396 for &s in &shape[1..] {
397 if s != n {
398 return Err(FerrayError::shape_mismatch(format!(
399 "all dimensions must be equal for diag_indices_from, got {shape:?}"
400 )));
401 }
402 }
403 Ok(diag_indices(n, ndim))
404}
405
406#[must_use]
416pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
417 let m = m.unwrap_or(n);
418 let mut rows = Vec::new();
419 let mut cols = Vec::new();
420
421 for i in 0..n {
422 for j in 0..m {
423 if (j as isize) <= (i as isize) + k {
424 rows.push(i);
425 cols.push(j);
426 }
427 }
428 }
429
430 (rows, cols)
431}
432
433#[must_use]
437pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
438 let m = m.unwrap_or(n);
439 let mut rows = Vec::new();
440 let mut cols = Vec::new();
441
442 for i in 0..n {
443 for j in 0..m {
444 if (j as isize) >= (i as isize) + k {
445 rows.push(i);
446 cols.push(j);
447 }
448 }
449 }
450
451 (rows, cols)
452}
453
454pub fn tril_indices_from<T: Element, D: Dimension>(
459 a: &Array<T, D>,
460 k: isize,
461) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
462 let shape = a.shape();
463 if shape.len() != 2 {
464 return Err(FerrayError::invalid_value(
465 "tril_indices_from requires a 2-D array",
466 ));
467 }
468 Ok(tril_indices(shape[0], k, Some(shape[1])))
469}
470
471pub fn triu_indices_from<T: Element, D: Dimension>(
476 a: &Array<T, D>,
477 k: isize,
478) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
479 let shape = a.shape();
480 if shape.len() != 2 {
481 return Err(FerrayError::invalid_value(
482 "triu_indices_from requires a 2-D array",
483 ));
484 }
485 Ok(triu_indices(shape[0], k, Some(shape[1])))
486}
487
488#[allow(clippy::needless_range_loop)]
501pub fn ravel_multi_index(multi_index: &[&[usize]], dims: &[usize]) -> FerrayResult<Vec<usize>> {
502 if multi_index.len() != dims.len() {
503 return Err(FerrayError::invalid_value(format!(
504 "multi_index has {} components but dims has {} dimensions",
505 multi_index.len(),
506 dims.len()
507 )));
508 }
509 if multi_index.is_empty() {
510 return Ok(vec![]);
511 }
512
513 let n = multi_index[0].len();
514 for (i, idx_arr) in multi_index.iter().enumerate() {
515 if idx_arr.len() != n {
516 return Err(FerrayError::invalid_value(format!(
517 "multi_index[{}] has length {} but expected {}",
518 i,
519 idx_arr.len(),
520 n
521 )));
522 }
523 }
524
525 let ndim = dims.len();
527 let mut strides = vec![1usize; ndim];
528 for i in (0..ndim - 1).rev() {
529 strides[i] = strides[i + 1] * dims[i + 1];
530 }
531
532 let mut flat = Vec::with_capacity(n);
533 #[allow(clippy::needless_range_loop)]
534 for pos in 0..n {
535 let mut linear = 0usize;
536 for (d, &dim_size) in dims.iter().enumerate() {
537 let coord = multi_index[d][pos];
538 if coord >= dim_size {
539 return Err(FerrayError::index_out_of_bounds(
540 coord as isize,
541 d,
542 dim_size,
543 ));
544 }
545 linear += coord * strides[d];
546 }
547 flat.push(linear);
548 }
549
550 Ok(flat)
551}
552
553pub fn unravel_index(flat_indices: &[usize], shape: &[usize]) -> FerrayResult<Vec<Vec<usize>>> {
561 let total: usize = shape.iter().product();
562 let ndim = shape.len();
563 let n = flat_indices.len();
564
565 let mut result: Vec<Vec<usize>> = vec![Vec::with_capacity(n); ndim];
566
567 for &flat_idx in flat_indices {
568 if flat_idx >= total {
569 return Err(FerrayError::index_out_of_bounds(
570 flat_idx as isize,
571 0,
572 total,
573 ));
574 }
575 let mut rem = flat_idx;
576 for (d, &dim_size) in shape.iter().enumerate().rev() {
577 result[d].push(rem % dim_size);
578 rem /= dim_size;
579 }
580 }
581
582 Ok(result)
583}
584
585pub fn flatnonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<usize> {
594 let zero = T::zero();
595 a.inner
596 .iter()
597 .enumerate()
598 .filter_map(|(i, val)| if *val == zero { None } else { Some(i) })
599 .collect()
600}
601
602pub fn nonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<Vec<usize>> {
622 let zero = T::zero();
623 let ndim = a.ndim();
624 let mut result: Vec<Vec<usize>> = vec![Vec::new(); ndim];
625 for (idx, val) in a.indexed_iter() {
626 if *val != zero {
627 for (d, &c) in idx.iter().enumerate() {
628 result[d].push(c);
629 }
630 }
631 }
632 result
633}
634
635pub fn argwhere<T: Element + PartialEq, D: Dimension>(
648 a: &Array<T, D>,
649) -> FerrayResult<Array<i64, Ix2>> {
650 let zero = T::zero();
651 let ndim = a.ndim();
652 let mut data: Vec<i64> = Vec::new();
653 let mut count: usize = 0;
654 for (idx, val) in a.indexed_iter() {
655 if *val != zero {
656 for &c in &idx {
657 data.push(c as i64);
658 }
659 count += 1;
660 }
661 }
662 Array::<i64, Ix2>::from_vec(Ix2::new([count, ndim]), data)
663}
664
665pub struct NdIndex {
673 shape: Vec<usize>,
674 current: Vec<usize>,
675 done: bool,
676}
677
678impl NdIndex {
679 fn new(shape: &[usize]) -> Self {
680 let done = shape.contains(&0);
681 Self {
682 shape: shape.to_vec(),
683 current: vec![0; shape.len()],
684 done,
685 }
686 }
687}
688
689impl Iterator for NdIndex {
690 type Item = Vec<usize>;
691
692 fn next(&mut self) -> Option<Self::Item> {
693 if self.done {
694 return None;
695 }
696
697 let result = self.current.clone();
698
699 let mut carry = true;
701 for i in (0..self.shape.len()).rev() {
702 if carry {
703 self.current[i] += 1;
704 if self.current[i] >= self.shape[i] {
705 self.current[i] = 0;
706 carry = true;
707 } else {
708 carry = false;
709 }
710 }
711 }
712 if carry {
713 self.done = true;
714 }
715
716 Some(result)
717 }
718
719 fn size_hint(&self) -> (usize, Option<usize>) {
720 if self.done {
721 return (0, Some(0));
722 }
723 let total: usize = self.shape.iter().product();
724 let mut yielded = 0usize;
726 let ndim = self.shape.len();
727 let mut stride = 1usize;
728 for i in (0..ndim).rev() {
729 yielded += self.current[i] * stride;
730 stride *= self.shape[i];
731 }
732 let remaining = total - yielded;
733 (remaining, Some(remaining))
734 }
735}
736
737#[must_use]
741pub fn ndindex(shape: &[usize]) -> NdIndex {
742 NdIndex::new(shape)
743}
744
745pub fn ndenumerate<T: Element, D: Dimension>(
749 a: &Array<T, D>,
750) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
751 let shape = a.shape().to_vec();
752 let ndim = shape.len();
753 a.inner.iter().enumerate().map(move |(flat_idx, val)| {
754 let mut idx = vec![0usize; ndim];
755 let mut rem = flat_idx;
756 for (d, s) in shape.iter().enumerate().rev() {
757 if *s > 0 {
758 idx[d] = rem % s;
759 rem /= s;
760 }
761 }
762 (idx, val)
763 })
764}
765
766pub fn where_select<T: Element + Copy, D: Dimension>(
782 condition: &Array<bool, D>,
783 x: &Array<T, D>,
784 y: &Array<T, D>,
785) -> FerrayResult<Array<T, D>> {
786 if condition.shape() != x.shape() || condition.shape() != y.shape() {
787 return Err(FerrayError::shape_mismatch(format!(
788 "where_select: condition shape {:?}, x shape {:?}, y shape {:?} must all match",
789 condition.shape(),
790 x.shape(),
791 y.shape()
792 )));
793 }
794 let data: Vec<T> = condition
795 .iter()
796 .zip(x.iter().zip(y.iter()))
797 .map(|(&c, (&xi, &yi))| if c { xi } else { yi })
798 .collect();
799 Array::from_vec(x.dim().clone(), data)
800}
801
802pub fn place<T: Element + Copy, D: Dimension>(
818 a: &mut Array<T, D>,
819 mask: &Array<bool, D>,
820 vals: &[T],
821) -> FerrayResult<()> {
822 if a.shape() != mask.shape() {
823 return Err(FerrayError::shape_mismatch(format!(
824 "place: mask shape {:?} differs from array shape {:?}",
825 mask.shape(),
826 a.shape(),
827 )));
828 }
829 let hits: usize = mask.iter().filter(|&&m| m).count();
830 if hits > 0 && vals.is_empty() {
831 return Err(FerrayError::invalid_value(
832 "place: vals must be non-empty when mask has any true entries",
833 ));
834 }
835 let mut vi = 0usize;
836 for (slot, &m) in a.inner.iter_mut().zip(mask.iter()) {
837 if m {
838 *slot = vals[vi % vals.len()];
839 vi += 1;
840 }
841 }
842 Ok(())
843}
844
845pub fn putmask<T: Element + Copy, D: Dimension>(
858 a: &mut Array<T, D>,
859 mask: &Array<bool, D>,
860 values: &[T],
861) -> FerrayResult<()> {
862 if a.shape() != mask.shape() {
863 return Err(FerrayError::shape_mismatch(format!(
864 "putmask: mask shape {:?} differs from array shape {:?}",
865 mask.shape(),
866 a.shape(),
867 )));
868 }
869 let n = a.size();
870 let scalar_mode = values.len() == 1;
871 if !scalar_mode && values.len() != n {
872 return Err(FerrayError::shape_mismatch(format!(
873 "putmask: values length {} must be 1 or equal to array size {}",
874 values.len(),
875 n,
876 )));
877 }
878 for (i, (slot, &m)) in a.inner.iter_mut().zip(mask.iter()).enumerate() {
879 if m {
880 *slot = if scalar_mode { values[0] } else { values[i] };
881 }
882 }
883 Ok(())
884}
885
886pub fn extract<T: Element + Copy, D: Dimension>(
895 condition: &Array<bool, D>,
896 a: &Array<T, D>,
897) -> FerrayResult<Array<T, crate::dimension::Ix1>> {
898 if condition.shape() != a.shape() {
899 return Err(FerrayError::shape_mismatch(format!(
900 "extract: condition shape {:?} differs from array shape {:?}",
901 condition.shape(),
902 a.shape(),
903 )));
904 }
905 let data: Vec<T> = condition
906 .iter()
907 .zip(a.iter())
908 .filter_map(|(&c, &v)| if c { Some(v) } else { None })
909 .collect();
910 let n = data.len();
911 Array::from_vec(crate::dimension::Ix1::new([n]), data)
912}
913
914#[derive(Debug, Clone, Copy, PartialEq, Eq)]
916pub enum MaskKind {
917 Tril,
919 Triu,
921 Diag,
923}
924
925pub fn mask_indices(n: usize, kind: MaskKind, k: isize) -> Vec<usize> {
933 let mut idx = Vec::new();
934 for i in 0..n {
935 for j in 0..n {
936 let select = match kind {
937 MaskKind::Tril => (j as isize) <= (i as isize) + k,
938 MaskKind::Triu => (j as isize) >= (i as isize) + k,
939 MaskKind::Diag => (j as isize) == (i as isize) + k,
940 };
941 if select {
942 idx.push(i * n + j);
943 }
944 }
945 }
946 idx
947}
948
949#[cfg(test)]
950mod tests {
951 use super::*;
952 use crate::dimension::{Ix1, Ix2};
953
954 #[test]
959 fn take_1d() {
960 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
961 let taken = take(&arr, &[0, 2, 4], Axis(0)).unwrap();
962 assert_eq!(taken.shape(), &[3]);
963 let data: Vec<i32> = taken.iter().copied().collect();
964 assert_eq!(data, vec![10, 30, 50]);
965 }
966
967 #[test]
968 fn take_2d_axis1() {
969 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
970 let taken = take(&arr, &[0, 2], Axis(1)).unwrap();
971 assert_eq!(taken.shape(), &[3, 2]);
972 let data: Vec<i32> = taken.iter().copied().collect();
973 assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
974 }
975
976 #[test]
977 fn take_negative_indices() {
978 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
979 let taken = take(&arr, &[-1, -3], Axis(0)).unwrap();
980 let data: Vec<i32> = taken.iter().copied().collect();
981 assert_eq!(data, vec![40, 20]);
982 }
983
984 #[test]
989 fn take_along_axis_basic() {
990 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
991 let taken = take_along_axis(&arr, &[1, 3], Axis(1)).unwrap();
992 assert_eq!(taken.shape(), &[3, 2]);
993 }
994
995 #[test]
1000 fn put_flat() {
1001 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 0, 0, 0, 0]).unwrap();
1002 arr.put(&[1, 3], &[99, 88]).unwrap();
1003 assert_eq!(arr.as_slice().unwrap(), &[0, 99, 0, 88, 0]);
1004 }
1005
1006 #[test]
1007 fn put_cycling_values() {
1008 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0; 5]).unwrap();
1009 arr.put(&[0, 1, 2, 3, 4], &[10, 20]).unwrap();
1010 assert_eq!(arr.as_slice().unwrap(), &[10, 20, 10, 20, 10]);
1011 }
1012
1013 #[test]
1014 fn put_out_of_bounds() {
1015 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1016 assert!(arr.put(&[5], &[1]).is_err());
1017 }
1018
1019 #[test]
1024 fn fill_diagonal_2d() {
1025 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 3]), vec![0; 9]).unwrap();
1026 arr.fill_diagonal(1);
1027 let data: Vec<i32> = arr.iter().copied().collect();
1028 assert_eq!(data, vec![1, 0, 0, 0, 1, 0, 0, 0, 1]);
1029 }
1030
1031 #[test]
1032 fn fill_diagonal_rectangular() {
1033 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 4]), vec![0; 8]).unwrap();
1034 arr.fill_diagonal(5);
1035 let data: Vec<i32> = arr.iter().copied().collect();
1036 assert_eq!(data, vec![5, 0, 0, 0, 0, 5, 0, 0]);
1037 }
1038
1039 #[test]
1044 fn choose_basic() {
1045 let idx = Array::<u64, Ix1>::from_vec(Ix1::new([4]), vec![0, 1, 0, 1]).unwrap();
1046 let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
1047 let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![100, 200, 300, 400]).unwrap();
1048 let result = choose(&idx, &[c0, c1]).unwrap();
1049 let data: Vec<i32> = result.iter().copied().collect();
1050 assert_eq!(data, vec![10, 200, 30, 400]);
1051 }
1052
1053 #[test]
1054 fn choose_out_of_bounds() {
1055 let idx = Array::<u64, Ix1>::from_vec(Ix1::new([2]), vec![0, 2]).unwrap();
1056 let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 2]).unwrap();
1057 let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3, 4]).unwrap();
1058 assert!(choose(&idx, &[c0, c1]).is_err());
1059 }
1060
1061 #[test]
1066 fn compress_1d() {
1067 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
1068 let result = compress(&[true, false, true, false, true], &arr, Axis(0)).unwrap();
1069 let data: Vec<i32> = result.iter().copied().collect();
1070 assert_eq!(data, vec![10, 30, 50]);
1071 }
1072
1073 #[test]
1074 fn compress_2d_axis0() {
1075 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
1076 let result = compress(&[true, false, true], &arr, Axis(0)).unwrap();
1077 assert_eq!(result.shape(), &[2, 4]);
1078 let data: Vec<i32> = result.iter().copied().collect();
1079 assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
1080 }
1081
1082 #[test]
1087 fn select_basic() {
1088 let c1 =
1089 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, false]).unwrap();
1090 let c2 =
1091 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
1092 let ch1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 1, 1, 1]).unwrap();
1093 let ch2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![2, 2, 2, 2]).unwrap();
1094 let result = select(&[c1, c2], &[ch1, ch2], 0).unwrap();
1095 let data: Vec<i32> = result.iter().copied().collect();
1096 assert_eq!(data, vec![1, 2, 0, 0]);
1097 }
1098
1099 #[test]
1104 fn indices_2d() {
1105 let idx = indices(&[2, 3]).unwrap();
1106 assert_eq!(idx.len(), 2);
1107 assert_eq!(idx[0].shape(), &[2, 3]);
1108 assert_eq!(idx[1].shape(), &[2, 3]);
1109 let rows: Vec<u64> = idx[0].iter().copied().collect();
1110 assert_eq!(rows, vec![0, 0, 0, 1, 1, 1]);
1111 let cols: Vec<u64> = idx[1].iter().copied().collect();
1112 assert_eq!(cols, vec![0, 1, 2, 0, 1, 2]);
1113 }
1114
1115 #[test]
1120 fn ix_basic() {
1121 let result = ix_(&[&[0, 1], &[2, 3, 4]]).unwrap();
1122 assert_eq!(result.len(), 2);
1123 assert_eq!(result[0].shape(), &[2, 1]);
1124 assert_eq!(result[1].shape(), &[1, 3]);
1125 }
1126
1127 #[test]
1132 fn diag_indices_basic() {
1133 let idx = diag_indices(3, 2);
1134 assert_eq!(idx.len(), 2);
1135 assert_eq!(idx[0], vec![0, 1, 2]);
1136 assert_eq!(idx[1], vec![0, 1, 2]);
1137 }
1138
1139 #[test]
1140 fn diag_indices_from_square() {
1141 let arr = Array::<f64, Ix2>::zeros(Ix2::new([4, 4])).unwrap();
1142 let idx = diag_indices_from(&arr).unwrap();
1143 assert_eq!(idx.len(), 2);
1144 assert_eq!(idx[0].len(), 4);
1145 }
1146
1147 #[test]
1148 fn diag_indices_from_not_square() {
1149 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
1150 assert!(diag_indices_from(&arr).is_err());
1151 }
1152
1153 #[test]
1158 fn tril_indices_basic() {
1159 let (rows, cols) = tril_indices(3, 0, None);
1160 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1161 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1162 }
1163
1164 #[test]
1165 fn triu_indices_basic() {
1166 let (rows, cols) = triu_indices(3, 0, None);
1167 assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
1168 assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
1169 }
1170
1171 #[test]
1172 fn tril_indices_with_k() {
1173 let (rows, cols) = tril_indices(3, 1, None);
1174 assert_eq!(rows, vec![0, 0, 1, 1, 1, 2, 2, 2]);
1175 assert_eq!(cols, vec![0, 1, 0, 1, 2, 0, 1, 2]);
1176 }
1177
1178 #[test]
1179 fn triu_indices_with_negative_k() {
1180 let (rows, cols) = triu_indices(3, -1, None);
1181 assert_eq!(rows, vec![0, 0, 0, 1, 1, 1, 2, 2]);
1182 assert_eq!(cols, vec![0, 1, 2, 0, 1, 2, 1, 2]);
1183 }
1184
1185 #[test]
1186 fn tril_indices_from_test() {
1187 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1188 let (rows, _cols) = tril_indices_from(&arr, 0).unwrap();
1189 assert_eq!(rows.len(), 6);
1190 }
1191
1192 #[test]
1193 fn triu_indices_from_test() {
1194 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1195 let (rows, _cols) = triu_indices_from(&arr, 0).unwrap();
1196 assert_eq!(rows.len(), 6);
1197 }
1198
1199 #[test]
1200 fn tril_indices_rectangular() {
1201 let (rows, cols) = tril_indices(3, 0, Some(4));
1202 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1203 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1204 }
1205
1206 #[test]
1211 fn ravel_multi_index_basic() {
1212 let flat = ravel_multi_index(&[&[0, 1, 2], &[1, 2, 0]], &[3, 4]).unwrap();
1213 assert_eq!(flat, vec![1, 6, 8]);
1214 }
1215
1216 #[test]
1217 fn ravel_multi_index_3d() {
1218 let flat = ravel_multi_index(&[&[0], &[1], &[2]], &[2, 3, 4]).unwrap();
1219 assert_eq!(flat, vec![6]);
1220 }
1221
1222 #[test]
1223 fn ravel_multi_index_out_of_bounds() {
1224 assert!(ravel_multi_index(&[&[3]], &[3]).is_err());
1225 }
1226
1227 #[test]
1228 fn unravel_index_basic() {
1229 let coords = unravel_index(&[1, 6, 8], &[3, 4]).unwrap();
1230 assert_eq!(coords[0], vec![0, 1, 2]);
1231 assert_eq!(coords[1], vec![1, 2, 0]);
1232 }
1233
1234 #[test]
1235 fn unravel_index_out_of_bounds() {
1236 assert!(unravel_index(&[12], &[3, 4]).is_err());
1237 }
1238
1239 #[test]
1240 fn ravel_unravel_roundtrip() {
1241 let dims = &[3, 4, 5];
1242 let a: &[usize] = &[1, 2];
1243 let b: &[usize] = &[2, 3];
1244 let c: &[usize] = &[3, 4];
1245 let multi: &[&[usize]] = &[a, b, c];
1246 let flat = ravel_multi_index(multi, dims).unwrap();
1247 let coords = unravel_index(&flat, dims).unwrap();
1248 assert_eq!(coords[0], vec![1, 2]);
1249 assert_eq!(coords[1], vec![2, 3]);
1250 assert_eq!(coords[2], vec![3, 4]);
1251 }
1252
1253 #[test]
1258 fn flatnonzero_basic() {
1259 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1260 let nz = flatnonzero(&arr);
1261 assert_eq!(nz, vec![1, 3]);
1262 }
1263
1264 #[test]
1265 fn flatnonzero_2d() {
1266 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1267 let nz = flatnonzero(&arr);
1268 assert_eq!(nz, vec![1, 3, 5]);
1269 }
1270
1271 #[test]
1272 fn flatnonzero_all_zero() {
1273 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1274 let nz = flatnonzero(&arr);
1275 assert_eq!(nz.len(), 0);
1276 }
1277
1278 #[test]
1283 fn nonzero_1d() {
1284 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1285 let nz = nonzero(&arr);
1286 assert_eq!(nz.len(), 1);
1288 assert_eq!(nz[0], vec![1, 3]);
1289 }
1290
1291 #[test]
1292 fn nonzero_2d_yields_row_and_col_indices() {
1293 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1297 let nz = nonzero(&arr);
1298 assert_eq!(nz.len(), 2);
1299 assert_eq!(nz[0], vec![0, 1, 1]);
1300 assert_eq!(nz[1], vec![1, 0, 2]);
1301 }
1302
1303 #[test]
1304 fn nonzero_all_zero_returns_empty_per_axis() {
1305 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1306 let nz = nonzero(&arr);
1307 assert_eq!(nz.len(), 2);
1308 assert!(nz[0].is_empty());
1309 assert!(nz[1].is_empty());
1310 }
1311
1312 #[test]
1313 fn nonzero_f64_treats_negative_zero_as_zero() {
1314 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![-0.0, 1.5, 0.0, -2.5]).unwrap();
1316 let nz = nonzero(&arr);
1317 assert_eq!(nz[0], vec![1, 3]);
1318 }
1319
1320 #[test]
1321 fn argwhere_2d_has_one_row_per_nonzero() {
1322 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1324 let coords = argwhere(&arr).unwrap();
1325 assert_eq!(coords.shape(), &[3, 2]);
1326 assert_eq!(coords.as_slice().unwrap(), &[0, 1, 1, 0, 1, 2]);
1327 }
1328
1329 #[test]
1330 fn argwhere_1d_is_column_vector() {
1331 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 7, 0, 9, 3]).unwrap();
1333 let coords = argwhere(&arr).unwrap();
1334 assert_eq!(coords.shape(), &[3, 1]);
1335 assert_eq!(coords.as_slice().unwrap(), &[1, 3, 4]);
1336 }
1337
1338 #[test]
1339 fn argwhere_all_zero_returns_empty() {
1340 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1341 let coords = argwhere(&arr).unwrap();
1342 assert_eq!(coords.shape(), &[0, 2]);
1343 assert_eq!(coords.size(), 0);
1344 }
1345
1346 #[test]
1351 fn ndindex_2d() {
1352 let indices: Vec<Vec<usize>> = ndindex(&[2, 3]).collect();
1353 assert_eq!(indices.len(), 6);
1354 assert_eq!(indices[0], vec![0, 0]);
1355 assert_eq!(indices[1], vec![0, 1]);
1356 assert_eq!(indices[2], vec![0, 2]);
1357 assert_eq!(indices[3], vec![1, 0]);
1358 assert_eq!(indices[4], vec![1, 1]);
1359 assert_eq!(indices[5], vec![1, 2]);
1360 }
1361
1362 #[test]
1363 fn ndindex_1d() {
1364 let indices: Vec<Vec<usize>> = ndindex(&[4]).collect();
1365 assert_eq!(indices.len(), 4);
1366 assert_eq!(indices[0], vec![0]);
1367 assert_eq!(indices[3], vec![3]);
1368 }
1369
1370 #[test]
1371 fn ndindex_empty() {
1372 assert_eq!(ndindex(&[0]).count(), 0);
1373 }
1374
1375 #[test]
1376 fn ndindex_scalar() {
1377 let indices: Vec<Vec<usize>> = ndindex(&[]).collect();
1378 assert_eq!(indices.len(), 1);
1379 assert_eq!(indices[0], Vec::<usize>::new());
1380 }
1381
1382 #[test]
1387 fn ndenumerate_2d() {
1388 let arr =
1389 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
1390 let items: Vec<(Vec<usize>, &i32)> = ndenumerate(&arr).collect();
1391 assert_eq!(items.len(), 6);
1392 assert_eq!(items[0], (vec![0, 0], &10));
1393 assert_eq!(items[1], (vec![0, 1], &20));
1394 assert_eq!(items[5], (vec![1, 2], &60));
1395 }
1396
1397 #[test]
1402 fn put_along_axis_basic() {
1403 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![0; 12]).unwrap();
1404 let values =
1405 Array::<i32, IxDyn>::from_vec(IxDyn::new(&[8]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
1406 arr.put_along_axis(&[0, 2], &values, Axis(0)).unwrap();
1407 let data: Vec<i32> = arr.iter().copied().collect();
1408 assert_eq!(data, vec![1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8]);
1409 }
1410
1411 #[test]
1416 fn where_basic() {
1417 let cond =
1418 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1419 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1420 let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
1421 let result = where_select(&cond, &x, &y).unwrap();
1422 assert_eq!(result.as_slice().unwrap(), &[1.0, 20.0, 3.0, 40.0]);
1423 }
1424
1425 #[test]
1426 fn where_all_true() {
1427 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1428 let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1429 let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1430 let result = where_select(&cond, &x, &y).unwrap();
1431 assert_eq!(result.as_slice().unwrap(), &[1, 2, 3]);
1432 }
1433
1434 #[test]
1435 fn where_all_false() {
1436 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1437 let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1438 let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1439 let result = where_select(&cond, &x, &y).unwrap();
1440 assert_eq!(result.as_slice().unwrap(), &[10, 20, 30]);
1441 }
1442
1443 #[test]
1444 fn where_shape_mismatch() {
1445 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true; 3]).unwrap();
1446 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0; 4]).unwrap();
1447 let y = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0; 3]).unwrap();
1448 assert!(where_select(&cond, &x, &y).is_err());
1449 }
1450
1451 #[test]
1452 fn where_2d() {
1453 let cond =
1454 Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![true, false, false, true]).unwrap();
1455 let x = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
1456 let y = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![10, 20, 30, 40]).unwrap();
1457 let result = where_select(&cond, &x, &y).unwrap();
1458 let data: Vec<i32> = result.iter().copied().collect();
1459 assert_eq!(data, vec![1, 20, 30, 4]);
1460 }
1461
1462 #[test]
1465 fn test_place_basic() {
1466 let mut a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
1467 let mask = Array::<bool, Ix2>::from_vec(
1468 Ix2::new([2, 3]),
1469 vec![false, true, false, true, false, true],
1470 )
1471 .unwrap();
1472 place(&mut a, &mask, &[10, 20]).unwrap();
1473 let data: Vec<i32> = a.iter().copied().collect();
1475 assert_eq!(data, vec![1, 10, 3, 20, 5, 10]);
1476 }
1477
1478 #[test]
1479 fn test_place_no_hits() {
1480 let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1481 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
1482 place(&mut a, &mask, &[]).unwrap(); assert_eq!(a.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
1484 }
1485
1486 #[test]
1487 fn test_place_shape_mismatch() {
1488 let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1489 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
1490 assert!(place(&mut a, &mask, &[0]).is_err());
1491 }
1492
1493 #[test]
1494 fn test_putmask_scalar() {
1495 let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1496 let mask =
1497 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1498 putmask(&mut a, &mask, &[99]).unwrap();
1499 assert_eq!(a.iter().copied().collect::<Vec<_>>(), vec![99, 2, 99, 4]);
1500 }
1501
1502 #[test]
1503 fn test_putmask_full_array() {
1504 let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1505 let mask =
1506 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1507 putmask(&mut a, &mask, &[10, 20, 30, 40]).unwrap();
1508 assert_eq!(a.iter().copied().collect::<Vec<_>>(), vec![10, 2, 30, 4]);
1509 }
1510
1511 #[test]
1512 fn test_putmask_bad_length() {
1513 let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1514 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
1515 assert!(putmask(&mut a, &mask, &[1, 2]).is_err());
1516 }
1517
1518 #[test]
1519 fn test_extract_basic() {
1520 let cond =
1521 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
1522 .unwrap();
1523 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1524 let r = extract(&cond, &a).unwrap();
1525 assert_eq!(r.iter().copied().collect::<Vec<_>>(), vec![1.0, 3.0, 5.0]);
1526 }
1527
1528 #[test]
1529 fn test_extract_2d() {
1530 let cond =
1531 Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![true, false, false, true]).unwrap();
1532 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![10, 20, 30, 40]).unwrap();
1533 let r = extract(&cond, &a).unwrap();
1534 assert_eq!(r.iter().copied().collect::<Vec<_>>(), vec![10, 40]);
1535 }
1536
1537 #[test]
1538 fn test_mask_indices_tril() {
1539 let idx = mask_indices(3, MaskKind::Tril, 0);
1540 assert_eq!(idx, vec![0, 3, 4, 6, 7, 8]);
1542 }
1543
1544 #[test]
1545 fn test_mask_indices_triu() {
1546 let idx = mask_indices(3, MaskKind::Triu, 0);
1547 assert_eq!(idx, vec![0, 1, 2, 4, 5, 8]);
1549 }
1550
1551 #[test]
1552 fn test_mask_indices_diag() {
1553 let idx = mask_indices(3, MaskKind::Diag, 0);
1554 assert_eq!(idx, vec![0, 4, 8]);
1555 }
1556}