joinable/
joined_grouped.rs

1use crate::rhs::RHS;
2
3enum JoinType {
4    Inner,
5    Outer,
6    Semi,
7    Anti,
8}
9
10/// A trait allowing the joining of a left-hand side (LHS) and a right-hand side ([RHS]) dataset.
11///
12/// Results for [inner_join_grouped](JoinableGrouped::inner_join_grouped) and
13/// [outer_join_grouped](JoinableGrouped::outer_join_grouped) are individual LHS records and a
14/// `Vec<R>`, which can be empty for outer joins if no match is found.
15pub trait JoinableGrouped<'a, LIt, R, P, L> {
16    /// Joins LHS and RHS, keeping only records from left that have one or more matches in right.
17    ///
18    /// The specified predicate returns a [std::cmp::Ordering] comparing left and right records.
19    ///
20    /// Like `outer_join_grouped`, this function returns a `(L, Vec<&R>)` with matching records from
21    /// RHS being collected. If multiple records from left match a given record from right,
22    /// right records may be returned multiple times.
23    fn inner_join_grouped(
24        self,
25        rhs: impl Into<RHS<'a, R>>,
26        predicate: P,
27    ) -> JoinedGrouped<'a, LIt, R, P>;
28
29    /// Joins LHS and RHS, keeping _all_ records from left.
30    ///
31    /// The specified predicate returns a [std::cmp::Ordering] comparing left and right records.
32    ///
33    /// Like `inner_join_grouped`, this function returns a `(L, Vec<&R>)` with matching records from
34    /// RHS being collected. If multiple records from left match a given record from right,
35    /// right records may be returned multiple times.
36    fn outer_join_grouped(
37        self,
38        rhs: impl Into<RHS<'a, R>>,
39        predicate: P,
40    ) -> JoinedGrouped<'a, LIt, R, P>;
41
42    /// Joins LHS and RHS, keeping all records from left that have one or more matches in right.
43    ///
44    /// The specified predicate returns a [std::cmp::Ordering] comparing left and right records.
45    ///
46    /// Like `anti_join`, this function only returns left records.
47    fn semi_join(self, rhs: impl Into<RHS<'a, R>>, predicate: P) -> JoinedLeft<'a, LIt, R, P>;
48
49    /// Joins LHS and RHS, keeping all records from left that have _no_ matches in right.
50    ///
51    /// The specified predicate returns a [std::cmp::Ordering] comparing left and right records.
52    ///
53    /// Like `semi_join`, this function only returns left records.
54    fn anti_join(self, rhs: impl Into<RHS<'a, R>>, predicate: P) -> JoinedLeft<'a, LIt, R, P>;
55}
56
57impl<'a, LIt, R, P, L> JoinableGrouped<'a, LIt, R, P, L> for LIt
58where
59    LIt: Iterator<Item = L>,
60    L: 'a,
61    R: 'a,
62    P: Fn(&L, &R) -> std::cmp::Ordering,
63{
64    fn inner_join_grouped(
65        self,
66        rhs: impl Into<RHS<'a, R>>,
67        predicate: P,
68    ) -> JoinedGrouped<'a, LIt, R, P> {
69        JoinedGrouped {
70            lhs_iter: self,
71            rhs: rhs.into(),
72            predicate,
73            join_type: crate::joined_grouped::JoinType::Inner,
74        }
75    }
76
77    fn outer_join_grouped(
78        self,
79        rhs: impl Into<RHS<'a, R>>,
80        predicate: P,
81    ) -> JoinedGrouped<'a, LIt, R, P> {
82        JoinedGrouped {
83            lhs_iter: self,
84            rhs: rhs.into(),
85            predicate,
86            join_type: crate::joined_grouped::JoinType::Outer,
87        }
88    }
89
90    fn semi_join(self, rhs: impl Into<RHS<'a, R>>, predicate: P) -> JoinedLeft<'a, LIt, R, P> {
91        JoinedLeft {
92            lhs_iter: self,
93            rhs: rhs.into(),
94            predicate,
95            join_type: crate::joined_grouped::JoinType::Semi,
96        }
97    }
98
99    fn anti_join(self, rhs: impl Into<RHS<'a, R>>, predicate: P) -> JoinedLeft<'a, LIt, R, P> {
100        JoinedLeft {
101            lhs_iter: self,
102            rhs: rhs.into(),
103            predicate,
104            join_type: crate::joined_grouped::JoinType::Anti,
105        }
106    }
107}
108
109/// The intermediate result of an inner- or outer-join that will yield `(L, Vec<&R>)` values.
110pub struct JoinedGrouped<'a, LIt, R, P> {
111    /// The iterator over all left-hand side values
112    lhs_iter: LIt,
113
114    /// A value giving us access to all right-hand side values
115    rhs: RHS<'a, R>,
116
117    /// A comparison predicate: Fn(&L, &R) -> std::cmp::Ordering
118    predicate: P,
119
120    /// One of: Inner, Outer, Semi, Anti
121    join_type: JoinType,
122}
123
124impl<'a, LIt, R, P, L> Iterator for JoinedGrouped<'a, LIt, R, P>
125where
126    LIt: Iterator<Item = L>,
127    L: 'a,
128    R: 'a,
129    P: Fn(&L, &R) -> std::cmp::Ordering,
130{
131    type Item = (L, Vec<&'a R>);
132
133    fn next(&mut self) -> Option<Self::Item> {
134        loop {
135            let left = self.lhs_iter.next()?;
136
137            let rs = match self.rhs {
138                RHS::Unsorted(inner) => inner
139                    .iter()
140                    .filter(|r| (self.predicate)(&left, r).is_eq())
141                    .collect::<Vec<_>>(),
142                RHS::Sorted(inner) => {
143                    match inner.binary_search_by(|r| (self.predicate)(&left, r)) {
144                        Ok(mut pos) => {
145                            let mut rs = Vec::new();
146                            // We found *a* match, but it may not be the first one
147                            while pos > 0 && (self.predicate)(&left, &inner[pos - 1]).is_eq() {
148                                pos -= 1;
149                            }
150
151                            // Found the first; now add every one in order until we reach a different one or the end
152                            while pos < inner.len() && (self.predicate)(&left, &inner[pos]).is_eq()
153                            {
154                                rs.push(&inner[pos]);
155                                pos += 1;
156                            }
157
158                            rs
159                        }
160                        Err(_) => vec![],
161                    }
162                }
163            };
164
165            match self.join_type {
166                JoinType::Inner => {
167                    if !rs.is_empty() {
168                        return Some((left, rs));
169                    }
170                }
171
172                JoinType::Outer => return Some((left, rs)),
173
174                JoinType::Semi => unreachable!(),
175                JoinType::Anti => unreachable!(),
176            }
177        }
178    }
179}
180
181/// The intermediate result of a semi- or anti-join that will yield `L` values.
182pub struct JoinedLeft<'a, LIt, R, P> {
183    lhs_iter: LIt,
184    rhs: RHS<'a, R>,
185    predicate: P,
186    join_type: JoinType,
187}
188
189impl<'a, LIt, R, P, L> Iterator for JoinedLeft<'a, LIt, R, P>
190where
191    LIt: Iterator<Item = L>,
192    P: Fn(&L, &R) -> std::cmp::Ordering,
193{
194    type Item = L;
195
196    fn next(&mut self) -> Option<Self::Item> {
197        loop {
198            let left = self.lhs_iter.next()?;
199
200            let has_right = self.rhs.has_value(&left, &self.predicate);
201
202            match self.join_type {
203                JoinType::Semi if has_right => return Some(left),
204                JoinType::Anti if !has_right => return Some(left),
205
206                JoinType::Semi => {}
207                JoinType::Anti => {}
208
209                JoinType::Inner => unreachable!(),
210                JoinType::Outer => unreachable!(),
211            }
212        }
213    }
214}
215
216#[cfg(test)]
217const LEFT_ITEMS: [(usize, &'static str); 12] = [
218    (0, "zero"),
219    (0, "nil"),
220    (1, "one"),
221    (2, "two"),
222    (3, "three"),
223    (4, "four"),
224    (5, "five"),
225    (6, "six"),
226    (7, "seven"),
227    (8, "eight"),
228    (9, "nine"),
229    (10, "ten"),
230];
231
232#[cfg(test)]
233const RIGHT_ITEMS: [(usize, &'static str); 8] = [
234    (0, "zéro"),
235    (1, "un"),
236    (1, "uno"),
237    (1, "ichi"),
238    (2, "dos"),
239    (2, "deux"),
240    (3, "trois"),
241    (4, "quatre"),
242];
243
244#[test]
245fn test_left_semi() {
246    let joined = LEFT_ITEMS
247        .iter()
248        .semi_join(&RIGHT_ITEMS[..], |l, r| l.0.cmp(&r.0))
249        .collect::<Vec<_>>();
250
251    assert_eq!(joined.len(), 6);
252
253    assert_eq!(joined[0], &(0, "zero"));
254    assert_eq!(joined[1], &(0, "nil"));
255    assert_eq!(joined[2], &(1, "one"));
256    assert_eq!(joined[3], &(2, "two"));
257    assert_eq!(joined[4], &(3, "three"));
258    assert_eq!(joined[5], &(4, "four"));
259}
260
261#[test]
262fn test_left_anti() {
263    let joined = LEFT_ITEMS
264        .iter()
265        .anti_join(&RIGHT_ITEMS[..], |l, r| l.0.cmp(&r.0))
266        .collect::<Vec<_>>();
267
268    assert_eq!(joined.len(), 6);
269
270    assert_eq!(joined[0], &(5, "five"));
271    assert_eq!(joined[1], &(6, "six"));
272    assert_eq!(joined[2], &(7, "seven"));
273    assert_eq!(joined[3], &(8, "eight"));
274    assert_eq!(joined[4], &(9, "nine"));
275    assert_eq!(joined[5], &(10, "ten"));
276}
277
278#[test]
279fn test_left_inner_grouped() {
280    let joined = LEFT_ITEMS
281        .iter()
282        .inner_join_grouped(&RIGHT_ITEMS[..], |l, r| l.0.cmp(&r.0))
283        .collect::<Vec<_>>();
284
285    assert_eq!(joined.len(), 6);
286
287    let mut it = joined.into_iter();
288
289    assert_eq!(it.next(), Some((&(0, "zero"), vec![&(0, "zéro")])));
290    assert_eq!(it.next(), Some((&(0, "nil"), vec![&(0, "zéro")])));
291    assert_eq!(
292        it.next(),
293        Some((&(1, "one"), vec![&(1, "un"), &(1, "uno"), &(1, "ichi")]))
294    );
295    assert_eq!(
296        it.next(),
297        Some((&(2, "two"), vec![&(2, "dos"), &(2, "deux")]))
298    );
299    assert_eq!(it.next(), Some((&(3, "three"), vec![&(3, "trois")])));
300    assert_eq!(it.next(), Some((&(4, "four"), vec![&(4, "quatre")])));
301}
302
303#[test]
304fn test_left_outer_grouped() {
305    let joined = LEFT_ITEMS
306        .iter()
307        .outer_join_grouped(&RIGHT_ITEMS[..], |l, r| l.0.cmp(&r.0))
308        .collect::<Vec<_>>();
309
310    assert_eq!(joined.len(), 12);
311
312    let mut it = joined.into_iter();
313
314    assert_eq!(it.next(), Some((&(0, "zero"), vec![&(0, "zéro")])));
315    assert_eq!(it.next(), Some((&(0, "nil"), vec![&(0, "zéro")])));
316    assert_eq!(
317        it.next(),
318        Some((&(1, "one"), vec![&(1, "un"), &(1, "uno"), &(1, "ichi")]))
319    );
320    assert_eq!(
321        it.next(),
322        Some((&(2, "two"), vec![&(2, "dos"), &(2, "deux")]))
323    );
324    assert_eq!(it.next(), Some((&(3, "three"), vec![&(3, "trois")])));
325    assert_eq!(it.next(), Some((&(4, "four"), vec![&(4, "quatre")])));
326
327    // No matches here
328    assert_eq!(it.next(), Some((&(5, "five"), vec![])));
329    assert_eq!(it.next(), Some((&(6, "six"), vec![])));
330    assert_eq!(it.next(), Some((&(7, "seven"), vec![])));
331    assert_eq!(it.next(), Some((&(8, "eight"), vec![])));
332    assert_eq!(it.next(), Some((&(9, "nine"), vec![])));
333    assert_eq!(it.next(), Some((&(10, "ten"), vec![])));
334}