Skip to main content

ferray_stats/
searching.rs

1// ferray-stats: Searching — unique, nonzero, where_, count_nonzero (REQ-14, REQ-15, REQ-16, REQ-17)
2
3use ferray_core::dimension::broadcast::broadcast_shapes;
4use ferray_core::error::{FerrayError, FerrayResult};
5use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
6
7use crate::reductions::{
8    borrow_data, make_result, output_shape, reduce_axis_general_u64, validate_axis,
9};
10
11// ---------------------------------------------------------------------------
12// unique
13// ---------------------------------------------------------------------------
14
15/// Result from the `unique` function.
16#[derive(Debug)]
17pub struct UniqueResult<T: Element> {
18    /// The sorted unique values.
19    pub values: Array<T, Ix1>,
20    /// If requested, the indices of the first occurrence of each unique value
21    /// in the original array (as u64).
22    pub indices: Option<Array<u64, Ix1>>,
23    /// If requested, the inverse index array such that
24    /// `values[inverse]` reconstructs the flattened input. Essential for
25    /// label encoding (#463).
26    pub inverse: Option<Array<u64, Ix1>>,
27    /// If requested, the count of each unique value (as u64).
28    pub counts: Option<Array<u64, Ix1>>,
29}
30
31/// Find the sorted unique elements of an array.
32///
33/// The input is flattened. Optionally returns:
34/// - `return_index`: indices of the first occurrence of each unique value.
35/// - `return_inverse`: an array of the same length as the flattened input,
36///   where each entry is the index into `values` of the corresponding
37///   original element. Satisfies `values[inverse] == flat_input`.
38/// - `return_counts`: count of each unique value.
39///
40/// Equivalent to `numpy.unique`.
41pub fn unique<T, D>(
42    a: &Array<T, D>,
43    return_index: bool,
44    return_inverse: bool,
45    return_counts: bool,
46) -> FerrayResult<UniqueResult<T>>
47where
48    T: Element + PartialOrd + Copy,
49    D: Dimension,
50{
51    let data: Vec<T> = a.iter().copied().collect();
52    let n_data = data.len();
53
54    // Create (value, original_index) pairs, then sort by value.
55    let mut pairs: Vec<(T, usize)> = data
56        .iter()
57        .copied()
58        .enumerate()
59        .map(|(i, v)| (v, i))
60        .collect();
61    pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
62
63    // Deduplicate. The inverse is built in lockstep: for each original
64    // position `orig_idx`, `inverse[orig_idx]` is the u64 index into
65    // `unique_vals` where that original's value ended up. We walk the
66    // sorted pairs, advancing the unique position whenever we see a new
67    // value, and use the sorted pair's `original_idx` to write into the
68    // right inverse slot.
69    let mut unique_vals = Vec::new();
70    let mut unique_indices: Vec<u64> = Vec::new();
71    let mut unique_counts: Vec<u64> = Vec::new();
72    let mut inverse_vec: Vec<u64> = if return_inverse {
73        vec![0u64; n_data]
74    } else {
75        Vec::new()
76    };
77
78    if !pairs.is_empty() {
79        unique_vals.push(pairs[0].0);
80        unique_indices.push(pairs[0].1 as u64);
81        if return_inverse {
82            inverse_vec[pairs[0].1] = 0;
83        }
84        let mut count = 1u64;
85        let mut unique_pos: u64 = 0;
86
87        for i in 1..pairs.len() {
88            if pairs[i].0.partial_cmp(&pairs[i - 1].0) == Some(std::cmp::Ordering::Equal) {
89                count += 1;
90                // Keep the first occurrence index (smallest original index).
91                let last = unique_indices.len() - 1;
92                let new_idx = pairs[i].1 as u64;
93                if new_idx < unique_indices[last] {
94                    unique_indices[last] = new_idx;
95                }
96            } else {
97                if return_counts {
98                    unique_counts.push(count);
99                }
100                unique_vals.push(pairs[i].0);
101                unique_indices.push(pairs[i].1 as u64);
102                count = 1;
103                unique_pos += 1;
104            }
105            if return_inverse {
106                inverse_vec[pairs[i].1] = unique_pos;
107            }
108        }
109        if return_counts {
110            unique_counts.push(count);
111        }
112    }
113
114    let n = unique_vals.len();
115    let values = Array::from_vec(Ix1::new([n]), unique_vals)?;
116    let indices = if return_index {
117        Some(Array::from_vec(Ix1::new([n]), unique_indices)?)
118    } else {
119        None
120    };
121    let inverse = if return_inverse {
122        Some(Array::from_vec(Ix1::new([n_data]), inverse_vec)?)
123    } else {
124        None
125    };
126    let counts = if return_counts {
127        Some(Array::from_vec(Ix1::new([n]), unique_counts)?)
128    } else {
129        None
130    };
131
132    Ok(UniqueResult {
133        values,
134        indices,
135        inverse,
136        counts,
137    })
138}
139
140/// Find unique hyperslices along an axis.
141///
142/// Equivalent to `numpy.unique(a, axis=axis)`: returns an array of the
143/// same dimensionality as `a` with the `axis` dimension reduced to the
144/// number of unique slices. Slices are compared element-wise (lex order)
145/// and returned in sorted order, matching numpy's behavior (#464).
146///
147/// For axis-less unique-on-flattened, see [`unique`].
148///
149/// # Errors
150/// - `FerrayError::AxisOutOfBounds` if `axis >= a.ndim()`.
151pub fn unique_axis<T, D>(a: &Array<T, D>, axis: usize) -> FerrayResult<Array<T, IxDyn>>
152where
153    T: Element + PartialOrd + Copy,
154    D: Dimension,
155{
156    let shape = a.shape().to_vec();
157    let ndim = shape.len();
158    if axis >= ndim {
159        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
160    }
161    let n = shape[axis];
162    let inner_stride: usize = shape[axis + 1..].iter().product();
163    let outer_size: usize = shape[..axis].iter().product();
164    let block = n * inner_stride;
165    let slice_len = outer_size * inner_stride;
166
167    let data: Vec<T> = a.iter().copied().collect();
168
169    // Empty axis: return the input unchanged.
170    if n == 0 {
171        return Array::<T, IxDyn>::from_vec(IxDyn::new(&shape), data);
172    }
173
174    // Gather each axis-slice in canonical (outer, inner) order.
175    let mut slices: Vec<Vec<T>> = Vec::with_capacity(n);
176    for i in 0..n {
177        let mut s = Vec::with_capacity(slice_len);
178        for o in 0..outer_size {
179            let base = o * block + i * inner_stride;
180            s.extend_from_slice(&data[base..base + inner_stride]);
181        }
182        slices.push(s);
183    }
184
185    // Sort axis indices by lex comparison of their slices. Use
186    // partial_cmp so floating-point T (NaN-tolerating) works the same
187    // way it does in `unique`.
188    let mut order: Vec<usize> = (0..n).collect();
189    order.sort_by(|&i, &j| {
190        slices[i]
191            .iter()
192            .zip(slices[j].iter())
193            .find_map(|(a, b)| match a.partial_cmp(b) {
194                None | Some(std::cmp::Ordering::Equal) => None,
195                Some(c) => Some(c),
196            })
197            .unwrap_or(std::cmp::Ordering::Equal)
198    });
199
200    // Dedupe consecutive equal slices.
201    let mut kept: Vec<usize> = Vec::with_capacity(n);
202    for &idx in &order {
203        if let Some(&prev) = kept.last() {
204            let equal = slices[idx]
205                .iter()
206                .zip(slices[prev].iter())
207                .all(|(a, b)| a.partial_cmp(b) == Some(std::cmp::Ordering::Equal));
208            if equal {
209                continue;
210            }
211        }
212        kept.push(idx);
213    }
214
215    let new_n = kept.len();
216    let mut out_shape = shape.clone();
217    out_shape[axis] = new_n;
218    let new_block = new_n * inner_stride;
219    let total: usize = out_shape.iter().product();
220    let mut out_data: Vec<T> = vec![data[0]; total];
221    for (out_i, &src_i) in kept.iter().enumerate() {
222        for o in 0..outer_size {
223            let src_base = o * block + src_i * inner_stride;
224            let dst_base = o * new_block + out_i * inner_stride;
225            out_data[dst_base..dst_base + inner_stride]
226                .copy_from_slice(&data[src_base..src_base + inner_stride]);
227        }
228    }
229    Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), out_data)
230}
231
232// ---------------------------------------------------------------------------
233// nonzero
234// ---------------------------------------------------------------------------
235
236/// Return the indices of non-zero elements.
237///
238/// Returns a vector of 1-D arrays (u64), one per dimension. For a 1-D input,
239/// returns a single array of indices.
240///
241/// Equivalent to `numpy.nonzero`.
242pub fn nonzero<T, D>(a: &Array<T, D>) -> FerrayResult<Vec<Array<u64, Ix1>>>
243where
244    T: Element + PartialEq + Copy,
245    D: Dimension,
246{
247    let shape = a.shape();
248    let ndim = shape.len();
249    let zero = <T as Element>::zero();
250
251    // Collect all multi-indices where element != 0
252    let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
253
254    // Compute strides for index conversion
255    let mut strides = vec![1usize; ndim];
256    for i in (0..ndim.saturating_sub(1)).rev() {
257        strides[i] = strides[i + 1] * shape[i + 1];
258    }
259
260    for (flat_idx, &val) in a.iter().enumerate() {
261        if val != zero {
262            let mut rem = flat_idx;
263            for d in 0..ndim {
264                indices_per_dim[d].push((rem / strides[d]) as u64);
265                rem %= strides[d];
266            }
267        }
268    }
269
270    let mut result = Vec::with_capacity(ndim);
271    for idx_vec in indices_per_dim {
272        let n = idx_vec.len();
273        result.push(Array::from_vec(Ix1::new([n]), idx_vec)?);
274    }
275
276    Ok(result)
277}
278
279// ---------------------------------------------------------------------------
280// where_
281// ---------------------------------------------------------------------------
282
283/// Conditional element selection.
284///
285/// For each element, if the corresponding element of `condition` is non-zero,
286/// select from `x`; otherwise select from `y`.
287///
288/// All three arrays must have the same shape.
289///
290/// Equivalent to `numpy.where`.
291pub fn where_<T, D>(
292    condition: &Array<bool, D>,
293    x: &Array<T, D>,
294    y: &Array<T, D>,
295) -> FerrayResult<Array<T, D>>
296where
297    T: Element + Copy,
298    D: Dimension,
299{
300    if condition.shape() != x.shape() || condition.shape() != y.shape() {
301        return Err(FerrayError::shape_mismatch(format!(
302            "condition, x, y shapes must match: {:?}, {:?}, {:?}",
303            condition.shape(),
304            x.shape(),
305            y.shape()
306        )));
307    }
308
309    let result: Vec<T> = condition
310        .iter()
311        .zip(x.iter())
312        .zip(y.iter())
313        .map(|((&c, &xv), &yv)| if c { xv } else { yv })
314        .collect();
315
316    Array::from_vec(condition.dim().clone(), result)
317}
318
319/// Broadcast-aware ternary `where`: pick from `x` where `condition` is
320/// true, else from `y`. Each of `condition`, `x`, `y` may have its
321/// own shape; they are NumPy-broadcast against each other to produce
322/// the output shape (#468).
323///
324/// For the same-shape fast path use [`where_`] — it skips the
325/// broadcast view setup and is lower overhead.
326///
327/// # Errors
328/// `FerrayError::ShapeMismatch` if any pair of shapes is not
329/// broadcast-compatible.
330pub fn where_broadcast<T>(
331    condition: &Array<bool, IxDyn>,
332    x: &Array<T, IxDyn>,
333    y: &Array<T, IxDyn>,
334) -> FerrayResult<Array<T, IxDyn>>
335where
336    T: Element + Copy,
337{
338    let cx = broadcast_shapes(condition.shape(), x.shape())?;
339    let target = broadcast_shapes(&cx, y.shape())?;
340    let cv = condition.broadcast_to(&target)?;
341    let xv = x.broadcast_to(&target)?;
342    let yv = y.broadcast_to(&target)?;
343    let result: Vec<T> = cv
344        .iter()
345        .zip(xv.iter())
346        .zip(yv.iter())
347        .map(|((&c, &xv), &yv)| if c { xv } else { yv })
348        .collect();
349    Array::<T, IxDyn>::from_vec(IxDyn::new(&target), result)
350}
351
352/// One-argument form of `where`: return the indices where `condition`
353/// is true, as a vector of 1-D index arrays (one per dimension).
354///
355/// Equivalent to `numpy.where(condition)` (single-argument form) or
356/// `numpy.nonzero(condition.astype(int))`. Added for `NumPy` parity
357/// (#166) — the three-argument form above is [`where_`].
358pub fn where_condition<D: Dimension>(
359    condition: &Array<bool, D>,
360) -> FerrayResult<Vec<Array<u64, Ix1>>> {
361    let shape = condition.shape();
362    let ndim = shape.len();
363    let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
364
365    let mut strides = vec![1usize; ndim];
366    for i in (0..ndim.saturating_sub(1)).rev() {
367        strides[i] = strides[i + 1] * shape[i + 1];
368    }
369
370    for (flat_idx, &val) in condition.iter().enumerate() {
371        if val {
372            let mut rem = flat_idx;
373            for d in 0..ndim {
374                indices_per_dim[d].push((rem / strides[d]) as u64);
375                rem %= strides[d];
376            }
377        }
378    }
379
380    indices_per_dim
381        .into_iter()
382        .map(|v| {
383            let n = v.len();
384            Array::from_vec(Ix1::new([n]), v)
385        })
386        .collect()
387}
388
389// ---------------------------------------------------------------------------
390// count_nonzero
391// ---------------------------------------------------------------------------
392
393/// Count the number of non-zero elements along a given axis.
394///
395/// Equivalent to `numpy.count_nonzero`.
396pub fn count_nonzero<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
397where
398    T: Element + PartialEq + Copy,
399    D: Dimension,
400{
401    let zero = <T as Element>::zero();
402    let data = borrow_data(a);
403    match axis {
404        None => {
405            let count = data.iter().filter(|&&x| x != zero).count() as u64;
406            make_result(&[], vec![count])
407        }
408        Some(ax) => {
409            validate_axis(ax, a.ndim())?;
410            let shape = a.shape();
411            let out_s = output_shape(shape, ax);
412            let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
413                lane.iter().filter(|&&x| x != zero).count() as u64
414            });
415            make_result(&out_s, result)
416        }
417    }
418}
419
420// ---------------------------------------------------------------------------
421// Array API standard names: unique_values / unique_counts / unique_inverse / unique_all
422// ---------------------------------------------------------------------------
423
424/// Sorted unique values of the (flattened) array.
425///
426/// Array-API-standard alias for [`unique`] with no extra return arrays.
427/// Equivalent to `numpy.unique_values(x)`.
428pub fn unique_values<T, D>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>>
429where
430    T: Element + PartialOrd + Copy,
431    D: Dimension,
432{
433    Ok(unique(a, false, false, false)?.values)
434}
435
436/// Sorted unique values and their occurrence counts.
437///
438/// Equivalent to `numpy.unique_counts(x)` — returns `(values, counts)`.
439pub fn unique_counts<T, D>(a: &Array<T, D>) -> FerrayResult<(Array<T, Ix1>, Array<u64, Ix1>)>
440where
441    T: Element + PartialOrd + Copy,
442    D: Dimension,
443{
444    let r = unique(a, false, false, true)?;
445    Ok((r.values, r.counts.expect("return_counts requested")))
446}
447
448/// Sorted unique values and the inverse-index array.
449///
450/// Equivalent to `numpy.unique_inverse(x)` — returns `(values, inverse)`
451/// where `values[inverse]` reconstructs the flattened input.
452pub fn unique_inverse<T, D>(a: &Array<T, D>) -> FerrayResult<(Array<T, Ix1>, Array<u64, Ix1>)>
453where
454    T: Element + PartialOrd + Copy,
455    D: Dimension,
456{
457    let r = unique(a, false, true, false)?;
458    Ok((r.values, r.inverse.expect("return_inverse requested")))
459}
460
461/// Sorted unique values along with first-occurrence indices, inverse, and
462/// counts.
463///
464/// Equivalent to `numpy.unique_all(x)` — returns
465/// `(values, indices, inverse, counts)`. The four-tuple is the documented
466/// Array-API contract; suppress clippy's complexity warning.
467#[allow(clippy::type_complexity)]
468pub fn unique_all<T, D>(
469    a: &Array<T, D>,
470) -> FerrayResult<(
471    Array<T, Ix1>,
472    Array<u64, Ix1>,
473    Array<u64, Ix1>,
474    Array<u64, Ix1>,
475)>
476where
477    T: Element + PartialOrd + Copy,
478    D: Dimension,
479{
480    let r = unique(a, true, true, true)?;
481    Ok((
482        r.values,
483        r.indices.expect("return_index requested"),
484        r.inverse.expect("return_inverse requested"),
485        r.counts.expect("return_counts requested"),
486    ))
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use ferray_core::{Ix1, Ix2};
493
494    // ---- where_broadcast (#468) ----------------------------------------
495
496    #[test]
497    fn where_broadcast_scalar_else() {
498        // condition: (3,), x: (3,), y: (1,) → pick from y for false slots.
499        let c = Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3]), vec![true, false, true]).unwrap();
500        let x = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), vec![1, 2, 3]).unwrap();
501        let y = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[1]), vec![99]).unwrap();
502        let r = where_broadcast(&c, &x, &y).unwrap();
503        assert_eq!(r.shape(), &[3]);
504        assert_eq!(r.as_slice().unwrap(), &[1, 99, 3]);
505    }
506
507    #[test]
508    fn where_broadcast_2d_outer() {
509        // condition: (3, 1), x: (1, 4), y: scalar (1,1) → output (3, 4).
510        let c =
511            Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3, 1]), vec![true, false, true]).unwrap();
512        let x = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[1, 4]), vec![10, 20, 30, 40]).unwrap();
513        let y = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[1, 1]), vec![-1]).unwrap();
514        let r = where_broadcast(&c, &x, &y).unwrap();
515        assert_eq!(r.shape(), &[3, 4]);
516        let s = r.as_slice().unwrap();
517        // Row 0: condition=true → row from x.
518        assert_eq!(&s[0..4], &[10, 20, 30, 40]);
519        // Row 1: condition=false → all -1.
520        assert_eq!(&s[4..8], &[-1, -1, -1, -1]);
521        // Row 2: condition=true → row from x.
522        assert_eq!(&s[8..12], &[10, 20, 30, 40]);
523    }
524
525    #[test]
526    fn where_broadcast_shape_mismatch() {
527        let c = Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3]), vec![true, false, true]).unwrap();
528        let x = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1, 2]).unwrap();
529        let y = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[1]), vec![0]).unwrap();
530        assert!(where_broadcast(&c, &x, &y).is_err());
531    }
532
533    // ---- unique_axis (#464) --------------------------------------------
534
535    #[test]
536    fn unique_axis_rows_dedup() {
537        // axis=0 on a 2D array dedupes rows.
538        let a =
539            Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), vec![1, 2, 3, 4, 5, 6, 1, 2, 3, 7, 8, 9])
540                .unwrap();
541        let u = unique_axis(&a, 0).unwrap();
542        // Three unique rows: [1,2,3], [4,5,6], [7,8,9] in sorted order.
543        assert_eq!(u.shape(), &[3, 3]);
544        let s = u.as_slice().unwrap();
545        assert_eq!(s, &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
546    }
547
548    #[test]
549    fn unique_axis_columns_dedup() {
550        // axis=1 dedupes columns. Construct a matrix where columns 0 and 2
551        // are identical.
552        let a =
553            Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), vec![1, 2, 1, 3, 4, 5, 4, 6, 7, 8, 7, 9])
554                .unwrap();
555        let u = unique_axis(&a, 1).unwrap();
556        // 3 unique columns: [1,4,7], [2,5,8], [3,6,9] sorted lex.
557        assert_eq!(u.shape(), &[3, 3]);
558        let s = u.as_slice().unwrap();
559        assert_eq!(s, &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
560    }
561
562    #[test]
563    fn unique_axis_all_distinct_keeps_count_but_sorts() {
564        let a = Array::<i32, Ix2>::from_vec(Ix2::new([3, 2]), vec![3, 4, 1, 2, 5, 6]).unwrap();
565        let u = unique_axis(&a, 0).unwrap();
566        assert_eq!(u.shape(), &[3, 2]);
567        let s = u.as_slice().unwrap();
568        assert_eq!(s, &[1, 2, 3, 4, 5, 6]);
569    }
570
571    #[test]
572    fn unique_axis_all_same_collapses() {
573        let a =
574            Array::<i32, Ix2>::from_vec(Ix2::new([4, 2]), vec![1, 2, 1, 2, 1, 2, 1, 2]).unwrap();
575        let u = unique_axis(&a, 0).unwrap();
576        assert_eq!(u.shape(), &[1, 2]);
577        let s = u.as_slice().unwrap();
578        assert_eq!(s, &[1, 2]);
579    }
580
581    #[test]
582    fn unique_axis_3d_axis0() {
583        // Shape (4, 2, 2) with rows 0 and 2 the same hyperslice.
584        let a = Array::<i32, Ix2>::from_vec(
585            Ix2::new([4, 4]),
586            vec![
587                1, 2, 3, 4, // row 0
588                5, 6, 7, 8, // row 1
589                1, 2, 3, 4, // row 2 (== row 0)
590                9, 0, 1, 2, // row 3
591            ],
592        )
593        .unwrap();
594        let u = unique_axis(&a, 0).unwrap();
595        assert_eq!(u.shape(), &[3, 4]);
596    }
597
598    #[test]
599    fn unique_axis_out_of_bounds() {
600        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
601        assert!(unique_axis(&a, 5).is_err());
602    }
603
604    #[test]
605    fn test_unique_basic() {
606        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
607        let u = unique(&a, false, false, false).unwrap();
608        let data: Vec<i32> = u.values.iter().copied().collect();
609        assert_eq!(data, vec![1, 2, 3]);
610    }
611
612    #[test]
613    fn test_unique_with_counts() {
614        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
615        let u = unique(&a, false, false, true).unwrap();
616        let vals: Vec<i32> = u.values.iter().copied().collect();
617        let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
618        assert_eq!(vals, vec![1, 2, 3]);
619        assert_eq!(cnts, vec![2, 2, 2]);
620    }
621
622    #[test]
623    fn test_unique_with_index() {
624        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![5, 3, 3, 1, 5]).unwrap();
625        let u = unique(&a, true, false, false).unwrap();
626        let vals: Vec<i32> = u.values.iter().copied().collect();
627        let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
628        assert_eq!(vals, vec![1, 3, 5]);
629        assert_eq!(idxs, vec![3, 1, 0]);
630    }
631
632    // ---- return_inverse (#463) ----
633
634    #[test]
635    fn test_unique_inverse_reconstructs_input() {
636        // The canonical label-encoding use case.
637        let input = vec![3, 1, 2, 1, 3, 2];
638        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), input.clone()).unwrap();
639        let u = unique(&a, false, true, false).unwrap();
640        let vals: Vec<i32> = u.values.iter().copied().collect();
641        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
642        // Unique values must be sorted.
643        assert_eq!(vals, vec![1, 2, 3]);
644        // values[inverse] must reconstruct the flattened input.
645        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
646        assert_eq!(reconstructed, input);
647    }
648
649    #[test]
650    fn test_unique_inverse_all_together() {
651        // Request indices + inverse + counts in one call; each field must
652        // independently match what a single-flag call would produce.
653        let a = Array::<i32, Ix1>::from_vec(Ix1::new([7]), vec![2, 1, 2, 3, 1, 2, 3]).unwrap();
654        let u = unique(&a, true, true, true).unwrap();
655        let vals: Vec<i32> = u.values.iter().copied().collect();
656        let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
657        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
658        let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
659        assert_eq!(vals, vec![1, 2, 3]);
660        assert_eq!(idxs, vec![1, 0, 3]); // first positions of 1, 2, 3
661        assert_eq!(cnts, vec![2, 3, 2]);
662        // Reconstruct via inverse.
663        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
664        assert_eq!(reconstructed, vec![2, 1, 2, 3, 1, 2, 3]);
665    }
666
667    #[test]
668    fn test_unique_inverse_with_2d_flattens_first() {
669        // NumPy's unique flattens the input; inverse has length
670        // shape.iter().product(), indexing into the flat logical traversal.
671        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 1, 3, 2, 1]).unwrap();
672        let u = unique(&a, false, true, false).unwrap();
673        let vals: Vec<i32> = u.values.iter().copied().collect();
674        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
675        assert_eq!(vals, vec![1, 2, 3]);
676        assert_eq!(inv.len(), 6);
677        let flat: Vec<i32> = vec![1, 2, 1, 3, 2, 1];
678        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
679        assert_eq!(reconstructed, flat);
680    }
681
682    #[test]
683    fn test_unique_inverse_empty_input() {
684        let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
685        let u = unique(&a, false, true, false).unwrap();
686        assert_eq!(u.values.shape(), &[0]);
687        let inv = u.inverse.unwrap();
688        assert_eq!(inv.shape(), &[0]);
689    }
690
691    #[test]
692    fn test_unique_inverse_single_value() {
693        // Every element identical → all inverse entries point at position 0.
694        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![7, 7, 7, 7]).unwrap();
695        let u = unique(&a, false, true, false).unwrap();
696        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
697        assert_eq!(inv, vec![0, 0, 0, 0]);
698    }
699
700    #[test]
701    fn test_unique_without_inverse_leaves_field_none() {
702        let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 1]).unwrap();
703        let u = unique(&a, false, false, false).unwrap();
704        assert!(u.inverse.is_none());
705    }
706
707    #[test]
708    fn test_nonzero_1d() {
709        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
710        let nz = nonzero(&a).unwrap();
711        assert_eq!(nz.len(), 1);
712        let data: Vec<u64> = nz[0].iter().copied().collect();
713        assert_eq!(data, vec![1, 3]);
714    }
715
716    #[test]
717    fn test_nonzero_2d() {
718        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
719        let nz = nonzero(&a).unwrap();
720        assert_eq!(nz.len(), 2);
721        let rows: Vec<u64> = nz[0].iter().copied().collect();
722        let cols: Vec<u64> = nz[1].iter().copied().collect();
723        assert_eq!(rows, vec![0, 1, 1]);
724        assert_eq!(cols, vec![1, 0, 2]);
725    }
726
727    #[test]
728    fn test_where_basic() {
729        let cond =
730            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
731        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
732        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
733        let r = where_(&cond, &x, &y).unwrap();
734        let data: Vec<f64> = r.iter().copied().collect();
735        assert_eq!(data, vec![1.0, 20.0, 3.0, 40.0]);
736    }
737
738    #[test]
739    fn test_where_shape_mismatch() {
740        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
741        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
742        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
743        assert!(where_(&cond, &x, &y).is_err());
744    }
745
746    #[test]
747    fn test_count_nonzero_total() {
748        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
749        let c = count_nonzero(&a, None).unwrap();
750        assert_eq!(c.iter().next(), Some(&2u64));
751    }
752
753    #[test]
754    fn test_count_nonzero_axis() {
755        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
756        let c = count_nonzero(&a, Some(0)).unwrap();
757        let data: Vec<u64> = c.iter().copied().collect();
758        assert_eq!(data, vec![1, 1, 1]);
759    }
760
761    #[test]
762    fn test_count_nonzero_axis1() {
763        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
764        let c = count_nonzero(&a, Some(1)).unwrap();
765        let data: Vec<u64> = c.iter().copied().collect();
766        assert_eq!(data, vec![1, 2]);
767    }
768
769    // -- Array API unique_* --
770
771    #[test]
772    fn test_unique_values_alias() {
773        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
774        let v = unique_values(&a).unwrap();
775        assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
776    }
777
778    #[test]
779    fn test_unique_counts_alias() {
780        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
781        let (v, c) = unique_counts(&a).unwrap();
782        assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
783        assert_eq!(c.iter().copied().collect::<Vec<_>>(), vec![2, 2, 2]);
784    }
785
786    #[test]
787    fn test_unique_inverse_alias() {
788        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
789        let (v, inv) = unique_inverse(&a).unwrap();
790        // values = [1, 2, 3]; inverse maps each original to that index
791        // 3→2, 1→0, 2→1, 1→0, 3→2, 2→1
792        assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
793        assert_eq!(
794            inv.iter().copied().collect::<Vec<_>>(),
795            vec![2, 0, 1, 0, 2, 1]
796        );
797    }
798
799    #[test]
800    fn test_unique_all_alias() {
801        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
802        let (v, _idx, inv, c) = unique_all(&a).unwrap();
803        assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
804        assert_eq!(
805            inv.iter().copied().collect::<Vec<_>>(),
806            vec![2, 0, 1, 0, 2, 1]
807        );
808        assert_eq!(c.iter().copied().collect::<Vec<_>>(), vec![2, 2, 2]);
809    }
810}