1use crate::rhs::RHS;
2
3enum JoinType {
4 Inner,
5 Outer,
6 Semi,
7 Anti,
8}
9
10pub trait JoinableGrouped<'a, LIt, R, P, L> {
16 fn inner_join_grouped(
24 self,
25 rhs: impl Into<RHS<'a, R>>,
26 predicate: P,
27 ) -> JoinedGrouped<'a, LIt, R, P>;
28
29 fn outer_join_grouped(
37 self,
38 rhs: impl Into<RHS<'a, R>>,
39 predicate: P,
40 ) -> JoinedGrouped<'a, LIt, R, P>;
41
42 fn semi_join(self, rhs: impl Into<RHS<'a, R>>, predicate: P) -> JoinedLeft<'a, LIt, R, P>;
48
49 fn anti_join(self, rhs: impl Into<RHS<'a, R>>, predicate: P) -> JoinedLeft<'a, LIt, R, P>;
55}
56
57impl<'a, LIt, R, P, L> JoinableGrouped<'a, LIt, R, P, L> for LIt
58where
59 LIt: Iterator<Item = L>,
60 L: 'a,
61 R: 'a,
62 P: Fn(&L, &R) -> std::cmp::Ordering,
63{
64 fn inner_join_grouped(
65 self,
66 rhs: impl Into<RHS<'a, R>>,
67 predicate: P,
68 ) -> JoinedGrouped<'a, LIt, R, P> {
69 JoinedGrouped {
70 lhs_iter: self,
71 rhs: rhs.into(),
72 predicate,
73 join_type: crate::joined_grouped::JoinType::Inner,
74 }
75 }
76
77 fn outer_join_grouped(
78 self,
79 rhs: impl Into<RHS<'a, R>>,
80 predicate: P,
81 ) -> JoinedGrouped<'a, LIt, R, P> {
82 JoinedGrouped {
83 lhs_iter: self,
84 rhs: rhs.into(),
85 predicate,
86 join_type: crate::joined_grouped::JoinType::Outer,
87 }
88 }
89
90 fn semi_join(self, rhs: impl Into<RHS<'a, R>>, predicate: P) -> JoinedLeft<'a, LIt, R, P> {
91 JoinedLeft {
92 lhs_iter: self,
93 rhs: rhs.into(),
94 predicate,
95 join_type: crate::joined_grouped::JoinType::Semi,
96 }
97 }
98
99 fn anti_join(self, rhs: impl Into<RHS<'a, R>>, predicate: P) -> JoinedLeft<'a, LIt, R, P> {
100 JoinedLeft {
101 lhs_iter: self,
102 rhs: rhs.into(),
103 predicate,
104 join_type: crate::joined_grouped::JoinType::Anti,
105 }
106 }
107}
108
109pub struct JoinedGrouped<'a, LIt, R, P> {
111 lhs_iter: LIt,
113
114 rhs: RHS<'a, R>,
116
117 predicate: P,
119
120 join_type: JoinType,
122}
123
124impl<'a, LIt, R, P, L> Iterator for JoinedGrouped<'a, LIt, R, P>
125where
126 LIt: Iterator<Item = L>,
127 L: 'a,
128 R: 'a,
129 P: Fn(&L, &R) -> std::cmp::Ordering,
130{
131 type Item = (L, Vec<&'a R>);
132
133 fn next(&mut self) -> Option<Self::Item> {
134 loop {
135 let left = self.lhs_iter.next()?;
136
137 let rs = match self.rhs {
138 RHS::Unsorted(inner) => inner
139 .iter()
140 .filter(|r| (self.predicate)(&left, r).is_eq())
141 .collect::<Vec<_>>(),
142 RHS::Sorted(inner) => {
143 match inner.binary_search_by(|r| (self.predicate)(&left, r)) {
144 Ok(mut pos) => {
145 let mut rs = Vec::new();
146 while pos > 0 && (self.predicate)(&left, &inner[pos - 1]).is_eq() {
148 pos -= 1;
149 }
150
151 while pos < inner.len() && (self.predicate)(&left, &inner[pos]).is_eq()
153 {
154 rs.push(&inner[pos]);
155 pos += 1;
156 }
157
158 rs
159 }
160 Err(_) => vec![],
161 }
162 }
163 };
164
165 match self.join_type {
166 JoinType::Inner => {
167 if !rs.is_empty() {
168 return Some((left, rs));
169 }
170 }
171
172 JoinType::Outer => return Some((left, rs)),
173
174 JoinType::Semi => unreachable!(),
175 JoinType::Anti => unreachable!(),
176 }
177 }
178 }
179}
180
181pub struct JoinedLeft<'a, LIt, R, P> {
183 lhs_iter: LIt,
184 rhs: RHS<'a, R>,
185 predicate: P,
186 join_type: JoinType,
187}
188
189impl<'a, LIt, R, P, L> Iterator for JoinedLeft<'a, LIt, R, P>
190where
191 LIt: Iterator<Item = L>,
192 P: Fn(&L, &R) -> std::cmp::Ordering,
193{
194 type Item = L;
195
196 fn next(&mut self) -> Option<Self::Item> {
197 loop {
198 let left = self.lhs_iter.next()?;
199
200 let has_right = self.rhs.has_value(&left, &self.predicate);
201
202 match self.join_type {
203 JoinType::Semi if has_right => return Some(left),
204 JoinType::Anti if !has_right => return Some(left),
205
206 JoinType::Semi => {}
207 JoinType::Anti => {}
208
209 JoinType::Inner => unreachable!(),
210 JoinType::Outer => unreachable!(),
211 }
212 }
213 }
214}
215
216#[cfg(test)]
217const LEFT_ITEMS: [(usize, &'static str); 12] = [
218 (0, "zero"),
219 (0, "nil"),
220 (1, "one"),
221 (2, "two"),
222 (3, "three"),
223 (4, "four"),
224 (5, "five"),
225 (6, "six"),
226 (7, "seven"),
227 (8, "eight"),
228 (9, "nine"),
229 (10, "ten"),
230];
231
232#[cfg(test)]
233const RIGHT_ITEMS: [(usize, &'static str); 8] = [
234 (0, "zéro"),
235 (1, "un"),
236 (1, "uno"),
237 (1, "ichi"),
238 (2, "dos"),
239 (2, "deux"),
240 (3, "trois"),
241 (4, "quatre"),
242];
243
244#[test]
245fn test_left_semi() {
246 let joined = LEFT_ITEMS
247 .iter()
248 .semi_join(&RIGHT_ITEMS[..], |l, r| l.0.cmp(&r.0))
249 .collect::<Vec<_>>();
250
251 assert_eq!(joined.len(), 6);
252
253 assert_eq!(joined[0], &(0, "zero"));
254 assert_eq!(joined[1], &(0, "nil"));
255 assert_eq!(joined[2], &(1, "one"));
256 assert_eq!(joined[3], &(2, "two"));
257 assert_eq!(joined[4], &(3, "three"));
258 assert_eq!(joined[5], &(4, "four"));
259}
260
261#[test]
262fn test_left_anti() {
263 let joined = LEFT_ITEMS
264 .iter()
265 .anti_join(&RIGHT_ITEMS[..], |l, r| l.0.cmp(&r.0))
266 .collect::<Vec<_>>();
267
268 assert_eq!(joined.len(), 6);
269
270 assert_eq!(joined[0], &(5, "five"));
271 assert_eq!(joined[1], &(6, "six"));
272 assert_eq!(joined[2], &(7, "seven"));
273 assert_eq!(joined[3], &(8, "eight"));
274 assert_eq!(joined[4], &(9, "nine"));
275 assert_eq!(joined[5], &(10, "ten"));
276}
277
278#[test]
279fn test_left_inner_grouped() {
280 let joined = LEFT_ITEMS
281 .iter()
282 .inner_join_grouped(&RIGHT_ITEMS[..], |l, r| l.0.cmp(&r.0))
283 .collect::<Vec<_>>();
284
285 assert_eq!(joined.len(), 6);
286
287 let mut it = joined.into_iter();
288
289 assert_eq!(it.next(), Some((&(0, "zero"), vec![&(0, "zéro")])));
290 assert_eq!(it.next(), Some((&(0, "nil"), vec![&(0, "zéro")])));
291 assert_eq!(
292 it.next(),
293 Some((&(1, "one"), vec![&(1, "un"), &(1, "uno"), &(1, "ichi")]))
294 );
295 assert_eq!(
296 it.next(),
297 Some((&(2, "two"), vec![&(2, "dos"), &(2, "deux")]))
298 );
299 assert_eq!(it.next(), Some((&(3, "three"), vec![&(3, "trois")])));
300 assert_eq!(it.next(), Some((&(4, "four"), vec![&(4, "quatre")])));
301}
302
303#[test]
304fn test_left_outer_grouped() {
305 let joined = LEFT_ITEMS
306 .iter()
307 .outer_join_grouped(&RIGHT_ITEMS[..], |l, r| l.0.cmp(&r.0))
308 .collect::<Vec<_>>();
309
310 assert_eq!(joined.len(), 12);
311
312 let mut it = joined.into_iter();
313
314 assert_eq!(it.next(), Some((&(0, "zero"), vec![&(0, "zéro")])));
315 assert_eq!(it.next(), Some((&(0, "nil"), vec![&(0, "zéro")])));
316 assert_eq!(
317 it.next(),
318 Some((&(1, "one"), vec![&(1, "un"), &(1, "uno"), &(1, "ichi")]))
319 );
320 assert_eq!(
321 it.next(),
322 Some((&(2, "two"), vec![&(2, "dos"), &(2, "deux")]))
323 );
324 assert_eq!(it.next(), Some((&(3, "three"), vec![&(3, "trois")])));
325 assert_eq!(it.next(), Some((&(4, "four"), vec![&(4, "quatre")])));
326
327 assert_eq!(it.next(), Some((&(5, "five"), vec![])));
329 assert_eq!(it.next(), Some((&(6, "six"), vec![])));
330 assert_eq!(it.next(), Some((&(7, "seven"), vec![])));
331 assert_eq!(it.next(), Some((&(8, "eight"), vec![])));
332 assert_eq!(it.next(), Some((&(9, "nine"), vec![])));
333 assert_eq!(it.next(), Some((&(10, "ten"), vec![])));
334}