Skip to main content

ferray_core/indexing/
extended.rs

1// ferray-core: Extended indexing functions (REQ-15a)
2//
3// take, take_along_axis, put, put_along_axis, choose, compress, select,
4// indices, ix_, diag_indices, diag_indices_from, tril_indices, triu_indices,
5// tril_indices_from, triu_indices_from, ravel_multi_index, unravel_index,
6// flatnonzero, fill_diagonal, ndindex (iterator), ndenumerate (iterator)
7//
8// Index-returning functions return Vec<Vec<usize>> or (Vec<usize>, Vec<usize>)
9// because usize is not an Element type. Callers can wrap these into arrays
10// of u64 or i64 if needed.
11
12use crate::array::owned::Array;
13use crate::dimension::{Axis, Dimension, IxDyn};
14use crate::dtype::Element;
15use crate::error::{FerrayError, FerrayResult};
16
17/// Normalize a potentially negative index, returning an error on out-of-bounds.
18fn normalize_index(index: isize, size: usize, axis: usize) -> FerrayResult<usize> {
19    if index < 0 {
20        let pos = size as isize + index;
21        if pos < 0 {
22            return Err(FerrayError::index_out_of_bounds(index, axis, size));
23        }
24        Ok(pos as usize)
25    } else {
26        let idx = index as usize;
27        if idx >= size {
28            return Err(FerrayError::index_out_of_bounds(index, axis, size));
29        }
30        Ok(idx)
31    }
32}
33
34// ===========================================================================
35// take / take_along_axis
36// ===========================================================================
37
38/// Take elements from an array along an axis.
39///
40/// Equivalent to `np.take(a, indices, axis)`. Returns a copy.
41///
42/// # Errors
43/// - `AxisOutOfBounds` if `axis >= ndim`
44/// - `IndexOutOfBounds` if any index is out of range
45pub fn take<T: Element, D: Dimension>(
46    a: &Array<T, D>,
47    indices: &[isize],
48    axis: Axis,
49) -> FerrayResult<Array<T, IxDyn>> {
50    a.index_select(axis, indices)
51}
52
53/// Take values from an array along an axis using an index slice.
54///
55/// Similar to `np.take_along_axis`. The `indices` slice contains
56/// indices into `a` along the specified axis. The result replaces
57/// the `axis` dimension with the indices dimension.
58///
59/// # Errors
60/// - `AxisOutOfBounds` if `axis >= ndim`
61/// - `IndexOutOfBounds` if any index is out of range
62pub fn take_along_axis<T: Element, D: Dimension>(
63    a: &Array<T, D>,
64    indices: &[isize],
65    axis: Axis,
66) -> FerrayResult<Array<T, IxDyn>> {
67    a.index_select(axis, indices)
68}
69
70// ===========================================================================
71// put / put_along_axis
72// ===========================================================================
73
74impl<T: Element, D: Dimension> Array<T, D> {
75    /// Put values into the flattened array at the given indices.
76    ///
77    /// Equivalent to `np.put(a, ind, v)`. Modifies the array in-place.
78    /// Indices refer to the flattened (row-major) array.
79    /// Values are cycled if fewer than indices.
80    ///
81    /// # Errors
82    /// - `IndexOutOfBounds` if any index is out of range
83    /// - `InvalidValue` if values is empty
84    pub fn put(&mut self, indices: &[isize], values: &[T]) -> FerrayResult<()> {
85        if values.is_empty() {
86            return Err(FerrayError::invalid_value("values must not be empty"));
87        }
88        let size = self.size();
89        let normalized: Vec<usize> = indices
90            .iter()
91            .map(|&idx| normalize_index(idx, size, 0))
92            .collect::<FerrayResult<Vec<_>>>()?;
93
94        let mut flat: Vec<&mut T> = self.inner.iter_mut().collect();
95
96        for (i, &idx) in normalized.iter().enumerate() {
97            let val_idx = i % values.len();
98            *flat[idx] = values[val_idx].clone();
99        }
100        Ok(())
101    }
102
103    /// Put values along an axis at specified indices.
104    ///
105    /// For each index position along `axis`, assigns the values from the
106    /// corresponding sub-array of `values`.
107    ///
108    /// # Errors
109    /// - `AxisOutOfBounds` if `axis >= ndim`
110    /// - `IndexOutOfBounds` if any index is out of range
111    pub fn put_along_axis(
112        &mut self,
113        indices: &[isize],
114        values: &Array<T, IxDyn>,
115        axis: Axis,
116    ) -> FerrayResult<()>
117    where
118        D::NdarrayDim: ndarray::RemoveAxis,
119    {
120        let ndim = self.ndim();
121        let ax = axis.index();
122        if ax >= ndim {
123            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
124        }
125        let axis_size = self.shape()[ax];
126
127        let normalized: Vec<usize> = indices
128            .iter()
129            .map(|&idx| normalize_index(idx, axis_size, ax))
130            .collect::<FerrayResult<Vec<_>>>()?;
131
132        let nd_axis = ndarray::Axis(ax);
133        let mut val_iter = values.inner.iter();
134
135        for &idx in &normalized {
136            let mut sub = self.inner.index_axis_mut(nd_axis, idx);
137            for elem in sub.iter_mut() {
138                if let Some(v) = val_iter.next() {
139                    *elem = v.clone();
140                }
141            }
142        }
143        Ok(())
144    }
145
146    /// Fill the main diagonal of a 2-D (or N-D) array with a value.
147    ///
148    /// For N-D arrays, the diagonal consists of indices where all
149    /// index values are equal: `a[i, i, ..., i]`.
150    ///
151    /// Equivalent to `np.fill_diagonal(a, val)`.
152    pub fn fill_diagonal(&mut self, val: T) {
153        let shape = self.shape().to_vec();
154        if shape.is_empty() {
155            return;
156        }
157        let min_dim = *shape.iter().min().unwrap_or(&0);
158        let ndim = shape.len();
159
160        for i in 0..min_dim {
161            let idx: Vec<usize> = vec![i; ndim];
162            let nd_idx = ndarray::IxDyn(&idx);
163            let mut dyn_view = self.inner.view_mut().into_dyn();
164            dyn_view[nd_idx] = val.clone();
165        }
166    }
167}
168
169// ===========================================================================
170// choose
171// ===========================================================================
172
173/// Construct an array from an index array and a list of arrays to choose from.
174///
175/// Equivalent to `np.choose(a, choices)`. For each element in `index_arr`,
176/// the value selects which choice array to pick from at that position.
177/// Index values are given as `u64` to avoid the `usize` Element issue.
178///
179/// # Errors
180/// - `IndexOutOfBounds` if any index in `index_arr` is >= `choices.len()`
181/// - `ShapeMismatch` if choice arrays have different shapes from `index_arr`
182pub fn choose<T: Element, D: Dimension>(
183    index_arr: &Array<u64, D>,
184    choices: &[Array<T, D>],
185) -> FerrayResult<Array<T, IxDyn>> {
186    if choices.is_empty() {
187        return Err(FerrayError::invalid_value("choices must not be empty"));
188    }
189
190    let shape = index_arr.shape();
191    for (i, c) in choices.iter().enumerate() {
192        if c.shape() != shape {
193            return Err(FerrayError::shape_mismatch(format!(
194                "choice[{}] shape {:?} does not match index array shape {:?}",
195                i,
196                c.shape(),
197                shape
198            )));
199        }
200    }
201
202    let n_choices = choices.len();
203    let choice_iters: Vec<Vec<T>> = choices
204        .iter()
205        .map(|c| c.inner.iter().cloned().collect())
206        .collect();
207
208    let mut data = Vec::with_capacity(index_arr.size());
209    for (pos, idx_val) in index_arr.inner.iter().enumerate() {
210        let idx = *idx_val as usize;
211        if idx >= n_choices {
212            return Err(FerrayError::index_out_of_bounds(idx as isize, 0, n_choices));
213        }
214        data.push(choice_iters[idx][pos].clone());
215    }
216
217    let dyn_shape = IxDyn::new(shape);
218    Array::from_vec(dyn_shape, data)
219}
220
221// ===========================================================================
222// compress
223// ===========================================================================
224
225/// Select slices of an array along an axis where `condition` is true.
226///
227/// Equivalent to `np.compress(condition, a, axis)`.
228///
229/// # Errors
230/// - `AxisOutOfBounds` if `axis >= ndim`
231/// - `ShapeMismatch` if `condition.len()` exceeds axis size
232pub fn compress<T: Element, D: Dimension>(
233    condition: &[bool],
234    a: &Array<T, D>,
235    axis: Axis,
236) -> FerrayResult<Array<T, IxDyn>> {
237    let ndim = a.ndim();
238    let ax = axis.index();
239    if ax >= ndim {
240        return Err(FerrayError::axis_out_of_bounds(ax, ndim));
241    }
242    let axis_size = a.shape()[ax];
243    if condition.len() > axis_size {
244        return Err(FerrayError::shape_mismatch(format!(
245            "condition length {} exceeds axis size {}",
246            condition.len(),
247            axis_size
248        )));
249    }
250
251    let indices: Vec<isize> = condition
252        .iter()
253        .enumerate()
254        .filter_map(|(i, &c)| if c { Some(i as isize) } else { None })
255        .collect();
256
257    a.index_select(axis, &indices)
258}
259
260// ===========================================================================
261// select
262// ===========================================================================
263
264/// Return an array drawn from elements in choicelist, depending on conditions.
265///
266/// Equivalent to `np.select(condlist, choicelist, default)`.
267/// The first condition that is true determines which choice is used.
268///
269/// # Errors
270/// - `InvalidValue` if condlist and choicelist have different lengths
271/// - `ShapeMismatch` if shapes are incompatible
272pub fn select<T: Element, D: Dimension>(
273    condlist: &[Array<bool, D>],
274    choicelist: &[Array<T, D>],
275    default: T,
276) -> FerrayResult<Array<T, IxDyn>> {
277    if condlist.len() != choicelist.len() {
278        return Err(FerrayError::invalid_value(format!(
279            "condlist length {} != choicelist length {}",
280            condlist.len(),
281            choicelist.len()
282        )));
283    }
284    if condlist.is_empty() {
285        return Err(FerrayError::invalid_value(
286            "condlist and choicelist must not be empty",
287        ));
288    }
289
290    let shape = condlist[0].shape();
291    for (i, (c, ch)) in condlist.iter().zip(choicelist.iter()).enumerate() {
292        if c.shape() != shape || ch.shape() != shape {
293            return Err(FerrayError::shape_mismatch(format!(
294                "condlist[{}]/choicelist[{}] shape mismatch with reference shape {:?}",
295                i, i, shape
296            )));
297        }
298    }
299
300    let size = condlist[0].size();
301    let mut data = vec![default; size];
302
303    // Process in reverse order so first matching condition wins
304    for (cond, choice) in condlist.iter().zip(choicelist.iter()).rev() {
305        for (i, (&c, v)) in cond.inner.iter().zip(choice.inner.iter()).enumerate() {
306            if c {
307                data[i] = v.clone();
308            }
309        }
310    }
311
312    let dyn_shape = IxDyn::new(shape);
313    Array::from_vec(dyn_shape, data)
314}
315
316// ===========================================================================
317// indices
318// ===========================================================================
319
320/// Return arrays representing the indices of a grid.
321///
322/// Equivalent to `np.indices(dimensions)`. Returns one `u64` array per
323/// dimension, each with shape `dimensions`.
324///
325/// For example, `indices(&[2, 3])` returns two arrays of shape `[2, 3]`:
326/// the first contains row indices, the second column indices.
327pub fn indices(dimensions: &[usize]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
328    let ndim = dimensions.len();
329    let total: usize = dimensions.iter().product();
330
331    let mut result = Vec::with_capacity(ndim);
332
333    for ax in 0..ndim {
334        let mut data = Vec::with_capacity(total);
335        for flat_idx in 0..total {
336            let mut rem = flat_idx;
337            let mut idx_for_ax = 0;
338            for (d, &dim_size) in dimensions.iter().enumerate().rev() {
339                let coord = rem % dim_size;
340                rem /= dim_size;
341                if d == ax {
342                    idx_for_ax = coord;
343                }
344            }
345            data.push(idx_for_ax as u64);
346        }
347        let dim = IxDyn::new(dimensions);
348        result.push(Array::from_vec(dim, data)?);
349    }
350
351    Ok(result)
352}
353
354// ===========================================================================
355// ix_
356// ===========================================================================
357
358/// Construct an open mesh from multiple sequences.
359///
360/// Equivalent to `np.ix_(*args)`. Returns a list of arrays, each with
361/// shape `(1, 1, ..., N, ..., 1)` where `N` is the length of that sequence
362/// and it appears in the position corresponding to its argument index.
363///
364/// This is useful for constructing index arrays for cross-indexing.
365pub fn ix_(sequences: &[&[u64]]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
366    let ndim = sequences.len();
367    let mut result = Vec::with_capacity(ndim);
368
369    for (i, seq) in sequences.iter().enumerate() {
370        let mut shape = vec![1usize; ndim];
371        shape[i] = seq.len();
372
373        let data = seq.to_vec();
374        let dim = IxDyn::new(&shape);
375        result.push(Array::from_vec(dim, data)?);
376    }
377
378    Ok(result)
379}
380
381// ===========================================================================
382// diag_indices / diag_indices_from
383// ===========================================================================
384
385/// Return the indices to access the main diagonal of an n x n array.
386///
387/// Equivalent to `np.diag_indices(n, ndim=2)`. Returns `ndim` vectors,
388/// each containing `[0, 1, ..., n-1]`.
389pub fn diag_indices(n: usize, ndim: usize) -> Vec<Vec<usize>> {
390    let data: Vec<usize> = (0..n).collect();
391    vec![data; ndim]
392}
393
394/// Return the indices to access the main diagonal of the given array.
395///
396/// The array must be at least 2-D and square (all dimensions equal).
397///
398/// # Errors
399/// - `InvalidValue` if the array has fewer than 2 dimensions
400/// - `ShapeMismatch` if dimensions are not all equal
401pub fn diag_indices_from<T: Element, D: Dimension>(
402    a: &Array<T, D>,
403) -> FerrayResult<Vec<Vec<usize>>> {
404    let ndim = a.ndim();
405    if ndim < 2 {
406        return Err(FerrayError::invalid_value(
407            "diag_indices_from requires at least 2 dimensions",
408        ));
409    }
410    let shape = a.shape();
411    let n = shape[0];
412    for &s in &shape[1..] {
413        if s != n {
414            return Err(FerrayError::shape_mismatch(format!(
415                "all dimensions must be equal for diag_indices_from, got {:?}",
416                shape
417            )));
418        }
419    }
420    Ok(diag_indices(n, ndim))
421}
422
423// ===========================================================================
424// tril_indices / triu_indices / tril_indices_from / triu_indices_from
425// ===========================================================================
426
427/// Return the indices for the lower triangle of an (n, m) array.
428///
429/// Equivalent to `np.tril_indices(n, k, m)`.
430/// `k` is the diagonal offset: 0 = main diagonal, positive = above,
431/// negative = below.
432pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
433    let m = m.unwrap_or(n);
434    let mut rows = Vec::new();
435    let mut cols = Vec::new();
436
437    for i in 0..n {
438        for j in 0..m {
439            if (j as isize) <= (i as isize) + k {
440                rows.push(i);
441                cols.push(j);
442            }
443        }
444    }
445
446    (rows, cols)
447}
448
449/// Return the indices for the upper triangle of an (n, m) array.
450///
451/// Equivalent to `np.triu_indices(n, k, m)`.
452pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
453    let m = m.unwrap_or(n);
454    let mut rows = Vec::new();
455    let mut cols = Vec::new();
456
457    for i in 0..n {
458        for j in 0..m {
459            if (j as isize) >= (i as isize) + k {
460                rows.push(i);
461                cols.push(j);
462            }
463        }
464    }
465
466    (rows, cols)
467}
468
469/// Return the indices for the lower triangle of the given 2-D array.
470///
471/// # Errors
472/// - `InvalidValue` if the array is not 2-D
473pub fn tril_indices_from<T: Element, D: Dimension>(
474    a: &Array<T, D>,
475    k: isize,
476) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
477    let shape = a.shape();
478    if shape.len() != 2 {
479        return Err(FerrayError::invalid_value(
480            "tril_indices_from requires a 2-D array",
481        ));
482    }
483    Ok(tril_indices(shape[0], k, Some(shape[1])))
484}
485
486/// Return the indices for the upper triangle of the given 2-D array.
487///
488/// # Errors
489/// - `InvalidValue` if the array is not 2-D
490pub fn triu_indices_from<T: Element, D: Dimension>(
491    a: &Array<T, D>,
492    k: isize,
493) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
494    let shape = a.shape();
495    if shape.len() != 2 {
496        return Err(FerrayError::invalid_value(
497            "triu_indices_from requires a 2-D array",
498        ));
499    }
500    Ok(triu_indices(shape[0], k, Some(shape[1])))
501}
502
503// ===========================================================================
504// ravel_multi_index / unravel_index
505// ===========================================================================
506
507/// Convert a tuple of index arrays to a flat index array.
508///
509/// Equivalent to `np.ravel_multi_index(multi_index, dims)`.
510/// Uses row-major (C) ordering.
511///
512/// # Errors
513/// - `InvalidValue` if multi_index arrays have different lengths
514/// - `IndexOutOfBounds` if any index is out of range for its dimension
515#[allow(clippy::needless_range_loop)]
516pub fn ravel_multi_index(multi_index: &[&[usize]], dims: &[usize]) -> FerrayResult<Vec<usize>> {
517    if multi_index.len() != dims.len() {
518        return Err(FerrayError::invalid_value(format!(
519            "multi_index has {} components but dims has {} dimensions",
520            multi_index.len(),
521            dims.len()
522        )));
523    }
524    if multi_index.is_empty() {
525        return Ok(vec![]);
526    }
527
528    let n = multi_index[0].len();
529    for (i, idx_arr) in multi_index.iter().enumerate() {
530        if idx_arr.len() != n {
531            return Err(FerrayError::invalid_value(format!(
532                "multi_index[{}] has length {} but expected {}",
533                i,
534                idx_arr.len(),
535                n
536            )));
537        }
538    }
539
540    // Compute strides for C-order
541    let ndim = dims.len();
542    let mut strides = vec![1usize; ndim];
543    for i in (0..ndim - 1).rev() {
544        strides[i] = strides[i + 1] * dims[i + 1];
545    }
546
547    let mut flat = Vec::with_capacity(n);
548    #[allow(clippy::needless_range_loop)]
549    for pos in 0..n {
550        let mut linear = 0usize;
551        for (d, &dim_size) in dims.iter().enumerate() {
552            let coord = multi_index[d][pos];
553            if coord >= dim_size {
554                return Err(FerrayError::index_out_of_bounds(
555                    coord as isize,
556                    d,
557                    dim_size,
558                ));
559            }
560            linear += coord * strides[d];
561        }
562        flat.push(linear);
563    }
564
565    Ok(flat)
566}
567
568/// Convert flat indices to a tuple of coordinate arrays.
569///
570/// Equivalent to `np.unravel_index(indices, shape)`.
571/// Uses row-major (C) ordering.
572///
573/// # Errors
574/// - `IndexOutOfBounds` if any flat index >= product(shape)
575pub fn unravel_index(flat_indices: &[usize], shape: &[usize]) -> FerrayResult<Vec<Vec<usize>>> {
576    let total: usize = shape.iter().product();
577    let ndim = shape.len();
578    let n = flat_indices.len();
579
580    let mut result: Vec<Vec<usize>> = vec![Vec::with_capacity(n); ndim];
581
582    for &flat_idx in flat_indices {
583        if flat_idx >= total {
584            return Err(FerrayError::index_out_of_bounds(
585                flat_idx as isize,
586                0,
587                total,
588            ));
589        }
590        let mut rem = flat_idx;
591        for (d, &dim_size) in shape.iter().enumerate().rev() {
592            result[d].push(rem % dim_size);
593            rem /= dim_size;
594        }
595    }
596
597    Ok(result)
598}
599
600// ===========================================================================
601// flatnonzero
602// ===========================================================================
603
604/// Return the indices of non-zero elements in the flattened array.
605///
606/// Equivalent to `np.flatnonzero(a)`. An element is "non-zero" if it
607/// is not equal to the type's zero value.
608pub fn flatnonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<usize> {
609    let zero = T::zero();
610    a.inner
611        .iter()
612        .enumerate()
613        .filter_map(|(i, val)| if *val != zero { Some(i) } else { None })
614        .collect()
615}
616
617// ===========================================================================
618// ndindex / ndenumerate iterators
619// ===========================================================================
620
621/// An iterator over all multi-dimensional indices for a given shape.
622///
623/// Equivalent to `np.ndindex(*shape)`. Yields indices in row-major order.
624pub struct NdIndex {
625    shape: Vec<usize>,
626    current: Vec<usize>,
627    done: bool,
628}
629
630impl NdIndex {
631    fn new(shape: &[usize]) -> Self {
632        let done = shape.contains(&0);
633        Self {
634            shape: shape.to_vec(),
635            current: vec![0; shape.len()],
636            done,
637        }
638    }
639}
640
641impl Iterator for NdIndex {
642    type Item = Vec<usize>;
643
644    fn next(&mut self) -> Option<Self::Item> {
645        if self.done {
646            return None;
647        }
648
649        let result = self.current.clone();
650
651        // Increment: rightmost dimension first (row-major / C-order)
652        let mut carry = true;
653        for i in (0..self.shape.len()).rev() {
654            if carry {
655                self.current[i] += 1;
656                if self.current[i] >= self.shape[i] {
657                    self.current[i] = 0;
658                    carry = true;
659                } else {
660                    carry = false;
661                }
662            }
663        }
664        if carry {
665            self.done = true;
666        }
667
668        Some(result)
669    }
670
671    fn size_hint(&self) -> (usize, Option<usize>) {
672        if self.done {
673            return (0, Some(0));
674        }
675        let total: usize = self.shape.iter().product();
676        // Compute how many we've already yielded
677        let mut yielded = 0usize;
678        let ndim = self.shape.len();
679        let mut stride = 1usize;
680        for i in (0..ndim).rev() {
681            yielded += self.current[i] * stride;
682            stride *= self.shape[i];
683        }
684        let remaining = total - yielded;
685        (remaining, Some(remaining))
686    }
687}
688
689/// Create an iterator over all multi-dimensional indices for a shape.
690///
691/// Equivalent to `np.ndindex(*shape)`.
692pub fn ndindex(shape: &[usize]) -> NdIndex {
693    NdIndex::new(shape)
694}
695
696/// Create an iterator yielding `(index, &value)` pairs.
697///
698/// Equivalent to `np.ndenumerate(a)`.
699pub fn ndenumerate<'a, T: Element, D: Dimension>(
700    a: &'a Array<T, D>,
701) -> impl Iterator<Item = (Vec<usize>, &'a T)> + 'a {
702    let shape = a.shape().to_vec();
703    let ndim = shape.len();
704    a.inner.iter().enumerate().map(move |(flat_idx, val)| {
705        let mut idx = vec![0usize; ndim];
706        let mut rem = flat_idx;
707        for (d, s) in shape.iter().enumerate().rev() {
708            if *s > 0 {
709                idx[d] = rem % s;
710                rem /= s;
711            }
712        }
713        (idx, val)
714    })
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720    use crate::dimension::{Ix1, Ix2};
721
722    // -----------------------------------------------------------------------
723    // take
724    // -----------------------------------------------------------------------
725
726    #[test]
727    fn take_1d() {
728        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
729        let taken = take(&arr, &[0, 2, 4], Axis(0)).unwrap();
730        assert_eq!(taken.shape(), &[3]);
731        let data: Vec<i32> = taken.iter().copied().collect();
732        assert_eq!(data, vec![10, 30, 50]);
733    }
734
735    #[test]
736    fn take_2d_axis1() {
737        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
738        let taken = take(&arr, &[0, 2], Axis(1)).unwrap();
739        assert_eq!(taken.shape(), &[3, 2]);
740        let data: Vec<i32> = taken.iter().copied().collect();
741        assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
742    }
743
744    #[test]
745    fn take_negative_indices() {
746        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
747        let taken = take(&arr, &[-1, -3], Axis(0)).unwrap();
748        let data: Vec<i32> = taken.iter().copied().collect();
749        assert_eq!(data, vec![40, 20]);
750    }
751
752    // -----------------------------------------------------------------------
753    // take_along_axis
754    // -----------------------------------------------------------------------
755
756    #[test]
757    fn take_along_axis_basic() {
758        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
759        let taken = take_along_axis(&arr, &[1, 3], Axis(1)).unwrap();
760        assert_eq!(taken.shape(), &[3, 2]);
761    }
762
763    // -----------------------------------------------------------------------
764    // put
765    // -----------------------------------------------------------------------
766
767    #[test]
768    fn put_flat() {
769        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 0, 0, 0, 0]).unwrap();
770        arr.put(&[1, 3], &[99, 88]).unwrap();
771        assert_eq!(arr.as_slice().unwrap(), &[0, 99, 0, 88, 0]);
772    }
773
774    #[test]
775    fn put_cycling_values() {
776        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0; 5]).unwrap();
777        arr.put(&[0, 1, 2, 3, 4], &[10, 20]).unwrap();
778        assert_eq!(arr.as_slice().unwrap(), &[10, 20, 10, 20, 10]);
779    }
780
781    #[test]
782    fn put_out_of_bounds() {
783        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
784        assert!(arr.put(&[5], &[1]).is_err());
785    }
786
787    // -----------------------------------------------------------------------
788    // fill_diagonal
789    // -----------------------------------------------------------------------
790
791    #[test]
792    fn fill_diagonal_2d() {
793        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 3]), vec![0; 9]).unwrap();
794        arr.fill_diagonal(1);
795        let data: Vec<i32> = arr.iter().copied().collect();
796        assert_eq!(data, vec![1, 0, 0, 0, 1, 0, 0, 0, 1]);
797    }
798
799    #[test]
800    fn fill_diagonal_rectangular() {
801        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 4]), vec![0; 8]).unwrap();
802        arr.fill_diagonal(5);
803        let data: Vec<i32> = arr.iter().copied().collect();
804        assert_eq!(data, vec![5, 0, 0, 0, 0, 5, 0, 0]);
805    }
806
807    // -----------------------------------------------------------------------
808    // choose
809    // -----------------------------------------------------------------------
810
811    #[test]
812    fn choose_basic() {
813        let idx = Array::<u64, Ix1>::from_vec(Ix1::new([4]), vec![0, 1, 0, 1]).unwrap();
814        let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
815        let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![100, 200, 300, 400]).unwrap();
816        let result = choose(&idx, &[c0, c1]).unwrap();
817        let data: Vec<i32> = result.iter().copied().collect();
818        assert_eq!(data, vec![10, 200, 30, 400]);
819    }
820
821    #[test]
822    fn choose_out_of_bounds() {
823        let idx = Array::<u64, Ix1>::from_vec(Ix1::new([2]), vec![0, 2]).unwrap();
824        let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 2]).unwrap();
825        let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3, 4]).unwrap();
826        assert!(choose(&idx, &[c0, c1]).is_err());
827    }
828
829    // -----------------------------------------------------------------------
830    // compress
831    // -----------------------------------------------------------------------
832
833    #[test]
834    fn compress_1d() {
835        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
836        let result = compress(&[true, false, true, false, true], &arr, Axis(0)).unwrap();
837        let data: Vec<i32> = result.iter().copied().collect();
838        assert_eq!(data, vec![10, 30, 50]);
839    }
840
841    #[test]
842    fn compress_2d_axis0() {
843        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
844        let result = compress(&[true, false, true], &arr, Axis(0)).unwrap();
845        assert_eq!(result.shape(), &[2, 4]);
846        let data: Vec<i32> = result.iter().copied().collect();
847        assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
848    }
849
850    // -----------------------------------------------------------------------
851    // select
852    // -----------------------------------------------------------------------
853
854    #[test]
855    fn select_basic() {
856        let c1 =
857            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, false]).unwrap();
858        let c2 =
859            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
860        let ch1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 1, 1, 1]).unwrap();
861        let ch2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![2, 2, 2, 2]).unwrap();
862        let result = select(&[c1, c2], &[ch1, ch2], 0).unwrap();
863        let data: Vec<i32> = result.iter().copied().collect();
864        assert_eq!(data, vec![1, 2, 0, 0]);
865    }
866
867    // -----------------------------------------------------------------------
868    // indices
869    // -----------------------------------------------------------------------
870
871    #[test]
872    fn indices_2d() {
873        let idx = indices(&[2, 3]).unwrap();
874        assert_eq!(idx.len(), 2);
875        assert_eq!(idx[0].shape(), &[2, 3]);
876        assert_eq!(idx[1].shape(), &[2, 3]);
877        let rows: Vec<u64> = idx[0].iter().copied().collect();
878        assert_eq!(rows, vec![0, 0, 0, 1, 1, 1]);
879        let cols: Vec<u64> = idx[1].iter().copied().collect();
880        assert_eq!(cols, vec![0, 1, 2, 0, 1, 2]);
881    }
882
883    // -----------------------------------------------------------------------
884    // ix_
885    // -----------------------------------------------------------------------
886
887    #[test]
888    fn ix_basic() {
889        let result = ix_(&[&[0, 1], &[2, 3, 4]]).unwrap();
890        assert_eq!(result.len(), 2);
891        assert_eq!(result[0].shape(), &[2, 1]);
892        assert_eq!(result[1].shape(), &[1, 3]);
893    }
894
895    // -----------------------------------------------------------------------
896    // diag_indices
897    // -----------------------------------------------------------------------
898
899    #[test]
900    fn diag_indices_basic() {
901        let idx = diag_indices(3, 2);
902        assert_eq!(idx.len(), 2);
903        assert_eq!(idx[0], vec![0, 1, 2]);
904        assert_eq!(idx[1], vec![0, 1, 2]);
905    }
906
907    #[test]
908    fn diag_indices_from_square() {
909        let arr = Array::<f64, Ix2>::zeros(Ix2::new([4, 4])).unwrap();
910        let idx = diag_indices_from(&arr).unwrap();
911        assert_eq!(idx.len(), 2);
912        assert_eq!(idx[0].len(), 4);
913    }
914
915    #[test]
916    fn diag_indices_from_not_square() {
917        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
918        assert!(diag_indices_from(&arr).is_err());
919    }
920
921    // -----------------------------------------------------------------------
922    // tril_indices / triu_indices
923    // -----------------------------------------------------------------------
924
925    #[test]
926    fn tril_indices_basic() {
927        let (rows, cols) = tril_indices(3, 0, None);
928        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
929        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
930    }
931
932    #[test]
933    fn triu_indices_basic() {
934        let (rows, cols) = triu_indices(3, 0, None);
935        assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
936        assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
937    }
938
939    #[test]
940    fn tril_indices_with_k() {
941        let (rows, cols) = tril_indices(3, 1, None);
942        assert_eq!(rows, vec![0, 0, 1, 1, 1, 2, 2, 2]);
943        assert_eq!(cols, vec![0, 1, 0, 1, 2, 0, 1, 2]);
944    }
945
946    #[test]
947    fn triu_indices_with_negative_k() {
948        let (rows, cols) = triu_indices(3, -1, None);
949        assert_eq!(rows, vec![0, 0, 0, 1, 1, 1, 2, 2]);
950        assert_eq!(cols, vec![0, 1, 2, 0, 1, 2, 1, 2]);
951    }
952
953    #[test]
954    fn tril_indices_from_test() {
955        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
956        let (rows, _cols) = tril_indices_from(&arr, 0).unwrap();
957        assert_eq!(rows.len(), 6);
958    }
959
960    #[test]
961    fn triu_indices_from_test() {
962        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
963        let (rows, _cols) = triu_indices_from(&arr, 0).unwrap();
964        assert_eq!(rows.len(), 6);
965    }
966
967    #[test]
968    fn tril_indices_rectangular() {
969        let (rows, cols) = tril_indices(3, 0, Some(4));
970        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
971        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
972    }
973
974    // -----------------------------------------------------------------------
975    // ravel_multi_index / unravel_index
976    // -----------------------------------------------------------------------
977
978    #[test]
979    fn ravel_multi_index_basic() {
980        let flat = ravel_multi_index(&[&[0, 1, 2], &[1, 2, 0]], &[3, 4]).unwrap();
981        assert_eq!(flat, vec![1, 6, 8]);
982    }
983
984    #[test]
985    fn ravel_multi_index_3d() {
986        let flat = ravel_multi_index(&[&[0], &[1], &[2]], &[2, 3, 4]).unwrap();
987        assert_eq!(flat, vec![6]);
988    }
989
990    #[test]
991    fn ravel_multi_index_out_of_bounds() {
992        assert!(ravel_multi_index(&[&[3]], &[3]).is_err());
993    }
994
995    #[test]
996    fn unravel_index_basic() {
997        let coords = unravel_index(&[1, 6, 8], &[3, 4]).unwrap();
998        assert_eq!(coords[0], vec![0, 1, 2]);
999        assert_eq!(coords[1], vec![1, 2, 0]);
1000    }
1001
1002    #[test]
1003    fn unravel_index_out_of_bounds() {
1004        assert!(unravel_index(&[12], &[3, 4]).is_err());
1005    }
1006
1007    #[test]
1008    fn ravel_unravel_roundtrip() {
1009        let dims = &[3, 4, 5];
1010        let a: &[usize] = &[1, 2];
1011        let b: &[usize] = &[2, 3];
1012        let c: &[usize] = &[3, 4];
1013        let multi: &[&[usize]] = &[a, b, c];
1014        let flat = ravel_multi_index(multi, dims).unwrap();
1015        let coords = unravel_index(&flat, dims).unwrap();
1016        assert_eq!(coords[0], vec![1, 2]);
1017        assert_eq!(coords[1], vec![2, 3]);
1018        assert_eq!(coords[2], vec![3, 4]);
1019    }
1020
1021    // -----------------------------------------------------------------------
1022    // flatnonzero
1023    // -----------------------------------------------------------------------
1024
1025    #[test]
1026    fn flatnonzero_basic() {
1027        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1028        let nz = flatnonzero(&arr);
1029        assert_eq!(nz, vec![1, 3]);
1030    }
1031
1032    #[test]
1033    fn flatnonzero_2d() {
1034        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1035        let nz = flatnonzero(&arr);
1036        assert_eq!(nz, vec![1, 3, 5]);
1037    }
1038
1039    #[test]
1040    fn flatnonzero_all_zero() {
1041        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1042        let nz = flatnonzero(&arr);
1043        assert_eq!(nz.len(), 0);
1044    }
1045
1046    // -----------------------------------------------------------------------
1047    // ndindex
1048    // -----------------------------------------------------------------------
1049
1050    #[test]
1051    fn ndindex_2d() {
1052        let indices: Vec<Vec<usize>> = ndindex(&[2, 3]).collect();
1053        assert_eq!(indices.len(), 6);
1054        assert_eq!(indices[0], vec![0, 0]);
1055        assert_eq!(indices[1], vec![0, 1]);
1056        assert_eq!(indices[2], vec![0, 2]);
1057        assert_eq!(indices[3], vec![1, 0]);
1058        assert_eq!(indices[4], vec![1, 1]);
1059        assert_eq!(indices[5], vec![1, 2]);
1060    }
1061
1062    #[test]
1063    fn ndindex_1d() {
1064        let indices: Vec<Vec<usize>> = ndindex(&[4]).collect();
1065        assert_eq!(indices.len(), 4);
1066        assert_eq!(indices[0], vec![0]);
1067        assert_eq!(indices[3], vec![3]);
1068    }
1069
1070    #[test]
1071    fn ndindex_empty() {
1072        let indices: Vec<Vec<usize>> = ndindex(&[0]).collect();
1073        assert_eq!(indices.len(), 0);
1074    }
1075
1076    #[test]
1077    fn ndindex_scalar() {
1078        let indices: Vec<Vec<usize>> = ndindex(&[]).collect();
1079        assert_eq!(indices.len(), 1);
1080        assert_eq!(indices[0], Vec::<usize>::new());
1081    }
1082
1083    // -----------------------------------------------------------------------
1084    // ndenumerate
1085    // -----------------------------------------------------------------------
1086
1087    #[test]
1088    fn ndenumerate_2d() {
1089        let arr =
1090            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
1091        let items: Vec<(Vec<usize>, &i32)> = ndenumerate(&arr).collect();
1092        assert_eq!(items.len(), 6);
1093        assert_eq!(items[0], (vec![0, 0], &10));
1094        assert_eq!(items[1], (vec![0, 1], &20));
1095        assert_eq!(items[5], (vec![1, 2], &60));
1096    }
1097
1098    // -----------------------------------------------------------------------
1099    // put_along_axis
1100    // -----------------------------------------------------------------------
1101
1102    #[test]
1103    fn put_along_axis_basic() {
1104        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![0; 12]).unwrap();
1105        let values =
1106            Array::<i32, IxDyn>::from_vec(IxDyn::new(&[8]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
1107        arr.put_along_axis(&[0, 2], &values, Axis(0)).unwrap();
1108        let data: Vec<i32> = arr.iter().copied().collect();
1109        assert_eq!(data, vec![1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8]);
1110    }
1111}