1use crate::array::owned::Array;
13use crate::dimension::{Axis, Dimension, IxDyn};
14use crate::dtype::Element;
15use crate::error::{FerrayError, FerrayResult};
16
17fn normalize_index(index: isize, size: usize, axis: usize) -> FerrayResult<usize> {
19 if index < 0 {
20 let pos = size as isize + index;
21 if pos < 0 {
22 return Err(FerrayError::index_out_of_bounds(index, axis, size));
23 }
24 Ok(pos as usize)
25 } else {
26 let idx = index as usize;
27 if idx >= size {
28 return Err(FerrayError::index_out_of_bounds(index, axis, size));
29 }
30 Ok(idx)
31 }
32}
33
34pub fn take<T: Element, D: Dimension>(
46 a: &Array<T, D>,
47 indices: &[isize],
48 axis: Axis,
49) -> FerrayResult<Array<T, IxDyn>> {
50 a.index_select(axis, indices)
51}
52
53pub fn take_along_axis<T: Element, D: Dimension>(
63 a: &Array<T, D>,
64 indices: &[isize],
65 axis: Axis,
66) -> FerrayResult<Array<T, IxDyn>> {
67 a.index_select(axis, indices)
68}
69
70impl<T: Element, D: Dimension> Array<T, D> {
75 pub fn put(&mut self, indices: &[isize], values: &[T]) -> FerrayResult<()> {
85 if values.is_empty() {
86 return Err(FerrayError::invalid_value("values must not be empty"));
87 }
88 let size = self.size();
89 let normalized: Vec<usize> = indices
90 .iter()
91 .map(|&idx| normalize_index(idx, size, 0))
92 .collect::<FerrayResult<Vec<_>>>()?;
93
94 let mut flat: Vec<&mut T> = self.inner.iter_mut().collect();
95
96 for (i, &idx) in normalized.iter().enumerate() {
97 let val_idx = i % values.len();
98 *flat[idx] = values[val_idx].clone();
99 }
100 Ok(())
101 }
102
103 pub fn put_along_axis(
112 &mut self,
113 indices: &[isize],
114 values: &Array<T, IxDyn>,
115 axis: Axis,
116 ) -> FerrayResult<()>
117 where
118 D::NdarrayDim: ndarray::RemoveAxis,
119 {
120 let ndim = self.ndim();
121 let ax = axis.index();
122 if ax >= ndim {
123 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
124 }
125 let axis_size = self.shape()[ax];
126
127 let normalized: Vec<usize> = indices
128 .iter()
129 .map(|&idx| normalize_index(idx, axis_size, ax))
130 .collect::<FerrayResult<Vec<_>>>()?;
131
132 let nd_axis = ndarray::Axis(ax);
133 let mut val_iter = values.inner.iter();
134
135 for &idx in &normalized {
136 let mut sub = self.inner.index_axis_mut(nd_axis, idx);
137 for elem in sub.iter_mut() {
138 if let Some(v) = val_iter.next() {
139 *elem = v.clone();
140 }
141 }
142 }
143 Ok(())
144 }
145
146 pub fn fill_diagonal(&mut self, val: T) {
153 let shape = self.shape().to_vec();
154 if shape.is_empty() {
155 return;
156 }
157 let min_dim = *shape.iter().min().unwrap_or(&0);
158 let ndim = shape.len();
159
160 for i in 0..min_dim {
161 let idx: Vec<usize> = vec![i; ndim];
162 let nd_idx = ndarray::IxDyn(&idx);
163 let mut dyn_view = self.inner.view_mut().into_dyn();
164 dyn_view[nd_idx] = val.clone();
165 }
166 }
167}
168
169pub fn choose<T: Element, D: Dimension>(
183 index_arr: &Array<u64, D>,
184 choices: &[Array<T, D>],
185) -> FerrayResult<Array<T, IxDyn>> {
186 if choices.is_empty() {
187 return Err(FerrayError::invalid_value("choices must not be empty"));
188 }
189
190 let shape = index_arr.shape();
191 for (i, c) in choices.iter().enumerate() {
192 if c.shape() != shape {
193 return Err(FerrayError::shape_mismatch(format!(
194 "choice[{}] shape {:?} does not match index array shape {:?}",
195 i,
196 c.shape(),
197 shape
198 )));
199 }
200 }
201
202 let n_choices = choices.len();
203 let choice_iters: Vec<Vec<T>> = choices
204 .iter()
205 .map(|c| c.inner.iter().cloned().collect())
206 .collect();
207
208 let mut data = Vec::with_capacity(index_arr.size());
209 for (pos, idx_val) in index_arr.inner.iter().enumerate() {
210 let idx = *idx_val as usize;
211 if idx >= n_choices {
212 return Err(FerrayError::index_out_of_bounds(idx as isize, 0, n_choices));
213 }
214 data.push(choice_iters[idx][pos].clone());
215 }
216
217 let dyn_shape = IxDyn::new(shape);
218 Array::from_vec(dyn_shape, data)
219}
220
221pub fn compress<T: Element, D: Dimension>(
233 condition: &[bool],
234 a: &Array<T, D>,
235 axis: Axis,
236) -> FerrayResult<Array<T, IxDyn>> {
237 let ndim = a.ndim();
238 let ax = axis.index();
239 if ax >= ndim {
240 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
241 }
242 let axis_size = a.shape()[ax];
243 if condition.len() > axis_size {
244 return Err(FerrayError::shape_mismatch(format!(
245 "condition length {} exceeds axis size {}",
246 condition.len(),
247 axis_size
248 )));
249 }
250
251 let indices: Vec<isize> = condition
252 .iter()
253 .enumerate()
254 .filter_map(|(i, &c)| if c { Some(i as isize) } else { None })
255 .collect();
256
257 a.index_select(axis, &indices)
258}
259
260pub fn select<T: Element, D: Dimension>(
273 condlist: &[Array<bool, D>],
274 choicelist: &[Array<T, D>],
275 default: T,
276) -> FerrayResult<Array<T, IxDyn>> {
277 if condlist.len() != choicelist.len() {
278 return Err(FerrayError::invalid_value(format!(
279 "condlist length {} != choicelist length {}",
280 condlist.len(),
281 choicelist.len()
282 )));
283 }
284 if condlist.is_empty() {
285 return Err(FerrayError::invalid_value(
286 "condlist and choicelist must not be empty",
287 ));
288 }
289
290 let shape = condlist[0].shape();
291 for (i, (c, ch)) in condlist.iter().zip(choicelist.iter()).enumerate() {
292 if c.shape() != shape || ch.shape() != shape {
293 return Err(FerrayError::shape_mismatch(format!(
294 "condlist[{}]/choicelist[{}] shape mismatch with reference shape {:?}",
295 i, i, shape
296 )));
297 }
298 }
299
300 let size = condlist[0].size();
301 let mut data = vec![default; size];
302
303 for (cond, choice) in condlist.iter().zip(choicelist.iter()).rev() {
305 for (i, (&c, v)) in cond.inner.iter().zip(choice.inner.iter()).enumerate() {
306 if c {
307 data[i] = v.clone();
308 }
309 }
310 }
311
312 let dyn_shape = IxDyn::new(shape);
313 Array::from_vec(dyn_shape, data)
314}
315
316pub fn indices(dimensions: &[usize]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
328 let ndim = dimensions.len();
329 let total: usize = dimensions.iter().product();
330
331 let mut result = Vec::with_capacity(ndim);
332
333 for ax in 0..ndim {
334 let mut data = Vec::with_capacity(total);
335 for flat_idx in 0..total {
336 let mut rem = flat_idx;
337 let mut idx_for_ax = 0;
338 for (d, &dim_size) in dimensions.iter().enumerate().rev() {
339 let coord = rem % dim_size;
340 rem /= dim_size;
341 if d == ax {
342 idx_for_ax = coord;
343 }
344 }
345 data.push(idx_for_ax as u64);
346 }
347 let dim = IxDyn::new(dimensions);
348 result.push(Array::from_vec(dim, data)?);
349 }
350
351 Ok(result)
352}
353
354pub fn ix_(sequences: &[&[u64]]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
366 let ndim = sequences.len();
367 let mut result = Vec::with_capacity(ndim);
368
369 for (i, seq) in sequences.iter().enumerate() {
370 let mut shape = vec![1usize; ndim];
371 shape[i] = seq.len();
372
373 let data = seq.to_vec();
374 let dim = IxDyn::new(&shape);
375 result.push(Array::from_vec(dim, data)?);
376 }
377
378 Ok(result)
379}
380
381pub fn diag_indices(n: usize, ndim: usize) -> Vec<Vec<usize>> {
390 let data: Vec<usize> = (0..n).collect();
391 vec![data; ndim]
392}
393
394pub fn diag_indices_from<T: Element, D: Dimension>(
402 a: &Array<T, D>,
403) -> FerrayResult<Vec<Vec<usize>>> {
404 let ndim = a.ndim();
405 if ndim < 2 {
406 return Err(FerrayError::invalid_value(
407 "diag_indices_from requires at least 2 dimensions",
408 ));
409 }
410 let shape = a.shape();
411 let n = shape[0];
412 for &s in &shape[1..] {
413 if s != n {
414 return Err(FerrayError::shape_mismatch(format!(
415 "all dimensions must be equal for diag_indices_from, got {:?}",
416 shape
417 )));
418 }
419 }
420 Ok(diag_indices(n, ndim))
421}
422
423pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
433 let m = m.unwrap_or(n);
434 let mut rows = Vec::new();
435 let mut cols = Vec::new();
436
437 for i in 0..n {
438 for j in 0..m {
439 if (j as isize) <= (i as isize) + k {
440 rows.push(i);
441 cols.push(j);
442 }
443 }
444 }
445
446 (rows, cols)
447}
448
449pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
453 let m = m.unwrap_or(n);
454 let mut rows = Vec::new();
455 let mut cols = Vec::new();
456
457 for i in 0..n {
458 for j in 0..m {
459 if (j as isize) >= (i as isize) + k {
460 rows.push(i);
461 cols.push(j);
462 }
463 }
464 }
465
466 (rows, cols)
467}
468
469pub fn tril_indices_from<T: Element, D: Dimension>(
474 a: &Array<T, D>,
475 k: isize,
476) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
477 let shape = a.shape();
478 if shape.len() != 2 {
479 return Err(FerrayError::invalid_value(
480 "tril_indices_from requires a 2-D array",
481 ));
482 }
483 Ok(tril_indices(shape[0], k, Some(shape[1])))
484}
485
486pub fn triu_indices_from<T: Element, D: Dimension>(
491 a: &Array<T, D>,
492 k: isize,
493) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
494 let shape = a.shape();
495 if shape.len() != 2 {
496 return Err(FerrayError::invalid_value(
497 "triu_indices_from requires a 2-D array",
498 ));
499 }
500 Ok(triu_indices(shape[0], k, Some(shape[1])))
501}
502
503#[allow(clippy::needless_range_loop)]
516pub fn ravel_multi_index(multi_index: &[&[usize]], dims: &[usize]) -> FerrayResult<Vec<usize>> {
517 if multi_index.len() != dims.len() {
518 return Err(FerrayError::invalid_value(format!(
519 "multi_index has {} components but dims has {} dimensions",
520 multi_index.len(),
521 dims.len()
522 )));
523 }
524 if multi_index.is_empty() {
525 return Ok(vec![]);
526 }
527
528 let n = multi_index[0].len();
529 for (i, idx_arr) in multi_index.iter().enumerate() {
530 if idx_arr.len() != n {
531 return Err(FerrayError::invalid_value(format!(
532 "multi_index[{}] has length {} but expected {}",
533 i,
534 idx_arr.len(),
535 n
536 )));
537 }
538 }
539
540 let ndim = dims.len();
542 let mut strides = vec![1usize; ndim];
543 for i in (0..ndim - 1).rev() {
544 strides[i] = strides[i + 1] * dims[i + 1];
545 }
546
547 let mut flat = Vec::with_capacity(n);
548 #[allow(clippy::needless_range_loop)]
549 for pos in 0..n {
550 let mut linear = 0usize;
551 for (d, &dim_size) in dims.iter().enumerate() {
552 let coord = multi_index[d][pos];
553 if coord >= dim_size {
554 return Err(FerrayError::index_out_of_bounds(
555 coord as isize,
556 d,
557 dim_size,
558 ));
559 }
560 linear += coord * strides[d];
561 }
562 flat.push(linear);
563 }
564
565 Ok(flat)
566}
567
568pub fn unravel_index(flat_indices: &[usize], shape: &[usize]) -> FerrayResult<Vec<Vec<usize>>> {
576 let total: usize = shape.iter().product();
577 let ndim = shape.len();
578 let n = flat_indices.len();
579
580 let mut result: Vec<Vec<usize>> = vec![Vec::with_capacity(n); ndim];
581
582 for &flat_idx in flat_indices {
583 if flat_idx >= total {
584 return Err(FerrayError::index_out_of_bounds(
585 flat_idx as isize,
586 0,
587 total,
588 ));
589 }
590 let mut rem = flat_idx;
591 for (d, &dim_size) in shape.iter().enumerate().rev() {
592 result[d].push(rem % dim_size);
593 rem /= dim_size;
594 }
595 }
596
597 Ok(result)
598}
599
600pub fn flatnonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<usize> {
609 let zero = T::zero();
610 a.inner
611 .iter()
612 .enumerate()
613 .filter_map(|(i, val)| if *val != zero { Some(i) } else { None })
614 .collect()
615}
616
617pub struct NdIndex {
625 shape: Vec<usize>,
626 current: Vec<usize>,
627 done: bool,
628}
629
630impl NdIndex {
631 fn new(shape: &[usize]) -> Self {
632 let done = shape.contains(&0);
633 Self {
634 shape: shape.to_vec(),
635 current: vec![0; shape.len()],
636 done,
637 }
638 }
639}
640
641impl Iterator for NdIndex {
642 type Item = Vec<usize>;
643
644 fn next(&mut self) -> Option<Self::Item> {
645 if self.done {
646 return None;
647 }
648
649 let result = self.current.clone();
650
651 let mut carry = true;
653 for i in (0..self.shape.len()).rev() {
654 if carry {
655 self.current[i] += 1;
656 if self.current[i] >= self.shape[i] {
657 self.current[i] = 0;
658 carry = true;
659 } else {
660 carry = false;
661 }
662 }
663 }
664 if carry {
665 self.done = true;
666 }
667
668 Some(result)
669 }
670
671 fn size_hint(&self) -> (usize, Option<usize>) {
672 if self.done {
673 return (0, Some(0));
674 }
675 let total: usize = self.shape.iter().product();
676 let mut yielded = 0usize;
678 let ndim = self.shape.len();
679 let mut stride = 1usize;
680 for i in (0..ndim).rev() {
681 yielded += self.current[i] * stride;
682 stride *= self.shape[i];
683 }
684 let remaining = total - yielded;
685 (remaining, Some(remaining))
686 }
687}
688
689pub fn ndindex(shape: &[usize]) -> NdIndex {
693 NdIndex::new(shape)
694}
695
696pub fn ndenumerate<'a, T: Element, D: Dimension>(
700 a: &'a Array<T, D>,
701) -> impl Iterator<Item = (Vec<usize>, &'a T)> + 'a {
702 let shape = a.shape().to_vec();
703 let ndim = shape.len();
704 a.inner.iter().enumerate().map(move |(flat_idx, val)| {
705 let mut idx = vec![0usize; ndim];
706 let mut rem = flat_idx;
707 for (d, s) in shape.iter().enumerate().rev() {
708 if *s > 0 {
709 idx[d] = rem % s;
710 rem /= s;
711 }
712 }
713 (idx, val)
714 })
715}
716
717#[cfg(test)]
718mod tests {
719 use super::*;
720 use crate::dimension::{Ix1, Ix2};
721
722 #[test]
727 fn take_1d() {
728 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
729 let taken = take(&arr, &[0, 2, 4], Axis(0)).unwrap();
730 assert_eq!(taken.shape(), &[3]);
731 let data: Vec<i32> = taken.iter().copied().collect();
732 assert_eq!(data, vec![10, 30, 50]);
733 }
734
735 #[test]
736 fn take_2d_axis1() {
737 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
738 let taken = take(&arr, &[0, 2], Axis(1)).unwrap();
739 assert_eq!(taken.shape(), &[3, 2]);
740 let data: Vec<i32> = taken.iter().copied().collect();
741 assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
742 }
743
744 #[test]
745 fn take_negative_indices() {
746 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
747 let taken = take(&arr, &[-1, -3], Axis(0)).unwrap();
748 let data: Vec<i32> = taken.iter().copied().collect();
749 assert_eq!(data, vec![40, 20]);
750 }
751
752 #[test]
757 fn take_along_axis_basic() {
758 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
759 let taken = take_along_axis(&arr, &[1, 3], Axis(1)).unwrap();
760 assert_eq!(taken.shape(), &[3, 2]);
761 }
762
763 #[test]
768 fn put_flat() {
769 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 0, 0, 0, 0]).unwrap();
770 arr.put(&[1, 3], &[99, 88]).unwrap();
771 assert_eq!(arr.as_slice().unwrap(), &[0, 99, 0, 88, 0]);
772 }
773
774 #[test]
775 fn put_cycling_values() {
776 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0; 5]).unwrap();
777 arr.put(&[0, 1, 2, 3, 4], &[10, 20]).unwrap();
778 assert_eq!(arr.as_slice().unwrap(), &[10, 20, 10, 20, 10]);
779 }
780
781 #[test]
782 fn put_out_of_bounds() {
783 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
784 assert!(arr.put(&[5], &[1]).is_err());
785 }
786
787 #[test]
792 fn fill_diagonal_2d() {
793 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 3]), vec![0; 9]).unwrap();
794 arr.fill_diagonal(1);
795 let data: Vec<i32> = arr.iter().copied().collect();
796 assert_eq!(data, vec![1, 0, 0, 0, 1, 0, 0, 0, 1]);
797 }
798
799 #[test]
800 fn fill_diagonal_rectangular() {
801 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 4]), vec![0; 8]).unwrap();
802 arr.fill_diagonal(5);
803 let data: Vec<i32> = arr.iter().copied().collect();
804 assert_eq!(data, vec![5, 0, 0, 0, 0, 5, 0, 0]);
805 }
806
807 #[test]
812 fn choose_basic() {
813 let idx = Array::<u64, Ix1>::from_vec(Ix1::new([4]), vec![0, 1, 0, 1]).unwrap();
814 let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
815 let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![100, 200, 300, 400]).unwrap();
816 let result = choose(&idx, &[c0, c1]).unwrap();
817 let data: Vec<i32> = result.iter().copied().collect();
818 assert_eq!(data, vec![10, 200, 30, 400]);
819 }
820
821 #[test]
822 fn choose_out_of_bounds() {
823 let idx = Array::<u64, Ix1>::from_vec(Ix1::new([2]), vec![0, 2]).unwrap();
824 let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 2]).unwrap();
825 let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3, 4]).unwrap();
826 assert!(choose(&idx, &[c0, c1]).is_err());
827 }
828
829 #[test]
834 fn compress_1d() {
835 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
836 let result = compress(&[true, false, true, false, true], &arr, Axis(0)).unwrap();
837 let data: Vec<i32> = result.iter().copied().collect();
838 assert_eq!(data, vec![10, 30, 50]);
839 }
840
841 #[test]
842 fn compress_2d_axis0() {
843 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
844 let result = compress(&[true, false, true], &arr, Axis(0)).unwrap();
845 assert_eq!(result.shape(), &[2, 4]);
846 let data: Vec<i32> = result.iter().copied().collect();
847 assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
848 }
849
850 #[test]
855 fn select_basic() {
856 let c1 =
857 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, false]).unwrap();
858 let c2 =
859 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
860 let ch1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 1, 1, 1]).unwrap();
861 let ch2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![2, 2, 2, 2]).unwrap();
862 let result = select(&[c1, c2], &[ch1, ch2], 0).unwrap();
863 let data: Vec<i32> = result.iter().copied().collect();
864 assert_eq!(data, vec![1, 2, 0, 0]);
865 }
866
867 #[test]
872 fn indices_2d() {
873 let idx = indices(&[2, 3]).unwrap();
874 assert_eq!(idx.len(), 2);
875 assert_eq!(idx[0].shape(), &[2, 3]);
876 assert_eq!(idx[1].shape(), &[2, 3]);
877 let rows: Vec<u64> = idx[0].iter().copied().collect();
878 assert_eq!(rows, vec![0, 0, 0, 1, 1, 1]);
879 let cols: Vec<u64> = idx[1].iter().copied().collect();
880 assert_eq!(cols, vec![0, 1, 2, 0, 1, 2]);
881 }
882
883 #[test]
888 fn ix_basic() {
889 let result = ix_(&[&[0, 1], &[2, 3, 4]]).unwrap();
890 assert_eq!(result.len(), 2);
891 assert_eq!(result[0].shape(), &[2, 1]);
892 assert_eq!(result[1].shape(), &[1, 3]);
893 }
894
895 #[test]
900 fn diag_indices_basic() {
901 let idx = diag_indices(3, 2);
902 assert_eq!(idx.len(), 2);
903 assert_eq!(idx[0], vec![0, 1, 2]);
904 assert_eq!(idx[1], vec![0, 1, 2]);
905 }
906
907 #[test]
908 fn diag_indices_from_square() {
909 let arr = Array::<f64, Ix2>::zeros(Ix2::new([4, 4])).unwrap();
910 let idx = diag_indices_from(&arr).unwrap();
911 assert_eq!(idx.len(), 2);
912 assert_eq!(idx[0].len(), 4);
913 }
914
915 #[test]
916 fn diag_indices_from_not_square() {
917 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
918 assert!(diag_indices_from(&arr).is_err());
919 }
920
921 #[test]
926 fn tril_indices_basic() {
927 let (rows, cols) = tril_indices(3, 0, None);
928 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
929 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
930 }
931
932 #[test]
933 fn triu_indices_basic() {
934 let (rows, cols) = triu_indices(3, 0, None);
935 assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
936 assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
937 }
938
939 #[test]
940 fn tril_indices_with_k() {
941 let (rows, cols) = tril_indices(3, 1, None);
942 assert_eq!(rows, vec![0, 0, 1, 1, 1, 2, 2, 2]);
943 assert_eq!(cols, vec![0, 1, 0, 1, 2, 0, 1, 2]);
944 }
945
946 #[test]
947 fn triu_indices_with_negative_k() {
948 let (rows, cols) = triu_indices(3, -1, None);
949 assert_eq!(rows, vec![0, 0, 0, 1, 1, 1, 2, 2]);
950 assert_eq!(cols, vec![0, 1, 2, 0, 1, 2, 1, 2]);
951 }
952
953 #[test]
954 fn tril_indices_from_test() {
955 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
956 let (rows, _cols) = tril_indices_from(&arr, 0).unwrap();
957 assert_eq!(rows.len(), 6);
958 }
959
960 #[test]
961 fn triu_indices_from_test() {
962 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
963 let (rows, _cols) = triu_indices_from(&arr, 0).unwrap();
964 assert_eq!(rows.len(), 6);
965 }
966
967 #[test]
968 fn tril_indices_rectangular() {
969 let (rows, cols) = tril_indices(3, 0, Some(4));
970 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
971 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
972 }
973
974 #[test]
979 fn ravel_multi_index_basic() {
980 let flat = ravel_multi_index(&[&[0, 1, 2], &[1, 2, 0]], &[3, 4]).unwrap();
981 assert_eq!(flat, vec![1, 6, 8]);
982 }
983
984 #[test]
985 fn ravel_multi_index_3d() {
986 let flat = ravel_multi_index(&[&[0], &[1], &[2]], &[2, 3, 4]).unwrap();
987 assert_eq!(flat, vec![6]);
988 }
989
990 #[test]
991 fn ravel_multi_index_out_of_bounds() {
992 assert!(ravel_multi_index(&[&[3]], &[3]).is_err());
993 }
994
995 #[test]
996 fn unravel_index_basic() {
997 let coords = unravel_index(&[1, 6, 8], &[3, 4]).unwrap();
998 assert_eq!(coords[0], vec![0, 1, 2]);
999 assert_eq!(coords[1], vec![1, 2, 0]);
1000 }
1001
1002 #[test]
1003 fn unravel_index_out_of_bounds() {
1004 assert!(unravel_index(&[12], &[3, 4]).is_err());
1005 }
1006
1007 #[test]
1008 fn ravel_unravel_roundtrip() {
1009 let dims = &[3, 4, 5];
1010 let a: &[usize] = &[1, 2];
1011 let b: &[usize] = &[2, 3];
1012 let c: &[usize] = &[3, 4];
1013 let multi: &[&[usize]] = &[a, b, c];
1014 let flat = ravel_multi_index(multi, dims).unwrap();
1015 let coords = unravel_index(&flat, dims).unwrap();
1016 assert_eq!(coords[0], vec![1, 2]);
1017 assert_eq!(coords[1], vec![2, 3]);
1018 assert_eq!(coords[2], vec![3, 4]);
1019 }
1020
1021 #[test]
1026 fn flatnonzero_basic() {
1027 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1028 let nz = flatnonzero(&arr);
1029 assert_eq!(nz, vec![1, 3]);
1030 }
1031
1032 #[test]
1033 fn flatnonzero_2d() {
1034 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1035 let nz = flatnonzero(&arr);
1036 assert_eq!(nz, vec![1, 3, 5]);
1037 }
1038
1039 #[test]
1040 fn flatnonzero_all_zero() {
1041 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1042 let nz = flatnonzero(&arr);
1043 assert_eq!(nz.len(), 0);
1044 }
1045
1046 #[test]
1051 fn ndindex_2d() {
1052 let indices: Vec<Vec<usize>> = ndindex(&[2, 3]).collect();
1053 assert_eq!(indices.len(), 6);
1054 assert_eq!(indices[0], vec![0, 0]);
1055 assert_eq!(indices[1], vec![0, 1]);
1056 assert_eq!(indices[2], vec![0, 2]);
1057 assert_eq!(indices[3], vec![1, 0]);
1058 assert_eq!(indices[4], vec![1, 1]);
1059 assert_eq!(indices[5], vec![1, 2]);
1060 }
1061
1062 #[test]
1063 fn ndindex_1d() {
1064 let indices: Vec<Vec<usize>> = ndindex(&[4]).collect();
1065 assert_eq!(indices.len(), 4);
1066 assert_eq!(indices[0], vec![0]);
1067 assert_eq!(indices[3], vec![3]);
1068 }
1069
1070 #[test]
1071 fn ndindex_empty() {
1072 let indices: Vec<Vec<usize>> = ndindex(&[0]).collect();
1073 assert_eq!(indices.len(), 0);
1074 }
1075
1076 #[test]
1077 fn ndindex_scalar() {
1078 let indices: Vec<Vec<usize>> = ndindex(&[]).collect();
1079 assert_eq!(indices.len(), 1);
1080 assert_eq!(indices[0], Vec::<usize>::new());
1081 }
1082
1083 #[test]
1088 fn ndenumerate_2d() {
1089 let arr =
1090 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
1091 let items: Vec<(Vec<usize>, &i32)> = ndenumerate(&arr).collect();
1092 assert_eq!(items.len(), 6);
1093 assert_eq!(items[0], (vec![0, 0], &10));
1094 assert_eq!(items[1], (vec![0, 1], &20));
1095 assert_eq!(items[5], (vec![1, 2], &60));
1096 }
1097
1098 #[test]
1103 fn put_along_axis_basic() {
1104 let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![0; 12]).unwrap();
1105 let values =
1106 Array::<i32, IxDyn>::from_vec(IxDyn::new(&[8]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
1107 arr.put_along_axis(&[0, 2], &values, Axis(0)).unwrap();
1108 let data: Vec<i32> = arr.iter().copied().collect();
1109 assert_eq!(data, vec![1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8]);
1110 }
1111}