use super::{DynTreeNode, Prefix, RadixTreeFactories, radix_tree_update};
use crate::{
Circuit, DBData, DynZWeight, Stream, ZWeight,
algebra::{HasOne, IndexedZSet, IndexedZSetReader, OrdIndexedZSet},
circuit::{
Scope,
operator_traits::{Operator, TernaryOperator},
},
dynamic::{DataTrait, DynDataTyped, Erase},
operator::dynamic::{
accumulate_trace::AccumulateTraceFeedback, aggregate::DynAggregator,
time_series::radix_tree::treenode::TreeNode, trace::TraceBounds,
},
trace::{Batch, BatchReader, BatchReaderFactories, Builder, Spine, TupleBuilder},
};
use dyn_clone::clone_box;
use num::PrimInt;
use size_of::SizeOf;
use std::{borrow::Cow, cmp::Ordering, marker::PhantomData, ops::Neg};
pub trait RadixTreeBatch<TS, A>:
IndexedZSet<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>
where
TS: DBData + PrimInt,
A: DataTrait + ?Sized,
{
}
impl<TS, A, B> RadixTreeBatch<TS, A> for B
where
TS: DBData + PrimInt,
A: DataTrait + ?Sized,
B: IndexedZSet<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>,
{
}
pub trait RadixTreeReader<TS, A>:
IndexedZSetReader<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>
where
TS: DBData + PrimInt,
A: DataTrait + ?Sized,
{
}
impl<TS, A, B> RadixTreeReader<TS, A> for B
where
B: IndexedZSetReader<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>,
TS: DBData + PrimInt,
A: DataTrait + ?Sized,
{
}
pub type OrdRadixTree<TS, A> = OrdIndexedZSet<DynDataTyped<Prefix<TS>>, DynTreeNode<TS, A>>;
pub struct TreeAggregateFactories<
TS: DBData + PrimInt,
Z: IndexedZSet<Key = DynDataTyped<TS>>,
O: RadixTreeBatch<TS, Acc>,
Acc: DataTrait + ?Sized,
> {
input_factories: Z::Factories,
output_factories: O::Factories,
radix_tree_factories: RadixTreeFactories<TS, Acc>,
}
impl<TS, Z, O, Acc> TreeAggregateFactories<TS, Z, O, Acc>
where
TS: DBData + PrimInt,
Z: IndexedZSet<Key = DynDataTyped<TS>>,
O: RadixTreeBatch<TS, Acc>,
Acc: DataTrait + ?Sized,
{
pub fn new<VType, AType>() -> Self
where
VType: DBData + Erase<Z::Val>,
AType: DBData + Erase<Acc>,
{
Self {
input_factories: BatchReaderFactories::new::<TS, VType, ZWeight>(),
output_factories: BatchReaderFactories::new::<Prefix<TS>, TreeNode<TS, AType>, ZWeight>(
),
radix_tree_factories: RadixTreeFactories::new::<AType>(),
}
}
}
impl<C, Z, TS> Stream<C, Z>
where
C: Circuit,
Z: IndexedZSet<Key = DynDataTyped<TS>> + SizeOf + Send,
TS: DBData + PrimInt,
{
pub fn tree_aggregate<Acc, Out>(
&self,
persistent_id: Option<&str>,
factories: &TreeAggregateFactories<TS, Z, OrdRadixTree<TS, Acc>, Acc>,
aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>,
) -> Stream<C, OrdRadixTree<TS, Acc>>
where
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
{
self.tree_aggregate_generic::<Acc, Out, OrdRadixTree<TS, Acc>>(
persistent_id,
factories,
aggregator,
)
}
pub fn tree_aggregate_generic<Acc, Out, O>(
&self,
persistent_id: Option<&str>,
factories: &TreeAggregateFactories<TS, Z, O, Acc>,
aggregator: &dyn DynAggregator<Z::Val, (), DynZWeight, Accumulator = Acc, Output = Out>,
) -> Stream<C, O>
where
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
O: RadixTreeBatch<TS, Acc>,
{
self.circuit().region("tree_aggregate", move || {
let circuit = self.circuit();
let stream = self.dyn_gather(&factories.input_factories, 0);
let feedback = circuit.add_accumulate_integrate_trace_feedback::<Spine<O>>(
persistent_id,
&factories.output_factories,
<TraceBounds<O::Key, O::Val>>::unbounded(),
);
let output = circuit.add_ternary_operator(
RadixTreeAggregate::new(
&factories.radix_tree_factories,
&factories.output_factories,
aggregator,
),
&stream.dyn_accumulate(&factories.input_factories),
&stream.dyn_accumulate_integrate_trace(&factories.input_factories),
&feedback.delayed_trace,
);
feedback.connect(&output, &factories.output_factories);
output
})
}
}
struct RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
where
Z: BatchReader<Key = DynDataTyped<TS>>,
TS: DBData + PrimInt,
O: Batch,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
{
aggregator: Box<dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>>,
radix_tree_factories: RadixTreeFactories<TS, Acc>,
output_factories: O::Factories,
phantom: PhantomData<(Z, IT, OT, O)>,
}
impl<Z, TS, IT, OT, Acc, Out, O> RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
where
Z: BatchReader<Key = DynDataTyped<TS>>,
TS: DBData + PrimInt,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
O: Batch,
{
pub fn new(
radix_tree_factories: &RadixTreeFactories<TS, Acc>,
output_factories: &O::Factories,
aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>,
) -> Self {
Self {
radix_tree_factories: radix_tree_factories.clone(),
output_factories: output_factories.clone(),
aggregator: clone_box(aggregator),
phantom: PhantomData,
}
}
}
impl<Z, TS, IT, OT, Acc, Out, O> Operator for RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
where
Z: BatchReader<Key = DynDataTyped<TS>>,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
TS: DBData + PrimInt,
IT: 'static,
OT: 'static,
O: Batch,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("RadixTreeAggregate")
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
}
impl<Z, TS, IT, OT, Acc, Out, O> TernaryOperator<Option<Spine<Z>>, IT, OT, O>
for RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
where
Z: IndexedZSet<Key = DynDataTyped<TS>>,
TS: DBData + PrimInt,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
O: RadixTreeBatch<TS, Acc>,
IT: IndexedZSetReader<Key = Z::Key, Val = Z::Val> + Clone,
OT: RadixTreeReader<TS, Acc> + Clone,
{
async fn eval(
&mut self,
delta: Cow<'_, Option<Spine<Z>>>,
input_trace: Cow<'_, IT>,
output_trace: Cow<'_, OT>,
) -> O {
let Some(delta) = delta.as_ref() else {
return O::dyn_empty(&self.output_factories);
};
let mut updates = self.radix_tree_factories.node_updates_factory.default_box();
updates.reserve(delta.key_count());
radix_tree_update::<TS, Z::Val, Acc, Out, _, _, _>(
&self.radix_tree_factories,
delta.cursor(),
input_trace.cursor(),
output_trace.cursor(),
self.aggregator.as_ref(),
&mut *updates,
);
let builder =
O::Builder::with_capacity(&self.output_factories, updates.len(), updates.len() * 2);
let mut builder = TupleBuilder::new(&self.output_factories, builder);
for update in updates.dyn_iter_mut() {
match update.new().cmp(update.old()) {
Ordering::Equal => {}
Ordering::Less => {
let mut prefix = update.prefix();
if let Some(new) = update.new_mut().get_mut() {
builder.push_vals(
prefix.clone().erase_mut(),
new,
&mut (),
ZWeight::one().erase_mut(),
);
};
if let Some(old) = update.old_mut().get_mut() {
builder.push_vals(
prefix.erase_mut(),
old,
&mut (),
ZWeight::one().neg().erase_mut(),
);
};
}
Ordering::Greater => {
let mut prefix = update.prefix();
if let Some(old) = update.old_mut().get_mut() {
builder.push_vals(
prefix.clone().erase_mut(),
old,
&mut (),
ZWeight::one().neg().erase_mut(),
);
};
if let Some(new) = update.new_mut().get_mut() {
builder.push_vals(
prefix.erase_mut(),
new,
&mut (),
ZWeight::one().erase_mut(),
);
};
}
}
}
builder.done()
}
}
#[cfg(test)]
mod test {
use super::super::RadixTreeCursor;
use crate::{
DynZWeight, Runtime, Stream, ZWeight,
algebra::{AddAssignByRef, DefaultSemigroup},
dynamic::{DowncastTrait, DynData, DynDataTyped, DynPair, Erase},
operator::{
Fold,
dynamic::{
aggregate::DynAggregatorImpl,
input::{AddInputIndexedZSetFactories, CollectionHandle},
time_series::{
TreeNode,
radix_tree::{
Prefix,
test::test_aggregate_range,
tree_aggregate::{OrdRadixTree, TreeAggregateFactories},
},
},
},
},
trace::{BatchReader, BatchReaderFactories},
utils::Tup2,
};
use std::{
collections::{BTreeMap, btree_map::Entry},
sync::{Arc, Mutex},
};
fn update_key(
input: &CollectionHandle<DynDataTyped<u64>, DynPair<DynData, DynZWeight>>,
contents: &mut BTreeMap<u64, Box<DynData >>,
key: u64,
upd: Tup2<u64, ZWeight>,
) {
input.dyn_push(key.clone().erase_mut(), upd.clone().erase_mut());
match contents.entry(key) {
Entry::Vacant(ve) => {
assert_eq!(upd.1, 1);
ve.insert(Box::new(upd.0).erase_box());
}
Entry::Occupied(mut oe) => {
assert!(upd.1 == 1 || upd.1 == -1);
if upd.1 == 1 {
*oe.get_mut().downcast_mut_checked::<u64>() += upd.0;
} else {
*oe.get_mut().downcast_mut_checked::<u64>() -= upd.0;
}
if *oe.get().downcast_checked::<u64>() == 0 {
oe.remove();
}
}
}
}
#[test]
fn test_tree_aggregate() {
let contents = Arc::new(Mutex::new(BTreeMap::new()));
let contents_clone = contents.clone();
let (mut circuit, input) = Runtime::init_circuit(1, move |circuit| {
let (input, input_handle) =
circuit.dyn_add_input_indexed_zset::<DynDataTyped<u64>, DynData>(&AddInputIndexedZSetFactories::new::<u64, u64>());
let aggregator = <Fold<u64, _, DefaultSemigroup<_>, _, _>>::new(
0u64,
|agg: &mut u64, val: &u64, _w: ZWeight| *agg += val,
);
let aggregate: Stream<_, OrdRadixTree<u64, DynData >> = input
.tree_aggregate::<DynData, DynData>(
None,
&TreeAggregateFactories::new::<u64, u64>(),
&DynAggregatorImpl::new(aggregator),
);
let factory = BatchReaderFactories::new::<Prefix<u64>, TreeNode<u64, u64>, ZWeight>();
aggregate
.dyn_integrate_trace(&factory)
.apply(move |tree_trace| {
println!("Radix tree:");
let mut treestr = String::new();
tree_trace.cursor().format_tree(&mut treestr).unwrap();
println!("{treestr}");
tree_trace
.cursor()
.validate(&contents_clone.lock().unwrap(), &|acc, val| {
acc.downcast_mut_checked::<u64>().add_assign_by_ref(val.downcast_checked::<u64>())
});
test_aggregate_range::<u64, u64, _, DefaultSemigroup<_>>(
&mut tree_trace.cursor(),
&contents_clone.lock().unwrap(),
);
});
Ok(input_handle)
})
.unwrap();
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0x1000_0000_0000_0001,
Tup2(1, 1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0x1000_0000_0000_0002,
Tup2(2, 1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0x1000_1000_0000_0000,
Tup2(3, 1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0x1000_0000_0000_0002,
Tup2(2, -1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0xf100_0000_0000_0001,
Tup2(4, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf200_0000_0000_0001,
Tup2(5, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_0000_0000_0001,
Tup2(6, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_0000_0001,
Tup2(7, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_1000_0001,
Tup2(8, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_1000_1001,
Tup2(9, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_1100_1001,
Tup2(10, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_1100_1001,
Tup2(10, -1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0xf400_1000_1100_1001,
Tup2(11, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_0000_0001,
Tup2(7, -1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0x1000_0000_0000_0001,
Tup2(1, -1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0x1000_1000_0000_0000,
Tup2(3, -1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf100_0000_0000_0001,
Tup2(4, -1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf200_0000_0000_0001,
Tup2(5, -1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_0000_0000_0001,
Tup2(6, -1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_1000_0001,
Tup2(8, -1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_1000_1001,
Tup2(9, -1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0xf400_1000_1100_1001,
Tup2(11, -1),
);
circuit.transaction().unwrap();
update_key(
&input,
&mut contents.lock().unwrap(),
0xf100_0000_0000_0001,
Tup2(4, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf200_0000_0000_0001,
Tup2(5, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_0000_0000_0001,
Tup2(6, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_0000_0001,
Tup2(7, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_1000_0001,
Tup2(8, 1),
);
update_key(
&input,
&mut contents.lock().unwrap(),
0xf300_1000_1000_0001,
Tup2(11, 1),
);
circuit.transaction().unwrap();
}
}