1use super::{DynTreeNode, Prefix, RadixTreeFactories, radix_tree_update};
2use crate::{
3 Circuit, DBData, DynZWeight, Stream, ZWeight,
4 algebra::{HasOne, IndexedZSet, IndexedZSetReader, OrdIndexedZSet},
5 circuit::{
6 Scope,
7 operator_traits::{Operator, TernaryOperator},
8 },
9 dynamic::{DataTrait, DynDataTyped, Erase},
10 operator::dynamic::{
11 accumulate_trace::AccumulateTraceFeedback, aggregate::DynAggregator,
12 time_series::radix_tree::treenode::TreeNode, trace::TraceBounds,
13 },
14 trace::{Batch, BatchReader, BatchReaderFactories, Builder, Spine, TupleBuilder},
15};
16use dyn_clone::clone_box;
17use num::PrimInt;
18use size_of::SizeOf;
19use std::{borrow::Cow, cmp::Ordering, marker::PhantomData, ops::Neg};
20
21pub trait RadixTreeBatch<TS, A>:
23 IndexedZSet<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>
24where
25 TS: DBData + PrimInt,
26 A: DataTrait + ?Sized,
27{
28}
29
30impl<TS, A, B> RadixTreeBatch<TS, A> for B
31where
32 TS: DBData + PrimInt,
33 A: DataTrait + ?Sized,
34 B: IndexedZSet<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>,
35{
36}
37
38pub trait RadixTreeReader<TS, A>:
39 IndexedZSetReader<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>
40where
41 TS: DBData + PrimInt,
42 A: DataTrait + ?Sized,
43{
44}
45
46impl<TS, A, B> RadixTreeReader<TS, A> for B
47where
48 B: IndexedZSetReader<Key = DynDataTyped<Prefix<TS>>, Val = DynTreeNode<TS, A>>,
49 TS: DBData + PrimInt,
50 A: DataTrait + ?Sized,
51{
52}
53
54pub type OrdRadixTree<TS, A> = OrdIndexedZSet<DynDataTyped<Prefix<TS>>, DynTreeNode<TS, A>>;
55
56pub struct TreeAggregateFactories<
57 TS: DBData + PrimInt,
58 Z: IndexedZSet<Key = DynDataTyped<TS>>,
59 O: RadixTreeBatch<TS, Acc>,
60 Acc: DataTrait + ?Sized,
61> {
62 input_factories: Z::Factories,
63 output_factories: O::Factories,
64 radix_tree_factories: RadixTreeFactories<TS, Acc>,
65}
66
67impl<TS, Z, O, Acc> TreeAggregateFactories<TS, Z, O, Acc>
68where
69 TS: DBData + PrimInt,
70 Z: IndexedZSet<Key = DynDataTyped<TS>>,
71 O: RadixTreeBatch<TS, Acc>,
72 Acc: DataTrait + ?Sized,
73{
74 pub fn new<VType, AType>() -> Self
75 where
76 VType: DBData + Erase<Z::Val>,
77 AType: DBData + Erase<Acc>,
78 {
79 Self {
80 input_factories: BatchReaderFactories::new::<TS, VType, ZWeight>(),
81 output_factories: BatchReaderFactories::new::<Prefix<TS>, TreeNode<TS, AType>, ZWeight>(
82 ),
83 radix_tree_factories: RadixTreeFactories::new::<AType>(),
84 }
85 }
86}
87
88impl<C, Z, TS> Stream<C, Z>
89where
90 C: Circuit,
91 Z: IndexedZSet<Key = DynDataTyped<TS>> + SizeOf + Send,
92 TS: DBData + PrimInt,
93{
94 pub fn tree_aggregate<Acc, Out>(
105 &self,
106 persistent_id: Option<&str>,
107 factories: &TreeAggregateFactories<TS, Z, OrdRadixTree<TS, Acc>, Acc>,
108 aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>,
109 ) -> Stream<C, OrdRadixTree<TS, Acc>>
110 where
111 Acc: DataTrait + ?Sized,
112 Out: DataTrait + ?Sized,
113 {
114 self.tree_aggregate_generic::<Acc, Out, OrdRadixTree<TS, Acc>>(
115 persistent_id,
116 factories,
117 aggregator,
118 )
119 }
120
121 pub fn tree_aggregate_generic<Acc, Out, O>(
123 &self,
124 persistent_id: Option<&str>,
125 factories: &TreeAggregateFactories<TS, Z, O, Acc>,
126 aggregator: &dyn DynAggregator<Z::Val, (), DynZWeight, Accumulator = Acc, Output = Out>,
127 ) -> Stream<C, O>
128 where
129 Acc: DataTrait + ?Sized,
130 Out: DataTrait + ?Sized,
131 O: RadixTreeBatch<TS, Acc>,
132 {
133 self.circuit().region("tree_aggregate", move || {
134 let circuit = self.circuit();
135 let stream = self.dyn_gather(&factories.input_factories, 0);
136
137 let feedback = circuit.add_accumulate_integrate_trace_feedback::<Spine<O>>(
156 persistent_id,
157 &factories.output_factories,
158 <TraceBounds<O::Key, O::Val>>::unbounded(),
159 );
160
161 let output = circuit.add_ternary_operator(
162 RadixTreeAggregate::new(
163 &factories.radix_tree_factories,
164 &factories.output_factories,
165 aggregator,
166 ),
167 &stream.dyn_accumulate(&factories.input_factories),
168 &stream.dyn_accumulate_integrate_trace(&factories.input_factories),
169 &feedback.delayed_trace,
170 );
171
172 feedback.connect(&output, &factories.output_factories);
173
174 output
175 })
176 }
177}
178
179struct RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
186where
187 Z: BatchReader<Key = DynDataTyped<TS>>,
188 TS: DBData + PrimInt,
189 O: Batch,
190 Acc: DataTrait + ?Sized,
191 Out: DataTrait + ?Sized,
192{
193 aggregator: Box<dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>>,
194 radix_tree_factories: RadixTreeFactories<TS, Acc>,
195 output_factories: O::Factories,
196 phantom: PhantomData<(Z, IT, OT, O)>,
197}
198
199impl<Z, TS, IT, OT, Acc, Out, O> RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
200where
201 Z: BatchReader<Key = DynDataTyped<TS>>,
202 TS: DBData + PrimInt,
203 Acc: DataTrait + ?Sized,
204 Out: DataTrait + ?Sized,
205 O: Batch,
206{
207 pub fn new(
208 radix_tree_factories: &RadixTreeFactories<TS, Acc>,
209 output_factories: &O::Factories,
210 aggregator: &dyn DynAggregator<Z::Val, (), Z::R, Accumulator = Acc, Output = Out>,
211 ) -> Self {
212 Self {
213 radix_tree_factories: radix_tree_factories.clone(),
214 output_factories: output_factories.clone(),
215 aggregator: clone_box(aggregator),
216 phantom: PhantomData,
217 }
218 }
219}
220
221impl<Z, TS, IT, OT, Acc, Out, O> Operator for RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
222where
223 Z: BatchReader<Key = DynDataTyped<TS>>,
224 Acc: DataTrait + ?Sized,
225 Out: DataTrait + ?Sized,
226 TS: DBData + PrimInt,
227 IT: 'static,
228 OT: 'static,
229 O: Batch,
230{
231 fn name(&self) -> Cow<'static, str> {
232 Cow::from("RadixTreeAggregate")
233 }
234
235 fn fixedpoint(&self, _scope: Scope) -> bool {
236 true
237 }
238}
239
240impl<Z, TS, IT, OT, Acc, Out, O> TernaryOperator<Option<Spine<Z>>, IT, OT, O>
241 for RadixTreeAggregate<Z, TS, IT, OT, Acc, Out, O>
242where
243 Z: IndexedZSet<Key = DynDataTyped<TS>>,
244 TS: DBData + PrimInt,
245 Acc: DataTrait + ?Sized,
246 Out: DataTrait + ?Sized,
247 O: RadixTreeBatch<TS, Acc>,
248 IT: IndexedZSetReader<Key = Z::Key, Val = Z::Val> + Clone,
249 OT: RadixTreeReader<TS, Acc> + Clone,
250{
251 async fn eval(
252 &mut self,
253 delta: Cow<'_, Option<Spine<Z>>>,
254 input_trace: Cow<'_, IT>,
255 output_trace: Cow<'_, OT>,
256 ) -> O {
257 let Some(delta) = delta.as_ref() else {
258 return O::dyn_empty(&self.output_factories);
259 };
260
261 let mut updates = self.radix_tree_factories.node_updates_factory.default_box();
262 updates.reserve(delta.key_count());
263
264 radix_tree_update::<TS, Z::Val, Acc, Out, _, _, _>(
265 &self.radix_tree_factories,
266 delta.cursor(),
267 input_trace.cursor(),
268 output_trace.cursor(),
269 self.aggregator.as_ref(),
270 &mut *updates,
271 );
272
273 let builder =
274 O::Builder::with_capacity(&self.output_factories, updates.len(), updates.len() * 2);
275 let mut builder = TupleBuilder::new(&self.output_factories, builder);
276
277 for update in updates.dyn_iter_mut() {
280 match update.new().cmp(update.old()) {
281 Ordering::Equal => {}
282 Ordering::Less => {
283 let mut prefix = update.prefix();
284 if let Some(new) = update.new_mut().get_mut() {
285 builder.push_vals(
286 prefix.clone().erase_mut(),
287 new,
288 &mut (),
289 ZWeight::one().erase_mut(),
290 );
291 };
292 if let Some(old) = update.old_mut().get_mut() {
293 builder.push_vals(
294 prefix.erase_mut(),
295 old,
296 &mut (),
297 ZWeight::one().neg().erase_mut(),
298 );
299 };
300 }
301 Ordering::Greater => {
302 let mut prefix = update.prefix();
303
304 if let Some(old) = update.old_mut().get_mut() {
305 builder.push_vals(
306 prefix.clone().erase_mut(),
307 old,
308 &mut (),
309 ZWeight::one().neg().erase_mut(),
310 );
311 };
312 if let Some(new) = update.new_mut().get_mut() {
313 builder.push_vals(
314 prefix.erase_mut(),
315 new,
316 &mut (),
317 ZWeight::one().erase_mut(),
318 );
319 };
320 }
321 }
322 }
323
324 builder.done()
325 }
326}
327
328#[cfg(test)]
329mod test {
330 use super::super::RadixTreeCursor;
331 use crate::{
332 DynZWeight, Runtime, Stream, ZWeight,
333 algebra::{AddAssignByRef, DefaultSemigroup},
334 dynamic::{DowncastTrait, DynData, DynDataTyped, DynPair, Erase},
335 operator::{
336 Fold,
337 dynamic::{
338 aggregate::DynAggregatorImpl,
339 input::{AddInputIndexedZSetFactories, CollectionHandle},
340 time_series::{
341 TreeNode,
342 radix_tree::{
343 Prefix,
344 test::test_aggregate_range,
345 tree_aggregate::{OrdRadixTree, TreeAggregateFactories},
346 },
347 },
348 },
349 },
350 trace::{BatchReader, BatchReaderFactories},
351 utils::Tup2,
352 };
353 use std::{
354 collections::{BTreeMap, btree_map::Entry},
355 sync::{Arc, Mutex},
356 };
357
358 fn update_key(
359 input: &CollectionHandle<DynDataTyped<u64>, DynPair<DynData, DynZWeight>>,
360 contents: &mut BTreeMap<u64, Box<DynData >>,
361 key: u64,
362 upd: Tup2<u64, ZWeight>,
363 ) {
364 input.dyn_push(key.clone().erase_mut(), upd.clone().erase_mut());
365 match contents.entry(key) {
366 Entry::Vacant(ve) => {
367 assert_eq!(upd.1, 1);
368 ve.insert(Box::new(upd.0).erase_box());
369 }
370 Entry::Occupied(mut oe) => {
371 assert!(upd.1 == 1 || upd.1 == -1);
372 if upd.1 == 1 {
373 *oe.get_mut().downcast_mut_checked::<u64>() += upd.0;
374 } else {
375 *oe.get_mut().downcast_mut_checked::<u64>() -= upd.0;
376 }
377 if *oe.get().downcast_checked::<u64>() == 0 {
378 oe.remove();
379 }
380 }
381 }
382 }
383
384 #[test]
385 fn test_tree_aggregate() {
386 let contents = Arc::new(Mutex::new(BTreeMap::new()));
387 let contents_clone = contents.clone();
388
389 let (mut circuit, input) = Runtime::init_circuit(1, move |circuit| {
390 let (input, input_handle) =
391 circuit.dyn_add_input_indexed_zset::<DynDataTyped<u64>, DynData>(&AddInputIndexedZSetFactories::new::<u64, u64>());
392
393 let aggregator = <Fold<u64, _, DefaultSemigroup<_>, _, _>>::new(
394 0u64,
395 |agg: &mut u64, val: &u64, _w: ZWeight| *agg += val,
396 );
397
398 let aggregate: Stream<_, OrdRadixTree<u64, DynData >> = input
399 .tree_aggregate::<DynData, DynData>(
400 None,
401 &TreeAggregateFactories::new::<u64, u64>(),
402 &DynAggregatorImpl::new(aggregator),
403 );
404 let factory = BatchReaderFactories::new::<Prefix<u64>, TreeNode<u64, u64>, ZWeight>();
405 aggregate
406 .dyn_integrate_trace(&factory)
407 .apply(move |tree_trace| {
408 println!("Radix tree:");
409 let mut treestr = String::new();
410 tree_trace.cursor().format_tree(&mut treestr).unwrap();
411 println!("{treestr}");
412 tree_trace
413 .cursor()
414 .validate(&contents_clone.lock().unwrap(), &|acc, val| {
415 acc.downcast_mut_checked::<u64>().add_assign_by_ref(val.downcast_checked::<u64>())
416 });
417 test_aggregate_range::<u64, u64, _, DefaultSemigroup<_>>(
418 &mut tree_trace.cursor(),
419 &contents_clone.lock().unwrap(),
420 );
421 });
422
423 Ok(input_handle)
424 })
425 .unwrap();
426
427 circuit.transaction().unwrap();
428
429 update_key(
430 &input,
431 &mut contents.lock().unwrap(),
432 0x1000_0000_0000_0001,
433 Tup2(1, 1),
434 );
435 circuit.transaction().unwrap();
436
437 update_key(
438 &input,
439 &mut contents.lock().unwrap(),
440 0x1000_0000_0000_0002,
441 Tup2(2, 1),
442 );
443 circuit.transaction().unwrap();
444
445 update_key(
446 &input,
447 &mut contents.lock().unwrap(),
448 0x1000_1000_0000_0000,
449 Tup2(3, 1),
450 );
451 circuit.transaction().unwrap();
452
453 update_key(
454 &input,
455 &mut contents.lock().unwrap(),
456 0x1000_0000_0000_0002,
457 Tup2(2, -1),
458 );
459 circuit.transaction().unwrap();
460
461 update_key(
462 &input,
463 &mut contents.lock().unwrap(),
464 0xf100_0000_0000_0001,
465 Tup2(4, 1),
466 );
467 update_key(
468 &input,
469 &mut contents.lock().unwrap(),
470 0xf200_0000_0000_0001,
471 Tup2(5, 1),
472 );
473 update_key(
474 &input,
475 &mut contents.lock().unwrap(),
476 0xf300_0000_0000_0001,
477 Tup2(6, 1),
478 );
479 update_key(
480 &input,
481 &mut contents.lock().unwrap(),
482 0xf300_1000_0000_0001,
483 Tup2(7, 1),
484 );
485 update_key(
486 &input,
487 &mut contents.lock().unwrap(),
488 0xf300_1000_1000_0001,
489 Tup2(8, 1),
490 );
491 update_key(
492 &input,
493 &mut contents.lock().unwrap(),
494 0xf300_1000_1000_1001,
495 Tup2(9, 1),
496 );
497 update_key(
498 &input,
499 &mut contents.lock().unwrap(),
500 0xf300_1000_1100_1001,
501 Tup2(10, 1),
502 );
503 update_key(
504 &input,
505 &mut contents.lock().unwrap(),
506 0xf300_1000_1100_1001,
507 Tup2(10, -1),
508 );
509 circuit.transaction().unwrap();
510
511 update_key(
512 &input,
513 &mut contents.lock().unwrap(),
514 0xf400_1000_1100_1001,
515 Tup2(11, 1),
516 );
517 update_key(
518 &input,
519 &mut contents.lock().unwrap(),
520 0xf300_1000_0000_0001,
521 Tup2(7, -1),
522 );
523 circuit.transaction().unwrap();
524
525 update_key(
526 &input,
527 &mut contents.lock().unwrap(),
528 0x1000_0000_0000_0001,
529 Tup2(1, -1),
530 );
531 update_key(
532 &input,
533 &mut contents.lock().unwrap(),
534 0x1000_1000_0000_0000,
535 Tup2(3, -1),
536 );
537 update_key(
538 &input,
539 &mut contents.lock().unwrap(),
540 0xf100_0000_0000_0001,
541 Tup2(4, -1),
542 );
543 update_key(
544 &input,
545 &mut contents.lock().unwrap(),
546 0xf200_0000_0000_0001,
547 Tup2(5, -1),
548 );
549 circuit.transaction().unwrap();
550
551 update_key(
552 &input,
553 &mut contents.lock().unwrap(),
554 0xf300_0000_0000_0001,
555 Tup2(6, -1),
556 );
557 update_key(
558 &input,
559 &mut contents.lock().unwrap(),
560 0xf300_1000_1000_0001,
561 Tup2(8, -1),
562 );
563 update_key(
564 &input,
565 &mut contents.lock().unwrap(),
566 0xf300_1000_1000_1001,
567 Tup2(9, -1),
568 );
569 circuit.transaction().unwrap();
570
571 update_key(
572 &input,
573 &mut contents.lock().unwrap(),
574 0xf400_1000_1100_1001,
575 Tup2(11, -1),
576 );
577 circuit.transaction().unwrap();
578
579 update_key(
580 &input,
581 &mut contents.lock().unwrap(),
582 0xf100_0000_0000_0001,
583 Tup2(4, 1),
584 );
585 update_key(
586 &input,
587 &mut contents.lock().unwrap(),
588 0xf200_0000_0000_0001,
589 Tup2(5, 1),
590 );
591 update_key(
592 &input,
593 &mut contents.lock().unwrap(),
594 0xf300_0000_0000_0001,
595 Tup2(6, 1),
596 );
597 update_key(
598 &input,
599 &mut contents.lock().unwrap(),
600 0xf300_1000_0000_0001,
601 Tup2(7, 1),
602 );
603 update_key(
604 &input,
605 &mut contents.lock().unwrap(),
606 0xf300_1000_1000_0001,
607 Tup2(8, 1),
608 );
609 update_key(
610 &input,
611 &mut contents.lock().unwrap(),
612 0xf300_1000_1000_0001,
613 Tup2(11, 1),
614 );
615 circuit.transaction().unwrap();
616 }
617}