Skip to main content

dbsp/operator/dynamic/time_series/
rolling_aggregate.rs

1use crate::{
2    Circuit, DBData, DBWeight, DynZWeight, Position, RootCircuit, Stream, ZWeight,
3    algebra::{HasZero, IndexedZSet, UnsignedPrimInt, ZRingValue},
4    circuit::{
5        Scope,
6        metadata::{BatchSizeStats, INPUT_BATCHES_STATS, OUTPUT_BATCHES_STATS, OperatorMeta},
7        operator_traits::Operator,
8        splitter_output_chunk_size, splitter_output_first_chunk_size,
9    },
10    dynamic::{
11        ClonableTrait, Data, DataTrait, DowncastTrait, DynDataTyped, DynOpt, DynPair, DynUnit,
12        Erase, Factory, WeightTrait, WithFactory,
13    },
14    operator::{
15        Avg,
16        async_stream_operators::{StreamingQuaternaryOperator, StreamingQuaternaryWrapper},
17        dynamic::{
18            accumulate_trace::AccumulateTraceFeedback,
19            aggregate::{AggCombineFunc, AggOutputFunc, DynAggregator, DynAverage},
20            filter_map::DynFilterMap,
21            time_series::{
22                OrdPartitionedIndexedZSet, PartitionCursor, PartitionedBatch,
23                PartitionedIndexedZSet, RelOffset,
24                radix_tree::{
25                    OrdPartitionedTreeAggregateFactories, PartitionedRadixTreeBatch,
26                    RadixTreeCursor, TreeNode,
27                },
28                range::{Range, RangeCursor, Ranges, RelRange},
29            },
30            trace::{TraceBound, TraceBounds},
31        },
32    },
33    trace::{
34        Batch, BatchReader, BatchReaderFactories, Builder, Cursor, Spine, SpineSnapshot,
35        merge_batches, spine_async::WithSnapshot,
36    },
37    utils::Tup2,
38};
39use async_stream::stream;
40use dyn_clone::{DynClone, clone_box};
41use futures::Stream as AsyncStream;
42use num::Bounded;
43use std::{
44    borrow::Cow,
45    cell::{Cell, RefCell},
46    marker::PhantomData,
47    ops::{Deref, Div, Neg},
48    rc::Rc,
49};
50
51use super::radix_tree::{FilePartitionedRadixTreeFactories, Prefix};
52
53pub trait WeighFunc<V: ?Sized, R: ?Sized, A: ?Sized>: Fn(&V, &R, &mut A) + DynClone {}
54
55impl<V: ?Sized, R: ?Sized, A: ?Sized, F> WeighFunc<V, R, A> for F where F: Fn(&V, &R, &mut A) + Clone
56{}
57
58dyn_clone::clone_trait_object! {<V: ?Sized, R: ?Sized, A: ?Sized> WeighFunc<V, R, A>}
59
60pub trait PartitionFunc<IV: ?Sized, PK: ?Sized, OV: ?Sized>:
61    Fn(&IV, &mut PK, &mut OV) + DynClone
62{
63}
64
65impl<IV: ?Sized, PK: ?Sized, OV: ?Sized, F> PartitionFunc<IV, PK, OV> for F where
66    F: Fn(&IV, &mut PK, &mut OV) + Clone
67{
68}
69
70pub type OrdPartitionedOverStream<PK, TS, A> =
71    Stream<RootCircuit, OrdPartitionedIndexedZSet<PK, TS, DynOpt<A>>>;
72
73pub struct PartitionedRollingAggregateFactories<TS, V, Acc, Out, B, O>
74where
75    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
76    O: IndexedZSet<Key = B::Key>,
77    Acc: DataTrait + ?Sized,
78    Out: DataTrait + ?Sized,
79    TS: DBData + UnsignedPrimInt,
80    V: DataTrait + ?Sized,
81{
82    input_factories: B::Factories,
83    radix_tree_factories: FilePartitionedRadixTreeFactories<B::Key, TS, Acc>,
84    partitioned_tree_aggregate_factories: OrdPartitionedTreeAggregateFactories<TS, V, B, Acc>,
85    output_factories: O::Factories,
86    phantom: PhantomData<fn(&Out)>,
87}
88
89impl<TS, V, Acc, Out, B, O> PartitionedRollingAggregateFactories<TS, V, Acc, Out, B, O>
90where
91    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
92    O: PartitionedIndexedZSet<DynDataTyped<TS>, DynOpt<Out>, Key = B::Key>,
93    Acc: DataTrait + ?Sized,
94    Out: DataTrait + ?Sized,
95    TS: DBData + UnsignedPrimInt,
96    V: DataTrait + ?Sized,
97{
98    pub fn new<KType, VType, AType, OType>() -> Self
99    where
100        KType: DBData + Erase<B::Key>,
101        VType: DBData + Erase<V>,
102        AType: DBData + Erase<Acc>,
103        OType: DBData + Erase<Out>,
104    {
105        Self {
106            input_factories: BatchReaderFactories::new::<KType, Tup2<TS, VType>, ZWeight>(),
107            radix_tree_factories: BatchReaderFactories::new::<
108                KType,
109                Tup2<Prefix<TS>, TreeNode<TS, AType>>,
110                ZWeight,
111            >(),
112            partitioned_tree_aggregate_factories: OrdPartitionedTreeAggregateFactories::new::<
113                KType,
114                VType,
115                AType,
116            >(),
117            output_factories: BatchReaderFactories::new::<KType, Tup2<TS, Option<OType>>, ZWeight>(
118            ),
119            phantom: PhantomData,
120        }
121    }
122}
123
124pub struct PartitionedRollingAggregateWithWaterlineFactories<PK, TS, V, Acc, Out, B>
125where
126    PK: DataTrait + ?Sized,
127    B: IndexedZSet,
128    TS: DBData + UnsignedPrimInt + Erase<B::Key>,
129    Acc: DataTrait + ?Sized,
130    Out: DataTrait + ?Sized,
131    V: DataTrait + ?Sized,
132{
133    input_factories: B::Factories,
134    rolling_aggregate_factories: PartitionedRollingAggregateFactories<
135        TS,
136        V,
137        Acc,
138        Out,
139        OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, V>,
140        OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, DynOpt<Out>>,
141    >,
142}
143
144impl<PK, TS, V, Acc, Out, B>
145    PartitionedRollingAggregateWithWaterlineFactories<PK, TS, V, Acc, Out, B>
146where
147    PK: DataTrait + ?Sized,
148    B: IndexedZSet,
149    TS: DBData + UnsignedPrimInt + Erase<B::Key>,
150    Acc: DataTrait + ?Sized,
151    Out: DataTrait + ?Sized,
152    V: DataTrait + ?Sized,
153{
154    pub fn new<VType, PKType, PVType, AType, OType>() -> Self
155    where
156        VType: DBData + Erase<B::Val>,
157        PKType: DBData + Erase<PK>,
158        PVType: DBData + Erase<V>,
159        AType: DBData + Erase<Acc>,
160        OType: DBData + Erase<Out>,
161    {
162        Self {
163            input_factories: BatchReaderFactories::new::<TS, VType, ZWeight>(),
164            rolling_aggregate_factories: PartitionedRollingAggregateFactories::new::<
165                PKType,
166                PVType,
167                AType,
168                OType,
169            >(),
170        }
171    }
172}
173
174pub struct PartitionedRollingAggregateLinearFactories<TS, V, OV, A, B, O>
175where
176    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
177    O: IndexedZSet<Key = B::Key>,
178    TS: DBData + UnsignedPrimInt,
179    V: DataTrait + ?Sized,
180    OV: DataTrait + ?Sized,
181    A: WeightTrait + ?Sized,
182{
183    aggregate_factory: &'static dyn Factory<A>,
184    opt_accumulator_factory: &'static dyn Factory<DynOpt<A>>,
185    output_factory: &'static dyn Factory<OV>,
186    rolling_aggregate_factories: PartitionedRollingAggregateFactories<TS, V, A, OV, B, O>,
187}
188
189impl<TS, V, OV, A, B, O> PartitionedRollingAggregateLinearFactories<TS, V, OV, A, B, O>
190where
191    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
192    O: PartitionedIndexedZSet<DynDataTyped<TS>, DynOpt<OV>, Key = B::Key>,
193    B::Key: DataTrait,
194    TS: DBData + UnsignedPrimInt,
195    V: DataTrait + ?Sized,
196    OV: DataTrait + ?Sized,
197    A: WeightTrait + ?Sized,
198{
199    pub fn new<KType, VType, AType, OVType>() -> Self
200    where
201        KType: DBData + Erase<B::Key>,
202        VType: DBData + Erase<V>,
203        AType: DBWeight + Erase<A>,
204        OVType: DBData + Erase<OV>,
205    {
206        Self {
207            aggregate_factory: WithFactory::<AType>::FACTORY,
208            opt_accumulator_factory: WithFactory::<Option<AType>>::FACTORY,
209            output_factory: WithFactory::<OVType>::FACTORY,
210            rolling_aggregate_factories: PartitionedRollingAggregateFactories::new::<
211                KType,
212                VType,
213                AType,
214                OVType,
215            >(),
216        }
217    }
218}
219
220pub struct PartitionedRollingAverageFactories<TS, V, W, B, O>
221where
222    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
223    O: IndexedZSet<Key = B::Key>,
224    TS: DBData + UnsignedPrimInt,
225    V: DataTrait + ?Sized,
226    W: WeightTrait + ?Sized,
227{
228    aggregate_factories:
229        PartitionedRollingAggregateLinearFactories<TS, V, V, DynAverage<W, B::R>, B, O>,
230    weight_factory: &'static dyn Factory<W>,
231}
232
233impl<TS, V, W, B, O> PartitionedRollingAverageFactories<TS, V, W, B, O>
234where
235    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
236    O: PartitionedIndexedZSet<DynDataTyped<TS>, DynOpt<V>, Key = B::Key>,
237    TS: DBData + UnsignedPrimInt,
238    V: DataTrait + ?Sized,
239    W: WeightTrait + ?Sized,
240{
241    pub fn new<KType, VType, WType>() -> Self
242    where
243        KType: DBData + Erase<B::Key>,
244        VType: DBData + Erase<V>,
245        WType: DBWeight + From<ZWeight> + Div<Output = WType> + Erase<W>,
246    {
247        Self {
248            aggregate_factories: PartitionedRollingAggregateLinearFactories::new::<
249                KType,
250                VType,
251                Avg<WType, ZWeight>,
252                VType,
253            >(),
254            weight_factory: WithFactory::<WType>::FACTORY,
255        }
256    }
257}
258
259/// `Aggregator` object that computes a linear aggregation function.
260// TODO: we need this because we currently compute linear aggregates
261// using the same algorithm as general aggregates.  Additional performance
262// gains can be obtained with an optimized implementation of radix trees
263// for linear aggregates (specifically, updating a node when only
264// some of its children have changed can be done without computing
265// the sum of all children from scratch).
266struct LinearAggregator<V, A, O>
267where
268    V: DataTrait + ?Sized,
269    A: WeightTrait + ?Sized,
270    O: DataTrait + ?Sized,
271{
272    acc_factory: &'static dyn Factory<A>,
273    opt_accumulator_factory: &'static dyn Factory<DynOpt<A>>,
274    output_factory: &'static dyn Factory<O>,
275    f: Box<dyn WeighFunc<V, DynZWeight, A>>,
276    output_func: Box<dyn AggOutputFunc<A, O>>,
277    combine: Box<dyn AggCombineFunc<A>>,
278}
279
280impl<V, A, O> Clone for LinearAggregator<V, A, O>
281where
282    V: DataTrait + ?Sized,
283    A: WeightTrait + ?Sized,
284    O: DataTrait + ?Sized,
285{
286    fn clone(&self) -> Self {
287        Self {
288            acc_factory: self.acc_factory,
289            opt_accumulator_factory: self.opt_accumulator_factory,
290            output_factory: self.output_factory,
291            f: clone_box(self.f.as_ref()),
292            output_func: clone_box(self.output_func.as_ref()),
293            combine: clone_box(self.combine.as_ref()),
294        }
295    }
296}
297
298impl<V, A, O> LinearAggregator<V, A, O>
299where
300    V: DataTrait + ?Sized,
301    A: WeightTrait + ?Sized,
302    O: DataTrait + ?Sized,
303{
304    fn new(
305        acc_factory: &'static dyn Factory<A>,
306        opt_accumulator_factory: &'static dyn Factory<DynOpt<A>>,
307        output_factory: &'static dyn Factory<O>,
308        f: Box<dyn WeighFunc<V, DynZWeight, A>>,
309        output_func: Box<dyn AggOutputFunc<A, O>>,
310    ) -> Self {
311        Self {
312            acc_factory,
313            opt_accumulator_factory,
314            output_factory,
315            f,
316            output_func,
317            combine: Box::new(|acc, v| acc.add_assign(v)),
318        }
319    }
320}
321
322impl<V, A, O> DynAggregator<V, (), DynZWeight> for LinearAggregator<V, A, O>
323where
324    V: DataTrait + ?Sized,
325    A: WeightTrait + ?Sized,
326    O: DataTrait + ?Sized,
327{
328    type Accumulator = A;
329    type Output = O;
330
331    fn combine(&self) -> &dyn AggCombineFunc<A> {
332        self.combine.as_ref()
333    }
334
335    fn aggregate(&self, cursor: &mut dyn Cursor<V, DynUnit, (), DynZWeight>, agg: &mut DynOpt<A>) {
336        agg.set_none();
337        while cursor.key_valid() {
338            self.acc_factory.with(&mut |tmp_agg| {
339                let w = *cursor.weight().deref();
340                (self.f)(cursor.key(), w.erase(), tmp_agg);
341                match agg.get_mut() {
342                    None => agg.from_val(tmp_agg),
343                    Some(old) => old.add_assign(tmp_agg),
344                };
345            });
346            cursor.step_key();
347        }
348    }
349
350    fn finalize(&self, accumulator: &mut A, output: &mut O) {
351        (self.output_func)(accumulator, output)
352    }
353
354    fn aggregate_and_finalize(
355        &self,
356        _cursor: &mut dyn Cursor<V, DynUnit, (), DynZWeight>,
357        _output: &mut DynOpt<Self::Output>,
358    ) {
359        todo!()
360    }
361
362    fn opt_accumulator_factory(&self) -> &'static dyn Factory<DynOpt<Self::Accumulator>> {
363        self.opt_accumulator_factory
364    }
365
366    fn output_factory(&self) -> &'static dyn Factory<Self::Output> {
367        self.output_factory
368    }
369}
370
371impl<B> Stream<RootCircuit, B>
372where
373    B: IndexedZSet,
374{
375    /// See [`Stream::partitioned_rolling_aggregate_with_waterline`].
376    pub fn dyn_partitioned_rolling_aggregate_with_waterline<PK, TS, V, Acc, Out>(
377        &self,
378        persistent_id: Option<&str>,
379        factories: &PartitionedRollingAggregateWithWaterlineFactories<PK, TS, V, Acc, Out, B>,
380        waterline: &Stream<RootCircuit, Box<DynDataTyped<TS>>>,
381        partition_func: Box<dyn PartitionFunc<B::Val, PK, V>>,
382        aggregator: &dyn DynAggregator<V, (), B::R, Accumulator = Acc, Output = Out>,
383        range: RelRange<TS>,
384    ) -> OrdPartitionedOverStream<PK, DynDataTyped<TS>, Out>
385    where
386        B: IndexedZSet,
387        B: for<'a> DynFilterMap<
388            DynItemRef<'a> = (&'a <B as BatchReader>::Key, &'a <B as BatchReader>::Val),
389        >,
390        Box<B::Key>: Clone,
391        PK: DataTrait + ?Sized,
392        TS: DBData + UnsignedPrimInt + Erase<B::Key>,
393        V: DataTrait + ?Sized,
394        Acc: DataTrait + ?Sized,
395        Out: DataTrait + ?Sized,
396    {
397        self.circuit()
398            .region("partitioned_rolling_aggregate_with_waterline", || {
399                // Shift the aggregation window so that its right end is at 0.
400                let shifted_range =
401                    RelRange::new(range.from - range.to, RelOffset::Before(HasZero::zero()));
402
403                // Trace bound used inside `partitioned_rolling_aggregate_inner` to
404                // bound its output trace.  This is the same bound we use to construct
405                // the input window here.
406                let bound: TraceBound<DynPair<DynDataTyped<TS>, DynOpt<Out>>> = TraceBound::new();
407                let bound_clone = bound.clone();
408
409                let mut bound_box = factories
410                    .rolling_aggregate_factories
411                    .output_factories
412                    .val_factory()
413                    .default_box();
414
415                // Restrict the input stream to the `[lb -> ∞)` time window,
416                // where `lb = waterline - (range.to - range.from)` is the lower
417                // bound on input timestamps that may be used to compute
418                // changes to the rolling aggregate operator.
419                let bounds = waterline.apply_mut(move |wm| {
420                    let lower = shifted_range
421                        .range_of(wm.as_ref().deref())
422                        .map(|range| range.from)
423                        .unwrap_or_else(|| Bounded::min_value());
424                    **bound_box.fst_mut() = lower;
425                    bound_box.snd_mut().set_none();
426                    bound_clone.set(clone_box(bound_box.as_ref()));
427                    (
428                        Box::new(lower).erase_box(),
429                        Box::new(<TS as Bounded>::max_value()).erase_box(),
430                    )
431                });
432                let window = self
433                    .dyn_window(&factories.input_factories, (true, true), &bounds)
434                    .set_persistent_id(
435                        persistent_id
436                            .map(|name| format!("{name}.window"))
437                            .as_deref(),
438                    );
439
440                // Now that we've truncated old inputs, which required the
441                // input stream to be indexed by time, we can re-index it
442                // by partition id.
443                let partition_func_clone = clone_box(partition_func.as_ref());
444
445                let partitioned_window = window
446                    .dyn_map_index(
447                        &factories.rolling_aggregate_factories.input_factories,
448                        Box::new(move |(ts, v), res| {
449                            let (partition_key, ts_val) = res.split_mut();
450                            let (res_ts, val) = ts_val.split_mut();
451                            partition_func_clone(v, partition_key, val);
452                            unsafe { *res_ts.downcast_mut::<TS>() = *ts.downcast::<TS>() };
453                        }),
454                    )
455                    .set_persistent_id(
456                        persistent_id
457                            .map(|name| format!("{name}-partitioned_window"))
458                            .as_deref(),
459                    );
460                let partitioned_self = self
461                    .dyn_map_index(
462                        &factories.rolling_aggregate_factories.input_factories,
463                        Box::new(move |(ts, v), res| {
464                            let (partition_key, ts_val) = res.split_mut();
465                            let (res_ts, val) = ts_val.split_mut();
466                            partition_func(v, partition_key, val);
467                            unsafe { *res_ts.downcast_mut::<TS>() = *ts.downcast::<TS>() };
468                        }),
469                    )
470                    .set_persistent_id(
471                        persistent_id
472                            .map(|name| format!("{name}-partitioned"))
473                            .as_deref(),
474                    );
475
476                partitioned_self.dyn_partitioned_rolling_aggregate_inner(
477                    persistent_id,
478                    &factories.rolling_aggregate_factories,
479                    &partitioned_window,
480                    aggregator,
481                    range,
482                    bound,
483                )
484            })
485    }
486
487    /// Like [`Self::dyn_partitioned_rolling_aggregate`], but can return any
488    /// batch type.
489    pub fn dyn_partitioned_rolling_aggregate<PK, TS, V, Acc, Out>(
490        &self,
491        persistent_id: Option<&str>,
492        factories: &PartitionedRollingAggregateFactories<
493            TS,
494            V,
495            Acc,
496            Out,
497            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, V>,
498            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, DynOpt<Out>>,
499        >,
500        partition_func: Box<dyn PartitionFunc<B::Val, PK, V>>,
501        aggregator: &dyn DynAggregator<V, (), B::R, Accumulator = Acc, Output = Out>,
502        range: RelRange<TS>,
503    ) -> OrdPartitionedOverStream<PK, DynDataTyped<TS>, Out>
504    where
505        B: IndexedZSet,
506        B: for<'a> DynFilterMap<
507            DynItemRef<'a> = (&'a <B as BatchReader>::Key, &'a <B as BatchReader>::Val),
508        >,
509        Acc: DataTrait + ?Sized,
510        Out: DataTrait + ?Sized,
511        PK: DataTrait + ?Sized,
512        TS: DBData + UnsignedPrimInt + Erase<B::Key>,
513        V: DataTrait + ?Sized,
514    {
515        // ```
516        //                  ┌───────────────┐   input_trace
517        //      ┌──────────►│integrate_trace├──────────────┐                              output
518        //      │           └───────────────┘              │                           ┌────────────────────────────────────►
519        //      │                                          ▼                           │
520        // self │    ┌──────────────────────────┐  tree  ┌───────────────────────────┐ │  ┌──────────────────┐ output_trace
521        // ─────┼───►│partitioned_tree_aggregate├───────►│PartitionedRollingAggregate├─┴──┤UntimedTraceAppend├────────┐
522        //      │    └──────────────────────────┘        └───────────────────────────┘    └──────────────────┘        │
523        //      │                                          ▲               ▲                 ▲                        │
524        //      └──────────────────────────────────────────┘               │                 │                        │
525        //                                                                 │               ┌─┴──┐                     │
526        //                                                                 └───────────────┤Z^-1│◄────────────────────┘
527        //                                                                   delayed_trace └────┘
528        // ```
529        self.circuit().region("partitioned_rolling_aggregate", || {
530            let partitioned = self
531                .dyn_map_index(
532                    &factories.input_factories,
533                    Box::new(move |(ts, v), res| {
534                        let (partition_key, ts_val) = res.split_mut();
535                        let (res_ts, val) = ts_val.split_mut();
536                        partition_func(v, partition_key, val);
537                        unsafe { *res_ts.downcast_mut::<TS>() = *ts.downcast::<TS>() };
538                    }),
539                )
540                .set_persistent_id(
541                    persistent_id
542                        .map(|name| format!("{name}-partitioned"))
543                        .as_deref(),
544                );
545
546            partitioned.dyn_partitioned_rolling_aggregate_inner(
547                persistent_id,
548                factories,
549                &partitioned,
550                aggregator,
551                range,
552                TraceBound::new(),
553            )
554        })
555    }
556
557    /// See [`Stream::partitioned_rolling_aggregate_linear`].
558    pub fn dyn_partitioned_rolling_aggregate_linear<PK, TS, V, A, O>(
559        &self,
560        persistent_id: Option<&str>,
561        factories: &PartitionedRollingAggregateLinearFactories<
562            TS,
563            V,
564            O,
565            A,
566            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, V>,
567            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, DynOpt<O>>,
568        >,
569        partition_func: Box<dyn PartitionFunc<B::Val, PK, V>>,
570        f: Box<dyn WeighFunc<V, B::R, A>>,
571        output_func: Box<dyn AggOutputFunc<A, O>>,
572        range: RelRange<TS>,
573    ) -> OrdPartitionedOverStream<PK, DynDataTyped<TS>, O>
574    where
575        B: IndexedZSet,
576        B: for<'a> DynFilterMap<
577            DynItemRef<'a> = (&'a <B as BatchReader>::Key, &'a <B as BatchReader>::Val),
578        >,
579        PK: DataTrait + ?Sized,
580        TS: DBData + UnsignedPrimInt + Erase<B::Key>,
581        V: DataTrait + ?Sized,
582        A: WeightTrait + ?Sized,
583        O: DataTrait + ?Sized,
584    {
585        let aggregator = LinearAggregator::new(
586            factories.aggregate_factory,
587            factories.opt_accumulator_factory,
588            factories.output_factory,
589            f,
590            output_func,
591        );
592        self.dyn_partitioned_rolling_aggregate::<PK, TS, V, _, _>(
593            persistent_id,
594            &factories.rolling_aggregate_factories,
595            partition_func,
596            &aggregator,
597            range,
598        )
599    }
600
601    pub fn dyn_partitioned_rolling_average<PK, TS, V, W>(
602        &self,
603        persistent_id: Option<&str>,
604        factories: &PartitionedRollingAverageFactories<
605            TS,
606            V,
607            W,
608            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, V>,
609            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, DynOpt<V>>,
610        >,
611        partition_func: Box<dyn PartitionFunc<B::Val, PK, V>>,
612        f: Box<dyn WeighFunc<V, B::R, W>>,
613        out_func: Box<dyn AggOutputFunc<W, V>>,
614        range: RelRange<TS>,
615    ) -> OrdPartitionedOverStream<PK, DynDataTyped<TS>, V>
616    where
617        B: IndexedZSet,
618        B: for<'a> DynFilterMap<
619            DynItemRef<'a> = (&'a <B as BatchReader>::Key, &'a <B as BatchReader>::Val),
620        >,
621        PK: DataTrait + ?Sized,
622        TS: DBData + UnsignedPrimInt + Erase<B::Key>,
623        V: DataTrait + ?Sized,
624        W: WeightTrait + ?Sized,
625    {
626        let weight_factory = factories.weight_factory;
627        self.dyn_partitioned_rolling_aggregate_linear(
628            persistent_id,
629            &factories.aggregate_factories,
630            partition_func,
631            Box::new(move |v: &V, w: &B::R, avg: &mut DynAverage<W, B::R>| {
632                let (sum, count) = avg.split_mut();
633                w.clone_to(count);
634                f(v, w, sum);
635            }),
636            Box::new(move |avg, out| {
637                weight_factory.with(&mut |avg_val| {
638                    avg.compute_avg(avg_val);
639                    out_func(avg_val, out)
640                })
641            }),
642            range,
643        )
644    }
645}
646
647impl<B> Stream<RootCircuit, B> {
648    #[doc(hidden)]
649    pub fn dyn_partitioned_rolling_aggregate_inner<TS, V, Acc, Out, O>(
650        &self,
651        partition_id: Option<&str>,
652        factories: &PartitionedRollingAggregateFactories<TS, V, Acc, Out, B, O>,
653        self_window: &Self,
654        aggregator: &dyn DynAggregator<V, (), DynZWeight, Accumulator = Acc, Output = Out>,
655        range: RelRange<TS>,
656        bound: TraceBound<DynPair<DynDataTyped<TS>, DynOpt<Out>>>,
657    ) -> Stream<RootCircuit, O>
658    where
659        B: PartitionedIndexedZSet<DynDataTyped<TS>, V> + Send,
660        O: PartitionedIndexedZSet<DynDataTyped<TS>, DynOpt<Out>, Key = B::Key>,
661        Acc: DataTrait + ?Sized,
662        Out: DataTrait + ?Sized,
663        TS: DBData + UnsignedPrimInt,
664        V: DataTrait + ?Sized,
665    {
666        let circuit = self.circuit();
667
668        let partitioned_tree_aggregate_name =
669            partition_id.map(|name| format!("{name}-tree_aggregate"));
670
671        // Build the radix tree over the bounded window.
672        let tree = self_window
673            .partitioned_tree_aggregate::<TS, V, Acc, Out>(
674                partitioned_tree_aggregate_name.as_deref(),
675                &factories.partitioned_tree_aggregate_factories,
676                aggregator,
677            )
678            .set_persistent_id(partitioned_tree_aggregate_name.as_deref())
679            .dyn_accumulate_integrate_trace(&factories.radix_tree_factories);
680
681        let input_trace =
682            self_window.dyn_shard_accumulate_integrate_trace(&factories.input_factories);
683
684        // Truncate timestamps `< bound` in the output trace.
685        let bounds = TraceBounds::new();
686        bounds.add_key_bound(TraceBound::new());
687        bounds.add_val_bound(bound);
688
689        let feedback = circuit.add_accumulate_integrate_trace_feedback::<Spine<O>>(
690            partition_id,
691            &factories.output_factories,
692            bounds,
693        );
694
695        let output = circuit
696            .add_quaternary_operator(
697                StreamingQuaternaryWrapper::new(
698                    <PartitionedRollingAggregate<TS, B, V, Acc, Out, _>>::new(
699                        &factories.output_factories,
700                        range,
701                        aggregator,
702                    ),
703                ),
704                &self.dyn_shard_accumulate(&factories.input_factories),
705                &input_trace,
706                &tree,
707                &feedback.delayed_trace,
708            )
709            .mark_distinct()
710            .mark_sharded();
711
712        feedback.connect(&output, &factories.output_factories);
713
714        output
715    }
716}
717
718/// Quaternary operator that implements the internals of
719/// `partitioned_rolling_aggregate`.
720///
721/// * Input stream 1: updates to the time series.  Used to identify affected
722///   partitions and times.
723/// * Input stream 2: trace containing the accumulated time series data.
724/// * Input stream 3: trace containing the partitioned radix tree over the input
725///   time series.
726/// * Input stream 4: trace of previously produced outputs.  Used to compute
727///   retractions.
728struct PartitionedRollingAggregate<
729    TS: DBData,
730    B: PartitionedBatch<DynDataTyped<TS>, V, R = DynZWeight>,
731    V: DataTrait + ?Sized,
732    Acc: DataTrait + ?Sized,
733    Out: DataTrait + ?Sized,
734    O: Batch,
735> {
736    output_factories: O::Factories,
737    range: RelRange<TS>,
738    aggregator: Box<dyn DynAggregator<V, (), DynZWeight, Accumulator = Acc, Output = Out>>,
739    flush: Cell<bool>,
740    input_delta: RefCell<Option<SpineSnapshot<B>>>,
741
742    // Input batch sizes.
743    input_batch_stats: RefCell<BatchSizeStats>,
744
745    // Output batch sizes.
746    output_batch_stats: RefCell<BatchSizeStats>,
747
748    phantom: PhantomData<fn(&V, &O)>,
749}
750
751impl<TS, B, V, Acc, Out, O> PartitionedRollingAggregate<TS, B, V, Acc, Out, O>
752where
753    TS: DBData,
754    B: PartitionedBatch<DynDataTyped<TS>, V, R = DynZWeight>,
755    V: DataTrait + ?Sized,
756    Acc: DataTrait + ?Sized,
757    Out: DataTrait + ?Sized,
758    O: Batch,
759{
760    fn new(
761        output_factories: &O::Factories,
762        range: RelRange<TS>,
763        aggregator: &dyn DynAggregator<V, (), DynZWeight, Accumulator = Acc, Output = Out>,
764    ) -> Self {
765        Self {
766            output_factories: output_factories.clone(),
767            range,
768            aggregator: clone_box(aggregator),
769            flush: Cell::new(false),
770            input_delta: RefCell::new(None),
771            input_batch_stats: RefCell::new(BatchSizeStats::new()),
772            output_batch_stats: RefCell::new(BatchSizeStats::new()),
773
774            phantom: PhantomData,
775        }
776    }
777
778    fn affected_ranges<R, C>(&self, delta_cursor: &mut C) -> Ranges<TS>
779    where
780        C: Cursor<DynDataTyped<TS>, V, (), R>,
781        TS: DBData + UnsignedPrimInt,
782        R: ?Sized,
783    {
784        let mut affected_ranges = Ranges::new();
785        let mut delta_ranges = Ranges::new();
786
787        while delta_cursor.key_valid() {
788            if let Some(range) = self.range.affected_range_of(delta_cursor.key().deref()) {
789                affected_ranges.push_monotonic(range);
790            }
791            // If `delta_cursor.key()` is a new key that doesn't yet occur in the input
792            // z-set, we need to compute its aggregate even if it is outside
793            // affected range.
794            delta_ranges.push_monotonic(Range::new(**delta_cursor.key(), **delta_cursor.key()));
795            delta_cursor.step_key();
796        }
797
798        affected_ranges.merge(&delta_ranges)
799    }
800}
801
802impl<TS, B, V, Acc, Out, O> Operator for PartitionedRollingAggregate<TS, B, V, Acc, Out, O>
803where
804    TS: DBData,
805    B: PartitionedBatch<DynDataTyped<TS>, V, R = DynZWeight>,
806    V: DataTrait + ?Sized,
807    Acc: DataTrait + ?Sized,
808    Out: DataTrait + ?Sized,
809    O: Batch,
810{
811    fn name(&self) -> Cow<'static, str> {
812        Cow::from("PartitionedRollingAggregate")
813    }
814
815    fn metadata(&self, meta: &mut OperatorMeta) {
816        meta.extend(metadata! {
817            INPUT_BATCHES_STATS => self.input_batch_stats.borrow().metadata(),
818            OUTPUT_BATCHES_STATS => self.output_batch_stats.borrow().metadata(),
819        });
820    }
821
822    fn fixedpoint(&self, _scope: Scope) -> bool {
823        true
824    }
825
826    fn flush(&mut self) {
827        self.flush.set(true);
828    }
829}
830
831impl<TS, V, Acc, Out, B, T, RT, OT, O>
832    StreamingQuaternaryOperator<Option<Spine<B>>, Spine<T>, Spine<RT>, Spine<OT>, O>
833    for PartitionedRollingAggregate<TS, B, V, Acc, Out, O>
834where
835    TS: DBData + UnsignedPrimInt,
836    V: DataTrait + ?Sized,
837    Acc: DataTrait + ?Sized,
838    Out: DataTrait + ?Sized,
839    B: PartitionedBatch<DynDataTyped<TS>, V, R = DynZWeight>,
840    T: PartitionedBatch<DynDataTyped<TS>, V, Key = B::Key, R = B::R> + Clone,
841    RT: PartitionedRadixTreeBatch<TS, Acc, Key = B::Key> + Clone,
842    OT: PartitionedBatch<DynDataTyped<TS>, DynOpt<Out>, Key = B::Key, R = B::R> + Clone,
843    O: IndexedZSet<Key = B::Key, Val = DynPair<DynDataTyped<TS>, DynOpt<Out>>>,
844{
845    fn eval(
846        self: Rc<Self>,
847        input_delta: Cow<'_, Option<Spine<B>>>,
848        input_trace: Cow<'_, Spine<T>>,
849        radix_tree: Cow<'_, Spine<RT>>,
850        output_trace: Cow<'_, Spine<OT>>,
851    ) -> impl AsyncStream<Item = (O, bool, Option<Position>)> + 'static {
852        let chunk_size = splitter_output_chunk_size();
853
854        if let Some(input_delta) = input_delta.as_ref().as_ref() {
855            assert!(self.input_delta.borrow().is_none());
856            *self.input_delta.borrow_mut() = Some(input_delta.ro_snapshot());
857        };
858
859        let input_trace = if self.flush.get() {
860            Some(input_trace.as_ref().ro_snapshot())
861        } else {
862            None
863        };
864
865        let radix_tree = if self.flush.get() {
866            Some(radix_tree.as_ref().ro_snapshot())
867        } else {
868            None
869        };
870
871        let output_trace = if self.flush.get() {
872            Some(output_trace.as_ref().ro_snapshot())
873        } else {
874            None
875        };
876
877        stream! {
878            if !self.flush.get() {
879                yield (O::dyn_empty(&self.output_factories), true, None);
880                return;
881            }
882
883            let input_delta = self.input_delta.borrow_mut().take().unwrap();
884            self.input_batch_stats.borrow_mut().add_batch(input_delta.len());
885
886            let mut delta_cursor = input_delta.cursor();
887            let mut output_trace_cursor = output_trace.unwrap().cursor();
888            let mut input_trace_cursor = input_trace.unwrap().cursor();
889            let mut tree_cursor = radix_tree.unwrap().cursor();
890
891            // Limit the initial capacity of the builder in case the chunk size
892            // is bigger than memory (e.g. `usize::MAX`).
893            let capacity = splitter_output_first_chunk_size();
894            let mut retraction_builder =
895                O::Builder::with_capacity(&self.output_factories, capacity, capacity);
896            let mut insertion_builder =
897                O::Builder::with_capacity(&self.output_factories, capacity, capacity);
898
899            // println!("delta: {input_delta:#x?}");
900            // println!("radix tree: {radix_tree:#x?}");
901            // println!("aggregate_range({range:x?})");
902            // let mut treestr = String::new();
903            // radix_tree.cursor().format_tree(&mut treestr).unwrap();
904            // println!("tree: {treestr}");
905            // tree_partition_cursor.rewind_keys();
906
907            let mut val = self.output_factories.val_factory().default_box();
908            let mut acc = self.aggregator.opt_accumulator_factory().default_box();
909            let mut agg = self.aggregator.output_factory().default_box();
910
911            // Iterate over affected partitions.
912            while delta_cursor.key_valid() {
913                // Compute affected intervals using `input_delta`.
914                let ranges = self.affected_ranges(&mut PartitionCursor::new(&mut delta_cursor));
915                // println!("affected_ranges: {ranges:?}");
916
917                // Clear old outputs.
918                let hash = delta_cursor.key().default_hash();
919                if output_trace_cursor.seek_key_exact(delta_cursor.key(), Some(hash)) {
920                    let mut range_cursor = RangeCursor::new(
921                        PartitionCursor::new(&mut output_trace_cursor),
922                        ranges.clone(),
923                    );
924                    let mut any_values = false;
925                    while range_cursor.key_valid() {
926                        while range_cursor.val_valid() {
927                            let weight = **range_cursor.weight();
928                            debug_assert!(weight != 0);
929                            val.from_refs(range_cursor.key(), range_cursor.val());
930                            retraction_builder.push_val_diff_mut(&mut *val, &mut weight.neg());
931                            any_values = true;
932
933                            if retraction_builder.num_tuples() >= chunk_size {
934                                retraction_builder.push_key(delta_cursor.key());
935                                let result = retraction_builder.done();
936                                self.output_batch_stats.borrow_mut().add_batch(result.len());
937                                yield (result, false, delta_cursor.position());
938                                any_values = false;
939                                retraction_builder = O::Builder::with_capacity(&self.output_factories, chunk_size, chunk_size);
940                            }
941                            range_cursor.step_val();
942                        }
943                        range_cursor.step_key();
944                    }
945                    if any_values {
946                        retraction_builder.push_key(delta_cursor.key());
947                    }
948                };
949
950                // Compute new outputs.
951                if input_trace_cursor.seek_key_exact(delta_cursor.key(), Some(hash))
952                    // It's possible that the key is in the input trace with weight 0, but it's no longer in the tree, which
953                    // caused `test_empty_tree()` to fail without this check.
954                    && tree_cursor.seek_key_exact(delta_cursor.key(), Some(hash))
955                {
956                    let mut tree_partition_cursor = PartitionCursor::new(&mut tree_cursor);
957                    let mut input_range_cursor =
958                        RangeCursor::new(PartitionCursor::new(&mut input_trace_cursor), ranges);
959
960                    // For all affected times, seek them in `input_trace`, compute aggregates using
961                    // using radix_tree.
962                    let mut any_values = false;
963                    while input_range_cursor.key_valid() {
964                        let range = self.range.range_of(input_range_cursor.key());
965                        tree_partition_cursor.rewind_keys();
966
967                        // println!("aggregate_range({range:x?})");
968                        // let mut treestr = String::new();
969                        // tree_partition_cursor.format_tree(&mut treestr).unwrap();
970                        // println!("tree: {treestr}");
971                        // tree_partition_cursor.rewind_keys();
972
973                        while input_range_cursor.val_valid() {
974                            // Generate output update.
975                            if !input_range_cursor.weight().le0() {
976                                **val.fst_mut() = **input_range_cursor.key();
977                                if let Some(range) = range {
978                                    tree_partition_cursor.aggregate_range(
979                                        &range,
980                                        self.aggregator.combine(),
981                                        acc.as_mut(),
982                                    );
983                                    if let Some(acc) = acc.get_mut() {
984                                        self.aggregator.finalize(acc, agg.as_mut());
985                                        val.snd_mut().from_val(agg.as_mut());
986                                    } else {
987                                        val.snd_mut().set_none();
988                                    }
989                                } else {
990                                    val.snd_mut().set_none();
991                                }
992
993                                // println!("insert({item:?})");
994
995                                insertion_builder.push_val_diff_mut(&mut *val, 1.erase_mut());
996                                any_values = true;
997
998                                if insertion_builder.num_tuples() >= chunk_size {
999                                    insertion_builder.push_key(delta_cursor.key());
1000                                    any_values = false;
1001
1002                                    let result = insertion_builder.done();
1003                                    self.output_batch_stats.borrow_mut().add_batch(result.len());
1004
1005                                    yield (result, false, delta_cursor.position());
1006                                    insertion_builder =
1007                                        O::Builder::with_capacity(&self.output_factories, chunk_size, chunk_size);
1008                                }
1009
1010                                break;
1011                            }
1012
1013                            input_range_cursor.step_val();
1014                        }
1015
1016                        input_range_cursor.step_key();
1017                    }
1018                    if any_values {
1019                        insertion_builder.push_key(delta_cursor.key());
1020                    }
1021                }
1022
1023                delta_cursor.step_key();
1024            }
1025
1026            self.flush.set(false);
1027            let retractions = retraction_builder.done();
1028            let insertions = insertion_builder.done();
1029
1030            let result = merge_batches(&insertions.factories(), [insertions,retractions], &None, &None);
1031            self.output_batch_stats.borrow_mut().add_batch(result.len());
1032            yield (result, true, delta_cursor.position());
1033        }
1034    }
1035}
1036
1037#[cfg(test)]
1038mod test {
1039    use crate::{
1040        DBData, DBSPHandle, IndexedZSetHandle, OrdIndexedZSet, OutputHandle, RootCircuit, Runtime,
1041        Stream, TypedBox, ZWeight,
1042        algebra::{DefaultSemigroup, UnsignedPrimInt},
1043        circuit::CircuitConfig,
1044        dynamic::{DowncastTrait, DynData, DynDataTyped, DynOpt, DynPair, Erase},
1045        lean_vec,
1046        operator::{
1047            Fold,
1048            dynamic::{
1049                input::AddInputIndexedZSetFactories,
1050                time_series::{
1051                    PartitionCursor,
1052                    range::{Range, RelOffset, RelRange},
1053                },
1054                trace::TraceBound,
1055            },
1056            time_series::OrdPartitionedIndexedZSet,
1057        },
1058        trace::{BatchReaderFactories, Cursor},
1059        typed_batch::{
1060            DynBatchReader, DynOrdIndexedZSet, IndexedZSetReader, SpineSnapshot, TypedBatch,
1061        },
1062        utils::Tup2,
1063    };
1064    use proptest::{collection, prelude::*};
1065    use size_of::SizeOf;
1066
1067    type DataBatch = DynOrdIndexedZSet<
1068        DynData, /* <u64> */
1069        DynPair<DynDataTyped<u64>, DynData /* <i64> */>,
1070    >;
1071    type OutputBatch = TypedBatch<
1072        u64,
1073        Tup2<u64, Option<i64>>,
1074        ZWeight,
1075        DynOrdIndexedZSet<
1076            DynData, /* <u64> */
1077            DynPair<DynDataTyped<u64>, DynOpt<DynData /* <i64> */>>,
1078        >,
1079    >;
1080
1081    impl<PK, TS, V> Stream<RootCircuit, OrdIndexedZSet<PK, Tup2<TS, V>>>
1082    where
1083        PK: DBData,
1084        TS: DBData + UnsignedPrimInt,
1085        V: DBData,
1086    {
1087        pub fn as_partitioned_zset(
1088            &self,
1089        ) -> Stream<RootCircuit, OrdPartitionedIndexedZSet<PK, TS, DynDataTyped<TS>, V, DynData>>
1090        {
1091            let factories = BatchReaderFactories::new::<PK, Tup2<TS, V>, ZWeight>();
1092
1093            self.inner()
1094                .dyn_map_index(
1095                    &factories,
1096                    Box::new(|(k, v), kv| {
1097                        kv.from_refs(k, unsafe { v.downcast::<Tup2<TS, V>>().erase() })
1098                    }),
1099                )
1100                .typed()
1101        }
1102    }
1103
1104    // Reference implementation of `aggregate_range` for testing.
1105    fn aggregate_range_slow(batch: &DataBatch, partition: u64, range: Range<u64>) -> Option<i64> {
1106        let mut cursor = batch.cursor();
1107
1108        cursor.seek_key(&partition);
1109        assert!(cursor.key_valid());
1110        assert!(*cursor.key().downcast_checked::<u64>() == partition);
1111        let mut partition_cursor = PartitionCursor::new(&mut cursor);
1112
1113        let mut agg = None;
1114        partition_cursor.seek_key(&range.from);
1115        while partition_cursor.key_valid()
1116            && *partition_cursor.key().downcast_checked::<u64>() <= range.to
1117        {
1118            while partition_cursor.val_valid() {
1119                let w = *partition_cursor.weight().downcast_checked::<ZWeight>();
1120                debug_assert!(w != 0);
1121                agg = if let Some(a) = agg {
1122                    Some(a + *partition_cursor.val().downcast_checked::<i64>() * w)
1123                } else {
1124                    Some(*partition_cursor.val().downcast_checked::<i64>() * w)
1125                };
1126                partition_cursor.step_val();
1127            }
1128            partition_cursor.step_key();
1129        }
1130
1131        agg
1132    }
1133
1134    // Reference implementation of `partitioned_rolling_aggregate` for testing.
1135    fn partitioned_rolling_aggregate_slow(
1136        stream: &Stream<RootCircuit, DataBatch>,
1137        range_spec: RelRange<u64>,
1138    ) -> Stream<RootCircuit, OutputBatch> {
1139        let stream = stream.typed::<TypedBatch<u64, Tup2<u64, i64>, ZWeight, _>>();
1140
1141        stream
1142            .circuit()
1143            .non_incremental(&stream, |_child, stream| {
1144                Ok(stream
1145                    .gather(0)
1146                    .integrate()
1147                    .apply(move |batch: &TypedBatch<_, _, _, DataBatch>| {
1148                        let mut tuples = Vec::with_capacity(batch.len());
1149
1150                        let mut cursor = batch.cursor();
1151
1152                        while cursor.key_valid() {
1153                            while cursor.val_valid() {
1154                                let partition = *cursor.key().downcast_checked::<u64>();
1155                                let Tup2(ts, _val) =
1156                                    *cursor.val().downcast_checked::<Tup2<u64, i64>>();
1157                                let agg = range_spec.range_of(&ts).and_then(|range| {
1158                                    aggregate_range_slow(batch, partition, range)
1159                                });
1160                                tuples.push(Tup2(Tup2(partition, Tup2(ts, agg)), 1));
1161                                cursor.step_val();
1162                            }
1163                            cursor.step_key();
1164                        }
1165
1166                        OutputBatch::from_tuples((), tuples)
1167                    })
1168                    .stream_distinct()
1169                    .gather(0))
1170            })
1171            .unwrap()
1172    }
1173
1174    type TestOutputHandle = OutputHandle<
1175        SpineSnapshot<
1176            OrdPartitionedIndexedZSet<u64, u64, DynDataTyped<u64>, Option<i64>, DynOpt<DynData>>,
1177        >,
1178    >;
1179
1180    fn partition_rolling_aggregate_circuit(
1181        lateness: u64,
1182        size_bound: Option<usize>,
1183    ) -> (
1184        DBSPHandle,
1185        (
1186            IndexedZSetHandle<u64, Tup2<u64, i64>>,
1187            TestOutputHandle,
1188            TestOutputHandle,
1189            TestOutputHandle,
1190            TestOutputHandle,
1191            TestOutputHandle,
1192            TestOutputHandle,
1193            TestOutputHandle,
1194            TestOutputHandle,
1195            TestOutputHandle,
1196            TestOutputHandle,
1197            TestOutputHandle,
1198            TestOutputHandle,
1199            TestOutputHandle,
1200            TestOutputHandle,
1201        ),
1202    ) {
1203        Runtime::init_circuit(
1204            CircuitConfig::from(2).with_splitter_chunk_size_records(6),
1205            move |circuit| {
1206                let (input_stream, input_handle) =
1207                    circuit.add_input_indexed_zset::<u64, Tup2<u64, i64>>();
1208
1209                let input_by_time = input_stream
1210                    .map_index(|(partition, Tup2(ts, val))| (*ts, Tup2(*partition, *val)));
1211
1212                let input_stream = input_stream.as_partitioned_zset();
1213
1214                let waterline: Stream<_, TypedBox<u64, DynDataTyped<u64>>> = input_by_time
1215                    .waterline_monotonic(|| 0, move |ts| ts.saturating_sub(lateness))
1216                    .transaction_delay_with_initial_value(TypedBox::new(0))
1217                    /* .inspect(|w| println!("waterline: {w:?}"))*/;
1218
1219                let aggregator = <Fold<i64, i64, DefaultSemigroup<_>, _, _>>::new(
1220                    0i64,
1221                    |agg: &mut i64, val: &i64, w: ZWeight| *agg += val * w,
1222                );
1223
1224                let range_spec = RelRange::new(RelOffset::Before(1000), RelOffset::Before(0));
1225                let output_1000_0 = input_by_time
1226                    .partitioned_rolling_aggregate(
1227                        |Tup2(partition, val)| (*partition, *val),
1228                        aggregator.clone(),
1229                        range_spec,
1230                    )
1231                    .accumulate_integrate()
1232                    .accumulate_output();
1233
1234                let output_1000_0_expected =
1235                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1236                        .accumulate_output();
1237
1238                let output_1000_0_waterline = Stream::partitioned_rolling_aggregate_with_waterline(
1239                    &input_by_time,
1240                    &waterline,
1241                    |Tup2(partition, val)| (*partition, *val),
1242                    aggregator.clone(),
1243                    range_spec,
1244                )
1245                .accumulate_integrate()
1246                .accumulate_output();
1247
1248                let output_1000_0_waterline_expected =
1249                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1250                        .accumulate_output();
1251
1252                let output_1000_0_linear = input_by_time
1253                    .partitioned_rolling_aggregate_linear(
1254                        |Tup2(partition, val)| (*partition, *val),
1255                        |v| *v,
1256                        |v| v,
1257                        range_spec,
1258                    )
1259                    .accumulate_integrate()
1260                    .accumulate_output();
1261
1262                let output_1000_0_linear_expected =
1263                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1264                        .accumulate_output();
1265
1266                let range_spec = RelRange::new(RelOffset::Before(500), RelOffset::After(500));
1267                let aggregate_500_500 = input_by_time
1268                    .partitioned_rolling_aggregate(
1269                        |Tup2(partition, val)| (*partition, *val),
1270                        aggregator.clone(),
1271                        range_spec,
1272                    )
1273                    .accumulate_integrate()
1274                    .accumulate_output();
1275
1276                let aggregate_500_500_expected =
1277                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1278                        .accumulate_output();
1279
1280                let aggregate_500_500_waterline = input_by_time
1281                    .partitioned_rolling_aggregate_with_waterline(
1282                        &waterline,
1283                        |Tup2(partition, val)| (*partition, *val),
1284                        aggregator.clone(),
1285                        range_spec,
1286                    );
1287
1288                // let output_500_500_waterline = aggregate_500_500_waterline.gather(0).integrate();
1289
1290                let bound: TraceBound<DynPair<DynDataTyped<u64>, DynOpt<DynData>>> =
1291                    TraceBound::new();
1292                let b: Tup2<u64, Option<i64>> = Tup2(u64::MAX, None::<i64>);
1293
1294                bound.set(Box::new(b).erase_box());
1295
1296                aggregate_500_500_waterline
1297                    .integrate_trace_with_bound(TraceBound::new(), bound)
1298                    .apply(move |trace| {
1299                        if let Some(bound) = size_bound {
1300                            assert!(trace.size_of().total_bytes() <= bound);
1301                        }
1302                    });
1303
1304                let aggregate_500_500_waterline = aggregate_500_500_waterline
1305                    .accumulate_integrate()
1306                    .accumulate_output();
1307
1308                let aggregate_500_500_waterline_expected =
1309                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1310                        .accumulate_output();
1311
1312                let output_500_500_linear = input_by_time
1313                    .partitioned_rolling_aggregate_linear(
1314                        |Tup2(partition, val)| (*partition, *val),
1315                        |v| *v,
1316                        |v| v,
1317                        range_spec,
1318                    )
1319                    .accumulate_integrate()
1320                    .accumulate_output();
1321
1322                let output_500_500_linear_expected =
1323                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1324                        .accumulate_output();
1325
1326                let range_spec = RelRange::new(RelOffset::Before(500), RelOffset::Before(100));
1327                let output_500_100 = input_by_time
1328                    .partitioned_rolling_aggregate(
1329                        |Tup2(partition, val)| (*partition, *val),
1330                        aggregator,
1331                        range_spec,
1332                    )
1333                    .accumulate_integrate()
1334                    .accumulate_output();
1335
1336                let output_500_100_expected =
1337                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1338                        .accumulate_output();
1339
1340                Ok((
1341                    input_handle,
1342                    output_1000_0,
1343                    output_1000_0_expected,
1344                    output_1000_0_waterline,
1345                    output_1000_0_waterline_expected,
1346                    output_1000_0_linear,
1347                    output_1000_0_linear_expected,
1348                    aggregate_500_500,
1349                    aggregate_500_500_expected,
1350                    aggregate_500_500_waterline,
1351                    aggregate_500_500_waterline_expected,
1352                    output_500_500_linear,
1353                    output_500_500_linear_expected,
1354                    output_500_100,
1355                    output_500_100_expected,
1356                ))
1357            },
1358        )
1359        .unwrap()
1360    }
1361
1362    fn test_partition_rolling_aggregate(
1363        lateness: u64,
1364        size_bound: Option<usize>,
1365        trace: Vec<InputBatch>,
1366        transaction: bool,
1367    ) {
1368        let (
1369            mut circuit,
1370            (
1371                input,
1372                output_1000_0,
1373                output_1000_0_expected,
1374                output_1000_0_waterline,
1375                output_1000_0_waterline_expected,
1376                output_1000_0_linear,
1377                output_1000_0_linear_expected,
1378                aggregate_500_500,
1379                aggregate_500_500_expected,
1380                aggregate_500_500_waterline,
1381                aggregate_500_500_waterline_expected,
1382                aggregate_500_500_linear,
1383                aggregate_500_500_linear_expected,
1384                output_500_100,
1385                output_500_100_expected,
1386            ),
1387        ) = partition_rolling_aggregate_circuit(lateness, size_bound);
1388
1389        if transaction {
1390            circuit.start_transaction().unwrap();
1391            for mut batch in trace {
1392                input.append(&mut batch);
1393                circuit.step().unwrap();
1394            }
1395
1396            circuit.commit_transaction().unwrap();
1397
1398            assert_eq!(
1399                output_1000_0.concat().consolidate(),
1400                output_1000_0_expected.concat().consolidate()
1401            );
1402            assert_eq!(
1403                output_1000_0_waterline.concat().consolidate(),
1404                output_1000_0_waterline_expected.concat().consolidate()
1405            );
1406            assert_eq!(
1407                output_1000_0_linear.concat().consolidate(),
1408                output_1000_0_linear_expected.concat().consolidate()
1409            );
1410            assert_eq!(
1411                aggregate_500_500.concat().consolidate(),
1412                aggregate_500_500_expected.concat().consolidate()
1413            );
1414            assert_eq!(
1415                aggregate_500_500_waterline.concat().consolidate(),
1416                aggregate_500_500_waterline_expected.concat().consolidate()
1417            );
1418            assert_eq!(
1419                aggregate_500_500_linear.concat().consolidate(),
1420                aggregate_500_500_linear_expected.concat().consolidate()
1421            );
1422            assert_eq!(
1423                output_500_100.concat().consolidate(),
1424                output_500_100_expected.concat().consolidate()
1425            );
1426        } else {
1427            for mut batch in trace {
1428                input.append(&mut batch);
1429                circuit.transaction().unwrap();
1430
1431                assert_eq!(
1432                    output_1000_0.concat().consolidate(),
1433                    output_1000_0_expected.concat().consolidate()
1434                );
1435                assert_eq!(
1436                    output_1000_0_waterline.concat().consolidate(),
1437                    output_1000_0_waterline_expected.concat().consolidate()
1438                );
1439                assert_eq!(
1440                    output_1000_0_linear.concat().consolidate(),
1441                    output_1000_0_linear_expected.concat().consolidate()
1442                );
1443                assert_eq!(
1444                    aggregate_500_500.concat().consolidate(),
1445                    aggregate_500_500_expected.concat().consolidate()
1446                );
1447                assert_eq!(
1448                    aggregate_500_500_waterline.concat().consolidate(),
1449                    aggregate_500_500_waterline_expected.concat().consolidate()
1450                );
1451                assert_eq!(
1452                    aggregate_500_500_linear.concat().consolidate(),
1453                    aggregate_500_500_linear_expected.concat().consolidate()
1454                );
1455                assert_eq!(
1456                    output_500_100.concat().consolidate(),
1457                    output_500_100_expected.concat().consolidate()
1458                );
1459            }
1460        }
1461
1462        circuit.kill().unwrap();
1463    }
1464
1465    #[test]
1466    fn test_partitioned_over_range_2() {
1467        test_partition_rolling_aggregate(
1468            u64::MAX,
1469            None,
1470            vec![
1471                vec![Tup2(2u64, Tup2(Tup2(110271u64, 100i64), 1i64))],
1472                vec![Tup2(2u64, Tup2(Tup2(0u64, 100i64), 1i64))],
1473            ],
1474            false,
1475        );
1476    }
1477
1478    #[test]
1479    fn test_partitioned_over_range() {
1480        test_partition_rolling_aggregate(
1481            u64::MAX,
1482            None,
1483            vec![
1484                vec![
1485                    Tup2(0u64, Tup2(Tup2(1u64, 100i64), 1)),
1486                    Tup2(0, Tup2(Tup2(10, 100), 1)),
1487                    Tup2(0, Tup2(Tup2(20, 100), 1)),
1488                    Tup2(0, Tup2(Tup2(30, 100), 1)),
1489                ],
1490                vec![
1491                    Tup2(0u64, Tup2(Tup2(1u64, 100i64), 1)),
1492                    Tup2(0, Tup2(Tup2(10, 100), 1)),
1493                    Tup2(0, Tup2(Tup2(20, 100), 1)),
1494                    Tup2(0, Tup2(Tup2(30, 100), 1)),
1495                ],
1496                vec![
1497                    Tup2(0u64, Tup2(Tup2(5u64, 100i64), 1)),
1498                    Tup2(0, Tup2(Tup2(15, 100), 1)),
1499                    Tup2(0, Tup2(Tup2(25, 100), 1)),
1500                    Tup2(0, Tup2(Tup2(35, 100), 1)),
1501                ],
1502                vec![
1503                    Tup2(1u64, Tup2(Tup2(1u64, 100i64), 1)),
1504                    Tup2(1, Tup2(Tup2(1000, 100), 1)),
1505                    Tup2(1, Tup2(Tup2(2000, 100), 1)),
1506                    Tup2(1, Tup2(Tup2(3000, 100), 1)),
1507                ],
1508            ],
1509            false,
1510        );
1511    }
1512
1513    #[test]
1514    fn test_empty_tree() {
1515        test_partition_rolling_aggregate(
1516            u64::MAX,
1517            None,
1518            std::iter::repeat_n(
1519                vec![
1520                    vec![Tup2(0u64, Tup2(Tup2(1u64, 100i64), 1))],
1521                    vec![Tup2(0u64, Tup2(Tup2(1u64, 100i64), -1))],
1522                ],
1523                1000,
1524            )
1525            .flatten()
1526            .collect::<Vec<_>>(),
1527            false,
1528        );
1529    }
1530
1531    // Test derived from issue #199 (https://github.com/feldera/feldera/issues/199).
1532    #[test]
1533    fn test_partitioned_rolling_aggregate2() {
1534        let (mut circuit, (input, expected)) = Runtime::init_circuit(1, move |circuit| {
1535            let (input_stream, input_handle) =
1536                circuit.add_input_indexed_zset::<u64, Tup2<u64, i64>>();
1537
1538            let (expected, expected_handle) =
1539                circuit.dyn_add_input_indexed_zset::<DynData/*<u64>*/, DynPair<DynDataTyped<u64>, DynOpt<DynData/*<i64>*/>>>(&AddInputIndexedZSetFactories::new::<u64, Tup2<u64, Option<i64>>>());
1540
1541            let expected = expected.typed::<OrdPartitionedIndexedZSet<u64, u64, _, Option<i64>, _>>();
1542
1543            let input_by_time =
1544                input_stream.map_index(|(partition, Tup2(ts, val))| (*ts, Tup2(*partition, *val)));
1545
1546            input_stream.inspect(|f| {
1547                for (p, Tup2(ts, v), w) in f.iter() {
1548                    println!(" input {p} {ts} {v:6} {w:+}");
1549                }
1550            });
1551            let range_spec = RelRange::new(RelOffset::Before(3), RelOffset::Before(2));
1552            let sum = input_by_time.partitioned_rolling_aggregate_linear(
1553                |Tup2(partition, val)| (*partition, *val),
1554                |&f| f,
1555                |x| x, range_spec);
1556            sum.inspect(|f| {
1557                for (p, Tup2(ts, sum), w) in f.iter() {
1558                    println!("output {p} {ts} {:6} {w:+}", sum.unwrap_or_default());
1559                }
1560            });
1561            expected.accumulate_apply2(&sum, |expected, actual| assert_eq!(expected.iter().collect::<Vec<_>>(), actual.iter().collect::<Vec<_>>()));
1562            Ok((input_handle, expected_handle))
1563        })
1564        .unwrap();
1565
1566        input.append(&mut vec![
1567            Tup2(1u64, Tup2(Tup2(0u64, 1i64), 1)),
1568            Tup2(1, Tup2(Tup2(1, 10), 1)),
1569            Tup2(1, Tup2(Tup2(2, 100), 1)),
1570            Tup2(1, Tup2(Tup2(3, 1000), 1)),
1571            Tup2(1, Tup2(Tup2(4, 10000), 1)),
1572            Tup2(1, Tup2(Tup2(5, 100000), 1)),
1573            Tup2(1, Tup2(Tup2(9, 123456), 1)),
1574        ]);
1575        expected.dyn_append(
1576            &mut Box::new(lean_vec![
1577                Tup2(1u64, Tup2(Tup2(0u64, None::<i64>), 1)),
1578                Tup2(1, Tup2(Tup2(1, None), 1)),
1579                Tup2(1, Tup2(Tup2(2, Some(1)), 1)),
1580                Tup2(1, Tup2(Tup2(3, Some(11)), 1)),
1581                Tup2(1, Tup2(Tup2(4, Some(110)), 1)),
1582                Tup2(1, Tup2(Tup2(5, Some(1100)), 1)),
1583                Tup2(1, Tup2(Tup2(9, None), 1)),
1584            ])
1585            .erase_box(),
1586        );
1587        circuit.transaction().unwrap();
1588    }
1589
1590    #[test]
1591    fn test_partitioned_rolling_average() {
1592        let (mut circuit, (input, expected)) = Runtime::init_circuit(1, move |circuit| {
1593            let (input_stream, input_handle) =
1594                circuit.add_input_indexed_zset::<u64, Tup2<u64, i64>>();
1595
1596            let (expected_stream, expected_handle) =
1597                circuit.dyn_add_input_indexed_zset::<DynData/*<u64>*/, DynPair<DynDataTyped<u64>, DynOpt<DynData/*<i64>*/>>>(&AddInputIndexedZSetFactories::new::<u64, Tup2<u64, Option<i64>>>());
1598
1599            let expected_stream = expected_stream.typed::<OrdPartitionedIndexedZSet<u64, u64, _, Option<i64>, _>>();
1600
1601            let input_by_time =
1602                input_stream.map_index(|(partition, Tup2(ts, val))| (*ts, Tup2(*partition, *val)));
1603
1604            let range_spec = RelRange::new(RelOffset::Before(3), RelOffset::Before(1));
1605            input_by_time
1606                .partitioned_rolling_average(
1607                    |Tup2(partition, val)| (*partition, *val),
1608                    range_spec)
1609                .accumulate_apply2(&expected_stream, |avg: SpineSnapshot<OrdPartitionedIndexedZSet<u64, u64, _, Option<i64>, _>>, expected| assert_eq!(avg.iter().collect::<Vec<_>>(), expected.iter().collect::<Vec<_>>()));
1610            Ok((input_handle, expected_handle))
1611        })
1612        .unwrap();
1613
1614        circuit.transaction().unwrap();
1615
1616        input.append(&mut vec![
1617            Tup2(0u64, Tup2(Tup2(10u64, 10i64), 1)),
1618            Tup2(0, Tup2(Tup2(11, 20), 1)),
1619            Tup2(0, Tup2(Tup2(12, 30), 1)),
1620            Tup2(0, Tup2(Tup2(13, 40), 1)),
1621            Tup2(0, Tup2(Tup2(14, 50), 1)),
1622            Tup2(0, Tup2(Tup2(15, 60), 1)),
1623        ]);
1624        expected.dyn_append(
1625            &mut Box::new(lean_vec![
1626                Tup2(0u64, Tup2(Tup2(10u64, None::<i64>), 1)),
1627                Tup2(0, Tup2(Tup2(11, Some(10)), 1)),
1628                Tup2(0, Tup2(Tup2(12, Some(15)), 1)),
1629                Tup2(0, Tup2(Tup2(13, Some(20)), 1)),
1630                Tup2(0, Tup2(Tup2(14, Some(30)), 1)),
1631                Tup2(0, Tup2(Tup2(15, Some(40)), 1)),
1632            ])
1633            .erase_box(),
1634        );
1635        circuit.transaction().unwrap();
1636    }
1637
1638    #[test]
1639    fn test_partitioned_rolling_aggregate() {
1640        let (mut circuit, input) = Runtime::init_circuit(1, move |circuit| {
1641            let (input_stream, input_handle) =
1642                circuit.add_input_indexed_zset::<u64, Tup2<u64, i64>>();
1643
1644            input_stream.inspect(|f| {
1645                for (p, Tup2(ts, v), w) in f.iter() {
1646                    println!(" input {p} {ts} {v:6} {w:+}");
1647                }
1648            });
1649            let input_by_time =
1650                input_stream.map_index(|(partition, Tup2(ts, val))| (*ts, Tup2(*partition, *val)));
1651
1652            let range_spec = RelRange::new(RelOffset::Before(3), RelOffset::Before(2));
1653            let sum = input_by_time.partitioned_rolling_aggregate_linear(
1654                |Tup2(partition, val)| (*partition, *val),
1655                |&f| f,
1656                |x| x,
1657                range_spec,
1658            );
1659            sum.inspect(|f| {
1660                for (p, Tup2(ts, sum), w) in f.iter() {
1661                    println!("output {p} {ts} {:6} {w:+}", sum.unwrap_or_default());
1662                }
1663            });
1664            Ok(input_handle)
1665        })
1666        .unwrap();
1667
1668        input.append(&mut vec![
1669            Tup2(1u64, Tup2(Tup2(0u64, 1i64), 1)),
1670            Tup2(1, Tup2(Tup2(1, 10), 1)),
1671            Tup2(1, Tup2(Tup2(2, 100), 1)),
1672            Tup2(1, Tup2(Tup2(3, 1000), 1)),
1673            Tup2(1, Tup2(Tup2(4, 10000), 1)),
1674            Tup2(1, Tup2(Tup2(5, 100000), 1)),
1675            Tup2(1, Tup2(Tup2(9, 123456), 1)),
1676        ]);
1677        circuit.transaction().unwrap();
1678    }
1679
1680    type InputTuple = Tup2<u64, Tup2<Tup2<u64, i64>, ZWeight>>;
1681    type InputBatch = Vec<InputTuple>;
1682
1683    fn input_tuple(partitions: u64, window: (u64, u64)) -> impl Strategy<Value = InputTuple> {
1684        (
1685            (0..partitions),
1686            (
1687                (window.0..window.1, 100..101i64).prop_map(|(x, y)| Tup2(x, y)),
1688                1..2i64,
1689            )
1690                .prop_map(|(x, y)| Tup2(x, y)),
1691        )
1692            .prop_map(|(x, y)| Tup2(x, y))
1693    }
1694
1695    fn input_batch(
1696        partitions: u64,
1697        window: (u64, u64),
1698        max_batch_size: usize,
1699    ) -> impl Strategy<Value = InputBatch> {
1700        collection::vec(input_tuple(partitions, window), 0..max_batch_size)
1701    }
1702
1703    fn input_trace(
1704        partitions: u64,
1705        epoch: u64,
1706        max_batch_size: usize,
1707        max_batches: usize,
1708    ) -> impl Strategy<Value = Vec<InputBatch>> {
1709        collection::vec(
1710            input_batch(partitions, (0, epoch), max_batch_size),
1711            0..max_batches,
1712        )
1713    }
1714
1715    fn input_trace_quasi_monotone(
1716        partitions: u64,
1717        window_size: u64,
1718        window_step: u64,
1719        max_batch_size: usize,
1720        batches: usize,
1721    ) -> impl Strategy<Value = Vec<InputBatch>> {
1722        (0..batches)
1723            .map(|i| {
1724                input_batch(
1725                    partitions,
1726                    (i as u64 * window_step, i as u64 * window_step + window_size),
1727                    max_batch_size,
1728                )
1729                .boxed()
1730            })
1731            .collect::<Vec<_>>()
1732    }
1733
1734    proptest! {
1735        #![proptest_config(ProptestConfig::with_cases(5))]
1736
1737        #[test]
1738        fn proptest_partitioned_rolling_aggregate_quasi_monotone_small_steps(trace in input_trace_quasi_monotone(5, 10_000, 2_000, 20, 200)) {
1739            // 10_000 is an empirically established bound: without GC this test needs >10KB.
1740            test_partition_rolling_aggregate(10000, Some(30_000), trace, false);
1741        }
1742
1743        #[test]
1744        #[ignore = "https://github.com/feldera/feldera/issues/4764"]
1745        fn proptest_partitioned_rolling_aggregate_quasi_monotone_big_step(trace in input_trace_quasi_monotone(5, 10_000, 2_000, 20, 200)) {
1746            // 10_000 is an empirically established bound: without GC this test needs >10KB.
1747            test_partition_rolling_aggregate(10000, Some(30_000), trace, true);
1748        }
1749    }
1750
1751    proptest! {
1752        #[test]
1753        fn proptest_partitioned_over_range_sparse_small_steps(trace in input_trace(5, 1_000_000, 10, 10)) {
1754            test_partition_rolling_aggregate(u64::MAX, None, trace, false);
1755        }
1756
1757        #[test]
1758        fn proptest_partitioned_over_range_sparse_big_step(trace in input_trace(5, 1_000_000, 10, 10)) {
1759            test_partition_rolling_aggregate(u64::MAX, None, trace, true);
1760        }
1761
1762        #[test]
1763        fn proptest_partitioned_over_range_dense_small_steps(trace in input_trace(5, 500, 25, 10)) {
1764            test_partition_rolling_aggregate(u64::MAX, None, trace, false);
1765        }
1766
1767        #[test]
1768        fn proptest_partitioned_over_range_dense_big_step(trace in input_trace(5, 500, 25, 10)) {
1769            test_partition_rolling_aggregate(u64::MAX, None, trace, true);
1770        }
1771    }
1772}