1use core::cmp::Ordering;
2use core::mem::swap;
3
4#[inline]
19pub fn subtract_iters<T: Ord, I: Iterator<Item = T>, J: Iterator<Item = T>>(
20 left_iter: I,
21 right_iters: Vec<J>,
22) -> SubtractIterator<T, I, J, impl Fn(&T, &T) -> Ordering> {
23 subtract_iters_by(left_iter, right_iters, T::cmp)
24}
25
26pub fn subtract_iters_by<
41 T,
42 I: Iterator<Item = T>,
43 J: Iterator<Item = T>,
44 F: Fn(&T, &T) -> Ordering + Copy,
45>(
46 mut left_iter: I,
47 right_iters: Vec<J>,
48 cmp: F,
49) -> SubtractIterator<T, I, J, F> {
50 let left = left_iter.next();
51 let front = Vec::with_capacity(right_iters.len());
52 SubtractIterator {
53 left_iter,
54 right_iters,
55 cmp,
56 left,
57 front,
58 all_greater: false,
59 collision_index: 0,
60 exhausted_indices: Vec::new(),
61 }
62}
63
64pub struct SubtractIterator<
65 T,
66 I: Iterator<Item = T>,
67 J: Iterator<Item = T>,
68 F: Fn(&T, &T) -> Ordering,
69> {
70 left_iter: I,
71 right_iters: Vec<J>,
72 cmp: F,
73 left: Option<T>,
74 front: Vec<T>,
75 all_greater: bool,
76 collision_index: usize,
77 exhausted_indices: Vec<usize>,
78}
79
80impl<T, I: Iterator<Item = T>, J: Iterator<Item = T>, F: Fn(&T, &T) -> Ordering> Iterator
81 for SubtractIterator<T, I, J, F>
82{
83 type Item = T;
84
85 fn next(&mut self) -> Option<Self::Item> {
86 self.left.as_ref()?;
87 'until_all_greater: while !self.all_greater {
88 let index_iter = ((0..=self.collision_index).rev())
89 .chain((self.collision_index + 1)..self.right_iters.len());
90 for i in index_iter {
91 if i >= self.front.len() {
92 if let Some(x) = self.right_iters[i].next() {
93 self.front.push(x);
94 } else {
95 self.exhausted_indices.push(i);
96 continue;
97 }
98 }
99 let mut ord = (self.cmp)(&self.front[i], self.left.as_ref().unwrap());
100 while ord.is_lt() {
101 if let Some(x) = self.right_iters[i].next() {
102 ord = (self.cmp)(&x, self.left.as_ref().unwrap());
103 self.front[i] = x;
104 } else {
105 self.exhausted_indices.push(i);
106 break;
107 }
108 }
109 if ord.is_eq() {
110 if let Some(x) = self.left_iter.next() {
111 self.left = Some(x);
112 self.collision_index = i;
113 self.exhausted_indices.sort_unstable();
114 for j in self.exhausted_indices.drain(..).rev() {
115 self.front.swap_remove(j);
116 self.right_iters.swap(j, self.front.len());
117 self.right_iters.swap_remove(self.front.len());
118 if self.collision_index == j {
119 self.collision_index = 0;
120 } else if self.collision_index == self.front.len() {
121 self.collision_index = j;
122 }
123 }
124 continue 'until_all_greater;
125 } else {
126 self.left = None;
127 return None;
128 }
129 }
130 }
131 self.exhausted_indices.sort_unstable();
132 for j in self.exhausted_indices.drain(..).rev() {
133 self.front.swap_remove(j);
134 self.right_iters.swap(j, self.front.len());
135 self.right_iters.swap_remove(self.front.len());
136 if self.collision_index == j {
137 self.collision_index = 0;
138 } else if self.collision_index == self.front.len() {
139 self.collision_index = j;
140 }
141 }
142 self.all_greater = true;
143 }
144 let mut res = self.left_iter.next();
145 swap(&mut res, &mut self.left);
146 self.all_greater = false;
147 res
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use rand::{rngs::StdRng, Rng, SeedableRng};
155 use std::collections::HashSet;
156
157 #[test]
158 fn test_subtract() {
159 let it1 = 1u8..=9;
160 let it2 = 6u8..=10;
161 let it3 = 2u8..=4;
162 let res: Vec<_> = subtract_iters(it1, vec![it2, it3]).collect();
163
164 assert_eq!(res, vec![1, 5]);
165 }
166
167 #[test]
168 fn test_subtract_by() {
169 let it1 = (1u8..=9).rev();
170 let it2 = (6u8..=10).rev();
171 let it3 = (2u8..=4).rev();
172 let res: Vec<_> = subtract_iters_by(it1, vec![it2, it3], |x, y| y.cmp(x)).collect();
173
174 assert_eq!(res, vec![5, 1]);
175 }
176
177 #[test]
178 fn test_large_subtract() {
179 const C: usize = 10;
180 const N: usize = 100_000;
181
182 let left_iter = (0..=(C * N)).step_by(C);
183 let right_iters: Vec<_> = (0..C).map(|i| (0..(C * N)).skip(i).step_by(C)).collect();
184 let res: Vec<_> = subtract_iters(left_iter, right_iters).collect();
185
186 assert_eq!(res, vec![C * N]);
187 }
188
189 #[test]
190 fn test_random_subtract() {
191 const C: usize = 10;
192 const N: usize = 10_000;
193
194 let mut rng = StdRng::seed_from_u64(42);
195 let mut vecs = Vec::with_capacity(C);
196 for _ in 0..C {
197 let mut vec = Vec::with_capacity(N);
198 for _ in 0..N {
199 vec.push(rng.gen::<u16>());
200 }
201 vec.sort_unstable();
202 vec.dedup();
203 vecs.push(vec);
204 }
205 let left_iter = vecs[0].iter();
206 let right_iters: Vec<_> = vecs.iter().skip(1).map(|v| v.iter()).collect();
207 let res: HashSet<_> = subtract_iters(left_iter, right_iters).collect();
208
209 for &x in res.iter() {
210 assert!(vecs[0].contains(x));
211 }
212 for vec in vecs.iter().skip(1) {
213 for x in vec {
214 assert!(!res.contains(x));
215 }
216 }
217 }
218}