noir_compute/operator/window/aggr/
join.rs1use super::super::*;
2use crate::operator::merge::MergeElement;
3use crate::operator::{Data, DataKey, Operator};
4use crate::stream::KeyedStream;
5
6#[derive(Clone)]
7struct Join<L, R> {
8 left: Vec<L>,
9 right: Vec<R>,
10}
11
12impl<L: Data, R: Data> WindowAccumulator for Join<L, R> {
13 type In = MergeElement<L, R>;
14 type Out = ProductIterator<L, R>; #[inline]
17 fn process(&mut self, el: Self::In) {
18 match el {
19 MergeElement::Left(l) => self.left.push(l),
20 MergeElement::Right(r) => self.right.push(r),
21 }
22 }
23
24 #[inline]
25 fn output(mut self) -> Self::Out {
26 ProductIterator::new(
27 std::mem::take(&mut self.left),
28 std::mem::take(&mut self.right),
29 )
30 }
31}
32
33#[derive(Clone)]
34struct ProductIterator<L, R> {
35 left: Vec<L>,
36 right: Vec<R>,
37 i: usize,
38 j: usize,
39}
40
41impl<L, R> ProductIterator<L, R> {
42 fn new(left: Vec<L>, right: Vec<R>) -> Self {
43 Self {
44 left,
45 right,
46 i: 0,
47 j: 0,
48 }
49 }
50}
51
52impl<L: Clone, R: Clone> Iterator for ProductIterator<L, R> {
53 type Item = (L, R);
54
55 #[inline]
56 fn next(&mut self) -> Option<Self::Item> {
57 if self.i >= self.left.len() || self.j >= self.right.len() {
58 return None;
59 }
60
61 let ret = (self.left[self.i].clone(), self.right[self.j].clone());
62
63 self.j += 1;
64 if self.j >= self.right.len() {
65 self.j = 0;
66 self.i += 1;
67 }
68
69 Some(ret)
70 }
71}
72
73impl<Key, Out, OperatorChain> KeyedStream<OperatorChain>
74where
75 OperatorChain: Operator<Out = (Key, Out)> + 'static,
76 Key: ExchangeData + DataKey,
77 Out: ExchangeData,
78{
79 pub fn window_join<Out2, OperatorChain2, WindowDescr>(
80 self,
81 descr: WindowDescr,
82 right: KeyedStream<OperatorChain2>,
83 ) -> KeyedStream<impl Operator<Out = (Key, (Out, Out2))>>
84 where
85 OperatorChain2: Operator<Out = (Key, Out2)> + 'static,
86 Out2: ExchangeData,
87 WindowDescr: WindowDescription<MergeElement<Out, Out2>> + 'static,
88 {
89 let acc = Join::<Out, Out2> {
90 left: Default::default(),
91 right: Default::default(),
92 };
93
94 self.merge_distinct(right)
95 .window(descr)
96 .add_window_operator("WindowJoin", acc)
97 .flatten()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn product_iterator() {
107 let t = ProductIterator::new(vec![1], vec!["asd"]).collect::<Vec<_>>();
108 let expected = vec![(1, "asd")];
109
110 assert_eq!(expected, t);
111
112 let t = ProductIterator::new(vec![1, 3, 5], vec![2, 4]).collect::<Vec<_>>();
113 let expected = vec![(1, 2), (1, 4), (3, 2), (3, 4), (5, 2), (5, 4)];
114
115 assert_eq!(expected, t);
116
117 let t = ProductIterator::new(vec![1, 3, 5], vec![]).collect::<Vec<(usize, usize)>>();
118 let expected: Vec<(usize, usize)> = vec![];
119
120 assert_eq!(expected, t);
121
122 let t = ProductIterator::new(vec![], vec![1, 3, 5]).collect::<Vec<(usize, usize)>>();
123 let expected: Vec<(usize, usize)> = vec![];
124
125 assert_eq!(expected, t);
126 }
127}