1use super::normalize_index;
13use crate::array::owned::Array;
14use crate::dimension::{Axis, Dimension, Ix2, IxDyn};
15use crate::dtype::Element;
16use crate::error::{FerrayError, FerrayResult};
17
18pub fn take<T: Element, D: Dimension>(
30 a: &Array<T, D>,
31 indices: &[isize],
32 axis: Axis,
33) -> FerrayResult<Array<T, IxDyn>> {
34 a.index_select(axis, indices)
35}
36
37pub fn take_along_axis<T: Element, D: Dimension>(
47 a: &Array<T, D>,
48 indices: &[isize],
49 axis: Axis,
50) -> FerrayResult<Array<T, IxDyn>> {
51 a.index_select(axis, indices)
52}
53
54impl<T: Element, D: Dimension> Array<T, D> {
59 pub fn put(&mut self, indices: &[isize], values: &[T]) -> FerrayResult<()> {
69 if values.is_empty() {
70 return Err(FerrayError::invalid_value("values must not be empty"));
71 }
72 let size = self.size();
73 let normalized: Vec<usize> = indices
74 .iter()
75 .map(|&idx| normalize_index(idx, size, 0))
76 .collect::<FerrayResult<Vec<_>>>()?;
77
78 let mut flat: Vec<&mut T> = self.inner.iter_mut().collect();
79
80 for (i, &idx) in normalized.iter().enumerate() {
81 let val_idx = i % values.len();
82 *flat[idx] = values[val_idx].clone();
83 }
84 Ok(())
85 }
86
87 pub fn put_along_axis(
96 &mut self,
97 indices: &[isize],
98 values: &Array<T, IxDyn>,
99 axis: Axis,
100 ) -> FerrayResult<()>
101 where
102 D::NdarrayDim: ndarray::RemoveAxis,
103 {
104 let ndim = self.ndim();
105 let ax = axis.index();
106 if ax >= ndim {
107 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
108 }
109 let axis_size = self.shape()[ax];
110
111 let normalized: Vec<usize> = indices
112 .iter()
113 .map(|&idx| normalize_index(idx, axis_size, ax))
114 .collect::<FerrayResult<Vec<_>>>()?;
115
116 let nd_axis = ndarray::Axis(ax);
117 let mut val_iter = values.inner.iter();
118
119 for &idx in &normalized {
120 let mut sub = self.inner.index_axis_mut(nd_axis, idx);
121 for elem in &mut sub {
122 if let Some(v) = val_iter.next() {
123 *elem = v.clone();
124 }
125 }
126 }
127 Ok(())
128 }
129
130 pub fn fill_diagonal(&mut self, val: T) {
137 let shape = self.shape().to_vec();
138 if shape.is_empty() {
139 return;
140 }
141 let min_dim = *shape.iter().min().unwrap_or(&0);
142 let ndim = shape.len();
143
144 for i in 0..min_dim {
145 let idx: Vec<usize> = vec![i; ndim];
146 let nd_idx = ndarray::IxDyn(&idx);
147 let mut dyn_view = self.inner.view_mut().into_dyn();
148 dyn_view[nd_idx] = val.clone();
149 }
150 }
151}
152
153pub fn choose<T: Element, D: Dimension>(
167 index_arr: &Array<u64, D>,
168 choices: &[Array<T, D>],
169) -> FerrayResult<Array<T, IxDyn>> {
170 if choices.is_empty() {
171 return Err(FerrayError::invalid_value("choices must not be empty"));
172 }
173
174 let shape = index_arr.shape();
175 for (i, c) in choices.iter().enumerate() {
176 if c.shape() != shape {
177 return Err(FerrayError::shape_mismatch(format!(
178 "choice[{}] shape {:?} does not match index array shape {:?}",
179 i,
180 c.shape(),
181 shape
182 )));
183 }
184 }
185
186 let n_choices = choices.len();
187 let choice_iters: Vec<Vec<T>> = choices
188 .iter()
189 .map(|c| c.inner.iter().cloned().collect())
190 .collect();
191
192 let mut data = Vec::with_capacity(index_arr.size());
193 for (pos, idx_val) in index_arr.inner.iter().enumerate() {
194 let idx = *idx_val as usize;
195 if idx >= n_choices {
196 return Err(FerrayError::index_out_of_bounds(idx as isize, 0, n_choices));
197 }
198 data.push(choice_iters[idx][pos].clone());
199 }
200
201 let dyn_shape = IxDyn::new(shape);
202 Array::from_vec(dyn_shape, data)
203}
204
205pub fn compress<T: Element, D: Dimension>(
217 condition: &[bool],
218 a: &Array<T, D>,
219 axis: Axis,
220) -> FerrayResult<Array<T, IxDyn>> {
221 let ndim = a.ndim();
222 let ax = axis.index();
223 if ax >= ndim {
224 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
225 }
226 let axis_size = a.shape()[ax];
227 if condition.len() > axis_size {
228 return Err(FerrayError::shape_mismatch(format!(
229 "condition length {} exceeds axis size {}",
230 condition.len(),
231 axis_size
232 )));
233 }
234
235 let indices: Vec<isize> = condition
236 .iter()
237 .enumerate()
238 .filter_map(|(i, &c)| if c { Some(i as isize) } else { None })
239 .collect();
240
241 a.index_select(axis, &indices)
242}
243
244pub fn select<T: Element, D: Dimension>(
257 condlist: &[Array<bool, D>],
258 choicelist: &[Array<T, D>],
259 default: T,
260) -> FerrayResult<Array<T, IxDyn>> {
261 if condlist.len() != choicelist.len() {
262 return Err(FerrayError::invalid_value(format!(
263 "condlist length {} != choicelist length {}",
264 condlist.len(),
265 choicelist.len()
266 )));
267 }
268 if condlist.is_empty() {
269 return Err(FerrayError::invalid_value(
270 "condlist and choicelist must not be empty",
271 ));
272 }
273
274 let shape = condlist[0].shape();
275 for (i, (c, ch)) in condlist.iter().zip(choicelist.iter()).enumerate() {
276 if c.shape() != shape || ch.shape() != shape {
277 return Err(FerrayError::shape_mismatch(format!(
278 "condlist[{i}]/choicelist[{i}] shape mismatch with reference shape {shape:?}"
279 )));
280 }
281 }
282
283 let size = condlist[0].size();
284 let mut data = vec![default; size];
285
286 for (cond, choice) in condlist.iter().zip(choicelist.iter()).rev() {
288 for (i, (&c, v)) in cond.inner.iter().zip(choice.inner.iter()).enumerate() {
289 if c {
290 data[i] = v.clone();
291 }
292 }
293 }
294
295 let dyn_shape = IxDyn::new(shape);
296 Array::from_vec(dyn_shape, data)
297}
298
299pub fn indices(dimensions: &[usize]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
311 let ndim = dimensions.len();
312 let total: usize = dimensions.iter().product();
313
314 let mut result = Vec::with_capacity(ndim);
315
316 for ax in 0..ndim {
317 let mut data = Vec::with_capacity(total);
318 for flat_idx in 0..total {
319 let mut rem = flat_idx;
320 let mut idx_for_ax = 0;
321 for (d, &dim_size) in dimensions.iter().enumerate().rev() {
322 let coord = rem % dim_size;
323 rem /= dim_size;
324 if d == ax {
325 idx_for_ax = coord;
326 }
327 }
328 data.push(idx_for_ax as u64);
329 }
330 let dim = IxDyn::new(dimensions);
331 result.push(Array::from_vec(dim, data)?);
332 }
333
334 Ok(result)
335}
336
337pub fn ix_(sequences: &[&[u64]]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
349 let ndim = sequences.len();
350 let mut result = Vec::with_capacity(ndim);
351
352 for (i, seq) in sequences.iter().enumerate() {
353 let mut shape = vec![1usize; ndim];
354 shape[i] = seq.len();
355
356 let data = seq.to_vec();
357 let dim = IxDyn::new(&shape);
358 result.push(Array::from_vec(dim, data)?);
359 }
360
361 Ok(result)
362}
363
364#[must_use]
373pub fn diag_indices(n: usize, ndim: usize) -> Vec<Vec<usize>> {
374 let data: Vec<usize> = (0..n).collect();
375 vec![data; ndim]
376}
377
378pub fn diag_indices_from<T: Element, D: Dimension>(
386 a: &Array<T, D>,
387) -> FerrayResult<Vec<Vec<usize>>> {
388 let ndim = a.ndim();
389 if ndim < 2 {
390 return Err(FerrayError::invalid_value(
391 "diag_indices_from requires at least 2 dimensions",
392 ));
393 }
394 let shape = a.shape();
395 let n = shape[0];
396 for &s in &shape[1..] {
397 if s != n {
398 return Err(FerrayError::shape_mismatch(format!(
399 "all dimensions must be equal for diag_indices_from, got {shape:?}"
400 )));
401 }
402 }
403 Ok(diag_indices(n, ndim))
404}
405
406#[must_use]
416pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
417 let m = m.unwrap_or(n);
418 let mut rows = Vec::new();
419 let mut cols = Vec::new();
420
421 for i in 0..n {
422 for j in 0..m {
423 if (j as isize) <= (i as isize) + k {
424 rows.push(i);
425 cols.push(j);
426 }
427 }
428 }
429
430 (rows, cols)
431}
432
433#[must_use]
437pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
438 let m = m.unwrap_or(n);
439 let mut rows = Vec::new();
440 let mut cols = Vec::new();
441
442 for i in 0..n {
443 for j in 0..m {
444 if (j as isize) >= (i as isize) + k {
445 rows.push(i);
446 cols.push(j);
447 }
448 }
449 }
450
451 (rows, cols)
452}
453
454pub fn tril_indices_from<T: Element, D: Dimension>(
459 a: &Array<T, D>,
460 k: isize,
461) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
462 let shape = a.shape();
463 if shape.len() != 2 {
464 return Err(FerrayError::invalid_value(
465 "tril_indices_from requires a 2-D array",
466 ));
467 }
468 Ok(tril_indices(shape[0], k, Some(shape[1])))
469}
470
471pub fn triu_indices_from<T: Element, D: Dimension>(
476 a: &Array<T, D>,
477 k: isize,
478) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
479 let shape = a.shape();
480 if shape.len() != 2 {
481 return Err(FerrayError::invalid_value(
482 "triu_indices_from requires a 2-D array",
483 ));
484 }
485 Ok(triu_indices(shape[0], k, Some(shape[1])))
486}
487
488#[allow(clippy::needless_range_loop)]
501pub fn ravel_multi_index(multi_index: &[&[usize]], dims: &[usize]) -> FerrayResult<Vec<usize>> {
502 if multi_index.len() != dims.len() {
503 return Err(FerrayError::invalid_value(format!(
504 "multi_index has {} components but dims has {} dimensions",
505 multi_index.len(),
506 dims.len()
507 )));
508 }
509 if multi_index.is_empty() {
510 return Ok(vec![]);
511 }
512
513 let n = multi_index[0].len();
514 for (i, idx_arr) in multi_index.iter().enumerate() {
515 if idx_arr.len() != n {
516 return Err(FerrayError::invalid_value(format!(
517 "multi_index[{}] has length {} but expected {}",
518 i,
519 idx_arr.len(),
520 n
521 )));
522 }
523 }
524
525 let ndim = dims.len();
527 let mut strides = vec![1usize; ndim];
528 for i in (0..ndim - 1).rev() {
529 strides[i] = strides[i + 1] * dims[i + 1];
530 }
531
532 let mut flat = Vec::with_capacity(n);
533 #[allow(clippy::needless_range_loop)]
534 for pos in 0..n {
535 let mut linear = 0usize;
536 for (d, &dim_size) in dims.iter().enumerate() {
537 let coord = multi_index[d][pos];
538 if coord >= dim_size {
539 return Err(FerrayError::index_out_of_bounds(
540 coord as isize,
541 d,
542 dim_size,
543 ));
544 }
545 linear += coord * strides[d];
546 }
547 flat.push(linear);
548 }
549
550 Ok(flat)
551}
552
553pub fn unravel_index(flat_indices: &[usize], shape: &[usize]) -> FerrayResult<Vec<Vec<usize>>> {
561 let total: usize = shape.iter().product();
562 let ndim = shape.len();
563 let n = flat_indices.len();
564
565 let mut result: Vec<Vec<usize>> = vec![Vec::with_capacity(n); ndim];
566
567 for &flat_idx in flat_indices {
568 if flat_idx >= total {
569 return Err(FerrayError::index_out_of_bounds(
570 flat_idx as isize,
571 0,
572 total,
573 ));
574 }
575 let mut rem = flat_idx;
576 for (d, &dim_size) in shape.iter().enumerate().rev() {
577 result[d].push(rem % dim_size);
578 rem /= dim_size;
579 }
580 }
581
582 Ok(result)
583}
584
585pub fn flatnonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<usize> {
594 let zero = T::zero();
595 a.inner
596 .iter()
597 .enumerate()
598 .filter_map(|(i, val)| if *val == zero { None } else { Some(i) })
599 .collect()
600}
601
602pub fn nonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<Vec<usize>> {
622 let zero = T::zero();
623 let ndim = a.ndim();
624 let mut result: Vec<Vec<usize>> = vec![Vec::new(); ndim];
625 for (idx, val) in a.indexed_iter() {
626 if *val != zero {
627 for (d, &c) in idx.iter().enumerate() {
628 result[d].push(c);
629 }
630 }
631 }
632 result
633}
634
635pub fn argwhere<T: Element + PartialEq, D: Dimension>(
648 a: &Array<T, D>,
649) -> FerrayResult<Array<i64, Ix2>> {
650 let zero = T::zero();
651 let ndim = a.ndim();
652 let mut data: Vec<i64> = Vec::new();
653 let mut count: usize = 0;
654 for (idx, val) in a.indexed_iter() {
655 if *val != zero {
656 for &c in &idx {
657 data.push(c as i64);
658 }
659 count += 1;
660 }
661 }
662 Array::<i64, Ix2>::from_vec(Ix2::new([count, ndim]), data)
663}
664
665pub struct NdIndex {
673 shape: Vec<usize>,
674 current: Vec<usize>,
675 done: bool,
676}
677
678impl NdIndex {
679 fn new(shape: &[usize]) -> Self {
680 let done = shape.contains(&0);
681 Self {
682 shape: shape.to_vec(),
683 current: vec![0; shape.len()],
684 done,
685 }
686 }
687}
688
689impl Iterator for NdIndex {
690 type Item = Vec<usize>;
691
692 fn next(&mut self) -> Option<Self::Item> {
693 if self.done {
694 return None;
695 }
696
697 let result = self.current.clone();
698
699 let mut carry = true;
701 for i in (0..self.shape.len()).rev() {
702 if carry {
703 self.current[i] += 1;
704 if self.current[i] >= self.shape[i] {
705 self.current[i] = 0;
706 carry = true;
707 } else {
708 carry = false;
709 }
710 }
711 }
712 if carry {
713 self.done = true;
714 }
715
716 Some(result)
717 }
718
719 fn size_hint(&self) -> (usize, Option<usize>) {
720 if self.done {
721 return (0, Some(0));
722 }
723 let total: usize = self.shape.iter().product();
724 let mut yielded = 0usize;
726 let ndim = self.shape.len();
727 let mut stride = 1usize;
728 for i in (0..ndim).rev() {
729 yielded += self.current[i] * stride;
730 stride *= self.shape[i];
731 }
732 let remaining = total - yielded;
733 (remaining, Some(remaining))
734 }
735}
736
737#[must_use]
741pub fn ndindex(shape: &[usize]) -> NdIndex {
742 NdIndex::new(shape)
743}
744
745pub fn ndenumerate<T: Element, D: Dimension>(
749 a: &Array<T, D>,
750) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
751 let shape = a.shape().to_vec();
752 let ndim = shape.len();
753 a.inner.iter().enumerate().map(move |(flat_idx, val)| {
754 let mut idx = vec![0usize; ndim];
755 let mut rem = flat_idx;
756 for (d, s) in shape.iter().enumerate().rev() {
757 if *s > 0 {
758 idx[d] = rem % s;
759 rem /= s;
760 }
761 }
762 (idx, val)
763 })
764}
765
766pub fn where_select<T: Element + Copy, D: Dimension>(
782 condition: &Array<bool, D>,
783 x: &Array<T, D>,
784 y: &Array<T, D>,
785) -> FerrayResult<Array<T, D>> {
786 if condition.shape() != x.shape() || condition.shape() != y.shape() {
787 return Err(FerrayError::shape_mismatch(format!(
788 "where_select: condition shape {:?}, x shape {:?}, y shape {:?} must all match",
789 condition.shape(),
790 x.shape(),
791 y.shape()
792 )));
793 }
794 let data: Vec<T> = condition
795 .iter()
796 .zip(x.iter().zip(y.iter()))
797 .map(|(&c, (&xi, &yi))| if c { xi } else { yi })
798 .collect();
799 Array::from_vec(x.dim().clone(), data)
800}
801
802#[cfg(test)]
803mod tests {
804 use super::*;
805 use crate::dimension::{Ix1, Ix2};
806
807 #[test]
812 fn take_1d() {
813 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
814 let taken = take(&arr, &[0, 2, 4], Axis(0)).unwrap();
815 assert_eq!(taken.shape(), &[3]);
816 let data: Vec<i32> = taken.iter().copied().collect();
817 assert_eq!(data, vec![10, 30, 50]);
818 }
819
820 #[test]
821 fn take_2d_axis1() {
822 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
823 let taken = take(&arr, &[0, 2], Axis(1)).unwrap();
824 assert_eq!(taken.shape(), &[3, 2]);
825 let data: Vec<i32> = taken.iter().copied().collect();
826 assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
827 }
828
829 #[test]
830 fn take_negative_indices() {
831 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
832 let taken = take(&arr, &[-1, -3], Axis(0)).unwrap();
833 let data: Vec<i32> = taken.iter().copied().collect();
834 assert_eq!(data, vec![40, 20]);
835 }
836
837 #[test]
842 fn take_along_axis_basic() {
843 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
844 let taken = take_along_axis(&arr, &[1, 3], Axis(1)).unwrap();
845 assert_eq!(taken.shape(), &[3, 2]);
846 }
847
848 #[test]
853 fn put_flat() {
854 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 0, 0, 0, 0]).unwrap();
855 arr.put(&[1, 3], &[99, 88]).unwrap();
856 assert_eq!(arr.as_slice().unwrap(), &[0, 99, 0, 88, 0]);
857 }
858
859 #[test]
860 fn put_cycling_values() {
861 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0; 5]).unwrap();
862 arr.put(&[0, 1, 2, 3, 4], &[10, 20]).unwrap();
863 assert_eq!(arr.as_slice().unwrap(), &[10, 20, 10, 20, 10]);
864 }
865
866 #[test]
867 fn put_out_of_bounds() {
868 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
869 assert!(arr.put(&[5], &[1]).is_err());
870 }
871
872 #[test]
877 fn fill_diagonal_2d() {
878 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 3]), vec![0; 9]).unwrap();
879 arr.fill_diagonal(1);
880 let data: Vec<i32> = arr.iter().copied().collect();
881 assert_eq!(data, vec![1, 0, 0, 0, 1, 0, 0, 0, 1]);
882 }
883
884 #[test]
885 fn fill_diagonal_rectangular() {
886 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 4]), vec![0; 8]).unwrap();
887 arr.fill_diagonal(5);
888 let data: Vec<i32> = arr.iter().copied().collect();
889 assert_eq!(data, vec![5, 0, 0, 0, 0, 5, 0, 0]);
890 }
891
892 #[test]
897 fn choose_basic() {
898 let idx = Array::<u64, Ix1>::from_vec(Ix1::new([4]), vec![0, 1, 0, 1]).unwrap();
899 let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
900 let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![100, 200, 300, 400]).unwrap();
901 let result = choose(&idx, &[c0, c1]).unwrap();
902 let data: Vec<i32> = result.iter().copied().collect();
903 assert_eq!(data, vec![10, 200, 30, 400]);
904 }
905
906 #[test]
907 fn choose_out_of_bounds() {
908 let idx = Array::<u64, Ix1>::from_vec(Ix1::new([2]), vec![0, 2]).unwrap();
909 let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 2]).unwrap();
910 let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3, 4]).unwrap();
911 assert!(choose(&idx, &[c0, c1]).is_err());
912 }
913
914 #[test]
919 fn compress_1d() {
920 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
921 let result = compress(&[true, false, true, false, true], &arr, Axis(0)).unwrap();
922 let data: Vec<i32> = result.iter().copied().collect();
923 assert_eq!(data, vec![10, 30, 50]);
924 }
925
926 #[test]
927 fn compress_2d_axis0() {
928 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
929 let result = compress(&[true, false, true], &arr, Axis(0)).unwrap();
930 assert_eq!(result.shape(), &[2, 4]);
931 let data: Vec<i32> = result.iter().copied().collect();
932 assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
933 }
934
935 #[test]
940 fn select_basic() {
941 let c1 =
942 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, false]).unwrap();
943 let c2 =
944 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
945 let ch1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 1, 1, 1]).unwrap();
946 let ch2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![2, 2, 2, 2]).unwrap();
947 let result = select(&[c1, c2], &[ch1, ch2], 0).unwrap();
948 let data: Vec<i32> = result.iter().copied().collect();
949 assert_eq!(data, vec![1, 2, 0, 0]);
950 }
951
952 #[test]
957 fn indices_2d() {
958 let idx = indices(&[2, 3]).unwrap();
959 assert_eq!(idx.len(), 2);
960 assert_eq!(idx[0].shape(), &[2, 3]);
961 assert_eq!(idx[1].shape(), &[2, 3]);
962 let rows: Vec<u64> = idx[0].iter().copied().collect();
963 assert_eq!(rows, vec![0, 0, 0, 1, 1, 1]);
964 let cols: Vec<u64> = idx[1].iter().copied().collect();
965 assert_eq!(cols, vec![0, 1, 2, 0, 1, 2]);
966 }
967
968 #[test]
973 fn ix_basic() {
974 let result = ix_(&[&[0, 1], &[2, 3, 4]]).unwrap();
975 assert_eq!(result.len(), 2);
976 assert_eq!(result[0].shape(), &[2, 1]);
977 assert_eq!(result[1].shape(), &[1, 3]);
978 }
979
980 #[test]
985 fn diag_indices_basic() {
986 let idx = diag_indices(3, 2);
987 assert_eq!(idx.len(), 2);
988 assert_eq!(idx[0], vec![0, 1, 2]);
989 assert_eq!(idx[1], vec![0, 1, 2]);
990 }
991
992 #[test]
993 fn diag_indices_from_square() {
994 let arr = Array::<f64, Ix2>::zeros(Ix2::new([4, 4])).unwrap();
995 let idx = diag_indices_from(&arr).unwrap();
996 assert_eq!(idx.len(), 2);
997 assert_eq!(idx[0].len(), 4);
998 }
999
1000 #[test]
1001 fn diag_indices_from_not_square() {
1002 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
1003 assert!(diag_indices_from(&arr).is_err());
1004 }
1005
1006 #[test]
1011 fn tril_indices_basic() {
1012 let (rows, cols) = tril_indices(3, 0, None);
1013 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1014 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1015 }
1016
1017 #[test]
1018 fn triu_indices_basic() {
1019 let (rows, cols) = triu_indices(3, 0, None);
1020 assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
1021 assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
1022 }
1023
1024 #[test]
1025 fn tril_indices_with_k() {
1026 let (rows, cols) = tril_indices(3, 1, None);
1027 assert_eq!(rows, vec![0, 0, 1, 1, 1, 2, 2, 2]);
1028 assert_eq!(cols, vec![0, 1, 0, 1, 2, 0, 1, 2]);
1029 }
1030
1031 #[test]
1032 fn triu_indices_with_negative_k() {
1033 let (rows, cols) = triu_indices(3, -1, None);
1034 assert_eq!(rows, vec![0, 0, 0, 1, 1, 1, 2, 2]);
1035 assert_eq!(cols, vec![0, 1, 2, 0, 1, 2, 1, 2]);
1036 }
1037
1038 #[test]
1039 fn tril_indices_from_test() {
1040 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1041 let (rows, _cols) = tril_indices_from(&arr, 0).unwrap();
1042 assert_eq!(rows.len(), 6);
1043 }
1044
1045 #[test]
1046 fn triu_indices_from_test() {
1047 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1048 let (rows, _cols) = triu_indices_from(&arr, 0).unwrap();
1049 assert_eq!(rows.len(), 6);
1050 }
1051
1052 #[test]
1053 fn tril_indices_rectangular() {
1054 let (rows, cols) = tril_indices(3, 0, Some(4));
1055 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1056 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1057 }
1058
1059 #[test]
1064 fn ravel_multi_index_basic() {
1065 let flat = ravel_multi_index(&[&[0, 1, 2], &[1, 2, 0]], &[3, 4]).unwrap();
1066 assert_eq!(flat, vec![1, 6, 8]);
1067 }
1068
1069 #[test]
1070 fn ravel_multi_index_3d() {
1071 let flat = ravel_multi_index(&[&[0], &[1], &[2]], &[2, 3, 4]).unwrap();
1072 assert_eq!(flat, vec![6]);
1073 }
1074
1075 #[test]
1076 fn ravel_multi_index_out_of_bounds() {
1077 assert!(ravel_multi_index(&[&[3]], &[3]).is_err());
1078 }
1079
1080 #[test]
1081 fn unravel_index_basic() {
1082 let coords = unravel_index(&[1, 6, 8], &[3, 4]).unwrap();
1083 assert_eq!(coords[0], vec![0, 1, 2]);
1084 assert_eq!(coords[1], vec![1, 2, 0]);
1085 }
1086
1087 #[test]
1088 fn unravel_index_out_of_bounds() {
1089 assert!(unravel_index(&[12], &[3, 4]).is_err());
1090 }
1091
1092 #[test]
1093 fn ravel_unravel_roundtrip() {
1094 let dims = &[3, 4, 5];
1095 let a: &[usize] = &[1, 2];
1096 let b: &[usize] = &[2, 3];
1097 let c: &[usize] = &[3, 4];
1098 let multi: &[&[usize]] = &[a, b, c];
1099 let flat = ravel_multi_index(multi, dims).unwrap();
1100 let coords = unravel_index(&flat, dims).unwrap();
1101 assert_eq!(coords[0], vec![1, 2]);
1102 assert_eq!(coords[1], vec![2, 3]);
1103 assert_eq!(coords[2], vec![3, 4]);
1104 }
1105
1106 #[test]
1111 fn flatnonzero_basic() {
1112 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1113 let nz = flatnonzero(&arr);
1114 assert_eq!(nz, vec![1, 3]);
1115 }
1116
1117 #[test]
1118 fn flatnonzero_2d() {
1119 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1120 let nz = flatnonzero(&arr);
1121 assert_eq!(nz, vec![1, 3, 5]);
1122 }
1123
1124 #[test]
1125 fn flatnonzero_all_zero() {
1126 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1127 let nz = flatnonzero(&arr);
1128 assert_eq!(nz.len(), 0);
1129 }
1130
1131 #[test]
1136 fn nonzero_1d() {
1137 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1138 let nz = nonzero(&arr);
1139 assert_eq!(nz.len(), 1);
1141 assert_eq!(nz[0], vec![1, 3]);
1142 }
1143
1144 #[test]
1145 fn nonzero_2d_yields_row_and_col_indices() {
1146 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1150 let nz = nonzero(&arr);
1151 assert_eq!(nz.len(), 2);
1152 assert_eq!(nz[0], vec![0, 1, 1]);
1153 assert_eq!(nz[1], vec![1, 0, 2]);
1154 }
1155
1156 #[test]
1157 fn nonzero_all_zero_returns_empty_per_axis() {
1158 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1159 let nz = nonzero(&arr);
1160 assert_eq!(nz.len(), 2);
1161 assert!(nz[0].is_empty());
1162 assert!(nz[1].is_empty());
1163 }
1164
1165 #[test]
1166 fn nonzero_f64_treats_negative_zero_as_zero() {
1167 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![-0.0, 1.5, 0.0, -2.5]).unwrap();
1169 let nz = nonzero(&arr);
1170 assert_eq!(nz[0], vec![1, 3]);
1171 }
1172
1173 #[test]
1174 fn argwhere_2d_has_one_row_per_nonzero() {
1175 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1177 let coords = argwhere(&arr).unwrap();
1178 assert_eq!(coords.shape(), &[3, 2]);
1179 assert_eq!(coords.as_slice().unwrap(), &[0, 1, 1, 0, 1, 2]);
1180 }
1181
1182 #[test]
1183 fn argwhere_1d_is_column_vector() {
1184 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 7, 0, 9, 3]).unwrap();
1186 let coords = argwhere(&arr).unwrap();
1187 assert_eq!(coords.shape(), &[3, 1]);
1188 assert_eq!(coords.as_slice().unwrap(), &[1, 3, 4]);
1189 }
1190
1191 #[test]
1192 fn argwhere_all_zero_returns_empty() {
1193 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1194 let coords = argwhere(&arr).unwrap();
1195 assert_eq!(coords.shape(), &[0, 2]);
1196 assert_eq!(coords.size(), 0);
1197 }
1198
1199 #[test]
1204 fn ndindex_2d() {
1205 let indices: Vec<Vec<usize>> = ndindex(&[2, 3]).collect();
1206 assert_eq!(indices.len(), 6);
1207 assert_eq!(indices[0], vec![0, 0]);
1208 assert_eq!(indices[1], vec![0, 1]);
1209 assert_eq!(indices[2], vec![0, 2]);
1210 assert_eq!(indices[3], vec![1, 0]);
1211 assert_eq!(indices[4], vec![1, 1]);
1212 assert_eq!(indices[5], vec![1, 2]);
1213 }
1214
1215 #[test]
1216 fn ndindex_1d() {
1217 let indices: Vec<Vec<usize>> = ndindex(&[4]).collect();
1218 assert_eq!(indices.len(), 4);
1219 assert_eq!(indices[0], vec![0]);
1220 assert_eq!(indices[3], vec![3]);
1221 }
1222
1223 #[test]
1224 fn ndindex_empty() {
1225 assert_eq!(ndindex(&[0]).count(), 0);
1226 }
1227
1228 #[test]
1229 fn ndindex_scalar() {
1230 let indices: Vec<Vec<usize>> = ndindex(&[]).collect();
1231 assert_eq!(indices.len(), 1);
1232 assert_eq!(indices[0], Vec::<usize>::new());
1233 }
1234
1235 #[test]
1240 fn ndenumerate_2d() {
1241 let arr =
1242 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
1243 let items: Vec<(Vec<usize>, &i32)> = ndenumerate(&arr).collect();
1244 assert_eq!(items.len(), 6);
1245 assert_eq!(items[0], (vec![0, 0], &10));
1246 assert_eq!(items[1], (vec![0, 1], &20));
1247 assert_eq!(items[5], (vec![1, 2], &60));
1248 }
1249
1250 #[test]
1255 fn put_along_axis_basic() {
1256 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![0; 12]).unwrap();
1257 let values =
1258 Array::<i32, IxDyn>::from_vec(IxDyn::new(&[8]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
1259 arr.put_along_axis(&[0, 2], &values, Axis(0)).unwrap();
1260 let data: Vec<i32> = arr.iter().copied().collect();
1261 assert_eq!(data, vec![1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8]);
1262 }
1263
1264 #[test]
1269 fn where_basic() {
1270 let cond =
1271 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1272 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1273 let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
1274 let result = where_select(&cond, &x, &y).unwrap();
1275 assert_eq!(result.as_slice().unwrap(), &[1.0, 20.0, 3.0, 40.0]);
1276 }
1277
1278 #[test]
1279 fn where_all_true() {
1280 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1281 let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1282 let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1283 let result = where_select(&cond, &x, &y).unwrap();
1284 assert_eq!(result.as_slice().unwrap(), &[1, 2, 3]);
1285 }
1286
1287 #[test]
1288 fn where_all_false() {
1289 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1290 let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1291 let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1292 let result = where_select(&cond, &x, &y).unwrap();
1293 assert_eq!(result.as_slice().unwrap(), &[10, 20, 30]);
1294 }
1295
1296 #[test]
1297 fn where_shape_mismatch() {
1298 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true; 3]).unwrap();
1299 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0; 4]).unwrap();
1300 let y = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0; 3]).unwrap();
1301 assert!(where_select(&cond, &x, &y).is_err());
1302 }
1303
1304 #[test]
1305 fn where_2d() {
1306 let cond =
1307 Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![true, false, false, true]).unwrap();
1308 let x = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
1309 let y = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![10, 20, 30, 40]).unwrap();
1310 let result = where_select(&cond, &x, &y).unwrap();
1311 let data: Vec<i32> = result.iter().copied().collect();
1312 assert_eq!(data, vec![1, 20, 30, 4]);
1313 }
1314}