Skip to main content

ferray_core/indexing/
extended.rs

1// ferray-core: Extended indexing functions (REQ-15a)
2//
3// take, take_along_axis, put, put_along_axis, choose, compress, select,
4// indices, ix_, diag_indices, diag_indices_from, tril_indices, triu_indices,
5// tril_indices_from, triu_indices_from, ravel_multi_index, unravel_index,
6// flatnonzero, fill_diagonal, ndindex (iterator), ndenumerate (iterator)
7//
8// Index-returning functions return Vec<Vec<usize>> or (Vec<usize>, Vec<usize>)
9// because usize is not an Element type. Callers can wrap these into arrays
10// of u64 or i64 if needed.
11
12use super::normalize_index;
13use crate::array::owned::Array;
14use crate::dimension::{Axis, Dimension, Ix2, IxDyn};
15use crate::dtype::Element;
16use crate::error::{FerrayError, FerrayResult};
17
18// ===========================================================================
19// take / take_along_axis
20// ===========================================================================
21
22/// Take elements from an array along an axis.
23///
24/// Equivalent to `np.take(a, indices, axis)`. Returns a copy.
25///
26/// # Errors
27/// - `AxisOutOfBounds` if `axis >= ndim`
28/// - `IndexOutOfBounds` if any index is out of range
29pub fn take<T: Element, D: Dimension>(
30    a: &Array<T, D>,
31    indices: &[isize],
32    axis: Axis,
33) -> FerrayResult<Array<T, IxDyn>> {
34    a.index_select(axis, indices)
35}
36
37/// Take values from an array along an axis using an index slice.
38///
39/// Similar to `np.take_along_axis`. The `indices` slice contains
40/// indices into `a` along the specified axis. The result replaces
41/// the `axis` dimension with the indices dimension.
42///
43/// # Errors
44/// - `AxisOutOfBounds` if `axis >= ndim`
45/// - `IndexOutOfBounds` if any index is out of range
46pub fn take_along_axis<T: Element, D: Dimension>(
47    a: &Array<T, D>,
48    indices: &[isize],
49    axis: Axis,
50) -> FerrayResult<Array<T, IxDyn>> {
51    a.index_select(axis, indices)
52}
53
54// ===========================================================================
55// put / put_along_axis
56// ===========================================================================
57
58impl<T: Element, D: Dimension> Array<T, D> {
59    /// Put values into the flattened array at the given indices.
60    ///
61    /// Equivalent to `np.put(a, ind, v)`. Modifies the array in-place.
62    /// Indices refer to the flattened (row-major) array.
63    /// Values are cycled if fewer than indices.
64    ///
65    /// # Errors
66    /// - `IndexOutOfBounds` if any index is out of range
67    /// - `InvalidValue` if values is empty
68    pub fn put(&mut self, indices: &[isize], values: &[T]) -> FerrayResult<()> {
69        if values.is_empty() {
70            return Err(FerrayError::invalid_value("values must not be empty"));
71        }
72        let size = self.size();
73        let normalized: Vec<usize> = indices
74            .iter()
75            .map(|&idx| normalize_index(idx, size, 0))
76            .collect::<FerrayResult<Vec<_>>>()?;
77
78        let mut flat: Vec<&mut T> = self.inner.iter_mut().collect();
79
80        for (i, &idx) in normalized.iter().enumerate() {
81            let val_idx = i % values.len();
82            *flat[idx] = values[val_idx].clone();
83        }
84        Ok(())
85    }
86
87    /// Put values along an axis at specified indices.
88    ///
89    /// For each index position along `axis`, assigns the values from the
90    /// corresponding sub-array of `values`.
91    ///
92    /// # Errors
93    /// - `AxisOutOfBounds` if `axis >= ndim`
94    /// - `IndexOutOfBounds` if any index is out of range
95    pub fn put_along_axis(
96        &mut self,
97        indices: &[isize],
98        values: &Array<T, IxDyn>,
99        axis: Axis,
100    ) -> FerrayResult<()>
101    where
102        D::NdarrayDim: ndarray::RemoveAxis,
103    {
104        let ndim = self.ndim();
105        let ax = axis.index();
106        if ax >= ndim {
107            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
108        }
109        let axis_size = self.shape()[ax];
110
111        let normalized: Vec<usize> = indices
112            .iter()
113            .map(|&idx| normalize_index(idx, axis_size, ax))
114            .collect::<FerrayResult<Vec<_>>>()?;
115
116        let nd_axis = ndarray::Axis(ax);
117        let mut val_iter = values.inner.iter();
118
119        for &idx in &normalized {
120            let mut sub = self.inner.index_axis_mut(nd_axis, idx);
121            for elem in &mut sub {
122                if let Some(v) = val_iter.next() {
123                    *elem = v.clone();
124                }
125            }
126        }
127        Ok(())
128    }
129
130    /// Fill the main diagonal of a 2-D (or N-D) array with a value.
131    ///
132    /// For N-D arrays, the diagonal consists of indices where all
133    /// index values are equal: `a[i, i, ..., i]`.
134    ///
135    /// Equivalent to `np.fill_diagonal(a, val)`.
136    pub fn fill_diagonal(&mut self, val: T) {
137        let shape = self.shape().to_vec();
138        if shape.is_empty() {
139            return;
140        }
141        let min_dim = *shape.iter().min().unwrap_or(&0);
142        let ndim = shape.len();
143
144        for i in 0..min_dim {
145            let idx: Vec<usize> = vec![i; ndim];
146            let nd_idx = ndarray::IxDyn(&idx);
147            let mut dyn_view = self.inner.view_mut().into_dyn();
148            dyn_view[nd_idx] = val.clone();
149        }
150    }
151}
152
153// ===========================================================================
154// choose
155// ===========================================================================
156
157/// Construct an array from an index array and a list of arrays to choose from.
158///
159/// Equivalent to `np.choose(a, choices)`. For each element in `index_arr`,
160/// the value selects which choice array to pick from at that position.
161/// Index values are given as `u64` to avoid the `usize` Element issue.
162///
163/// # Errors
164/// - `IndexOutOfBounds` if any index in `index_arr` is >= `choices.len()`
165/// - `ShapeMismatch` if choice arrays have different shapes from `index_arr`
166pub fn choose<T: Element, D: Dimension>(
167    index_arr: &Array<u64, D>,
168    choices: &[Array<T, D>],
169) -> FerrayResult<Array<T, IxDyn>> {
170    if choices.is_empty() {
171        return Err(FerrayError::invalid_value("choices must not be empty"));
172    }
173
174    let shape = index_arr.shape();
175    for (i, c) in choices.iter().enumerate() {
176        if c.shape() != shape {
177            return Err(FerrayError::shape_mismatch(format!(
178                "choice[{}] shape {:?} does not match index array shape {:?}",
179                i,
180                c.shape(),
181                shape
182            )));
183        }
184    }
185
186    let n_choices = choices.len();
187    let choice_iters: Vec<Vec<T>> = choices
188        .iter()
189        .map(|c| c.inner.iter().cloned().collect())
190        .collect();
191
192    let mut data = Vec::with_capacity(index_arr.size());
193    for (pos, idx_val) in index_arr.inner.iter().enumerate() {
194        let idx = *idx_val as usize;
195        if idx >= n_choices {
196            return Err(FerrayError::index_out_of_bounds(idx as isize, 0, n_choices));
197        }
198        data.push(choice_iters[idx][pos].clone());
199    }
200
201    let dyn_shape = IxDyn::new(shape);
202    Array::from_vec(dyn_shape, data)
203}
204
205// ===========================================================================
206// compress
207// ===========================================================================
208
209/// Select slices of an array along an axis where `condition` is true.
210///
211/// Equivalent to `np.compress(condition, a, axis)`.
212///
213/// # Errors
214/// - `AxisOutOfBounds` if `axis >= ndim`
215/// - `ShapeMismatch` if `condition.len()` exceeds axis size
216pub fn compress<T: Element, D: Dimension>(
217    condition: &[bool],
218    a: &Array<T, D>,
219    axis: Axis,
220) -> FerrayResult<Array<T, IxDyn>> {
221    let ndim = a.ndim();
222    let ax = axis.index();
223    if ax >= ndim {
224        return Err(FerrayError::axis_out_of_bounds(ax, ndim));
225    }
226    let axis_size = a.shape()[ax];
227    if condition.len() > axis_size {
228        return Err(FerrayError::shape_mismatch(format!(
229            "condition length {} exceeds axis size {}",
230            condition.len(),
231            axis_size
232        )));
233    }
234
235    let indices: Vec<isize> = condition
236        .iter()
237        .enumerate()
238        .filter_map(|(i, &c)| if c { Some(i as isize) } else { None })
239        .collect();
240
241    a.index_select(axis, &indices)
242}
243
244// ===========================================================================
245// select
246// ===========================================================================
247
248/// Return an array drawn from elements in choicelist, depending on conditions.
249///
250/// Equivalent to `np.select(condlist, choicelist, default)`.
251/// The first condition that is true determines which choice is used.
252///
253/// # Errors
254/// - `InvalidValue` if condlist and choicelist have different lengths
255/// - `ShapeMismatch` if shapes are incompatible
256pub fn select<T: Element, D: Dimension>(
257    condlist: &[Array<bool, D>],
258    choicelist: &[Array<T, D>],
259    default: T,
260) -> FerrayResult<Array<T, IxDyn>> {
261    if condlist.len() != choicelist.len() {
262        return Err(FerrayError::invalid_value(format!(
263            "condlist length {} != choicelist length {}",
264            condlist.len(),
265            choicelist.len()
266        )));
267    }
268    if condlist.is_empty() {
269        return Err(FerrayError::invalid_value(
270            "condlist and choicelist must not be empty",
271        ));
272    }
273
274    let shape = condlist[0].shape();
275    for (i, (c, ch)) in condlist.iter().zip(choicelist.iter()).enumerate() {
276        if c.shape() != shape || ch.shape() != shape {
277            return Err(FerrayError::shape_mismatch(format!(
278                "condlist[{i}]/choicelist[{i}] shape mismatch with reference shape {shape:?}"
279            )));
280        }
281    }
282
283    let size = condlist[0].size();
284    let mut data = vec![default; size];
285
286    // Process in reverse order so first matching condition wins
287    for (cond, choice) in condlist.iter().zip(choicelist.iter()).rev() {
288        for (i, (&c, v)) in cond.inner.iter().zip(choice.inner.iter()).enumerate() {
289            if c {
290                data[i] = v.clone();
291            }
292        }
293    }
294
295    let dyn_shape = IxDyn::new(shape);
296    Array::from_vec(dyn_shape, data)
297}
298
299// ===========================================================================
300// indices
301// ===========================================================================
302
303/// Return arrays representing the indices of a grid.
304///
305/// Equivalent to `np.indices(dimensions)`. Returns one `u64` array per
306/// dimension, each with shape `dimensions`.
307///
308/// For example, `indices(&[2, 3])` returns two arrays of shape `[2, 3]`:
309/// the first contains row indices, the second column indices.
310pub fn indices(dimensions: &[usize]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
311    let ndim = dimensions.len();
312    let total: usize = dimensions.iter().product();
313
314    let mut result = Vec::with_capacity(ndim);
315
316    for ax in 0..ndim {
317        let mut data = Vec::with_capacity(total);
318        for flat_idx in 0..total {
319            let mut rem = flat_idx;
320            let mut idx_for_ax = 0;
321            for (d, &dim_size) in dimensions.iter().enumerate().rev() {
322                let coord = rem % dim_size;
323                rem /= dim_size;
324                if d == ax {
325                    idx_for_ax = coord;
326                }
327            }
328            data.push(idx_for_ax as u64);
329        }
330        let dim = IxDyn::new(dimensions);
331        result.push(Array::from_vec(dim, data)?);
332    }
333
334    Ok(result)
335}
336
337// ===========================================================================
338// ix_
339// ===========================================================================
340
341/// Construct an open mesh from multiple sequences.
342///
343/// Equivalent to `np.ix_(*args)`. Returns a list of arrays, each with
344/// shape `(1, 1, ..., N, ..., 1)` where `N` is the length of that sequence
345/// and it appears in the position corresponding to its argument index.
346///
347/// This is useful for constructing index arrays for cross-indexing.
348pub fn ix_(sequences: &[&[u64]]) -> FerrayResult<Vec<Array<u64, IxDyn>>> {
349    let ndim = sequences.len();
350    let mut result = Vec::with_capacity(ndim);
351
352    for (i, seq) in sequences.iter().enumerate() {
353        let mut shape = vec![1usize; ndim];
354        shape[i] = seq.len();
355
356        let data = seq.to_vec();
357        let dim = IxDyn::new(&shape);
358        result.push(Array::from_vec(dim, data)?);
359    }
360
361    Ok(result)
362}
363
364// ===========================================================================
365// diag_indices / diag_indices_from
366// ===========================================================================
367
368/// Return the indices to access the main diagonal of an n x n array.
369///
370/// Equivalent to `np.diag_indices(n, ndim=2)`. Returns `ndim` vectors,
371/// each containing `[0, 1, ..., n-1]`.
372#[must_use]
373pub fn diag_indices(n: usize, ndim: usize) -> Vec<Vec<usize>> {
374    let data: Vec<usize> = (0..n).collect();
375    vec![data; ndim]
376}
377
378/// Return the indices to access the main diagonal of the given array.
379///
380/// The array must be at least 2-D and square (all dimensions equal).
381///
382/// # Errors
383/// - `InvalidValue` if the array has fewer than 2 dimensions
384/// - `ShapeMismatch` if dimensions are not all equal
385pub fn diag_indices_from<T: Element, D: Dimension>(
386    a: &Array<T, D>,
387) -> FerrayResult<Vec<Vec<usize>>> {
388    let ndim = a.ndim();
389    if ndim < 2 {
390        return Err(FerrayError::invalid_value(
391            "diag_indices_from requires at least 2 dimensions",
392        ));
393    }
394    let shape = a.shape();
395    let n = shape[0];
396    for &s in &shape[1..] {
397        if s != n {
398            return Err(FerrayError::shape_mismatch(format!(
399                "all dimensions must be equal for diag_indices_from, got {shape:?}"
400            )));
401        }
402    }
403    Ok(diag_indices(n, ndim))
404}
405
406// ===========================================================================
407// tril_indices / triu_indices / tril_indices_from / triu_indices_from
408// ===========================================================================
409
410/// Return the indices for the lower triangle of an (n, m) array.
411///
412/// Equivalent to `np.tril_indices(n, k, m)`.
413/// `k` is the diagonal offset: 0 = main diagonal, positive = above,
414/// negative = below.
415#[must_use]
416pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
417    let m = m.unwrap_or(n);
418    let mut rows = Vec::new();
419    let mut cols = Vec::new();
420
421    for i in 0..n {
422        for j in 0..m {
423            if (j as isize) <= (i as isize) + k {
424                rows.push(i);
425                cols.push(j);
426            }
427        }
428    }
429
430    (rows, cols)
431}
432
433/// Return the indices for the upper triangle of an (n, m) array.
434///
435/// Equivalent to `np.triu_indices(n, k, m)`.
436#[must_use]
437pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> (Vec<usize>, Vec<usize>) {
438    let m = m.unwrap_or(n);
439    let mut rows = Vec::new();
440    let mut cols = Vec::new();
441
442    for i in 0..n {
443        for j in 0..m {
444            if (j as isize) >= (i as isize) + k {
445                rows.push(i);
446                cols.push(j);
447            }
448        }
449    }
450
451    (rows, cols)
452}
453
454/// Return the indices for the lower triangle of the given 2-D array.
455///
456/// # Errors
457/// - `InvalidValue` if the array is not 2-D
458pub fn tril_indices_from<T: Element, D: Dimension>(
459    a: &Array<T, D>,
460    k: isize,
461) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
462    let shape = a.shape();
463    if shape.len() != 2 {
464        return Err(FerrayError::invalid_value(
465            "tril_indices_from requires a 2-D array",
466        ));
467    }
468    Ok(tril_indices(shape[0], k, Some(shape[1])))
469}
470
471/// Return the indices for the upper triangle of the given 2-D array.
472///
473/// # Errors
474/// - `InvalidValue` if the array is not 2-D
475pub fn triu_indices_from<T: Element, D: Dimension>(
476    a: &Array<T, D>,
477    k: isize,
478) -> FerrayResult<(Vec<usize>, Vec<usize>)> {
479    let shape = a.shape();
480    if shape.len() != 2 {
481        return Err(FerrayError::invalid_value(
482            "triu_indices_from requires a 2-D array",
483        ));
484    }
485    Ok(triu_indices(shape[0], k, Some(shape[1])))
486}
487
488// ===========================================================================
489// ravel_multi_index / unravel_index
490// ===========================================================================
491
492/// Convert a tuple of index arrays to a flat index array.
493///
494/// Equivalent to `np.ravel_multi_index(multi_index, dims)`.
495/// Uses row-major (C) ordering.
496///
497/// # Errors
498/// - `InvalidValue` if `multi_index` arrays have different lengths
499/// - `IndexOutOfBounds` if any index is out of range for its dimension
500#[allow(clippy::needless_range_loop)]
501pub fn ravel_multi_index(multi_index: &[&[usize]], dims: &[usize]) -> FerrayResult<Vec<usize>> {
502    if multi_index.len() != dims.len() {
503        return Err(FerrayError::invalid_value(format!(
504            "multi_index has {} components but dims has {} dimensions",
505            multi_index.len(),
506            dims.len()
507        )));
508    }
509    if multi_index.is_empty() {
510        return Ok(vec![]);
511    }
512
513    let n = multi_index[0].len();
514    for (i, idx_arr) in multi_index.iter().enumerate() {
515        if idx_arr.len() != n {
516            return Err(FerrayError::invalid_value(format!(
517                "multi_index[{}] has length {} but expected {}",
518                i,
519                idx_arr.len(),
520                n
521            )));
522        }
523    }
524
525    // Compute strides for C-order
526    let ndim = dims.len();
527    let mut strides = vec![1usize; ndim];
528    for i in (0..ndim - 1).rev() {
529        strides[i] = strides[i + 1] * dims[i + 1];
530    }
531
532    let mut flat = Vec::with_capacity(n);
533    #[allow(clippy::needless_range_loop)]
534    for pos in 0..n {
535        let mut linear = 0usize;
536        for (d, &dim_size) in dims.iter().enumerate() {
537            let coord = multi_index[d][pos];
538            if coord >= dim_size {
539                return Err(FerrayError::index_out_of_bounds(
540                    coord as isize,
541                    d,
542                    dim_size,
543                ));
544            }
545            linear += coord * strides[d];
546        }
547        flat.push(linear);
548    }
549
550    Ok(flat)
551}
552
553/// Convert flat indices to a tuple of coordinate arrays.
554///
555/// Equivalent to `np.unravel_index(indices, shape)`.
556/// Uses row-major (C) ordering.
557///
558/// # Errors
559/// - `IndexOutOfBounds` if any flat index >= product(shape)
560pub fn unravel_index(flat_indices: &[usize], shape: &[usize]) -> FerrayResult<Vec<Vec<usize>>> {
561    let total: usize = shape.iter().product();
562    let ndim = shape.len();
563    let n = flat_indices.len();
564
565    let mut result: Vec<Vec<usize>> = vec![Vec::with_capacity(n); ndim];
566
567    for &flat_idx in flat_indices {
568        if flat_idx >= total {
569            return Err(FerrayError::index_out_of_bounds(
570                flat_idx as isize,
571                0,
572                total,
573            ));
574        }
575        let mut rem = flat_idx;
576        for (d, &dim_size) in shape.iter().enumerate().rev() {
577            result[d].push(rem % dim_size);
578            rem /= dim_size;
579        }
580    }
581
582    Ok(result)
583}
584
585// ===========================================================================
586// flatnonzero
587// ===========================================================================
588
589/// Return the indices of non-zero elements in the flattened array.
590///
591/// Equivalent to `np.flatnonzero(a)`. An element is "non-zero" if it
592/// is not equal to the type's zero value.
593pub fn flatnonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<usize> {
594    let zero = T::zero();
595    a.inner
596        .iter()
597        .enumerate()
598        .filter_map(|(i, val)| if *val == zero { None } else { Some(i) })
599        .collect()
600}
601
602// ===========================================================================
603// nonzero / argwhere (#373)
604// ===========================================================================
605
606/// Return the indices of non-zero elements, one index vector per axis.
607///
608/// Equivalent to `np.nonzero(a)`. For an N-dimensional array with K
609/// non-zero elements, the returned `Vec<Vec<usize>>` has length N, and
610/// each inner vector has length K. `result[d][i]` is the coordinate
611/// along axis `d` of the `i`-th non-zero element (in row-major order).
612///
613/// For a 1-D array this is equivalent to wrapping `flatnonzero` in a
614/// single-element outer vector. For 2-D, the two inner vectors give
615/// (`row_indices`, `col_indices`) — the typical pair used to reconstruct
616/// sparse matrices or index back into the original array.
617///
618/// Returns `usize` coordinates because `usize` is not an `Element` type;
619/// callers who need a `Array<i64, _>` can cast via
620/// [`argwhere`] or by wrapping each inner vector in an `Array<i64, Ix1>`.
621pub fn nonzero<T: Element + PartialEq, D: Dimension>(a: &Array<T, D>) -> Vec<Vec<usize>> {
622    let zero = T::zero();
623    let ndim = a.ndim();
624    let mut result: Vec<Vec<usize>> = vec![Vec::new(); ndim];
625    for (idx, val) in a.indexed_iter() {
626        if *val != zero {
627            for (d, &c) in idx.iter().enumerate() {
628                result[d].push(c);
629            }
630        }
631    }
632    result
633}
634
635/// Return the coordinates of non-zero elements as a 2-D `(N, ndim)` array.
636///
637/// Equivalent to `np.argwhere(a)`. Each row of the result is the
638/// multi-index of one non-zero element, in row-major order. The output
639/// dtype is `i64` to match `NumPy`'s default signed integer index type
640/// (and because `usize` is not an `Element`). For a 0-D (scalar) input
641/// the result has shape `(0, 0)` or `(1, 0)` depending on whether the
642/// scalar is zero.
643///
644/// # Errors
645/// Returns an error only if constructing the result `Array` fails, which
646/// should never happen for a well-formed input.
647pub fn argwhere<T: Element + PartialEq, D: Dimension>(
648    a: &Array<T, D>,
649) -> FerrayResult<Array<i64, Ix2>> {
650    let zero = T::zero();
651    let ndim = a.ndim();
652    let mut data: Vec<i64> = Vec::new();
653    let mut count: usize = 0;
654    for (idx, val) in a.indexed_iter() {
655        if *val != zero {
656            for &c in &idx {
657                data.push(c as i64);
658            }
659            count += 1;
660        }
661    }
662    Array::<i64, Ix2>::from_vec(Ix2::new([count, ndim]), data)
663}
664
665// ===========================================================================
666// ndindex / ndenumerate iterators
667// ===========================================================================
668
669/// An iterator over all multi-dimensional indices for a given shape.
670///
671/// Equivalent to `np.ndindex(*shape)`. Yields indices in row-major order.
672pub struct NdIndex {
673    shape: Vec<usize>,
674    current: Vec<usize>,
675    done: bool,
676}
677
678impl NdIndex {
679    fn new(shape: &[usize]) -> Self {
680        let done = shape.contains(&0);
681        Self {
682            shape: shape.to_vec(),
683            current: vec![0; shape.len()],
684            done,
685        }
686    }
687}
688
689impl Iterator for NdIndex {
690    type Item = Vec<usize>;
691
692    fn next(&mut self) -> Option<Self::Item> {
693        if self.done {
694            return None;
695        }
696
697        let result = self.current.clone();
698
699        // Increment: rightmost dimension first (row-major / C-order)
700        let mut carry = true;
701        for i in (0..self.shape.len()).rev() {
702            if carry {
703                self.current[i] += 1;
704                if self.current[i] >= self.shape[i] {
705                    self.current[i] = 0;
706                    carry = true;
707                } else {
708                    carry = false;
709                }
710            }
711        }
712        if carry {
713            self.done = true;
714        }
715
716        Some(result)
717    }
718
719    fn size_hint(&self) -> (usize, Option<usize>) {
720        if self.done {
721            return (0, Some(0));
722        }
723        let total: usize = self.shape.iter().product();
724        // Compute how many we've already yielded
725        let mut yielded = 0usize;
726        let ndim = self.shape.len();
727        let mut stride = 1usize;
728        for i in (0..ndim).rev() {
729            yielded += self.current[i] * stride;
730            stride *= self.shape[i];
731        }
732        let remaining = total - yielded;
733        (remaining, Some(remaining))
734    }
735}
736
737/// Create an iterator over all multi-dimensional indices for a shape.
738///
739/// Equivalent to `np.ndindex(*shape)`.
740#[must_use]
741pub fn ndindex(shape: &[usize]) -> NdIndex {
742    NdIndex::new(shape)
743}
744
745/// Create an iterator yielding `(index, &value)` pairs.
746///
747/// Equivalent to `np.ndenumerate(a)`.
748pub fn ndenumerate<T: Element, D: Dimension>(
749    a: &Array<T, D>,
750) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
751    let shape = a.shape().to_vec();
752    let ndim = shape.len();
753    a.inner.iter().enumerate().map(move |(flat_idx, val)| {
754        let mut idx = vec![0usize; ndim];
755        let mut rem = flat_idx;
756        for (d, s) in shape.iter().enumerate().rev() {
757            if *s > 0 {
758                idx[d] = rem % s;
759                rem /= s;
760            }
761        }
762        (idx, val)
763    })
764}
765
766// ---------------------------------------------------------------------------
767// where_select — ternary selection (np.where equivalent)
768// ---------------------------------------------------------------------------
769
770/// Select elements from `x` or `y` depending on `condition`.
771///
772/// For each element position, returns `x[i]` where `condition[i]` is `true`,
773/// and `y[i]` where `condition[i]` is `false`.
774///
775/// All three arrays must have the same shape.
776///
777/// Equivalent to `numpy.where(condition, x, y)`.
778///
779/// # Errors
780/// Returns `FerrayError::ShapeMismatch` if shapes differ.
781pub fn where_select<T: Element + Copy, D: Dimension>(
782    condition: &Array<bool, D>,
783    x: &Array<T, D>,
784    y: &Array<T, D>,
785) -> FerrayResult<Array<T, D>> {
786    if condition.shape() != x.shape() || condition.shape() != y.shape() {
787        return Err(FerrayError::shape_mismatch(format!(
788            "where_select: condition shape {:?}, x shape {:?}, y shape {:?} must all match",
789            condition.shape(),
790            x.shape(),
791            y.shape()
792        )));
793    }
794    let data: Vec<T> = condition
795        .iter()
796        .zip(x.iter().zip(y.iter()))
797        .map(|(&c, (&xi, &yi))| if c { xi } else { yi })
798        .collect();
799    Array::from_vec(x.dim().clone(), data)
800}
801
802// ===========================================================================
803// place / putmask / extract / mask_indices
804// ===========================================================================
805
806/// Change elements of `a` based on a boolean mask, taking values from `vals`
807/// (cycling) where the mask is true.
808///
809/// Mutates `a` in place. The mask shape must equal `a.shape()`. `vals` is
810/// reused cyclically if it has fewer elements than there are mask hits.
811///
812/// Analogous to `numpy.place(arr, mask, vals)`.
813///
814/// # Errors
815/// Returns `FerrayError::ShapeMismatch` if mask shape != array shape.
816/// Returns `FerrayError::InvalidValue` if there are mask hits but `vals` is empty.
817pub fn place<T: Element + Copy, D: Dimension>(
818    a: &mut Array<T, D>,
819    mask: &Array<bool, D>,
820    vals: &[T],
821) -> FerrayResult<()> {
822    if a.shape() != mask.shape() {
823        return Err(FerrayError::shape_mismatch(format!(
824            "place: mask shape {:?} differs from array shape {:?}",
825            mask.shape(),
826            a.shape(),
827        )));
828    }
829    let hits: usize = mask.iter().filter(|&&m| m).count();
830    if hits > 0 && vals.is_empty() {
831        return Err(FerrayError::invalid_value(
832            "place: vals must be non-empty when mask has any true entries",
833        ));
834    }
835    let mut vi = 0usize;
836    for (slot, &m) in a.inner.iter_mut().zip(mask.iter()) {
837        if m {
838            *slot = vals[vi % vals.len()];
839            vi += 1;
840        }
841    }
842    Ok(())
843}
844
845/// Change elements of `a` based on a boolean mask, broadcasting/cycling values
846/// from a same-shape (or scalar) `values` slice where the mask is true.
847///
848/// Where `values.len() == 1`, that scalar is used. Otherwise `values` must
849/// have the same number of elements as `a` (mask hits use the corresponding
850/// element of `values`, like NumPy's `putmask`). Mutates `a` in place.
851///
852/// Analogous to `numpy.putmask(a, mask, values)`.
853///
854/// # Errors
855/// Returns `FerrayError::ShapeMismatch` if mask shape != array shape, or
856/// if `values.len()` is neither 1 nor `a.size()`.
857pub fn putmask<T: Element + Copy, D: Dimension>(
858    a: &mut Array<T, D>,
859    mask: &Array<bool, D>,
860    values: &[T],
861) -> FerrayResult<()> {
862    if a.shape() != mask.shape() {
863        return Err(FerrayError::shape_mismatch(format!(
864            "putmask: mask shape {:?} differs from array shape {:?}",
865            mask.shape(),
866            a.shape(),
867        )));
868    }
869    let n = a.size();
870    let scalar_mode = values.len() == 1;
871    if !scalar_mode && values.len() != n {
872        return Err(FerrayError::shape_mismatch(format!(
873            "putmask: values length {} must be 1 or equal to array size {}",
874            values.len(),
875            n,
876        )));
877    }
878    for (i, (slot, &m)) in a.inner.iter_mut().zip(mask.iter()).enumerate() {
879        if m {
880            *slot = if scalar_mode { values[0] } else { values[i] };
881        }
882    }
883    Ok(())
884}
885
886/// Return the elements of `a` where `condition` is true, as a 1-D array.
887///
888/// Analogous to `numpy.extract(condition, arr)`. Equivalent to
889/// `arr.flatten()[condition.flatten()]`.
890///
891/// # Errors
892/// Returns `FerrayError::ShapeMismatch` if `condition` does not have the
893/// same shape as `a`.
894pub fn extract<T: Element + Copy, D: Dimension>(
895    condition: &Array<bool, D>,
896    a: &Array<T, D>,
897) -> FerrayResult<Array<T, crate::dimension::Ix1>> {
898    if condition.shape() != a.shape() {
899        return Err(FerrayError::shape_mismatch(format!(
900            "extract: condition shape {:?} differs from array shape {:?}",
901            condition.shape(),
902            a.shape(),
903        )));
904    }
905    let data: Vec<T> = condition
906        .iter()
907        .zip(a.iter())
908        .filter_map(|(&c, &v)| if c { Some(v) } else { None })
909        .collect();
910    let n = data.len();
911    Array::from_vec(crate::dimension::Ix1::new([n]), data)
912}
913
914/// Mask kind for [`mask_indices`].
915#[derive(Debug, Clone, Copy, PartialEq, Eq)]
916pub enum MaskKind {
917    /// Lower-triangular mask (corresponds to `np.tril`).
918    Tril,
919    /// Upper-triangular mask (corresponds to `np.triu`).
920    Triu,
921    /// Main-diagonal-only mask.
922    Diag,
923}
924
925/// Return the flat indices into an `(n, n)` array where the chosen mask is true.
926///
927/// `k` shifts the boundary: for `Tril`, elements with `j <= i + k` are
928/// selected; for `Triu`, elements with `j >= i + k`; for `Diag`, the
929/// `k`-th diagonal.
930///
931/// Analogous to `numpy.mask_indices(n, mask_func, k)`.
932pub fn mask_indices(n: usize, kind: MaskKind, k: isize) -> Vec<usize> {
933    let mut idx = Vec::new();
934    for i in 0..n {
935        for j in 0..n {
936            let select = match kind {
937                MaskKind::Tril => (j as isize) <= (i as isize) + k,
938                MaskKind::Triu => (j as isize) >= (i as isize) + k,
939                MaskKind::Diag => (j as isize) == (i as isize) + k,
940            };
941            if select {
942                idx.push(i * n + j);
943            }
944        }
945    }
946    idx
947}
948
949#[cfg(test)]
950mod tests {
951    use super::*;
952    use crate::dimension::{Ix1, Ix2};
953
954    // -----------------------------------------------------------------------
955    // take
956    // -----------------------------------------------------------------------
957
958    #[test]
959    fn take_1d() {
960        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
961        let taken = take(&arr, &[0, 2, 4], Axis(0)).unwrap();
962        assert_eq!(taken.shape(), &[3]);
963        let data: Vec<i32> = taken.iter().copied().collect();
964        assert_eq!(data, vec![10, 30, 50]);
965    }
966
967    #[test]
968    fn take_2d_axis1() {
969        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
970        let taken = take(&arr, &[0, 2], Axis(1)).unwrap();
971        assert_eq!(taken.shape(), &[3, 2]);
972        let data: Vec<i32> = taken.iter().copied().collect();
973        assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
974    }
975
976    #[test]
977    fn take_negative_indices() {
978        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
979        let taken = take(&arr, &[-1, -3], Axis(0)).unwrap();
980        let data: Vec<i32> = taken.iter().copied().collect();
981        assert_eq!(data, vec![40, 20]);
982    }
983
984    // -----------------------------------------------------------------------
985    // take_along_axis
986    // -----------------------------------------------------------------------
987
988    #[test]
989    fn take_along_axis_basic() {
990        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
991        let taken = take_along_axis(&arr, &[1, 3], Axis(1)).unwrap();
992        assert_eq!(taken.shape(), &[3, 2]);
993    }
994
995    // -----------------------------------------------------------------------
996    // put
997    // -----------------------------------------------------------------------
998
999    #[test]
1000    fn put_flat() {
1001        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 0, 0, 0, 0]).unwrap();
1002        arr.put(&[1, 3], &[99, 88]).unwrap();
1003        assert_eq!(arr.as_slice().unwrap(), &[0, 99, 0, 88, 0]);
1004    }
1005
1006    #[test]
1007    fn put_cycling_values() {
1008        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0; 5]).unwrap();
1009        arr.put(&[0, 1, 2, 3, 4], &[10, 20]).unwrap();
1010        assert_eq!(arr.as_slice().unwrap(), &[10, 20, 10, 20, 10]);
1011    }
1012
1013    #[test]
1014    fn put_out_of_bounds() {
1015        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1016        assert!(arr.put(&[5], &[1]).is_err());
1017    }
1018
1019    // -----------------------------------------------------------------------
1020    // fill_diagonal
1021    // -----------------------------------------------------------------------
1022
1023    #[test]
1024    fn fill_diagonal_2d() {
1025        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 3]), vec![0; 9]).unwrap();
1026        arr.fill_diagonal(1);
1027        let data: Vec<i32> = arr.iter().copied().collect();
1028        assert_eq!(data, vec![1, 0, 0, 0, 1, 0, 0, 0, 1]);
1029    }
1030
1031    #[test]
1032    fn fill_diagonal_rectangular() {
1033        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 4]), vec![0; 8]).unwrap();
1034        arr.fill_diagonal(5);
1035        let data: Vec<i32> = arr.iter().copied().collect();
1036        assert_eq!(data, vec![5, 0, 0, 0, 0, 5, 0, 0]);
1037    }
1038
1039    // -----------------------------------------------------------------------
1040    // choose
1041    // -----------------------------------------------------------------------
1042
1043    #[test]
1044    fn choose_basic() {
1045        let idx = Array::<u64, Ix1>::from_vec(Ix1::new([4]), vec![0, 1, 0, 1]).unwrap();
1046        let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
1047        let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![100, 200, 300, 400]).unwrap();
1048        let result = choose(&idx, &[c0, c1]).unwrap();
1049        let data: Vec<i32> = result.iter().copied().collect();
1050        assert_eq!(data, vec![10, 200, 30, 400]);
1051    }
1052
1053    #[test]
1054    fn choose_out_of_bounds() {
1055        let idx = Array::<u64, Ix1>::from_vec(Ix1::new([2]), vec![0, 2]).unwrap();
1056        let c0 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 2]).unwrap();
1057        let c1 = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3, 4]).unwrap();
1058        assert!(choose(&idx, &[c0, c1]).is_err());
1059    }
1060
1061    // -----------------------------------------------------------------------
1062    // compress
1063    // -----------------------------------------------------------------------
1064
1065    #[test]
1066    fn compress_1d() {
1067        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
1068        let result = compress(&[true, false, true, false, true], &arr, Axis(0)).unwrap();
1069        let data: Vec<i32> = result.iter().copied().collect();
1070        assert_eq!(data, vec![10, 30, 50]);
1071    }
1072
1073    #[test]
1074    fn compress_2d_axis0() {
1075        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
1076        let result = compress(&[true, false, true], &arr, Axis(0)).unwrap();
1077        assert_eq!(result.shape(), &[2, 4]);
1078        let data: Vec<i32> = result.iter().copied().collect();
1079        assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
1080    }
1081
1082    // -----------------------------------------------------------------------
1083    // select
1084    // -----------------------------------------------------------------------
1085
1086    #[test]
1087    fn select_basic() {
1088        let c1 =
1089            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, false]).unwrap();
1090        let c2 =
1091            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
1092        let ch1 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 1, 1, 1]).unwrap();
1093        let ch2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![2, 2, 2, 2]).unwrap();
1094        let result = select(&[c1, c2], &[ch1, ch2], 0).unwrap();
1095        let data: Vec<i32> = result.iter().copied().collect();
1096        assert_eq!(data, vec![1, 2, 0, 0]);
1097    }
1098
1099    // -----------------------------------------------------------------------
1100    // indices
1101    // -----------------------------------------------------------------------
1102
1103    #[test]
1104    fn indices_2d() {
1105        let idx = indices(&[2, 3]).unwrap();
1106        assert_eq!(idx.len(), 2);
1107        assert_eq!(idx[0].shape(), &[2, 3]);
1108        assert_eq!(idx[1].shape(), &[2, 3]);
1109        let rows: Vec<u64> = idx[0].iter().copied().collect();
1110        assert_eq!(rows, vec![0, 0, 0, 1, 1, 1]);
1111        let cols: Vec<u64> = idx[1].iter().copied().collect();
1112        assert_eq!(cols, vec![0, 1, 2, 0, 1, 2]);
1113    }
1114
1115    // -----------------------------------------------------------------------
1116    // ix_
1117    // -----------------------------------------------------------------------
1118
1119    #[test]
1120    fn ix_basic() {
1121        let result = ix_(&[&[0, 1], &[2, 3, 4]]).unwrap();
1122        assert_eq!(result.len(), 2);
1123        assert_eq!(result[0].shape(), &[2, 1]);
1124        assert_eq!(result[1].shape(), &[1, 3]);
1125    }
1126
1127    // -----------------------------------------------------------------------
1128    // diag_indices
1129    // -----------------------------------------------------------------------
1130
1131    #[test]
1132    fn diag_indices_basic() {
1133        let idx = diag_indices(3, 2);
1134        assert_eq!(idx.len(), 2);
1135        assert_eq!(idx[0], vec![0, 1, 2]);
1136        assert_eq!(idx[1], vec![0, 1, 2]);
1137    }
1138
1139    #[test]
1140    fn diag_indices_from_square() {
1141        let arr = Array::<f64, Ix2>::zeros(Ix2::new([4, 4])).unwrap();
1142        let idx = diag_indices_from(&arr).unwrap();
1143        assert_eq!(idx.len(), 2);
1144        assert_eq!(idx[0].len(), 4);
1145    }
1146
1147    #[test]
1148    fn diag_indices_from_not_square() {
1149        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
1150        assert!(diag_indices_from(&arr).is_err());
1151    }
1152
1153    // -----------------------------------------------------------------------
1154    // tril_indices / triu_indices
1155    // -----------------------------------------------------------------------
1156
1157    #[test]
1158    fn tril_indices_basic() {
1159        let (rows, cols) = tril_indices(3, 0, None);
1160        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1161        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1162    }
1163
1164    #[test]
1165    fn triu_indices_basic() {
1166        let (rows, cols) = triu_indices(3, 0, None);
1167        assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
1168        assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
1169    }
1170
1171    #[test]
1172    fn tril_indices_with_k() {
1173        let (rows, cols) = tril_indices(3, 1, None);
1174        assert_eq!(rows, vec![0, 0, 1, 1, 1, 2, 2, 2]);
1175        assert_eq!(cols, vec![0, 1, 0, 1, 2, 0, 1, 2]);
1176    }
1177
1178    #[test]
1179    fn triu_indices_with_negative_k() {
1180        let (rows, cols) = triu_indices(3, -1, None);
1181        assert_eq!(rows, vec![0, 0, 0, 1, 1, 1, 2, 2]);
1182        assert_eq!(cols, vec![0, 1, 2, 0, 1, 2, 1, 2]);
1183    }
1184
1185    #[test]
1186    fn tril_indices_from_test() {
1187        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1188        let (rows, _cols) = tril_indices_from(&arr, 0).unwrap();
1189        assert_eq!(rows.len(), 6);
1190    }
1191
1192    #[test]
1193    fn triu_indices_from_test() {
1194        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 3])).unwrap();
1195        let (rows, _cols) = triu_indices_from(&arr, 0).unwrap();
1196        assert_eq!(rows.len(), 6);
1197    }
1198
1199    #[test]
1200    fn tril_indices_rectangular() {
1201        let (rows, cols) = tril_indices(3, 0, Some(4));
1202        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1203        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1204    }
1205
1206    // -----------------------------------------------------------------------
1207    // ravel_multi_index / unravel_index
1208    // -----------------------------------------------------------------------
1209
1210    #[test]
1211    fn ravel_multi_index_basic() {
1212        let flat = ravel_multi_index(&[&[0, 1, 2], &[1, 2, 0]], &[3, 4]).unwrap();
1213        assert_eq!(flat, vec![1, 6, 8]);
1214    }
1215
1216    #[test]
1217    fn ravel_multi_index_3d() {
1218        let flat = ravel_multi_index(&[&[0], &[1], &[2]], &[2, 3, 4]).unwrap();
1219        assert_eq!(flat, vec![6]);
1220    }
1221
1222    #[test]
1223    fn ravel_multi_index_out_of_bounds() {
1224        assert!(ravel_multi_index(&[&[3]], &[3]).is_err());
1225    }
1226
1227    #[test]
1228    fn unravel_index_basic() {
1229        let coords = unravel_index(&[1, 6, 8], &[3, 4]).unwrap();
1230        assert_eq!(coords[0], vec![0, 1, 2]);
1231        assert_eq!(coords[1], vec![1, 2, 0]);
1232    }
1233
1234    #[test]
1235    fn unravel_index_out_of_bounds() {
1236        assert!(unravel_index(&[12], &[3, 4]).is_err());
1237    }
1238
1239    #[test]
1240    fn ravel_unravel_roundtrip() {
1241        let dims = &[3, 4, 5];
1242        let a: &[usize] = &[1, 2];
1243        let b: &[usize] = &[2, 3];
1244        let c: &[usize] = &[3, 4];
1245        let multi: &[&[usize]] = &[a, b, c];
1246        let flat = ravel_multi_index(multi, dims).unwrap();
1247        let coords = unravel_index(&flat, dims).unwrap();
1248        assert_eq!(coords[0], vec![1, 2]);
1249        assert_eq!(coords[1], vec![2, 3]);
1250        assert_eq!(coords[2], vec![3, 4]);
1251    }
1252
1253    // -----------------------------------------------------------------------
1254    // flatnonzero
1255    // -----------------------------------------------------------------------
1256
1257    #[test]
1258    fn flatnonzero_basic() {
1259        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1260        let nz = flatnonzero(&arr);
1261        assert_eq!(nz, vec![1, 3]);
1262    }
1263
1264    #[test]
1265    fn flatnonzero_2d() {
1266        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1267        let nz = flatnonzero(&arr);
1268        assert_eq!(nz, vec![1, 3, 5]);
1269    }
1270
1271    #[test]
1272    fn flatnonzero_all_zero() {
1273        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![0, 0, 0]).unwrap();
1274        let nz = flatnonzero(&arr);
1275        assert_eq!(nz.len(), 0);
1276    }
1277
1278    // -----------------------------------------------------------------------
1279    // nonzero / argwhere (#373)
1280    // -----------------------------------------------------------------------
1281
1282    #[test]
1283    fn nonzero_1d() {
1284        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
1285        let nz = nonzero(&arr);
1286        // One axis, one inner vec with the hit positions.
1287        assert_eq!(nz.len(), 1);
1288        assert_eq!(nz[0], vec![1, 3]);
1289    }
1290
1291    #[test]
1292    fn nonzero_2d_yields_row_and_col_indices() {
1293        // [[0, 1, 0],
1294        //  [2, 0, 3]]
1295        // Non-zero coordinates (row-major): (0,1), (1,0), (1,2).
1296        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1297        let nz = nonzero(&arr);
1298        assert_eq!(nz.len(), 2);
1299        assert_eq!(nz[0], vec![0, 1, 1]);
1300        assert_eq!(nz[1], vec![1, 0, 2]);
1301    }
1302
1303    #[test]
1304    fn nonzero_all_zero_returns_empty_per_axis() {
1305        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1306        let nz = nonzero(&arr);
1307        assert_eq!(nz.len(), 2);
1308        assert!(nz[0].is_empty());
1309        assert!(nz[1].is_empty());
1310    }
1311
1312    #[test]
1313    fn nonzero_f64_treats_negative_zero_as_zero() {
1314        // -0.0 == 0.0 for PartialEq, so -0.0 is "zero" per numpy semantics.
1315        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![-0.0, 1.5, 0.0, -2.5]).unwrap();
1316        let nz = nonzero(&arr);
1317        assert_eq!(nz[0], vec![1, 3]);
1318    }
1319
1320    #[test]
1321    fn argwhere_2d_has_one_row_per_nonzero() {
1322        // Same input as nonzero_2d_yields_row_and_col_indices.
1323        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 2, 0, 3]).unwrap();
1324        let coords = argwhere(&arr).unwrap();
1325        assert_eq!(coords.shape(), &[3, 2]);
1326        assert_eq!(coords.as_slice().unwrap(), &[0, 1, 1, 0, 1, 2]);
1327    }
1328
1329    #[test]
1330    fn argwhere_1d_is_column_vector() {
1331        // A (K, 1) shape means K non-zero rows for a 1-D array.
1332        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 7, 0, 9, 3]).unwrap();
1333        let coords = argwhere(&arr).unwrap();
1334        assert_eq!(coords.shape(), &[3, 1]);
1335        assert_eq!(coords.as_slice().unwrap(), &[1, 3, 4]);
1336    }
1337
1338    #[test]
1339    fn argwhere_all_zero_returns_empty() {
1340        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
1341        let coords = argwhere(&arr).unwrap();
1342        assert_eq!(coords.shape(), &[0, 2]);
1343        assert_eq!(coords.size(), 0);
1344    }
1345
1346    // -----------------------------------------------------------------------
1347    // ndindex
1348    // -----------------------------------------------------------------------
1349
1350    #[test]
1351    fn ndindex_2d() {
1352        let indices: Vec<Vec<usize>> = ndindex(&[2, 3]).collect();
1353        assert_eq!(indices.len(), 6);
1354        assert_eq!(indices[0], vec![0, 0]);
1355        assert_eq!(indices[1], vec![0, 1]);
1356        assert_eq!(indices[2], vec![0, 2]);
1357        assert_eq!(indices[3], vec![1, 0]);
1358        assert_eq!(indices[4], vec![1, 1]);
1359        assert_eq!(indices[5], vec![1, 2]);
1360    }
1361
1362    #[test]
1363    fn ndindex_1d() {
1364        let indices: Vec<Vec<usize>> = ndindex(&[4]).collect();
1365        assert_eq!(indices.len(), 4);
1366        assert_eq!(indices[0], vec![0]);
1367        assert_eq!(indices[3], vec![3]);
1368    }
1369
1370    #[test]
1371    fn ndindex_empty() {
1372        assert_eq!(ndindex(&[0]).count(), 0);
1373    }
1374
1375    #[test]
1376    fn ndindex_scalar() {
1377        let indices: Vec<Vec<usize>> = ndindex(&[]).collect();
1378        assert_eq!(indices.len(), 1);
1379        assert_eq!(indices[0], Vec::<usize>::new());
1380    }
1381
1382    // -----------------------------------------------------------------------
1383    // ndenumerate
1384    // -----------------------------------------------------------------------
1385
1386    #[test]
1387    fn ndenumerate_2d() {
1388        let arr =
1389            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
1390        let items: Vec<(Vec<usize>, &i32)> = ndenumerate(&arr).collect();
1391        assert_eq!(items.len(), 6);
1392        assert_eq!(items[0], (vec![0, 0], &10));
1393        assert_eq!(items[1], (vec![0, 1], &20));
1394        assert_eq!(items[5], (vec![1, 2], &60));
1395    }
1396
1397    // -----------------------------------------------------------------------
1398    // put_along_axis
1399    // -----------------------------------------------------------------------
1400
1401    #[test]
1402    fn put_along_axis_basic() {
1403        let mut arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![0; 12]).unwrap();
1404        let values =
1405            Array::<i32, IxDyn>::from_vec(IxDyn::new(&[8]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
1406        arr.put_along_axis(&[0, 2], &values, Axis(0)).unwrap();
1407        let data: Vec<i32> = arr.iter().copied().collect();
1408        assert_eq!(data, vec![1, 2, 3, 4, 0, 0, 0, 0, 5, 6, 7, 8]);
1409    }
1410
1411    // -----------------------------------------------------------------------
1412    // where_
1413    // -----------------------------------------------------------------------
1414
1415    #[test]
1416    fn where_basic() {
1417        let cond =
1418            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1419        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1420        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
1421        let result = where_select(&cond, &x, &y).unwrap();
1422        assert_eq!(result.as_slice().unwrap(), &[1.0, 20.0, 3.0, 40.0]);
1423    }
1424
1425    #[test]
1426    fn where_all_true() {
1427        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1428        let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1429        let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1430        let result = where_select(&cond, &x, &y).unwrap();
1431        assert_eq!(result.as_slice().unwrap(), &[1, 2, 3]);
1432    }
1433
1434    #[test]
1435    fn where_all_false() {
1436        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1437        let x = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1438        let y = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
1439        let result = where_select(&cond, &x, &y).unwrap();
1440        assert_eq!(result.as_slice().unwrap(), &[10, 20, 30]);
1441    }
1442
1443    #[test]
1444    fn where_shape_mismatch() {
1445        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true; 3]).unwrap();
1446        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0; 4]).unwrap();
1447        let y = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0; 3]).unwrap();
1448        assert!(where_select(&cond, &x, &y).is_err());
1449    }
1450
1451    #[test]
1452    fn where_2d() {
1453        let cond =
1454            Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![true, false, false, true]).unwrap();
1455        let x = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
1456        let y = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![10, 20, 30, 40]).unwrap();
1457        let result = where_select(&cond, &x, &y).unwrap();
1458        let data: Vec<i32> = result.iter().copied().collect();
1459        assert_eq!(data, vec![1, 20, 30, 4]);
1460    }
1461
1462    // -- place / putmask / extract / mask_indices --
1463
1464    #[test]
1465    fn test_place_basic() {
1466        let mut a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
1467        let mask = Array::<bool, Ix2>::from_vec(
1468            Ix2::new([2, 3]),
1469            vec![false, true, false, true, false, true],
1470        )
1471        .unwrap();
1472        place(&mut a, &mask, &[10, 20]).unwrap();
1473        // Mask hits at positions 1, 3, 5 → cycled vals 10, 20, 10
1474        let data: Vec<i32> = a.iter().copied().collect();
1475        assert_eq!(data, vec![1, 10, 3, 20, 5, 10]);
1476    }
1477
1478    #[test]
1479    fn test_place_no_hits() {
1480        let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1481        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
1482        place(&mut a, &mask, &[]).unwrap(); // No hits, empty vals OK
1483        assert_eq!(a.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
1484    }
1485
1486    #[test]
1487    fn test_place_shape_mismatch() {
1488        let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
1489        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
1490        assert!(place(&mut a, &mask, &[0]).is_err());
1491    }
1492
1493    #[test]
1494    fn test_putmask_scalar() {
1495        let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1496        let mask =
1497            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1498        putmask(&mut a, &mask, &[99]).unwrap();
1499        assert_eq!(a.iter().copied().collect::<Vec<_>>(), vec![99, 2, 99, 4]);
1500    }
1501
1502    #[test]
1503    fn test_putmask_full_array() {
1504        let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1505        let mask =
1506            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1507        putmask(&mut a, &mask, &[10, 20, 30, 40]).unwrap();
1508        assert_eq!(a.iter().copied().collect::<Vec<_>>(), vec![10, 2, 30, 4]);
1509    }
1510
1511    #[test]
1512    fn test_putmask_bad_length() {
1513        let mut a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1514        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
1515        assert!(putmask(&mut a, &mask, &[1, 2]).is_err());
1516    }
1517
1518    #[test]
1519    fn test_extract_basic() {
1520        let cond =
1521            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
1522                .unwrap();
1523        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1524        let r = extract(&cond, &a).unwrap();
1525        assert_eq!(r.iter().copied().collect::<Vec<_>>(), vec![1.0, 3.0, 5.0]);
1526    }
1527
1528    #[test]
1529    fn test_extract_2d() {
1530        let cond =
1531            Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![true, false, false, true]).unwrap();
1532        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![10, 20, 30, 40]).unwrap();
1533        let r = extract(&cond, &a).unwrap();
1534        assert_eq!(r.iter().copied().collect::<Vec<_>>(), vec![10, 40]);
1535    }
1536
1537    #[test]
1538    fn test_mask_indices_tril() {
1539        let idx = mask_indices(3, MaskKind::Tril, 0);
1540        // Lower triangle of 3x3: positions (0,0), (1,0), (1,1), (2,0), (2,1), (2,2)
1541        assert_eq!(idx, vec![0, 3, 4, 6, 7, 8]);
1542    }
1543
1544    #[test]
1545    fn test_mask_indices_triu() {
1546        let idx = mask_indices(3, MaskKind::Triu, 0);
1547        // Upper triangle: (0,0), (0,1), (0,2), (1,1), (1,2), (2,2)
1548        assert_eq!(idx, vec![0, 1, 2, 4, 5, 8]);
1549    }
1550
1551    #[test]
1552    fn test_mask_indices_diag() {
1553        let idx = mask_indices(3, MaskKind::Diag, 0);
1554        assert_eq!(idx, vec![0, 4, 8]);
1555    }
1556}