Skip to main content

ferray_stats/
searching.rs

1// ferray-stats: Searching — unique, nonzero, where_, count_nonzero (REQ-14, REQ-15, REQ-16, REQ-17)
2
3use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
5
6use crate::reductions::{
7    borrow_data, make_result, output_shape, reduce_axis_general_u64, validate_axis,
8};
9
10// ---------------------------------------------------------------------------
11// unique
12// ---------------------------------------------------------------------------
13
14/// Result from the `unique` function.
15#[derive(Debug)]
16pub struct UniqueResult<T: Element> {
17    /// The sorted unique values.
18    pub values: Array<T, Ix1>,
19    /// If requested, the indices of the first occurrence of each unique value
20    /// in the original array (as u64).
21    pub indices: Option<Array<u64, Ix1>>,
22    /// If requested, the count of each unique value (as u64).
23    pub counts: Option<Array<u64, Ix1>>,
24}
25
26/// Find the sorted unique elements of an array.
27///
28/// The input is flattened. Optionally returns indices and/or counts.
29///
30/// Equivalent to `numpy.unique`.
31pub fn unique<T, D>(
32    a: &Array<T, D>,
33    return_index: bool,
34    return_counts: bool,
35) -> FerrayResult<UniqueResult<T>>
36where
37    T: Element + PartialOrd + Copy,
38    D: Dimension,
39{
40    let data: Vec<T> = a.iter().copied().collect();
41
42    // Create (value, original_index) pairs, then sort by value
43    let mut pairs: Vec<(T, usize)> = data
44        .iter()
45        .copied()
46        .enumerate()
47        .map(|(i, v)| (v, i))
48        .collect();
49    pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
50
51    // Deduplicate
52    let mut unique_vals = Vec::new();
53    let mut unique_indices: Vec<u64> = Vec::new();
54    let mut unique_counts: Vec<u64> = Vec::new();
55
56    if !pairs.is_empty() {
57        unique_vals.push(pairs[0].0);
58        unique_indices.push(pairs[0].1 as u64);
59        let mut count = 1u64;
60
61        for i in 1..pairs.len() {
62            if pairs[i].0.partial_cmp(&pairs[i - 1].0) != Some(std::cmp::Ordering::Equal) {
63                if return_counts {
64                    unique_counts.push(count);
65                }
66                unique_vals.push(pairs[i].0);
67                unique_indices.push(pairs[i].1 as u64);
68                count = 1;
69            } else {
70                count += 1;
71                // Keep the first occurrence index (smallest original index)
72                let last = unique_indices.len() - 1;
73                let new_idx = pairs[i].1 as u64;
74                if new_idx < unique_indices[last] {
75                    unique_indices[last] = new_idx;
76                }
77            }
78        }
79        if return_counts {
80            unique_counts.push(count);
81        }
82    }
83
84    let n = unique_vals.len();
85    let values = Array::from_vec(Ix1::new([n]), unique_vals)?;
86    let indices = if return_index {
87        Some(Array::from_vec(Ix1::new([n]), unique_indices)?)
88    } else {
89        None
90    };
91    let counts = if return_counts {
92        Some(Array::from_vec(Ix1::new([n]), unique_counts)?)
93    } else {
94        None
95    };
96
97    Ok(UniqueResult {
98        values,
99        indices,
100        counts,
101    })
102}
103
104// ---------------------------------------------------------------------------
105// nonzero
106// ---------------------------------------------------------------------------
107
108/// Return the indices of non-zero elements.
109///
110/// Returns a vector of 1-D arrays (u64), one per dimension. For a 1-D input,
111/// returns a single array of indices.
112///
113/// Equivalent to `numpy.nonzero`.
114pub fn nonzero<T, D>(a: &Array<T, D>) -> FerrayResult<Vec<Array<u64, Ix1>>>
115where
116    T: Element + PartialEq + Copy,
117    D: Dimension,
118{
119    let shape = a.shape();
120    let ndim = shape.len();
121    let zero = <T as Element>::zero();
122
123    // Collect all multi-indices where element != 0
124    let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
125
126    // Compute strides for index conversion
127    let mut strides = vec![1usize; ndim];
128    for i in (0..ndim.saturating_sub(1)).rev() {
129        strides[i] = strides[i + 1] * shape[i + 1];
130    }
131
132    for (flat_idx, &val) in a.iter().enumerate() {
133        if val != zero {
134            let mut rem = flat_idx;
135            for d in 0..ndim {
136                indices_per_dim[d].push((rem / strides[d]) as u64);
137                rem %= strides[d];
138            }
139        }
140    }
141
142    let mut result = Vec::with_capacity(ndim);
143    for idx_vec in indices_per_dim {
144        let n = idx_vec.len();
145        result.push(Array::from_vec(Ix1::new([n]), idx_vec)?);
146    }
147
148    Ok(result)
149}
150
151// ---------------------------------------------------------------------------
152// where_
153// ---------------------------------------------------------------------------
154
155/// Conditional element selection.
156///
157/// For each element, if the corresponding element of `condition` is non-zero,
158/// select from `x`; otherwise select from `y`.
159///
160/// All three arrays must have the same shape.
161///
162/// Equivalent to `numpy.where`.
163pub fn where_<T, D>(
164    condition: &Array<bool, D>,
165    x: &Array<T, D>,
166    y: &Array<T, D>,
167) -> FerrayResult<Array<T, D>>
168where
169    T: Element + Copy,
170    D: Dimension,
171{
172    if condition.shape() != x.shape() || condition.shape() != y.shape() {
173        return Err(FerrayError::shape_mismatch(format!(
174            "condition, x, y shapes must match: {:?}, {:?}, {:?}",
175            condition.shape(),
176            x.shape(),
177            y.shape()
178        )));
179    }
180
181    let result: Vec<T> = condition
182        .iter()
183        .zip(x.iter())
184        .zip(y.iter())
185        .map(|((&c, &xv), &yv)| if c { xv } else { yv })
186        .collect();
187
188    Array::from_vec(condition.dim().clone(), result)
189}
190
191// ---------------------------------------------------------------------------
192// count_nonzero
193// ---------------------------------------------------------------------------
194
195/// Count the number of non-zero elements along a given axis.
196///
197/// Equivalent to `numpy.count_nonzero`.
198pub fn count_nonzero<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
199where
200    T: Element + PartialEq + Copy,
201    D: Dimension,
202{
203    let zero = <T as Element>::zero();
204    let data = borrow_data(a);
205    match axis {
206        None => {
207            let count = data.iter().filter(|&&x| x != zero).count() as u64;
208            make_result(&[], vec![count])
209        }
210        Some(ax) => {
211            validate_axis(ax, a.ndim())?;
212            let shape = a.shape();
213            let out_s = output_shape(shape, ax);
214            let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
215                lane.iter().filter(|&&x| x != zero).count() as u64
216            });
217            make_result(&out_s, result)
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use ferray_core::{Ix1, Ix2};
226
227    #[test]
228    fn test_unique_basic() {
229        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
230        let u = unique(&a, false, false).unwrap();
231        let data: Vec<i32> = u.values.iter().copied().collect();
232        assert_eq!(data, vec![1, 2, 3]);
233    }
234
235    #[test]
236    fn test_unique_with_counts() {
237        let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
238        let u = unique(&a, false, true).unwrap();
239        let vals: Vec<i32> = u.values.iter().copied().collect();
240        let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
241        assert_eq!(vals, vec![1, 2, 3]);
242        assert_eq!(cnts, vec![2, 2, 2]);
243    }
244
245    #[test]
246    fn test_unique_with_index() {
247        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![5, 3, 3, 1, 5]).unwrap();
248        let u = unique(&a, true, false).unwrap();
249        let vals: Vec<i32> = u.values.iter().copied().collect();
250        let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
251        assert_eq!(vals, vec![1, 3, 5]);
252        assert_eq!(idxs, vec![3, 1, 0]);
253    }
254
255    #[test]
256    fn test_nonzero_1d() {
257        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
258        let nz = nonzero(&a).unwrap();
259        assert_eq!(nz.len(), 1);
260        let data: Vec<u64> = nz[0].iter().copied().collect();
261        assert_eq!(data, vec![1, 3]);
262    }
263
264    #[test]
265    fn test_nonzero_2d() {
266        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
267        let nz = nonzero(&a).unwrap();
268        assert_eq!(nz.len(), 2);
269        let rows: Vec<u64> = nz[0].iter().copied().collect();
270        let cols: Vec<u64> = nz[1].iter().copied().collect();
271        assert_eq!(rows, vec![0, 1, 1]);
272        assert_eq!(cols, vec![1, 0, 2]);
273    }
274
275    #[test]
276    fn test_where_basic() {
277        let cond =
278            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
279        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
280        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
281        let r = where_(&cond, &x, &y).unwrap();
282        let data: Vec<f64> = r.iter().copied().collect();
283        assert_eq!(data, vec![1.0, 20.0, 3.0, 40.0]);
284    }
285
286    #[test]
287    fn test_where_shape_mismatch() {
288        let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
289        let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
290        let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
291        assert!(where_(&cond, &x, &y).is_err());
292    }
293
294    #[test]
295    fn test_count_nonzero_total() {
296        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
297        let c = count_nonzero(&a, None).unwrap();
298        assert_eq!(c.iter().next(), Some(&2u64));
299    }
300
301    #[test]
302    fn test_count_nonzero_axis() {
303        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
304        let c = count_nonzero(&a, Some(0)).unwrap();
305        let data: Vec<u64> = c.iter().copied().collect();
306        assert_eq!(data, vec![1, 1, 1]);
307    }
308
309    #[test]
310    fn test_count_nonzero_axis1() {
311        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
312        let c = count_nonzero(&a, Some(1)).unwrap();
313        let data: Vec<u64> = c.iter().copied().collect();
314        assert_eq!(data, vec![1, 2]);
315    }
316}