1use crate::{
18 DBData, ZWeight,
19 algebra::{IndexedZSet, IndexedZSetReader, MulByRef, OrdIndexedZSet, OrdZSet},
20 circuit::{
21 Circuit, Scope, Stream,
22 operator_traits::{BinaryOperator, Operator},
23 },
24 dynamic::{DataTrait, DynUnit, Erase},
25 trace::{BatchFactories, BatchReaderFactories, Cursor},
26};
27use std::{borrow::Cow, marker::PhantomData};
28
29pub struct StreamJoinRangeFactories<I, O>
30where
31 I: IndexedZSetReader,
32 O: IndexedZSet,
33{
34 input2_factories: I::Factories,
35 output_factories: O::Factories,
36}
37
38impl<I: IndexedZSetReader, O: IndexedZSet> Clone for StreamJoinRangeFactories<I, O> {
39 fn clone(&self) -> Self {
40 Self {
41 input2_factories: self.input2_factories.clone(),
42 output_factories: self.output_factories.clone(),
43 }
44 }
45}
46
47impl<I, O> StreamJoinRangeFactories<I, O>
48where
49 I: IndexedZSetReader,
50 O: IndexedZSet,
51{
52 pub fn new<IKType, IVType, OKType, OVType>() -> Self
53 where
54 IKType: DBData + Erase<I::Key>,
55 IVType: DBData + Erase<I::Val>,
56 OKType: DBData + Erase<O::Key>,
57 OVType: DBData + Erase<O::Val>,
58 {
59 Self {
60 input2_factories: BatchReaderFactories::new::<IKType, IVType, ZWeight>(),
61 output_factories: BatchReaderFactories::new::<OKType, OVType, ZWeight>(),
62 }
63 }
64}
65
66impl<C, I1> Stream<C, I1>
67where
68 C: Circuit,
69{
70 pub fn dyn_stream_join_range<I2, V>(
72 &self,
73 factories: &StreamJoinRangeFactories<I2, OrdZSet<V>>,
74 other: &Stream<C, I2>,
75 range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
76 join_func: Box<
77 dyn Fn(&I1::Key, &I1::Val, &I2::Key, &I2::Val, &mut dyn FnMut(&mut V, &mut DynUnit)),
78 >,
79 ) -> Stream<C, OrdZSet<V>>
80 where
81 I1: IndexedZSetReader + Clone,
82 I2: IndexedZSetReader + Clone,
83 V: DataTrait + ?Sized,
84 {
85 self.dyn_stream_join_range_generic(factories, other, range_func, join_func)
86 }
87
88 pub fn dyn_stream_join_range_index<K, V, I2>(
90 &self,
91 factories: &StreamJoinRangeFactories<I2, OrdIndexedZSet<K, V>>,
92 other: &Stream<C, I2>,
93 range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
94 join_func: Box<
95 dyn Fn(&I1::Key, &I1::Val, &I2::Key, &I2::Val, &mut dyn FnMut(&mut K, &mut V)),
96 >,
97 ) -> Stream<C, OrdIndexedZSet<K, V>>
98 where
99 I1: IndexedZSetReader + Clone,
100 I2: IndexedZSetReader + Clone,
101 K: DataTrait + ?Sized,
102 V: DataTrait + ?Sized,
103 {
104 self.dyn_stream_join_range_generic(factories, other, range_func, join_func)
105 }
106
107 pub fn dyn_stream_join_range_generic<I2, O>(
110 &self,
111 factories: &StreamJoinRangeFactories<I2, O>,
112 other: &Stream<C, I2>,
113 range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
114 join_func: Box<
115 dyn Fn(
116 &I1::Key,
117 &I1::Val,
118 &I2::Key,
119 &I2::Val,
120 &mut dyn FnMut(&mut O::Key, &mut O::Val),
121 ),
122 >,
123 ) -> Stream<C, O>
124 where
125 I1: IndexedZSetReader + Clone,
126 I2: IndexedZSetReader + Clone,
127 O: IndexedZSet,
128 {
129 self.circuit().add_binary_operator(
130 StreamJoinRange::new(factories, range_func, join_func),
131 self,
132 other,
133 )
134 }
135}
136
137pub struct StreamJoinRange<I1, I2, O>
138where
139 I1: IndexedZSetReader,
140 I2: IndexedZSetReader,
141 O: IndexedZSet,
142{
143 factories: StreamJoinRangeFactories<I2, O>,
144 range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
145 join_func: Box<
146 dyn Fn(&I1::Key, &I1::Val, &I2::Key, &I2::Val, &mut dyn FnMut(&mut O::Key, &mut O::Val)),
147 >,
148 _types: PhantomData<(I1, I2, O)>,
149}
150
151impl<I1, I2, O> StreamJoinRange<I1, I2, O>
152where
153 I1: IndexedZSetReader,
154 I2: IndexedZSetReader,
155 O: IndexedZSet,
156{
157 pub fn new(
158 factories: &StreamJoinRangeFactories<I2, O>,
159 range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
160 join_func: Box<
161 dyn Fn(
162 &I1::Key,
163 &I1::Val,
164 &I2::Key,
165 &I2::Val,
166 &mut dyn FnMut(&mut O::Key, &mut O::Val),
167 ),
168 >,
169 ) -> Self {
170 Self {
171 factories: factories.clone(),
172 range_func,
173 join_func,
174 _types: PhantomData,
175 }
176 }
177}
178
179impl<I1, I2, O> Operator for StreamJoinRange<I1, I2, O>
180where
181 I1: IndexedZSetReader,
182 I2: IndexedZSetReader,
183 O: IndexedZSet,
184{
185 fn name(&self) -> Cow<'static, str> {
186 Cow::from("StreamJoinRange")
187 }
188 fn fixedpoint(&self, _scope: Scope) -> bool {
189 true
190 }
191}
192
193impl<I1, I2, O> BinaryOperator<I1, I2, O> for StreamJoinRange<I1, I2, O>
194where
195 I1: IndexedZSetReader + Clone,
196 I2: IndexedZSetReader + Clone,
197 O: IndexedZSet,
198{
199 async fn eval(&mut self, i1: &I1, i2: &I2) -> O {
200 let mut tuples = self
201 .factories
202 .output_factories
203 .weighted_items_factory()
204 .default_box();
205
206 let mut item = self
207 .factories
208 .output_factories
209 .weighted_item_factory()
210 .default_box();
211
212 let mut i1_cursor = i1.cursor();
213 let mut i2_cursor = i2.cursor();
214
215 let mut lower = self.factories.input2_factories.key_factory().default_box();
216 let mut upper = self.factories.input2_factories.key_factory().default_box();
217
218 while i1_cursor.key_valid() {
220 (self.range_func)(i1_cursor.key(), lower.as_mut(), upper.as_mut());
222
223 i2_cursor.rewind_keys();
226 i2_cursor.seek_key(&lower);
227
228 while i2_cursor.key_valid() && i2_cursor.key() < &upper {
230 i1_cursor.rewind_vals();
232 while i1_cursor.val_valid() {
233 let w1 = **i1_cursor.weight();
234 let k1 = i1_cursor.key();
235 let v1 = i1_cursor.val();
236 i2_cursor.rewind_vals();
237
238 while i2_cursor.val_valid() {
239 let w2 = **i2_cursor.weight();
240 let w = w1.mul_by_ref(&w2);
241
242 (self.join_func)(k1, v1, i2_cursor.key(), i2_cursor.val(), &mut |k, v| {
244 let (kv, weight) = item.split_mut();
245 kv.from_vals(k, v);
246 **weight = w;
247 tuples.push_val(item.as_mut());
248 });
249 i2_cursor.step_val();
250 }
251 i1_cursor.step_val();
252 }
253 i2_cursor.step_key();
254 }
255 i1_cursor.step_key();
256 }
257
258 O::dyn_from_tuples(&self.factories.output_factories, (), &mut tuples)
259 }
260}
261
262#[cfg(test)]
263mod test {
264 use crate::{Circuit, RootCircuit, operator::Generator, utils::Tup2, zset};
265
266 #[test]
267 fn stream_join_range_test() {
268 let circuit = RootCircuit::build(move |circuit| {
269 let mut input1 = vec![
270 zset! {
271 Tup2(1, 'a') => 1,
272 Tup2(1, 'b') => 2,
273 Tup2(2, 'c') => 3,
274 Tup2(2, 'd') => 4,
275 Tup2(3, 'e') => 5,
276 Tup2(3, 'f') => -2,
277 },
278 zset! {Tup2(1, 'a') => 1},
279 zset! {Tup2(1, 'a') => 1},
280 zset! {Tup2(4, 'n') => 2},
281 zset! {Tup2(1, 'a') => 0},
282 ]
283 .into_iter();
284 let mut input2 = vec![
285 zset! {
286 Tup2(2, 'g') => 3,
287 Tup2(2, 'h') => 4,
288 Tup2(3, 'i') => 5,
289 Tup2(3, 'j') => -2,
290 Tup2(4, 'k') => 5,
291 Tup2(4, 'l') => -2,
292 },
293 zset! {Tup2(1, 'b') => 1},
294 zset! {Tup2(4, 'm') => 1},
295 zset! {},
296 zset! {},
297 ]
298 .into_iter();
299 let mut outputs = vec![
300 zset! {
301 Tup2(Tup2(1, 'a'), Tup2(2, 'g')) => 3,
302 Tup2(Tup2(1, 'a'), Tup2(2, 'h')) => 4,
303 Tup2(Tup2(1, 'b'), Tup2(2, 'g')) => 6,
304 Tup2(Tup2(1, 'b'), Tup2(2, 'h')) => 8,
305 Tup2(Tup2(2, 'c'), Tup2(2, 'g')) => 9,
306 Tup2(Tup2(2, 'c'), Tup2(2, 'h')) => 12,
307 Tup2(Tup2(2, 'c'), Tup2(3, 'i')) => 15,
308 Tup2(Tup2(2, 'c'), Tup2(3, 'j')) => -6,
309 Tup2(Tup2(2, 'd'), Tup2(2, 'g')) => 12,
310 Tup2(Tup2(2, 'd'), Tup2(2, 'h')) => 16,
311 Tup2(Tup2(2, 'd'), Tup2(3, 'i')) => 20,
312 Tup2(Tup2(2, 'd'), Tup2(3, 'j')) => -8,
313 Tup2(Tup2(3, 'e'), Tup2(2, 'g')) => 15,
314 Tup2(Tup2(3, 'e'), Tup2(2, 'h')) => 20,
315 Tup2(Tup2(3, 'e'), Tup2(3, 'i')) => 25,
316 Tup2(Tup2(3, 'e'), Tup2(3, 'j')) => -10,
317 Tup2(Tup2(3, 'e'), Tup2(4, 'k')) => 25,
318 Tup2(Tup2(3, 'e'), Tup2(4, 'l')) => -10,
319 Tup2(Tup2(3, 'f'), Tup2(2, 'g')) => -6,
320 Tup2(Tup2(3, 'f'), Tup2(2, 'h')) => -8,
321 Tup2(Tup2(3, 'f'), Tup2(3, 'i')) => -10,
322 Tup2(Tup2(3, 'f'), Tup2(3, 'j')) => 4,
323 Tup2(Tup2(3, 'f'), Tup2(4, 'k')) => -10,
324 Tup2(Tup2(3, 'f'), Tup2(4, 'l')) => 4,
325 },
326 zset! {
327 Tup2(Tup2(1, 'a'), Tup2(1, 'b')) => 1,
328 },
329 zset! {},
330 zset! {},
331 zset! {},
332 ]
333 .into_iter();
334
335 let index1 = circuit
336 .add_source(Generator::new(move || input1.next().unwrap()))
337 .map_index(|Tup2(k, v)| (*k, *v));
338 let index2 = circuit
339 .add_source(Generator::new(move || input2.next().unwrap()))
340 .map_index(|Tup2(k, v)| (*k, *v));
341 let output1 = index1.stream_join_range(
342 &index2,
343 |&k| (k - 1, k + 2),
344 |&k1, &v1, &k2, &v2| Some(Tup2(Tup2(k1, v1), Tup2(k2, v2))),
345 );
346 output1.inspect(move |fm| assert_eq!(fm, &outputs.next().unwrap()));
347 let output2 = index1.stream_join_range_index(
348 &index2,
349 |&k| (k - 1, k + 2),
350 |&k1, &v1, &k2, &v2| Some((Tup2(k1, v1), Tup2(k2, v2))),
351 );
352 output1
353 .map_index(|Tup2(k, v)| (*k, *v))
354 .apply2(&output2, |o1, o2| assert_eq!(o1, o2));
355 Ok(())
356 })
357 .unwrap()
358 .0;
359
360 for _ in 0..5 {
361 circuit.transaction().unwrap();
362 }
363 }
364}