Skip to main content

ferray_core/indexing/
extended.rs

1// ferray-core: Extended indexing functions (REQ-15a)
2//
3// take, take_along_axis, put, put_along_axis, choose, compress, select,
4// indices, ix_, diag_indices, diag_indices_from, tril_indices, triu_indices,
5// tril_indices_from, triu_indices_from, ravel_multi_index, unravel_index,
6// flatnonzero, fill_diagonal, ndindex (iterator), ndenumerate (iterator)
7//
8// Index-returning functions return Vec<Vec<usize>> or (Vec<usize>, Vec<usize>)
9// because usize is not an Element type. Callers can wrap these into arrays
10// of u64 or i64 if needed.
11
12use super::normalize_index;
13use crate::array::owned::Array;
14use crate::dimension::{Axis, Dimension, Ix2, IxDyn};
15use crate::dtype::Element;
16use crate::error::{FerrayError, FerrayResult};
17
18// ===========================================================================
19// take / take_along_axis
20// ===========================================================================
21
22/// Take elements from an array along an axis.
23///
24/// Equivalent to `np.take(a, indices, axis)`. Returns a copy.
25///
26/// # Errors
27/// - `AxisOutOfBounds` if `axis >= ndim`
28/// - `IndexOutOfBounds` if any index is out of range
29pub fn take<T: Element, D: Dimension>(
30    a: &Array<T, D>,
31    indices: &[isize],
32    axis: Axis,
33) -> FerrayResult<Array<T, IxDyn>> {
34    a.index_select(axis, indices)
35}
36
37/// Take values from an array along an axis using an index slice.
38///
39/// Similar to `np.take_along_axis`. The `indices` slice contains
40/// indices into `a` along the specified axis. The result replaces
41/// the `axis` dimension with the indices dimension.
42///
43/// # Errors
44/// - `AxisOutOfBounds` if `axis >= ndim`
45/// - `IndexOutOfBounds` if any index is out of range
46pub fn take_along_axis<T: Element, D: Dimension>(
47    a: &Array<T, D>,
48    indices: &[isize],
49    axis: Axis,
50) -> FerrayResult<Array<T, IxDyn>> {
51    a.index_select(axis, indices)
52}
53
54// ===========================================================================
55// put / put_along_axis
56// ===========================================================================
57
58impl<T: Element, D: Dimension> Array<T, D> {
59    /// Put values into the flattened array at the given indices.
60    ///
61    /// Equivalent to `np.put(a, ind, v)`. Modifies the array in-place.
62    /// Indices refer to the flattened (row-major) array.
63    /// Values are cycled if fewer than indices.
64    ///
65    /// # Errors
66    /// - `IndexOutOfBounds` if any index is out of range
67    /// - `InvalidValue` if values is empty
68    pub fn put(&mut self, indices: &[isize], values: &[T]) -> FerrayResult<()> {
69        if values.is_empty() {
70            return Err(FerrayError::invalid_value("values must not be empty"));
71        }
72        let size = self.size();
73        let normalized: Vec<usize> = indices
74            .iter()
75            .map(|&idx| normalize_index(idx, size, 0))
76            .collect::<FerrayResult<Vec<_>>>()?;
77
78        let mut flat: Vec<&mut T> = self.inner.iter_mut().collect();
79
80        for (i, &idx) in normalized.iter().enumerate() {
81            let val_idx = i % values.len();
82            *flat[idx] = values[val_idx].clone();
83        }
84        Ok(())
85    }
86
87    /// Put values along an axis at specified indices.
88    ///
89    /// For each index position along `axis`, assigns the values from the
90    /// corresponding sub-array of `values`.
91    ///
92    /// # Errors
93    /// - `AxisOutOfBounds` if `axis >= ndim`
94    /// - `IndexOutOfBounds` if any index is out of range
95    pub fn put_along_axis(
96        &mut self,
97        indices: &[isize],
98        values: &Array<T, IxDyn>,
99        axis: Axis,
100    ) -> FerrayResult<()>
101    where
102        D::NdarrayDim: ndarray::RemoveAxis,
103    {
104        let ndim = self.ndim();
105        let ax = axis.index();
106        if ax >= ndim {
107            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
108        }
109        let axis_size = self.shape()[ax];
110
111        let normalized: Vec<usize> = indices
112            .iter()
113            .map(|&idx| normalize_index(idx, axis_size, ax))
114            .collect::<FerrayResult<Vec<_>>>()?;
115
116        let nd_axis = ndarray::Axis(ax);
117        let mut val_iter = values.inner.iter();
118
119        for &idx in &normalized {
120            let mut sub = self.inner.index_axis_mut(nd_axis, idx);
121            for elem in sub.iter_mut() {
122                if let Some(v) = val_iter.next() {
123                    *elem = v.clone();
124                }
125            }
126        }
127        Ok(())
128    }
129
130    /// Fill the main diagonal of a 2-D (or N-D) array with a value.
131    ///
132    /// For N-D arrays, the diagonal consists of indices where all
133    /// index values are equal: `a[i, i, ..., i]`.
134    ///
135    /// Equivalent to `np.fill_diagonal(a, val)`.
136    pub fn fill_diagonal(&mut self, val: T) {
137        let shape = self.shape().to_vec();
138        if shape.is_empty() {
139            return;
140        }
141        let min_dim = *shape.iter().min().unwrap_or(&0);
142        let ndim = shape.len();
143
144        for i in 0..min_dim {
145            let idx: Vec<usize> = vec![i; ndim];
146            let nd_idx = ndarray::IxDyn(&idx);
147            let mut dyn_view = self.inner.view_mut().into_dyn();
148            dyn_view[nd_idx] = val.clone();
149        }
150    }
151}
152
153// ===========================================================================
154// choose
155// ===========================================================================
156
157/// Construct an array from an index array and a list of arrays to choose from.
158///
159/// Equivalent to `np.choose(a, choices)`. For each element in `index_arr`,
160/// the value selects which choice array to pick from at that position.
161/// Index values are given as `u64` to avoid the `usize` Element issue.
162///
163/// # Errors
164/// - `IndexOutOfBounds` if any index in `index_arr` is >= `choices.len()`
165/// - `ShapeMismatch` if choice arrays have different shapes from `index_arr`
166pub fn choose<T: Element, D: Dimension>(
167    index_arr: &Array<u64, D>,
168    choices: &[Array<T, D>],
169) -> FerrayResult<Array<T, IxDyn>> {
170    if choices.is_empty() {
171        return Err(FerrayError::invalid_value("choices must not be empty"));
172    }
173
174    let shape = index_arr.shape();
175    for (i, c) in choices.iter().enumerate() {
176        if c.shape() != shape {
177            return Err(FerrayError::shape_mismatch(format!(
178                "choice[{}] shape {:?} does not match index array shape {:?}",
179                i,
180                c.shape(),
181                shape
182            )));
183        }
184    }
185
186    let n_choices = choices.len();
187    let choice_iters: Vec<Vec<T>> = choices
188        .iter()
189        .map(|c| c.inner.iter().cloned().collect())
190        .collect();
191
192    let mut data = Vec::with_capacity(index_arr.size());
193    for (pos, idx_val) in index_arr.inner.iter().enumerate() {
194        let idx = *idx_val as usize;
195        if idx >= n_choices {
196            return Err(FerrayError::index_out_of_bounds(idx as isize, 0, n_choices));
197        }
198        data.push(choice_iters[idx][pos].clone());
199    }
200
201    let dyn_shape = IxDyn::new(shape);
202    Array::from_vec(dyn_shape, data)
203}
204
205// ===========================================================================
206// compress
207// ===========================================================================
208
209/// Select slices of an array along an axis where `condition` is true.
210///
211/// Equivalent to `np.compress(condition, a, axis)`.
212///
213/// # Errors
214/// - `AxisOutOfBounds` if `axis >= ndim`
215/// - `ShapeMismatch` if `condition.len()` exceeds axis size
216pub fn compress<T: Element, D: Dimension>(
217    condition: &[bool],
218    a: &Array<T, D>,
219    axis: Axis,
220) -> FerrayResult<Array<T, IxDyn>> {
221    let ndim = a.ndim();
222    let ax = axis.index();
223    if ax >= ndim {
224        return Err(FerrayError::axis_out_of_bounds(ax, ndim));
225    }
226    let axis_size = a.shape()[ax];
227    if condition.len() > axis_size {
228        return Err(FerrayError::shape_mismatch(format!(
229            "condition length {} exceeds axis size {}",
230            condition.len(),
231            axis_size
232        )));
233    }
234
235    let indices: Vec<isize> = condition
236        .iter()
237        .enumerate()
238        .filter_map(|(i, &c)| if c { Some(i as isize) } else { None })
239        .collect();
240
241    a.index_select(axis, &indices)
242}
243
244// ===========================================================================
245// select
246// ===========================================================================
247
248/// Return an array drawn from elements in choicelist, depending on conditions.
249///
250/// Equivalent to `np.select(condlist, choicelist, default)`.
251/// The first condition that is true determines which choice is used.
252///
253/// # Errors
254/// - `InvalidValue` if condlist and choicelist have different lengths
255/// - `ShapeMismatch` if shapes are incompatible
256pub fn select<T: Element, D: Dimension>(
257    condlist: &[Array<bool, D>],
258    choicelist: &[Array<T, D>],
259    default: T,
260) -> FerrayResult<Array<T, IxDyn>> {
261    if condlist.len() != choicelist.len() {
262        return Err(FerrayError::invalid_value(format!(
263            "condlist length {} != choicelist length {}",
264            condlist.len(),
265            choicelist.len()
266        )));
267    }
268    if condlist.is_empty() {
269        return Err(FerrayError::invalid_value(
270            "condlist and choicelist must not be empty",
271        ));
272    }
273
274    let shape = condlist[0].shape();
275    for (i, (c, ch)) in condlist.iter().zip(choicelist.iter()).enumerate() {
276        if c.shape() != shape || ch.shape() != shape {
277            return Err(FerrayError::shape_mismatch(format!(
278                "condlist[{}]/choicelist[{}] shape mismatch with reference shape {:?}",
279                i, i, shape
280            )));
281        }
282    }
283
284    let size = condlist[0].size();
285    let mut data = vec![default; size];
286
287    // Process in reverse order so first matching condition wins
288    for (cond, choice) in condlist.iter().zip(choicelist.iter()).rev() {
289        for (i, (&c, v)) in cond.inner.iter().zip(choice.inner.iter()).enumerate() {
290            if c {
291                data[i] = v.clone();
292            }
293        }
294    }
295
296    let dyn_shape = IxDyn::new(shape);
297    Array::from_vec(dyn_shape, data)
298}
299
300// ===========================================================================
301// indices
302// ===========================================================================
303
304/// Return arrays representing the indices of a grid.
305///
306/// Equivalent to `np.indices(dimensions)`. Returns one `u64` array per
307/// dimension, each with shape `dimensions`.
308///
309/// For example, `indices(&[2, 3])` returns two arrays of shape `[2, 3]`:
310/// the first contains row indices, the second column indices.
311pub fn indices(dimensions: &[usize]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
312    let ndim = dimensions.len();
313    let total: usize = dimensions.iter().product();
314
315    let mut result = Vec::with_capacity(ndim);
316
317    for ax in 0..ndim {
318        let mut data = Vec::with_capacity(total);
319        for flat_idx in 0..total {
320            let mut rem = flat_idx;
321            let mut idx_for_ax = 0;
322            for (d, &dim_size) in dimensions.iter().enumerate().rev() {
323                let coord = rem % dim_size;
324                rem /= dim_size;
325                if d == ax {
326                    idx_for_ax = coord;
327                }
328            }
329            data.push(idx_for_ax as u64);
330        }
331        let dim = IxDyn::new(dimensions);
332        result.push(Array::from_vec(dim, data)?);
333    }
334
335    Ok(result)
336}
337
338// ===========================================================================
339// ix_
340// ===========================================================================
341
342/// Construct an open mesh from multiple sequences.
343///
344/// Equivalent to `np.ix_(*args)`. Returns a list of arrays, each with
345/// shape `(1, 1, ..., N, ..., 1)` where `N` is the length of that sequence
346/// and it appears in the position corresponding to its argument index.
347///
348/// This is useful for constructing index arrays for cross-indexing.
349pub fn ix_(sequences: &[&[u64]]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
350    let ndim = sequences.len();
351    let mut result = Vec::with_capacity(ndim);
352
353    for (i, seq) in sequences.iter().enumerate() {
354        let mut shape = vec![1usize; ndim];
355        shape[i] = seq.len();
356
357        let data = seq.to_vec();
358        let dim = IxDyn::new(&shape);
359        result.push(Array::from_vec(dim, data)?);
360    }
361
362    Ok(result)
363}
364
365// ===========================================================================
366// diag_indices / diag_indices_from
367// ===========================================================================
368
369/// Return the indices to access the main diagonal of an n x n array.
370///
371/// Equivalent to `np.diag_indices(n, ndim=2)`. Returns `ndim` vectors,
372/// each containing `[0, 1, ..., n-1]`.
373pub fn diag_indices(n: usize, ndim: usize) -> Vec<Vec<usize>> {
374    let data: Vec<usize> = (0..n).collect();
375    vec![data; ndim]
376}
377
378/// Return the indices to access the main diagonal of the given array.
379///
380/// The array must be at least 2-D and square (all dimensions equal).
381///
382/// # Errors
383/// - `InvalidValue` if the array has fewer than 2 dimensions
384/// - `ShapeMismatch` if dimensions are not all equal
385pub fn diag_indices_from<T: Element, D: Dimension>(
386    a: &Array<T, D>,
387) -> FerrayResult<Vec<Vec<usize>>> {
388    let ndim = a.ndim();
389    if ndim < 2 {
390        return Err(FerrayError::invalid_value(
391            "diag_indices_from requires at least 2 dimensions",
392        ));
393    }
394    let shape = a.shape();
395    let n = shape[0];
396    for &s in &shape[1..] {
397        if s != n {
398            return Err(FerrayError::shape_mismatch(format!(
399                "all dimensions must be equal for diag_indices_from, got {:?}",
400                shape
401            )));
402        }
403    }
404    Ok(diag_indices(n, ndim))
405}
406
407// ===========================================================================
408// tril_indices / triu_indices / tril_indices_from / triu_indices_from
409// ===========================================================================
410
411/// Return the indices for the lower triangle of an (n, m) array.
412///
413/// Equivalent to `np.tril_indices(n, k, m)`.
414/// `k` is the diagonal offset: 0 = main diagonal, positive = above,
415/// negative = below.
416pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
417    let m = m.unwrap_or(n);
418    let mut rows = Vec::new();
419    let mut cols = Vec::new();
420
421    for i in 0..n {
422        for j in 0..m {
423            if (j as isize) <= (i as isize) + k {
424                rows.push(i);
425                cols.push(j);
426            }
427        }
428    }
429
430    (rows, cols)
431}
432
433/// Return the indices for the upper triangle of an (n, m) array.
434///
435/// Equivalent to `np.triu_indices(n, k, m)`.
436pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
437    let m = m.unwrap_or(n);
438    let mut rows = Vec::new();
439    let mut cols = Vec::new();
440
441    for i in 0..n {
442        for j in 0..m {
443            if (j as isize) >= (i as isize) + k {
444                rows.push(i);
445                cols.push(j);
446            }
447        }
448    }
449
450    (rows, cols)
451}
452
453/// Return the indices for the lower triangle of the given 2-D array.
454///
455/// # Errors
456/// - `InvalidValue` if the array is not 2-D
457pub fn tril_indices_from<T: Element, D: Dimension>(
458    a: &Array<T, D>,
459    k: isize,
460) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
461    let shape = a.shape();
462    if shape.len() != 2 {
463        return Err(FerrayError::invalid_value(
464            "tril_indices_from requires a 2-D array",
465        ));
466    }
467    Ok(tril_indices(shape[0], k, Some(shape[1])))
468}
469
470/// Return the indices for the upper triangle of the given 2-D array.
471///
472/// # Errors
473/// - `InvalidValue` if the array is not 2-D
474pub fn triu_indices_from<T: Element, D: Dimension>(
475    a: &Array<T, D>,
476    k: isize,
477) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
478    let shape = a.shape();
479    if shape.len() != 2 {
480        return Err(FerrayError::invalid_value(
481            "triu_indices_from requires a 2-D array",
482        ));
483    }
484    Ok(triu_indices(shape[0], k, Some(shape[1])))
485}
486
487// ===========================================================================
488// ravel_multi_index / unravel_index
489// ===========================================================================
490
491/// Convert a tuple of index arrays to a flat index array.
492///
493/// Equivalent to `np.ravel_multi_index(multi_index, dims)`.
494/// Uses row-major (C) ordering.
495///
496/// # Errors
497/// - `InvalidValue` if multi_index arrays have different lengths
498/// - `IndexOutOfBounds` if any index is out of range for its dimension
499#[allow(clippy::needless_range_loop)]
500pub fn ravel_multi_index(multi_index: &[&[usize]], dims: &[usize]) -> FerrayResult<Vec<usize>> {
501    if multi_index.len() != dims.len() {
502        return Err(FerrayError::invalid_value(format!(
503            "multi_index has {} components but dims has {} dimensions",
504            multi_index.len(),
505            dims.len()
506        )));
507    }
508    if multi_index.is_empty() {
509        return Ok(vec![]);
510    }
511
512    let n = multi_index[0].len();
513    for (i, idx_arr) in multi_index.iter().enumerate() {
514        if idx_arr.len() != n {
515            return Err(FerrayError::invalid_value(format!(
516                "multi_index[{}] has length {} but expected {}",
517                i,
518                idx_arr.len(),
519                n
520            )));
521        }
522    }
523
524    // Compute strides for C-order
525    let ndim = dims.len();
526    let mut strides = vec![1usize; ndim];
527    for i in (0..ndim - 1).rev() {
528        strides[i] = strides[i + 1] * dims[i + 1];
529    }
530
531    let mut flat = Vec::with_capacity(n);
532    #[allow(clippy::needless_range_loop)]
533    for pos in 0..n {
534        let mut linear = 0usize;
535        for (d, &dim_size) in dims.iter().enumerate() {
536            let coord = multi_index[d][pos];
537            if coord >= dim_size {
538                return Err(FerrayError::index_out_of_bounds(
539                    coord as isize,
540                    d,
541                    dim_size,
542                ));
543            }
544            linear += coord * strides[d];
545        }
546        flat.push(linear);
547    }
548
549    Ok(flat)
550}
551
552/// Convert flat indices to a tuple of coordinate arrays.
553///
554/// Equivalent to `np.unravel_index(indices, shape)`.
555/// Uses row-major (C) ordering.
556///
557/// # Errors
558/// - `IndexOutOfBounds` if any flat index >= product(shape)
559pub fn unravel_index(flat_indices: &[usize], shape: &[usize]) -> FerrayResult<Vec<Vec<usize>>> {
560    let total: usize = shape.iter().product();
561    let ndim = shape.len();
562    let n = flat_indices.len();
563
564    let mut result: Vec<Vec<usize>> = vec![Vec::with_capacity(n); ndim];
565
566    for &flat_idx in flat_indices {
567        if flat_idx >= total {
568            return Err(FerrayError::index_out_of_bounds(
569                flat_idx as isize,
570                0,
571                total,
572            ));
573        }
574        let mut rem = flat_idx;
575        for (d, &dim_size) in shape.iter().enumerate().rev() {
576            result[d].push(rem % dim_size);
577            rem /= dim_size;
578        }
579    }
580
581    Ok(result)
582}
583
584// ===========================================================================
585// flatnonzero
586// ===========================================================================
587
588/// Return the indices of non-zero elements in the flattened array.
589///
590/// Equivalent to `np.flatnonzero(a)`. An element is "non-zero" if it
591/// is not equal to the type's zero value.
592pub fn flatnonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<usize> {
593    let zero = T::zero();
594    a.inner
595        .iter()
596        .enumerate()
597        .filter_map(|(i, val)| if *val != zero { Some(i) } else { None })
598        .collect()
599}
600
601// ===========================================================================
602// nonzero / argwhere (#373)
603// ===========================================================================
604
605/// Return the indices of non-zero elements, one index vector per axis.
606///
607/// Equivalent to `np.nonzero(a)`. For an N-dimensional array with K
608/// non-zero elements, the returned `Vec<Vec<usize>>` has length N, and
609/// each inner vector has length K. `result[d][i]` is the coordinate
610/// along axis `d` of the `i`-th non-zero element (in row-major order).
611///
612/// For a 1-D array this is equivalent to wrapping `flatnonzero` in a
613/// single-element outer vector. For 2-D, the two inner vectors give
614/// (row_indices, col_indices) — the typical pair used to reconstruct
615/// sparse matrices or index back into the original array.
616///
617/// Returns `usize` coordinates because `usize` is not an `Element` type;
618/// callers who need a `Array<i64, _>` can cast via
619/// [`argwhere`] or by wrapping each inner vector in an `Array<i64, Ix1>`.
620pub fn nonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<Vec<usize>> {
621    let zero = T::zero();
622    let ndim = a.ndim();
623    let mut result: Vec<Vec<usize>> = vec![Vec::new(); ndim];
624    for (idx, val) in a.indexed_iter() {
625        if *val != zero {
626            for (d, &c) in idx.iter().enumerate() {
627                result[d].push(c);
628            }
629        }
630    }
631    result
632}
633
634/// Return the coordinates of non-zero elements as a 2-D `(N, ndim)` array.
635///
636/// Equivalent to `np.argwhere(a)`. Each row of the result is the
637/// multi-index of one non-zero element, in row-major order. The output
638/// dtype is `i64` to match NumPy's default signed integer index type
639/// (and because `usize` is not an `Element`). For a 0-D (scalar) input
640/// the result has shape `(0, 0)` or `(1, 0)` depending on whether the
641/// scalar is zero.
642///
643/// # Errors
644/// Returns an error only if constructing the result `Array` fails, which
645/// should never happen for a well-formed input.
646pub fn argwhere<T: Element + PartialEq, D: Dimension>(
647    a: &Array<T, D>,
648) -> FerrayResult<Array<i64, Ix2>> {
649    let zero = T::zero();
650    let ndim = a.ndim();
651    let mut data: Vec<i64> = Vec::new();
652    let mut count: usize = 0;
653    for (idx, val) in a.indexed_iter() {
654        if *val != zero {
655            for &c in &idx {
656                data.push(c as i64);
657            }
658            count += 1;
659        }
660    }
661    Array::<i64, Ix2>::from_vec(Ix2::new([count, ndim]), data)
662}
663
664// ===========================================================================
665// ndindex / ndenumerate iterators
666// ===========================================================================
667
668/// An iterator over all multi-dimensional indices for a given shape.
669///
670/// Equivalent to `np.ndindex(*shape)`. Yields indices in row-major order.
671pub struct NdIndex {
672    shape: Vec<usize>,
673    current: Vec<usize>,
674    done: bool,
675}
676
677impl NdIndex {
678    fn new(shape: &[usize]) -> Self {
679        let done = shape.contains(&0);
680        Self {
681            shape: shape.to_vec(),
682            current: vec![0; shape.len()],
683            done,
684        }
685    }
686}
687
688impl Iterator for NdIndex {
689    type Item = Vec<usize>;
690
691    fn next(&mut self) -> Option<Self::Item> {
692        if self.done {
693            return None;
694        }
695
696        let result = self.current.clone();
697
698        // Increment: rightmost dimension first (row-major / C-order)
699        let mut carry = true;
700        for i in (0..self.shape.len()).rev() {
701            if carry {
702                self.current[i] += 1;
703                if self.current[i] >= self.shape[i] {
704                    self.current[i] = 0;
705                    carry = true;
706                } else {
707                    carry = false;
708                }
709            }
710        }
711        if carry {
712            self.done = true;
713        }
714
715        Some(result)
716    }
717
718    fn size_hint(&self) -> (usize, Option<usize>) {
719        if self.done {
720            return (0, Some(0));
721        }
722        let total: usize = self.shape.iter().product();
723        // Compute how many we've already yielded
724        let mut yielded = 0usize;
725        let ndim = self.shape.len();
726        let mut stride = 1usize;
727        for i in (0..ndim).rev() {
728            yielded += self.current[i] * stride;
729            stride *= self.shape[i];
730        }
731        let remaining = total - yielded;
732        (remaining, Some(remaining))
733    }
734}
735
736/// Create an iterator over all multi-dimensional indices for a shape.
737///
738/// Equivalent to `np.ndindex(*shape)`.
739pub fn ndindex(shape: &[usize]) -> NdIndex {
740    NdIndex::new(shape)
741}
742
743/// Create an iterator yielding `(index, &value)` pairs.
744///
745/// Equivalent to `np.ndenumerate(a)`.
746pub fn ndenumerate<'a, T: Element, D: Dimension>(
747    a: &'a Array<T, D>,
748) -> impl Iterator<Item = (Vec<usize>, &'a T)> + 'a {
749    let shape = a.shape().to_vec();
750    let ndim = shape.len();
751    a.inner.iter().enumerate().map(move |(flat_idx, val)| {
752        let mut idx = vec![0usize; ndim];
753        let mut rem = flat_idx;
754        for (d, s) in shape.iter().enumerate().rev() {
755            if *s > 0 {
756                idx[d] = rem % s;
757                rem /= s;
758            }
759        }
760        (idx, val)
761    })
762}
763
764// ---------------------------------------------------------------------------
765// where_select — ternary selection (np.where equivalent)
766// ---------------------------------------------------------------------------
767
768/// Select elements from `x` or `y` depending on `condition`.
769///
770/// For each element position, returns `x[i]` where `condition[i]` is `true`,
771/// and `y[i]` where `condition[i]` is `false`.
772///
773/// All three arrays must have the same shape.
774///
775/// Equivalent to `numpy.where(condition, x, y)`.
776///
777/// # Errors
778/// Returns `FerrayError::ShapeMismatch` if shapes differ.
779pub fn where_select<T: Element + Copy, D: Dimension>(
780    condition: &Array<bool, D>,
781    x: &Array<T, D>,
782    y: &Array<T, D>,
783) -> FerrayResult<Array<T, D>> {
784    if condition.shape() != x.shape() || condition.shape() != y.shape() {
785        return Err(FerrayError::shape_mismatch(format!(
786            "where_select: condition shape {:?}, x shape {:?}, y shape {:?} must all match",
787            condition.shape(),
788            x.shape(),
789            y.shape()
790        )));
791    }
792    let data: Vec<T> = condition
793        .iter()
794        .zip(x.iter().zip(y.iter()))
795        .map(|(&c, (&xi, &yi))| if c { xi } else { yi })
796        .collect();
797    Array::from_vec(x.dim().clone(), data)
798}
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803    use crate::dimension::{Ix1, Ix2};
804
805    // -----------------------------------------------------------------------
806    // take
807    // -----------------------------------------------------------------------
808
809    #[test]
810    fn take_1d() {
811        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
812        let taken = take(&arr, &[0, 2, 4], Axis(0)).unwrap();
813        assert_eq!(taken.shape(), &[3]);
814        let data: Vec<i32> = taken.iter().copied().collect();
815        assert_eq!(data, vec![10, 30, 50]);
816    }
817
818    #[test]
819    fn take_2d_axis1() {
820        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
821        let taken = take(&arr, &[0, 2], Axis(1)).unwrap();
822        assert_eq!(taken.shape(), &[3, 2]);
823        let data: Vec<i32> = taken.iter().copied().collect();
824        assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
825    }
826
827    #[test]
828    fn take_negative_indices() {
829        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
830        let taken = take(&arr, &[-1, -3], Axis(0)).unwrap();
831        let data: Vec<i32> = taken.iter().copied().collect();
832        assert_eq!(data, vec![40, 20]);
833    }
834
835    // -----------------------------------------------------------------------
836    // take_along_axis
837    // -----------------------------------------------------------------------
838
839    #[test]
840    fn take_along_axis_basic() {
841        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
842        let taken = take_along_axis(&arr, &[1, 3], Axis(1)).unwrap();
843        assert_eq!(taken.shape(), &[3, 2]);
844    }
845
846    // -----------------------------------------------------------------------
847    // put
848    // -----------------------------------------------------------------------
849
850    #[test]
851    fn put_flat() {
852        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 0, 0, 0, 0]).unwrap();
853        arr.put(&[1, 3], &[99, 88]).unwrap();
854        assert_eq!(arr.as_slice().unwrap(), &[0, 99, 0, 88, 0]);
855    }
856
857    #[test]
858    fn put_cycling_values() {
859        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0; 5]).unwrap();
860        arr.put(&[0, 1, 2, 3, 4], &[10, 20]).unwrap();
861        assert_eq!(arr.as_slice().unwrap(), &[10, 20, 10, 20, 10]);
862    }
863
864    #[test]
865    fn put_out_of_bounds() {
866        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
867        assert!(arr.put(&[5], &[1]).is_err());
868    }
869
870    // -----------------------------------------------------------------------
871    // fill_diagonal
872    // -----------------------------------------------------------------------
873
874    #[test]
875    fn fill_diagonal_2d() {
876        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 3]), vec![0; 9]).unwrap();
877        arr.fill_diagonal(1);
878        let data: Vec<i32> = arr.iter().copied().collect();
879        assert_eq!(data, vec![1, 0, 0, 0, 1, 0, 0, 0, 1]);
880    }
881
882    #[test]
883    fn fill_diagonal_rectangular() {
884        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 4]), vec![0; 8]).unwrap();
885        arr.fill_diagonal(5);
886        let data: Vec<i32> = arr.iter().copied().collect();
887        assert_eq!(data, vec![5, 0, 0, 0, 0, 5, 0, 0]);
888    }
889
890    // -----------------------------------------------------------------------
891    // choose
892    // -----------------------------------------------------------------------
893
894    #[test]
895    fn choose_basic() {
896        let idx = Array::<u64, Ix1>::from_vec(Ix1::new([4]), vec![0, 1, 0, 1]).unwrap();
897        let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
898        let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![100, 200, 300, 400]).unwrap();
899        let result = choose(&idx, &[c0, c1]).unwrap();
900        let data: Vec<i32> = result.iter().copied().collect();
901        assert_eq!(data, vec![10, 200, 30, 400]);
902    }
903
904    #[test]
905    fn choose_out_of_bounds() {
906        let idx = Array::<u64, Ix1>::from_vec(Ix1::new([2]), vec![0, 2]).unwrap();
907        let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 2]).unwrap();
908        let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3, 4]).unwrap();
909        assert!(choose(&idx, &[c0, c1]).is_err());
910    }
911
912    // -----------------------------------------------------------------------
913    // compress
914    // -----------------------------------------------------------------------
915
916    #[test]
917    fn compress_1d() {
918        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
919        let result = compress(&[true, false, true, false, true], &arr, Axis(0)).unwrap();
920        let data: Vec<i32> = result.iter().copied().collect();
921        assert_eq!(data, vec![10, 30, 50]);
922    }
923
924    #[test]
925    fn compress_2d_axis0() {
926        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
927        let result = compress(&[true, false, true], &arr, Axis(0)).unwrap();
928        assert_eq!(result.shape(), &[2, 4]);
929        let data: Vec<i32> = result.iter().copied().collect();
930        assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
931    }
932
933    // -----------------------------------------------------------------------
934    // select
935    // -----------------------------------------------------------------------
936
937    #[test]
938    fn select_basic() {
939        let c1 =
940            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, false]).unwrap();
941        let c2 =
942            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
943        let ch1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 1, 1, 1]).unwrap();
944        let ch2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![2, 2, 2, 2]).unwrap();
945        let result = select(&[c1, c2], &[ch1, ch2], 0).unwrap();
946        let data: Vec<i32> = result.iter().copied().collect();
947        assert_eq!(data, vec![1, 2, 0, 0]);
948    }
949
950    // -----------------------------------------------------------------------
951    // indices
952    // -----------------------------------------------------------------------
953
954    #[test]
955    fn indices_2d() {
956        let idx = indices(&[2, 3]).unwrap();
957        assert_eq!(idx.len(), 2);
958        assert_eq!(idx[0].shape(), &[2, 3]);
959        assert_eq!(idx[1].shape(), &[2, 3]);
960        let rows: Vec<u64> = idx[0].iter().copied().collect();
961        assert_eq!(rows, vec![0, 0, 0, 1, 1, 1]);
962        let cols: Vec<u64> = idx[1].iter().copied().collect();
963        assert_eq!(cols, vec![0, 1, 2, 0, 1, 2]);
964    }
965
966    // -----------------------------------------------------------------------
967    // ix_
968    // -----------------------------------------------------------------------
969
970    #[test]
971    fn ix_basic() {
972        let result = ix_(&[&[0, 1], &[2, 3, 4]]).unwrap();
973        assert_eq!(result.len(), 2);
974        assert_eq!(result[0].shape(), &[2, 1]);
975        assert_eq!(result[1].shape(), &[1, 3]);
976    }
977
978    // -----------------------------------------------------------------------
979    // diag_indices
980    // -----------------------------------------------------------------------
981
982    #[test]
983    fn diag_indices_basic() {
984        let idx = diag_indices(3, 2);
985        assert_eq!(idx.len(), 2);
986        assert_eq!(idx[0], vec![0, 1, 2]);
987        assert_eq!(idx[1], vec![0, 1, 2]);
988    }
989
990    #[test]
991    fn diag_indices_from_square() {
992        let arr = Array::<f64, Ix2>::zeros(Ix2::new([4, 4])).unwrap();
993        let idx = diag_indices_from(&arr).unwrap();
994        assert_eq!(idx.len(), 2);
995        assert_eq!(idx[0].len(), 4);
996    }
997
998    #[test]
999    fn diag_indices_from_not_square() {
1000        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
1001        assert!(diag_indices_from(&arr).is_err());
1002    }
1003
1004    // -----------------------------------------------------------------------
1005    // tril_indices / triu_indices
1006    // -----------------------------------------------------------------------
1007
1008    #[test]
1009    fn tril_indices_basic() {
1010        let (rows, cols) = tril_indices(3, 0, None);
1011        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1012        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1013    }
1014
1015    #[test]
1016    fn triu_indices_basic() {
1017        let (rows, cols) = triu_indices(3, 0, None);
1018        assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
1019        assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
1020    }
1021
1022    #[test]
1023    fn tril_indices_with_k() {
1024        let (rows, cols) = tril_indices(3, 1, None);
1025        assert_eq!(rows, vec![0, 0, 1, 1, 1, 2, 2, 2]);
1026        assert_eq!(cols, vec![0, 1, 0, 1, 2, 0, 1, 2]);
1027    }
1028
1029    #[test]
1030    fn triu_indices_with_negative_k() {
1031        let (rows, cols) = triu_indices(3, -1, None);
1032        assert_eq!(rows, vec![0, 0, 0, 1, 1, 1, 2, 2]);
1033        assert_eq!(cols, vec![0, 1, 2, 0, 1, 2, 1, 2]);
1034    }
1035
1036    #[test]
1037    fn tril_indices_from_test() {
1038        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1039        let (rows, _cols) = tril_indices_from(&arr, 0).unwrap();
1040        assert_eq!(rows.len(), 6);
1041    }
1042
1043    #[test]
1044    fn triu_indices_from_test() {
1045        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1046        let (rows, _cols) = triu_indices_from(&arr, 0).unwrap();
1047        assert_eq!(rows.len(), 6);
1048    }
1049
1050    #[test]
1051    fn tril_indices_rectangular() {
1052        let (rows, cols) = tril_indices(3, 0, Some(4));
1053        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1054        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1055    }
1056
1057    // -----------------------------------------------------------------------
1058    // ravel_multi_index / unravel_index
1059    // -----------------------------------------------------------------------
1060
1061    #[test]
1062    fn ravel_multi_index_basic() {
1063        let flat = ravel_multi_index(&[&[0, 1, 2], &[1, 2, 0]], &[3, 4]).unwrap();
1064        assert_eq!(flat, vec![1, 6, 8]);
1065    }
1066
1067    #[test]
1068    fn ravel_multi_index_3d() {
1069        let flat = ravel_multi_index(&[&[0], &[1], &[2]], &[2, 3, 4]).unwrap();
1070        assert_eq!(flat, vec![6]);
1071    }
1072
1073    #[test]
1074    fn ravel_multi_index_out_of_bounds() {
1075        assert!(ravel_multi_index(&[&[3]], &[3]).is_err());
1076    }
1077
1078    #[test]
1079    fn unravel_index_basic() {
1080        let coords = unravel_index(&[1, 6, 8], &[3, 4]).unwrap();
1081        assert_eq!(coords[0], vec![0, 1, 2]);
1082        assert_eq!(coords[1], vec![1, 2, 0]);
1083    }
1084
1085    #[test]
1086    fn unravel_index_out_of_bounds() {
1087        assert!(unravel_index(&[12], &[3, 4]).is_err());
1088    }
1089
1090    #[test]
1091    fn ravel_unravel_roundtrip() {
1092        let dims = &[3, 4, 5];
1093        let a: &[usize] = &[1, 2];
1094        let b: &[usize] = &[2, 3];
1095        let c: &[usize] = &[3, 4];
1096        let multi: &[&[usize]] = &[a, b, c];
1097        let flat = ravel_multi_index(multi, dims).unwrap();
1098        let coords = unravel_index(&flat, dims).unwrap();
1099        assert_eq!(coords[0], vec![1, 2]);
1100        assert_eq!(coords[1], vec![2, 3]);
1101        assert_eq!(coords[2], vec![3, 4]);
1102    }
1103
1104    // -----------------------------------------------------------------------
1105    // flatnonzero
1106    // -----------------------------------------------------------------------
1107
1108    #[test]
1109    fn flatnonzero_basic() {
1110        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1111        let nz = flatnonzero(&arr);
1112        assert_eq!(nz, vec![1, 3]);
1113    }
1114
1115    #[test]
1116    fn flatnonzero_2d() {
1117        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1118        let nz = flatnonzero(&arr);
1119        assert_eq!(nz, vec![1, 3, 5]);
1120    }
1121
1122    #[test]
1123    fn flatnonzero_all_zero() {
1124        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1125        let nz = flatnonzero(&arr);
1126        assert_eq!(nz.len(), 0);
1127    }
1128
1129    // -----------------------------------------------------------------------
1130    // nonzero / argwhere (#373)
1131    // -----------------------------------------------------------------------
1132
1133    #[test]
1134    fn nonzero_1d() {
1135        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1136        let nz = nonzero(&arr);
1137        // One axis, one inner vec with the hit positions.
1138        assert_eq!(nz.len(), 1);
1139        assert_eq!(nz[0], vec![1, 3]);
1140    }
1141
1142    #[test]
1143    fn nonzero_2d_yields_row_and_col_indices() {
1144        // [[0, 1, 0],
1145        //  [2, 0, 3]]
1146        // Non-zero coordinates (row-major): (0,1), (1,0), (1,2).
1147        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1148        let nz = nonzero(&arr);
1149        assert_eq!(nz.len(), 2);
1150        assert_eq!(nz[0], vec![0, 1, 1]);
1151        assert_eq!(nz[1], vec![1, 0, 2]);
1152    }
1153
1154    #[test]
1155    fn nonzero_all_zero_returns_empty_per_axis() {
1156        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1157        let nz = nonzero(&arr);
1158        assert_eq!(nz.len(), 2);
1159        assert!(nz[0].is_empty());
1160        assert!(nz[1].is_empty());
1161    }
1162
1163    #[test]
1164    fn nonzero_f64_treats_negative_zero_as_zero() {
1165        // -0.0 == 0.0 for PartialEq, so -0.0 is "zero" per numpy semantics.
1166        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![-0.0, 1.5, 0.0, -2.5]).unwrap();
1167        let nz = nonzero(&arr);
1168        assert_eq!(nz[0], vec![1, 3]);
1169    }
1170
1171    #[test]
1172    fn argwhere_2d_has_one_row_per_nonzero() {
1173        // Same input as nonzero_2d_yields_row_and_col_indices.
1174        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1175        let coords = argwhere(&arr).unwrap();
1176        assert_eq!(coords.shape(), &[3, 2]);
1177        assert_eq!(coords.as_slice().unwrap(), &[0, 1, 1, 0, 1, 2]);
1178    }
1179
1180    #[test]
1181    fn argwhere_1d_is_column_vector() {
1182        // A (K, 1) shape means K non-zero rows for a 1-D array.
1183        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 7, 0, 9, 3]).unwrap();
1184        let coords = argwhere(&arr).unwrap();
1185        assert_eq!(coords.shape(), &[3, 1]);
1186        assert_eq!(coords.as_slice().unwrap(), &[1, 3, 4]);
1187    }
1188
1189    #[test]
1190    fn argwhere_all_zero_returns_empty() {
1191        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1192        let coords = argwhere(&arr).unwrap();
1193        assert_eq!(coords.shape(), &[0, 2]);
1194        assert_eq!(coords.size(), 0);
1195    }
1196
1197    // -----------------------------------------------------------------------
1198    // ndindex
1199    // -----------------------------------------------------------------------
1200
1201    #[test]
1202    fn ndindex_2d() {
1203        let indices: Vec<Vec<usize>> = ndindex(&[2, 3]).collect();
1204        assert_eq!(indices.len(), 6);
1205        assert_eq!(indices[0], vec![0, 0]);
1206        assert_eq!(indices[1], vec![0, 1]);
1207        assert_eq!(indices[2], vec![0, 2]);
1208        assert_eq!(indices[3], vec![1, 0]);
1209        assert_eq!(indices[4], vec![1, 1]);
1210        assert_eq!(indices[5], vec![1, 2]);
1211    }
1212
1213    #[test]
1214    fn ndindex_1d() {
1215        let indices: Vec<Vec<usize>> = ndindex(&[4]).collect();
1216        assert_eq!(indices.len(), 4);
1217        assert_eq!(indices[0], vec![0]);
1218        assert_eq!(indices[3], vec![3]);
1219    }
1220
1221    #[test]
1222    fn ndindex_empty() {
1223        let indices: Vec<Vec<usize>> = ndindex(&[0]).collect();
1224        assert_eq!(indices.len(), 0);
1225    }
1226
1227    #[test]
1228    fn ndindex_scalar() {
1229        let indices: Vec<Vec<usize>> = ndindex(&[]).collect();
1230        assert_eq!(indices.len(), 1);
1231        assert_eq!(indices[0], Vec::<usize>::new());
1232    }
1233
1234    // -----------------------------------------------------------------------
1235    // ndenumerate
1236    // -----------------------------------------------------------------------
1237
1238    #[test]
1239    fn ndenumerate_2d() {
1240        let arr =
1241            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
1242        let items: Vec<(Vec<usize>, &i32)> = ndenumerate(&arr).collect();
1243        assert_eq!(items.len(), 6);
1244        assert_eq!(items[0], (vec![0, 0], &10));
1245        assert_eq!(items[1], (vec![0, 1], &20));
1246        assert_eq!(items[5], (vec![1, 2], &60));
1247    }
1248
1249    // -----------------------------------------------------------------------
1250    // put_along_axis
1251    // -----------------------------------------------------------------------
1252
1253    #[test]
1254    fn put_along_axis_basic() {
1255        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![0; 12]).unwrap();
1256        let values =
1257            Array::<i32, IxDyn>::from_vec(IxDyn::new(&[8]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
1258        arr.put_along_axis(&[0, 2], &values, Axis(0)).unwrap();
1259        let data: Vec<i32> = arr.iter().copied().collect();
1260        assert_eq!(data, vec![1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8]);
1261    }
1262
1263    // -----------------------------------------------------------------------
1264    // where_
1265    // -----------------------------------------------------------------------
1266
1267    #[test]
1268    fn where_basic() {
1269        let cond =
1270            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1271        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1272        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
1273        let result = where_select(&cond, &x, &y).unwrap();
1274        assert_eq!(result.as_slice().unwrap(), &[1.0, 20.0, 3.0, 40.0]);
1275    }
1276
1277    #[test]
1278    fn where_all_true() {
1279        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1280        let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1281        let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1282        let result = where_select(&cond, &x, &y).unwrap();
1283        assert_eq!(result.as_slice().unwrap(), &[1, 2, 3]);
1284    }
1285
1286    #[test]
1287    fn where_all_false() {
1288        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1289        let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1290        let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1291        let result = where_select(&cond, &x, &y).unwrap();
1292        assert_eq!(result.as_slice().unwrap(), &[10, 20, 30]);
1293    }
1294
1295    #[test]
1296    fn where_shape_mismatch() {
1297        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true; 3]).unwrap();
1298        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0; 4]).unwrap();
1299        let y = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0; 3]).unwrap();
1300        assert!(where_select(&cond, &x, &y).is_err());
1301    }
1302
1303    #[test]
1304    fn where_2d() {
1305        let cond =
1306            Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![true, false, false, true]).unwrap();
1307        let x = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
1308        let y = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![10, 20, 30, 40]).unwrap();
1309        let result = where_select(&cond, &x, &y).unwrap();
1310        let data: Vec<i32> = result.iter().copied().collect();
1311        assert_eq!(data, vec![1, 20, 30, 4]);
1312    }
1313}