Skip to main content

ferray_stats/
set_ops.rs

1// ferray-stats: Set operations — union1d, intersect1d, setdiff1d, setxor1d, in1d, isin (REQ-18)
2
3use ferray_core::error::FerrayResult;
4use ferray_core::{Array, Element, Ix1};
5
6// ---------------------------------------------------------------------------
7// Helpers
8// ---------------------------------------------------------------------------
9
10/// Sort and deduplicate a vector.
11fn sorted_unique<T: PartialOrd + Copy>(data: &[T]) -> Vec<T> {
12    let mut v: Vec<T> = data.to_vec();
13    v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
14    v.dedup_by(|a, b| (*a).partial_cmp(b) == Some(std::cmp::Ordering::Equal));
15    v
16}
17
18/// Collect array data into a Vec.
19fn to_vec<T: Element + Copy>(a: &Array<T, Ix1>) -> Vec<T> {
20    a.iter().copied().collect()
21}
22
23/// Make a 1-D result array.
24fn make_1d<T: Element>(data: Vec<T>) -> FerrayResult<Array<T, Ix1>> {
25    let n = data.len();
26    Array::from_vec(Ix1::new([n]), data)
27}
28
29// ---------------------------------------------------------------------------
30// union1d
31// ---------------------------------------------------------------------------
32
33/// Return the sorted union of two 1-D arrays.
34///
35/// Equivalent to `numpy.union1d`.
36pub fn union1d<T>(
37    a: &Array<T, Ix1>,
38    b: &Array<T, Ix1>,
39    assume_unique: bool,
40) -> FerrayResult<Array<T, Ix1>>
41where
42    T: Element + PartialOrd + Copy,
43{
44    let av = if assume_unique {
45        to_vec(a)
46    } else {
47        sorted_unique(&to_vec(a))
48    };
49    let bv = if assume_unique {
50        to_vec(b)
51    } else {
52        sorted_unique(&to_vec(b))
53    };
54
55    // Merge two sorted arrays
56    let mut result = Vec::with_capacity(av.len() + bv.len());
57    let (mut i, mut j) = (0, 0);
58    while i < av.len() && j < bv.len() {
59        match av[i]
60            .partial_cmp(&bv[j])
61            .unwrap_or(std::cmp::Ordering::Equal)
62        {
63            std::cmp::Ordering::Less => {
64                result.push(av[i]);
65                i += 1;
66            }
67            std::cmp::Ordering::Greater => {
68                result.push(bv[j]);
69                j += 1;
70            }
71            std::cmp::Ordering::Equal => {
72                result.push(av[i]);
73                i += 1;
74                j += 1;
75            }
76        }
77    }
78    result.extend_from_slice(&av[i..]);
79    result.extend_from_slice(&bv[j..]);
80
81    make_1d(result)
82}
83
84// ---------------------------------------------------------------------------
85// intersect1d
86// ---------------------------------------------------------------------------
87
88/// Return the sorted intersection of two 1-D arrays.
89///
90/// Equivalent to `numpy.intersect1d`.
91pub fn intersect1d<T>(
92    a: &Array<T, Ix1>,
93    b: &Array<T, Ix1>,
94    assume_unique: bool,
95) -> FerrayResult<Array<T, Ix1>>
96where
97    T: Element + PartialOrd + Copy,
98{
99    let av = if assume_unique {
100        to_vec(a)
101    } else {
102        sorted_unique(&to_vec(a))
103    };
104    let bv = if assume_unique {
105        to_vec(b)
106    } else {
107        sorted_unique(&to_vec(b))
108    };
109
110    let mut result = Vec::new();
111    let (mut i, mut j) = (0, 0);
112    while i < av.len() && j < bv.len() {
113        match av[i]
114            .partial_cmp(&bv[j])
115            .unwrap_or(std::cmp::Ordering::Equal)
116        {
117            std::cmp::Ordering::Less => i += 1,
118            std::cmp::Ordering::Greater => j += 1,
119            std::cmp::Ordering::Equal => {
120                result.push(av[i]);
121                i += 1;
122                j += 1;
123            }
124        }
125    }
126
127    make_1d(result)
128}
129
130// ---------------------------------------------------------------------------
131// setdiff1d
132// ---------------------------------------------------------------------------
133
134/// Return the sorted set difference of two 1-D arrays (elements in `a` not in `b`).
135///
136/// Equivalent to `numpy.setdiff1d`.
137pub fn setdiff1d<T>(
138    a: &Array<T, Ix1>,
139    b: &Array<T, Ix1>,
140    assume_unique: bool,
141) -> FerrayResult<Array<T, Ix1>>
142where
143    T: Element + PartialOrd + Copy,
144{
145    let av = if assume_unique {
146        to_vec(a)
147    } else {
148        sorted_unique(&to_vec(a))
149    };
150    let bv = if assume_unique {
151        to_vec(b)
152    } else {
153        sorted_unique(&to_vec(b))
154    };
155
156    let mut result = Vec::new();
157    let (mut i, mut j) = (0, 0);
158    while i < av.len() {
159        if j >= bv.len() {
160            result.push(av[i]);
161            i += 1;
162        } else {
163            match av[i]
164                .partial_cmp(&bv[j])
165                .unwrap_or(std::cmp::Ordering::Equal)
166            {
167                std::cmp::Ordering::Less => {
168                    result.push(av[i]);
169                    i += 1;
170                }
171                std::cmp::Ordering::Greater => {
172                    j += 1;
173                }
174                std::cmp::Ordering::Equal => {
175                    i += 1;
176                    j += 1;
177                }
178            }
179        }
180    }
181
182    make_1d(result)
183}
184
185// ---------------------------------------------------------------------------
186// setxor1d
187// ---------------------------------------------------------------------------
188
189/// Return the sorted symmetric difference of two 1-D arrays.
190///
191/// Elements that are in exactly one of the two arrays.
192///
193/// Equivalent to `numpy.setxor1d`.
194pub fn setxor1d<T>(
195    a: &Array<T, Ix1>,
196    b: &Array<T, Ix1>,
197    assume_unique: bool,
198) -> FerrayResult<Array<T, Ix1>>
199where
200    T: Element + PartialOrd + Copy,
201{
202    let av = if assume_unique {
203        to_vec(a)
204    } else {
205        sorted_unique(&to_vec(a))
206    };
207    let bv = if assume_unique {
208        to_vec(b)
209    } else {
210        sorted_unique(&to_vec(b))
211    };
212
213    let mut result = Vec::new();
214    let (mut i, mut j) = (0, 0);
215    while i < av.len() && j < bv.len() {
216        match av[i]
217            .partial_cmp(&bv[j])
218            .unwrap_or(std::cmp::Ordering::Equal)
219        {
220            std::cmp::Ordering::Less => {
221                result.push(av[i]);
222                i += 1;
223            }
224            std::cmp::Ordering::Greater => {
225                result.push(bv[j]);
226                j += 1;
227            }
228            std::cmp::Ordering::Equal => {
229                i += 1;
230                j += 1;
231            }
232        }
233    }
234    result.extend_from_slice(&av[i..]);
235    result.extend_from_slice(&bv[j..]);
236
237    make_1d(result)
238}
239
240// ---------------------------------------------------------------------------
241// in1d
242// ---------------------------------------------------------------------------
243
244/// Test whether each element of `a` is also present in `b`.
245///
246/// Returns a boolean array of the same length as `a`.
247///
248/// Equivalent to `numpy.in1d`.
249pub fn in1d<T>(
250    a: &Array<T, Ix1>,
251    b: &Array<T, Ix1>,
252    assume_unique: bool,
253) -> FerrayResult<Array<bool, Ix1>>
254where
255    T: Element + PartialOrd + Copy,
256{
257    let av = to_vec(a);
258    let bv = if assume_unique {
259        to_vec(b)
260    } else {
261        sorted_unique(&to_vec(b))
262    };
263
264    let result: Vec<bool> = av
265        .iter()
266        .map(|&val| {
267            bv.binary_search_by(|probe| {
268                probe.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)
269            })
270            .is_ok()
271        })
272        .collect();
273
274    let n = result.len();
275    Array::from_vec(Ix1::new([n]), result)
276}
277
278// ---------------------------------------------------------------------------
279// isin
280// ---------------------------------------------------------------------------
281
282/// Test whether each element of `element` is in `test_elements`.
283///
284/// This is the same as `in1d` but named to match `numpy.isin`.
285///
286/// Equivalent to `numpy.isin`.
287pub fn isin<T>(
288    element: &Array<T, Ix1>,
289    test_elements: &Array<T, Ix1>,
290    assume_unique: bool,
291) -> FerrayResult<Array<bool, Ix1>>
292where
293    T: Element + PartialOrd + Copy,
294{
295    in1d(element, test_elements, assume_unique)
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    fn arr(data: Vec<i32>) -> Array<i32, Ix1> {
303        let n = data.len();
304        Array::from_vec(Ix1::new([n]), data).unwrap()
305    }
306
307    #[test]
308    fn test_union1d() {
309        let a = arr(vec![1, 2, 3]);
310        let b = arr(vec![2, 3, 4]);
311        let u = union1d(&a, &b, false).unwrap();
312        let data: Vec<i32> = u.iter().copied().collect();
313        assert_eq!(data, vec![1, 2, 3, 4]);
314    }
315
316    #[test]
317    fn test_intersect1d() {
318        let a = arr(vec![1, 2, 3, 4]);
319        let b = arr(vec![2, 4, 6]);
320        let i = intersect1d(&a, &b, false).unwrap();
321        let data: Vec<i32> = i.iter().copied().collect();
322        assert_eq!(data, vec![2, 4]);
323    }
324
325    #[test]
326    fn test_setdiff1d() {
327        let a = arr(vec![1, 2, 3, 4]);
328        let b = arr(vec![2, 4]);
329        let d = setdiff1d(&a, &b, false).unwrap();
330        let data: Vec<i32> = d.iter().copied().collect();
331        assert_eq!(data, vec![1, 3]);
332    }
333
334    #[test]
335    fn test_setxor1d() {
336        let a = arr(vec![1, 2, 3]);
337        let b = arr(vec![2, 3, 4]);
338        let x = setxor1d(&a, &b, false).unwrap();
339        let data: Vec<i32> = x.iter().copied().collect();
340        assert_eq!(data, vec![1, 4]);
341    }
342
343    #[test]
344    fn test_in1d() {
345        let a = arr(vec![1, 2, 3, 4, 5]);
346        let b = arr(vec![2, 4]);
347        let r = in1d(&a, &b, false).unwrap();
348        let data: Vec<bool> = r.iter().copied().collect();
349        assert_eq!(data, vec![false, true, false, true, false]);
350    }
351
352    #[test]
353    fn test_isin() {
354        let elem = arr(vec![1, 2, 3, 4, 5]);
355        let test = arr(vec![3, 5]);
356        let r = isin(&elem, &test, false).unwrap();
357        let data: Vec<bool> = r.iter().copied().collect();
358        assert_eq!(data, vec![false, false, true, false, true]);
359    }
360
361    #[test]
362    fn test_union1d_with_duplicates() {
363        let a = arr(vec![3, 1, 2, 1]);
364        let b = arr(vec![4, 2, 3, 2]);
365        let u = union1d(&a, &b, false).unwrap();
366        let data: Vec<i32> = u.iter().copied().collect();
367        assert_eq!(data, vec![1, 2, 3, 4]);
368    }
369
370    #[test]
371    fn test_union1d_assume_unique() {
372        let a = arr(vec![1, 2, 3]);
373        let b = arr(vec![2, 3, 4]);
374        let u = union1d(&a, &b, true).unwrap();
375        let data: Vec<i32> = u.iter().copied().collect();
376        assert_eq!(data, vec![1, 2, 3, 4]);
377    }
378
379    #[test]
380    fn test_setdiff1d_empty_result() {
381        let a = arr(vec![1, 2, 3]);
382        let b = arr(vec![1, 2, 3, 4]);
383        let d = setdiff1d(&a, &b, false).unwrap();
384        assert_eq!(d.size(), 0);
385    }
386
387    #[test]
388    fn test_intersect1d_empty_result() {
389        let a = arr(vec![1, 2, 3]);
390        let b = arr(vec![4, 5, 6]);
391        let i = intersect1d(&a, &b, false).unwrap();
392        assert_eq!(i.size(), 0);
393    }
394}