Skip to main content

dbsp/operator/dynamic/time_series/
rolling_aggregate.rs

1use crate::{
2    Circuit, DBData, DBWeight, DynZWeight, Position, RootCircuit, Stream, ZWeight,
3    algebra::{HasZero, IndexedZSet, UnsignedPrimInt, ZRingValue},
4    circuit::{
5        Scope,
6        metadata::{BatchSizeStats, INPUT_BATCHES_STATS, OUTPUT_BATCHES_STATS, OperatorMeta},
7        operator_traits::Operator,
8        splitter_output_chunk_size,
9    },
10    dynamic::{
11        ClonableTrait, Data, DataTrait, DowncastTrait, DynDataTyped, DynOpt, DynPair, DynUnit,
12        Erase, Factory, WeightTrait, WithFactory,
13    },
14    operator::{
15        Avg,
16        async_stream_operators::{StreamingQuaternaryOperator, StreamingQuaternaryWrapper},
17        dynamic::{
18            accumulate_trace::AccumulateTraceFeedback,
19            aggregate::{AggCombineFunc, AggOutputFunc, DynAggregator, DynAverage},
20            filter_map::DynFilterMap,
21            time_series::{
22                OrdPartitionedIndexedZSet, PartitionCursor, PartitionedBatch,
23                PartitionedIndexedZSet, RelOffset,
24                radix_tree::{
25                    OrdPartitionedTreeAggregateFactories, PartitionedRadixTreeBatch,
26                    RadixTreeCursor, TreeNode,
27                },
28                range::{Range, RangeCursor, Ranges, RelRange},
29            },
30            trace::{TraceBound, TraceBounds},
31        },
32    },
33    trace::{
34        Batch, BatchReader, BatchReaderFactories, Builder, Cursor, Spine, SpineSnapshot,
35        merge_batches, spine_async::WithSnapshot,
36    },
37    utils::Tup2,
38};
39use async_stream::stream;
40use dyn_clone::{DynClone, clone_box};
41use futures::Stream as AsyncStream;
42use num::Bounded;
43use std::{
44    borrow::Cow,
45    cell::RefCell,
46    marker::PhantomData,
47    ops::{Deref, Div, Neg},
48    rc::Rc,
49};
50
51use super::radix_tree::{FilePartitionedRadixTreeFactories, Prefix};
52
53pub trait WeighFunc<V: ?Sized, R: ?Sized, A: ?Sized>: Fn(&V, &R, &mut A) + DynClone {}
54
55impl<V: ?Sized, R: ?Sized, A: ?Sized, F> WeighFunc<V, R, A> for F where F: Fn(&V, &R, &mut A) + Clone
56{}
57
58dyn_clone::clone_trait_object! {<V: ?Sized, R: ?Sized, A: ?Sized> WeighFunc<V, R, A>}
59
60pub trait PartitionFunc<IV: ?Sized, PK: ?Sized, OV: ?Sized>:
61    Fn(&IV, &mut PK, &mut OV) + DynClone
62{
63}
64
65impl<IV: ?Sized, PK: ?Sized, OV: ?Sized, F> PartitionFunc<IV, PK, OV> for F where
66    F: Fn(&IV, &mut PK, &mut OV) + Clone
67{
68}
69
70pub type OrdPartitionedOverStream<PK, TS, A> =
71    Stream<RootCircuit, OrdPartitionedIndexedZSet<PK, TS, DynOpt<A>>>;
72
73pub struct PartitionedRollingAggregateFactories<TS, V, Acc, Out, B, O>
74where
75    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
76    O: IndexedZSet<Key = B::Key>,
77    Acc: DataTrait + ?Sized,
78    Out: DataTrait + ?Sized,
79    TS: DBData + UnsignedPrimInt,
80    V: DataTrait + ?Sized,
81{
82    input_factories: B::Factories,
83    radix_tree_factories: FilePartitionedRadixTreeFactories<B::Key, TS, Acc>,
84    partitioned_tree_aggregate_factories: OrdPartitionedTreeAggregateFactories<TS, V, B, Acc>,
85    output_factories: O::Factories,
86    phantom: PhantomData<fn(&Out)>,
87}
88
89impl<TS, V, Acc, Out, B, O> PartitionedRollingAggregateFactories<TS, V, Acc, Out, B, O>
90where
91    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
92    O: PartitionedIndexedZSet<DynDataTyped<TS>, DynOpt<Out>, Key = B::Key>,
93    Acc: DataTrait + ?Sized,
94    Out: DataTrait + ?Sized,
95    TS: DBData + UnsignedPrimInt,
96    V: DataTrait + ?Sized,
97{
98    pub fn new<KType, VType, AType, OType>() -> Self
99    where
100        KType: DBData + Erase<B::Key>,
101        VType: DBData + Erase<V>,
102        AType: DBData + Erase<Acc>,
103        OType: DBData + Erase<Out>,
104    {
105        Self {
106            input_factories: BatchReaderFactories::new::<KType, Tup2<TS, VType>, ZWeight>(),
107            radix_tree_factories: BatchReaderFactories::new::<
108                KType,
109                Tup2<Prefix<TS>, TreeNode<TS, AType>>,
110                ZWeight,
111            >(),
112            partitioned_tree_aggregate_factories: OrdPartitionedTreeAggregateFactories::new::<
113                KType,
114                VType,
115                AType,
116            >(),
117            output_factories: BatchReaderFactories::new::<KType, Tup2<TS, Option<OType>>, ZWeight>(
118            ),
119            phantom: PhantomData,
120        }
121    }
122}
123
124pub struct PartitionedRollingAggregateWithWaterlineFactories<PK, TS, V, Acc, Out, B>
125where
126    PK: DataTrait + ?Sized,
127    B: IndexedZSet,
128    TS: DBData + UnsignedPrimInt + Erase<B::Key>,
129    Acc: DataTrait + ?Sized,
130    Out: DataTrait + ?Sized,
131    V: DataTrait + ?Sized,
132{
133    input_factories: B::Factories,
134    rolling_aggregate_factories: PartitionedRollingAggregateFactories<
135        TS,
136        V,
137        Acc,
138        Out,
139        OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, V>,
140        OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, DynOpt<Out>>,
141    >,
142}
143
144impl<PK, TS, V, Acc, Out, B>
145    PartitionedRollingAggregateWithWaterlineFactories<PK, TS, V, Acc, Out, B>
146where
147    PK: DataTrait + ?Sized,
148    B: IndexedZSet,
149    TS: DBData + UnsignedPrimInt + Erase<B::Key>,
150    Acc: DataTrait + ?Sized,
151    Out: DataTrait + ?Sized,
152    V: DataTrait + ?Sized,
153{
154    pub fn new<VType, PKType, PVType, AType, OType>() -> Self
155    where
156        VType: DBData + Erase<B::Val>,
157        PKType: DBData + Erase<PK>,
158        PVType: DBData + Erase<V>,
159        AType: DBData + Erase<Acc>,
160        OType: DBData + Erase<Out>,
161    {
162        Self {
163            input_factories: BatchReaderFactories::new::<TS, VType, ZWeight>(),
164            rolling_aggregate_factories: PartitionedRollingAggregateFactories::new::<
165                PKType,
166                PVType,
167                AType,
168                OType,
169            >(),
170        }
171    }
172}
173
174pub struct PartitionedRollingAggregateLinearFactories<TS, V, OV, A, B, O>
175where
176    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
177    O: IndexedZSet<Key = B::Key>,
178    TS: DBData + UnsignedPrimInt,
179    V: DataTrait + ?Sized,
180    OV: DataTrait + ?Sized,
181    A: WeightTrait + ?Sized,
182{
183    aggregate_factory: &'static dyn Factory<A>,
184    opt_accumulator_factory: &'static dyn Factory<DynOpt<A>>,
185    output_factory: &'static dyn Factory<OV>,
186    rolling_aggregate_factories: PartitionedRollingAggregateFactories<TS, V, A, OV, B, O>,
187}
188
189impl<TS, V, OV, A, B, O> PartitionedRollingAggregateLinearFactories<TS, V, OV, A, B, O>
190where
191    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
192    O: PartitionedIndexedZSet<DynDataTyped<TS>, DynOpt<OV>, Key = B::Key>,
193    B::Key: DataTrait,
194    TS: DBData + UnsignedPrimInt,
195    V: DataTrait + ?Sized,
196    OV: DataTrait + ?Sized,
197    A: WeightTrait + ?Sized,
198{
199    pub fn new<KType, VType, AType, OVType>() -> Self
200    where
201        KType: DBData + Erase<B::Key>,
202        VType: DBData + Erase<V>,
203        AType: DBWeight + Erase<A>,
204        OVType: DBData + Erase<OV>,
205    {
206        Self {
207            aggregate_factory: WithFactory::<AType>::FACTORY,
208            opt_accumulator_factory: WithFactory::<Option<AType>>::FACTORY,
209            output_factory: WithFactory::<OVType>::FACTORY,
210            rolling_aggregate_factories: PartitionedRollingAggregateFactories::new::<
211                KType,
212                VType,
213                AType,
214                OVType,
215            >(),
216        }
217    }
218}
219
220pub struct PartitionedRollingAverageFactories<TS, V, W, B, O>
221where
222    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
223    O: IndexedZSet<Key = B::Key>,
224    TS: DBData + UnsignedPrimInt,
225    V: DataTrait + ?Sized,
226    W: WeightTrait + ?Sized,
227{
228    aggregate_factories:
229        PartitionedRollingAggregateLinearFactories<TS, V, V, DynAverage<W, B::R>, B, O>,
230    weight_factory: &'static dyn Factory<W>,
231}
232
233impl<TS, V, W, B, O> PartitionedRollingAverageFactories<TS, V, W, B, O>
234where
235    B: PartitionedIndexedZSet<DynDataTyped<TS>, V>,
236    O: PartitionedIndexedZSet<DynDataTyped<TS>, DynOpt<V>, Key = B::Key>,
237    TS: DBData + UnsignedPrimInt,
238    V: DataTrait + ?Sized,
239    W: WeightTrait + ?Sized,
240{
241    pub fn new<KType, VType, WType>() -> Self
242    where
243        KType: DBData + Erase<B::Key>,
244        VType: DBData + Erase<V>,
245        WType: DBWeight + From<ZWeight> + Div<Output = WType> + Erase<W>,
246    {
247        Self {
248            aggregate_factories: PartitionedRollingAggregateLinearFactories::new::<
249                KType,
250                VType,
251                Avg<WType, ZWeight>,
252                VType,
253            >(),
254            weight_factory: WithFactory::<WType>::FACTORY,
255        }
256    }
257}
258
259/// `Aggregator` object that computes a linear aggregation function.
260// TODO: we need this because we currently compute linear aggregates
261// using the same algorithm as general aggregates.  Additional performance
262// gains can be obtained with an optimized implementation of radix trees
263// for linear aggregates (specifically, updating a node when only
264// some of its children have changed can be done without computing
265// the sum of all children from scratch).
266struct LinearAggregator<V, A, O>
267where
268    V: DataTrait + ?Sized,
269    A: WeightTrait + ?Sized,
270    O: DataTrait + ?Sized,
271{
272    acc_factory: &'static dyn Factory<A>,
273    opt_accumulator_factory: &'static dyn Factory<DynOpt<A>>,
274    output_factory: &'static dyn Factory<O>,
275    f: Box<dyn WeighFunc<V, DynZWeight, A>>,
276    output_func: Box<dyn AggOutputFunc<A, O>>,
277    combine: Box<dyn AggCombineFunc<A>>,
278}
279
280impl<V, A, O> Clone for LinearAggregator<V, A, O>
281where
282    V: DataTrait + ?Sized,
283    A: WeightTrait + ?Sized,
284    O: DataTrait + ?Sized,
285{
286    fn clone(&self) -> Self {
287        Self {
288            acc_factory: self.acc_factory,
289            opt_accumulator_factory: self.opt_accumulator_factory,
290            output_factory: self.output_factory,
291            f: clone_box(self.f.as_ref()),
292            output_func: clone_box(self.output_func.as_ref()),
293            combine: clone_box(self.combine.as_ref()),
294        }
295    }
296}
297
298impl<V, A, O> LinearAggregator<V, A, O>
299where
300    V: DataTrait + ?Sized,
301    A: WeightTrait + ?Sized,
302    O: DataTrait + ?Sized,
303{
304    fn new(
305        acc_factory: &'static dyn Factory<A>,
306        opt_accumulator_factory: &'static dyn Factory<DynOpt<A>>,
307        output_factory: &'static dyn Factory<O>,
308        f: Box<dyn WeighFunc<V, DynZWeight, A>>,
309        output_func: Box<dyn AggOutputFunc<A, O>>,
310    ) -> Self {
311        Self {
312            acc_factory,
313            opt_accumulator_factory,
314            output_factory,
315            f,
316            output_func,
317            combine: Box::new(|acc, v| acc.add_assign(v)),
318        }
319    }
320}
321
322impl<V, A, O> DynAggregator<V, (), DynZWeight> for LinearAggregator<V, A, O>
323where
324    V: DataTrait + ?Sized,
325    A: WeightTrait + ?Sized,
326    O: DataTrait + ?Sized,
327{
328    type Accumulator = A;
329    type Output = O;
330
331    fn combine(&self) -> &dyn AggCombineFunc<A> {
332        self.combine.as_ref()
333    }
334
335    fn aggregate(&self, cursor: &mut dyn Cursor<V, DynUnit, (), DynZWeight>, agg: &mut DynOpt<A>) {
336        agg.set_none();
337        while cursor.key_valid() {
338            self.acc_factory.with(&mut |tmp_agg| {
339                let w = *cursor.weight().deref();
340                (self.f)(cursor.key(), w.erase(), tmp_agg);
341                match agg.get_mut() {
342                    None => agg.from_val(tmp_agg),
343                    Some(old) => old.add_assign(tmp_agg),
344                };
345            });
346            cursor.step_key();
347        }
348    }
349
350    fn finalize(&self, accumulator: &mut A, output: &mut O) {
351        (self.output_func)(accumulator, output)
352    }
353
354    fn aggregate_and_finalize(
355        &self,
356        _cursor: &mut dyn Cursor<V, DynUnit, (), DynZWeight>,
357        _output: &mut DynOpt<Self::Output>,
358    ) {
359        todo!()
360    }
361
362    fn opt_accumulator_factory(&self) -> &'static dyn Factory<DynOpt<Self::Accumulator>> {
363        self.opt_accumulator_factory
364    }
365
366    fn output_factory(&self) -> &'static dyn Factory<Self::Output> {
367        self.output_factory
368    }
369}
370
371impl<B> Stream<RootCircuit, B>
372where
373    B: IndexedZSet,
374{
375    /// See [`Stream::partitioned_rolling_aggregate_with_waterline`].
376    pub fn dyn_partitioned_rolling_aggregate_with_waterline<PK, TS, V, Acc, Out>(
377        &self,
378        persistent_id: Option<&str>,
379        factories: &PartitionedRollingAggregateWithWaterlineFactories<PK, TS, V, Acc, Out, B>,
380        waterline: &Stream<RootCircuit, Box<DynDataTyped<TS>>>,
381        partition_func: Box<dyn PartitionFunc<B::Val, PK, V>>,
382        aggregator: &dyn DynAggregator<V, (), B::R, Accumulator = Acc, Output = Out>,
383        range: RelRange<TS>,
384    ) -> OrdPartitionedOverStream<PK, DynDataTyped<TS>, Out>
385    where
386        B: IndexedZSet,
387        B: for<'a> DynFilterMap<
388            DynItemRef<'a> = (&'a <B as BatchReader>::Key, &'a <B as BatchReader>::Val),
389        >,
390        Box<B::Key>: Clone,
391        PK: DataTrait + ?Sized,
392        TS: DBData + UnsignedPrimInt + Erase<B::Key>,
393        V: DataTrait + ?Sized,
394        Acc: DataTrait + ?Sized,
395        Out: DataTrait + ?Sized,
396    {
397        self.circuit()
398            .region("partitioned_rolling_aggregate_with_waterline", || {
399                // Shift the aggregation window so that its right end is at 0.
400                let shifted_range =
401                    RelRange::new(range.from - range.to, RelOffset::Before(HasZero::zero()));
402
403                // Trace bound used inside `partitioned_rolling_aggregate_inner` to
404                // bound its output trace.  This is the same bound we use to construct
405                // the input window here.
406                let bound: TraceBound<DynPair<DynDataTyped<TS>, DynOpt<Out>>> = TraceBound::new();
407                let bound_clone = bound.clone();
408
409                let mut bound_box = factories
410                    .rolling_aggregate_factories
411                    .output_factories
412                    .val_factory()
413                    .default_box();
414
415                // Restrict the input stream to the `[lb -> ∞)` time window,
416                // where `lb = waterline - (range.to - range.from)` is the lower
417                // bound on input timestamps that may be used to compute
418                // changes to the rolling aggregate operator.
419                let bounds = waterline.apply_mut(move |wm| {
420                    let lower = shifted_range
421                        .range_of(wm.as_ref().deref())
422                        .map(|range| range.from)
423                        .unwrap_or_else(|| Bounded::min_value());
424                    **bound_box.fst_mut() = lower;
425                    bound_box.snd_mut().set_none();
426                    bound_clone.set(clone_box(bound_box.as_ref()));
427                    (
428                        Box::new(lower).erase_box(),
429                        Box::new(<TS as Bounded>::max_value()).erase_box(),
430                    )
431                });
432                let window = self
433                    .dyn_window(&factories.input_factories, (true, true), &bounds)
434                    .set_persistent_id(
435                        persistent_id
436                            .map(|name| format!("{name}.window"))
437                            .as_deref(),
438                    );
439
440                // Now that we've truncated old inputs, which required the
441                // input stream to be indexed by time, we can re-index it
442                // by partition id.
443                let partition_func_clone = clone_box(partition_func.as_ref());
444
445                let partitioned_window = window
446                    .dyn_map_index(
447                        &factories.rolling_aggregate_factories.input_factories,
448                        Box::new(move |(ts, v), res| {
449                            let (partition_key, ts_val) = res.split_mut();
450                            let (res_ts, val) = ts_val.split_mut();
451                            partition_func_clone(v, partition_key, val);
452                            unsafe { *res_ts.downcast_mut::<TS>() = *ts.downcast::<TS>() };
453                        }),
454                    )
455                    .set_persistent_id(
456                        persistent_id
457                            .map(|name| format!("{name}-partitioned_window"))
458                            .as_deref(),
459                    );
460                let partitioned_self = self
461                    .dyn_map_index(
462                        &factories.rolling_aggregate_factories.input_factories,
463                        Box::new(move |(ts, v), res| {
464                            let (partition_key, ts_val) = res.split_mut();
465                            let (res_ts, val) = ts_val.split_mut();
466                            partition_func(v, partition_key, val);
467                            unsafe { *res_ts.downcast_mut::<TS>() = *ts.downcast::<TS>() };
468                        }),
469                    )
470                    .set_persistent_id(
471                        persistent_id
472                            .map(|name| format!("{name}-partitioned"))
473                            .as_deref(),
474                    );
475
476                partitioned_self.dyn_partitioned_rolling_aggregate_inner(
477                    persistent_id,
478                    &factories.rolling_aggregate_factories,
479                    &partitioned_window,
480                    aggregator,
481                    range,
482                    bound,
483                )
484            })
485    }
486
487    /// Like [`Self::dyn_partitioned_rolling_aggregate`], but can return any
488    /// batch type.
489    pub fn dyn_partitioned_rolling_aggregate<PK, TS, V, Acc, Out>(
490        &self,
491        persistent_id: Option<&str>,
492        factories: &PartitionedRollingAggregateFactories<
493            TS,
494            V,
495            Acc,
496            Out,
497            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, V>,
498            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, DynOpt<Out>>,
499        >,
500        partition_func: Box<dyn PartitionFunc<B::Val, PK, V>>,
501        aggregator: &dyn DynAggregator<V, (), B::R, Accumulator = Acc, Output = Out>,
502        range: RelRange<TS>,
503    ) -> OrdPartitionedOverStream<PK, DynDataTyped<TS>, Out>
504    where
505        B: IndexedZSet,
506        B: for<'a> DynFilterMap<
507            DynItemRef<'a> = (&'a <B as BatchReader>::Key, &'a <B as BatchReader>::Val),
508        >,
509        Acc: DataTrait + ?Sized,
510        Out: DataTrait + ?Sized,
511        PK: DataTrait + ?Sized,
512        TS: DBData + UnsignedPrimInt + Erase<B::Key>,
513        V: DataTrait + ?Sized,
514    {
515        // ```
516        //                  ┌───────────────┐   input_trace
517        //      ┌──────────►│integrate_trace├──────────────┐                              output
518        //      │           └───────────────┘              │                           ┌────────────────────────────────────►
519        //      │                                          ▼                           │
520        // self │    ┌──────────────────────────┐  tree  ┌───────────────────────────┐ │  ┌──────────────────┐ output_trace
521        // ─────┼───►│partitioned_tree_aggregate├───────►│PartitionedRollingAggregate├─┴──┤UntimedTraceAppend├────────┐
522        //      │    └──────────────────────────┘        └───────────────────────────┘    └──────────────────┘        │
523        //      │                                          ▲               ▲                 ▲                        │
524        //      └──────────────────────────────────────────┘               │                 │                        │
525        //                                                                 │               ┌─┴──┐                     │
526        //                                                                 └───────────────┤Z^-1│◄────────────────────┘
527        //                                                                   delayed_trace └────┘
528        // ```
529        self.circuit().region("partitioned_rolling_aggregate", || {
530            let partitioned = self
531                .dyn_map_index(
532                    &factories.input_factories,
533                    Box::new(move |(ts, v), res| {
534                        let (partition_key, ts_val) = res.split_mut();
535                        let (res_ts, val) = ts_val.split_mut();
536                        partition_func(v, partition_key, val);
537                        unsafe { *res_ts.downcast_mut::<TS>() = *ts.downcast::<TS>() };
538                    }),
539                )
540                .set_persistent_id(
541                    persistent_id
542                        .map(|name| format!("{name}-partitioned"))
543                        .as_deref(),
544                );
545
546            partitioned.dyn_partitioned_rolling_aggregate_inner(
547                persistent_id,
548                factories,
549                &partitioned,
550                aggregator,
551                range,
552                TraceBound::new(),
553            )
554        })
555    }
556
557    /// See [`Stream::partitioned_rolling_aggregate_linear`].
558    pub fn dyn_partitioned_rolling_aggregate_linear<PK, TS, V, A, O>(
559        &self,
560        persistent_id: Option<&str>,
561        factories: &PartitionedRollingAggregateLinearFactories<
562            TS,
563            V,
564            O,
565            A,
566            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, V>,
567            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, DynOpt<O>>,
568        >,
569        partition_func: Box<dyn PartitionFunc<B::Val, PK, V>>,
570        f: Box<dyn WeighFunc<V, B::R, A>>,
571        output_func: Box<dyn AggOutputFunc<A, O>>,
572        range: RelRange<TS>,
573    ) -> OrdPartitionedOverStream<PK, DynDataTyped<TS>, O>
574    where
575        B: IndexedZSet,
576        B: for<'a> DynFilterMap<
577            DynItemRef<'a> = (&'a <B as BatchReader>::Key, &'a <B as BatchReader>::Val),
578        >,
579        PK: DataTrait + ?Sized,
580        TS: DBData + UnsignedPrimInt + Erase<B::Key>,
581        V: DataTrait + ?Sized,
582        A: WeightTrait + ?Sized,
583        O: DataTrait + ?Sized,
584    {
585        let aggregator = LinearAggregator::new(
586            factories.aggregate_factory,
587            factories.opt_accumulator_factory,
588            factories.output_factory,
589            f,
590            output_func,
591        );
592        self.dyn_partitioned_rolling_aggregate::<PK, TS, V, _, _>(
593            persistent_id,
594            &factories.rolling_aggregate_factories,
595            partition_func,
596            &aggregator,
597            range,
598        )
599    }
600
601    pub fn dyn_partitioned_rolling_average<PK, TS, V, W>(
602        &self,
603        persistent_id: Option<&str>,
604        factories: &PartitionedRollingAverageFactories<
605            TS,
606            V,
607            W,
608            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, V>,
609            OrdPartitionedIndexedZSet<PK, DynDataTyped<TS>, DynOpt<V>>,
610        >,
611        partition_func: Box<dyn PartitionFunc<B::Val, PK, V>>,
612        f: Box<dyn WeighFunc<V, B::R, W>>,
613        out_func: Box<dyn AggOutputFunc<W, V>>,
614        range: RelRange<TS>,
615    ) -> OrdPartitionedOverStream<PK, DynDataTyped<TS>, V>
616    where
617        B: IndexedZSet,
618        B: for<'a> DynFilterMap<
619            DynItemRef<'a> = (&'a <B as BatchReader>::Key, &'a <B as BatchReader>::Val),
620        >,
621        PK: DataTrait + ?Sized,
622        TS: DBData + UnsignedPrimInt + Erase<B::Key>,
623        V: DataTrait + ?Sized,
624        W: WeightTrait + ?Sized,
625    {
626        let weight_factory = factories.weight_factory;
627        self.dyn_partitioned_rolling_aggregate_linear(
628            persistent_id,
629            &factories.aggregate_factories,
630            partition_func,
631            Box::new(move |v: &V, w: &B::R, avg: &mut DynAverage<W, B::R>| {
632                let (sum, count) = avg.split_mut();
633                w.clone_to(count);
634                f(v, w, sum);
635            }),
636            Box::new(move |avg, out| {
637                weight_factory.with(&mut |avg_val| {
638                    avg.compute_avg(avg_val);
639                    out_func(avg_val, out)
640                })
641            }),
642            range,
643        )
644    }
645}
646
647impl<B> Stream<RootCircuit, B> {
648    #[doc(hidden)]
649    pub fn dyn_partitioned_rolling_aggregate_inner<TS, V, Acc, Out, O>(
650        &self,
651        partition_id: Option<&str>,
652        factories: &PartitionedRollingAggregateFactories<TS, V, Acc, Out, B, O>,
653        self_window: &Self,
654        aggregator: &dyn DynAggregator<V, (), DynZWeight, Accumulator = Acc, Output = Out>,
655        range: RelRange<TS>,
656        bound: TraceBound<DynPair<DynDataTyped<TS>, DynOpt<Out>>>,
657    ) -> Stream<RootCircuit, O>
658    where
659        B: PartitionedIndexedZSet<DynDataTyped<TS>, V> + Send,
660        O: PartitionedIndexedZSet<DynDataTyped<TS>, DynOpt<Out>, Key = B::Key>,
661        Acc: DataTrait + ?Sized,
662        Out: DataTrait + ?Sized,
663        TS: DBData + UnsignedPrimInt,
664        V: DataTrait + ?Sized,
665    {
666        let circuit = self.circuit();
667        let stream = self.dyn_shard(&factories.input_factories);
668        let stream_window = self_window.dyn_shard(&factories.input_factories);
669
670        let partitioned_tree_aggregate_name =
671            partition_id.map(|name| format!("{name}-tree_aggregate"));
672
673        // Build the radix tree over the bounded window.
674        let tree = stream_window
675            .partitioned_tree_aggregate::<TS, V, Acc, Out>(
676                partitioned_tree_aggregate_name.as_deref(),
677                &factories.partitioned_tree_aggregate_factories,
678                aggregator,
679            )
680            .set_persistent_id(partitioned_tree_aggregate_name.as_deref())
681            .dyn_accumulate_integrate_trace(&factories.radix_tree_factories);
682
683        let input_trace = stream_window.dyn_accumulate_integrate_trace(&factories.input_factories);
684
685        // Truncate timestamps `< bound` in the output trace.
686        let bounds = TraceBounds::new();
687        bounds.add_key_bound(TraceBound::new());
688        bounds.add_val_bound(bound);
689
690        let feedback = circuit.add_accumulate_integrate_trace_feedback::<Spine<O>>(
691            partition_id,
692            &factories.output_factories,
693            bounds,
694        );
695
696        let output = circuit
697            .add_quaternary_operator(
698                StreamingQuaternaryWrapper::new(
699                    <PartitionedRollingAggregate<TS, B, V, Acc, Out, _>>::new(
700                        &factories.output_factories,
701                        range,
702                        aggregator,
703                    ),
704                ),
705                &stream.dyn_accumulate(&factories.input_factories),
706                &input_trace,
707                &tree,
708                &feedback.delayed_trace,
709            )
710            .mark_distinct()
711            .mark_sharded();
712
713        feedback.connect(&output, &factories.output_factories);
714
715        output
716    }
717}
718
719/// Quaternary operator that implements the internals of
720/// `partitioned_rolling_aggregate`.
721///
722/// * Input stream 1: updates to the time series.  Used to identify affected
723///   partitions and times.
724/// * Input stream 2: trace containing the accumulated time series data.
725/// * Input stream 3: trace containing the partitioned radix tree over the input
726///   time series.
727/// * Input stream 4: trace of previously produced outputs.  Used to compute
728///   retractions.
729struct PartitionedRollingAggregate<
730    TS: DBData,
731    B: PartitionedBatch<DynDataTyped<TS>, V, R = DynZWeight>,
732    V: DataTrait + ?Sized,
733    Acc: DataTrait + ?Sized,
734    Out: DataTrait + ?Sized,
735    O: Batch,
736> {
737    output_factories: O::Factories,
738    range: RelRange<TS>,
739    aggregator: Box<dyn DynAggregator<V, (), DynZWeight, Accumulator = Acc, Output = Out>>,
740    flush: RefCell<bool>,
741    input_delta: RefCell<Option<SpineSnapshot<B>>>,
742
743    // Input batch sizes.
744    input_batch_stats: RefCell<BatchSizeStats>,
745
746    // Output batch sizes.
747    output_batch_stats: RefCell<BatchSizeStats>,
748
749    phantom: PhantomData<fn(&V, &O)>,
750}
751
752impl<TS, B, V, Acc, Out, O> PartitionedRollingAggregate<TS, B, V, Acc, Out, O>
753where
754    TS: DBData,
755    B: PartitionedBatch<DynDataTyped<TS>, V, R = DynZWeight>,
756    V: DataTrait + ?Sized,
757    Acc: DataTrait + ?Sized,
758    Out: DataTrait + ?Sized,
759    O: Batch,
760{
761    fn new(
762        output_factories: &O::Factories,
763        range: RelRange<TS>,
764        aggregator: &dyn DynAggregator<V, (), DynZWeight, Accumulator = Acc, Output = Out>,
765    ) -> Self {
766        Self {
767            output_factories: output_factories.clone(),
768            range,
769            aggregator: clone_box(aggregator),
770            flush: RefCell::new(false),
771            input_delta: RefCell::new(None),
772            input_batch_stats: RefCell::new(BatchSizeStats::new()),
773            output_batch_stats: RefCell::new(BatchSizeStats::new()),
774
775            phantom: PhantomData,
776        }
777    }
778
779    fn affected_ranges<R, C>(&self, delta_cursor: &mut C) -> Ranges<TS>
780    where
781        C: Cursor<DynDataTyped<TS>, V, (), R>,
782        TS: DBData + UnsignedPrimInt,
783        R: ?Sized,
784    {
785        let mut affected_ranges = Ranges::new();
786        let mut delta_ranges = Ranges::new();
787
788        while delta_cursor.key_valid() {
789            if let Some(range) = self.range.affected_range_of(delta_cursor.key().deref()) {
790                affected_ranges.push_monotonic(range);
791            }
792            // If `delta_cursor.key()` is a new key that doesn't yet occur in the input
793            // z-set, we need to compute its aggregate even if it is outside
794            // affected range.
795            delta_ranges.push_monotonic(Range::new(**delta_cursor.key(), **delta_cursor.key()));
796            delta_cursor.step_key();
797        }
798
799        affected_ranges.merge(&delta_ranges)
800    }
801}
802
803impl<TS, B, V, Acc, Out, O> Operator for PartitionedRollingAggregate<TS, B, V, Acc, Out, O>
804where
805    TS: DBData,
806    B: PartitionedBatch<DynDataTyped<TS>, V, R = DynZWeight>,
807    V: DataTrait + ?Sized,
808    Acc: DataTrait + ?Sized,
809    Out: DataTrait + ?Sized,
810    O: Batch,
811{
812    fn name(&self) -> Cow<'static, str> {
813        Cow::from("PartitionedRollingAggregate")
814    }
815
816    fn metadata(&self, meta: &mut OperatorMeta) {
817        meta.extend(metadata! {
818            INPUT_BATCHES_STATS => self.input_batch_stats.borrow().metadata(),
819            OUTPUT_BATCHES_STATS => self.output_batch_stats.borrow().metadata(),
820        });
821    }
822
823    fn fixedpoint(&self, _scope: Scope) -> bool {
824        true
825    }
826
827    fn flush(&mut self) {
828        *self.flush.borrow_mut() = true;
829    }
830}
831
832impl<TS, V, Acc, Out, B, T, RT, OT, O>
833    StreamingQuaternaryOperator<Option<Spine<B>>, Spine<T>, Spine<RT>, Spine<OT>, O>
834    for PartitionedRollingAggregate<TS, B, V, Acc, Out, O>
835where
836    TS: DBData + UnsignedPrimInt,
837    V: DataTrait + ?Sized,
838    Acc: DataTrait + ?Sized,
839    Out: DataTrait + ?Sized,
840    B: PartitionedBatch<DynDataTyped<TS>, V, R = DynZWeight>,
841    T: PartitionedBatch<DynDataTyped<TS>, V, Key = B::Key, R = B::R> + Clone,
842    RT: PartitionedRadixTreeBatch<TS, Acc, Key = B::Key> + Clone,
843    OT: PartitionedBatch<DynDataTyped<TS>, DynOpt<Out>, Key = B::Key, R = B::R> + Clone,
844    O: IndexedZSet<Key = B::Key, Val = DynPair<DynDataTyped<TS>, DynOpt<Out>>>,
845{
846    fn eval(
847        self: Rc<Self>,
848        input_delta: Cow<'_, Option<Spine<B>>>,
849        input_trace: Cow<'_, Spine<T>>,
850        radix_tree: Cow<'_, Spine<RT>>,
851        output_trace: Cow<'_, Spine<OT>>,
852    ) -> impl AsyncStream<Item = (O, bool, Option<Position>)> + 'static {
853        let chunk_size = splitter_output_chunk_size();
854
855        if let Some(input_delta) = input_delta.as_ref().as_ref() {
856            assert!(self.input_delta.borrow().is_none());
857            *self.input_delta.borrow_mut() = Some(input_delta.ro_snapshot());
858        };
859
860        let input_trace = if *self.flush.borrow() {
861            Some(input_trace.as_ref().ro_snapshot())
862        } else {
863            None
864        };
865
866        let radix_tree = if *self.flush.borrow() {
867            Some(radix_tree.as_ref().ro_snapshot())
868        } else {
869            None
870        };
871
872        let output_trace = if *self.flush.borrow() {
873            Some(output_trace.as_ref().ro_snapshot())
874        } else {
875            None
876        };
877
878        stream! {
879            if !*self.flush.borrow() {
880                yield (O::dyn_empty(&self.output_factories), true, None);
881                return;
882            }
883
884            let input_delta = self.input_delta.borrow_mut().take().unwrap();
885            self.input_batch_stats.borrow_mut().add_batch(input_delta.len());
886
887            let mut delta_cursor = input_delta.cursor();
888            let mut output_trace_cursor = output_trace.unwrap().cursor();
889            let mut input_trace_cursor = input_trace.unwrap().cursor();
890            let mut tree_cursor = radix_tree.unwrap().cursor();
891
892            let mut retraction_builder =
893                O::Builder::with_capacity(&self.output_factories, chunk_size, chunk_size);
894            let mut insertion_builder =
895                O::Builder::with_capacity(&self.output_factories, chunk_size, chunk_size);
896
897            // println!("delta: {input_delta:#x?}");
898            // println!("radix tree: {radix_tree:#x?}");
899            // println!("aggregate_range({range:x?})");
900            // let mut treestr = String::new();
901            // radix_tree.cursor().format_tree(&mut treestr).unwrap();
902            // println!("tree: {treestr}");
903            // tree_partition_cursor.rewind_keys();
904
905            let mut val = self.output_factories.val_factory().default_box();
906            let mut acc = self.aggregator.opt_accumulator_factory().default_box();
907            let mut agg = self.aggregator.output_factory().default_box();
908
909            // Iterate over affected partitions.
910            while delta_cursor.key_valid() {
911                // Compute affected intervals using `input_delta`.
912                let ranges = self.affected_ranges(&mut PartitionCursor::new(&mut delta_cursor));
913                // println!("affected_ranges: {ranges:?}");
914
915                // Clear old outputs.
916                let hash = delta_cursor.key().default_hash();
917                if output_trace_cursor.seek_key_exact(delta_cursor.key(), Some(hash)) {
918                    let mut range_cursor = RangeCursor::new(
919                        PartitionCursor::new(&mut output_trace_cursor),
920                        ranges.clone(),
921                    );
922                    let mut any_values = false;
923                    while range_cursor.key_valid() {
924                        while range_cursor.val_valid() {
925                            let weight = **range_cursor.weight();
926                            debug_assert!(weight != 0);
927                            val.from_refs(range_cursor.key(), range_cursor.val());
928                            retraction_builder.push_val_diff_mut(&mut *val, &mut weight.neg());
929                            any_values = true;
930
931                            if retraction_builder.num_tuples() >= chunk_size {
932                                retraction_builder.push_key(delta_cursor.key());
933                                let result = retraction_builder.done();
934                                self.output_batch_stats.borrow_mut().add_batch(result.len());
935                                yield (result, false, delta_cursor.position());
936                                any_values = false;
937                                retraction_builder = O::Builder::with_capacity(&self.output_factories, chunk_size, chunk_size);
938                            }
939                            range_cursor.step_val();
940                        }
941                        range_cursor.step_key();
942                    }
943                    if any_values {
944                        retraction_builder.push_key(delta_cursor.key());
945                    }
946                };
947
948                // Compute new outputs.
949                if input_trace_cursor.seek_key_exact(delta_cursor.key(), Some(hash))
950                    // It's possible that the key is in the input trace with weight 0, but it's no longer in the tree, which
951                    // caused `test_empty_tree()` to fail without this check.
952                    && tree_cursor.seek_key_exact(delta_cursor.key(), Some(hash))
953                {
954                    let mut tree_partition_cursor = PartitionCursor::new(&mut tree_cursor);
955                    let mut input_range_cursor =
956                        RangeCursor::new(PartitionCursor::new(&mut input_trace_cursor), ranges);
957
958                    // For all affected times, seek them in `input_trace`, compute aggregates using
959                    // using radix_tree.
960                    let mut any_values = false;
961                    while input_range_cursor.key_valid() {
962                        let range = self.range.range_of(input_range_cursor.key());
963                        tree_partition_cursor.rewind_keys();
964
965                        // println!("aggregate_range({range:x?})");
966                        // let mut treestr = String::new();
967                        // tree_partition_cursor.format_tree(&mut treestr).unwrap();
968                        // println!("tree: {treestr}");
969                        // tree_partition_cursor.rewind_keys();
970
971                        while input_range_cursor.val_valid() {
972                            // Generate output update.
973                            if !input_range_cursor.weight().le0() {
974                                **val.fst_mut() = **input_range_cursor.key();
975                                if let Some(range) = range {
976                                    tree_partition_cursor.aggregate_range(
977                                        &range,
978                                        self.aggregator.combine(),
979                                        acc.as_mut(),
980                                    );
981                                    if let Some(acc) = acc.get_mut() {
982                                        self.aggregator.finalize(acc, agg.as_mut());
983                                        val.snd_mut().from_val(agg.as_mut());
984                                    } else {
985                                        val.snd_mut().set_none();
986                                    }
987                                } else {
988                                    val.snd_mut().set_none();
989                                }
990
991                                // println!("insert({item:?})");
992
993                                insertion_builder.push_val_diff_mut(&mut *val, 1.erase_mut());
994                                any_values = true;
995
996                                if insertion_builder.num_tuples() >= chunk_size {
997                                    insertion_builder.push_key(delta_cursor.key());
998                                    any_values = false;
999
1000                                    let result = insertion_builder.done();
1001                                    self.output_batch_stats.borrow_mut().add_batch(result.len());
1002
1003                                    yield (result, false, delta_cursor.position());
1004                                    insertion_builder =
1005                                        O::Builder::with_capacity(&self.output_factories, chunk_size, chunk_size);
1006                                }
1007
1008                                break;
1009                            }
1010
1011                            input_range_cursor.step_val();
1012                        }
1013
1014                        input_range_cursor.step_key();
1015                    }
1016                    if any_values {
1017                        insertion_builder.push_key(delta_cursor.key());
1018                    }
1019                }
1020
1021                delta_cursor.step_key();
1022            }
1023
1024            *self.flush.borrow_mut() = false;
1025            let retractions = retraction_builder.done();
1026            let insertions = insertion_builder.done();
1027
1028            let result = merge_batches(&insertions.factories(), [insertions,retractions], &None, &None);
1029            self.output_batch_stats.borrow_mut().add_batch(result.len());
1030            yield (result, true, delta_cursor.position());
1031        }
1032    }
1033}
1034
1035#[cfg(test)]
1036mod test {
1037    use crate::{
1038        DBData, DBSPHandle, IndexedZSetHandle, OrdIndexedZSet, OutputHandle, RootCircuit, Runtime,
1039        Stream, TypedBox, ZWeight,
1040        algebra::{DefaultSemigroup, UnsignedPrimInt},
1041        circuit::CircuitConfig,
1042        dynamic::{DowncastTrait, DynData, DynDataTyped, DynOpt, DynPair, Erase},
1043        lean_vec,
1044        operator::{
1045            Fold,
1046            dynamic::{
1047                input::AddInputIndexedZSetFactories,
1048                time_series::{
1049                    PartitionCursor,
1050                    range::{Range, RelOffset, RelRange},
1051                },
1052                trace::TraceBound,
1053            },
1054            time_series::OrdPartitionedIndexedZSet,
1055        },
1056        trace::{BatchReaderFactories, Cursor},
1057        typed_batch::{
1058            DynBatchReader, DynOrdIndexedZSet, IndexedZSetReader, SpineSnapshot, TypedBatch,
1059        },
1060        utils::Tup2,
1061    };
1062    use proptest::{collection, prelude::*};
1063    use size_of::SizeOf;
1064
1065    type DataBatch = DynOrdIndexedZSet<
1066        DynData, /* <u64> */
1067        DynPair<DynDataTyped<u64>, DynData /* <i64> */>,
1068    >;
1069    type OutputBatch = TypedBatch<
1070        u64,
1071        Tup2<u64, Option<i64>>,
1072        ZWeight,
1073        DynOrdIndexedZSet<
1074            DynData, /* <u64> */
1075            DynPair<DynDataTyped<u64>, DynOpt<DynData /* <i64> */>>,
1076        >,
1077    >;
1078
1079    impl<PK, TS, V> Stream<RootCircuit, OrdIndexedZSet<PK, Tup2<TS, V>>>
1080    where
1081        PK: DBData,
1082        TS: DBData + UnsignedPrimInt,
1083        V: DBData,
1084    {
1085        pub fn as_partitioned_zset(
1086            &self,
1087        ) -> Stream<RootCircuit, OrdPartitionedIndexedZSet<PK, TS, DynDataTyped<TS>, V, DynData>>
1088        {
1089            let factories = BatchReaderFactories::new::<PK, Tup2<TS, V>, ZWeight>();
1090
1091            self.inner()
1092                .dyn_map_index(
1093                    &factories,
1094                    Box::new(|(k, v), kv| {
1095                        kv.from_refs(k, unsafe { v.downcast::<Tup2<TS, V>>().erase() })
1096                    }),
1097                )
1098                .typed()
1099        }
1100    }
1101
1102    // Reference implementation of `aggregate_range` for testing.
1103    fn aggregate_range_slow(batch: &DataBatch, partition: u64, range: Range<u64>) -> Option<i64> {
1104        let mut cursor = batch.cursor();
1105
1106        cursor.seek_key(&partition);
1107        assert!(cursor.key_valid());
1108        assert!(*cursor.key().downcast_checked::<u64>() == partition);
1109        let mut partition_cursor = PartitionCursor::new(&mut cursor);
1110
1111        let mut agg = None;
1112        partition_cursor.seek_key(&range.from);
1113        while partition_cursor.key_valid()
1114            && *partition_cursor.key().downcast_checked::<u64>() <= range.to
1115        {
1116            while partition_cursor.val_valid() {
1117                let w = *partition_cursor.weight().downcast_checked::<ZWeight>();
1118                debug_assert!(w != 0);
1119                agg = if let Some(a) = agg {
1120                    Some(a + *partition_cursor.val().downcast_checked::<i64>() * w)
1121                } else {
1122                    Some(*partition_cursor.val().downcast_checked::<i64>() * w)
1123                };
1124                partition_cursor.step_val();
1125            }
1126            partition_cursor.step_key();
1127        }
1128
1129        agg
1130    }
1131
1132    // Reference implementation of `partitioned_rolling_aggregate` for testing.
1133    fn partitioned_rolling_aggregate_slow(
1134        stream: &Stream<RootCircuit, DataBatch>,
1135        range_spec: RelRange<u64>,
1136    ) -> Stream<RootCircuit, OutputBatch> {
1137        let stream = stream.typed::<TypedBatch<u64, Tup2<u64, i64>, ZWeight, _>>();
1138
1139        stream
1140            .circuit()
1141            .non_incremental(&stream, |_child, stream| {
1142                Ok(stream
1143                    .gather(0)
1144                    .integrate()
1145                    .apply(move |batch: &TypedBatch<_, _, _, DataBatch>| {
1146                        let mut tuples = Vec::with_capacity(batch.len());
1147
1148                        let mut cursor = batch.cursor();
1149
1150                        while cursor.key_valid() {
1151                            while cursor.val_valid() {
1152                                let partition = *cursor.key().downcast_checked::<u64>();
1153                                let Tup2(ts, _val) =
1154                                    *cursor.val().downcast_checked::<Tup2<u64, i64>>();
1155                                let agg = range_spec.range_of(&ts).and_then(|range| {
1156                                    aggregate_range_slow(batch, partition, range)
1157                                });
1158                                tuples.push(Tup2(Tup2(partition, Tup2(ts, agg)), 1));
1159                                cursor.step_val();
1160                            }
1161                            cursor.step_key();
1162                        }
1163
1164                        OutputBatch::from_tuples((), tuples)
1165                    })
1166                    .stream_distinct()
1167                    .gather(0))
1168            })
1169            .unwrap()
1170    }
1171
1172    type TestOutputHandle = OutputHandle<
1173        SpineSnapshot<
1174            OrdPartitionedIndexedZSet<u64, u64, DynDataTyped<u64>, Option<i64>, DynOpt<DynData>>,
1175        >,
1176    >;
1177
1178    fn partition_rolling_aggregate_circuit(
1179        lateness: u64,
1180        size_bound: Option<usize>,
1181    ) -> (
1182        DBSPHandle,
1183        (
1184            IndexedZSetHandle<u64, Tup2<u64, i64>>,
1185            TestOutputHandle,
1186            TestOutputHandle,
1187            TestOutputHandle,
1188            TestOutputHandle,
1189            TestOutputHandle,
1190            TestOutputHandle,
1191            TestOutputHandle,
1192            TestOutputHandle,
1193            TestOutputHandle,
1194            TestOutputHandle,
1195            TestOutputHandle,
1196            TestOutputHandle,
1197            TestOutputHandle,
1198            TestOutputHandle,
1199        ),
1200    ) {
1201        Runtime::init_circuit(
1202            CircuitConfig::from(2).with_splitter_chunk_size_records(6),
1203            move |circuit| {
1204                let (input_stream, input_handle) =
1205                    circuit.add_input_indexed_zset::<u64, Tup2<u64, i64>>();
1206
1207                let input_by_time = input_stream
1208                    .map_index(|(partition, Tup2(ts, val))| (*ts, Tup2(*partition, *val)));
1209
1210                let input_stream = input_stream.as_partitioned_zset();
1211
1212                let waterline: Stream<_, TypedBox<u64, DynDataTyped<u64>>> = input_by_time
1213                    .waterline_monotonic(|| 0, move |ts| ts.saturating_sub(lateness))
1214                    .transaction_delay_with_initial_value(TypedBox::new(0))
1215                    .inspect(|w| println!("waterline: {w:?}"));
1216
1217                let aggregator = <Fold<i64, i64, DefaultSemigroup<_>, _, _>>::new(
1218                    0i64,
1219                    |agg: &mut i64, val: &i64, w: ZWeight| *agg += val * w,
1220                );
1221
1222                let range_spec = RelRange::new(RelOffset::Before(1000), RelOffset::Before(0));
1223                let output_1000_0 = input_by_time
1224                    .partitioned_rolling_aggregate(
1225                        |Tup2(partition, val)| (*partition, *val),
1226                        aggregator.clone(),
1227                        range_spec,
1228                    )
1229                    .accumulate_integrate()
1230                    .accumulate_output();
1231
1232                let output_1000_0_expected =
1233                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1234                        .accumulate_output();
1235
1236                let output_1000_0_waterline = Stream::partitioned_rolling_aggregate_with_waterline(
1237                    &input_by_time,
1238                    &waterline,
1239                    |Tup2(partition, val)| (*partition, *val),
1240                    aggregator.clone(),
1241                    range_spec,
1242                )
1243                .accumulate_integrate()
1244                .accumulate_output();
1245
1246                let output_1000_0_waterline_expected =
1247                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1248                        .accumulate_output();
1249
1250                let output_1000_0_linear = input_by_time
1251                    .partitioned_rolling_aggregate_linear(
1252                        |Tup2(partition, val)| (*partition, *val),
1253                        |v| *v,
1254                        |v| v,
1255                        range_spec,
1256                    )
1257                    .accumulate_integrate()
1258                    .accumulate_output();
1259
1260                let output_1000_0_linear_expected =
1261                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1262                        .accumulate_output();
1263
1264                let range_spec = RelRange::new(RelOffset::Before(500), RelOffset::After(500));
1265                let aggregate_500_500 = input_by_time
1266                    .partitioned_rolling_aggregate(
1267                        |Tup2(partition, val)| (*partition, *val),
1268                        aggregator.clone(),
1269                        range_spec,
1270                    )
1271                    .accumulate_integrate()
1272                    .accumulate_output();
1273
1274                let aggregate_500_500_expected =
1275                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1276                        .accumulate_output();
1277
1278                let aggregate_500_500_waterline = input_by_time
1279                    .partitioned_rolling_aggregate_with_waterline(
1280                        &waterline,
1281                        |Tup2(partition, val)| (*partition, *val),
1282                        aggregator.clone(),
1283                        range_spec,
1284                    );
1285
1286                // let output_500_500_waterline = aggregate_500_500_waterline.gather(0).integrate();
1287
1288                let bound: TraceBound<DynPair<DynDataTyped<u64>, DynOpt<DynData>>> =
1289                    TraceBound::new();
1290                let b: Tup2<u64, Option<i64>> = Tup2(u64::MAX, None::<i64>);
1291
1292                bound.set(Box::new(b).erase_box());
1293
1294                aggregate_500_500_waterline
1295                    .integrate_trace_with_bound(TraceBound::new(), bound)
1296                    .apply(move |trace| {
1297                        if let Some(bound) = size_bound {
1298                            assert!(trace.size_of().total_bytes() <= bound);
1299                        }
1300                    });
1301
1302                let aggregate_500_500_waterline = aggregate_500_500_waterline
1303                    .accumulate_integrate()
1304                    .accumulate_output();
1305
1306                let aggregate_500_500_waterline_expected =
1307                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1308                        .accumulate_output();
1309
1310                let output_500_500_linear = input_by_time
1311                    .partitioned_rolling_aggregate_linear(
1312                        |Tup2(partition, val)| (*partition, *val),
1313                        |v| *v,
1314                        |v| v,
1315                        range_spec,
1316                    )
1317                    .accumulate_integrate()
1318                    .accumulate_output();
1319
1320                let output_500_500_linear_expected =
1321                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1322                        .accumulate_output();
1323
1324                let range_spec = RelRange::new(RelOffset::Before(500), RelOffset::Before(100));
1325                let output_500_100 = input_by_time
1326                    .partitioned_rolling_aggregate(
1327                        |Tup2(partition, val)| (*partition, *val),
1328                        aggregator,
1329                        range_spec,
1330                    )
1331                    .accumulate_integrate()
1332                    .accumulate_output();
1333
1334                let output_500_100_expected =
1335                    partitioned_rolling_aggregate_slow(&input_stream.inner(), range_spec)
1336                        .accumulate_output();
1337
1338                Ok((
1339                    input_handle,
1340                    output_1000_0,
1341                    output_1000_0_expected,
1342                    output_1000_0_waterline,
1343                    output_1000_0_waterline_expected,
1344                    output_1000_0_linear,
1345                    output_1000_0_linear_expected,
1346                    aggregate_500_500,
1347                    aggregate_500_500_expected,
1348                    aggregate_500_500_waterline,
1349                    aggregate_500_500_waterline_expected,
1350                    output_500_500_linear,
1351                    output_500_500_linear_expected,
1352                    output_500_100,
1353                    output_500_100_expected,
1354                ))
1355            },
1356        )
1357        .unwrap()
1358    }
1359
1360    fn test_partition_rolling_aggregate(
1361        lateness: u64,
1362        size_bound: Option<usize>,
1363        trace: Vec<InputBatch>,
1364        transaction: bool,
1365    ) {
1366        let (
1367            mut circuit,
1368            (
1369                input,
1370                output_1000_0,
1371                output_1000_0_expected,
1372                output_1000_0_waterline,
1373                output_1000_0_waterline_expected,
1374                output_1000_0_linear,
1375                output_1000_0_linear_expected,
1376                aggregate_500_500,
1377                aggregate_500_500_expected,
1378                aggregate_500_500_waterline,
1379                aggregate_500_500_waterline_expected,
1380                aggregate_500_500_linear,
1381                aggregate_500_500_linear_expected,
1382                output_500_100,
1383                output_500_100_expected,
1384            ),
1385        ) = partition_rolling_aggregate_circuit(lateness, size_bound);
1386
1387        if transaction {
1388            circuit.start_transaction().unwrap();
1389            for mut batch in trace {
1390                input.append(&mut batch);
1391                circuit.step().unwrap();
1392            }
1393
1394            circuit.commit_transaction().unwrap();
1395
1396            assert_eq!(
1397                output_1000_0.concat().consolidate(),
1398                output_1000_0_expected.concat().consolidate()
1399            );
1400            assert_eq!(
1401                output_1000_0_waterline.concat().consolidate(),
1402                output_1000_0_waterline_expected.concat().consolidate()
1403            );
1404            assert_eq!(
1405                output_1000_0_linear.concat().consolidate(),
1406                output_1000_0_linear_expected.concat().consolidate()
1407            );
1408            assert_eq!(
1409                aggregate_500_500.concat().consolidate(),
1410                aggregate_500_500_expected.concat().consolidate()
1411            );
1412            assert_eq!(
1413                aggregate_500_500_waterline.concat().consolidate(),
1414                aggregate_500_500_waterline_expected.concat().consolidate()
1415            );
1416            assert_eq!(
1417                aggregate_500_500_linear.concat().consolidate(),
1418                aggregate_500_500_linear_expected.concat().consolidate()
1419            );
1420            assert_eq!(
1421                output_500_100.concat().consolidate(),
1422                output_500_100_expected.concat().consolidate()
1423            );
1424        } else {
1425            for mut batch in trace {
1426                input.append(&mut batch);
1427                circuit.transaction().unwrap();
1428
1429                assert_eq!(
1430                    output_1000_0.concat().consolidate(),
1431                    output_1000_0_expected.concat().consolidate()
1432                );
1433                assert_eq!(
1434                    output_1000_0_waterline.concat().consolidate(),
1435                    output_1000_0_waterline_expected.concat().consolidate()
1436                );
1437                assert_eq!(
1438                    output_1000_0_linear.concat().consolidate(),
1439                    output_1000_0_linear_expected.concat().consolidate()
1440                );
1441                assert_eq!(
1442                    aggregate_500_500.concat().consolidate(),
1443                    aggregate_500_500_expected.concat().consolidate()
1444                );
1445                assert_eq!(
1446                    aggregate_500_500_waterline.concat().consolidate(),
1447                    aggregate_500_500_waterline_expected.concat().consolidate()
1448                );
1449                assert_eq!(
1450                    aggregate_500_500_linear.concat().consolidate(),
1451                    aggregate_500_500_linear_expected.concat().consolidate()
1452                );
1453                assert_eq!(
1454                    output_500_100.concat().consolidate(),
1455                    output_500_100_expected.concat().consolidate()
1456                );
1457            }
1458        }
1459
1460        circuit.kill().unwrap();
1461    }
1462
1463    #[test]
1464    fn test_partitioned_over_range_2() {
1465        test_partition_rolling_aggregate(
1466            u64::MAX,
1467            None,
1468            vec![
1469                vec![Tup2(2u64, Tup2(Tup2(110271u64, 100i64), 1i64))],
1470                vec![Tup2(2u64, Tup2(Tup2(0u64, 100i64), 1i64))],
1471            ],
1472            false,
1473        );
1474    }
1475
1476    #[test]
1477    fn test_partitioned_over_range() {
1478        test_partition_rolling_aggregate(
1479            u64::MAX,
1480            None,
1481            vec![
1482                vec![
1483                    Tup2(0u64, Tup2(Tup2(1u64, 100i64), 1)),
1484                    Tup2(0, Tup2(Tup2(10, 100), 1)),
1485                    Tup2(0, Tup2(Tup2(20, 100), 1)),
1486                    Tup2(0, Tup2(Tup2(30, 100), 1)),
1487                ],
1488                vec![
1489                    Tup2(0u64, Tup2(Tup2(1u64, 100i64), 1)),
1490                    Tup2(0, Tup2(Tup2(10, 100), 1)),
1491                    Tup2(0, Tup2(Tup2(20, 100), 1)),
1492                    Tup2(0, Tup2(Tup2(30, 100), 1)),
1493                ],
1494                vec![
1495                    Tup2(0u64, Tup2(Tup2(5u64, 100i64), 1)),
1496                    Tup2(0, Tup2(Tup2(15, 100), 1)),
1497                    Tup2(0, Tup2(Tup2(25, 100), 1)),
1498                    Tup2(0, Tup2(Tup2(35, 100), 1)),
1499                ],
1500                vec![
1501                    Tup2(1u64, Tup2(Tup2(1u64, 100i64), 1)),
1502                    Tup2(1, Tup2(Tup2(1000, 100), 1)),
1503                    Tup2(1, Tup2(Tup2(2000, 100), 1)),
1504                    Tup2(1, Tup2(Tup2(3000, 100), 1)),
1505                ],
1506            ],
1507            false,
1508        );
1509    }
1510
1511    #[test]
1512    fn test_empty_tree() {
1513        test_partition_rolling_aggregate(
1514            u64::MAX,
1515            None,
1516            std::iter::repeat_n(
1517                vec![
1518                    vec![Tup2(0u64, Tup2(Tup2(1u64, 100i64), 1))],
1519                    vec![Tup2(0u64, Tup2(Tup2(1u64, 100i64), -1))],
1520                ],
1521                1000,
1522            )
1523            .flatten()
1524            .collect::<Vec<_>>(),
1525            false,
1526        );
1527    }
1528
1529    // Test derived from issue #199 (https://github.com/feldera/feldera/issues/199).
1530    #[test]
1531    fn test_partitioned_rolling_aggregate2() {
1532        let (circuit, (input, expected)) = RootCircuit::build(move |circuit| {
1533            let (input_stream, input_handle) =
1534                circuit.add_input_indexed_zset::<u64, Tup2<u64, i64>>();
1535
1536            let (expected, expected_handle) =
1537                circuit.dyn_add_input_indexed_zset::<DynData/*<u64>*/, DynPair<DynDataTyped<u64>, DynOpt<DynData/*<i64>*/>>>(&AddInputIndexedZSetFactories::new::<u64, Tup2<u64, Option<i64>>>());
1538
1539            let expected = expected.typed::<OrdPartitionedIndexedZSet<u64, u64, _, Option<i64>, _>>();
1540
1541            let input_by_time =
1542                input_stream.map_index(|(partition, Tup2(ts, val))| (*ts, Tup2(*partition, *val)));
1543
1544            input_stream.inspect(|f| {
1545                for (p, Tup2(ts, v), w) in f.iter() {
1546                    println!(" input {p} {ts} {v:6} {w:+}");
1547                }
1548            });
1549            let range_spec = RelRange::new(RelOffset::Before(3), RelOffset::Before(2));
1550            let sum = input_by_time.partitioned_rolling_aggregate_linear(
1551                |Tup2(partition, val)| (*partition, *val),
1552                |&f| f,
1553                |x| x, range_spec);
1554            sum.inspect(|f| {
1555                for (p, Tup2(ts, sum), w) in f.iter() {
1556                    println!("output {p} {ts} {:6} {w:+}", sum.unwrap_or_default());
1557                }
1558            });
1559            expected.accumulate_apply2(&sum, |expected, actual| assert_eq!(expected.iter().collect::<Vec<_>>(), actual.iter().collect::<Vec<_>>()));
1560            Ok((input_handle, expected_handle))
1561        })
1562        .unwrap();
1563
1564        input.append(&mut vec![
1565            Tup2(1u64, Tup2(Tup2(0u64, 1i64), 1)),
1566            Tup2(1, Tup2(Tup2(1, 10), 1)),
1567            Tup2(1, Tup2(Tup2(2, 100), 1)),
1568            Tup2(1, Tup2(Tup2(3, 1000), 1)),
1569            Tup2(1, Tup2(Tup2(4, 10000), 1)),
1570            Tup2(1, Tup2(Tup2(5, 100000), 1)),
1571            Tup2(1, Tup2(Tup2(9, 123456), 1)),
1572        ]);
1573        expected.dyn_append(
1574            &mut Box::new(lean_vec![
1575                Tup2(1u64, Tup2(Tup2(0u64, None::<i64>), 1)),
1576                Tup2(1, Tup2(Tup2(1, None), 1)),
1577                Tup2(1, Tup2(Tup2(2, Some(1)), 1)),
1578                Tup2(1, Tup2(Tup2(3, Some(11)), 1)),
1579                Tup2(1, Tup2(Tup2(4, Some(110)), 1)),
1580                Tup2(1, Tup2(Tup2(5, Some(1100)), 1)),
1581                Tup2(1, Tup2(Tup2(9, None), 1)),
1582            ])
1583            .erase_box(),
1584        );
1585        circuit.transaction().unwrap();
1586    }
1587
1588    #[test]
1589    fn test_partitioned_rolling_average() {
1590        let (circuit, (input, expected)) = RootCircuit::build(move |circuit| {
1591            let (input_stream, input_handle) =
1592                circuit.add_input_indexed_zset::<u64, Tup2<u64, i64>>();
1593
1594            let (expected_stream, expected_handle) =
1595                circuit.dyn_add_input_indexed_zset::<DynData/*<u64>*/, DynPair<DynDataTyped<u64>, DynOpt<DynData/*<i64>*/>>>(&AddInputIndexedZSetFactories::new::<u64, Tup2<u64, Option<i64>>>());
1596
1597            let expected_stream = expected_stream.typed::<OrdPartitionedIndexedZSet<u64, u64, _, Option<i64>, _>>();
1598
1599            let input_by_time =
1600                input_stream.map_index(|(partition, Tup2(ts, val))| (*ts, Tup2(*partition, *val)));
1601
1602            let range_spec = RelRange::new(RelOffset::Before(3), RelOffset::Before(1));
1603            input_by_time
1604                .partitioned_rolling_average(
1605                    |Tup2(partition, val)| (*partition, *val),
1606                    range_spec)
1607                .accumulate_apply2(&expected_stream, |avg: SpineSnapshot<OrdPartitionedIndexedZSet<u64, u64, _, Option<i64>, _>>, expected| assert_eq!(avg.iter().collect::<Vec<_>>(), expected.iter().collect::<Vec<_>>()));
1608            Ok((input_handle, expected_handle))
1609        })
1610        .unwrap();
1611
1612        circuit.transaction().unwrap();
1613
1614        input.append(&mut vec![
1615            Tup2(0u64, Tup2(Tup2(10u64, 10i64), 1)),
1616            Tup2(0, Tup2(Tup2(11, 20), 1)),
1617            Tup2(0, Tup2(Tup2(12, 30), 1)),
1618            Tup2(0, Tup2(Tup2(13, 40), 1)),
1619            Tup2(0, Tup2(Tup2(14, 50), 1)),
1620            Tup2(0, Tup2(Tup2(15, 60), 1)),
1621        ]);
1622        expected.dyn_append(
1623            &mut Box::new(lean_vec![
1624                Tup2(0u64, Tup2(Tup2(10u64, None::<i64>), 1)),
1625                Tup2(0, Tup2(Tup2(11, Some(10)), 1)),
1626                Tup2(0, Tup2(Tup2(12, Some(15)), 1)),
1627                Tup2(0, Tup2(Tup2(13, Some(20)), 1)),
1628                Tup2(0, Tup2(Tup2(14, Some(30)), 1)),
1629                Tup2(0, Tup2(Tup2(15, Some(40)), 1)),
1630            ])
1631            .erase_box(),
1632        );
1633        circuit.transaction().unwrap();
1634    }
1635
1636    #[test]
1637    fn test_partitioned_rolling_aggregate() {
1638        let (circuit, input) = RootCircuit::build(move |circuit| {
1639            let (input_stream, input_handle) =
1640                circuit.add_input_indexed_zset::<u64, Tup2<u64, i64>>();
1641
1642            input_stream.inspect(|f| {
1643                for (p, Tup2(ts, v), w) in f.iter() {
1644                    println!(" input {p} {ts} {v:6} {w:+}");
1645                }
1646            });
1647            let input_by_time =
1648                input_stream.map_index(|(partition, Tup2(ts, val))| (*ts, Tup2(*partition, *val)));
1649
1650            let range_spec = RelRange::new(RelOffset::Before(3), RelOffset::Before(2));
1651            let sum = input_by_time.partitioned_rolling_aggregate_linear(
1652                |Tup2(partition, val)| (*partition, *val),
1653                |&f| f,
1654                |x| x,
1655                range_spec,
1656            );
1657            sum.inspect(|f| {
1658                for (p, Tup2(ts, sum), w) in f.iter() {
1659                    println!("output {p} {ts} {:6} {w:+}", sum.unwrap_or_default());
1660                }
1661            });
1662            Ok(input_handle)
1663        })
1664        .unwrap();
1665
1666        input.append(&mut vec![
1667            Tup2(1u64, Tup2(Tup2(0u64, 1i64), 1)),
1668            Tup2(1, Tup2(Tup2(1, 10), 1)),
1669            Tup2(1, Tup2(Tup2(2, 100), 1)),
1670            Tup2(1, Tup2(Tup2(3, 1000), 1)),
1671            Tup2(1, Tup2(Tup2(4, 10000), 1)),
1672            Tup2(1, Tup2(Tup2(5, 100000), 1)),
1673            Tup2(1, Tup2(Tup2(9, 123456), 1)),
1674        ]);
1675        circuit.transaction().unwrap();
1676    }
1677
1678    type InputTuple = Tup2<u64, Tup2<Tup2<u64, i64>, ZWeight>>;
1679    type InputBatch = Vec<InputTuple>;
1680
1681    fn input_tuple(partitions: u64, window: (u64, u64)) -> impl Strategy<Value = InputTuple> {
1682        (
1683            (0..partitions),
1684            (
1685                (window.0..window.1, 100..101i64).prop_map(|(x, y)| Tup2(x, y)),
1686                1..2i64,
1687            )
1688                .prop_map(|(x, y)| Tup2(x, y)),
1689        )
1690            .prop_map(|(x, y)| Tup2(x, y))
1691    }
1692
1693    fn input_batch(
1694        partitions: u64,
1695        window: (u64, u64),
1696        max_batch_size: usize,
1697    ) -> impl Strategy<Value = InputBatch> {
1698        collection::vec(input_tuple(partitions, window), 0..max_batch_size)
1699    }
1700
1701    fn input_trace(
1702        partitions: u64,
1703        epoch: u64,
1704        max_batch_size: usize,
1705        max_batches: usize,
1706    ) -> impl Strategy<Value = Vec<InputBatch>> {
1707        collection::vec(
1708            input_batch(partitions, (0, epoch), max_batch_size),
1709            0..max_batches,
1710        )
1711    }
1712
1713    fn input_trace_quasi_monotone(
1714        partitions: u64,
1715        window_size: u64,
1716        window_step: u64,
1717        max_batch_size: usize,
1718        batches: usize,
1719    ) -> impl Strategy<Value = Vec<InputBatch>> {
1720        (0..batches)
1721            .map(|i| {
1722                input_batch(
1723                    partitions,
1724                    (i as u64 * window_step, i as u64 * window_step + window_size),
1725                    max_batch_size,
1726                )
1727                .boxed()
1728            })
1729            .collect::<Vec<_>>()
1730    }
1731
1732    proptest! {
1733        #![proptest_config(ProptestConfig::with_cases(5))]
1734
1735        #[test]
1736        fn proptest_partitioned_rolling_aggregate_quasi_monotone_small_steps(trace in input_trace_quasi_monotone(5, 10_000, 2_000, 20, 200)) {
1737            // 10_000 is an empirically established bound: without GC this test needs >10KB.
1738            test_partition_rolling_aggregate(10000, Some(30_000), trace, false);
1739        }
1740
1741        #[test]
1742        #[ignore = "https://github.com/feldera/feldera/issues/4764"]
1743        fn proptest_partitioned_rolling_aggregate_quasi_monotone_big_step(trace in input_trace_quasi_monotone(5, 10_000, 2_000, 20, 200)) {
1744            // 10_000 is an empirically established bound: without GC this test needs >10KB.
1745            test_partition_rolling_aggregate(10000, Some(30_000), trace, true);
1746        }
1747    }
1748
1749    proptest! {
1750        #[test]
1751        fn proptest_partitioned_over_range_sparse_small_steps(trace in input_trace(5, 1_000_000, 10, 10)) {
1752            test_partition_rolling_aggregate(u64::MAX, None, trace, false);
1753        }
1754
1755        #[test]
1756        fn proptest_partitioned_over_range_sparse_big_step(trace in input_trace(5, 1_000_000, 10, 10)) {
1757            test_partition_rolling_aggregate(u64::MAX, None, trace, true);
1758        }
1759
1760        #[test]
1761        fn proptest_partitioned_over_range_dense_small_steps(trace in input_trace(5, 500, 25, 10)) {
1762            test_partition_rolling_aggregate(u64::MAX, None, trace, false);
1763        }
1764
1765        #[test]
1766        fn proptest_partitioned_over_range_dense_big_step(trace in input_trace(5, 500, 25, 10)) {
1767            test_partition_rolling_aggregate(u64::MAX, None, trace, true);
1768        }
1769    }
1770}