1use super::normalize_index;
13use crate::array::owned::Array;
14use crate::dimension::{Axis, Dimension, Ix2, IxDyn};
15use crate::dtype::Element;
16use crate::error::{FerrayError, FerrayResult};
17
18pub fn take<T: Element, D: Dimension>(
30 a: &Array<T, D>,
31 indices: &[isize],
32 axis: Axis,
33) -> FerrayResult<Array<T, IxDyn>> {
34 a.index_select(axis, indices)
35}
36
37pub fn take_along_axis<T: Element, D: Dimension>(
47 a: &Array<T, D>,
48 indices: &[isize],
49 axis: Axis,
50) -> FerrayResult<Array<T, IxDyn>> {
51 a.index_select(axis, indices)
52}
53
54impl<T: Element, D: Dimension> Array<T, D> {
59 pub fn put(&mut self, indices: &[isize], values: &[T]) -> FerrayResult<()> {
69 if values.is_empty() {
70 return Err(FerrayError::invalid_value("values must not be empty"));
71 }
72 let size = self.size();
73 let normalized: Vec<usize> = indices
74 .iter()
75 .map(|&idx| normalize_index(idx, size, 0))
76 .collect::<FerrayResult<Vec<_>>>()?;
77
78 let mut flat: Vec<&mut T> = self.inner.iter_mut().collect();
79
80 for (i, &idx) in normalized.iter().enumerate() {
81 let val_idx = i % values.len();
82 *flat[idx] = values[val_idx].clone();
83 }
84 Ok(())
85 }
86
87 pub fn put_along_axis(
96 &mut self,
97 indices: &[isize],
98 values: &Array<T, IxDyn>,
99 axis: Axis,
100 ) -> FerrayResult<()>
101 where
102 D::NdarrayDim: ndarray::RemoveAxis,
103 {
104 let ndim = self.ndim();
105 let ax = axis.index();
106 if ax >= ndim {
107 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
108 }
109 let axis_size = self.shape()[ax];
110
111 let normalized: Vec<usize> = indices
112 .iter()
113 .map(|&idx| normalize_index(idx, axis_size, ax))
114 .collect::<FerrayResult<Vec<_>>>()?;
115
116 let nd_axis = ndarray::Axis(ax);
117 let mut val_iter = values.inner.iter();
118
119 for &idx in &normalized {
120 let mut sub = self.inner.index_axis_mut(nd_axis, idx);
121 for elem in sub.iter_mut() {
122 if let Some(v) = val_iter.next() {
123 *elem = v.clone();
124 }
125 }
126 }
127 Ok(())
128 }
129
130 pub fn fill_diagonal(&mut self, val: T) {
137 let shape = self.shape().to_vec();
138 if shape.is_empty() {
139 return;
140 }
141 let min_dim = *shape.iter().min().unwrap_or(&0);
142 let ndim = shape.len();
143
144 for i in 0..min_dim {
145 let idx: Vec<usize> = vec![i; ndim];
146 let nd_idx = ndarray::IxDyn(&idx);
147 let mut dyn_view = self.inner.view_mut().into_dyn();
148 dyn_view[nd_idx] = val.clone();
149 }
150 }
151}
152
153pub fn choose<T: Element, D: Dimension>(
167 index_arr: &Array<u64, D>,
168 choices: &[Array<T, D>],
169) -> FerrayResult<Array<T, IxDyn>> {
170 if choices.is_empty() {
171 return Err(FerrayError::invalid_value("choices must not be empty"));
172 }
173
174 let shape = index_arr.shape();
175 for (i, c) in choices.iter().enumerate() {
176 if c.shape() != shape {
177 return Err(FerrayError::shape_mismatch(format!(
178 "choice[{}] shape {:?} does not match index array shape {:?}",
179 i,
180 c.shape(),
181 shape
182 )));
183 }
184 }
185
186 let n_choices = choices.len();
187 let choice_iters: Vec<Vec<T>> = choices
188 .iter()
189 .map(|c| c.inner.iter().cloned().collect())
190 .collect();
191
192 let mut data = Vec::with_capacity(index_arr.size());
193 for (pos, idx_val) in index_arr.inner.iter().enumerate() {
194 let idx = *idx_val as usize;
195 if idx >= n_choices {
196 return Err(FerrayError::index_out_of_bounds(idx as isize, 0, n_choices));
197 }
198 data.push(choice_iters[idx][pos].clone());
199 }
200
201 let dyn_shape = IxDyn::new(shape);
202 Array::from_vec(dyn_shape, data)
203}
204
205pub fn compress<T: Element, D: Dimension>(
217 condition: &[bool],
218 a: &Array<T, D>,
219 axis: Axis,
220) -> FerrayResult<Array<T, IxDyn>> {
221 let ndim = a.ndim();
222 let ax = axis.index();
223 if ax >= ndim {
224 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
225 }
226 let axis_size = a.shape()[ax];
227 if condition.len() > axis_size {
228 return Err(FerrayError::shape_mismatch(format!(
229 "condition length {} exceeds axis size {}",
230 condition.len(),
231 axis_size
232 )));
233 }
234
235 let indices: Vec<isize> = condition
236 .iter()
237 .enumerate()
238 .filter_map(|(i, &c)| if c { Some(i as isize) } else { None })
239 .collect();
240
241 a.index_select(axis, &indices)
242}
243
244pub fn select<T: Element, D: Dimension>(
257 condlist: &[Array<bool, D>],
258 choicelist: &[Array<T, D>],
259 default: T,
260) -> FerrayResult<Array<T, IxDyn>> {
261 if condlist.len() != choicelist.len() {
262 return Err(FerrayError::invalid_value(format!(
263 "condlist length {} != choicelist length {}",
264 condlist.len(),
265 choicelist.len()
266 )));
267 }
268 if condlist.is_empty() {
269 return Err(FerrayError::invalid_value(
270 "condlist and choicelist must not be empty",
271 ));
272 }
273
274 let shape = condlist[0].shape();
275 for (i, (c, ch)) in condlist.iter().zip(choicelist.iter()).enumerate() {
276 if c.shape() != shape || ch.shape() != shape {
277 return Err(FerrayError::shape_mismatch(format!(
278 "condlist[{}]/choicelist[{}] shape mismatch with reference shape {:?}",
279 i, i, shape
280 )));
281 }
282 }
283
284 let size = condlist[0].size();
285 let mut data = vec![default; size];
286
287 for (cond, choice) in condlist.iter().zip(choicelist.iter()).rev() {
289 for (i, (&c, v)) in cond.inner.iter().zip(choice.inner.iter()).enumerate() {
290 if c {
291 data[i] = v.clone();
292 }
293 }
294 }
295
296 let dyn_shape = IxDyn::new(shape);
297 Array::from_vec(dyn_shape, data)
298}
299
300pub fn indices(dimensions: &[usize]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
312 let ndim = dimensions.len();
313 let total: usize = dimensions.iter().product();
314
315 let mut result = Vec::with_capacity(ndim);
316
317 for ax in 0..ndim {
318 let mut data = Vec::with_capacity(total);
319 for flat_idx in 0..total {
320 let mut rem = flat_idx;
321 let mut idx_for_ax = 0;
322 for (d, &dim_size) in dimensions.iter().enumerate().rev() {
323 let coord = rem % dim_size;
324 rem /= dim_size;
325 if d == ax {
326 idx_for_ax = coord;
327 }
328 }
329 data.push(idx_for_ax as u64);
330 }
331 let dim = IxDyn::new(dimensions);
332 result.push(Array::from_vec(dim, data)?);
333 }
334
335 Ok(result)
336}
337
338pub fn ix_(sequences: &[&[u64]]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
350 let ndim = sequences.len();
351 let mut result = Vec::with_capacity(ndim);
352
353 for (i, seq) in sequences.iter().enumerate() {
354 let mut shape = vec![1usize; ndim];
355 shape[i] = seq.len();
356
357 let data = seq.to_vec();
358 let dim = IxDyn::new(&shape);
359 result.push(Array::from_vec(dim, data)?);
360 }
361
362 Ok(result)
363}
364
365pub fn diag_indices(n: usize, ndim: usize) -> Vec<Vec<usize>> {
374 let data: Vec<usize> = (0..n).collect();
375 vec![data; ndim]
376}
377
378pub fn diag_indices_from<T: Element, D: Dimension>(
386 a: &Array<T, D>,
387) -> FerrayResult<Vec<Vec<usize>>> {
388 let ndim = a.ndim();
389 if ndim < 2 {
390 return Err(FerrayError::invalid_value(
391 "diag_indices_from requires at least 2 dimensions",
392 ));
393 }
394 let shape = a.shape();
395 let n = shape[0];
396 for &s in &shape[1..] {
397 if s != n {
398 return Err(FerrayError::shape_mismatch(format!(
399 "all dimensions must be equal for diag_indices_from, got {:?}",
400 shape
401 )));
402 }
403 }
404 Ok(diag_indices(n, ndim))
405}
406
407pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
417 let m = m.unwrap_or(n);
418 let mut rows = Vec::new();
419 let mut cols = Vec::new();
420
421 for i in 0..n {
422 for j in 0..m {
423 if (j as isize) <= (i as isize) + k {
424 rows.push(i);
425 cols.push(j);
426 }
427 }
428 }
429
430 (rows, cols)
431}
432
433pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
437 let m = m.unwrap_or(n);
438 let mut rows = Vec::new();
439 let mut cols = Vec::new();
440
441 for i in 0..n {
442 for j in 0..m {
443 if (j as isize) >= (i as isize) + k {
444 rows.push(i);
445 cols.push(j);
446 }
447 }
448 }
449
450 (rows, cols)
451}
452
453pub fn tril_indices_from<T: Element, D: Dimension>(
458 a: &Array<T, D>,
459 k: isize,
460) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
461 let shape = a.shape();
462 if shape.len() != 2 {
463 return Err(FerrayError::invalid_value(
464 "tril_indices_from requires a 2-D array",
465 ));
466 }
467 Ok(tril_indices(shape[0], k, Some(shape[1])))
468}
469
470pub fn triu_indices_from<T: Element, D: Dimension>(
475 a: &Array<T, D>,
476 k: isize,
477) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
478 let shape = a.shape();
479 if shape.len() != 2 {
480 return Err(FerrayError::invalid_value(
481 "triu_indices_from requires a 2-D array",
482 ));
483 }
484 Ok(triu_indices(shape[0], k, Some(shape[1])))
485}
486
487#[allow(clippy::needless_range_loop)]
500pub fn ravel_multi_index(multi_index: &[&[usize]], dims: &[usize]) -> FerrayResult<Vec<usize>> {
501 if multi_index.len() != dims.len() {
502 return Err(FerrayError::invalid_value(format!(
503 "multi_index has {} components but dims has {} dimensions",
504 multi_index.len(),
505 dims.len()
506 )));
507 }
508 if multi_index.is_empty() {
509 return Ok(vec![]);
510 }
511
512 let n = multi_index[0].len();
513 for (i, idx_arr) in multi_index.iter().enumerate() {
514 if idx_arr.len() != n {
515 return Err(FerrayError::invalid_value(format!(
516 "multi_index[{}] has length {} but expected {}",
517 i,
518 idx_arr.len(),
519 n
520 )));
521 }
522 }
523
524 let ndim = dims.len();
526 let mut strides = vec![1usize; ndim];
527 for i in (0..ndim - 1).rev() {
528 strides[i] = strides[i + 1] * dims[i + 1];
529 }
530
531 let mut flat = Vec::with_capacity(n);
532 #[allow(clippy::needless_range_loop)]
533 for pos in 0..n {
534 let mut linear = 0usize;
535 for (d, &dim_size) in dims.iter().enumerate() {
536 let coord = multi_index[d][pos];
537 if coord >= dim_size {
538 return Err(FerrayError::index_out_of_bounds(
539 coord as isize,
540 d,
541 dim_size,
542 ));
543 }
544 linear += coord * strides[d];
545 }
546 flat.push(linear);
547 }
548
549 Ok(flat)
550}
551
552pub fn unravel_index(flat_indices: &[usize], shape: &[usize]) -> FerrayResult<Vec<Vec<usize>>> {
560 let total: usize = shape.iter().product();
561 let ndim = shape.len();
562 let n = flat_indices.len();
563
564 let mut result: Vec<Vec<usize>> = vec![Vec::with_capacity(n); ndim];
565
566 for &flat_idx in flat_indices {
567 if flat_idx >= total {
568 return Err(FerrayError::index_out_of_bounds(
569 flat_idx as isize,
570 0,
571 total,
572 ));
573 }
574 let mut rem = flat_idx;
575 for (d, &dim_size) in shape.iter().enumerate().rev() {
576 result[d].push(rem % dim_size);
577 rem /= dim_size;
578 }
579 }
580
581 Ok(result)
582}
583
584pub fn flatnonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<usize> {
593 let zero = T::zero();
594 a.inner
595 .iter()
596 .enumerate()
597 .filter_map(|(i, val)| if *val != zero { Some(i) } else { None })
598 .collect()
599}
600
601pub fn nonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<Vec<usize>> {
621 let zero = T::zero();
622 let ndim = a.ndim();
623 let mut result: Vec<Vec<usize>> = vec![Vec::new(); ndim];
624 for (idx, val) in a.indexed_iter() {
625 if *val != zero {
626 for (d, &c) in idx.iter().enumerate() {
627 result[d].push(c);
628 }
629 }
630 }
631 result
632}
633
634pub fn argwhere<T: Element + PartialEq, D: Dimension>(
647 a: &Array<T, D>,
648) -> FerrayResult<Array<i64, Ix2>> {
649 let zero = T::zero();
650 let ndim = a.ndim();
651 let mut data: Vec<i64> = Vec::new();
652 let mut count: usize = 0;
653 for (idx, val) in a.indexed_iter() {
654 if *val != zero {
655 for &c in &idx {
656 data.push(c as i64);
657 }
658 count += 1;
659 }
660 }
661 Array::<i64, Ix2>::from_vec(Ix2::new([count, ndim]), data)
662}
663
664pub struct NdIndex {
672 shape: Vec<usize>,
673 current: Vec<usize>,
674 done: bool,
675}
676
677impl NdIndex {
678 fn new(shape: &[usize]) -> Self {
679 let done = shape.contains(&0);
680 Self {
681 shape: shape.to_vec(),
682 current: vec![0; shape.len()],
683 done,
684 }
685 }
686}
687
688impl Iterator for NdIndex {
689 type Item = Vec<usize>;
690
691 fn next(&mut self) -> Option<Self::Item> {
692 if self.done {
693 return None;
694 }
695
696 let result = self.current.clone();
697
698 let mut carry = true;
700 for i in (0..self.shape.len()).rev() {
701 if carry {
702 self.current[i] += 1;
703 if self.current[i] >= self.shape[i] {
704 self.current[i] = 0;
705 carry = true;
706 } else {
707 carry = false;
708 }
709 }
710 }
711 if carry {
712 self.done = true;
713 }
714
715 Some(result)
716 }
717
718 fn size_hint(&self) -> (usize, Option<usize>) {
719 if self.done {
720 return (0, Some(0));
721 }
722 let total: usize = self.shape.iter().product();
723 let mut yielded = 0usize;
725 let ndim = self.shape.len();
726 let mut stride = 1usize;
727 for i in (0..ndim).rev() {
728 yielded += self.current[i] * stride;
729 stride *= self.shape[i];
730 }
731 let remaining = total - yielded;
732 (remaining, Some(remaining))
733 }
734}
735
736pub fn ndindex(shape: &[usize]) -> NdIndex {
740 NdIndex::new(shape)
741}
742
743pub fn ndenumerate<'a, T: Element, D: Dimension>(
747 a: &'a Array<T, D>,
748) -> impl Iterator<Item = (Vec<usize>, &'a T)> + 'a {
749 let shape = a.shape().to_vec();
750 let ndim = shape.len();
751 a.inner.iter().enumerate().map(move |(flat_idx, val)| {
752 let mut idx = vec![0usize; ndim];
753 let mut rem = flat_idx;
754 for (d, s) in shape.iter().enumerate().rev() {
755 if *s > 0 {
756 idx[d] = rem % s;
757 rem /= s;
758 }
759 }
760 (idx, val)
761 })
762}
763
764pub fn where_select<T: Element + Copy, D: Dimension>(
780 condition: &Array<bool, D>,
781 x: &Array<T, D>,
782 y: &Array<T, D>,
783) -> FerrayResult<Array<T, D>> {
784 if condition.shape() != x.shape() || condition.shape() != y.shape() {
785 return Err(FerrayError::shape_mismatch(format!(
786 "where_select: condition shape {:?}, x shape {:?}, y shape {:?} must all match",
787 condition.shape(),
788 x.shape(),
789 y.shape()
790 )));
791 }
792 let data: Vec<T> = condition
793 .iter()
794 .zip(x.iter().zip(y.iter()))
795 .map(|(&c, (&xi, &yi))| if c { xi } else { yi })
796 .collect();
797 Array::from_vec(x.dim().clone(), data)
798}
799
800#[cfg(test)]
801mod tests {
802 use super::*;
803 use crate::dimension::{Ix1, Ix2};
804
805 #[test]
810 fn take_1d() {
811 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
812 let taken = take(&arr, &[0, 2, 4], Axis(0)).unwrap();
813 assert_eq!(taken.shape(), &[3]);
814 let data: Vec<i32> = taken.iter().copied().collect();
815 assert_eq!(data, vec![10, 30, 50]);
816 }
817
818 #[test]
819 fn take_2d_axis1() {
820 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
821 let taken = take(&arr, &[0, 2], Axis(1)).unwrap();
822 assert_eq!(taken.shape(), &[3, 2]);
823 let data: Vec<i32> = taken.iter().copied().collect();
824 assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
825 }
826
827 #[test]
828 fn take_negative_indices() {
829 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
830 let taken = take(&arr, &[-1, -3], Axis(0)).unwrap();
831 let data: Vec<i32> = taken.iter().copied().collect();
832 assert_eq!(data, vec![40, 20]);
833 }
834
835 #[test]
840 fn take_along_axis_basic() {
841 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
842 let taken = take_along_axis(&arr, &[1, 3], Axis(1)).unwrap();
843 assert_eq!(taken.shape(), &[3, 2]);
844 }
845
846 #[test]
851 fn put_flat() {
852 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 0, 0, 0, 0]).unwrap();
853 arr.put(&[1, 3], &[99, 88]).unwrap();
854 assert_eq!(arr.as_slice().unwrap(), &[0, 99, 0, 88, 0]);
855 }
856
857 #[test]
858 fn put_cycling_values() {
859 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0; 5]).unwrap();
860 arr.put(&[0, 1, 2, 3, 4], &[10, 20]).unwrap();
861 assert_eq!(arr.as_slice().unwrap(), &[10, 20, 10, 20, 10]);
862 }
863
864 #[test]
865 fn put_out_of_bounds() {
866 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
867 assert!(arr.put(&[5], &[1]).is_err());
868 }
869
870 #[test]
875 fn fill_diagonal_2d() {
876 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 3]), vec![0; 9]).unwrap();
877 arr.fill_diagonal(1);
878 let data: Vec<i32> = arr.iter().copied().collect();
879 assert_eq!(data, vec![1, 0, 0, 0, 1, 0, 0, 0, 1]);
880 }
881
882 #[test]
883 fn fill_diagonal_rectangular() {
884 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 4]), vec![0; 8]).unwrap();
885 arr.fill_diagonal(5);
886 let data: Vec<i32> = arr.iter().copied().collect();
887 assert_eq!(data, vec![5, 0, 0, 0, 0, 5, 0, 0]);
888 }
889
890 #[test]
895 fn choose_basic() {
896 let idx = Array::<u64, Ix1>::from_vec(Ix1::new([4]), vec![0, 1, 0, 1]).unwrap();
897 let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
898 let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![100, 200, 300, 400]).unwrap();
899 let result = choose(&idx, &[c0, c1]).unwrap();
900 let data: Vec<i32> = result.iter().copied().collect();
901 assert_eq!(data, vec![10, 200, 30, 400]);
902 }
903
904 #[test]
905 fn choose_out_of_bounds() {
906 let idx = Array::<u64, Ix1>::from_vec(Ix1::new([2]), vec![0, 2]).unwrap();
907 let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 2]).unwrap();
908 let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3, 4]).unwrap();
909 assert!(choose(&idx, &[c0, c1]).is_err());
910 }
911
912 #[test]
917 fn compress_1d() {
918 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
919 let result = compress(&[true, false, true, false, true], &arr, Axis(0)).unwrap();
920 let data: Vec<i32> = result.iter().copied().collect();
921 assert_eq!(data, vec![10, 30, 50]);
922 }
923
924 #[test]
925 fn compress_2d_axis0() {
926 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
927 let result = compress(&[true, false, true], &arr, Axis(0)).unwrap();
928 assert_eq!(result.shape(), &[2, 4]);
929 let data: Vec<i32> = result.iter().copied().collect();
930 assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
931 }
932
933 #[test]
938 fn select_basic() {
939 let c1 =
940 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, false]).unwrap();
941 let c2 =
942 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
943 let ch1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 1, 1, 1]).unwrap();
944 let ch2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![2, 2, 2, 2]).unwrap();
945 let result = select(&[c1, c2], &[ch1, ch2], 0).unwrap();
946 let data: Vec<i32> = result.iter().copied().collect();
947 assert_eq!(data, vec![1, 2, 0, 0]);
948 }
949
950 #[test]
955 fn indices_2d() {
956 let idx = indices(&[2, 3]).unwrap();
957 assert_eq!(idx.len(), 2);
958 assert_eq!(idx[0].shape(), &[2, 3]);
959 assert_eq!(idx[1].shape(), &[2, 3]);
960 let rows: Vec<u64> = idx[0].iter().copied().collect();
961 assert_eq!(rows, vec![0, 0, 0, 1, 1, 1]);
962 let cols: Vec<u64> = idx[1].iter().copied().collect();
963 assert_eq!(cols, vec![0, 1, 2, 0, 1, 2]);
964 }
965
966 #[test]
971 fn ix_basic() {
972 let result = ix_(&[&[0, 1], &[2, 3, 4]]).unwrap();
973 assert_eq!(result.len(), 2);
974 assert_eq!(result[0].shape(), &[2, 1]);
975 assert_eq!(result[1].shape(), &[1, 3]);
976 }
977
978 #[test]
983 fn diag_indices_basic() {
984 let idx = diag_indices(3, 2);
985 assert_eq!(idx.len(), 2);
986 assert_eq!(idx[0], vec![0, 1, 2]);
987 assert_eq!(idx[1], vec![0, 1, 2]);
988 }
989
990 #[test]
991 fn diag_indices_from_square() {
992 let arr = Array::<f64, Ix2>::zeros(Ix2::new([4, 4])).unwrap();
993 let idx = diag_indices_from(&arr).unwrap();
994 assert_eq!(idx.len(), 2);
995 assert_eq!(idx[0].len(), 4);
996 }
997
998 #[test]
999 fn diag_indices_from_not_square() {
1000 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
1001 assert!(diag_indices_from(&arr).is_err());
1002 }
1003
1004 #[test]
1009 fn tril_indices_basic() {
1010 let (rows, cols) = tril_indices(3, 0, None);
1011 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1012 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1013 }
1014
1015 #[test]
1016 fn triu_indices_basic() {
1017 let (rows, cols) = triu_indices(3, 0, None);
1018 assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
1019 assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
1020 }
1021
1022 #[test]
1023 fn tril_indices_with_k() {
1024 let (rows, cols) = tril_indices(3, 1, None);
1025 assert_eq!(rows, vec![0, 0, 1, 1, 1, 2, 2, 2]);
1026 assert_eq!(cols, vec![0, 1, 0, 1, 2, 0, 1, 2]);
1027 }
1028
1029 #[test]
1030 fn triu_indices_with_negative_k() {
1031 let (rows, cols) = triu_indices(3, -1, None);
1032 assert_eq!(rows, vec![0, 0, 0, 1, 1, 1, 2, 2]);
1033 assert_eq!(cols, vec![0, 1, 2, 0, 1, 2, 1, 2]);
1034 }
1035
1036 #[test]
1037 fn tril_indices_from_test() {
1038 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1039 let (rows, _cols) = tril_indices_from(&arr, 0).unwrap();
1040 assert_eq!(rows.len(), 6);
1041 }
1042
1043 #[test]
1044 fn triu_indices_from_test() {
1045 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1046 let (rows, _cols) = triu_indices_from(&arr, 0).unwrap();
1047 assert_eq!(rows.len(), 6);
1048 }
1049
1050 #[test]
1051 fn tril_indices_rectangular() {
1052 let (rows, cols) = tril_indices(3, 0, Some(4));
1053 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1054 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1055 }
1056
1057 #[test]
1062 fn ravel_multi_index_basic() {
1063 let flat = ravel_multi_index(&[&[0, 1, 2], &[1, 2, 0]], &[3, 4]).unwrap();
1064 assert_eq!(flat, vec![1, 6, 8]);
1065 }
1066
1067 #[test]
1068 fn ravel_multi_index_3d() {
1069 let flat = ravel_multi_index(&[&[0], &[1], &[2]], &[2, 3, 4]).unwrap();
1070 assert_eq!(flat, vec![6]);
1071 }
1072
1073 #[test]
1074 fn ravel_multi_index_out_of_bounds() {
1075 assert!(ravel_multi_index(&[&[3]], &[3]).is_err());
1076 }
1077
1078 #[test]
1079 fn unravel_index_basic() {
1080 let coords = unravel_index(&[1, 6, 8], &[3, 4]).unwrap();
1081 assert_eq!(coords[0], vec![0, 1, 2]);
1082 assert_eq!(coords[1], vec![1, 2, 0]);
1083 }
1084
1085 #[test]
1086 fn unravel_index_out_of_bounds() {
1087 assert!(unravel_index(&[12], &[3, 4]).is_err());
1088 }
1089
1090 #[test]
1091 fn ravel_unravel_roundtrip() {
1092 let dims = &[3, 4, 5];
1093 let a: &[usize] = &[1, 2];
1094 let b: &[usize] = &[2, 3];
1095 let c: &[usize] = &[3, 4];
1096 let multi: &[&[usize]] = &[a, b, c];
1097 let flat = ravel_multi_index(multi, dims).unwrap();
1098 let coords = unravel_index(&flat, dims).unwrap();
1099 assert_eq!(coords[0], vec![1, 2]);
1100 assert_eq!(coords[1], vec![2, 3]);
1101 assert_eq!(coords[2], vec![3, 4]);
1102 }
1103
1104 #[test]
1109 fn flatnonzero_basic() {
1110 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1111 let nz = flatnonzero(&arr);
1112 assert_eq!(nz, vec![1, 3]);
1113 }
1114
1115 #[test]
1116 fn flatnonzero_2d() {
1117 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1118 let nz = flatnonzero(&arr);
1119 assert_eq!(nz, vec![1, 3, 5]);
1120 }
1121
1122 #[test]
1123 fn flatnonzero_all_zero() {
1124 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1125 let nz = flatnonzero(&arr);
1126 assert_eq!(nz.len(), 0);
1127 }
1128
1129 #[test]
1134 fn nonzero_1d() {
1135 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1136 let nz = nonzero(&arr);
1137 assert_eq!(nz.len(), 1);
1139 assert_eq!(nz[0], vec![1, 3]);
1140 }
1141
1142 #[test]
1143 fn nonzero_2d_yields_row_and_col_indices() {
1144 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1148 let nz = nonzero(&arr);
1149 assert_eq!(nz.len(), 2);
1150 assert_eq!(nz[0], vec![0, 1, 1]);
1151 assert_eq!(nz[1], vec![1, 0, 2]);
1152 }
1153
1154 #[test]
1155 fn nonzero_all_zero_returns_empty_per_axis() {
1156 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1157 let nz = nonzero(&arr);
1158 assert_eq!(nz.len(), 2);
1159 assert!(nz[0].is_empty());
1160 assert!(nz[1].is_empty());
1161 }
1162
1163 #[test]
1164 fn nonzero_f64_treats_negative_zero_as_zero() {
1165 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![-0.0, 1.5, 0.0, -2.5]).unwrap();
1167 let nz = nonzero(&arr);
1168 assert_eq!(nz[0], vec![1, 3]);
1169 }
1170
1171 #[test]
1172 fn argwhere_2d_has_one_row_per_nonzero() {
1173 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1175 let coords = argwhere(&arr).unwrap();
1176 assert_eq!(coords.shape(), &[3, 2]);
1177 assert_eq!(coords.as_slice().unwrap(), &[0, 1, 1, 0, 1, 2]);
1178 }
1179
1180 #[test]
1181 fn argwhere_1d_is_column_vector() {
1182 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 7, 0, 9, 3]).unwrap();
1184 let coords = argwhere(&arr).unwrap();
1185 assert_eq!(coords.shape(), &[3, 1]);
1186 assert_eq!(coords.as_slice().unwrap(), &[1, 3, 4]);
1187 }
1188
1189 #[test]
1190 fn argwhere_all_zero_returns_empty() {
1191 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1192 let coords = argwhere(&arr).unwrap();
1193 assert_eq!(coords.shape(), &[0, 2]);
1194 assert_eq!(coords.size(), 0);
1195 }
1196
1197 #[test]
1202 fn ndindex_2d() {
1203 let indices: Vec<Vec<usize>> = ndindex(&[2, 3]).collect();
1204 assert_eq!(indices.len(), 6);
1205 assert_eq!(indices[0], vec![0, 0]);
1206 assert_eq!(indices[1], vec![0, 1]);
1207 assert_eq!(indices[2], vec![0, 2]);
1208 assert_eq!(indices[3], vec![1, 0]);
1209 assert_eq!(indices[4], vec![1, 1]);
1210 assert_eq!(indices[5], vec![1, 2]);
1211 }
1212
1213 #[test]
1214 fn ndindex_1d() {
1215 let indices: Vec<Vec<usize>> = ndindex(&[4]).collect();
1216 assert_eq!(indices.len(), 4);
1217 assert_eq!(indices[0], vec![0]);
1218 assert_eq!(indices[3], vec![3]);
1219 }
1220
1221 #[test]
1222 fn ndindex_empty() {
1223 let indices: Vec<Vec<usize>> = ndindex(&[0]).collect();
1224 assert_eq!(indices.len(), 0);
1225 }
1226
1227 #[test]
1228 fn ndindex_scalar() {
1229 let indices: Vec<Vec<usize>> = ndindex(&[]).collect();
1230 assert_eq!(indices.len(), 1);
1231 assert_eq!(indices[0], Vec::<usize>::new());
1232 }
1233
1234 #[test]
1239 fn ndenumerate_2d() {
1240 let arr =
1241 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
1242 let items: Vec<(Vec<usize>, &i32)> = ndenumerate(&arr).collect();
1243 assert_eq!(items.len(), 6);
1244 assert_eq!(items[0], (vec![0, 0], &10));
1245 assert_eq!(items[1], (vec![0, 1], &20));
1246 assert_eq!(items[5], (vec![1, 2], &60));
1247 }
1248
1249 #[test]
1254 fn put_along_axis_basic() {
1255 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![0; 12]).unwrap();
1256 let values =
1257 Array::<i32, IxDyn>::from_vec(IxDyn::new(&[8]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
1258 arr.put_along_axis(&[0, 2], &values, Axis(0)).unwrap();
1259 let data: Vec<i32> = arr.iter().copied().collect();
1260 assert_eq!(data, vec![1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8]);
1261 }
1262
1263 #[test]
1268 fn where_basic() {
1269 let cond =
1270 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1271 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1272 let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
1273 let result = where_select(&cond, &x, &y).unwrap();
1274 assert_eq!(result.as_slice().unwrap(), &[1.0, 20.0, 3.0, 40.0]);
1275 }
1276
1277 #[test]
1278 fn where_all_true() {
1279 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1280 let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1281 let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1282 let result = where_select(&cond, &x, &y).unwrap();
1283 assert_eq!(result.as_slice().unwrap(), &[1, 2, 3]);
1284 }
1285
1286 #[test]
1287 fn where_all_false() {
1288 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1289 let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1290 let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1291 let result = where_select(&cond, &x, &y).unwrap();
1292 assert_eq!(result.as_slice().unwrap(), &[10, 20, 30]);
1293 }
1294
1295 #[test]
1296 fn where_shape_mismatch() {
1297 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true; 3]).unwrap();
1298 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0; 4]).unwrap();
1299 let y = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0; 3]).unwrap();
1300 assert!(where_select(&cond, &x, &y).is_err());
1301 }
1302
1303 #[test]
1304 fn where_2d() {
1305 let cond =
1306 Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![true, false, false, true]).unwrap();
1307 let x = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
1308 let y = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![10, 20, 30, 40]).unwrap();
1309 let result = where_select(&cond, &x, &y).unwrap();
1310 let data: Vec<i32> = result.iter().copied().collect();
1311 assert_eq!(data, vec![1, 20, 30, 4]);
1312 }
1313}