Skip to main content

ferray_stats/
searching.rs

1// ferray-stats: Searching — unique, nonzero, where_, count_nonzero (REQ-14, REQ-15, REQ-16, REQ-17)
2//
3// ## REQ status (ferray-stats searching, NumPy parity)
4//  - REQ-14 (unique with return_index / return_counts) — SHIPPED:
5//    `pub fn unique` returns sorted unique elements with optional index/counts,
6//    plus `pub fn unique_values` / `pub fn unique_counts` / `pub fn unique_inverse`
7//    / `pub fn unique_all` / `pub fn unique_axis`, matching numpy's
8//    `unique`/`unique_*` family (numpy/lib/_arraysetops_impl.py:138 `unique`).
9//    Non-test consumers: `ferray_stats::unique` / `ferray_stats::unique_values`
10//    `#[pyfunction]` shims in `ferray-python/src/stats.rs`.
11//  - REQ-15 (nonzero returning a tuple of index arrays) — SHIPPED:
12//    `pub fn nonzero` returns `Vec<Array<u64, Ix1>>` (one index array per
13//    dimension), matching numpy/_core/fromnumeric.py:1937 `nonzero`. Non-test
14//    consumer: the `nonzero` `#[pyfunction]` in `ferray-python/src/indexing.rs`
15//    (`fi::nonzero(&fa)`, where `fi` aliases `ferray_stats`).
16//  - REQ-16 (where_ conditional selection) — SHIPPED: `pub fn where_` plus
17//    `pub fn where_broadcast` (broadcasting cond/x/y) and `pub fn where_condition`
18//    (single-arg index form), matching numpy's `where`
19//    (numpy/_core/multiarray `where`). Consumer: `ferray_stats::where_`
20//    `#[pyfunction]` shim in `stats.rs`.
21//  - REQ-17 (count_nonzero with optional axis) — SHIPPED: `pub fn count_nonzero`
22//    returns `Array<u64, IxDyn>`, taking `axis: Option<usize>`, matching
23//    numpy/_core/numeric.py `count_nonzero`. Consumer:
24//    `ferray_stats::count_nonzero` `#[pyfunction]` shim in `stats.rs`.
25
26use ferray_core::dimension::broadcast::broadcast_shapes;
27use ferray_core::error::{FerrayError, FerrayResult};
28use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
29
30use crate::reductions::{
31    borrow_data, make_result, output_shape, reduce_axis_general_u64, validate_axis,
32};
33
34// ---------------------------------------------------------------------------
35// unique
36// ---------------------------------------------------------------------------
37
38/// Result from the `unique` function.
39#[derive(Debug)]
40pub struct UniqueResult<T: Element> {
41    /// The sorted unique values.
42    pub values: Array<T, Ix1>,
43    /// If requested, the indices of the first occurrence of each unique value
44    /// in the original array (as u64).
45    pub indices: Option<Array<u64, Ix1>>,
46    /// If requested, the inverse index array such that
47    /// `values[inverse]` reconstructs the flattened input. Essential for
48    /// label encoding (#463).
49    pub inverse: Option<Array<u64, Ix1>>,
50    /// If requested, the count of each unique value (as u64).
51    pub counts: Option<Array<u64, Ix1>>,
52}
53
54/// Find the sorted unique elements of an array.
55///
56/// The input is flattened. Optionally returns:
57/// - `return_index`: indices of the first occurrence of each unique value.
58/// - `return_inverse`: an array of the same length as the flattened input,
59///   where each entry is the index into `values` of the corresponding
60///   original element. Satisfies `values[inverse] == flat_input`.
61/// - `return_counts`: count of each unique value.
62///
63/// Equivalent to `numpy.unique`.
64pub fn unique<T, D>(
65    a: &Array<T, D>,
66    return_index: bool,
67    return_inverse: bool,
68    return_counts: bool,
69) -> FerrayResult<UniqueResult<T>>
70where
71    T: Element + PartialOrd + Copy,
72    D: Dimension,
73{
74    let data: Vec<T> = a.iter().copied().collect();
75    let n_data = data.len();
76
77    // Create (value, original_index) pairs, then sort by value.
78    let mut pairs: Vec<(T, usize)> = data
79        .iter()
80        .copied()
81        .enumerate()
82        .map(|(i, v)| (v, i))
83        .collect();
84    // NaN-aware total order (#977, same comparator as the #918 sort fix):
85    // every NaN sorts LAST, so after sorting all NaN are adjacent at the tail
86    // and the dedup below collapses them to a single value (numpy's default
87    // `equal_nan=True`, numpy/lib/_arraysetops_impl.py:389-399). A plain
88    // `partial_cmp(...).unwrap_or(Equal)` would instead treat NaN as equal to
89    // everything, leaving the result unsorted with NaN un-deduplicated.
90    pairs.sort_by(|a, b| crate::parallel::nan_last_cmp(&a.0, &b.0));
91
92    // Deduplicate. The inverse is built in lockstep: for each original
93    // position `orig_idx`, `inverse[orig_idx]` is the u64 index into
94    // `unique_vals` where that original's value ended up. We walk the
95    // sorted pairs, advancing the unique position whenever we see a new
96    // value, and use the sorted pair's `original_idx` to write into the
97    // right inverse slot.
98    let mut unique_vals = Vec::new();
99    let mut unique_indices: Vec<u64> = Vec::new();
100    let mut unique_counts: Vec<u64> = Vec::new();
101    let mut inverse_vec: Vec<u64> = if return_inverse {
102        vec![0u64; n_data]
103    } else {
104        Vec::new()
105    };
106
107    if !pairs.is_empty() {
108        unique_vals.push(pairs[0].0);
109        unique_indices.push(pairs[0].1 as u64);
110        if return_inverse {
111            inverse_vec[pairs[0].1] = 0;
112        }
113        let mut count = 1u64;
114        let mut unique_pos: u64 = 0;
115
116        for i in 1..pairs.len() {
117            // Adjacent NaN compare Equal under `nan_last_cmp`, so they collapse
118            // into one unique value (#977, `equal_nan=True`).
119            if crate::parallel::nan_last_cmp(&pairs[i].0, &pairs[i - 1].0)
120                == std::cmp::Ordering::Equal
121            {
122                count += 1;
123                // Keep the first occurrence index (smallest original index).
124                let last = unique_indices.len() - 1;
125                let new_idx = pairs[i].1 as u64;
126                if new_idx < unique_indices[last] {
127                    unique_indices[last] = new_idx;
128                }
129            } else {
130                if return_counts {
131                    unique_counts.push(count);
132                }
133                unique_vals.push(pairs[i].0);
134                unique_indices.push(pairs[i].1 as u64);
135                count = 1;
136                unique_pos += 1;
137            }
138            if return_inverse {
139                inverse_vec[pairs[i].1] = unique_pos;
140            }
141        }
142        if return_counts {
143            unique_counts.push(count);
144        }
145    }
146
147    let n = unique_vals.len();
148    let values = Array::from_vec(Ix1::new([n]), unique_vals)?;
149    let indices = if return_index {
150        Some(Array::from_vec(Ix1::new([n]), unique_indices)?)
151    } else {
152        None
153    };
154    let inverse = if return_inverse {
155        Some(Array::from_vec(Ix1::new([n_data]), inverse_vec)?)
156    } else {
157        None
158    };
159    let counts = if return_counts {
160        Some(Array::from_vec(Ix1::new([n]), unique_counts)?)
161    } else {
162        None
163    };
164
165    Ok(UniqueResult {
166        values,
167        indices,
168        inverse,
169        counts,
170    })
171}
172
173/// Find unique hyperslices along an axis.
174///
175/// Equivalent to `numpy.unique(a, axis=axis)`: returns an array of the
176/// same dimensionality as `a` with the `axis` dimension reduced to the
177/// number of unique slices. Slices are compared element-wise (lex order)
178/// and returned in sorted order, matching numpy's behavior (#464).
179///
180/// For axis-less unique-on-flattened, see [`unique`].
181///
182/// # Errors
183/// - `FerrayError::AxisOutOfBounds` if `axis >= a.ndim()`.
184pub fn unique_axis<T, D>(a: &Array<T, D>, axis: usize) -> FerrayResult<Array<T, IxDyn>>
185where
186    T: Element + PartialOrd + Copy,
187    D: Dimension,
188{
189    let shape = a.shape().to_vec();
190    let ndim = shape.len();
191    if axis >= ndim {
192        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
193    }
194    let n = shape[axis];
195    let inner_stride: usize = shape[axis + 1..].iter().product();
196    let outer_size: usize = shape[..axis].iter().product();
197    let block = n * inner_stride;
198    let slice_len = outer_size * inner_stride;
199
200    let data: Vec<T> = a.iter().copied().collect();
201
202    // Empty axis: return the input unchanged.
203    if n == 0 {
204        return Array::<T, IxDyn>::from_vec(IxDyn::new(&shape), data);
205    }
206
207    // Gather each axis-slice in canonical (outer, inner) order.
208    let mut slices: Vec<Vec<T>> = Vec::with_capacity(n);
209    for i in 0..n {
210        let mut s = Vec::with_capacity(slice_len);
211        for o in 0..outer_size {
212            let base = o * block + i * inner_stride;
213            s.extend_from_slice(&data[base..base + inner_stride]);
214        }
215        slices.push(s);
216    }
217
218    // Sort axis indices by lex comparison of their slices. Use
219    // partial_cmp so floating-point T (NaN-tolerating) works the same
220    // way it does in `unique`.
221    let mut order: Vec<usize> = (0..n).collect();
222    order.sort_by(|&i, &j| {
223        slices[i]
224            .iter()
225            .zip(slices[j].iter())
226            .find_map(|(a, b)| match a.partial_cmp(b) {
227                None | Some(std::cmp::Ordering::Equal) => None,
228                Some(c) => Some(c),
229            })
230            .unwrap_or(std::cmp::Ordering::Equal)
231    });
232
233    // Dedupe consecutive equal slices.
234    let mut kept: Vec<usize> = Vec::with_capacity(n);
235    for &idx in &order {
236        if let Some(&prev) = kept.last() {
237            let equal = slices[idx]
238                .iter()
239                .zip(slices[prev].iter())
240                .all(|(a, b)| a.partial_cmp(b) == Some(std::cmp::Ordering::Equal));
241            if equal {
242                continue;
243            }
244        }
245        kept.push(idx);
246    }
247
248    let new_n = kept.len();
249    let mut out_shape = shape.clone();
250    out_shape[axis] = new_n;
251    let new_block = new_n * inner_stride;
252    let total: usize = out_shape.iter().product();
253    let mut out_data: Vec<T> = vec![data[0]; total];
254    for (out_i, &src_i) in kept.iter().enumerate() {
255        for o in 0..outer_size {
256            let src_base = o * block + src_i * inner_stride;
257            let dst_base = o * new_block + out_i * inner_stride;
258            out_data[dst_base..dst_base + inner_stride]
259                .copy_from_slice(&data[src_base..src_base + inner_stride]);
260        }
261    }
262    Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), out_data)
263}
264
265// ---------------------------------------------------------------------------
266// nonzero
267// ---------------------------------------------------------------------------
268
269/// Return the indices of non-zero elements.
270///
271/// Returns a vector of 1-D arrays (u64), one per dimension. For a 1-D input,
272/// returns a single array of indices.
273///
274/// Equivalent to `numpy.nonzero`.
275pub fn nonzero<T, D>(a: &Array<T, D>) -> FerrayResult<Vec<Array<u64, Ix1>>>
276where
277    T: Element + PartialEq + Copy,
278    D: Dimension,
279{
280    let shape = a.shape();
281    let ndim = shape.len();
282    let zero = <T as Element>::zero();
283
284    // Collect all multi-indices where element != 0
285    let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
286
287    // Compute strides for index conversion
288    let mut strides = vec![1usize; ndim];
289    for i in (0..ndim.saturating_sub(1)).rev() {
290        strides[i] = strides[i + 1] * shape[i + 1];
291    }
292
293    for (flat_idx, &val) in a.iter().enumerate() {
294        if val != zero {
295            let mut rem = flat_idx;
296            for d in 0..ndim {
297                indices_per_dim[d].push((rem / strides[d]) as u64);
298                rem %= strides[d];
299            }
300        }
301    }
302
303    let mut result = Vec::with_capacity(ndim);
304    for idx_vec in indices_per_dim {
305        let n = idx_vec.len();
306        result.push(Array::from_vec(Ix1::new([n]), idx_vec)?);
307    }
308
309    Ok(result)
310}
311
312// ---------------------------------------------------------------------------
313// where_
314// ---------------------------------------------------------------------------
315
316/// Conditional element selection.
317///
318/// For each element, if the corresponding element of `condition` is non-zero,
319/// select from `x`; otherwise select from `y`.
320///
321/// All three arrays must have the same shape.
322///
323/// Equivalent to `numpy.where`.
324pub fn where_<T, D>(
325    condition: &Array<bool, D>,
326    x: &Array<T, D>,
327    y: &Array<T, D>,
328) -> FerrayResult<Array<T, D>>
329where
330    T: Element + Copy,
331    D: Dimension,
332{
333    if condition.shape() != x.shape() || condition.shape() != y.shape() {
334        return Err(FerrayError::shape_mismatch(format!(
335            "condition, x, y shapes must match: {:?}, {:?}, {:?}",
336            condition.shape(),
337            x.shape(),
338            y.shape()
339        )));
340    }
341
342    let result: Vec<T> = condition
343        .iter()
344        .zip(x.iter())
345        .zip(y.iter())
346        .map(|((&c, &xv), &yv)| if c { xv } else { yv })
347        .collect();
348
349    Array::from_vec(condition.dim().clone(), result)
350}
351
352/// Broadcast-aware ternary `where`: pick from `x` where `condition` is
353/// true, else from `y`. Each of `condition`, `x`, `y` may have its
354/// own shape; they are NumPy-broadcast against each other to produce
355/// the output shape (#468).
356///
357/// For the same-shape fast path use [`where_`] — it skips the
358/// broadcast view setup and is lower overhead.
359///
360/// # Errors
361/// `FerrayError::ShapeMismatch` if any pair of shapes is not
362/// broadcast-compatible.
363pub fn where_broadcast<T>(
364    condition: &Array<bool, IxDyn>,
365    x: &Array<T, IxDyn>,
366    y: &Array<T, IxDyn>,
367) -> FerrayResult<Array<T, IxDyn>>
368where
369    T: Element + Copy,
370{
371    let cx = broadcast_shapes(condition.shape(), x.shape())?;
372    let target = broadcast_shapes(&cx, y.shape())?;
373    let cv = condition.broadcast_to(&target)?;
374    let xv = x.broadcast_to(&target)?;
375    let yv = y.broadcast_to(&target)?;
376    let result: Vec<T> = cv
377        .iter()
378        .zip(xv.iter())
379        .zip(yv.iter())
380        .map(|((&c, &xv), &yv)| if c { xv } else { yv })
381        .collect();
382    Array::<T, IxDyn>::from_vec(IxDyn::new(&target), result)
383}
384
385/// One-argument form of `where`: return the indices where `condition`
386/// is true, as a vector of 1-D index arrays (one per dimension).
387///
388/// Equivalent to `numpy.where(condition)` (single-argument form) or
389/// `numpy.nonzero(condition.astype(int))`. Added for `NumPy` parity
390/// (#166) — the three-argument form above is [`where_`].
391pub fn where_condition<D: Dimension>(
392    condition: &Array<bool, D>,
393) -> FerrayResult<Vec<Array<u64, Ix1>>> {
394    let shape = condition.shape();
395    let ndim = shape.len();
396    let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
397
398    let mut strides = vec![1usize; ndim];
399    for i in (0..ndim.saturating_sub(1)).rev() {
400        strides[i] = strides[i + 1] * shape[i + 1];
401    }
402
403    for (flat_idx, &val) in condition.iter().enumerate() {
404        if val {
405            let mut rem = flat_idx;
406            for d in 0..ndim {
407                indices_per_dim[d].push((rem / strides[d]) as u64);
408                rem %= strides[d];
409            }
410        }
411    }
412
413    indices_per_dim
414        .into_iter()
415        .map(|v| {
416            let n = v.len();
417            Array::from_vec(Ix1::new([n]), v)
418        })
419        .collect()
420}
421
422// ---------------------------------------------------------------------------
423// count_nonzero
424// ---------------------------------------------------------------------------
425
426/// Count the number of non-zero elements along a given axis.
427///
428/// Equivalent to `numpy.count_nonzero`.
429pub fn count_nonzero<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
430where
431    T: Element + PartialEq + Copy,
432    D: Dimension,
433{
434    let zero = <T as Element>::zero();
435    let data = borrow_data(a);
436    match axis {
437        None => {
438            let count = data.iter().filter(|&&x| x != zero).count() as u64;
439            make_result(&[], vec![count])
440        }
441        Some(ax) => {
442            validate_axis(ax, a.ndim())?;
443            let shape = a.shape();
444            let out_s = output_shape(shape, ax);
445            let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
446                lane.iter().filter(|&&x| x != zero).count() as u64
447            });
448            make_result(&out_s, result)
449        }
450    }
451}
452
453// ---------------------------------------------------------------------------
454// Array API standard names: unique_values / unique_counts / unique_inverse / unique_all
455// ---------------------------------------------------------------------------
456
457/// Sorted unique values of the (flattened) array.
458///
459/// Array-API-standard alias for [`unique`] with no extra return arrays.
460/// Equivalent to `numpy.unique_values(x)`.
461pub fn unique_values<T, D>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>>
462where
463    T: Element + PartialOrd + Copy,
464    D: Dimension,
465{
466    Ok(unique(a, false, false, false)?.values)
467}
468
469/// Sorted unique values and their occurrence counts.
470///
471/// Equivalent to `numpy.unique_counts(x)` — returns `(values, counts)`.
472pub fn unique_counts<T, D>(a: &Array<T, D>) -> FerrayResult<(Array<T, Ix1>, Array<u64, Ix1>)>
473where
474    T: Element + PartialOrd + Copy,
475    D: Dimension,
476{
477    let r = unique(a, false, false, true)?;
478    Ok((r.values, r.counts.expect("return_counts requested")))
479}
480
481/// Sorted unique values and the inverse-index array.
482///
483/// Equivalent to `numpy.unique_inverse(x)` — returns `(values, inverse)`
484/// where `values[inverse]` reconstructs the flattened input.
485pub fn unique_inverse<T, D>(a: &Array<T, D>) -> FerrayResult<(Array<T, Ix1>, Array<u64, Ix1>)>
486where
487    T: Element + PartialOrd + Copy,
488    D: Dimension,
489{
490    let r = unique(a, false, true, false)?;
491    Ok((r.values, r.inverse.expect("return_inverse requested")))
492}
493
494/// Sorted unique values along with first-occurrence indices, inverse, and
495/// counts.
496///
497/// Equivalent to `numpy.unique_all(x)` — returns
498/// `(values, indices, inverse, counts)`. The four-tuple is the documented
499/// Array-API contract; suppress clippy's complexity warning.
500#[allow(clippy::type_complexity)]
501pub fn unique_all<T, D>(
502    a: &Array<T, D>,
503) -> FerrayResult<(
504    Array<T, Ix1>,
505    Array<u64, Ix1>,
506    Array<u64, Ix1>,
507    Array<u64, Ix1>,
508)>
509where
510    T: Element + PartialOrd + Copy,
511    D: Dimension,
512{
513    let r = unique(a, true, true, true)?;
514    Ok((
515        r.values,
516        r.indices.expect("return_index requested"),
517        r.inverse.expect("return_inverse requested"),
518        r.counts.expect("return_counts requested"),
519    ))
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use ferray_core::{Ix1, Ix2};
526
527    // ---- where_broadcast (#468) ----------------------------------------
528
529    #[test]
530    fn where_broadcast_scalar_else() {
531        // condition: (3,), x: (3,), y: (1,) → pick from y for false slots.
532        let c = Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3]), vec![true, false, true]).unwrap();
533        let x = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), vec![1, 2, 3]).unwrap();
534        let y = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[1]), vec![99]).unwrap();
535        let r = where_broadcast(&c, &x, &y).unwrap();
536        assert_eq!(r.shape(), &[3]);
537        assert_eq!(r.as_slice().unwrap(), &[1, 99, 3]);
538    }
539
540    #[test]
541    fn where_broadcast_2d_outer() {
542        // condition: (3, 1), x: (1, 4), y: scalar (1,1) → output (3, 4).
543        let c =
544            Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3, 1]), vec![true, false, true]).unwrap();
545        let x = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[1, 4]), vec![10, 20, 30, 40]).unwrap();
546        let y = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[1, 1]), vec![-1]).unwrap();
547        let r = where_broadcast(&c, &x, &y).unwrap();
548        assert_eq!(r.shape(), &[3, 4]);
549        let s = r.as_slice().unwrap();
550        // Row 0: condition=true → row from x.
551        assert_eq!(&s[0..4], &[10, 20, 30, 40]);
552        // Row 1: condition=false → all -1.
553        assert_eq!(&s[4..8], &[-1, -1, -1, -1]);
554        // Row 2: condition=true → row from x.
555        assert_eq!(&s[8..12], &[10, 20, 30, 40]);
556    }
557
558    #[test]
559    fn where_broadcast_shape_mismatch() {
560        let c = Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3]), vec![true, false, true]).unwrap();
561        let x = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1, 2]).unwrap();
562        let y = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[1]), vec![0]).unwrap();
563        assert!(where_broadcast(&c, &x, &y).is_err());
564    }
565
566    // ---- unique_axis (#464) --------------------------------------------
567
568    #[test]
569    fn unique_axis_rows_dedup() {
570        // axis=0 on a 2D array dedupes rows.
571        let a =
572            Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), vec![1, 2, 3, 4, 5, 6, 1, 2, 3, 7, 8, 9])
573                .unwrap();
574        let u = unique_axis(&a, 0).unwrap();
575        // Three unique rows: [1,2,3], [4,5,6], [7,8,9] in sorted order.
576        assert_eq!(u.shape(), &[3, 3]);
577        let s = u.as_slice().unwrap();
578        assert_eq!(s, &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
579    }
580
581    #[test]
582    fn unique_axis_columns_dedup() {
583        // axis=1 dedupes columns. Construct a matrix where columns 0 and 2
584        // are identical.
585        let a =
586            Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![1, 2, 1, 3, 4, 5, 4, 6, 7, 8, 7, 9])
587                .unwrap();
588        let u = unique_axis(&a, 1).unwrap();
589        // 3 unique columns: [1,4,7], [2,5,8], [3,6,9] sorted lex.
590        assert_eq!(u.shape(), &[3, 3]);
591        let s = u.as_slice().unwrap();
592        assert_eq!(s, &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
593    }
594
595    #[test]
596    fn unique_axis_all_distinct_keeps_count_but_sorts() {
597        let a = Array::<i32, Ix2>::from_vec(Ix2::new([3, 2]), vec![3, 4, 1, 2, 5, 6]).unwrap();
598        let u = unique_axis(&a, 0).unwrap();
599        assert_eq!(u.shape(), &[3, 2]);
600        let s = u.as_slice().unwrap();
601        assert_eq!(s, &[1, 2, 3, 4, 5, 6]);
602    }
603
604    #[test]
605    fn unique_axis_all_same_collapses() {
606        let a =
607            Array::<i32, Ix2>::from_vec(Ix2::new([4, 2]), vec![1, 2, 1, 2, 1, 2, 1, 2]).unwrap();
608        let u = unique_axis(&a, 0).unwrap();
609        assert_eq!(u.shape(), &[1, 2]);
610        let s = u.as_slice().unwrap();
611        assert_eq!(s, &[1, 2]);
612    }
613
614    #[test]
615    fn unique_axis_3d_axis0() {
616        // Shape (4, 2, 2) with rows 0 and 2 the same hyperslice.
617        let a = Array::<i32, Ix2>::from_vec(
618            Ix2::new([4, 4]),
619            vec![
620                1, 2, 3, 4, // row 0
621                5, 6, 7, 8, // row 1
622                1, 2, 3, 4, // row 2 (== row 0)
623                9, 0, 1, 2, // row 3
624            ],
625        )
626        .unwrap();
627        let u = unique_axis(&a, 0).unwrap();
628        assert_eq!(u.shape(), &[3, 4]);
629    }
630
631    #[test]
632    fn unique_axis_out_of_bounds() {
633        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
634        assert!(unique_axis(&a, 5).is_err());
635    }
636
637    #[test]
638    fn test_unique_basic() {
639        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
640        let u = unique(&a, false, false, false).unwrap();
641        let data: Vec<i32> = u.values.iter().copied().collect();
642        assert_eq!(data, vec![1, 2, 3]);
643    }
644
645    #[test]
646    fn test_unique_with_counts() {
647        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
648        let u = unique(&a, false, false, true).unwrap();
649        let vals: Vec<i32> = u.values.iter().copied().collect();
650        let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
651        assert_eq!(vals, vec![1, 2, 3]);
652        assert_eq!(cnts, vec![2, 2, 2]);
653    }
654
655    #[test]
656    fn test_unique_with_index() {
657        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![5, 3, 3, 1, 5]).unwrap();
658        let u = unique(&a, true, false, false).unwrap();
659        let vals: Vec<i32> = u.values.iter().copied().collect();
660        let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
661        assert_eq!(vals, vec![1, 3, 5]);
662        assert_eq!(idxs, vec![3, 1, 0]);
663    }
664
665    // ---- return_inverse (#463) ----
666
667    #[test]
668    fn test_unique_inverse_reconstructs_input() {
669        // The canonical label-encoding use case.
670        let input = vec![3, 1, 2, 1, 3, 2];
671        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), input.clone()).unwrap();
672        let u = unique(&a, false, true, false).unwrap();
673        let vals: Vec<i32> = u.values.iter().copied().collect();
674        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
675        // Unique values must be sorted.
676        assert_eq!(vals, vec![1, 2, 3]);
677        // values[inverse] must reconstruct the flattened input.
678        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
679        assert_eq!(reconstructed, input);
680    }
681
682    #[test]
683    fn test_unique_inverse_all_together() {
684        // Request indices + inverse + counts in one call; each field must
685        // independently match what a single-flag call would produce.
686        let a = Array::<i32, Ix1>::from_vec(Ix1::new([7]), vec![2, 1, 2, 3, 1, 2, 3]).unwrap();
687        let u = unique(&a, true, true, true).unwrap();
688        let vals: Vec<i32> = u.values.iter().copied().collect();
689        let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
690        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
691        let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
692        assert_eq!(vals, vec![1, 2, 3]);
693        assert_eq!(idxs, vec![1, 0, 3]); // first positions of 1, 2, 3
694        assert_eq!(cnts, vec![2, 3, 2]);
695        // Reconstruct via inverse.
696        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
697        assert_eq!(reconstructed, vec![2, 1, 2, 3, 1, 2, 3]);
698    }
699
700    #[test]
701    fn test_unique_inverse_with_2d_flattens_first() {
702        // NumPy's unique flattens the input; inverse has length
703        // shape.iter().product(), indexing into the flat logical traversal.
704        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 1, 3, 2, 1]).unwrap();
705        let u = unique(&a, false, true, false).unwrap();
706        let vals: Vec<i32> = u.values.iter().copied().collect();
707        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
708        assert_eq!(vals, vec![1, 2, 3]);
709        assert_eq!(inv.len(), 6);
710        let flat: Vec<i32> = vec![1, 2, 1, 3, 2, 1];
711        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
712        assert_eq!(reconstructed, flat);
713    }
714
715    #[test]
716    fn test_unique_inverse_empty_input() {
717        let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
718        let u = unique(&a, false, true, false).unwrap();
719        assert_eq!(u.values.shape(), &[0]);
720        let inv = u.inverse.unwrap();
721        assert_eq!(inv.shape(), &[0]);
722    }
723
724    #[test]
725    fn test_unique_inverse_single_value() {
726        // Every element identical → all inverse entries point at position 0.
727        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![7, 7, 7, 7]).unwrap();
728        let u = unique(&a, false, true, false).unwrap();
729        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
730        assert_eq!(inv, vec![0, 0, 0, 0]);
731    }
732
733    #[test]
734    fn test_unique_without_inverse_leaves_field_none() {
735        let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 1]).unwrap();
736        let u = unique(&a, false, false, false).unwrap();
737        assert!(u.inverse.is_none());
738    }
739
740    #[test]
741    fn test_nonzero_1d() {
742        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
743        let nz = nonzero(&a).unwrap();
744        assert_eq!(nz.len(), 1);
745        let data: Vec<u64> = nz[0].iter().copied().collect();
746        assert_eq!(data, vec![1, 3]);
747    }
748
749    #[test]
750    fn test_nonzero_2d() {
751        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
752        let nz = nonzero(&a).unwrap();
753        assert_eq!(nz.len(), 2);
754        let rows: Vec<u64> = nz[0].iter().copied().collect();
755        let cols: Vec<u64> = nz[1].iter().copied().collect();
756        assert_eq!(rows, vec![0, 1, 1]);
757        assert_eq!(cols, vec![1, 0, 2]);
758    }
759
760    #[test]
761    fn test_where_basic() {
762        let cond =
763            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
764        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
765        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
766        let r = where_(&cond, &x, &y).unwrap();
767        let data: Vec<f64> = r.iter().copied().collect();
768        assert_eq!(data, vec![1.0, 20.0, 3.0, 40.0]);
769    }
770
771    #[test]
772    fn test_where_shape_mismatch() {
773        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
774        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
775        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
776        assert!(where_(&cond, &x, &y).is_err());
777    }
778
779    #[test]
780    fn test_count_nonzero_total() {
781        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
782        let c = count_nonzero(&a, None).unwrap();
783        assert_eq!(c.iter().next(), Some(&2u64));
784    }
785
786    #[test]
787    fn test_count_nonzero_axis() {
788        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
789        let c = count_nonzero(&a, Some(0)).unwrap();
790        let data: Vec<u64> = c.iter().copied().collect();
791        assert_eq!(data, vec![1, 1, 1]);
792    }
793
794    #[test]
795    fn test_count_nonzero_axis1() {
796        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
797        let c = count_nonzero(&a, Some(1)).unwrap();
798        let data: Vec<u64> = c.iter().copied().collect();
799        assert_eq!(data, vec![1, 2]);
800    }
801
802    // -- Array API unique_* --
803
804    #[test]
805    fn test_unique_values_alias() {
806        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
807        let v = unique_values(&a).unwrap();
808        assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
809    }
810
811    #[test]
812    fn test_unique_counts_alias() {
813        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
814        let (v, c) = unique_counts(&a).unwrap();
815        assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
816        assert_eq!(c.iter().copied().collect::<Vec<_>>(), vec![2, 2, 2]);
817    }
818
819    #[test]
820    fn test_unique_inverse_alias() {
821        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
822        let (v, inv) = unique_inverse(&a).unwrap();
823        // values = [1, 2, 3]; inverse maps each original to that index
824        // 3→2, 1→0, 2→1, 1→0, 3→2, 2→1
825        assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
826        assert_eq!(
827            inv.iter().copied().collect::<Vec<_>>(),
828            vec![2, 0, 1, 0, 2, 1]
829        );
830    }
831
832    #[test]
833    fn test_unique_all_alias() {
834        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
835        let (v, _idx, inv, c) = unique_all(&a).unwrap();
836        assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
837        assert_eq!(
838            inv.iter().copied().collect::<Vec<_>>(),
839            vec![2, 0, 1, 0, 2, 1]
840        );
841        assert_eq!(c.iter().copied().collect::<Vec<_>>(), vec![2, 2, 2]);
842    }
843}