Skip to main content

ferray_core/indexing/
extended.rs

1// ferray-core: Extended indexing functions (REQ-15a)
2//
3// take, take_along_axis, put, put_along_axis, choose, compress, select,
4// indices, ix_, diag_indices, diag_indices_from, tril_indices, triu_indices,
5// tril_indices_from, triu_indices_from, ravel_multi_index, unravel_index,
6// flatnonzero, fill_diagonal, ndindex (iterator), ndenumerate (iterator)
7//
8// Index-returning functions return Vec<Vec<usize>> or (Vec<usize>, Vec<usize>)
9// because usize is not an Element type. Callers can wrap these into arrays
10// of u64 or i64 if needed.
11
12use super::normalize_index;
13use crate::array::owned::Array;
14use crate::dimension::{Axis, Dimension, Ix2, IxDyn};
15use crate::dtype::Element;
16use crate::error::{FerrayError, FerrayResult};
17
18// ===========================================================================
19// take / take_along_axis
20// ===========================================================================
21
22/// Take elements from an array along an axis.
23///
24/// Equivalent to `np.take(a, indices, axis)`. Returns a copy.
25///
26/// # Errors
27/// - `AxisOutOfBounds` if `axis >= ndim`
28/// - `IndexOutOfBounds` if any index is out of range
29pub fn take<T: Element, D: Dimension>(
30    a: &Array<T, D>,
31    indices: &[isize],
32    axis: Axis,
33) -> FerrayResult<Array<T, IxDyn>> {
34    a.index_select(axis, indices)
35}
36
37/// Take values from an array along an axis using an index slice.
38///
39/// Similar to `np.take_along_axis`. The `indices` slice contains
40/// indices into `a` along the specified axis. The result replaces
41/// the `axis` dimension with the indices dimension.
42///
43/// # Errors
44/// - `AxisOutOfBounds` if `axis >= ndim`
45/// - `IndexOutOfBounds` if any index is out of range
46pub fn take_along_axis<T: Element, D: Dimension>(
47    a: &Array<T, D>,
48    indices: &[isize],
49    axis: Axis,
50) -> FerrayResult<Array<T, IxDyn>> {
51    a.index_select(axis, indices)
52}
53
54// ===========================================================================
55// put / put_along_axis
56// ===========================================================================
57
58impl<T: Element, D: Dimension> Array<T, D> {
59    /// Put values into the flattened array at the given indices.
60    ///
61    /// Equivalent to `np.put(a, ind, v)`. Modifies the array in-place.
62    /// Indices refer to the flattened (row-major) array.
63    /// Values are cycled if fewer than indices.
64    ///
65    /// # Errors
66    /// - `IndexOutOfBounds` if any index is out of range
67    /// - `InvalidValue` if values is empty
68    pub fn put(&mut self, indices: &[isize], values: &[T]) -> FerrayResult<()> {
69        if values.is_empty() {
70            return Err(FerrayError::invalid_value("values must not be empty"));
71        }
72        let size = self.size();
73        let normalized: Vec<usize> = indices
74            .iter()
75            .map(|&idx| normalize_index(idx, size, 0))
76            .collect::<FerrayResult<Vec<_>>>()?;
77
78        let mut flat: Vec<&mut T> = self.inner.iter_mut().collect();
79
80        for (i, &idx) in normalized.iter().enumerate() {
81            let val_idx = i % values.len();
82            *flat[idx] = values[val_idx].clone();
83        }
84        Ok(())
85    }
86
87    /// Put values along an axis at specified indices.
88    ///
89    /// For each index position along `axis`, assigns the values from the
90    /// corresponding sub-array of `values`.
91    ///
92    /// # Errors
93    /// - `AxisOutOfBounds` if `axis >= ndim`
94    /// - `IndexOutOfBounds` if any index is out of range
95    pub fn put_along_axis(
96        &mut self,
97        indices: &[isize],
98        values: &Array<T, IxDyn>,
99        axis: Axis,
100    ) -> FerrayResult<()>
101    where
102        D::NdarrayDim: ndarray::RemoveAxis,
103    {
104        let ndim = self.ndim();
105        let ax = axis.index();
106        if ax >= ndim {
107            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
108        }
109        let axis_size = self.shape()[ax];
110
111        let normalized: Vec<usize> = indices
112            .iter()
113            .map(|&idx| normalize_index(idx, axis_size, ax))
114            .collect::<FerrayResult<Vec<_>>>()?;
115
116        let nd_axis = ndarray::Axis(ax);
117        let mut val_iter = values.inner.iter();
118
119        for &idx in &normalized {
120            let mut sub = self.inner.index_axis_mut(nd_axis, idx);
121            for elem in &mut sub {
122                if let Some(v) = val_iter.next() {
123                    *elem = v.clone();
124                }
125            }
126        }
127        Ok(())
128    }
129
130    /// Fill the main diagonal of a 2-D (or N-D) array with a value.
131    ///
132    /// For N-D arrays, the diagonal consists of indices where all
133    /// index values are equal: `a[i, i, ..., i]`.
134    ///
135    /// Equivalent to `np.fill_diagonal(a, val)`.
136    pub fn fill_diagonal(&mut self, val: T) {
137        let shape = self.shape().to_vec();
138        if shape.is_empty() {
139            return;
140        }
141        let min_dim = *shape.iter().min().unwrap_or(&0);
142        let ndim = shape.len();
143
144        for i in 0..min_dim {
145            let idx: Vec<usize> = vec![i; ndim];
146            let nd_idx = ndarray::IxDyn(&idx);
147            let mut dyn_view = self.inner.view_mut().into_dyn();
148            dyn_view[nd_idx] = val.clone();
149        }
150    }
151}
152
153// ===========================================================================
154// choose
155// ===========================================================================
156
157/// Construct an array from an index array and a list of arrays to choose from.
158///
159/// Equivalent to `np.choose(a, choices)`. For each element in `index_arr`,
160/// the value selects which choice array to pick from at that position.
161/// Index values are given as `u64` to avoid the `usize` Element issue.
162///
163/// # Errors
164/// - `IndexOutOfBounds` if any index in `index_arr` is >= `choices.len()`
165/// - `ShapeMismatch` if choice arrays have different shapes from `index_arr`
166pub fn choose<T: Element, D: Dimension>(
167    index_arr: &Array<u64, D>,
168    choices: &[Array<T, D>],
169) -> FerrayResult<Array<T, IxDyn>> {
170    if choices.is_empty() {
171        return Err(FerrayError::invalid_value("choices must not be empty"));
172    }
173
174    let shape = index_arr.shape();
175    for (i, c) in choices.iter().enumerate() {
176        if c.shape() != shape {
177            return Err(FerrayError::shape_mismatch(format!(
178                "choice[{}] shape {:?} does not match index array shape {:?}",
179                i,
180                c.shape(),
181                shape
182            )));
183        }
184    }
185
186    let n_choices = choices.len();
187    let choice_iters: Vec<Vec<T>> = choices
188        .iter()
189        .map(|c| c.inner.iter().cloned().collect())
190        .collect();
191
192    let mut data = Vec::with_capacity(index_arr.size());
193    for (pos, idx_val) in index_arr.inner.iter().enumerate() {
194        let idx = *idx_val as usize;
195        if idx >= n_choices {
196            return Err(FerrayError::index_out_of_bounds(idx as isize, 0, n_choices));
197        }
198        data.push(choice_iters[idx][pos].clone());
199    }
200
201    let dyn_shape = IxDyn::new(shape);
202    Array::from_vec(dyn_shape, data)
203}
204
205// ===========================================================================
206// compress
207// ===========================================================================
208
209/// Select slices of an array along an axis where `condition` is true.
210///
211/// Equivalent to `np.compress(condition, a, axis)`.
212///
213/// # Errors
214/// - `AxisOutOfBounds` if `axis >= ndim`
215/// - `ShapeMismatch` if `condition.len()` exceeds axis size
216pub fn compress<T: Element, D: Dimension>(
217    condition: &[bool],
218    a: &Array<T, D>,
219    axis: Axis,
220) -> FerrayResult<Array<T, IxDyn>> {
221    let ndim = a.ndim();
222    let ax = axis.index();
223    if ax >= ndim {
224        return Err(FerrayError::axis_out_of_bounds(ax, ndim));
225    }
226    let axis_size = a.shape()[ax];
227    if condition.len() > axis_size {
228        return Err(FerrayError::shape_mismatch(format!(
229            "condition length {} exceeds axis size {}",
230            condition.len(),
231            axis_size
232        )));
233    }
234
235    let indices: Vec<isize> = condition
236        .iter()
237        .enumerate()
238        .filter_map(|(i, &c)| if c { Some(i as isize) } else { None })
239        .collect();
240
241    a.index_select(axis, &indices)
242}
243
244// ===========================================================================
245// select
246// ===========================================================================
247
248/// Return an array drawn from elements in choicelist, depending on conditions.
249///
250/// Equivalent to `np.select(condlist, choicelist, default)`.
251/// The first condition that is true determines which choice is used.
252///
253/// # Errors
254/// - `InvalidValue` if condlist and choicelist have different lengths
255/// - `ShapeMismatch` if shapes are incompatible
256pub fn select<T: Element, D: Dimension>(
257    condlist: &[Array<bool, D>],
258    choicelist: &[Array<T, D>],
259    default: T,
260) -> FerrayResult<Array<T, IxDyn>> {
261    if condlist.len() != choicelist.len() {
262        return Err(FerrayError::invalid_value(format!(
263            "condlist length {} != choicelist length {}",
264            condlist.len(),
265            choicelist.len()
266        )));
267    }
268    if condlist.is_empty() {
269        return Err(FerrayError::invalid_value(
270            "condlist and choicelist must not be empty",
271        ));
272    }
273
274    let shape = condlist[0].shape();
275    for (i, (c, ch)) in condlist.iter().zip(choicelist.iter()).enumerate() {
276        if c.shape() != shape || ch.shape() != shape {
277            return Err(FerrayError::shape_mismatch(format!(
278                "condlist[{i}]/choicelist[{i}] shape mismatch with reference shape {shape:?}"
279            )));
280        }
281    }
282
283    let size = condlist[0].size();
284    let mut data = vec![default; size];
285
286    // Process in reverse order so first matching condition wins
287    for (cond, choice) in condlist.iter().zip(choicelist.iter()).rev() {
288        for (i, (&c, v)) in cond.inner.iter().zip(choice.inner.iter()).enumerate() {
289            if c {
290                data[i] = v.clone();
291            }
292        }
293    }
294
295    let dyn_shape = IxDyn::new(shape);
296    Array::from_vec(dyn_shape, data)
297}
298
299// ===========================================================================
300// indices
301// ===========================================================================
302
303/// Return arrays representing the indices of a grid.
304///
305/// Equivalent to `np.indices(dimensions)`. Returns one `u64` array per
306/// dimension, each with shape `dimensions`.
307///
308/// For example, `indices(&[2, 3])` returns two arrays of shape `[2, 3]`:
309/// the first contains row indices, the second column indices.
310pub fn indices(dimensions: &[usize]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
311    let ndim = dimensions.len();
312    let total: usize = dimensions.iter().product();
313
314    let mut result = Vec::with_capacity(ndim);
315
316    for ax in 0..ndim {
317        let mut data = Vec::with_capacity(total);
318        for flat_idx in 0..total {
319            let mut rem = flat_idx;
320            let mut idx_for_ax = 0;
321            for (d, &dim_size) in dimensions.iter().enumerate().rev() {
322                let coord = rem % dim_size;
323                rem /= dim_size;
324                if d == ax {
325                    idx_for_ax = coord;
326                }
327            }
328            data.push(idx_for_ax as u64);
329        }
330        let dim = IxDyn::new(dimensions);
331        result.push(Array::from_vec(dim, data)?);
332    }
333
334    Ok(result)
335}
336
337// ===========================================================================
338// ix_
339// ===========================================================================
340
341/// Construct an open mesh from multiple sequences.
342///
343/// Equivalent to `np.ix_(*args)`. Returns a list of arrays, each with
344/// shape `(1, 1, ..., N, ..., 1)` where `N` is the length of that sequence
345/// and it appears in the position corresponding to its argument index.
346///
347/// This is useful for constructing index arrays for cross-indexing.
348pub fn ix_(sequences: &[&[u64]]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
349    let ndim = sequences.len();
350    let mut result = Vec::with_capacity(ndim);
351
352    for (i, seq) in sequences.iter().enumerate() {
353        let mut shape = vec![1usize; ndim];
354        shape[i] = seq.len();
355
356        let data = seq.to_vec();
357        let dim = IxDyn::new(&shape);
358        result.push(Array::from_vec(dim, data)?);
359    }
360
361    Ok(result)
362}
363
364// ===========================================================================
365// diag_indices / diag_indices_from
366// ===========================================================================
367
368/// Return the indices to access the main diagonal of an n x n array.
369///
370/// Equivalent to `np.diag_indices(n, ndim=2)`. Returns `ndim` vectors,
371/// each containing `[0, 1, ..., n-1]`.
372#[must_use]
373pub fn diag_indices(n: usize, ndim: usize) -> Vec<Vec<usize>> {
374    let data: Vec<usize> = (0..n).collect();
375    vec![data; ndim]
376}
377
378/// Return the indices to access the main diagonal of the given array.
379///
380/// The array must be at least 2-D and square (all dimensions equal).
381///
382/// # Errors
383/// - `InvalidValue` if the array has fewer than 2 dimensions
384/// - `ShapeMismatch` if dimensions are not all equal
385pub fn diag_indices_from<T: Element, D: Dimension>(
386    a: &Array<T, D>,
387) -> FerrayResult<Vec<Vec<usize>>> {
388    let ndim = a.ndim();
389    if ndim < 2 {
390        return Err(FerrayError::invalid_value(
391            "diag_indices_from requires at least 2 dimensions",
392        ));
393    }
394    let shape = a.shape();
395    let n = shape[0];
396    for &s in &shape[1..] {
397        if s != n {
398            return Err(FerrayError::shape_mismatch(format!(
399                "all dimensions must be equal for diag_indices_from, got {shape:?}"
400            )));
401        }
402    }
403    Ok(diag_indices(n, ndim))
404}
405
406// ===========================================================================
407// tril_indices / triu_indices / tril_indices_from / triu_indices_from
408// ===========================================================================
409
410/// Return the indices for the lower triangle of an (n, m) array.
411///
412/// Equivalent to `np.tril_indices(n, k, m)`.
413/// `k` is the diagonal offset: 0 = main diagonal, positive = above,
414/// negative = below.
415#[must_use]
416pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
417    let m = m.unwrap_or(n);
418    let mut rows = Vec::new();
419    let mut cols = Vec::new();
420
421    for i in 0..n {
422        for j in 0..m {
423            if (j as isize) <= (i as isize) + k {
424                rows.push(i);
425                cols.push(j);
426            }
427        }
428    }
429
430    (rows, cols)
431}
432
433/// Return the indices for the upper triangle of an (n, m) array.
434///
435/// Equivalent to `np.triu_indices(n, k, m)`.
436#[must_use]
437pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
438    let m = m.unwrap_or(n);
439    let mut rows = Vec::new();
440    let mut cols = Vec::new();
441
442    for i in 0..n {
443        for j in 0..m {
444            if (j as isize) >= (i as isize) + k {
445                rows.push(i);
446                cols.push(j);
447            }
448        }
449    }
450
451    (rows, cols)
452}
453
454/// Return the indices for the lower triangle of the given 2-D array.
455///
456/// # Errors
457/// - `InvalidValue` if the array is not 2-D
458pub fn tril_indices_from<T: Element, D: Dimension>(
459    a: &Array<T, D>,
460    k: isize,
461) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
462    let shape = a.shape();
463    if shape.len() != 2 {
464        return Err(FerrayError::invalid_value(
465            "tril_indices_from requires a 2-D array",
466        ));
467    }
468    Ok(tril_indices(shape[0], k, Some(shape[1])))
469}
470
471/// Return the indices for the upper triangle of the given 2-D array.
472///
473/// # Errors
474/// - `InvalidValue` if the array is not 2-D
475pub fn triu_indices_from<T: Element, D: Dimension>(
476    a: &Array<T, D>,
477    k: isize,
478) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
479    let shape = a.shape();
480    if shape.len() != 2 {
481        return Err(FerrayError::invalid_value(
482            "triu_indices_from requires a 2-D array",
483        ));
484    }
485    Ok(triu_indices(shape[0], k, Some(shape[1])))
486}
487
488// ===========================================================================
489// ravel_multi_index / unravel_index
490// ===========================================================================
491
492/// Convert a tuple of index arrays to a flat index array.
493///
494/// Equivalent to `np.ravel_multi_index(multi_index, dims)`.
495/// Uses row-major (C) ordering.
496///
497/// # Errors
498/// - `InvalidValue` if `multi_index` arrays have different lengths
499/// - `IndexOutOfBounds` if any index is out of range for its dimension
500#[allow(clippy::needless_range_loop)]
501pub fn ravel_multi_index(multi_index: &[&[usize]], dims: &[usize]) -> FerrayResult<Vec<usize>> {
502    if multi_index.len() != dims.len() {
503        return Err(FerrayError::invalid_value(format!(
504            "multi_index has {} components but dims has {} dimensions",
505            multi_index.len(),
506            dims.len()
507        )));
508    }
509    if multi_index.is_empty() {
510        return Ok(vec![]);
511    }
512
513    let n = multi_index[0].len();
514    for (i, idx_arr) in multi_index.iter().enumerate() {
515        if idx_arr.len() != n {
516            return Err(FerrayError::invalid_value(format!(
517                "multi_index[{}] has length {} but expected {}",
518                i,
519                idx_arr.len(),
520                n
521            )));
522        }
523    }
524
525    // Compute strides for C-order
526    let ndim = dims.len();
527    let mut strides = vec![1usize; ndim];
528    for i in (0..ndim - 1).rev() {
529        strides[i] = strides[i + 1] * dims[i + 1];
530    }
531
532    let mut flat = Vec::with_capacity(n);
533    #[allow(clippy::needless_range_loop)]
534    for pos in 0..n {
535        let mut linear = 0usize;
536        for (d, &dim_size) in dims.iter().enumerate() {
537            let coord = multi_index[d][pos];
538            if coord >= dim_size {
539                return Err(FerrayError::index_out_of_bounds(
540                    coord as isize,
541                    d,
542                    dim_size,
543                ));
544            }
545            linear += coord * strides[d];
546        }
547        flat.push(linear);
548    }
549
550    Ok(flat)
551}
552
553/// Convert flat indices to a tuple of coordinate arrays.
554///
555/// Equivalent to `np.unravel_index(indices, shape)`.
556/// Uses row-major (C) ordering.
557///
558/// # Errors
559/// - `IndexOutOfBounds` if any flat index >= product(shape)
560pub fn unravel_index(flat_indices: &[usize], shape: &[usize]) -> FerrayResult<Vec<Vec<usize>>> {
561    let total: usize = shape.iter().product();
562    let ndim = shape.len();
563    let n = flat_indices.len();
564
565    let mut result: Vec<Vec<usize>> = vec![Vec::with_capacity(n); ndim];
566
567    for &flat_idx in flat_indices {
568        if flat_idx >= total {
569            return Err(FerrayError::index_out_of_bounds(
570                flat_idx as isize,
571                0,
572                total,
573            ));
574        }
575        let mut rem = flat_idx;
576        for (d, &dim_size) in shape.iter().enumerate().rev() {
577            result[d].push(rem % dim_size);
578            rem /= dim_size;
579        }
580    }
581
582    Ok(result)
583}
584
585// ===========================================================================
586// flatnonzero
587// ===========================================================================
588
589/// Return the indices of non-zero elements in the flattened array.
590///
591/// Equivalent to `np.flatnonzero(a)`. An element is "non-zero" if it
592/// is not equal to the type's zero value.
593pub fn flatnonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<usize> {
594    let zero = T::zero();
595    a.inner
596        .iter()
597        .enumerate()
598        .filter_map(|(i, val)| if *val == zero { None } else { Some(i) })
599        .collect()
600}
601
602// ===========================================================================
603// nonzero / argwhere (#373)
604// ===========================================================================
605
606/// Return the indices of non-zero elements, one index vector per axis.
607///
608/// Equivalent to `np.nonzero(a)`. For an N-dimensional array with K
609/// non-zero elements, the returned `Vec<Vec<usize>>` has length N, and
610/// each inner vector has length K. `result[d][i]` is the coordinate
611/// along axis `d` of the `i`-th non-zero element (in row-major order).
612///
613/// For a 1-D array this is equivalent to wrapping `flatnonzero` in a
614/// single-element outer vector. For 2-D, the two inner vectors give
615/// (`row_indices`, `col_indices`) — the typical pair used to reconstruct
616/// sparse matrices or index back into the original array.
617///
618/// Returns `usize` coordinates because `usize` is not an `Element` type;
619/// callers who need a `Array<i64, _>` can cast via
620/// [`argwhere`] or by wrapping each inner vector in an `Array<i64, Ix1>`.
621pub fn nonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<Vec<usize>> {
622    let zero = T::zero();
623    let ndim = a.ndim();
624    let mut result: Vec<Vec<usize>> = vec![Vec::new(); ndim];
625    for (idx, val) in a.indexed_iter() {
626        if *val != zero {
627            for (d, &c) in idx.iter().enumerate() {
628                result[d].push(c);
629            }
630        }
631    }
632    result
633}
634
635/// Return the coordinates of non-zero elements as a 2-D `(N, ndim)` array.
636///
637/// Equivalent to `np.argwhere(a)`. Each row of the result is the
638/// multi-index of one non-zero element, in row-major order. The output
639/// dtype is `i64` to match `NumPy`'s default signed integer index type
640/// (and because `usize` is not an `Element`). For a 0-D (scalar) input
641/// the result has shape `(0, 0)` or `(1, 0)` depending on whether the
642/// scalar is zero.
643///
644/// # Errors
645/// Returns an error only if constructing the result `Array` fails, which
646/// should never happen for a well-formed input.
647pub fn argwhere<T: Element + PartialEq, D: Dimension>(
648    a: &Array<T, D>,
649) -> FerrayResult<Array<i64, Ix2>> {
650    let zero = T::zero();
651    let ndim = a.ndim();
652    let mut data: Vec<i64> = Vec::new();
653    let mut count: usize = 0;
654    for (idx, val) in a.indexed_iter() {
655        if *val != zero {
656            for &c in &idx {
657                data.push(c as i64);
658            }
659            count += 1;
660        }
661    }
662    Array::<i64, Ix2>::from_vec(Ix2::new([count, ndim]), data)
663}
664
665// ===========================================================================
666// ndindex / ndenumerate iterators
667// ===========================================================================
668
669/// An iterator over all multi-dimensional indices for a given shape.
670///
671/// Equivalent to `np.ndindex(*shape)`. Yields indices in row-major order.
672pub struct NdIndex {
673    shape: Vec<usize>,
674    current: Vec<usize>,
675    done: bool,
676}
677
678impl NdIndex {
679    fn new(shape: &[usize]) -> Self {
680        let done = shape.contains(&0);
681        Self {
682            shape: shape.to_vec(),
683            current: vec![0; shape.len()],
684            done,
685        }
686    }
687}
688
689impl Iterator for NdIndex {
690    type Item = Vec<usize>;
691
692    fn next(&mut self) -> Option<Self::Item> {
693        if self.done {
694            return None;
695        }
696
697        let result = self.current.clone();
698
699        // Increment: rightmost dimension first (row-major / C-order)
700        let mut carry = true;
701        for i in (0..self.shape.len()).rev() {
702            if carry {
703                self.current[i] += 1;
704                if self.current[i] >= self.shape[i] {
705                    self.current[i] = 0;
706                    carry = true;
707                } else {
708                    carry = false;
709                }
710            }
711        }
712        if carry {
713            self.done = true;
714        }
715
716        Some(result)
717    }
718
719    fn size_hint(&self) -> (usize, Option<usize>) {
720        if self.done {
721            return (0, Some(0));
722        }
723        let total: usize = self.shape.iter().product();
724        // Compute how many we've already yielded
725        let mut yielded = 0usize;
726        let ndim = self.shape.len();
727        let mut stride = 1usize;
728        for i in (0..ndim).rev() {
729            yielded += self.current[i] * stride;
730            stride *= self.shape[i];
731        }
732        let remaining = total - yielded;
733        (remaining, Some(remaining))
734    }
735}
736
737/// Create an iterator over all multi-dimensional indices for a shape.
738///
739/// Equivalent to `np.ndindex(*shape)`.
740#[must_use]
741pub fn ndindex(shape: &[usize]) -> NdIndex {
742    NdIndex::new(shape)
743}
744
745/// Create an iterator yielding `(index, &value)` pairs.
746///
747/// Equivalent to `np.ndenumerate(a)`.
748pub fn ndenumerate<T: Element, D: Dimension>(
749    a: &Array<T, D>,
750) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
751    let shape = a.shape().to_vec();
752    let ndim = shape.len();
753    a.inner.iter().enumerate().map(move |(flat_idx, val)| {
754        let mut idx = vec![0usize; ndim];
755        let mut rem = flat_idx;
756        for (d, s) in shape.iter().enumerate().rev() {
757            if *s > 0 {
758                idx[d] = rem % s;
759                rem /= s;
760            }
761        }
762        (idx, val)
763    })
764}
765
766// ---------------------------------------------------------------------------
767// where_select — ternary selection (np.where equivalent)
768// ---------------------------------------------------------------------------
769
770/// Select elements from `x` or `y` depending on `condition`.
771///
772/// For each element position, returns `x[i]` where `condition[i]` is `true`,
773/// and `y[i]` where `condition[i]` is `false`.
774///
775/// All three arrays must have the same shape.
776///
777/// Equivalent to `numpy.where(condition, x, y)`.
778///
779/// # Errors
780/// Returns `FerrayError::ShapeMismatch` if shapes differ.
781pub fn where_select<T: Element + Copy, D: Dimension>(
782    condition: &Array<bool, D>,
783    x: &Array<T, D>,
784    y: &Array<T, D>,
785) -> FerrayResult<Array<T, D>> {
786    if condition.shape() != x.shape() || condition.shape() != y.shape() {
787        return Err(FerrayError::shape_mismatch(format!(
788            "where_select: condition shape {:?}, x shape {:?}, y shape {:?} must all match",
789            condition.shape(),
790            x.shape(),
791            y.shape()
792        )));
793    }
794    let data: Vec<T> = condition
795        .iter()
796        .zip(x.iter().zip(y.iter()))
797        .map(|(&c, (&xi, &yi))| if c { xi } else { yi })
798        .collect();
799    Array::from_vec(x.dim().clone(), data)
800}
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805    use crate::dimension::{Ix1, Ix2};
806
807    // -----------------------------------------------------------------------
808    // take
809    // -----------------------------------------------------------------------
810
811    #[test]
812    fn take_1d() {
813        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
814        let taken = take(&arr, &[0, 2, 4], Axis(0)).unwrap();
815        assert_eq!(taken.shape(), &[3]);
816        let data: Vec<i32> = taken.iter().copied().collect();
817        assert_eq!(data, vec![10, 30, 50]);
818    }
819
820    #[test]
821    fn take_2d_axis1() {
822        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
823        let taken = take(&arr, &[0, 2], Axis(1)).unwrap();
824        assert_eq!(taken.shape(), &[3, 2]);
825        let data: Vec<i32> = taken.iter().copied().collect();
826        assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
827    }
828
829    #[test]
830    fn take_negative_indices() {
831        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
832        let taken = take(&arr, &[-1, -3], Axis(0)).unwrap();
833        let data: Vec<i32> = taken.iter().copied().collect();
834        assert_eq!(data, vec![40, 20]);
835    }
836
837    // -----------------------------------------------------------------------
838    // take_along_axis
839    // -----------------------------------------------------------------------
840
841    #[test]
842    fn take_along_axis_basic() {
843        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
844        let taken = take_along_axis(&arr, &[1, 3], Axis(1)).unwrap();
845        assert_eq!(taken.shape(), &[3, 2]);
846    }
847
848    // -----------------------------------------------------------------------
849    // put
850    // -----------------------------------------------------------------------
851
852    #[test]
853    fn put_flat() {
854        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 0, 0, 0, 0]).unwrap();
855        arr.put(&[1, 3], &[99, 88]).unwrap();
856        assert_eq!(arr.as_slice().unwrap(), &[0, 99, 0, 88, 0]);
857    }
858
859    #[test]
860    fn put_cycling_values() {
861        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0; 5]).unwrap();
862        arr.put(&[0, 1, 2, 3, 4], &[10, 20]).unwrap();
863        assert_eq!(arr.as_slice().unwrap(), &[10, 20, 10, 20, 10]);
864    }
865
866    #[test]
867    fn put_out_of_bounds() {
868        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
869        assert!(arr.put(&[5], &[1]).is_err());
870    }
871
872    // -----------------------------------------------------------------------
873    // fill_diagonal
874    // -----------------------------------------------------------------------
875
876    #[test]
877    fn fill_diagonal_2d() {
878        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 3]), vec![0; 9]).unwrap();
879        arr.fill_diagonal(1);
880        let data: Vec<i32> = arr.iter().copied().collect();
881        assert_eq!(data, vec![1, 0, 0, 0, 1, 0, 0, 0, 1]);
882    }
883
884    #[test]
885    fn fill_diagonal_rectangular() {
886        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 4]), vec![0; 8]).unwrap();
887        arr.fill_diagonal(5);
888        let data: Vec<i32> = arr.iter().copied().collect();
889        assert_eq!(data, vec![5, 0, 0, 0, 0, 5, 0, 0]);
890    }
891
892    // -----------------------------------------------------------------------
893    // choose
894    // -----------------------------------------------------------------------
895
896    #[test]
897    fn choose_basic() {
898        let idx = Array::<u64, Ix1>::from_vec(Ix1::new([4]), vec![0, 1, 0, 1]).unwrap();
899        let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
900        let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![100, 200, 300, 400]).unwrap();
901        let result = choose(&idx, &[c0, c1]).unwrap();
902        let data: Vec<i32> = result.iter().copied().collect();
903        assert_eq!(data, vec![10, 200, 30, 400]);
904    }
905
906    #[test]
907    fn choose_out_of_bounds() {
908        let idx = Array::<u64, Ix1>::from_vec(Ix1::new([2]), vec![0, 2]).unwrap();
909        let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 2]).unwrap();
910        let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3, 4]).unwrap();
911        assert!(choose(&idx, &[c0, c1]).is_err());
912    }
913
914    // -----------------------------------------------------------------------
915    // compress
916    // -----------------------------------------------------------------------
917
918    #[test]
919    fn compress_1d() {
920        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
921        let result = compress(&[true, false, true, false, true], &arr, Axis(0)).unwrap();
922        let data: Vec<i32> = result.iter().copied().collect();
923        assert_eq!(data, vec![10, 30, 50]);
924    }
925
926    #[test]
927    fn compress_2d_axis0() {
928        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
929        let result = compress(&[true, false, true], &arr, Axis(0)).unwrap();
930        assert_eq!(result.shape(), &[2, 4]);
931        let data: Vec<i32> = result.iter().copied().collect();
932        assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
933    }
934
935    // -----------------------------------------------------------------------
936    // select
937    // -----------------------------------------------------------------------
938
939    #[test]
940    fn select_basic() {
941        let c1 =
942            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, false]).unwrap();
943        let c2 =
944            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
945        let ch1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 1, 1, 1]).unwrap();
946        let ch2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![2, 2, 2, 2]).unwrap();
947        let result = select(&[c1, c2], &[ch1, ch2], 0).unwrap();
948        let data: Vec<i32> = result.iter().copied().collect();
949        assert_eq!(data, vec![1, 2, 0, 0]);
950    }
951
952    // -----------------------------------------------------------------------
953    // indices
954    // -----------------------------------------------------------------------
955
956    #[test]
957    fn indices_2d() {
958        let idx = indices(&[2, 3]).unwrap();
959        assert_eq!(idx.len(), 2);
960        assert_eq!(idx[0].shape(), &[2, 3]);
961        assert_eq!(idx[1].shape(), &[2, 3]);
962        let rows: Vec<u64> = idx[0].iter().copied().collect();
963        assert_eq!(rows, vec![0, 0, 0, 1, 1, 1]);
964        let cols: Vec<u64> = idx[1].iter().copied().collect();
965        assert_eq!(cols, vec![0, 1, 2, 0, 1, 2]);
966    }
967
968    // -----------------------------------------------------------------------
969    // ix_
970    // -----------------------------------------------------------------------
971
972    #[test]
973    fn ix_basic() {
974        let result = ix_(&[&[0, 1], &[2, 3, 4]]).unwrap();
975        assert_eq!(result.len(), 2);
976        assert_eq!(result[0].shape(), &[2, 1]);
977        assert_eq!(result[1].shape(), &[1, 3]);
978    }
979
980    // -----------------------------------------------------------------------
981    // diag_indices
982    // -----------------------------------------------------------------------
983
984    #[test]
985    fn diag_indices_basic() {
986        let idx = diag_indices(3, 2);
987        assert_eq!(idx.len(), 2);
988        assert_eq!(idx[0], vec![0, 1, 2]);
989        assert_eq!(idx[1], vec![0, 1, 2]);
990    }
991
992    #[test]
993    fn diag_indices_from_square() {
994        let arr = Array::<f64, Ix2>::zeros(Ix2::new([4, 4])).unwrap();
995        let idx = diag_indices_from(&arr).unwrap();
996        assert_eq!(idx.len(), 2);
997        assert_eq!(idx[0].len(), 4);
998    }
999
1000    #[test]
1001    fn diag_indices_from_not_square() {
1002        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
1003        assert!(diag_indices_from(&arr).is_err());
1004    }
1005
1006    // -----------------------------------------------------------------------
1007    // tril_indices / triu_indices
1008    // -----------------------------------------------------------------------
1009
1010    #[test]
1011    fn tril_indices_basic() {
1012        let (rows, cols) = tril_indices(3, 0, None);
1013        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1014        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1015    }
1016
1017    #[test]
1018    fn triu_indices_basic() {
1019        let (rows, cols) = triu_indices(3, 0, None);
1020        assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
1021        assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
1022    }
1023
1024    #[test]
1025    fn tril_indices_with_k() {
1026        let (rows, cols) = tril_indices(3, 1, None);
1027        assert_eq!(rows, vec![0, 0, 1, 1, 1, 2, 2, 2]);
1028        assert_eq!(cols, vec![0, 1, 0, 1, 2, 0, 1, 2]);
1029    }
1030
1031    #[test]
1032    fn triu_indices_with_negative_k() {
1033        let (rows, cols) = triu_indices(3, -1, None);
1034        assert_eq!(rows, vec![0, 0, 0, 1, 1, 1, 2, 2]);
1035        assert_eq!(cols, vec![0, 1, 2, 0, 1, 2, 1, 2]);
1036    }
1037
1038    #[test]
1039    fn tril_indices_from_test() {
1040        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1041        let (rows, _cols) = tril_indices_from(&arr, 0).unwrap();
1042        assert_eq!(rows.len(), 6);
1043    }
1044
1045    #[test]
1046    fn triu_indices_from_test() {
1047        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1048        let (rows, _cols) = triu_indices_from(&arr, 0).unwrap();
1049        assert_eq!(rows.len(), 6);
1050    }
1051
1052    #[test]
1053    fn tril_indices_rectangular() {
1054        let (rows, cols) = tril_indices(3, 0, Some(4));
1055        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1056        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1057    }
1058
1059    // -----------------------------------------------------------------------
1060    // ravel_multi_index / unravel_index
1061    // -----------------------------------------------------------------------
1062
1063    #[test]
1064    fn ravel_multi_index_basic() {
1065        let flat = ravel_multi_index(&[&[0, 1, 2], &[1, 2, 0]], &[3, 4]).unwrap();
1066        assert_eq!(flat, vec![1, 6, 8]);
1067    }
1068
1069    #[test]
1070    fn ravel_multi_index_3d() {
1071        let flat = ravel_multi_index(&[&[0], &[1], &[2]], &[2, 3, 4]).unwrap();
1072        assert_eq!(flat, vec![6]);
1073    }
1074
1075    #[test]
1076    fn ravel_multi_index_out_of_bounds() {
1077        assert!(ravel_multi_index(&[&[3]], &[3]).is_err());
1078    }
1079
1080    #[test]
1081    fn unravel_index_basic() {
1082        let coords = unravel_index(&[1, 6, 8], &[3, 4]).unwrap();
1083        assert_eq!(coords[0], vec![0, 1, 2]);
1084        assert_eq!(coords[1], vec![1, 2, 0]);
1085    }
1086
1087    #[test]
1088    fn unravel_index_out_of_bounds() {
1089        assert!(unravel_index(&[12], &[3, 4]).is_err());
1090    }
1091
1092    #[test]
1093    fn ravel_unravel_roundtrip() {
1094        let dims = &[3, 4, 5];
1095        let a: &[usize] = &[1, 2];
1096        let b: &[usize] = &[2, 3];
1097        let c: &[usize] = &[3, 4];
1098        let multi: &[&[usize]] = &[a, b, c];
1099        let flat = ravel_multi_index(multi, dims).unwrap();
1100        let coords = unravel_index(&flat, dims).unwrap();
1101        assert_eq!(coords[0], vec![1, 2]);
1102        assert_eq!(coords[1], vec![2, 3]);
1103        assert_eq!(coords[2], vec![3, 4]);
1104    }
1105
1106    // -----------------------------------------------------------------------
1107    // flatnonzero
1108    // -----------------------------------------------------------------------
1109
1110    #[test]
1111    fn flatnonzero_basic() {
1112        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1113        let nz = flatnonzero(&arr);
1114        assert_eq!(nz, vec![1, 3]);
1115    }
1116
1117    #[test]
1118    fn flatnonzero_2d() {
1119        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1120        let nz = flatnonzero(&arr);
1121        assert_eq!(nz, vec![1, 3, 5]);
1122    }
1123
1124    #[test]
1125    fn flatnonzero_all_zero() {
1126        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1127        let nz = flatnonzero(&arr);
1128        assert_eq!(nz.len(), 0);
1129    }
1130
1131    // -----------------------------------------------------------------------
1132    // nonzero / argwhere (#373)
1133    // -----------------------------------------------------------------------
1134
1135    #[test]
1136    fn nonzero_1d() {
1137        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1138        let nz = nonzero(&arr);
1139        // One axis, one inner vec with the hit positions.
1140        assert_eq!(nz.len(), 1);
1141        assert_eq!(nz[0], vec![1, 3]);
1142    }
1143
1144    #[test]
1145    fn nonzero_2d_yields_row_and_col_indices() {
1146        // [[0, 1, 0],
1147        //  [2, 0, 3]]
1148        // Non-zero coordinates (row-major): (0,1), (1,0), (1,2).
1149        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1150        let nz = nonzero(&arr);
1151        assert_eq!(nz.len(), 2);
1152        assert_eq!(nz[0], vec![0, 1, 1]);
1153        assert_eq!(nz[1], vec![1, 0, 2]);
1154    }
1155
1156    #[test]
1157    fn nonzero_all_zero_returns_empty_per_axis() {
1158        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1159        let nz = nonzero(&arr);
1160        assert_eq!(nz.len(), 2);
1161        assert!(nz[0].is_empty());
1162        assert!(nz[1].is_empty());
1163    }
1164
1165    #[test]
1166    fn nonzero_f64_treats_negative_zero_as_zero() {
1167        // -0.0 == 0.0 for PartialEq, so -0.0 is "zero" per numpy semantics.
1168        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![-0.0, 1.5, 0.0, -2.5]).unwrap();
1169        let nz = nonzero(&arr);
1170        assert_eq!(nz[0], vec![1, 3]);
1171    }
1172
1173    #[test]
1174    fn argwhere_2d_has_one_row_per_nonzero() {
1175        // Same input as nonzero_2d_yields_row_and_col_indices.
1176        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1177        let coords = argwhere(&arr).unwrap();
1178        assert_eq!(coords.shape(), &[3, 2]);
1179        assert_eq!(coords.as_slice().unwrap(), &[0, 1, 1, 0, 1, 2]);
1180    }
1181
1182    #[test]
1183    fn argwhere_1d_is_column_vector() {
1184        // A (K, 1) shape means K non-zero rows for a 1-D array.
1185        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 7, 0, 9, 3]).unwrap();
1186        let coords = argwhere(&arr).unwrap();
1187        assert_eq!(coords.shape(), &[3, 1]);
1188        assert_eq!(coords.as_slice().unwrap(), &[1, 3, 4]);
1189    }
1190
1191    #[test]
1192    fn argwhere_all_zero_returns_empty() {
1193        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1194        let coords = argwhere(&arr).unwrap();
1195        assert_eq!(coords.shape(), &[0, 2]);
1196        assert_eq!(coords.size(), 0);
1197    }
1198
1199    // -----------------------------------------------------------------------
1200    // ndindex
1201    // -----------------------------------------------------------------------
1202
1203    #[test]
1204    fn ndindex_2d() {
1205        let indices: Vec<Vec<usize>> = ndindex(&[2, 3]).collect();
1206        assert_eq!(indices.len(), 6);
1207        assert_eq!(indices[0], vec![0, 0]);
1208        assert_eq!(indices[1], vec![0, 1]);
1209        assert_eq!(indices[2], vec![0, 2]);
1210        assert_eq!(indices[3], vec![1, 0]);
1211        assert_eq!(indices[4], vec![1, 1]);
1212        assert_eq!(indices[5], vec![1, 2]);
1213    }
1214
1215    #[test]
1216    fn ndindex_1d() {
1217        let indices: Vec<Vec<usize>> = ndindex(&[4]).collect();
1218        assert_eq!(indices.len(), 4);
1219        assert_eq!(indices[0], vec![0]);
1220        assert_eq!(indices[3], vec![3]);
1221    }
1222
1223    #[test]
1224    fn ndindex_empty() {
1225        assert_eq!(ndindex(&[0]).count(), 0);
1226    }
1227
1228    #[test]
1229    fn ndindex_scalar() {
1230        let indices: Vec<Vec<usize>> = ndindex(&[]).collect();
1231        assert_eq!(indices.len(), 1);
1232        assert_eq!(indices[0], Vec::<usize>::new());
1233    }
1234
1235    // -----------------------------------------------------------------------
1236    // ndenumerate
1237    // -----------------------------------------------------------------------
1238
1239    #[test]
1240    fn ndenumerate_2d() {
1241        let arr =
1242            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
1243        let items: Vec<(Vec<usize>, &i32)> = ndenumerate(&arr).collect();
1244        assert_eq!(items.len(), 6);
1245        assert_eq!(items[0], (vec![0, 0], &10));
1246        assert_eq!(items[1], (vec![0, 1], &20));
1247        assert_eq!(items[5], (vec![1, 2], &60));
1248    }
1249
1250    // -----------------------------------------------------------------------
1251    // put_along_axis
1252    // -----------------------------------------------------------------------
1253
1254    #[test]
1255    fn put_along_axis_basic() {
1256        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![0; 12]).unwrap();
1257        let values =
1258            Array::<i32, IxDyn>::from_vec(IxDyn::new(&[8]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
1259        arr.put_along_axis(&[0, 2], &values, Axis(0)).unwrap();
1260        let data: Vec<i32> = arr.iter().copied().collect();
1261        assert_eq!(data, vec![1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8]);
1262    }
1263
1264    // -----------------------------------------------------------------------
1265    // where_
1266    // -----------------------------------------------------------------------
1267
1268    #[test]
1269    fn where_basic() {
1270        let cond =
1271            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1272        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1273        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
1274        let result = where_select(&cond, &x, &y).unwrap();
1275        assert_eq!(result.as_slice().unwrap(), &[1.0, 20.0, 3.0, 40.0]);
1276    }
1277
1278    #[test]
1279    fn where_all_true() {
1280        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1281        let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1282        let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1283        let result = where_select(&cond, &x, &y).unwrap();
1284        assert_eq!(result.as_slice().unwrap(), &[1, 2, 3]);
1285    }
1286
1287    #[test]
1288    fn where_all_false() {
1289        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1290        let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1291        let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1292        let result = where_select(&cond, &x, &y).unwrap();
1293        assert_eq!(result.as_slice().unwrap(), &[10, 20, 30]);
1294    }
1295
1296    #[test]
1297    fn where_shape_mismatch() {
1298        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true; 3]).unwrap();
1299        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0; 4]).unwrap();
1300        let y = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0; 3]).unwrap();
1301        assert!(where_select(&cond, &x, &y).is_err());
1302    }
1303
1304    #[test]
1305    fn where_2d() {
1306        let cond =
1307            Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![true, false, false, true]).unwrap();
1308        let x = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
1309        let y = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![10, 20, 30, 40]).unwrap();
1310        let result = where_select(&cond, &x, &y).unwrap();
1311        let data: Vec<i32> = result.iter().copied().collect();
1312        assert_eq!(data, vec![1, 20, 30, 4]);
1313    }
1314}