noir_compute/operator/join/
local_sort_merge.rs

1#![allow(clippy::type_complexity)]
2
3use std::collections::VecDeque;
4use std::fmt::Display;
5use std::marker::PhantomData;
6
7use crate::block::{BlockStructure, OperatorStructure};
8use crate::operator::join::ship::{ShipBroadcastRight, ShipHash, ShipStrategy};
9use crate::operator::join::{InnerJoinTuple, JoinVariant, LeftJoinTuple, OuterJoinTuple};
10use crate::operator::start::{BinaryElement, BinaryStartOperator};
11use crate::operator::{Data, DataKey, ExchangeData, KeyerFn, Operator, StreamElement};
12use crate::scheduler::ExecutionMetadata;
13use crate::stream::{KeyedStream, Stream};
14use crate::worker::replica_coord;
15
16/// This operator performs the join using the local sort-merge strategy.
17///
18/// This operator is able to produce the outer join tuples (the most general type of join), but it
19/// can be asked to skip generating the `None` tuples if the join was actually inner.
20#[derive(Clone, Debug)]
21struct JoinLocalSortMerge<
22    Key: Data + Ord,
23    Out1: ExchangeData,
24    Out2: ExchangeData,
25    Keyer1: KeyerFn<Key, Out1>,
26    Keyer2: KeyerFn<Key, Out2>,
27    OperatorChain: Operator<Out = BinaryElement<Out1, Out2>>,
28> {
29    prev: OperatorChain,
30
31    keyer1: Keyer1,
32    keyer2: Keyer2,
33
34    /// Whether the left side has ended.
35    left_ended: bool,
36    /// Whether the right side has ended.
37    right_ended: bool,
38    /// Elements of the left side.
39    left: Vec<(Key, Out1)>,
40    /// Elements of the right side.
41    right: Vec<(Key, Out2)>,
42    /// Buffer with elements ready to be sent downstream.
43    buffer: VecDeque<(Key, OuterJoinTuple<Out1, Out2>)>,
44    /// Join variant.
45    variant: JoinVariant,
46    /// The last key of the last element processed by `advance()` coming from the left side.
47    /// This is used to check whether an element of the right side was matched with an element
48    /// of the left side or not.
49    last_left_key: Option<Key>,
50}
51
52impl<
53        Key: Data + Ord,
54        Out1: ExchangeData,
55        Out2: ExchangeData,
56        Keyer1: KeyerFn<Key, Out1>,
57        Keyer2: KeyerFn<Key, Out2>,
58        OperatorChain: Operator<Out = BinaryElement<Out1, Out2>>,
59    > Display for JoinLocalSortMerge<Key, Out1, Out2, Keyer1, Keyer2, OperatorChain>
60{
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(
63            f,
64            "{} -> JoinLocalSortMerge<{}>",
65            self.prev,
66            std::any::type_name::<Key>()
67        )
68    }
69}
70impl<
71        Key: Data + Ord,
72        Out1: ExchangeData,
73        Out2: ExchangeData,
74        Keyer1: KeyerFn<Key, Out1>,
75        Keyer2: KeyerFn<Key, Out2>,
76        OperatorChain: Operator<Out = BinaryElement<Out1, Out2>>,
77    > JoinLocalSortMerge<Key, Out1, Out2, Keyer1, Keyer2, OperatorChain>
78{
79    fn new(prev: OperatorChain, variant: JoinVariant, keyer1: Keyer1, keyer2: Keyer2) -> Self {
80        Self {
81            prev,
82            keyer1,
83            keyer2,
84            left_ended: false,
85            right_ended: false,
86            left: Default::default(),
87            right: Default::default(),
88            buffer: Default::default(),
89            variant,
90            last_left_key: None,
91        }
92    }
93
94    /// Discard the last element in the buffer containing elements of the right side.
95    /// If needed, generate the right-outer join tuple.
96    fn discard_right(&mut self) {
97        let (rkey, rvalue) = self.right.pop().unwrap();
98
99        // check if the element has been matched with at least one left-side element
100        let matched = matches!(&self.last_left_key, Some(lkey) if lkey == &rkey);
101
102        if !matched && self.variant.right_outer() {
103            self.buffer.push_back((rkey, (None, Some(rvalue))));
104        }
105    }
106
107    /// Generate some join tuples. Since the number of join tuples can be quite high,
108    /// this is used to generate the tuples incrementally, so that the memory usage is lower.
109    fn advance(&mut self) {
110        while self.buffer.is_empty() && (!self.left.is_empty() || !self.right.is_empty()) {
111            // try matching one element of the left side with some elements of the right side
112            if let Some((lkey, lvalue)) = self.left.pop() {
113                // discard the elements of the right side with key bigger than the key of
114                // the element of the left side
115                let discarded = self
116                    .right
117                    .iter()
118                    .rev()
119                    .take_while(|(rkey, _)| rkey > &lkey)
120                    .count();
121                for _ in 0..discarded {
122                    self.discard_right();
123                }
124
125                // check if there is at least one element matching in the right side
126                let has_matches = matches!(self.right.last(), Some((rkey, _)) if rkey == &lkey);
127
128                if has_matches {
129                    let matches = self
130                        .right
131                        .iter()
132                        .rev()
133                        .take_while(|(rkey, _)| &lkey == rkey)
134                        .map(|(_, rvalue)| {
135                            (lkey.clone(), (Some(lvalue.clone()), Some(rvalue.clone())))
136                        });
137                    self.buffer.extend(matches);
138                } else if self.variant.left_outer() {
139                    self.buffer.push_back((lkey.clone(), (Some(lvalue), None)));
140                }
141
142                // set this key as the last key processed
143                self.last_left_key = Some(lkey);
144            } else {
145                // there are no elements left in the left side,
146                // so discard what is remaining in the right side
147                while !self.right.is_empty() {
148                    self.discard_right();
149                }
150            }
151        }
152    }
153}
154
155impl<
156        Key: Data + Ord,
157        Out1: ExchangeData,
158        Out2: ExchangeData,
159        Keyer1: KeyerFn<Key, Out1>,
160        Keyer2: KeyerFn<Key, Out2>,
161        OperatorChain: Operator<Out = BinaryElement<Out1, Out2>>,
162    > Operator for JoinLocalSortMerge<Key, Out1, Out2, Keyer1, Keyer2, OperatorChain>
163{
164    type Out = (Key, OuterJoinTuple<Out1, Out2>);
165
166    fn setup(&mut self, metadata: &mut ExecutionMetadata) {
167        self.prev.setup(metadata);
168    }
169
170    fn next(&mut self) -> StreamElement<(Key, (Option<Out1>, Option<Out2>))> {
171        loop {
172            if self.buffer.is_empty() && self.left_ended && self.right_ended {
173                // try to generate some join tuples
174                self.advance();
175            }
176
177            if let Some(item) = self.buffer.pop_front() {
178                return StreamElement::Item(item);
179            }
180
181            match self.prev.next() {
182                StreamElement::Item(BinaryElement::Left(item)) => {
183                    self.left.push(((self.keyer1)(&item), item));
184                }
185                StreamElement::Item(BinaryElement::Right(item)) => {
186                    self.right.push(((self.keyer2)(&item), item));
187                }
188                StreamElement::Item(BinaryElement::LeftEnd) => {
189                    self.left_ended = true;
190                    self.left.sort_unstable_by(|(k1, _), (k2, _)| k1.cmp(k2));
191                }
192                StreamElement::Item(BinaryElement::RightEnd) => {
193                    self.right_ended = true;
194                    self.right.sort_unstable_by(|(k1, _), (k2, _)| k1.cmp(k2));
195                }
196                StreamElement::Timestamped(_, _) | StreamElement::Watermark(_) => {
197                    panic!("Cannot join timestamp streams")
198                }
199                StreamElement::FlushAndRestart => {
200                    assert!(self.left_ended, "{} left missing", replica_coord().unwrap());
201                    assert!(
202                        self.right_ended,
203                        "{} right missing",
204                        replica_coord().unwrap()
205                    );
206                    assert!(self.left.is_empty());
207                    assert!(self.right.is_empty());
208
209                    // reset the state of the operator
210                    self.left_ended = false;
211                    self.right_ended = false;
212                    self.last_left_key = None;
213
214                    return StreamElement::FlushAndRestart;
215                }
216                StreamElement::FlushBatch => return StreamElement::FlushBatch,
217                StreamElement::Terminate => return StreamElement::Terminate,
218            }
219        }
220    }
221
222    fn structure(&self) -> BlockStructure {
223        self.prev.structure().add_operator(OperatorStructure::new::<
224            (Key, OuterJoinTuple<Out1, Out2>),
225            _,
226        >("JoinLocalSortMerge"))
227    }
228}
229
230/// This is an intermediate type for building a join operator.
231///
232/// The ship strategy has already been selected and it's stored in `ShipStrat`, the local strategy
233/// is hash and now the join variant has to be selected.
234///
235/// Note that `outer` join is not supported if the ship strategy is `broadcast_right`.
236pub struct JoinStreamLocalSortMerge<
237    Key: Data + Ord,
238    Out1: ExchangeData,
239    Out2: ExchangeData,
240    Keyer1: KeyerFn<Key, Out1>,
241    Keyer2: KeyerFn<Key, Out2>,
242    ShipStrat: ShipStrategy,
243> {
244    stream: Stream<BinaryStartOperator<Out1, Out2>>,
245    keyer1: Keyer1,
246    keyer2: Keyer2,
247    _key: PhantomData<Key>,
248    _s: PhantomData<ShipStrat>,
249}
250
251impl<Key: Data + Ord, Out1: ExchangeData, Out2: ExchangeData, Keyer1, Keyer2, ShipStrat>
252    JoinStreamLocalSortMerge<Key, Out1, Out2, Keyer1, Keyer2, ShipStrat>
253where
254    Keyer1: KeyerFn<Key, Out1>,
255    Keyer2: KeyerFn<Key, Out2>,
256    ShipStrat: ShipStrategy,
257{
258    pub(crate) fn new(
259        stream: Stream<BinaryStartOperator<Out1, Out2>>,
260        keyer1: Keyer1,
261        keyer2: Keyer2,
262    ) -> Self {
263        Self {
264            stream,
265            keyer1,
266            keyer2,
267            _key: Default::default(),
268            _s: Default::default(),
269        }
270    }
271}
272
273impl<Key: DataKey + Ord, Out1: ExchangeData, Out2: ExchangeData, Keyer1, Keyer2>
274    JoinStreamLocalSortMerge<Key, Out1, Out2, Keyer1, Keyer2, ShipHash>
275where
276    Keyer1: KeyerFn<Key, Out1>,
277    Keyer2: KeyerFn<Key, Out2>,
278{
279    /// Finalize the join operator by specifying that this is an _inner join_.
280    ///
281    /// Given two stream, create a stream with all the pairs (left item from the left stream, right
282    /// item from the right), such that the key obtained with `keyer1` on an item from the left is
283    /// equal to the key obtained with `keyer2` on an item from the right.
284    ///
285    /// This is an inner join, very similarly to `SELECT a, b FROM a JOIN b ON keyer1(a) = keyer2(b)`.
286    ///
287    /// **Note**: this operator will split the current block.
288    pub fn inner(self) -> KeyedStream<impl Operator<Out = (Key, InnerJoinTuple<Out1, Out2>)>> {
289        let keyer1 = self.keyer1;
290        let keyer2 = self.keyer2;
291        let inner = self
292            .stream
293            .add_operator(|prev| JoinLocalSortMerge::new(prev, JoinVariant::Inner, keyer1, keyer2));
294        KeyedStream(inner.map(|(key, (lhs, rhs))| (key, (lhs.unwrap(), rhs.unwrap()))))
295    }
296
297    /// Finalize the join operator by specifying that this is a _left join_.
298    ///
299    /// Given two stream, create a stream with all the pairs (left item from the left stream, right
300    /// item from the right), such that the key obtained with `keyer1` on an item from the left is
301    /// equal to the key obtained with `keyer2` on an item from the right.
302    ///
303    /// This is a **left** join, meaning that if an item from the left does not find and element
304    /// from the right with which make a pair, an extra pair `(left, None)` is generated. If you
305    /// want to have a _right_ join, you just need to switch the two sides and use a left join.
306    ///
307    /// This is very similar to `SELECT a, b FROM a LEFT JOIN b ON keyer1(a) = keyer2(b)`.    
308    ///
309    /// **Note**: this operator will split the current block.
310    pub fn left(self) -> KeyedStream<impl Operator<Out = (Key, LeftJoinTuple<Out1, Out2>)>> {
311        let keyer1 = self.keyer1;
312        let keyer2 = self.keyer2;
313        let inner = self
314            .stream
315            .add_operator(|prev| JoinLocalSortMerge::new(prev, JoinVariant::Left, keyer1, keyer2));
316        KeyedStream(inner.map(|(key, (lhs, rhs))| (key, (lhs.unwrap(), rhs))))
317    }
318
319    /// Finalize the join operator by specifying that this is an _outer join_.
320    ///
321    /// Given two stream, create a stream with all the pairs (left item from the left stream, right
322    /// item from the right), such that the key obtained with `keyer1` on an item from the left is
323    /// equal to the key obtained with `keyer2` on an item from the right.
324    ///
325    /// This is a **full-outer** join, meaning that if an item from the left does not find and element
326    /// from the right with which make a pair, an extra pair `(left, None)` is generated. Similarly
327    /// if an element from the right does not appear in any pair, a new one is generated with
328    /// `(None, right)`.
329    ///
330    /// This is very similar to `SELECT a, b FROM a FULL OUTER JOIN b ON keyer1(a) = keyer2(b)`.
331    ///
332    /// **Note**: this operator will split the current block.
333    pub fn outer(self) -> KeyedStream<impl Operator<Out = (Key, OuterJoinTuple<Out1, Out2>)>> {
334        let keyer1 = self.keyer1;
335        let keyer2 = self.keyer2;
336        let inner = self
337            .stream
338            .add_operator(|prev| JoinLocalSortMerge::new(prev, JoinVariant::Outer, keyer1, keyer2));
339        KeyedStream(inner)
340    }
341}
342
343impl<Key: Data + Ord, Out1: ExchangeData, Out2: ExchangeData, Keyer1, Keyer2>
344    JoinStreamLocalSortMerge<Key, Out1, Out2, Keyer1, Keyer2, ShipBroadcastRight>
345where
346    Keyer1: KeyerFn<Key, Out1>,
347    Keyer2: KeyerFn<Key, Out2>,
348{
349    /// Finalize the join operator by specifying that this is an _inner join_.
350    ///
351    /// Given two stream, create a stream with all the pairs (left item from the left stream, right
352    /// item from the right), such that the key obtained with `keyer1` on an item from the left is
353    /// equal to the key obtained with `keyer2` on an item from the right.
354    ///
355    /// This is an inner join, very similarly to `SELECT a, b FROM a JOIN b ON keyer1(a) = keyer2(b)`.
356    ///
357    /// **Note**: this operator will split the current block.
358    pub fn inner(self) -> Stream<impl Operator<Out = (Key, InnerJoinTuple<Out1, Out2>)>> {
359        let keyer1 = self.keyer1;
360        let keyer2 = self.keyer2;
361        self.stream
362            .add_operator(|prev| JoinLocalSortMerge::new(prev, JoinVariant::Inner, keyer1, keyer2))
363            .map(|(key, (lhs, rhs))| (key, (lhs.unwrap(), rhs.unwrap())))
364    }
365
366    /// Finalize the join operator by specifying that this is a _left join_.
367    ///
368    /// Given two stream, create a stream with all the pairs (left item from the left stream, right
369    /// item from the right), such that the key obtained with `keyer1` on an item from the left is
370    /// equal to the key obtained with `keyer2` on an item from the right.
371    ///
372    /// This is a **left** join, meaning that if an item from the left does not find and element
373    /// from the right with which make a pair, an extra pair `(left, None)` is generated. If you
374    /// want to have a _right_ join, you just need to switch the two sides and use a left join.
375    ///
376    /// This is very similar to `SELECT a, b FROM a LEFT JOIN b ON keyer1(a) = keyer2(b)`.    
377    ///
378    /// **Note**: this operator will split the current block.
379    pub fn left(self) -> Stream<impl Operator<Out = (Key, LeftJoinTuple<Out1, Out2>)>> {
380        let keyer1 = self.keyer1;
381        let keyer2 = self.keyer2;
382        self.stream
383            .add_operator(|prev| JoinLocalSortMerge::new(prev, JoinVariant::Left, keyer1, keyer2))
384            .map(|(key, (lhs, rhs))| (key, (lhs.unwrap(), rhs)))
385    }
386}