iter_set_ops/
subtract.rs

1use core::cmp::Ordering;
2use core::mem::swap;
3
4/// Iterates over the difference of many sorted deduplicated iterators.
5///
6/// # Examples
7///
8/// ```
9/// use iter_set_ops::subtract_iters;
10///
11/// let it1 = 1u8..=9;
12/// let it2 = 6u8..=10;
13/// let it3 = 2u8..=4;
14/// let res: Vec<_> = subtract_iters(it1, vec![it2, it3]).collect();
15///
16/// assert_eq!(res, vec![1, 5]);
17/// ```
18#[inline]
19pub fn subtract_iters<T: Ord, I: Iterator<Item = T>, J: Iterator<Item = T>>(
20    left_iter: I,
21    right_iters: Vec<J>,
22) -> SubtractIterator<T, I, J, impl Fn(&T, &T) -> Ordering> {
23    subtract_iters_by(left_iter, right_iters, T::cmp)
24}
25
26/// Iterates over the difference of many sorted deduplicated iterators, using `cmp` as the comparison operator.
27///
28/// # Examples
29///
30/// ```
31/// use iter_set_ops::subtract_iters_by;
32///
33/// let it1 = (1u8..=9).rev();
34/// let it2 = (6u8..=10).rev();
35/// let it3 = (2u8..=4).rev();
36/// let res: Vec<_> = subtract_iters_by(it1, vec![it2, it3], |x, y| y.cmp(x)).collect();
37///
38/// assert_eq!(res, vec![5, 1]);
39/// ```
40pub fn subtract_iters_by<
41    T,
42    I: Iterator<Item = T>,
43    J: Iterator<Item = T>,
44    F: Fn(&T, &T) -> Ordering + Copy,
45>(
46    mut left_iter: I,
47    right_iters: Vec<J>,
48    cmp: F,
49) -> SubtractIterator<T, I, J, F> {
50    let left = left_iter.next();
51    let front = Vec::with_capacity(right_iters.len());
52    SubtractIterator {
53        left_iter,
54        right_iters,
55        cmp,
56        left,
57        front,
58        all_greater: false,
59        collision_index: 0,
60        exhausted_indices: Vec::new(),
61    }
62}
63
64pub struct SubtractIterator<
65    T,
66    I: Iterator<Item = T>,
67    J: Iterator<Item = T>,
68    F: Fn(&T, &T) -> Ordering,
69> {
70    left_iter: I,
71    right_iters: Vec<J>,
72    cmp: F,
73    left: Option<T>,
74    front: Vec<T>,
75    all_greater: bool,
76    collision_index: usize,
77    exhausted_indices: Vec<usize>,
78}
79
80impl<T, I: Iterator<Item = T>, J: Iterator<Item = T>, F: Fn(&T, &T) -> Ordering> Iterator
81    for SubtractIterator<T, I, J, F>
82{
83    type Item = T;
84
85    fn next(&mut self) -> Option<Self::Item> {
86        self.left.as_ref()?;
87        'until_all_greater: while !self.all_greater {
88            let index_iter = ((0..=self.collision_index).rev())
89                .chain((self.collision_index + 1)..self.right_iters.len());
90            for i in index_iter {
91                if i >= self.front.len() {
92                    if let Some(x) = self.right_iters[i].next() {
93                        self.front.push(x);
94                    } else {
95                        self.exhausted_indices.push(i);
96                        continue;
97                    }
98                }
99                let mut ord = (self.cmp)(&self.front[i], self.left.as_ref().unwrap());
100                while ord.is_lt() {
101                    if let Some(x) = self.right_iters[i].next() {
102                        ord = (self.cmp)(&x, self.left.as_ref().unwrap());
103                        self.front[i] = x;
104                    } else {
105                        self.exhausted_indices.push(i);
106                        break;
107                    }
108                }
109                if ord.is_eq() {
110                    if let Some(x) = self.left_iter.next() {
111                        self.left = Some(x);
112                        self.collision_index = i;
113                        self.exhausted_indices.sort_unstable();
114                        for j in self.exhausted_indices.drain(..).rev() {
115                            self.front.swap_remove(j);
116                            self.right_iters.swap(j, self.front.len());
117                            self.right_iters.swap_remove(self.front.len());
118                            if self.collision_index == j {
119                                self.collision_index = 0;
120                            } else if self.collision_index == self.front.len() {
121                                self.collision_index = j;
122                            }
123                        }
124                        continue 'until_all_greater;
125                    } else {
126                        self.left = None;
127                        return None;
128                    }
129                }
130            }
131            self.exhausted_indices.sort_unstable();
132            for j in self.exhausted_indices.drain(..).rev() {
133                self.front.swap_remove(j);
134                self.right_iters.swap(j, self.front.len());
135                self.right_iters.swap_remove(self.front.len());
136                if self.collision_index == j {
137                    self.collision_index = 0;
138                } else if self.collision_index == self.front.len() {
139                    self.collision_index = j;
140                }
141            }
142            self.all_greater = true;
143        }
144        let mut res = self.left_iter.next();
145        swap(&mut res, &mut self.left);
146        self.all_greater = false;
147        res
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use rand::{rngs::StdRng, Rng, SeedableRng};
155    use std::collections::HashSet;
156
157    #[test]
158    fn test_subtract() {
159        let it1 = 1u8..=9;
160        let it2 = 6u8..=10;
161        let it3 = 2u8..=4;
162        let res: Vec<_> = subtract_iters(it1, vec![it2, it3]).collect();
163
164        assert_eq!(res, vec![1, 5]);
165    }
166
167    #[test]
168    fn test_subtract_by() {
169        let it1 = (1u8..=9).rev();
170        let it2 = (6u8..=10).rev();
171        let it3 = (2u8..=4).rev();
172        let res: Vec<_> = subtract_iters_by(it1, vec![it2, it3], |x, y| y.cmp(x)).collect();
173
174        assert_eq!(res, vec![5, 1]);
175    }
176
177    #[test]
178    fn test_large_subtract() {
179        const C: usize = 10;
180        const N: usize = 100_000;
181
182        let left_iter = (0..=(C * N)).step_by(C);
183        let right_iters: Vec<_> = (0..C).map(|i| (0..(C * N)).skip(i).step_by(C)).collect();
184        let res: Vec<_> = subtract_iters(left_iter, right_iters).collect();
185
186        assert_eq!(res, vec![C * N]);
187    }
188
189    #[test]
190    fn test_random_subtract() {
191        const C: usize = 10;
192        const N: usize = 10_000;
193
194        let mut rng = StdRng::seed_from_u64(42);
195        let mut vecs = Vec::with_capacity(C);
196        for _ in 0..C {
197            let mut vec = Vec::with_capacity(N);
198            for _ in 0..N {
199                vec.push(rng.gen::<u16>());
200            }
201            vec.sort_unstable();
202            vec.dedup();
203            vecs.push(vec);
204        }
205        let left_iter = vecs[0].iter();
206        let right_iters: Vec<_> = vecs.iter().skip(1).map(|v| v.iter()).collect();
207        let res: HashSet<_> = subtract_iters(left_iter, right_iters).collect();
208
209        for &x in res.iter() {
210            assert!(vecs[0].contains(x));
211        }
212        for vec in vecs.iter().skip(1) {
213            for x in vec {
214                assert!(!res.contains(x));
215            }
216        }
217    }
218}