Skip to main content

dbsp/operator/dynamic/
join_range.rs

1//! Range-join operators.
2//!
3//! Range-join is a form of non-equi join where each key in the left operand
4//! matches a contiguous range of keys in the right operand.
5//!
6//! Consider two indexed Z-sets `z1` and `z2`, a function `range_func` that
7//! maps a key in `z1` to a half-closed interval of keys `[lower, upper)` in
8//! `z2` and another function `join_func` that, given `(k1, v1)` in `z1` and
9//! `(k2, v2)` in `z2`, returns an iterable collection `c` of output values.
10//! The range-join operator works as follows:
11//! * For each `k1` in `z1`, find all keys in `z2` within the half-closed
12//!   interval `join_range(k1)`.
13//! * For each `((k1, v1), w1)` in `z1` and `((k2, v2), w2)` in `z2` where `k2 ∈
14//!   join_range(k1)`, add all values in `join_func(k1,v1,k2,v2)` to the output
15//!   batch with weight `w1 * w2`.
16
17use crate::{
18    DBData, ZWeight,
19    algebra::{IndexedZSet, IndexedZSetReader, MulByRef, OrdIndexedZSet, OrdZSet},
20    circuit::{
21        Circuit, Scope, Stream,
22        operator_traits::{BinaryOperator, Operator},
23    },
24    dynamic::{DataTrait, DynUnit, Erase},
25    trace::{BatchFactories, BatchReaderFactories, Cursor},
26};
27use std::{borrow::Cow, marker::PhantomData};
28
29pub struct StreamJoinRangeFactories<I, O>
30where
31    I: IndexedZSetReader,
32    O: IndexedZSet,
33{
34    input2_factories: I::Factories,
35    output_factories: O::Factories,
36}
37
38impl<I: IndexedZSetReader, O: IndexedZSet> Clone for StreamJoinRangeFactories<I, O> {
39    fn clone(&self) -> Self {
40        Self {
41            input2_factories: self.input2_factories.clone(),
42            output_factories: self.output_factories.clone(),
43        }
44    }
45}
46
47impl<I, O> StreamJoinRangeFactories<I, O>
48where
49    I: IndexedZSetReader,
50    O: IndexedZSet,
51{
52    pub fn new<IKType, IVType, OKType, OVType>() -> Self
53    where
54        IKType: DBData + Erase<I::Key>,
55        IVType: DBData + Erase<I::Val>,
56        OKType: DBData + Erase<O::Key>,
57        OVType: DBData + Erase<O::Val>,
58    {
59        Self {
60            input2_factories: BatchReaderFactories::new::<IKType, IVType, ZWeight>(),
61            output_factories: BatchReaderFactories::new::<OKType, OVType, ZWeight>(),
62        }
63    }
64}
65
66impl<C, I1> Stream<C, I1>
67where
68    C: Circuit,
69{
70    /// See [`Stream::stream_join_range`].
71    pub fn dyn_stream_join_range<I2, V>(
72        &self,
73        factories: &StreamJoinRangeFactories<I2, OrdZSet<V>>,
74        other: &Stream<C, I2>,
75        range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
76        join_func: Box<
77            dyn Fn(&I1::Key, &I1::Val, &I2::Key, &I2::Val, &mut dyn FnMut(&mut V, &mut DynUnit)),
78        >,
79    ) -> Stream<C, OrdZSet<V>>
80    where
81        I1: IndexedZSetReader + Clone,
82        I2: IndexedZSetReader + Clone,
83        V: DataTrait + ?Sized,
84    {
85        self.dyn_stream_join_range_generic(factories, other, range_func, join_func)
86    }
87
88    /// See [`Stream::stream_join_range_index`].
89    pub fn dyn_stream_join_range_index<K, V, I2>(
90        &self,
91        factories: &StreamJoinRangeFactories<I2, OrdIndexedZSet<K, V>>,
92        other: &Stream<C, I2>,
93        range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
94        join_func: Box<
95            dyn Fn(&I1::Key, &I1::Val, &I2::Key, &I2::Val, &mut dyn FnMut(&mut K, &mut V)),
96        >,
97    ) -> Stream<C, OrdIndexedZSet<K, V>>
98    where
99        I1: IndexedZSetReader + Clone,
100        I2: IndexedZSetReader + Clone,
101        K: DataTrait + ?Sized,
102        V: DataTrait + ?Sized,
103    {
104        self.dyn_stream_join_range_generic(factories, other, range_func, join_func)
105    }
106
107    /// Like [`Self::dyn_stream_join_range`], but can return any indexed Z-set
108    /// type.
109    pub fn dyn_stream_join_range_generic<I2, O>(
110        &self,
111        factories: &StreamJoinRangeFactories<I2, O>,
112        other: &Stream<C, I2>,
113        range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
114        join_func: Box<
115            dyn Fn(
116                &I1::Key,
117                &I1::Val,
118                &I2::Key,
119                &I2::Val,
120                &mut dyn FnMut(&mut O::Key, &mut O::Val),
121            ),
122        >,
123    ) -> Stream<C, O>
124    where
125        I1: IndexedZSetReader + Clone,
126        I2: IndexedZSetReader + Clone,
127        O: IndexedZSet,
128    {
129        self.circuit().add_binary_operator(
130            StreamJoinRange::new(factories, range_func, join_func),
131            self,
132            other,
133        )
134    }
135}
136
137pub struct StreamJoinRange<I1, I2, O>
138where
139    I1: IndexedZSetReader,
140    I2: IndexedZSetReader,
141    O: IndexedZSet,
142{
143    factories: StreamJoinRangeFactories<I2, O>,
144    range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
145    join_func: Box<
146        dyn Fn(&I1::Key, &I1::Val, &I2::Key, &I2::Val, &mut dyn FnMut(&mut O::Key, &mut O::Val)),
147    >,
148    _types: PhantomData<(I1, I2, O)>,
149}
150
151impl<I1, I2, O> StreamJoinRange<I1, I2, O>
152where
153    I1: IndexedZSetReader,
154    I2: IndexedZSetReader,
155    O: IndexedZSet,
156{
157    pub fn new(
158        factories: &StreamJoinRangeFactories<I2, O>,
159        range_func: Box<dyn Fn(&I1::Key, &mut I2::Key, &mut I2::Key)>,
160        join_func: Box<
161            dyn Fn(
162                &I1::Key,
163                &I1::Val,
164                &I2::Key,
165                &I2::Val,
166                &mut dyn FnMut(&mut O::Key, &mut O::Val),
167            ),
168        >,
169    ) -> Self {
170        Self {
171            factories: factories.clone(),
172            range_func,
173            join_func,
174            _types: PhantomData,
175        }
176    }
177}
178
179impl<I1, I2, O> Operator for StreamJoinRange<I1, I2, O>
180where
181    I1: IndexedZSetReader,
182    I2: IndexedZSetReader,
183    O: IndexedZSet,
184{
185    fn name(&self) -> Cow<'static, str> {
186        Cow::from("StreamJoinRange")
187    }
188    fn fixedpoint(&self, _scope: Scope) -> bool {
189        true
190    }
191}
192
193impl<I1, I2, O> BinaryOperator<I1, I2, O> for StreamJoinRange<I1, I2, O>
194where
195    I1: IndexedZSetReader + Clone,
196    I2: IndexedZSetReader + Clone,
197    O: IndexedZSet,
198{
199    async fn eval(&mut self, i1: &I1, i2: &I2) -> O {
200        let mut tuples = self
201            .factories
202            .output_factories
203            .weighted_items_factory()
204            .default_box();
205
206        let mut item = self
207            .factories
208            .output_factories
209            .weighted_item_factory()
210            .default_box();
211
212        let mut i1_cursor = i1.cursor();
213        let mut i2_cursor = i2.cursor();
214
215        let mut lower = self.factories.input2_factories.key_factory().default_box();
216        let mut upper = self.factories.input2_factories.key_factory().default_box();
217
218        // For each key in `i1`.
219        while i1_cursor.key_valid() {
220            // Compute the range of matching keys in `i2`.
221            (self.range_func)(i1_cursor.key(), lower.as_mut(), upper.as_mut());
222
223            // Assuming that `lower` grows monotonically, we wouldn't need to rewind every
224            // time.
225            i2_cursor.rewind_keys();
226            i2_cursor.seek_key(&lower);
227
228            // Iterate over the `[lower, upper)` interval.
229            while i2_cursor.key_valid() && i2_cursor.key() < &upper {
230                // Iterate over all pairs of values in i1 and i2.
231                i1_cursor.rewind_vals();
232                while i1_cursor.val_valid() {
233                    let w1 = **i1_cursor.weight();
234                    let k1 = i1_cursor.key();
235                    let v1 = i1_cursor.val();
236                    i2_cursor.rewind_vals();
237
238                    while i2_cursor.val_valid() {
239                        let w2 = **i2_cursor.weight();
240                        let w = w1.mul_by_ref(&w2);
241
242                        // Add all `(k,v)` tuples output by `join_func` to the output batch.
243                        (self.join_func)(k1, v1, i2_cursor.key(), i2_cursor.val(), &mut |k, v| {
244                            let (kv, weight) = item.split_mut();
245                            kv.from_vals(k, v);
246                            **weight = w;
247                            tuples.push_val(item.as_mut());
248                        });
249                        i2_cursor.step_val();
250                    }
251                    i1_cursor.step_val();
252                }
253                i2_cursor.step_key();
254            }
255            i1_cursor.step_key();
256        }
257
258        O::dyn_from_tuples(&self.factories.output_factories, (), &mut tuples)
259    }
260}
261
262#[cfg(test)]
263mod test {
264    use crate::{Circuit, RootCircuit, operator::Generator, utils::Tup2, zset};
265
266    #[test]
267    fn stream_join_range_test() {
268        let circuit = RootCircuit::build(move |circuit| {
269            let mut input1 = vec![
270                zset! {
271                    Tup2(1, 'a') => 1,
272                    Tup2(1, 'b') => 2,
273                    Tup2(2, 'c') => 3,
274                    Tup2(2, 'd') => 4,
275                    Tup2(3, 'e') => 5,
276                    Tup2(3, 'f') => -2,
277                },
278                zset! {Tup2(1, 'a') => 1},
279                zset! {Tup2(1, 'a') => 1},
280                zset! {Tup2(4, 'n') => 2},
281                zset! {Tup2(1, 'a') => 0},
282            ]
283            .into_iter();
284            let mut input2 = vec![
285                zset! {
286                    Tup2(2, 'g') => 3,
287                    Tup2(2, 'h') => 4,
288                    Tup2(3, 'i') => 5,
289                    Tup2(3, 'j') => -2,
290                    Tup2(4, 'k') => 5,
291                    Tup2(4, 'l') => -2,
292                },
293                zset! {Tup2(1, 'b') => 1},
294                zset! {Tup2(4, 'm') => 1},
295                zset! {},
296                zset! {},
297            ]
298            .into_iter();
299            let mut outputs = vec![
300                zset! {
301                    Tup2(Tup2(1, 'a'), Tup2(2, 'g')) => 3,
302                    Tup2(Tup2(1, 'a'), Tup2(2, 'h')) => 4,
303                    Tup2(Tup2(1, 'b'), Tup2(2, 'g')) => 6,
304                    Tup2(Tup2(1, 'b'), Tup2(2, 'h')) => 8,
305                    Tup2(Tup2(2, 'c'), Tup2(2, 'g')) => 9,
306                    Tup2(Tup2(2, 'c'), Tup2(2, 'h')) => 12,
307                    Tup2(Tup2(2, 'c'), Tup2(3, 'i')) => 15,
308                    Tup2(Tup2(2, 'c'), Tup2(3, 'j')) => -6,
309                    Tup2(Tup2(2, 'd'), Tup2(2, 'g')) => 12,
310                    Tup2(Tup2(2, 'd'), Tup2(2, 'h')) => 16,
311                    Tup2(Tup2(2, 'd'), Tup2(3, 'i')) => 20,
312                    Tup2(Tup2(2, 'd'), Tup2(3, 'j')) => -8,
313                    Tup2(Tup2(3, 'e'), Tup2(2, 'g')) => 15,
314                    Tup2(Tup2(3, 'e'), Tup2(2, 'h')) => 20,
315                    Tup2(Tup2(3, 'e'), Tup2(3, 'i')) => 25,
316                    Tup2(Tup2(3, 'e'), Tup2(3, 'j')) => -10,
317                    Tup2(Tup2(3, 'e'), Tup2(4, 'k')) => 25,
318                    Tup2(Tup2(3, 'e'), Tup2(4, 'l')) => -10,
319                    Tup2(Tup2(3, 'f'), Tup2(2, 'g')) => -6,
320                    Tup2(Tup2(3, 'f'), Tup2(2, 'h')) => -8,
321                    Tup2(Tup2(3, 'f'), Tup2(3, 'i')) => -10,
322                    Tup2(Tup2(3, 'f'), Tup2(3, 'j')) => 4,
323                    Tup2(Tup2(3, 'f'), Tup2(4, 'k')) => -10,
324                    Tup2(Tup2(3, 'f'), Tup2(4, 'l')) => 4,
325                },
326                zset! {
327                    Tup2(Tup2(1, 'a'), Tup2(1, 'b')) => 1,
328                },
329                zset! {},
330                zset! {},
331                zset! {},
332            ]
333            .into_iter();
334
335            let index1 = circuit
336                .add_source(Generator::new(move || input1.next().unwrap()))
337                .map_index(|Tup2(k, v)| (*k, *v));
338            let index2 = circuit
339                .add_source(Generator::new(move || input2.next().unwrap()))
340                .map_index(|Tup2(k, v)| (*k, *v));
341            let output1 = index1.stream_join_range(
342                &index2,
343                |&k| (k - 1, k + 2),
344                |&k1, &v1, &k2, &v2| Some(Tup2(Tup2(k1, v1), Tup2(k2, v2))),
345            );
346            output1.inspect(move |fm| assert_eq!(fm, &outputs.next().unwrap()));
347            let output2 = index1.stream_join_range_index(
348                &index2,
349                |&k| (k - 1, k + 2),
350                |&k1, &v1, &k2, &v2| Some((Tup2(k1, v1), Tup2(k2, v2))),
351            );
352            output1
353                .map_index(|Tup2(k, v)| (*k, *v))
354                .apply2(&output2, |o1, o2| assert_eq!(o1, o2));
355            Ok(())
356        })
357        .unwrap()
358        .0;
359
360        for _ in 0..5 {
361            circuit.transaction().unwrap();
362        }
363    }
364}