noir_compute/operator/window/aggr/
join.rs

1use super::super::*;
2use crate::operator::merge::MergeElement;
3use crate::operator::{Data, DataKey, Operator};
4use crate::stream::KeyedStream;
5
6#[derive(Clone)]
7struct Join<L, R> {
8    left: Vec<L>,
9    right: Vec<R>,
10}
11
12impl<L: Data, R: Data> WindowAccumulator for Join<L, R> {
13    type In = MergeElement<L, R>;
14    type Out = ProductIterator<L, R>; // TODO: may have more efficient formulations
15
16    #[inline]
17    fn process(&mut self, el: Self::In) {
18        match el {
19            MergeElement::Left(l) => self.left.push(l),
20            MergeElement::Right(r) => self.right.push(r),
21        }
22    }
23
24    #[inline]
25    fn output(mut self) -> Self::Out {
26        ProductIterator::new(
27            std::mem::take(&mut self.left),
28            std::mem::take(&mut self.right),
29        )
30    }
31}
32
33#[derive(Clone)]
34struct ProductIterator<L, R> {
35    left: Vec<L>,
36    right: Vec<R>,
37    i: usize,
38    j: usize,
39}
40
41impl<L, R> ProductIterator<L, R> {
42    fn new(left: Vec<L>, right: Vec<R>) -> Self {
43        Self {
44            left,
45            right,
46            i: 0,
47            j: 0,
48        }
49    }
50}
51
52impl<L: Clone, R: Clone> Iterator for ProductIterator<L, R> {
53    type Item = (L, R);
54
55    #[inline]
56    fn next(&mut self) -> Option<Self::Item> {
57        if self.i >= self.left.len() || self.j >= self.right.len() {
58            return None;
59        }
60
61        let ret = (self.left[self.i].clone(), self.right[self.j].clone());
62
63        self.j += 1;
64        if self.j >= self.right.len() {
65            self.j = 0;
66            self.i += 1;
67        }
68
69        Some(ret)
70    }
71}
72
73impl<Key, Out, OperatorChain> KeyedStream<OperatorChain>
74where
75    OperatorChain: Operator<Out = (Key, Out)> + 'static,
76    Key: ExchangeData + DataKey,
77    Out: ExchangeData,
78{
79    pub fn window_join<Out2, OperatorChain2, WindowDescr>(
80        self,
81        descr: WindowDescr,
82        right: KeyedStream<OperatorChain2>,
83    ) -> KeyedStream<impl Operator<Out = (Key, (Out, Out2))>>
84    where
85        OperatorChain2: Operator<Out = (Key, Out2)> + 'static,
86        Out2: ExchangeData,
87        WindowDescr: WindowDescription<MergeElement<Out, Out2>> + 'static,
88    {
89        let acc = Join::<Out, Out2> {
90            left: Default::default(),
91            right: Default::default(),
92        };
93
94        self.merge_distinct(right)
95            .window(descr)
96            .add_window_operator("WindowJoin", acc)
97            .flatten()
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn product_iterator() {
107        let t = ProductIterator::new(vec![1], vec!["asd"]).collect::<Vec<_>>();
108        let expected = vec![(1, "asd")];
109
110        assert_eq!(expected, t);
111
112        let t = ProductIterator::new(vec![1, 3, 5], vec![2, 4]).collect::<Vec<_>>();
113        let expected = vec![(1, 2), (1, 4), (3, 2), (3, 4), (5, 2), (5, 4)];
114
115        assert_eq!(expected, t);
116
117        let t = ProductIterator::new(vec![1, 3, 5], vec![]).collect::<Vec<(usize, usize)>>();
118        let expected: Vec<(usize, usize)> = vec![];
119
120        assert_eq!(expected, t);
121
122        let t = ProductIterator::new(vec![], vec![1, 3, 5]).collect::<Vec<(usize, usize)>>();
123        let expected: Vec<(usize, usize)> = vec![];
124
125        assert_eq!(expected, t);
126    }
127}