iter_set_ops/
intersect.rs

1use core::cmp::Ordering;
2use core::mem::swap;
3
4/// Iterates over the intersection of many sorted deduplicated iterators.
5///
6/// # Examples
7///
8/// ```
9/// use iter_set_ops::intersect_iters;
10///
11/// let it1 = 1u8..=5;
12/// let it2 = 3u8..=7;
13/// let it3 = 2u8..=4;
14/// let mut iters = [it1, it2, it3];
15/// let res: Vec<_> = intersect_iters(&mut iters).collect();
16///
17/// assert_eq!(res, vec![3, 4]);
18/// ```
19#[inline]
20pub fn intersect_iters<'a, T: Ord + 'a, I: Iterator<Item = T>>(
21    iters: &mut [I],
22) -> IntersectIterator<T, I, impl Fn(&T, &T) -> Ordering + 'a> {
23    intersect_iters_by(iters, T::cmp)
24}
25
26/// Iterates over the intersection of many sorted deduplicated iterators, using `cmp` as the comparison operator.
27///
28/// # Examples
29///
30/// ```
31/// use iter_set_ops::intersect_iters_by;
32///
33/// let it1 = (1u8..=5).rev();
34/// let it2 = (3u8..=7).rev();
35/// let it3 = (2u8..=4).rev();
36/// let mut iters = [it1, it2, it3];
37/// let res: Vec<_> = intersect_iters_by(&mut iters, |x, y| y.cmp(x)).collect();
38///
39/// assert_eq!(res, vec![4, 3]);
40/// ```
41pub fn intersect_iters_by<'a, T, I: Iterator<Item = T>, F: Fn(&T, &T) -> Ordering + Copy + 'a>(
42    iters: &mut [I],
43    cmp: F,
44) -> IntersectIterator<T, I, F> {
45    let mut front: Vec<T> = Vec::with_capacity(iters.len());
46    let mut max_index = 0;
47    let mut nonmax_index = 0;
48    let mut all_equal = true;
49    if iters.is_empty() {
50        return IntersectIterator {
51            iters,
52            cmp,
53            front,
54            max_index,
55            nonmax_index,
56            all_equal,
57            exhausted: true,
58        };
59    }
60    if let Some(x) = iters[0].next() {
61        front.push(x);
62    } else {
63        return IntersectIterator {
64            iters,
65            cmp,
66            front,
67            max_index,
68            nonmax_index,
69            all_equal,
70            exhausted: true,
71        };
72    }
73    for (i, iter) in iters.iter_mut().enumerate().skip(1) {
74        if let Some(x) = iter.next() {
75            front.push(x);
76            match cmp(&front[i], &front[max_index]) {
77                Ordering::Less => {
78                    nonmax_index = i;
79                    all_equal = false;
80                    break;
81                }
82                Ordering::Greater => {
83                    nonmax_index = max_index;
84                    max_index = i;
85                    all_equal = false;
86                    break;
87                }
88                _ => (),
89            }
90        } else {
91            return IntersectIterator {
92                iters,
93                cmp,
94                front,
95                max_index,
96                nonmax_index,
97                all_equal,
98                exhausted: true,
99            };
100        }
101    }
102    IntersectIterator {
103        iters,
104        cmp,
105        front,
106        max_index,
107        nonmax_index,
108        all_equal,
109        exhausted: false,
110    }
111}
112
113pub struct IntersectIterator<'a, T, I: Iterator<Item = T>, F: Fn(&T, &T) -> Ordering> {
114    iters: &'a mut [I],
115    cmp: F,
116    front: Vec<T>,
117    max_index: usize,
118    nonmax_index: usize,
119    all_equal: bool,
120    exhausted: bool,
121}
122
123impl<T, I: Iterator<Item = T>, F: Fn(&T, &T) -> Ordering> Iterator
124    for IntersectIterator<'_, T, I, F>
125{
126    type Item = T;
127
128    fn next(&mut self) -> Option<Self::Item> {
129        if self.exhausted {
130            return None;
131        }
132        'until_all_equal: while !self.all_equal {
133            let index_iter = ((0..=self.nonmax_index).rev())
134                .chain((self.nonmax_index + 1)..self.iters.len())
135                .filter(|&i| i != self.max_index);
136            for i in index_iter {
137                if i >= self.front.len() {
138                    if let Some(x) = self.iters[i].next() {
139                        self.front.push(x);
140                    } else {
141                        self.exhausted = true;
142                        return None;
143                    }
144                }
145                let mut ord = (self.cmp)(&self.front[i], &self.front[self.max_index]);
146                while ord.is_lt() {
147                    if let Some(x) = self.iters[i].next() {
148                        ord = (self.cmp)(&x, &self.front[self.max_index]);
149                        self.front[i] = x;
150                    } else {
151                        self.exhausted = true;
152                        return None;
153                    }
154                }
155                if ord.is_gt() {
156                    self.nonmax_index = self.max_index;
157                    self.max_index = i;
158                    continue 'until_all_equal;
159                }
160            }
161            self.all_equal = true;
162        }
163        let res = if let Some(mut x) = self.iters[self.max_index].next() {
164            swap(&mut x, &mut self.front[self.max_index]);
165            x
166        } else {
167            self.exhausted = true;
168            return Some(self.front.swap_remove(self.max_index));
169        };
170        let index_iter = ((0..=self.nonmax_index).rev())
171            .chain((self.nonmax_index + 1)..self.iters.len())
172            .filter(|&i| i != self.max_index);
173        for i in index_iter {
174            if let Some(x) = self.iters[i].next() {
175                self.front[i] = x;
176                match (self.cmp)(&self.front[i], &self.front[self.max_index]) {
177                    Ordering::Less => {
178                        self.nonmax_index = i;
179                        self.all_equal = false;
180                        break;
181                    }
182                    Ordering::Greater => {
183                        self.nonmax_index = self.max_index;
184                        self.max_index = i;
185                        self.all_equal = false;
186                        break;
187                    }
188                    _ => (),
189                }
190            } else {
191                self.exhausted = true;
192                break;
193            }
194        }
195        Some(res)
196    }
197}
198
199/// Iterates over the intersection of many sorted deduplicated iterators and groups equal items into a [`Vec`].
200///
201/// # Examples
202///
203/// ```
204/// use iter_set_ops::intersect_iters_detailed;
205///
206/// let it1 = 1u8..=5;
207/// let it2 = 3u8..=7;
208/// let it3 = 2u8..=4;
209/// let mut iters = [it1, it2, it3];
210/// let res: Vec<_> = intersect_iters_detailed(&mut iters).collect();
211///
212/// assert_eq!(res, vec![vec![3, 3, 3], vec![4, 4, 4]]);
213/// ```
214#[inline]
215pub fn intersect_iters_detailed<'a, T: Ord + 'a, I: Iterator<Item = T>>(
216    iters: &mut [I],
217) -> DetailedIntersectIterator<T, I, impl Fn(&T, &T) -> Ordering + 'a> {
218    intersect_iters_detailed_by(iters, T::cmp)
219}
220
221/// Iterates over the intersection of many sorted deduplicated iterators and groups equal items into a [`Vec`], using `cmp` as the comparison operator.
222///
223/// # Examples
224///
225/// ```
226/// use iter_set_ops::intersect_iters_detailed_by;
227///
228/// let it1 = (1u8..=5).rev();
229/// let it2 = (3u8..=7).rev();
230/// let it3 = (2u8..=4).rev();
231/// let mut iters = [it1, it2, it3];
232/// let res: Vec<_> = intersect_iters_detailed_by(&mut iters, |x, y| y.cmp(x)).collect();
233///
234/// assert_eq!(res, vec![vec![4, 4, 4], vec![3, 3, 3]]);
235/// ```
236pub fn intersect_iters_detailed_by<
237    'a,
238    T,
239    I: Iterator<Item = T>,
240    F: Fn(&T, &T) -> Ordering + Copy + 'a,
241>(
242    iters: &mut [I],
243    cmp: F,
244) -> DetailedIntersectIterator<T, I, F> {
245    let mut front: Vec<T> = Vec::with_capacity(iters.len());
246    let mut max_index = 0;
247    let mut nonmax_index = 0;
248    let mut all_equal = true;
249    if iters.is_empty() {
250        return DetailedIntersectIterator {
251            iters,
252            cmp,
253            front,
254            max_index,
255            nonmax_index,
256            all_equal,
257            exhausted: true,
258        };
259    }
260    if let Some(x) = iters[0].next() {
261        front.push(x);
262    } else {
263        return DetailedIntersectIterator {
264            iters,
265            cmp,
266            front,
267            max_index,
268            nonmax_index,
269            all_equal,
270            exhausted: true,
271        };
272    }
273    for (i, iter) in iters.iter_mut().enumerate().skip(1) {
274        if let Some(x) = iter.next() {
275            front.push(x);
276            match cmp(&front[i], &front[max_index]) {
277                Ordering::Less => {
278                    nonmax_index = i;
279                    all_equal = false;
280                    break;
281                }
282                Ordering::Greater => {
283                    nonmax_index = max_index;
284                    max_index = i;
285                    all_equal = false;
286                    break;
287                }
288                _ => (),
289            }
290        } else {
291            return DetailedIntersectIterator {
292                iters,
293                cmp,
294                front,
295                max_index,
296                nonmax_index,
297                all_equal,
298                exhausted: true,
299            };
300        }
301    }
302    DetailedIntersectIterator {
303        iters,
304        cmp,
305        front,
306        max_index,
307        nonmax_index,
308        all_equal,
309        exhausted: false,
310    }
311}
312
313pub struct DetailedIntersectIterator<'a, T, I: Iterator<Item = T>, F: Fn(&T, &T) -> Ordering> {
314    iters: &'a mut [I],
315    cmp: F,
316    front: Vec<T>,
317    max_index: usize,
318    nonmax_index: usize,
319    all_equal: bool,
320    exhausted: bool,
321}
322
323impl<T, I: Iterator<Item = T>, F: Fn(&T, &T) -> Ordering> Iterator
324    for DetailedIntersectIterator<'_, T, I, F>
325{
326    type Item = Vec<T>;
327
328    fn next(&mut self) -> Option<Self::Item> {
329        if self.exhausted {
330            return None;
331        }
332        'until_all_equal: while !self.all_equal {
333            let index_iter = ((0..=self.nonmax_index).rev())
334                .chain((self.nonmax_index + 1)..self.iters.len())
335                .filter(|&i| i != self.max_index);
336            for i in index_iter {
337                if i >= self.front.len() {
338                    if let Some(x) = self.iters[i].next() {
339                        self.front.push(x);
340                    } else {
341                        self.exhausted = true;
342                        return None;
343                    }
344                }
345                let mut ord = (self.cmp)(&self.front[i], &self.front[self.max_index]);
346                while ord.is_lt() {
347                    if let Some(x) = self.iters[i].next() {
348                        ord = (self.cmp)(&x, &self.front[self.max_index]);
349                        self.front[i] = x;
350                    } else {
351                        self.exhausted = true;
352                        return None;
353                    }
354                }
355                if ord.is_gt() {
356                    self.nonmax_index = self.max_index;
357                    self.max_index = i;
358                    continue 'until_all_equal;
359                }
360            }
361            self.all_equal = true;
362        }
363        self.max_index = 0;
364        self.nonmax_index = 0;
365        let mut res = Vec::with_capacity(self.front.len());
366        for (i, iter) in self.iters.iter_mut().enumerate().rev() {
367            if let Some(mut x) = iter.next() {
368                swap(&mut x, &mut self.front[i]);
369                res.push(x);
370                if !self.exhausted {
371                    match (self.cmp)(&self.front[i], &self.front[self.max_index]) {
372                        Ordering::Less => {
373                            self.nonmax_index = i;
374                            self.all_equal = false;
375                        }
376                        Ordering::Greater => {
377                            self.nonmax_index = self.max_index;
378                            self.max_index = i;
379                            self.all_equal = false;
380                        }
381                        _ => (),
382                    }
383                }
384            } else {
385                self.exhausted = true;
386                res.push(self.front.swap_remove(i));
387            }
388        }
389        res.reverse();
390        Some(res)
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use rand::{rngs::StdRng, Rng, SeedableRng};
398    use std::collections::HashSet;
399
400    #[test]
401    fn test_intersect() {
402        let it1 = 1u8..=5;
403        let it2 = 3u8..=7;
404        let it3 = 2u8..=4;
405        let mut iters = [it1, it2, it3];
406        let res: Vec<_> = intersect_iters(&mut iters).collect();
407
408        assert_eq!(res, vec![3, 4]);
409        assert!(iters[1].next().is_some());
410    }
411
412    #[test]
413    fn test_intersect_pair() {
414        let it1 = 3u8..=7;
415        let it2 = 2u8..=4;
416        let mut iters = [it1, it2];
417        let res: Vec<_> = intersect_iters(&mut iters).collect();
418
419        assert_eq!(res, vec![3, 4]);
420        assert!(iters[0].next().is_some());
421    }
422
423    #[test]
424    fn test_intersect_by() {
425        let it1 = (1u8..=5).rev();
426        let it2 = (3u8..=7).rev();
427        let it3 = (2u8..=4).rev();
428        let mut iters = [it1, it2, it3];
429        let res: Vec<_> = intersect_iters_by(&mut iters, |x, y| y.cmp(x)).collect();
430
431        assert_eq!(res, vec![4, 3]);
432        assert!(iters[0].next().is_some());
433    }
434
435    #[test]
436    fn test_intersect_detailed() {
437        let it1 = 1u8..=5;
438        let it2 = 3u8..=7;
439        let it3 = 2u8..=4;
440        let mut iters = [it1, it2, it3];
441        let res: Vec<_> = intersect_iters_detailed(&mut iters).collect();
442
443        assert_eq!(res, vec![vec![3, 3, 3], vec![4, 4, 4]]);
444        assert!(iters[1].next().is_some());
445    }
446
447    #[test]
448    fn test_intersect_detailed_by() {
449        let it1 = (1u8..=5).rev();
450        let it2 = (3u8..=7).rev();
451        let it3 = (2u8..=4).rev();
452        let mut iters = [it1, it2, it3];
453        let res: Vec<_> = intersect_iters_detailed_by(&mut iters, |x, y| y.cmp(x)).collect();
454
455        assert_eq!(res, vec![vec![4, 4, 4], vec![3, 3, 3]]);
456        assert!(iters[0].next().is_some());
457    }
458
459    #[test]
460    fn test_random_intersect() {
461        const C: usize = 5;
462        const N: usize = 100_000;
463
464        let mut rng = StdRng::seed_from_u64(42);
465        let mut vecs = Vec::with_capacity(C);
466        for _ in 0..C {
467            let mut vec = Vec::with_capacity(N);
468            for _ in 0..N {
469                vec.push(rng.gen::<u16>());
470            }
471            vec.sort_unstable();
472            vec.dedup();
473            vecs.push(vec);
474        }
475        let mut iters: Vec<_> = vecs.iter().map(|v| v.iter()).collect();
476        let res: HashSet<_> = intersect_iters(&mut iters).collect();
477        let sets: Vec<HashSet<u16>> = vecs
478            .iter()
479            .map(|vec| vec.iter().copied().collect())
480            .collect();
481
482        for x in res {
483            for set in sets.iter() {
484                assert!(set.contains(x));
485            }
486        }
487    }
488
489    #[test]
490    fn test_intersect_preserve_details() {
491        const C: usize = 5;
492        const N: usize = 100_000;
493
494        let mut rng = StdRng::seed_from_u64(42);
495        let mut vecs = Vec::with_capacity(C);
496        for i in 0..C {
497            let mut vec = Vec::with_capacity(N);
498            for _ in 0..N {
499                vec.push((i, rng.gen::<u16>()));
500            }
501            vec.sort_unstable();
502            vec.dedup();
503            vecs.push(vec);
504        }
505        let mut iters: Vec<_> = vecs.iter().map(|v| v.iter()).collect();
506        for details in intersect_iters_detailed_by(&mut iters, |(_, x), (_, y)| x.cmp(y)) {
507            let x = details[0].1;
508            for (i, &&(j, y)) in details.iter().enumerate() {
509                assert_eq!(x, y);
510                assert_eq!(i, j);
511            }
512        }
513    }
514
515    #[test]
516    fn test_associative_intersect() {
517        const C: usize = 6;
518        const N: usize = 100_000;
519
520        let mut rng = StdRng::seed_from_u64(42);
521        let mut vecs = Vec::with_capacity(C);
522        for _ in 0..C {
523            let mut vec = Vec::with_capacity(N);
524            for _ in 0..N {
525                vec.push(rng.gen::<u16>());
526            }
527            vec.sort_unstable();
528            vec.dedup();
529            vecs.push(vec);
530        }
531
532        let mut iters: Vec<_> = vecs.iter().map(|v| v.iter()).collect();
533        let res6: HashSet<_> = intersect_iters(&mut iters).collect();
534
535        let mut nested_iters: Vec<Vec<_>> = (0..C)
536            .step_by(3)
537            .map(|i| vecs.iter().skip(i).take(3).map(|v| v.iter()).collect())
538            .collect();
539        let res3: HashSet<_> = intersect_iters(
540            &mut nested_iters
541                .iter_mut()
542                .map(|inner_iters| intersect_iters(inner_iters))
543                .collect::<Vec<_>>(),
544        )
545        .collect();
546
547        let mut nested_iters: Vec<Vec<_>> = (0..C)
548            .step_by(2)
549            .map(|i| vecs.iter().skip(i).take(2).map(|v| v.iter()).collect())
550            .collect();
551        let res2: HashSet<_> = intersect_iters(
552            &mut nested_iters
553                .iter_mut()
554                .map(|inner_iters| intersect_iters(inner_iters))
555                .collect::<Vec<_>>(),
556        )
557        .collect();
558
559        assert_eq!(res6, res3);
560        assert_eq!(res6, res2);
561    }
562}