noir_compute/operator/join/
keyed_join.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    fmt::{Debug, Display},
4    marker::PhantomData,
5};
6
7use crate::{
8    block::{NextStrategy, OperatorStructure},
9    network::Coord,
10    operator::{
11        BinaryElement, BinaryStartOperator, Data, DataKey, ExchangeData, Operator, Start,
12        StreamElement,
13    },
14    KeyedStream,
15};
16
17use super::{InnerJoinTuple, JoinVariant, OuterJoinTuple};
18
19type BinaryTuple<K, V1, V2> = BinaryElement<(K, V1), (K, V2)>;
20
21/// This type keeps the elements of a side of the join.
22#[derive(Debug, Clone)]
23struct SideHashMap<Key: DataKey, Out> {
24    /// The actual items on this side, grouped by key.
25    ///
26    /// Note that when the other side ends this map is emptied.
27    data: HashMap<Key, Vec<Out>, crate::block::GroupHasherBuilder>,
28    /// The set of all the keys seen.
29    ///
30    /// Note that when this side ends this set is emptied since it won't be used again.
31    keys: HashSet<Key>,
32    /// Whether this side has ended.
33    ended: bool,
34    /// The number of items received.
35    count: usize,
36}
37
38impl<Key: DataKey, Out> Default for SideHashMap<Key, Out> {
39    fn default() -> Self {
40        Self {
41            data: Default::default(),
42            keys: Default::default(),
43            ended: false,
44            count: 0,
45        }
46    }
47}
48
49#[derive(Clone)]
50struct JoinKeyedOuter<K: DataKey + ExchangeData, V1: ExchangeData, V2: ExchangeData> {
51    prev: BinaryStartOperator<(K, V1), (K, V2)>,
52    variant: JoinVariant,
53    _k: PhantomData<K>,
54    _v1: PhantomData<V1>,
55    _v2: PhantomData<V2>,
56    coord: Option<Coord>,
57
58    /// The content of the left side.
59    left: SideHashMap<K, V1>,
60    /// The content of the right side.
61    right: SideHashMap<K, V2>,
62
63    buffer: VecDeque<(K, OuterJoinTuple<V1, V2>)>,
64}
65
66impl<K: DataKey + ExchangeData, V1: ExchangeData, V2: ExchangeData> JoinKeyedOuter<K, V1, V2> {
67    pub(crate) fn new(prev: BinaryStartOperator<(K, V1), (K, V2)>, variant: JoinVariant) -> Self {
68        JoinKeyedOuter {
69            prev,
70            variant,
71            _k: PhantomData,
72            _v1: PhantomData,
73            _v2: PhantomData,
74            coord: Default::default(),
75            left: Default::default(),
76            right: Default::default(),
77            buffer: Default::default(),
78        }
79    }
80
81    fn process_item(&mut self, item: BinaryTuple<K, V1, V2>) {
82        let left_outer = self.variant.left_outer();
83        let right_outer = self.variant.right_outer();
84        match item {
85            BinaryElement::Left((key, v1)) => {
86                self.left.count += 1;
87                if let Some(right) = self.right.data.get(&key) {
88                    // the left item has at least one right matching element
89                    for v2 in right {
90                        self.buffer
91                            .push_back((key.clone(), (Some(v1.clone()), Some(v2.clone()))));
92                    }
93                } else if self.right.ended && left_outer {
94                    // if the left item has no right correspondent, but the right has already ended
95                    // we might need to generate the outer tuple.
96                    self.buffer
97                        .push_back((key.clone(), (Some(v1.clone()), None)));
98                }
99                if right_outer {
100                    self.left.keys.insert(key.clone());
101                }
102                if !self.right.ended {
103                    self.left.data.entry(key).or_default().push(v1);
104                }
105            }
106            BinaryElement::Right((key, v2)) => {
107                self.right.count += 1;
108                if let Some(left) = self.left.data.get(&key) {
109                    // the left item has at least one right matching element
110                    for v1 in left {
111                        self.buffer
112                            .push_back((key.clone(), (Some(v1.clone()), Some(v2.clone()))));
113                    }
114                } else if self.left.ended && right_outer {
115                    // if the left item has no right correspondent, but the right has already ended
116                    // we might need to generate the outer tuple.
117                    self.buffer
118                        .push_back((key.clone(), (None, Some(v2.clone()))));
119                }
120                if left_outer {
121                    self.right.keys.insert(key.clone());
122                }
123                if !self.left.ended {
124                    self.right.data.entry(key).or_default().push(v2);
125                }
126            }
127            BinaryElement::LeftEnd => {
128                log::debug!(
129                    "Left side of join ended with {} elements on the left \
130                    and {} elements on the right",
131                    self.left.count,
132                    self.right.count
133                );
134                if right_outer {
135                    // left ended and this is a right-outer, so we need to generate (None, Some)
136                    // tuples. For each value on the right side, before dropping the right hashmap,
137                    // search if there was already a match.
138                    for (key, right) in self.right.data.drain() {
139                        if !self.left.keys.contains(&key) {
140                            for v2 in right {
141                                self.buffer.push_back((key.clone(), (None, Some(v2))));
142                            }
143                        }
144                    }
145                } else {
146                    // in any case, we won't need the right hashmap anymore.
147                    self.right.data.clear();
148                }
149                // we will never look at it, and nothing will be inserted, drop it freeing some memory.
150                self.left.keys.clear();
151                self.left.ended = true;
152            }
153            BinaryElement::RightEnd => {
154                log::debug!(
155                    "Right side of join ended with {} elements on the left \
156                    and {} elements on the right",
157                    self.left.count,
158                    self.right.count
159                );
160                if left_outer {
161                    // right ended and this is a left-outer, so we need to generate (None, Some)
162                    // tuples. For each value on the left side, before dropping the left hashmap,
163                    // search if there was already a match.
164                    for (key, left) in self.left.data.drain() {
165                        if !self.right.keys.contains(&key) {
166                            for v1 in left {
167                                self.buffer.push_back((key.clone(), (Some(v1), None)));
168                            }
169                        }
170                    }
171                } else {
172                    // in any case, we won't need the left hashmap anymore.
173                    self.left.data.clear();
174                }
175                // we will never look at it, and nothing will be inserted, drop it freeing some memory.
176                self.right.keys.clear();
177                self.right.ended = true;
178            }
179        }
180    }
181}
182
183impl<K: DataKey + ExchangeData, V1: ExchangeData, V2: ExchangeData> Display
184    for JoinKeyedOuter<K, V1, V2>
185{
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        write!(
188            f,
189            "{} -> JoinKeyed<{},{},{}>",
190            self.prev,
191            std::any::type_name::<K>(),
192            std::any::type_name::<V1>(),
193            std::any::type_name::<V2>(),
194        )
195    }
196}
197
198impl<K: DataKey + ExchangeData, V1: ExchangeData, V2: ExchangeData> Operator
199    for JoinKeyedOuter<K, V1, V2>
200{
201    type Out = (K, OuterJoinTuple<V1, V2>);
202
203    fn setup(&mut self, metadata: &mut crate::ExecutionMetadata) {
204        self.prev.setup(metadata);
205        self.coord = Some(metadata.coord);
206    }
207
208    fn next(&mut self) -> crate::operator::StreamElement<(K, OuterJoinTuple<V1, V2>)> {
209        while self.buffer.is_empty() {
210            match self.prev.next() {
211                StreamElement::Item(el) => self.process_item(el),
212                StreamElement::FlushAndRestart => {
213                    assert!(self.left.ended);
214                    assert!(self.right.ended);
215                    assert!(self.left.data.is_empty());
216                    assert!(self.right.data.is_empty());
217                    assert!(self.left.keys.is_empty());
218                    assert!(self.right.keys.is_empty());
219                    self.left.ended = false;
220                    self.left.count = 0;
221                    self.right.ended = false;
222                    self.right.count = 0;
223                    log::debug!(
224                        "JoinLocalHash at {} emitted FlushAndRestart",
225                        self.coord.unwrap()
226                    );
227                    return StreamElement::FlushAndRestart;
228                }
229                StreamElement::Terminate => return StreamElement::Terminate,
230                StreamElement::FlushBatch => return StreamElement::FlushBatch,
231                StreamElement::Watermark(_) | StreamElement::Timestamped(_, _) => {
232                    panic!("Cannot yet join timestamped streams")
233                }
234            }
235        }
236
237        let item = self.buffer.pop_front().unwrap();
238        StreamElement::Item(item)
239    }
240
241    fn structure(&self) -> crate::block::BlockStructure {
242        self.prev.structure().add_operator(
243            OperatorStructure::new::<(K, InnerJoinTuple<V1, V2>), _>("JoinKeyed"),
244        )
245    }
246}
247
248#[derive(Clone)]
249struct JoinKeyedInner<K: DataKey + ExchangeData, V1: ExchangeData, V2: ExchangeData> {
250    prev: BinaryStartOperator<(K, V1), (K, V2)>,
251    _k: PhantomData<K>,
252    _v1: PhantomData<V1>,
253    _v2: PhantomData<V2>,
254    coord: Option<Coord>,
255
256    /// The content of the left side.
257    left: HashMap<K, Vec<V1>, crate::block::CoordHasherBuilder>,
258    /// The content of the right side.
259    right: HashMap<K, Vec<V2>, crate::block::CoordHasherBuilder>,
260
261    left_ended: bool,
262    right_ended: bool,
263
264    buffer: VecDeque<(K, InnerJoinTuple<V1, V2>)>,
265}
266
267impl<K: DataKey + ExchangeData, V1: ExchangeData, V2: ExchangeData> Display
268    for JoinKeyedInner<K, V1, V2>
269{
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        write!(
272            f,
273            "{} -> JoinKeyedInner<{},{},{}>",
274            self.prev,
275            std::any::type_name::<K>(),
276            std::any::type_name::<V1>(),
277            std::any::type_name::<V2>(),
278        )
279    }
280}
281
282impl<K: DataKey + ExchangeData + Debug, V1: ExchangeData + Debug, V2: ExchangeData + Debug>
283    JoinKeyedInner<K, V1, V2>
284{
285    pub(crate) fn new(prev: BinaryStartOperator<(K, V1), (K, V2)>) -> Self {
286        JoinKeyedInner {
287            prev,
288            _k: PhantomData,
289            _v1: PhantomData,
290            _v2: PhantomData,
291            coord: Default::default(),
292            left: Default::default(),
293            right: Default::default(),
294            buffer: Default::default(),
295            left_ended: false,
296            right_ended: false,
297        }
298    }
299
300    fn process_item(&mut self, item: BinaryTuple<K, V1, V2>) {
301        match item {
302            BinaryElement::Left((key, v1)) => {
303                if let Some(right) = self.right.get(&key) {
304                    // the left item has at least one right matching element
305                    for v2 in right {
306                        self.buffer
307                            .push_back((key.clone(), (v1.clone(), v2.clone())));
308                    }
309                }
310                self.left.entry(key).or_default().push(v1);
311            }
312            BinaryElement::Right((key, v2)) => {
313                if let Some(left) = self.left.get(&key) {
314                    // the left item has at least one right matching element
315                    for v1 in left {
316                        self.buffer
317                            .push_back((key.clone(), (v1.clone(), v2.clone())));
318                    }
319                }
320                self.right.entry(key).or_default().push(v2);
321            }
322            BinaryElement::LeftEnd => {
323                self.left_ended = true;
324                self.right.clear();
325                if self.right_ended {
326                    self.left.clear();
327                    self.right.clear();
328                }
329            }
330            BinaryElement::RightEnd => {
331                self.right_ended = true;
332                self.left.clear();
333                if self.left_ended {
334                    self.left.clear();
335                    self.right.clear();
336                }
337            }
338        }
339    }
340}
341
342impl<K: DataKey + ExchangeData + Debug, V1: ExchangeData + Debug, V2: ExchangeData + Debug> Operator
343    for JoinKeyedInner<K, V1, V2>
344{
345    type Out = (K, InnerJoinTuple<V1, V2>);
346
347    fn setup(&mut self, metadata: &mut crate::ExecutionMetadata) {
348        self.coord = Some(metadata.coord);
349        self.prev.setup(metadata);
350    }
351
352    fn next(&mut self) -> crate::operator::StreamElement<(K, InnerJoinTuple<V1, V2>)> {
353        while self.buffer.is_empty() {
354            match self.prev.next() {
355                StreamElement::Item(el) => self.process_item(el),
356                StreamElement::FlushAndRestart => {
357                    assert!(self.left.is_empty());
358                    assert!(self.right.is_empty());
359                    log::debug!(
360                        "JoinLocalHash at {} emitted FlushAndRestart",
361                        self.coord.unwrap()
362                    );
363                    self.left_ended = false;
364                    self.right_ended = false;
365                    return StreamElement::FlushAndRestart;
366                }
367                StreamElement::Terminate => return StreamElement::Terminate,
368                StreamElement::FlushBatch => return StreamElement::FlushBatch,
369                StreamElement::Watermark(_) | StreamElement::Timestamped(_, _) => {
370                    panic!("Cannot yet join timestamped streams")
371                }
372            }
373        }
374
375        let item = self.buffer.pop_front().unwrap();
376        StreamElement::Item(item)
377    }
378
379    fn structure(&self) -> crate::block::BlockStructure {
380        self.prev.structure().add_operator(
381            OperatorStructure::new::<(K, InnerJoinTuple<V1, V2>), _>("JoinKeyed"),
382        )
383    }
384}
385
386impl<K: DataKey + ExchangeData + Debug, V1: Data + ExchangeData + Debug, O1> KeyedStream<O1>
387where
388    O1: Operator<Out = (K, V1)> + 'static,
389{
390    pub fn join_outer<V2: Data + ExchangeData + Debug, O2>(
391        self,
392        rhs: KeyedStream<O2>,
393    ) -> KeyedStream<impl Operator<Out = (K, (Option<V1>, Option<V2>))>>
394    where
395        O2: Operator<Out = (K, V2)> + 'static,
396    {
397        let next_strategy1 = NextStrategy::only_one();
398        let next_strategy2 = NextStrategy::only_one();
399
400        let inner =
401            self.0
402                .binary_connection(rhs.0, Start::multiple, next_strategy1, next_strategy2);
403
404        let s = inner.add_operator(move |prev| JoinKeyedOuter::new(prev, JoinVariant::Outer));
405        KeyedStream(s)
406    }
407
408    pub fn join<V2: Data + ExchangeData + Debug, O2>(
409        self,
410        rhs: KeyedStream<O2>,
411    ) -> KeyedStream<impl Operator<Out = (K, (V1, V2))>>
412    where
413        O2: Operator<Out = (K, V2)> + 'static,
414    {
415        let next_strategy1 = NextStrategy::only_one();
416        let next_strategy2 = NextStrategy::only_one();
417
418        let inner =
419            self.0
420                .binary_connection(rhs.0, Start::multiple, next_strategy1, next_strategy2);
421
422        let s = inner.add_operator(move |prev| JoinKeyedInner::new(prev));
423        KeyedStream(s)
424    }
425}