use async_stream::stream;
use dyn_clone::{DynClone, clone_box};
use futures::Stream as AsyncStream;
use std::{
any::TypeId,
borrow::Cow,
cell::{Cell, RefCell},
cmp::{Ordering, min},
collections::BTreeMap,
marker::PhantomData,
rc::Rc,
};
use crate::{
DBData, DBWeight, DynZWeight, NestedCircuit, Position, RootCircuit, ZWeight,
algebra::{HasOne, IndexedZSet, IndexedZSetReader, Lattice, OrdIndexedZSet, PartialOrder},
circuit::{
Circuit, Scope, Stream, WithClock,
metadata::{BatchSizeStats, INPUT_BATCHES_STATS, OUTPUT_BATCHES_STATS, OperatorMeta},
operator_traits::{Operator, UnaryOperator},
splitter_output_chunk_size,
},
dynamic::{
DataTrait, DynData, DynOpt, DynPair, DynPairs, DynSet, DynUnit, DynWeight, Erase, Factory,
LeanVec, Weight, WeightTrait, WithFactory,
},
operator::{
async_stream_operators::{StreamingBinaryOperator, StreamingBinaryWrapper},
dynamic::upsert::UpsertFactories,
},
time::Timestamp,
trace::{
Batch, BatchReader, BatchReaderFactories, Builder, Cursor, Filter, OrdWSet,
OrdWSetFactories, Spine,
cursor::CursorGroup,
spine_async::{SpineCursor, WithSnapshot},
},
};
mod aggregator;
mod average;
mod chain_aggregate;
mod fold;
mod max;
mod min;
use crate::{
dynamic::{BSet, ClonableTrait},
storage::file::Deserializable,
utils::Tup2,
};
pub use aggregator::{
AggCombineFunc, AggOutputFunc, Aggregator, DynAggregator, DynAggregatorImpl, Postprocess,
};
pub use average::{Avg, AvgFactories, DynAverage};
pub use fold::Fold;
pub use max::{Max, MaxSemigroup};
pub use min::{ArgMinSome, Min, MinSemigroup, MinSome1, MinSome1Semigroup};
use super::MonoIndexedZSet;
pub struct IncAggregateFactories<I: Batch<Time = ()>, O: IndexedZSet, T: Timestamp> {
pub input_factories: I::Factories,
pub trace_factories: <T::TimedBatch<I> as BatchReader>::Factories,
pub upsert_factories: UpsertFactories<T, O>,
keys_factory: &'static dyn Factory<DynSet<I::Key>>,
output_pair_factory: &'static dyn Factory<DynPair<I::Key, DynOpt<O::Val>>>,
output_pairs_factory: &'static dyn Factory<DynPairs<I::Key, DynOpt<O::Val>>>,
}
impl<I, O, T> IncAggregateFactories<I, O, T>
where
I: Batch<Time = ()>,
O: IndexedZSet<Key = I::Key>,
T: Timestamp,
{
pub fn new<KType, VType, RType, OType>() -> Self
where
KType: DBData + Erase<I::Key>,
<KType as Deserializable>::ArchivedDeser: Ord,
VType: DBData + Erase<I::Val>,
RType: DBWeight + Erase<I::R>,
OType: DBData + Erase<O::Val>,
{
Self {
input_factories: BatchReaderFactories::new::<KType, VType, RType>(),
trace_factories: BatchReaderFactories::new::<KType, VType, RType>(),
upsert_factories: UpsertFactories::new::<KType, OType>(),
keys_factory: WithFactory::<BSet<KType>>::FACTORY,
output_pair_factory: WithFactory::<Tup2<KType, Option<OType>>>::FACTORY,
output_pairs_factory: WithFactory::<LeanVec<Tup2<KType, Option<OType>>>>::FACTORY,
}
}
}
pub struct IncAggregateLinearFactories<
I: BatchReader,
R: WeightTrait + ?Sized,
O: IndexedZSet,
T: Timestamp,
> {
pub out_factory: &'static dyn Factory<O::Val>,
pub agg_factory: &'static dyn Factory<R>,
pub option_agg_factory: &'static dyn Factory<DynOpt<R>>,
pub aggregate_factories: IncAggregateFactories<OrdWSet<I::Key, R>, O, T>,
}
impl<I, R, O, T> IncAggregateLinearFactories<I, R, O, T>
where
I: BatchReader,
O: IndexedZSet<Key = I::Key>,
R: WeightTrait + ?Sized,
T: Timestamp,
{
pub fn new<KType, RType, OType>() -> Self
where
KType: DBData + Erase<I::Key>,
<KType as Deserializable>::ArchivedDeser: Ord,
RType: DBWeight + Erase<R>,
OType: DBData + Erase<O::Val>,
{
Self {
out_factory: WithFactory::<OType>::FACTORY,
agg_factory: WithFactory::<RType>::FACTORY,
option_agg_factory: WithFactory::<Option<RType>>::FACTORY,
aggregate_factories: IncAggregateFactories::new::<KType, (), RType, OType>(),
}
}
}
pub struct StreamAggregateFactories<I: BatchReader, O: IndexedZSet> {
pub input_factories: I::Factories,
pub output_factories: O::Factories,
pub option_output_factory: &'static dyn Factory<DynOpt<O::Val>>,
}
impl<I, O> StreamAggregateFactories<I, O>
where
I: BatchReader,
O: IndexedZSet<Key = I::Key>,
{
pub fn new<KType, VType, RType, OType>() -> Self
where
KType: DBData + Erase<I::Key>,
VType: DBData + Erase<I::Val>,
RType: DBWeight + Erase<I::R>,
OType: DBData + Erase<O::Val>,
{
Self {
input_factories: BatchReaderFactories::new::<KType, VType, RType>(),
output_factories: BatchReaderFactories::new::<KType, OType, ZWeight>(),
option_output_factory: WithFactory::<Option<OType>>::FACTORY,
}
}
}
pub struct StreamLinearAggregateFactories<I: IndexedZSetReader, R, O: IndexedZSet>
where
R: WeightTrait + ?Sized,
{
aggregate_factories: StreamAggregateFactories<OrdWSet<I::Key, R>, O>,
out_factory: &'static dyn Factory<O::Val>,
agg_factory: &'static dyn Factory<R>,
option_agg_factory: &'static dyn Factory<DynOpt<R>>,
}
impl<I, R, O> StreamLinearAggregateFactories<I, R, O>
where
I: IndexedZSetReader,
O: IndexedZSet<Key = I::Key>,
R: WeightTrait + ?Sized,
{
pub fn new<KType, VType, RType, OType>() -> Self
where
KType: DBData + Erase<I::Key>,
VType: DBData + Erase<I::Val>,
RType: DBWeight + Erase<R>,
OType: DBData + Erase<O::Val>,
{
Self {
aggregate_factories: StreamAggregateFactories::new::<KType, (), RType, OType>(),
out_factory: WithFactory::<OType>::FACTORY,
agg_factory: WithFactory::<RType>::FACTORY,
option_agg_factory: WithFactory::<Option<RType>>::FACTORY,
}
}
}
pub trait WeightedCountOutFunc<R: ?Sized, O: ?Sized>: Fn(&mut R, &mut O) + DynClone {}
impl<R: ?Sized, O: ?Sized, F> WeightedCountOutFunc<R, O> for F where F: Fn(&mut R, &mut O) + DynClone
{}
struct WeightedCount<R: WeightTrait + ?Sized, O: DataTrait + ?Sized> {
out_factory: &'static dyn Factory<O>,
weight_factory: &'static dyn Factory<R>,
option_weight_factory: &'static dyn Factory<DynOpt<R>>,
out_func: Box<dyn WeightedCountOutFunc<R, O>>,
}
impl<R: WeightTrait + ?Sized, O: DataTrait + ?Sized> Clone for WeightedCount<R, O> {
fn clone(&self) -> Self {
Self {
out_factory: self.out_factory,
weight_factory: self.weight_factory,
option_weight_factory: self.option_weight_factory,
out_func: clone_box(self.out_func.as_ref()),
}
}
}
impl<R: WeightTrait + ?Sized, O: DataTrait + ?Sized> WeightedCount<R, O> {
fn new(
out_factory: &'static dyn Factory<O>,
weight_factory: &'static dyn Factory<R>,
option_weight_factory: &'static dyn Factory<DynOpt<R>>,
out_func: Box<dyn WeightedCountOutFunc<R, O>>,
) -> Self {
Self {
out_factory,
weight_factory,
option_weight_factory,
out_func,
}
}
}
impl<T, R, O> DynAggregator<DynUnit, T, R> for WeightedCount<R, O>
where
T: Timestamp,
R: WeightTrait + ?Sized,
O: DataTrait + ?Sized,
{
type Accumulator = R;
type Output = O;
fn opt_accumulator_factory(&self) -> &'static dyn Factory<DynOpt<Self::Accumulator>> {
self.option_weight_factory
}
fn output_factory(&self) -> &'static dyn Factory<Self::Output> {
self.out_factory
}
fn combine(&self) -> &dyn AggCombineFunc<R> {
&|left: &mut R, right: &R| left.add_assign(right)
}
fn aggregate(
&self,
cursor: &mut dyn Cursor<DynUnit, DynUnit, T, R>,
aggregate: &mut DynOpt<R>,
) {
let mut weight = self.weight_factory.default_box();
cursor.map_times(&mut |_t, w| weight.add_assign(w));
if weight.is_zero() {
aggregate.set_none();
} else {
aggregate.from_val(&mut weight)
}
}
fn finalize(&self, accumulator: &mut R, output: &mut O) {
(self.out_func)(accumulator, output);
}
fn aggregate_and_finalize(
&self,
cursor: &mut dyn Cursor<DynUnit, DynUnit, T, R>,
output: &mut DynOpt<Self::Output>,
) {
self.option_weight_factory.with(&mut |w| {
self.aggregate(cursor, w);
match w.get_mut() {
None => output.set_none(),
Some(w) => {
output.set_some_with(&mut |o| DynAggregator::<_, T, _>::finalize(self, w, o))
}
}
})
}
}
impl Stream<RootCircuit, MonoIndexedZSet> {
#[allow(clippy::type_complexity)]
pub fn dyn_aggregate_mono(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateFactories<MonoIndexedZSet, MonoIndexedZSet, ()>,
aggregator: &dyn DynAggregator<DynData, (), DynZWeight, Accumulator = DynData, Output = DynData>,
) -> Stream<RootCircuit, MonoIndexedZSet> {
self.dyn_aggregate(persistent_id, factories, aggregator)
}
pub fn dyn_aggregate_linear_mono(
&self,
persisent_id: Option<&str>,
factories: &IncAggregateLinearFactories<MonoIndexedZSet, DynWeight, MonoIndexedZSet, ()>,
agg_func: Box<dyn Fn(&DynData, &DynData, &DynZWeight, &mut DynWeight)>,
out_func: Box<dyn WeightedCountOutFunc<DynWeight, DynData>>,
) -> Stream<RootCircuit, MonoIndexedZSet> {
self.dyn_aggregate_linear_generic(persisent_id, factories, agg_func, out_func)
}
}
impl Stream<NestedCircuit, MonoIndexedZSet> {
#[allow(clippy::type_complexity)]
pub fn dyn_aggregate_mono(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateFactories<
MonoIndexedZSet,
MonoIndexedZSet,
<NestedCircuit as WithClock>::Time,
>,
aggregator: &dyn DynAggregator<
DynData,
<NestedCircuit as WithClock>::Time,
DynZWeight,
Accumulator = DynData,
Output = DynData,
>,
) -> Stream<NestedCircuit, MonoIndexedZSet> {
self.dyn_aggregate(persistent_id, factories, aggregator)
}
pub fn dyn_aggregate_linear_mono(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateLinearFactories<
MonoIndexedZSet,
DynWeight,
MonoIndexedZSet,
<NestedCircuit as WithClock>::Time,
>,
agg_func: Box<dyn Fn(&DynData, &DynData, &DynZWeight, &mut DynWeight)>,
out_func: Box<dyn WeightedCountOutFunc<DynWeight, DynData>>,
) -> Stream<NestedCircuit, MonoIndexedZSet> {
self.dyn_aggregate_linear_generic(persistent_id, factories, agg_func, out_func)
}
}
impl<C, Z> Stream<C, Z>
where
C: Circuit,
Z: Clone + 'static,
{
#[allow(clippy::type_complexity)]
pub fn dyn_stream_aggregate<Acc, Out>(
&self,
factories: &StreamAggregateFactories<Z, OrdIndexedZSet<Z::Key, Out>>,
aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>,
) -> Stream<C, OrdIndexedZSet<Z::Key, Out>>
where
Z: IndexedZSet,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
{
self.dyn_stream_aggregate_generic(factories, aggregator)
}
pub fn dyn_stream_aggregate_generic<Acc, Out, O>(
&self,
factories: &StreamAggregateFactories<Z, O>,
aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>,
) -> Stream<C, O>
where
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
Z: Batch<Time = ()>,
O: IndexedZSet<Key = Z::Key, Val = Out>,
{
self.circuit()
.add_unary_operator(
Aggregate::new(
&factories.output_factories,
factories.option_output_factory,
aggregator,
),
&self.dyn_shard(&factories.input_factories),
)
.mark_sharded()
}
pub fn dyn_stream_aggregate_linear<A>(
&self,
factories: &StreamLinearAggregateFactories<Z, A, OrdIndexedZSet<Z::Key, A>>,
f: Box<dyn Fn(&Z::Key, &Z::Val, &Z::R, &mut A)>,
) -> Stream<C, OrdIndexedZSet<Z::Key, A>>
where
Z: IndexedZSet,
A: WeightTrait + ?Sized,
{
self.dyn_stream_aggregate_linear_generic(factories, f, Box::new(|w, out| w.move_to(out)))
}
pub fn dyn_stream_aggregate_linear_generic<A, O>(
&self,
factories: &StreamLinearAggregateFactories<Z, A, O>,
agg_func: Box<dyn Fn(&Z::Key, &Z::Val, &Z::R, &mut A)>,
out_func: Box<dyn WeightedCountOutFunc<A, O::Val>>,
) -> Stream<C, O>
where
Z: IndexedZSet,
O: IndexedZSet<Key = Z::Key>,
A: WeightTrait + ?Sized,
{
self.dyn_weigh(&factories.aggregate_factories.input_factories, agg_func)
.dyn_stream_aggregate_generic(
&factories.aggregate_factories,
&WeightedCount::new(
factories.out_factory,
factories.agg_factory,
factories.option_agg_factory,
out_func,
),
)
}
#[allow(clippy::type_complexity)]
pub fn dyn_aggregate<Acc, Out>(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateFactories<Z, OrdIndexedZSet<Z::Key, Out>, C::Time>,
aggregator: &dyn DynAggregator<Z::Val, <C as WithClock>::Time, Z::R, Accumulator = Acc, Output = Out>,
) -> Stream<C, OrdIndexedZSet<Z::Key, Out>>
where
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
Z: IndexedZSet,
{
self.dyn_aggregate_generic::<Acc, Out, OrdIndexedZSet<Z::Key, Out>>(
persistent_id,
factories,
aggregator,
)
}
pub fn dyn_aggregate_generic<Acc, Out, O>(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateFactories<Z, O, C::Time>,
aggregator: &dyn DynAggregator<Z::Val, <C as WithClock>::Time, Z::R, Accumulator = Acc, Output = Out>,
) -> Stream<C, O>
where
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
Z: Batch<Time = ()>,
O: IndexedZSet<Key = Z::Key, Val = Out>,
{
let circuit = self.circuit();
circuit
.region("aggregate", || {
let stream = self.dyn_shard(&factories.input_factories);
circuit
.add_binary_operator(
StreamingBinaryWrapper::new(AggregateIncremental::new(
factories.keys_factory,
factories.output_pair_factory,
factories.output_pairs_factory,
aggregator,
circuit.clone(),
)),
&stream.dyn_accumulate(&factories.input_factories),
&stream.dyn_accumulate_trace(
&factories.trace_factories,
&factories.input_factories,
),
)
.mark_sharded()
.upsert::<O>(persistent_id, &factories.upsert_factories)
.mark_sharded()
})
.clone()
}
pub fn dyn_aggregate_linear<A>(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateLinearFactories<Z, A, OrdIndexedZSet<Z::Key, A>, C::Time>,
f: Box<dyn Fn(&Z::Key, &Z::Val, &Z::R, &mut A)>,
) -> Stream<C, OrdIndexedZSet<Z::Key, A>>
where
Z: IndexedZSet,
A: WeightTrait + ?Sized,
{
self.dyn_aggregate_linear_generic(
persistent_id,
factories,
f,
Box::new(|w, out| w.move_to(out)),
)
}
pub fn dyn_aggregate_linear_generic<A, O>(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateLinearFactories<Z, A, O, C::Time>,
agg_func: Box<dyn Fn(&Z::Key, &Z::Val, &Z::R, &mut A)>,
out_func: Box<dyn WeightedCountOutFunc<A, O::Val>>,
) -> Stream<C, O>
where
Z: IndexedZSet,
O: IndexedZSet<Key = Z::Key>,
A: WeightTrait + ?Sized,
{
self.circuit()
.region("aggregate_linear", || {
self.dyn_weigh(&factories.aggregate_factories.input_factories, agg_func)
.set_persistent_id(
persistent_id
.map(|name| format!("{name}-weighted"))
.as_deref(),
)
.dyn_aggregate_generic(
persistent_id,
&factories.aggregate_factories,
&WeightedCount::new(
factories.out_factory,
factories.agg_factory,
factories.option_agg_factory,
out_func,
),
)
})
.clone()
}
pub fn dyn_weigh<T>(
&self,
output_factories: &OrdWSetFactories<Z::Key, T>,
f: Box<dyn Fn(&Z::Key, &Z::Val, &Z::R, &mut T)>,
) -> Stream<C, OrdWSet<Z::Key, T>>
where
T: WeightTrait + ?Sized,
Z: IndexedZSet,
{
self.dyn_weigh_generic::<OrdWSet<_, _>>(output_factories, f)
}
pub fn dyn_weigh_generic<O>(
&self,
output_factories: &O::Factories,
f: Box<dyn Fn(&Z::Key, &Z::Val, &Z::R, &mut O::R)>,
) -> Stream<C, O>
where
Z: IndexedZSet,
O: Batch<Key = Z::Key, Val = DynUnit, Time = ()>,
{
let output_factories = output_factories.clone();
let output = self
.try_sharded_version()
.apply_named("Weigh", move |batch| {
let mut agg = output_factories.weight_factory().default_box();
let mut agg_delta = output_factories.weight_factory().default_box();
let mut input_weight = batch.factories().weight_factory().default_box();
let mut delta = <O::Builder>::with_capacity(
&output_factories,
batch.key_count(),
batch.key_count(),
);
let mut cursor = batch.cursor();
while cursor.key_valid() {
agg.set_zero();
while cursor.val_valid() {
**input_weight = **cursor.weight();
f(
cursor.key(),
cursor.val(),
&*input_weight,
agg_delta.as_mut(),
);
agg.add_assign(&agg_delta);
cursor.step_val();
}
if !agg.is_zero() {
delta.push_val_diff_mut(().erase_mut(), &mut agg);
delta.push_key(cursor.key());
}
cursor.step_key();
}
delta.done()
});
output.mark_sharded_if(self);
output
}
}
impl Stream<RootCircuit, MonoIndexedZSet> {
pub fn dyn_aggregate_linear_retain_keys_mono(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateLinearFactories<MonoIndexedZSet, DynWeight, MonoIndexedZSet, ()>,
waterline: &Stream<RootCircuit, Box<DynData>>,
retain_key_func: Box<dyn Fn(&DynData) -> Filter<DynData>>,
agg_func: Box<dyn Fn(&DynData, &DynData, &DynZWeight, &mut DynWeight)>,
out_func: Box<dyn WeightedCountOutFunc<DynWeight, DynData>>,
) -> Stream<RootCircuit, MonoIndexedZSet>
where {
self.dyn_aggregate_linear_retain_keys_generic(
persistent_id,
factories,
waterline,
retain_key_func,
agg_func,
out_func,
)
}
}
impl<Z> Stream<RootCircuit, Z>
where
Z: Clone + 'static,
{
pub fn dyn_aggregate_linear_retain_keys_generic<A, O, TS>(
&self,
persistent_id: Option<&str>,
factories: &IncAggregateLinearFactories<Z, A, O, ()>,
waterline: &Stream<RootCircuit, Box<TS>>,
retain_key_func: Box<dyn Fn(&TS) -> Filter<Z::Key>>,
agg_func: Box<dyn Fn(&Z::Key, &Z::Val, &Z::R, &mut A)>,
out_func: Box<dyn WeightedCountOutFunc<A, O::Val>>,
) -> Stream<RootCircuit, O>
where
Z: IndexedZSet<Time = ()>,
O: IndexedZSet<Key = Z::Key>,
A: WeightTrait + ?Sized,
TS: DataTrait + ?Sized,
Box<TS>: Clone,
{
self.circuit()
.region("aggregate_linear_retain_keys", || {
let weighted =
self.dyn_weigh(&factories.aggregate_factories.input_factories, agg_func);
weighted.dyn_integrate_trace_retain_keys(waterline, retain_key_func);
weighted
.set_persistent_id(
persistent_id
.map(|name| format!("{name}-weighted"))
.as_deref(),
)
.dyn_aggregate_generic(
persistent_id,
&factories.aggregate_factories,
&WeightedCount::new(
factories.out_factory,
factories.agg_factory,
factories.option_agg_factory,
out_func,
),
)
})
.clone()
}
}
struct Aggregate<Z, Acc, O>
where
Z: BatchReader,
O: IndexedZSet,
Acc: DataTrait + ?Sized,
{
factories: O::Factories,
option_output_factory: &'static dyn Factory<DynOpt<O::Val>>,
aggregator: Box<dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = O::Val>>,
_type: PhantomData<(Z, O)>,
}
impl<Z, Acc, O: Batch> Aggregate<Z, Acc, O>
where
Z: BatchReader,
O: IndexedZSet,
Acc: DataTrait + ?Sized,
{
pub fn new(
factories: &O::Factories,
option_output_factory: &'static dyn Factory<DynOpt<O::Val>>,
aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = O::Val>,
) -> Self {
Self {
factories: factories.clone(),
option_output_factory,
aggregator: clone_box(aggregator),
_type: PhantomData,
}
}
}
impl<Z, Acc, O: Batch> Operator for Aggregate<Z, Acc, O>
where
Z: BatchReader,
O: IndexedZSet,
Acc: DataTrait + ?Sized,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("Aggregate")
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
}
impl<Z, Acc, O> UnaryOperator<Z, O> for Aggregate<Z, Acc, O>
where
Z: BatchReader<Time = ()>,
O: IndexedZSet<Key = Z::Key>,
Acc: DataTrait + ?Sized,
{
async fn eval(&mut self, i: &Z) -> O {
let n = i.key_count();
let mut builder = O::Builder::with_capacity(&self.factories, n, n);
let mut agg = self.option_output_factory.default_box();
let mut cursor = i.cursor();
while cursor.key_valid() {
self.aggregator
.aggregate_and_finalize(&mut CursorGroup::new(&mut cursor, ()), agg.as_mut());
if let Some(agg) = agg.get_mut() {
builder.push_val_diff_mut(agg, ZWeight::one().erase_mut());
builder.push_key(cursor.key());
}
cursor.step_key();
}
builder.done()
}
}
struct AggregateIncremental<Z, IT, Acc, Out, Clk>
where
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
Z: BatchReader,
IT: WithSnapshot,
{
keys_factory: &'static dyn Factory<DynSet<Z::Key>>,
output_pair_factory: &'static dyn Factory<DynPair<Z::Key, DynOpt<Out>>>,
output_pairs_factory: &'static dyn Factory<DynPairs<Z::Key, DynOpt<Out>>>,
clock: Clk,
aggregator: Box<
dyn DynAggregator<
Z::Val,
<IT::Batch as BatchReader>::Time,
Z::R,
Accumulator = Acc,
Output = Out,
>,
>,
empty_input: Cell<bool>,
empty_output: Cell<bool>,
keys_of_interest: RefCell<BTreeMap<<IT::Batch as BatchReader>::Time, Box<DynSet<Z::Key>>>>,
input_batch_stats: RefCell<BatchSizeStats>,
output_batch_stats: RefCell<BatchSizeStats>,
_type: PhantomData<fn(&Z, &IT)>,
}
impl<Z, IT, Acc, Out, Clk> AggregateIncremental<Z, IT, Acc, Out, Clk>
where
Clk: WithClock<Time = <IT::Batch as BatchReader>::Time>,
Z: BatchReader<Time = ()>,
IT: WithSnapshot,
IT::Batch: BatchReader<Key = Z::Key, Val = Z::Val, R = Z::R>,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
{
pub fn new(
keys_factory: &'static dyn Factory<DynSet<Z::Key>>,
output_pair_factory: &'static dyn Factory<DynPair<Z::Key, DynOpt<Out>>>,
output_pairs_factory: &'static dyn Factory<DynPairs<Z::Key, DynOpt<Out>>>,
aggregator: &dyn DynAggregator<
Z::Val,
<IT::Batch as BatchReader>::Time,
Z::R,
Accumulator = Acc,
Output = Out,
>,
clock: Clk,
) -> Self {
Self {
keys_factory,
output_pair_factory,
output_pairs_factory,
clock,
aggregator: clone_box(aggregator),
empty_input: Cell::new(false),
empty_output: Cell::new(false),
keys_of_interest: RefCell::new(BTreeMap::new()),
input_batch_stats: RefCell::new(BatchSizeStats::new()),
output_batch_stats: RefCell::new(BatchSizeStats::new()),
_type: PhantomData,
}
}
fn eval_key(
self: &Rc<Self>,
key: &Z::Key,
input_cursor: &mut SpineCursor<IT::Batch>,
output: &mut DynPairs<Z::Key, DynOpt<Out>>,
time: &Clk::Time,
key_aggregate: &mut DynPair<Z::Key, DynOpt<Out>>,
) {
let (output_key, aggregate) = key_aggregate.split_mut();
key.clone_to(output_key);
if input_cursor.seek_key_exact(key, None) {
self.aggregator.aggregate_and_finalize(
&mut CursorGroup::new(input_cursor, time.clone()),
aggregate,
);
output.push_val(key_aggregate);
if TypeId::of::<<IT::Batch as BatchReader>::Time>() != TypeId::of::<()>() {
input_cursor.rewind_vals();
let mut time_of_interest = None;
while input_cursor.val_valid() {
input_cursor.map_times(&mut |ts, _| {
time_of_interest = if !ts.less_equal(time) {
match &time_of_interest {
None => Some(time.join(ts)),
Some(time_of_interest) => {
Some(min(time_of_interest, &time.join(ts)).clone())
}
}
} else {
time_of_interest.clone()
}
});
input_cursor.step_val();
}
if let Some(t) = time_of_interest {
self.keys_of_interest
.borrow_mut()
.entry(t)
.or_insert_with(|| self.keys_factory.default_box())
.insert_ref(key);
}
}
} else {
output.push_val(key_aggregate);
}
}
}
impl<Z, IT, Acc, Out, Clk> Operator for AggregateIncremental<Z, IT, Acc, Out, Clk>
where
Clk: WithClock<Time = <IT::Batch as BatchReader>::Time> + 'static,
Z: BatchReader,
IT: WithSnapshot + 'static,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("AggregateIncremental")
}
fn clock_start(&mut self, scope: Scope) {
if scope == 0 {
self.empty_input.set(false);
self.empty_output.set(false);
}
}
fn clock_end(&mut self, scope: Scope) {
debug_assert!(self.keys_of_interest.borrow().keys().all(|ts| {
if ts.less_equal(&self.clock.time().epoch_end(scope)) {
}
!ts.less_equal(&self.clock.time().epoch_end(scope))
}));
}
fn metadata(&self, meta: &mut OperatorMeta) {
meta.extend(metadata! {
INPUT_BATCHES_STATS => self.input_batch_stats.borrow().metadata(),
OUTPUT_BATCHES_STATS => self.output_batch_stats.borrow().metadata(),
});
}
fn fixedpoint(&self, scope: Scope) -> bool {
let epoch_end = self.clock.time().epoch_end(scope);
self.empty_input.get()
&& self.empty_output.get()
&& self
.keys_of_interest
.borrow()
.keys()
.all(|ts| !ts.less_equal(&epoch_end))
}
}
impl<Z, IT, Acc, Out, Clk>
StreamingBinaryOperator<Option<Spine<Z>>, IT, Box<DynPairs<Z::Key, DynOpt<Out>>>>
for AggregateIncremental<Z, IT, Acc, Out, Clk>
where
Clk: WithClock<Time = <IT::Batch as BatchReader>::Time> + 'static,
Z: Batch<Time = ()>,
IT: WithSnapshot + 'static,
IT::Batch: BatchReader<Key = Z::Key, Val = Z::Val, R = Z::R>,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
{
fn eval(
self: Rc<Self>,
delta: &Option<Spine<Z>>,
input_trace: &IT,
) -> impl AsyncStream<Item = (Box<DynPairs<Z::Key, DynOpt<Out>>>, bool, Option<Position>)> + 'static
{
let chunk_size = splitter_output_chunk_size();
let delta = delta.as_ref().map(|b| b.ro_snapshot());
let input_trace = if delta.is_some() {
Some(input_trace.ro_snapshot())
} else {
None
};
stream! {
let Some(delta) = delta else {
yield (self.output_pairs_factory.default_box(), true, None);
return;
};
self.input_batch_stats.borrow_mut().add_batch(delta.len());
self.empty_input.set(delta.is_empty());
self.empty_output.set(true);
let mut result = self.output_pairs_factory.default_box();
result.reserve(chunk_size);
let mut delta_cursor = delta.cursor();
let mut input_trace_cursor = input_trace.unwrap().cursor();
let time = self.clock.time();
let keys_of_interest = self
.keys_of_interest
.borrow_mut()
.remove(&time)
.unwrap_or_else(|| self.keys_factory.default_box());
let mut keys_of_interest = keys_of_interest.dyn_iter();
let mut key_of_interest = keys_of_interest.next();
let mut key_aggregate = self.output_pair_factory.default_box();
while delta_cursor.key_valid() && key_of_interest.is_some() {
let key_of_interest_ref = key_of_interest.unwrap();
match delta_cursor.key().cmp(key_of_interest_ref) {
Ordering::Less => {
self.eval_key(
delta_cursor.key(),
&mut input_trace_cursor,
result.as_mut(),
&time,
key_aggregate.as_mut(),
);
delta_cursor.step_key();
}
Ordering::Greater => {
self.eval_key(
key_of_interest_ref,
&mut input_trace_cursor,
result.as_mut(),
&time,
key_aggregate.as_mut(),
);
key_of_interest = keys_of_interest.next();
}
Ordering::Equal => {
self.eval_key(
delta_cursor.key(),
&mut input_trace_cursor,
result.as_mut(),
&time,
key_aggregate.as_mut(),
);
delta_cursor.step_key();
key_of_interest = keys_of_interest.next();
}
}
if result.len() >= chunk_size {
self.empty_output.update(|empty_output| empty_output & result.is_empty());
self.output_batch_stats.borrow_mut().add_batch(result.len());
yield (result, false, delta_cursor.position());
result = self.output_pairs_factory.default_box();
result.reserve(chunk_size);
}
}
while delta_cursor.key_valid() {
self.eval_key(
delta_cursor.key(),
&mut input_trace_cursor,
result.as_mut(),
&time,
key_aggregate.as_mut(),
);
delta_cursor.step_key();
if result.len() >= chunk_size {
self.empty_output.update(|empty_output| empty_output & result.is_empty());
self.output_batch_stats.borrow_mut().add_batch(result.len());
yield (result, false, delta_cursor.position());
result = self.output_pairs_factory.default_box();
result.reserve(chunk_size);
}
}
while key_of_interest.is_some() {
self.eval_key(
key_of_interest.unwrap(),
&mut input_trace_cursor,
result.as_mut(),
&time,
key_aggregate.as_mut(),
);
key_of_interest = keys_of_interest.next();
if result.len() >= chunk_size {
self.empty_output.update(|empty_output| empty_output & result.is_empty());
self.output_batch_stats.borrow_mut().add_batch(result.len());
yield (result, !(delta_cursor.key_valid() || key_of_interest.is_some()), delta_cursor.position());
result = self.output_pairs_factory.default_box();
result.reserve(chunk_size);
}
}
self.output_batch_stats.borrow_mut().add_batch(result.len());
yield (result, true, delta_cursor.position());
}
}
}
#[cfg(test)]
pub mod test {
use anyhow::Result as AnyResult;
use std::{cell::RefCell, rc::Rc};
use crate::{
Circuit, RootCircuit, Runtime, Stream,
algebra::DefaultSemigroup,
circuit::CircuitConfig,
indexed_zset,
operator::{
Fold, GeneratorNested, Min,
dynamic::aggregate::{MinSome1, Postprocess},
},
trace::{BatchReader, Cursor},
typed_batch::{OrdIndexedZSet, OrdZSet, TypedBatch},
utils::{Tup1, Tup3},
zset,
};
type TestZSet = OrdZSet<Tup2<u64, i64>>;
fn aggregate_test_circuit(
circuit: &mut RootCircuit,
inputs: Vec<Vec<TestZSet>>,
) -> AnyResult<()> {
let mut inputs = inputs.into_iter();
circuit
.iterate(|child| {
let counter = Rc::new(RefCell::new(0));
let counter_clone = counter.clone();
let input: Stream<_, OrdIndexedZSet<u64, i64>> = child
.add_source(GeneratorNested::new(
Box::new(move || {
*counter_clone.borrow_mut() = 0;
if Runtime::worker_index() == 0 {
let mut deltas = inputs.next().unwrap_or_default().into_iter();
Box::new(move || deltas.next().unwrap_or_else(|| zset! {}))
} else {
Box::new(|| zset! {})
}
}),
zset! {},
))
.map_index(|Tup2(k, v)| (*k, *v));
let sum = <Fold<i64, i64, DefaultSemigroup<_>, _, _>>::new(
0,
|acc: &mut i64, v: &i64, w: i64| *acc += *v * w,
);
let sum_linear = |val: &i64| -> i64 { *val };
let sum_inc: Stream<_, OrdIndexedZSet<u64, i64>> =
input.aggregate(sum.clone()).gather(0);
let sum_inc_linear: Stream<_, OrdIndexedZSet<u64, i64>> =
input.aggregate_linear(sum_linear).gather(0);
sum_inc.accumulate_apply2(&sum_inc_linear, |d1, d2| {
let mut cursor1 = d1.cursor();
let mut cursor2 = d2.cursor();
while cursor1.key_valid() {
while cursor1.val_valid() {
if *cursor1.val().downcast_checked::<i64>() != 0 {
assert!(cursor2.key_valid());
assert_eq!(cursor2.key(), cursor1.key());
assert!(cursor2.val_valid());
assert_eq!(cursor2.val(), cursor1.val());
assert_eq!(cursor2.weight(), cursor1.weight());
cursor2.step_val();
}
cursor1.step_val();
}
if cursor2.key_valid() && cursor2.key() == cursor1.key() {
cursor2.step_key();
}
cursor1.step_key();
}
assert!(!cursor2.key_valid());
});
Ok((
async move || {
*counter.borrow_mut() += 1;
Ok(*counter.borrow() == MAX_ITERATIONS)
},
(),
))
})
.unwrap();
Ok(())
}
use crate::{dynamic::DowncastTrait, utils::Tup2};
use proptest::{collection, prelude::*};
const MAX_ROUNDS: usize = 10;
const MAX_ITERATIONS: usize = 10;
const NUM_KEYS: u64 = 3;
const MAX_VAL: i64 = 3;
const MAX_TUPLES: usize = 8;
pub fn test_zset() -> impl Strategy<Value = TestZSet> {
collection::vec(
(
(0..NUM_KEYS, -MAX_VAL..MAX_VAL).prop_map(|(x, y)| Tup2(x, y)),
-1..=1i64,
),
0..MAX_TUPLES,
)
.prop_map(|tuples| {
OrdZSet::from_tuples(
(),
tuples
.into_iter()
.map(|(k, w)| Tup2(Tup2(k, ()), w))
.collect(),
)
})
}
pub fn test_input() -> impl Strategy<Value = Vec<Vec<TestZSet>>> {
collection::vec(
collection::vec(test_zset(), 0..MAX_ITERATIONS),
0..MAX_ROUNDS,
)
}
proptest! {
#[test]
fn proptest_aggregate_test_st(inputs in test_input()) {
let iterations = inputs.len();
let mut circuit = Runtime::init_circuit(1, |circuit| aggregate_test_circuit(circuit, inputs)).unwrap().0;
for _ in 0..iterations {
circuit.transaction().unwrap();
}
}
#[test]
fn proptest_aggregate_test_mt(inputs in test_input(), log_workers in (1..=4)) {
let workers = 1usize << log_workers;
let iterations = inputs.len();
let mut circuit = Runtime::init_circuit(workers, |circuit| aggregate_test_circuit(circuit, inputs)).unwrap().0;
for _ in 0..iterations {
circuit.transaction().unwrap();
}
circuit.kill().unwrap();
}
}
fn count_test(workers: usize, transaction: bool) {
let (
mut dbsp,
(
input_handle,
count_weighted_output,
sum_weighted_output,
count_distinct_output,
sum_distinct_output,
),
) = Runtime::init_circuit(
CircuitConfig::from(workers).with_splitter_chunk_size_records(2),
move |circuit| {
let (input_stream, input_handle) = circuit.add_input_indexed_zset::<u64, u64>();
let count_weighted_output = input_stream.weighted_count().accumulate_output();
let sum_weighted_output = input_stream
.aggregate_linear(|value: &u64| *value as i64)
.accumulate_output();
let count_distinct_output = input_stream
.aggregate(<Fold<_, _, DefaultSemigroup<_>, _, _>>::new(
0,
|sum: &mut u64, _v: &u64, _w| *sum += 1,
))
.accumulate_output();
let sum_distinct_output = input_stream
.aggregate(<Fold<_, _, DefaultSemigroup<_>, _, _>>::new(
0,
|sum: &mut u64, v: &u64, _w| *sum += v,
))
.accumulate_output();
Ok((
input_handle,
count_weighted_output,
sum_weighted_output,
count_distinct_output,
sum_distinct_output,
))
},
)
.unwrap();
let input = [
vec![Tup2(1u64, Tup2(1u64, 1)), Tup2(1, Tup2(2, 2))],
vec![
Tup2(2, Tup2(2, 1)),
Tup2(2, Tup2(4, 1)),
Tup2(1, Tup2(2, -1)),
],
vec![Tup2(1, Tup2(3, 1)), Tup2(1, Tup2(2, -1))],
];
let expected_count_output = vec![
indexed_zset! {1 => {2 => 1}},
indexed_zset! {2 => {2 => 1}},
indexed_zset! {},
];
let expected_sum_output = vec![
indexed_zset! {1 => {3 => 1}},
indexed_zset! {2 => {6 => 1}},
indexed_zset! {1 => {3 => -1, 4 => 1}},
];
let expected_count_weighted_output = vec![
indexed_zset! {1 => {3 => 1}},
indexed_zset! {1 => {3 => -1, 2 => 1}, 2 => {2 => 1}},
indexed_zset! {},
];
let expected_sum_weighted_output = vec![
indexed_zset! {1 => {5 => 1}},
indexed_zset! {2 => {6 => 1}, 1 => {5 => -1, 3 => 1}},
indexed_zset! {1 => {3 => -1, 4 => 1}},
];
if transaction {
dbsp.start_transaction().unwrap();
for i in input {
input_handle.append(&mut i.clone());
dbsp.step().unwrap();
}
dbsp.commit_transaction().unwrap();
assert_eq!(
count_distinct_output.concat().consolidate(),
TypedBatch::merge_batches(expected_count_output)
);
assert_eq!(
sum_distinct_output.concat().consolidate(),
TypedBatch::merge_batches(expected_sum_output)
);
assert_eq!(
count_weighted_output.concat().consolidate(),
TypedBatch::merge_batches(expected_count_weighted_output)
);
assert_eq!(
sum_weighted_output.concat().consolidate(),
TypedBatch::merge_batches(expected_sum_weighted_output)
);
} else {
for i in 0..input.len() {
input_handle.append(&mut input[i].clone());
dbsp.transaction().unwrap();
assert_eq!(
count_distinct_output.concat().consolidate(),
expected_count_output[i]
);
assert_eq!(
sum_distinct_output.concat().consolidate(),
expected_sum_output[i]
);
assert_eq!(
count_weighted_output.concat().consolidate(),
expected_count_weighted_output[i]
);
assert_eq!(
sum_weighted_output.concat().consolidate(),
expected_sum_weighted_output[i]
);
}
}
dbsp.kill().unwrap();
}
#[test]
fn count_test1_small_step() {
count_test(1, false);
}
#[test]
fn count_test4_small_step() {
count_test(4, false);
}
#[test]
fn count_test1_big_step() {
count_test(1, true);
}
#[test]
fn count_test4_big_step() {
count_test(4, true);
}
#[test]
fn min_some_test_small_steps() {
min_some_test(false);
}
#[test]
fn min_some_test_big_step() {
min_some_test(true);
}
fn min_some_test(transaction: bool) {
let (mut dbsp, (input_handle, output_handle)) = Runtime::init_circuit(
CircuitConfig::from(4).with_splitter_chunk_size_records(1),
move |circuit| {
let (input_stream, input_handle) =
circuit.add_input_indexed_zset::<u64, Tup1<Option<u64>>>();
let output_handle = input_stream
.aggregate(MinSome1)
.accumulate_integrate()
.accumulate_output();
Ok((input_handle, output_handle))
},
)
.unwrap();
let inputs = [
vec![
Tup2(1u64, Tup2(Tup1(None), 1)),
Tup2(2u64, Tup2(Tup1(Some(5)), 1)),
],
vec![
Tup2(1u64, Tup2(Tup1(Some(3)), 1)),
Tup2(2u64, Tup2(Tup1(None), 1)),
],
vec![
Tup2(1u64, Tup2(Tup1(None), -1)),
Tup2(2u64, Tup2(Tup1(Some(5)), -1)),
],
];
let expected_outputs = [
indexed_zset! {1 => {Tup1(None) => 1}, 2 => { Tup1(Some(5)) => 1 }},
indexed_zset! {1 => {Tup1(Some(3)) => 1}, 2 => { Tup1(Some(5)) => 1 }},
indexed_zset! {1 => {Tup1(Some(3)) => 1}, 2 => { Tup1(None) => 1 }},
];
if transaction {
dbsp.start_transaction().unwrap();
for i in inputs {
input_handle.append(&mut i.clone());
dbsp.step().unwrap();
}
dbsp.commit_transaction().unwrap();
assert_eq!(
output_handle.concat().consolidate(),
expected_outputs.last().unwrap().clone()
);
} else {
for i in 0..inputs.len() {
input_handle.append(&mut inputs[i].clone());
dbsp.transaction().unwrap();
assert_eq!(output_handle.concat().consolidate(), expected_outputs[i]);
}
}
dbsp.kill().unwrap();
}
#[test]
fn postprocess_test_small_steps() {
postprocess_test(false);
}
#[test]
fn postprocess_test_big_steps() {
postprocess_test(true);
}
fn postprocess_test(transaction: bool) {
let (mut dbsp, (input_handle, output_handle)) = Runtime::init_circuit(
CircuitConfig::from(4).with_splitter_chunk_size_records(2),
move |circuit| {
let (input_stream, input_handle) =
circuit.add_input_indexed_zset::<u64, Tup1<u64>>();
let output_handle = input_stream
.aggregate(Postprocess::new(Min, |x: &Tup1<u64>| Tup1(Some(x.0))))
.accumulate_integrate()
.accumulate_output();
Ok((input_handle, output_handle))
},
)
.unwrap();
let inputs = [
vec![Tup2(1u64, Tup2(Tup1(1), 1)), Tup2(2u64, Tup2(Tup1(5), 1))],
vec![Tup2(1u64, Tup2(Tup1(3), 1)), Tup2(2u64, Tup2(Tup1(2), 1))],
vec![Tup2(1u64, Tup2(Tup1(1), -1)), Tup2(2u64, Tup2(Tup1(5), -1))],
];
let expected_outputs = [
indexed_zset! {1 => {Tup1(Some(1)) => 1}, 2 => { Tup1(Some(5)) => 1 }},
indexed_zset! {1 => {Tup1(Some(1)) => 1}, 2 => { Tup1(Some(2)) => 1 }},
indexed_zset! {1 => {Tup1(Some(3)) => 1}, 2 => { Tup1(Some(2)) => 1 }},
];
if transaction {
dbsp.start_transaction().unwrap();
for i in inputs {
input_handle.append(&mut i.clone());
dbsp.step().unwrap();
}
dbsp.commit_transaction().unwrap();
let output = output_handle.concat().consolidate();
assert_eq!(output, expected_outputs.last().unwrap().clone());
} else {
for i in 0..inputs.len() {
input_handle.append(&mut inputs[i].clone());
dbsp.transaction().unwrap();
let output = output_handle.concat().consolidate();
assert_eq!(output, expected_outputs[i]);
}
}
dbsp.kill().unwrap();
}
#[test]
fn aggregate_linear_postprocess_test_small_step() {
aggregate_linear_postprocess_test(false)
}
#[test]
fn aggregate_linear_postprocess_test_big_step() {
aggregate_linear_postprocess_test(true)
}
fn aggregate_linear_postprocess_test(transaction: bool) {
fn agg_func(x: &Option<i32>) -> Tup3<i32, i32, i32> {
let v = x.unwrap_or_default();
Tup3(1, if x.is_some() { 1 } else { 0 }, v)
}
fn postprocess_func(Tup3(_count, non_nulls, sum): Tup3<i32, i32, i32>) -> Option<i32> {
if non_nulls > 0 { Some(sum) } else { None }
}
let (
mut dbsp,
(
input_handle,
_waterline_handle,
sum_handle,
sum_slow_handle,
sum_retain_handle,
sum_slow_retain_handle,
),
) = Runtime::init_circuit(
CircuitConfig::from(2).with_splitter_chunk_size_records(1),
|circuit| {
let (input, input_handle) = circuit.add_input_zset::<Tup2<i32, Option<i32>>>();
let (waterline, waterline_handle) = circuit.add_input_stream::<i32>();
let indexed_input = input.map_index(|Tup2(k, v)| (*k, *v));
let sum = indexed_input.aggregate_linear_postprocess(agg_func, postprocess_func);
let sum_retain = indexed_input.aggregate_linear_postprocess_retain_keys(
&waterline.typed_box(),
|k, ts| k >= ts,
agg_func,
postprocess_func,
);
let sum_slow = indexed_input
.aggregate_linear(agg_func)
.map_index(|(k, v)| (*k, postprocess_func(*v)));
let sum_slow_retain = indexed_input
.aggregate_linear_retain_keys(&waterline.typed_box(), |k, ts| k >= ts, agg_func)
.map_index(|(k, v)| (*k, postprocess_func(*v)));
let sum_handle = sum.accumulate_integrate().accumulate_output();
let sum_slow_handle = sum_slow.accumulate_integrate().accumulate_output();
let sum_retain_handle = sum_retain.accumulate_integrate().accumulate_output();
let sum_slow_retain_handle =
sum_slow_retain.accumulate_integrate().accumulate_output();
Ok((
input_handle,
waterline_handle,
sum_handle,
sum_slow_handle,
sum_retain_handle,
sum_slow_retain_handle,
))
},
)
.unwrap();
let inputs = vec![
vec![Tup2(Tup2(1i32, None), 2)],
vec![Tup2(Tup2(1i32, Some(5)), 1)],
vec![Tup2(Tup2(1i32, Some(3)), 2)],
vec![Tup2(Tup2(1i32, Some(-11)), 1)],
vec![
Tup2(Tup2(1i32, Some(-11)), -1),
Tup2(Tup2(1i32, Some(5)), -1),
Tup2(Tup2(1i32, Some(3)), -2),
],
vec![Tup2(Tup2(1i32, None), -2)],
];
let expected_output = [
indexed_zset! {1 => {None => 1}},
indexed_zset! {1 => {Some(5) => 1}},
indexed_zset! {1 => {Some(11) => 1}},
indexed_zset! {1 => {Some(0) => 1}},
indexed_zset! {1 => {None => 1}},
indexed_zset! {},
];
if transaction {
dbsp.start_transaction().unwrap();
for i in inputs {
input_handle.append(&mut i.clone());
dbsp.step().unwrap();
}
dbsp.commit_transaction().unwrap();
assert_eq!(
sum_handle.concat().consolidate(),
expected_output.last().unwrap().clone()
);
assert_eq!(
sum_slow_handle.concat().consolidate(),
expected_output.last().unwrap().clone()
);
assert_eq!(
sum_retain_handle.concat().consolidate(),
expected_output.last().unwrap().clone()
);
assert_eq!(
sum_slow_retain_handle.concat().consolidate(),
expected_output.last().unwrap().clone()
);
} else {
for i in 0..inputs.len() {
input_handle.append(&mut inputs[i].clone());
dbsp.transaction().unwrap();
assert_eq!(sum_handle.concat().consolidate(), expected_output[i]);
assert_eq!(sum_slow_handle.concat().consolidate(), expected_output[i]);
assert_eq!(sum_retain_handle.concat().consolidate(), expected_output[i]);
assert_eq!(
sum_slow_retain_handle.concat().consolidate(),
expected_output[i]
);
}
}
dbsp.kill().unwrap()
}
#[test]
fn aggregate_linear_postprocess_test_i8_small_step() {
aggregate_linear_postprocess_test_i8(false)
}
#[test]
fn aggregate_linear_postprocess_test_i8_big_step() {
aggregate_linear_postprocess_test_i8(true)
}
fn aggregate_linear_postprocess_test_i8(transaction: bool) {
fn agg_func(x: &Option<i8>) -> Tup3<i8, i8, i8> {
let v = x.unwrap_or_default();
Tup3(1, if x.is_some() { 1 } else { 0 }, v)
}
fn postprocess_func(Tup3(_count, non_nulls, sum): Tup3<i8, i8, i8>) -> Option<i8> {
if non_nulls > 0 { Some(sum) } else { None }
}
let (
mut dbsp,
(
input_handle,
_waterline_handle,
sum_handle,
sum_slow_handle,
sum_retain_handle,
sum_slow_retain_handle,
),
) = Runtime::init_circuit(
CircuitConfig::from(2).with_splitter_chunk_size_records(1),
|circuit| {
let (input, input_handle) = circuit.add_input_zset::<Tup2<i8, Option<i8>>>();
let (waterline, waterline_handle) = circuit.add_input_stream::<i8>();
let indexed_input = input.map_index(|Tup2(k, v)| (*k, *v));
let sum = indexed_input.aggregate_linear_postprocess(agg_func, postprocess_func);
let sum_retain = indexed_input.aggregate_linear_postprocess_retain_keys(
&waterline.typed_box(),
|k, ts| k >= ts,
agg_func,
postprocess_func,
);
let sum_slow = indexed_input
.aggregate_linear(agg_func)
.map_index(|(k, v)| (*k, postprocess_func(*v)));
let sum_slow_retain = indexed_input
.aggregate_linear_retain_keys(&waterline.typed_box(), |k, ts| k >= ts, agg_func)
.map_index(|(k, v)| (*k, postprocess_func(*v)));
let sum_handle = sum.accumulate_integrate().accumulate_output();
let sum_slow_handle = sum_slow.accumulate_integrate().accumulate_output();
let sum_retain_handle = sum_retain.accumulate_integrate().accumulate_output();
let sum_slow_retain_handle =
sum_slow_retain.accumulate_integrate().accumulate_output();
Ok((
input_handle,
waterline_handle,
sum_handle,
sum_slow_handle,
sum_retain_handle,
sum_slow_retain_handle,
))
},
)
.unwrap();
let inputs = vec![
vec![Tup2(Tup2(1i8, None), 2)],
vec![Tup2(Tup2(1i8, Some(5)), 1)],
vec![Tup2(Tup2(1i8, Some(3)), 2)],
vec![Tup2(Tup2(1i8, Some(-11)), 1)],
vec![
Tup2(Tup2(1i8, Some(-11)), -1),
Tup2(Tup2(1i8, Some(5)), -1),
Tup2(Tup2(1i8, Some(3)), -2),
],
vec![Tup2(Tup2(1i8, None), -2)],
];
let expected_output = [
indexed_zset! {1 => {None => 1}},
indexed_zset! {1 => {Some(5) => 1}},
indexed_zset! {1 => {Some(11) => 1}},
indexed_zset! {1 => {Some(0) => 1}},
indexed_zset! {1 => {None => 1}},
indexed_zset! {},
];
if transaction {
dbsp.start_transaction().unwrap();
for i in inputs {
input_handle.append(&mut i.clone());
dbsp.step().unwrap();
}
dbsp.commit_transaction().unwrap();
assert_eq!(
sum_handle.concat().consolidate(),
expected_output.last().unwrap().clone()
);
assert_eq!(
sum_slow_handle.concat().consolidate(),
expected_output.last().unwrap().clone()
);
assert_eq!(
sum_retain_handle.concat().consolidate(),
expected_output.last().unwrap().clone()
);
assert_eq!(
sum_slow_retain_handle.concat().consolidate(),
expected_output.last().unwrap().clone()
);
} else {
for i in 0..inputs.len() {
input_handle.append(&mut inputs[i].clone());
dbsp.transaction().unwrap();
assert_eq!(sum_handle.concat().consolidate(), expected_output[i]);
assert_eq!(sum_slow_handle.concat().consolidate(), expected_output[i]);
assert_eq!(sum_retain_handle.concat().consolidate(), expected_output[i]);
assert_eq!(
sum_slow_retain_handle.concat().consolidate(),
expected_output[i]
);
}
}
dbsp.kill().unwrap()
}
mod retain_values_test {
use proptest::{collection, prelude::*};
use crate::{Runtime, ZWeight, circuit::CircuitConfig, operator::Max, utils::Tup2};
const LATENESS: u32 = 10;
fn max_retain_values_test(inputs: Vec<Vec<Tup2<u32, Tup2<Tup2<u32, u32>, ZWeight>>>>) {
let (mut dbsp, input_handle) = Runtime::init_circuit(
CircuitConfig::from(2).with_splitter_chunk_size_records(2),
|circuit| {
let (input, input_handle) =
circuit.add_input_indexed_zset::<u32, Tup2<u32, u32>>();
let waterline = input.waterline(
|| u32::MIN,
|_k, Tup2(_val, ts)| {
(*ts).saturating_sub(LATENESS)
},
|ts1, ts2| {
std::cmp::max(*ts1, *ts2)
},
);
let max = input.aggregate(Max);
input.accumulate_integrate_trace_retain_values_top_n(
&waterline,
|val, ts| val.1 >= ts.saturating_sub(LATENESS),
1,
);
let input2 = input.map_index(|(k, v)| (*k, *v));
let expected_max = input2.aggregate(Max);
max.apply2(&expected_max, |val, expected_val| {
assert_eq!(val, expected_val);
});
let input3 = input.map_index(|(k, v)| (*k, *v));
let top3 = input3.topk_desc(3);
input3.accumulate_integrate_trace_retain_values_top_n(
&waterline,
|val, ts| val.1 >= ts.saturating_sub(LATENESS),
3,
);
let expected_top3 = input.map_index(|(k, v)| (*k, *v)).topk_desc(3);
top3.apply2(&expected_top3, |val, expected_val| {
assert_eq!(val, expected_val);
});
let input4 = input.map_index(|(k, v)| (*k, *v));
let bottom3 = input4.topk_asc(3);
input4.accumulate_integrate_trace_retain_values_bottom_n(
&waterline,
|val, ts| val.1 >= ts.saturating_sub(LATENESS),
3,
);
let expected_bottom3 = input.map_index(|(k, v)| (*k, *v)).topk_asc(3);
bottom3.apply2(&expected_bottom3, |val, expected_val| {
assert_eq!(val, expected_val);
});
Ok(input_handle)
},
)
.unwrap();
for i in 0..inputs.len() {
input_handle.append(&mut inputs[i].clone());
dbsp.transaction().unwrap();
}
dbsp.kill().unwrap()
}
fn input(
step: usize,
max_key: u32,
max_tuples: usize,
) -> impl Strategy<Value = Vec<Tup2<u32, Tup2<Tup2<u32, u32>, ZWeight>>>> {
collection::vec(
(0..max_key, 0..LATENESS, 0..LATENESS, 0..2i64),
0..max_tuples,
)
.prop_map(move |v| {
v.into_iter()
.map(|(k, v, ts, w)| Tup2(k, Tup2(Tup2(step as u32 + v, step as u32 + ts), w)))
.collect()
})
}
fn inputs(
steps: usize,
max_key: u32,
max_tuples: usize,
) -> impl Strategy<Value = Vec<Vec<Tup2<u32, Tup2<Tup2<u32, u32>, ZWeight>>>>> {
(0..steps)
.map(|step| input(step, max_key, max_tuples))
.collect::<Vec<_>>()
.prop_map(|v| v)
}
proptest! {
#[test]
fn proptest_max_retain_values_test(inputs in inputs(100, 100, 20)) {
max_retain_values_test(inputs);
}
}
}
}