Skip to main content

ferray_stats/
searching.rs

1// ferray-stats: Searching — unique, nonzero, where_, count_nonzero (REQ-14, REQ-15, REQ-16, REQ-17)
2
3use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
5
6use crate::reductions::{
7    borrow_data, make_result, output_shape, reduce_axis_general_u64, validate_axis,
8};
9
10// ---------------------------------------------------------------------------
11// unique
12// ---------------------------------------------------------------------------
13
14/// Result from the `unique` function.
15#[derive(Debug)]
16pub struct UniqueResult<T: Element> {
17    /// The sorted unique values.
18    pub values: Array<T, Ix1>,
19    /// If requested, the indices of the first occurrence of each unique value
20    /// in the original array (as u64).
21    pub indices: Option<Array<u64, Ix1>>,
22    /// If requested, the inverse index array such that
23    /// `values[inverse]` reconstructs the flattened input. Essential for
24    /// label encoding (#463).
25    pub inverse: Option<Array<u64, Ix1>>,
26    /// If requested, the count of each unique value (as u64).
27    pub counts: Option<Array<u64, Ix1>>,
28}
29
30/// Find the sorted unique elements of an array.
31///
32/// The input is flattened. Optionally returns:
33/// - `return_index`: indices of the first occurrence of each unique value.
34/// - `return_inverse`: an array of the same length as the flattened input,
35///   where each entry is the index into `values` of the corresponding
36///   original element. Satisfies `values[inverse] == flat_input`.
37/// - `return_counts`: count of each unique value.
38///
39/// Equivalent to `numpy.unique`.
40pub fn unique<T, D>(
41    a: &Array<T, D>,
42    return_index: bool,
43    return_inverse: bool,
44    return_counts: bool,
45) -> FerrayResult<UniqueResult<T>>
46where
47    T: Element + PartialOrd + Copy,
48    D: Dimension,
49{
50    let data: Vec<T> = a.iter().copied().collect();
51    let n_data = data.len();
52
53    // Create (value, original_index) pairs, then sort by value.
54    let mut pairs: Vec<(T, usize)> = data
55        .iter()
56        .copied()
57        .enumerate()
58        .map(|(i, v)| (v, i))
59        .collect();
60    pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
61
62    // Deduplicate. The inverse is built in lockstep: for each original
63    // position `orig_idx`, `inverse[orig_idx]` is the u64 index into
64    // `unique_vals` where that original's value ended up. We walk the
65    // sorted pairs, advancing the unique position whenever we see a new
66    // value, and use the sorted pair's `original_idx` to write into the
67    // right inverse slot.
68    let mut unique_vals = Vec::new();
69    let mut unique_indices: Vec<u64> = Vec::new();
70    let mut unique_counts: Vec<u64> = Vec::new();
71    let mut inverse_vec: Vec<u64> = if return_inverse {
72        vec![0u64; n_data]
73    } else {
74        Vec::new()
75    };
76
77    if !pairs.is_empty() {
78        unique_vals.push(pairs[0].0);
79        unique_indices.push(pairs[0].1 as u64);
80        if return_inverse {
81            inverse_vec[pairs[0].1] = 0;
82        }
83        let mut count = 1u64;
84        let mut unique_pos: u64 = 0;
85
86        for i in 1..pairs.len() {
87            if pairs[i].0.partial_cmp(&pairs[i - 1].0) != Some(std::cmp::Ordering::Equal) {
88                if return_counts {
89                    unique_counts.push(count);
90                }
91                unique_vals.push(pairs[i].0);
92                unique_indices.push(pairs[i].1 as u64);
93                count = 1;
94                unique_pos += 1;
95            } else {
96                count += 1;
97                // Keep the first occurrence index (smallest original index).
98                let last = unique_indices.len() - 1;
99                let new_idx = pairs[i].1 as u64;
100                if new_idx < unique_indices[last] {
101                    unique_indices[last] = new_idx;
102                }
103            }
104            if return_inverse {
105                inverse_vec[pairs[i].1] = unique_pos;
106            }
107        }
108        if return_counts {
109            unique_counts.push(count);
110        }
111    }
112
113    let n = unique_vals.len();
114    let values = Array::from_vec(Ix1::new([n]), unique_vals)?;
115    let indices = if return_index {
116        Some(Array::from_vec(Ix1::new([n]), unique_indices)?)
117    } else {
118        None
119    };
120    let inverse = if return_inverse {
121        Some(Array::from_vec(Ix1::new([n_data]), inverse_vec)?)
122    } else {
123        None
124    };
125    let counts = if return_counts {
126        Some(Array::from_vec(Ix1::new([n]), unique_counts)?)
127    } else {
128        None
129    };
130
131    Ok(UniqueResult {
132        values,
133        indices,
134        inverse,
135        counts,
136    })
137}
138
139// ---------------------------------------------------------------------------
140// nonzero
141// ---------------------------------------------------------------------------
142
143/// Return the indices of non-zero elements.
144///
145/// Returns a vector of 1-D arrays (u64), one per dimension. For a 1-D input,
146/// returns a single array of indices.
147///
148/// Equivalent to `numpy.nonzero`.
149pub fn nonzero<T, D>(a: &Array<T, D>) -> FerrayResult<Vec<Array<u64, Ix1>>>
150where
151    T: Element + PartialEq + Copy,
152    D: Dimension,
153{
154    let shape = a.shape();
155    let ndim = shape.len();
156    let zero = <T as Element>::zero();
157
158    // Collect all multi-indices where element != 0
159    let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
160
161    // Compute strides for index conversion
162    let mut strides = vec![1usize; ndim];
163    for i in (0..ndim.saturating_sub(1)).rev() {
164        strides[i] = strides[i + 1] * shape[i + 1];
165    }
166
167    for (flat_idx, &val) in a.iter().enumerate() {
168        if val != zero {
169            let mut rem = flat_idx;
170            for d in 0..ndim {
171                indices_per_dim[d].push((rem / strides[d]) as u64);
172                rem %= strides[d];
173            }
174        }
175    }
176
177    let mut result = Vec::with_capacity(ndim);
178    for idx_vec in indices_per_dim {
179        let n = idx_vec.len();
180        result.push(Array::from_vec(Ix1::new([n]), idx_vec)?);
181    }
182
183    Ok(result)
184}
185
186// ---------------------------------------------------------------------------
187// where_
188// ---------------------------------------------------------------------------
189
190/// Conditional element selection.
191///
192/// For each element, if the corresponding element of `condition` is non-zero,
193/// select from `x`; otherwise select from `y`.
194///
195/// All three arrays must have the same shape.
196///
197/// Equivalent to `numpy.where`.
198pub fn where_<T, D>(
199    condition: &Array<bool, D>,
200    x: &Array<T, D>,
201    y: &Array<T, D>,
202) -> FerrayResult<Array<T, D>>
203where
204    T: Element + Copy,
205    D: Dimension,
206{
207    if condition.shape() != x.shape() || condition.shape() != y.shape() {
208        return Err(FerrayError::shape_mismatch(format!(
209            "condition, x, y shapes must match: {:?}, {:?}, {:?}",
210            condition.shape(),
211            x.shape(),
212            y.shape()
213        )));
214    }
215
216    let result: Vec<T> = condition
217        .iter()
218        .zip(x.iter())
219        .zip(y.iter())
220        .map(|((&c, &xv), &yv)| if c { xv } else { yv })
221        .collect();
222
223    Array::from_vec(condition.dim().clone(), result)
224}
225
226/// One-argument form of `where`: return the indices where `condition`
227/// is true, as a vector of 1-D index arrays (one per dimension).
228///
229/// Equivalent to `numpy.where(condition)` (single-argument form) or
230/// `numpy.nonzero(condition.astype(int))`. Added for NumPy parity
231/// (#166) — the three-argument form above is [`where_`].
232pub fn where_condition<D: Dimension>(
233    condition: &Array<bool, D>,
234) -> FerrayResult<Vec<Array<u64, Ix1>>> {
235    let shape = condition.shape();
236    let ndim = shape.len();
237    let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
238
239    let mut strides = vec![1usize; ndim];
240    for i in (0..ndim.saturating_sub(1)).rev() {
241        strides[i] = strides[i + 1] * shape[i + 1];
242    }
243
244    for (flat_idx, &val) in condition.iter().enumerate() {
245        if val {
246            let mut rem = flat_idx;
247            for d in 0..ndim {
248                indices_per_dim[d].push((rem / strides[d]) as u64);
249                rem %= strides[d];
250            }
251        }
252    }
253
254    indices_per_dim
255        .into_iter()
256        .map(|v| {
257            let n = v.len();
258            Array::from_vec(Ix1::new([n]), v)
259        })
260        .collect()
261}
262
263// ---------------------------------------------------------------------------
264// count_nonzero
265// ---------------------------------------------------------------------------
266
267/// Count the number of non-zero elements along a given axis.
268///
269/// Equivalent to `numpy.count_nonzero`.
270pub fn count_nonzero<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
271where
272    T: Element + PartialEq + Copy,
273    D: Dimension,
274{
275    let zero = <T as Element>::zero();
276    let data = borrow_data(a);
277    match axis {
278        None => {
279            let count = data.iter().filter(|&&x| x != zero).count() as u64;
280            make_result(&[], vec![count])
281        }
282        Some(ax) => {
283            validate_axis(ax, a.ndim())?;
284            let shape = a.shape();
285            let out_s = output_shape(shape, ax);
286            let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
287                lane.iter().filter(|&&x| x != zero).count() as u64
288            });
289            make_result(&out_s, result)
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use ferray_core::{Ix1, Ix2};
298
299    #[test]
300    fn test_unique_basic() {
301        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
302        let u = unique(&a, false, false, false).unwrap();
303        let data: Vec<i32> = u.values.iter().copied().collect();
304        assert_eq!(data, vec![1, 2, 3]);
305    }
306
307    #[test]
308    fn test_unique_with_counts() {
309        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
310        let u = unique(&a, false, false, true).unwrap();
311        let vals: Vec<i32> = u.values.iter().copied().collect();
312        let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
313        assert_eq!(vals, vec![1, 2, 3]);
314        assert_eq!(cnts, vec![2, 2, 2]);
315    }
316
317    #[test]
318    fn test_unique_with_index() {
319        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![5, 3, 3, 1, 5]).unwrap();
320        let u = unique(&a, true, false, false).unwrap();
321        let vals: Vec<i32> = u.values.iter().copied().collect();
322        let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
323        assert_eq!(vals, vec![1, 3, 5]);
324        assert_eq!(idxs, vec![3, 1, 0]);
325    }
326
327    // ---- return_inverse (#463) ----
328
329    #[test]
330    fn test_unique_inverse_reconstructs_input() {
331        // The canonical label-encoding use case.
332        let input = vec![3, 1, 2, 1, 3, 2];
333        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), input.clone()).unwrap();
334        let u = unique(&a, false, true, false).unwrap();
335        let vals: Vec<i32> = u.values.iter().copied().collect();
336        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
337        // Unique values must be sorted.
338        assert_eq!(vals, vec![1, 2, 3]);
339        // values[inverse] must reconstruct the flattened input.
340        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
341        assert_eq!(reconstructed, input);
342    }
343
344    #[test]
345    fn test_unique_inverse_all_together() {
346        // Request indices + inverse + counts in one call; each field must
347        // independently match what a single-flag call would produce.
348        let a = Array::<i32, Ix1>::from_vec(Ix1::new([7]), vec![2, 1, 2, 3, 1, 2, 3]).unwrap();
349        let u = unique(&a, true, true, true).unwrap();
350        let vals: Vec<i32> = u.values.iter().copied().collect();
351        let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
352        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
353        let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
354        assert_eq!(vals, vec![1, 2, 3]);
355        assert_eq!(idxs, vec![1, 0, 3]); // first positions of 1, 2, 3
356        assert_eq!(cnts, vec![2, 3, 2]);
357        // Reconstruct via inverse.
358        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
359        assert_eq!(reconstructed, vec![2, 1, 2, 3, 1, 2, 3]);
360    }
361
362    #[test]
363    fn test_unique_inverse_with_2d_flattens_first() {
364        // NumPy's unique flattens the input; inverse has length
365        // shape.iter().product(), indexing into the flat logical traversal.
366        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 1, 3, 2, 1]).unwrap();
367        let u = unique(&a, false, true, false).unwrap();
368        let vals: Vec<i32> = u.values.iter().copied().collect();
369        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
370        assert_eq!(vals, vec![1, 2, 3]);
371        assert_eq!(inv.len(), 6);
372        let flat: Vec<i32> = vec![1, 2, 1, 3, 2, 1];
373        let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
374        assert_eq!(reconstructed, flat);
375    }
376
377    #[test]
378    fn test_unique_inverse_empty_input() {
379        let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
380        let u = unique(&a, false, true, false).unwrap();
381        assert_eq!(u.values.shape(), &[0]);
382        let inv = u.inverse.unwrap();
383        assert_eq!(inv.shape(), &[0]);
384    }
385
386    #[test]
387    fn test_unique_inverse_single_value() {
388        // Every element identical → all inverse entries point at position 0.
389        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![7, 7, 7, 7]).unwrap();
390        let u = unique(&a, false, true, false).unwrap();
391        let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
392        assert_eq!(inv, vec![0, 0, 0, 0]);
393    }
394
395    #[test]
396    fn test_unique_without_inverse_leaves_field_none() {
397        let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 1]).unwrap();
398        let u = unique(&a, false, false, false).unwrap();
399        assert!(u.inverse.is_none());
400    }
401
402    #[test]
403    fn test_nonzero_1d() {
404        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
405        let nz = nonzero(&a).unwrap();
406        assert_eq!(nz.len(), 1);
407        let data: Vec<u64> = nz[0].iter().copied().collect();
408        assert_eq!(data, vec![1, 3]);
409    }
410
411    #[test]
412    fn test_nonzero_2d() {
413        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
414        let nz = nonzero(&a).unwrap();
415        assert_eq!(nz.len(), 2);
416        let rows: Vec<u64> = nz[0].iter().copied().collect();
417        let cols: Vec<u64> = nz[1].iter().copied().collect();
418        assert_eq!(rows, vec![0, 1, 1]);
419        assert_eq!(cols, vec![1, 0, 2]);
420    }
421
422    #[test]
423    fn test_where_basic() {
424        let cond =
425            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
426        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
427        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
428        let r = where_(&cond, &x, &y).unwrap();
429        let data: Vec<f64> = r.iter().copied().collect();
430        assert_eq!(data, vec![1.0, 20.0, 3.0, 40.0]);
431    }
432
433    #[test]
434    fn test_where_shape_mismatch() {
435        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
436        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
437        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
438        assert!(where_(&cond, &x, &y).is_err());
439    }
440
441    #[test]
442    fn test_count_nonzero_total() {
443        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
444        let c = count_nonzero(&a, None).unwrap();
445        assert_eq!(c.iter().next(), Some(&2u64));
446    }
447
448    #[test]
449    fn test_count_nonzero_axis() {
450        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
451        let c = count_nonzero(&a, Some(0)).unwrap();
452        let data: Vec<u64> = c.iter().copied().collect();
453        assert_eq!(data, vec![1, 1, 1]);
454    }
455
456    #[test]
457    fn test_count_nonzero_axis1() {
458        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
459        let c = count_nonzero(&a, Some(1)).unwrap();
460        let data: Vec<u64> = c.iter().copied().collect();
461        assert_eq!(data, vec![1, 2]);
462    }
463}