Skip to main content

dbsp/operator/dynamic/time_series/radix_tree/
tree_aggregate.rs

1use super::{DynTreeNode, Prefix, RadixTreeFactories, radix_tree_update};
2use crate::{
3    Circuit, DBData, DynZWeight, Stream, ZWeight,
4    algebra::{HasOne, IndexedZSet, IndexedZSetReader, OrdIndexedZSet},
5    circuit::{
6        Scope,
7        operator_traits::{Operator, TernaryOperator},
8    },
9    dynamic::{DataTrait, DynDataTyped, Erase},
10    operator::dynamic::{
11        accumulate_trace::AccumulateTraceFeedback, aggregate::DynAggregator,
12        time_series::radix_tree::treenode::TreeNode, trace::TraceBounds,
13    },
14    trace::{Batch, BatchReader, BatchReaderFactories, Builder, Spine, TupleBuilder},
15};
16use dyn_clone::clone_box;
17use num::PrimInt;
18use size_of::SizeOf;
19use std::{borrow::Cow, cmp::Ordering, marker::PhantomData, ops::Neg};
20
21/// A batch that contains updates to a radix tree.
22pub trait RadixTreeBatch<TS, A>:
23    IndexedZSet<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>
24where
25    TS: DBData + PrimInt,
26    A: DataTrait + ?Sized,
27{
28}
29
30impl<TS, A, B> RadixTreeBatch<TS, A> for B
31where
32    TS: DBData + PrimInt,
33    A: DataTrait + ?Sized,
34    B: IndexedZSet<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>,
35{
36}
37
38pub trait RadixTreeReader<TS, A>:
39    IndexedZSetReader<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>
40where
41    TS: DBData + PrimInt,
42    A: DataTrait + ?Sized,
43{
44}
45
46impl<TS, A, B> RadixTreeReader<TS, A> for B
47where
48    B: IndexedZSetReader<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>,
49    TS: DBData + PrimInt,
50    A: DataTrait + ?Sized,
51{
52}
53
54pub type OrdRadixTree<TS, A> = OrdIndexedZSet<DynDataTyped<Prefix<TS>>, DynTreeNode<TS, A>>;
55
56pub struct TreeAggregateFactories<
57    TS: DBData + PrimInt,
58    Z: IndexedZSet<Key = DynDataTyped<TS>>,
59    O: RadixTreeBatch<TS, Acc>,
60    Acc: DataTrait + ?Sized,
61> {
62    input_factories: Z::Factories,
63    output_factories: O::Factories,
64    radix_tree_factories: RadixTreeFactories<TS, Acc>,
65}
66
67impl<TS, Z, O, Acc> TreeAggregateFactories<TS, Z, O, Acc>
68where
69    TS: DBData + PrimInt,
70    Z: IndexedZSet<Key = DynDataTyped<TS>>,
71    O: RadixTreeBatch<TS, Acc>,
72    Acc: DataTrait + ?Sized,
73{
74    pub fn new<VType, AType>() -> Self
75    where
76        VType: DBData + Erase<Z::Val>,
77        AType: DBData + Erase<Acc>,
78    {
79        Self {
80            input_factories: BatchReaderFactories::new::<TS, VType, ZWeight>(),
81            output_factories: BatchReaderFactories::new::<Prefix<TS>, TreeNode<TS, AType>, ZWeight>(
82            ),
83            radix_tree_factories: RadixTreeFactories::new::<AType>(),
84        }
85    }
86}
87
88impl<C, Z, TS> Stream<C, Z>
89where
90    C: Circuit,
91    Z: IndexedZSet<Key = DynDataTyped<TS>> + SizeOf + Send,
92    TS: DBData + PrimInt,
93{
94    /// Given a batch of updates to a time series stream, computes a stream of
95    /// updates to its radix tree.
96    ///
97    /// This is intended as a building block for higher-level operators.
98    ///
99    /// # Limitations
100    ///
101    /// Unlike `Stream::partitioned_tree_aggregate()`, this operator is
102    /// currently not parallelized, performing all work in a single worker
103    /// thread.
104    pub fn tree_aggregate<Acc, Out>(
105        &self,
106        persistent_id: Option<&str>,
107        factories: &TreeAggregateFactories<TS, Z, OrdRadixTree<TS, Acc>, Acc>,
108        aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>,
109    ) -> Stream<C, OrdRadixTree<TS, Acc>>
110    where
111        Acc: DataTrait + ?Sized,
112        Out: DataTrait + ?Sized,
113    {
114        self.tree_aggregate_generic::<Acc, Out, OrdRadixTree<TS, Acc>>(
115            persistent_id,
116            factories,
117            aggregator,
118        )
119    }
120
121    /// Like [`Self::tree_aggregate`], but can return any batch type.
122    pub fn tree_aggregate_generic<Acc, Out, O>(
123        &self,
124        persistent_id: Option<&str>,
125        factories: &TreeAggregateFactories<TS, Z, O, Acc>,
126        aggregator: &dyn DynAggregator<Z::Val, (), DynZWeight, Accumulator = Acc, Output = Out>,
127    ) -> Stream<C, O>
128    where
129        Acc: DataTrait + ?Sized,
130        Out: DataTrait + ?Sized,
131        O: RadixTreeBatch<TS, Acc>,
132    {
133        self.circuit().region("tree_aggregate", move || {
134            let circuit = self.circuit();
135            let stream = self.dyn_gather(&factories.input_factories, 0);
136
137            // We construct the following circuit.  See `RadixTreeAggregate`
138            // documentation for details.
139            //
140            // ```
141            //          ┌─────────────────────────────────────────┐
142            //          │                                         │                           output
143            //          │                                         │                 ┌─────────────────────────────────►
144            //          │                                         ▼                 │
145            //    stream│     ┌───────────────┐         ┌──────────────────────┐    │      ┌──────────────────┐
146            // ─────────┴─────┤integrate_trace├───────► │  RadixTreeAggregate  ├────┴─────►│UntimedTraceAppend├──┐
147            //                └───────────────┘         └──────────────────────┘           └──────────────────┘  │
148            //                                                    ▲                               ▲              │output_trace
149            //                                                    │                               │              │
150            //                                                    │                           ┌───┴───┐          │
151            //                                                    └───────────────────────────┤Z1Trace│◄─────────┘
152            //                                                            delayed_trace       └───────┘
153            // ```
154
155            let feedback = circuit.add_accumulate_integrate_trace_feedback::<Spine<O>>(
156                persistent_id,
157                &factories.output_factories,
158                <TraceBounds<O::Key, O::Val>>::unbounded(),
159            );
160
161            let output = circuit.add_ternary_operator(
162                RadixTreeAggregate::new(
163                    &factories.radix_tree_factories,
164                    &factories.output_factories,
165                    aggregator,
166                ),
167                &stream.dyn_accumulate(&factories.input_factories),
168                &stream.dyn_accumulate_integrate_trace(&factories.input_factories),
169                &feedback.delayed_trace,
170            );
171
172            feedback.connect(&output, &factories.output_factories);
173
174            output
175        })
176    }
177}
178
179/// Ternary operator that implements the internals of `tree_aggregate`.
180///
181/// * Input stream 1: updates to the time series.  Only used to identify
182///   affected times.
183/// * Input stream 2: trace containing the accumulated time series data.
184/// * Input stream 3: trace containing the current contents of the radix tree.
185struct RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
186where
187    Z: BatchReader<Key = DynDataTyped<TS>>,
188    TS: DBData + PrimInt,
189    O: Batch,
190    Acc: DataTrait + ?Sized,
191    Out: DataTrait + ?Sized,
192{
193    aggregator: Box<dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>>,
194    radix_tree_factories: RadixTreeFactories<TS, Acc>,
195    output_factories: O::Factories,
196    phantom: PhantomData<(Z, IT, OT, O)>,
197}
198
199impl<Z, TS, IT, OT, Acc, Out, O> RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
200where
201    Z: BatchReader<Key = DynDataTyped<TS>>,
202    TS: DBData + PrimInt,
203    Acc: DataTrait + ?Sized,
204    Out: DataTrait + ?Sized,
205    O: Batch,
206{
207    pub fn new(
208        radix_tree_factories: &RadixTreeFactories<TS, Acc>,
209        output_factories: &O::Factories,
210        aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>,
211    ) -> Self {
212        Self {
213            radix_tree_factories: radix_tree_factories.clone(),
214            output_factories: output_factories.clone(),
215            aggregator: clone_box(aggregator),
216            phantom: PhantomData,
217        }
218    }
219}
220
221impl<Z, TS, IT, OT, Acc, Out, O> Operator for RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
222where
223    Z: BatchReader<Key = DynDataTyped<TS>>,
224    Acc: DataTrait + ?Sized,
225    Out: DataTrait + ?Sized,
226    TS: DBData + PrimInt,
227    IT: 'static,
228    OT: 'static,
229    O: Batch,
230{
231    fn name(&self) -> Cow<'static, str> {
232        Cow::from("RadixTreeAggregate")
233    }
234
235    fn fixedpoint(&self, _scope: Scope) -> bool {
236        true
237    }
238}
239
240impl<Z, TS, IT, OT, Acc, Out, O> TernaryOperator<Option<Spine<Z>>, IT, OT, O>
241    for RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
242where
243    Z: IndexedZSet<Key = DynDataTyped<TS>>,
244    TS: DBData + PrimInt,
245    Acc: DataTrait + ?Sized,
246    Out: DataTrait + ?Sized,
247    O: RadixTreeBatch<TS, Acc>,
248    IT: IndexedZSetReader<Key = Z::Key, Val = Z::Val> + Clone,
249    OT: RadixTreeReader<TS, Acc> + Clone,
250{
251    async fn eval(
252        &mut self,
253        delta: Cow<'_, Option<Spine<Z>>>,
254        input_trace: Cow<'_, IT>,
255        output_trace: Cow<'_, OT>,
256    ) -> O {
257        let Some(delta) = delta.as_ref() else {
258            return O::dyn_empty(&self.output_factories);
259        };
260
261        let mut updates = self.radix_tree_factories.node_updates_factory.default_box();
262        updates.reserve(delta.key_count());
263
264        radix_tree_update::<TS, Z::Val, Acc, Out, _, _, _>(
265            &self.radix_tree_factories,
266            delta.cursor(),
267            input_trace.cursor(),
268            output_trace.cursor(),
269            self.aggregator.as_ref(),
270            &mut *updates,
271        );
272
273        let builder =
274            O::Builder::with_capacity(&self.output_factories, updates.len(), updates.len() * 2);
275        let mut builder = TupleBuilder::new(&self.output_factories, builder);
276
277        // `updates` are already ordered by prefix.  All that remains is to order
278        // insertion and deletion within each update.
279        for update in updates.dyn_iter_mut() {
280            match update.new().cmp(update.old()) {
281                Ordering::Equal => {}
282                Ordering::Less => {
283                    let mut prefix = update.prefix();
284                    if let Some(new) = update.new_mut().get_mut() {
285                        builder.push_vals(
286                            prefix.clone().erase_mut(),
287                            new,
288                            &mut (),
289                            ZWeight::one().erase_mut(),
290                        );
291                    };
292                    if let Some(old) = update.old_mut().get_mut() {
293                        builder.push_vals(
294                            prefix.erase_mut(),
295                            old,
296                            &mut (),
297                            ZWeight::one().neg().erase_mut(),
298                        );
299                    };
300                }
301                Ordering::Greater => {
302                    let mut prefix = update.prefix();
303
304                    if let Some(old) = update.old_mut().get_mut() {
305                        builder.push_vals(
306                            prefix.clone().erase_mut(),
307                            old,
308                            &mut (),
309                            ZWeight::one().neg().erase_mut(),
310                        );
311                    };
312                    if let Some(new) = update.new_mut().get_mut() {
313                        builder.push_vals(
314                            prefix.erase_mut(),
315                            new,
316                            &mut (),
317                            ZWeight::one().erase_mut(),
318                        );
319                    };
320                }
321            }
322        }
323
324        builder.done()
325    }
326}
327
328#[cfg(test)]
329mod test {
330    use super::super::RadixTreeCursor;
331    use crate::{
332        DynZWeight, Runtime, Stream, ZWeight,
333        algebra::{AddAssignByRef, DefaultSemigroup},
334        dynamic::{DowncastTrait, DynData, DynDataTyped, DynPair, Erase},
335        operator::{
336            Fold,
337            dynamic::{
338                aggregate::DynAggregatorImpl,
339                input::{AddInputIndexedZSetFactories, CollectionHandle},
340                time_series::{
341                    TreeNode,
342                    radix_tree::{
343                        Prefix,
344                        test::test_aggregate_range,
345                        tree_aggregate::{OrdRadixTree, TreeAggregateFactories},
346                    },
347                },
348            },
349        },
350        trace::{BatchReader, BatchReaderFactories},
351        utils::Tup2,
352    };
353    use std::{
354        collections::{BTreeMap, btree_map::Entry},
355        sync::{Arc, Mutex},
356    };
357
358    fn update_key(
359        input: &CollectionHandle<DynDataTyped<u64>, DynPair<DynData, DynZWeight>>,
360        contents: &mut BTreeMap<u64, Box<DynData /* <u64> */>>,
361        key: u64,
362        upd: Tup2<u64, ZWeight>,
363    ) {
364        input.dyn_push(key.clone().erase_mut(), upd.clone().erase_mut());
365        match contents.entry(key) {
366            Entry::Vacant(ve) => {
367                assert_eq!(upd.1, 1);
368                ve.insert(Box::new(upd.0).erase_box());
369            }
370            Entry::Occupied(mut oe) => {
371                assert!(upd.1 == 1 || upd.1 == -1);
372                if upd.1 == 1 {
373                    *oe.get_mut().downcast_mut_checked::<u64>() += upd.0;
374                } else {
375                    *oe.get_mut().downcast_mut_checked::<u64>() -= upd.0;
376                }
377                if *oe.get().downcast_checked::<u64>() == 0 {
378                    oe.remove();
379                }
380            }
381        }
382    }
383
384    #[test]
385    fn test_tree_aggregate() {
386        let contents = Arc::new(Mutex::new(BTreeMap::new()));
387        let contents_clone = contents.clone();
388
389        let (mut circuit, input) = Runtime::init_circuit(1, move |circuit| {
390            let (input, input_handle) =
391                circuit.dyn_add_input_indexed_zset::<DynDataTyped<u64>, DynData/*u64*/>(&AddInputIndexedZSetFactories::new::<u64, u64>());
392
393            let aggregator = <Fold<u64, _, DefaultSemigroup<_>, _, _>>::new(
394                0u64,
395                |agg: &mut u64, val: &u64, _w: ZWeight| *agg += val,
396            );
397
398            let aggregate: Stream<_, OrdRadixTree<u64, DynData /* <u64> */>> = input
399                .tree_aggregate::<DynData/*<u64>*/, DynData/*<u64>*/>(
400                    None,
401                    &TreeAggregateFactories::new::<u64, u64>(),
402                    &DynAggregatorImpl::new(aggregator),
403                );
404            let factory = BatchReaderFactories::new::<Prefix<u64>, TreeNode<u64, u64>, ZWeight>();
405            aggregate
406                .dyn_integrate_trace(&factory)
407                .apply(move |tree_trace| {
408                    println!("Radix tree:");
409                    let mut treestr = String::new();
410                    tree_trace.cursor().format_tree(&mut treestr).unwrap();
411                    println!("{treestr}");
412                    tree_trace
413                        .cursor()
414                        .validate(&contents_clone.lock().unwrap(), &|acc, val| {
415                            acc.downcast_mut_checked::<u64>().add_assign_by_ref(val.downcast_checked::<u64>())
416                        });
417                    test_aggregate_range::<u64, u64, _, DefaultSemigroup<_>>(
418                        &mut tree_trace.cursor(),
419                        &contents_clone.lock().unwrap(),
420                    );
421                });
422
423            Ok(input_handle)
424        })
425        .unwrap();
426
427        circuit.transaction().unwrap();
428
429        update_key(
430            &input,
431            &mut contents.lock().unwrap(),
432            0x1000_0000_0000_0001,
433            Tup2(1, 1),
434        );
435        circuit.transaction().unwrap();
436
437        update_key(
438            &input,
439            &mut contents.lock().unwrap(),
440            0x1000_0000_0000_0002,
441            Tup2(2, 1),
442        );
443        circuit.transaction().unwrap();
444
445        update_key(
446            &input,
447            &mut contents.lock().unwrap(),
448            0x1000_1000_0000_0000,
449            Tup2(3, 1),
450        );
451        circuit.transaction().unwrap();
452
453        update_key(
454            &input,
455            &mut contents.lock().unwrap(),
456            0x1000_0000_0000_0002,
457            Tup2(2, -1),
458        );
459        circuit.transaction().unwrap();
460
461        update_key(
462            &input,
463            &mut contents.lock().unwrap(),
464            0xf100_0000_0000_0001,
465            Tup2(4, 1),
466        );
467        update_key(
468            &input,
469            &mut contents.lock().unwrap(),
470            0xf200_0000_0000_0001,
471            Tup2(5, 1),
472        );
473        update_key(
474            &input,
475            &mut contents.lock().unwrap(),
476            0xf300_0000_0000_0001,
477            Tup2(6, 1),
478        );
479        update_key(
480            &input,
481            &mut contents.lock().unwrap(),
482            0xf300_1000_0000_0001,
483            Tup2(7, 1),
484        );
485        update_key(
486            &input,
487            &mut contents.lock().unwrap(),
488            0xf300_1000_1000_0001,
489            Tup2(8, 1),
490        );
491        update_key(
492            &input,
493            &mut contents.lock().unwrap(),
494            0xf300_1000_1000_1001,
495            Tup2(9, 1),
496        );
497        update_key(
498            &input,
499            &mut contents.lock().unwrap(),
500            0xf300_1000_1100_1001,
501            Tup2(10, 1),
502        );
503        update_key(
504            &input,
505            &mut contents.lock().unwrap(),
506            0xf300_1000_1100_1001,
507            Tup2(10, -1),
508        );
509        circuit.transaction().unwrap();
510
511        update_key(
512            &input,
513            &mut contents.lock().unwrap(),
514            0xf400_1000_1100_1001,
515            Tup2(11, 1),
516        );
517        update_key(
518            &input,
519            &mut contents.lock().unwrap(),
520            0xf300_1000_0000_0001,
521            Tup2(7, -1),
522        );
523        circuit.transaction().unwrap();
524
525        update_key(
526            &input,
527            &mut contents.lock().unwrap(),
528            0x1000_0000_0000_0001,
529            Tup2(1, -1),
530        );
531        update_key(
532            &input,
533            &mut contents.lock().unwrap(),
534            0x1000_1000_0000_0000,
535            Tup2(3, -1),
536        );
537        update_key(
538            &input,
539            &mut contents.lock().unwrap(),
540            0xf100_0000_0000_0001,
541            Tup2(4, -1),
542        );
543        update_key(
544            &input,
545            &mut contents.lock().unwrap(),
546            0xf200_0000_0000_0001,
547            Tup2(5, -1),
548        );
549        circuit.transaction().unwrap();
550
551        update_key(
552            &input,
553            &mut contents.lock().unwrap(),
554            0xf300_0000_0000_0001,
555            Tup2(6, -1),
556        );
557        update_key(
558            &input,
559            &mut contents.lock().unwrap(),
560            0xf300_1000_1000_0001,
561            Tup2(8, -1),
562        );
563        update_key(
564            &input,
565            &mut contents.lock().unwrap(),
566            0xf300_1000_1000_1001,
567            Tup2(9, -1),
568        );
569        circuit.transaction().unwrap();
570
571        update_key(
572            &input,
573            &mut contents.lock().unwrap(),
574            0xf400_1000_1100_1001,
575            Tup2(11, -1),
576        );
577        circuit.transaction().unwrap();
578
579        update_key(
580            &input,
581            &mut contents.lock().unwrap(),
582            0xf100_0000_0000_0001,
583            Tup2(4, 1),
584        );
585        update_key(
586            &input,
587            &mut contents.lock().unwrap(),
588            0xf200_0000_0000_0001,
589            Tup2(5, 1),
590        );
591        update_key(
592            &input,
593            &mut contents.lock().unwrap(),
594            0xf300_0000_0000_0001,
595            Tup2(6, 1),
596        );
597        update_key(
598            &input,
599            &mut contents.lock().unwrap(),
600            0xf300_1000_0000_0001,
601            Tup2(7, 1),
602        );
603        update_key(
604            &input,
605            &mut contents.lock().unwrap(),
606            0xf300_1000_1000_0001,
607            Tup2(8, 1),
608        );
609        update_key(
610            &input,
611            &mut contents.lock().unwrap(),
612            0xf300_1000_1000_0001,
613            Tup2(11, 1),
614        );
615        circuit.transaction().unwrap();
616    }
617}