use std::{
borrow::Cow,
cell::{Cell, RefCell},
cmp::Ordering,
marker::PhantomData,
panic::Location,
rc::Rc,
};
use async_stream::stream;
use crate::{
Circuit, DBData, DynZWeight, Position, RootCircuit, Scope, Stream, ZWeight,
algebra::{IndexedZSet, IndexedZSetReader, OrdIndexedZSet, OrdZSet, ZBatchReader, ZCursor},
circuit::{
metadata::{
BatchSizeStats, LEFT_INPUT_BATCHES_STATS, OUTPUT_BATCHES_STATS, OperatorLocation,
OperatorMeta, RIGHT_INPUT_BATCHES_STATS,
},
operator_traits::Operator,
splitter_output_chunk_size,
},
dynamic::{
ClonableTrait, Data, DataTrait, DowncastTrait, DynData, DynPair, DynUnit, DynVec,
DynWeightedPairs, Erase, Factory, LeanVec, WeightTrait, WithFactory,
},
operator::async_stream_operators::{StreamingQuaternaryOperator, StreamingQuaternaryWrapper},
trace::{
BatchFactories, BatchReader, BatchReaderFactories, Cursor, Spine, SpineSnapshot,
cursor::{CursorEmpty, CursorPair},
spine_async::WithSnapshot,
},
};
use super::{MonoIndexedZSet, MonoZSet};
pub struct AsofJoinFactories<TS, I1, I2, O>
where
TS: DataTrait + ?Sized,
I1: IndexedZSetReader,
I2: IndexedZSetReader,
O: IndexedZSet,
{
pub timestamp_factory: &'static dyn Factory<TS>,
pub timestamps_factory: &'static dyn Factory<DynVec<TS>>,
pub left_factories: I1::Factories,
pub right_factories: I2::Factories,
pub output_factories: O::Factories,
}
impl<TS, I1, I2, O> AsofJoinFactories<TS, I1, I2, O>
where
TS: DataTrait + ?Sized,
I1: IndexedZSetReader,
I2: IndexedZSetReader<Key = I1::Key>,
O: IndexedZSet,
{
pub fn new<TSType, KType, V1Type, V2Type, OKType, OVType>() -> Self
where
TSType: DBData + Erase<TS>,
KType: DBData + Erase<I1::Key>,
V1Type: DBData + Erase<I1::Val>,
V2Type: DBData + Erase<I2::Val>,
OKType: DBData + Erase<O::Key>,
OVType: DBData + Erase<O::Val>,
{
Self {
timestamp_factory: WithFactory::<TSType>::FACTORY,
timestamps_factory: WithFactory::<LeanVec<TSType>>::FACTORY,
left_factories: BatchReaderFactories::new::<KType, V1Type, ZWeight>(),
right_factories: BatchReaderFactories::new::<KType, V2Type, ZWeight>(),
output_factories: BatchReaderFactories::new::<OKType, OVType, ZWeight>(),
}
}
}
impl<TS, I1, I2, O> Clone for AsofJoinFactories<TS, I1, I2, O>
where
TS: DataTrait + ?Sized,
I1: IndexedZSetReader,
I2: IndexedZSetReader,
O: IndexedZSet,
{
fn clone(&self) -> Self {
Self {
timestamp_factory: self.timestamp_factory,
timestamps_factory: self.timestamps_factory,
left_factories: self.left_factories.clone(),
right_factories: self.right_factories.clone(),
output_factories: self.output_factories.clone(),
}
}
}
impl Stream<RootCircuit, MonoIndexedZSet> {
#[track_caller]
pub fn dyn_asof_join_mono(
&self,
factories: &AsofJoinFactories<DynData, MonoIndexedZSet, MonoIndexedZSet, MonoZSet>,
other: &Stream<RootCircuit, MonoIndexedZSet>,
ts_func1: Box<dyn Fn(&DynData, &mut DynData)>,
tscmp_func: Box<dyn Fn(&DynData, &DynData) -> Ordering>,
valts_cmp_func: Box<dyn Fn(&DynData, &DynData) -> Ordering>,
join_func: Box<AsofJoinFunc<DynData, DynData, DynData, DynData, DynUnit>>,
) -> Stream<RootCircuit, MonoZSet> {
self.dyn_asof_join(
factories,
other,
ts_func1,
tscmp_func,
valts_cmp_func,
join_func,
)
}
}
impl<I1> Stream<RootCircuit, I1>
where
I1: IndexedZSet + Send,
{
#[track_caller]
pub fn dyn_asof_join<TS, I2, V>(
&self,
factories: &AsofJoinFactories<TS, I1, I2, OrdZSet<V>>,
other: &Stream<RootCircuit, I2>,
ts_func1: Box<dyn Fn(&I1::Val, &mut TS)>,
tscmp_func: Box<dyn Fn(&I1::Val, &I2::Val) -> Ordering>,
valts_cmp_func: Box<dyn Fn(&I1::Val, &TS) -> Ordering>,
join_func: Box<AsofJoinFunc<I1::Key, I1::Val, I2::Val, V, DynUnit>>,
) -> Stream<RootCircuit, OrdZSet<V>>
where
TS: DataTrait + ?Sized,
I2: IndexedZSet<Key = I1::Key>,
V: DataTrait + ?Sized,
{
self.dyn_asof_join_generic(
factories,
other,
ts_func1,
tscmp_func,
valts_cmp_func,
join_func,
)
}
#[track_caller]
pub fn dyn_asof_join_index<TS, I2, K, V>(
&self,
factories: &AsofJoinFactories<TS, I1, I2, OrdIndexedZSet<K, V>>,
other: &Stream<RootCircuit, I2>,
ts_func1: Box<dyn Fn(&I1::Val, &mut TS)>,
tscmp_func: Box<dyn Fn(&I1::Val, &I2::Val) -> Ordering>,
valts_cmp_func: Box<dyn Fn(&I1::Val, &TS) -> Ordering>,
join_func: Box<AsofJoinFunc<I1::Key, I1::Val, I2::Val, K, V>>,
) -> Stream<RootCircuit, OrdIndexedZSet<K, V>>
where
TS: DataTrait + ?Sized,
I2: IndexedZSet<Key = I1::Key>,
K: DataTrait + ?Sized,
V: DataTrait + ?Sized,
{
self.dyn_asof_join_generic(
factories,
other,
ts_func1,
tscmp_func,
valts_cmp_func,
join_func,
)
}
#[track_caller]
pub fn dyn_asof_join_generic<TS, I2, Z>(
&self,
factories: &AsofJoinFactories<TS, I1, I2, Z>,
other: &Stream<RootCircuit, I2>,
ts_func1: Box<dyn Fn(&I1::Val, &mut TS)>,
tscmp_func: Box<dyn Fn(&I1::Val, &I2::Val) -> Ordering>,
valts_cmp_func: Box<dyn Fn(&I1::Val, &TS) -> Ordering>,
join_func: Box<AsofJoinFunc<I1::Key, I1::Val, I2::Val, Z::Key, Z::Val>>,
) -> Stream<RootCircuit, Z>
where
TS: DataTrait + ?Sized,
I2: IndexedZSet<Key = I1::Key>,
Z: IndexedZSet,
{
self.circuit().region("asof_join", || {
let left = self.dyn_shard(&factories.left_factories);
let right = other.dyn_shard(&factories.right_factories);
let left_trace = left
.dyn_accumulate_integrate_trace(&factories.left_factories)
.accumulate_delay_trace();
let right_trace = right
.dyn_accumulate_integrate_trace(&factories.right_factories)
.accumulate_delay_trace();
self.circuit().add_quaternary_operator(
StreamingQuaternaryWrapper::new(AsofJoin::new(
factories.clone(),
ts_func1,
tscmp_func,
valts_cmp_func,
join_func,
Location::caller(),
)),
&left.dyn_accumulate(&factories.left_factories),
&left_trace,
&right.dyn_accumulate(&factories.right_factories),
&right_trace,
)
})
}
}
pub type AsofJoinFunc<K, V1, V2, OK, OV> =
dyn Fn(&K, &V1, Option<&V2>, &mut dyn FnMut(&mut OK, &mut OV));
pub struct AsofJoin<TS, I1, T1, I2, T2, Z>
where
TS: DataTrait + ?Sized,
I1: IndexedZSet,
T1: ZBatchReader,
I2: IndexedZSet,
T2: ZBatchReader,
Z: IndexedZSet,
{
factories: AsofJoinFactories<TS, I1, I2, Z>,
ts_func1: Box<dyn Fn(&I1::Val, &mut TS)>,
tscmp_func: Box<dyn Fn(&I1::Val, &I2::Val) -> Ordering>,
valts_cmp_func: Box<dyn Fn(&I1::Val, &TS) -> Ordering>,
join_func: Box<AsofJoinFunc<I1::Key, I1::Val, I2::Val, Z::Key, Z::Val>>,
location: &'static Location<'static>,
flush: Cell<bool>,
delta1: RefCell<Option<SpineSnapshot<I1>>>,
delta2: RefCell<Option<SpineSnapshot<I2>>>,
delta1_batch_stats: RefCell<BatchSizeStats>,
delta2_batch_stats: RefCell<BatchSizeStats>,
output_batch_stats: RefCell<BatchSizeStats>,
phantom: PhantomData<(I1, T1, I2, T2, Z)>,
}
impl<TS, I1, T1, I2, T2, Z> AsofJoin<TS, I1, T1, I2, T2, Z>
where
TS: DataTrait + ?Sized,
I1: IndexedZSet,
T1: ZBatchReader,
I2: IndexedZSet<Key = I1::Key>,
T2: ZBatchReader,
Z: IndexedZSet,
{
pub fn new(
factories: AsofJoinFactories<TS, I1, I2, Z>,
ts_func1: Box<dyn Fn(&I1::Val, &mut TS)>,
tscmp_func: Box<dyn Fn(&I1::Val, &I2::Val) -> Ordering>,
valts_cmp_func: Box<dyn Fn(&I1::Val, &TS) -> Ordering>,
join_func: Box<AsofJoinFunc<I1::Key, I1::Val, I2::Val, Z::Key, Z::Val>>,
location: &'static Location<'static>,
) -> Self {
Self {
factories,
ts_func1,
tscmp_func,
valts_cmp_func,
join_func,
location,
flush: Cell::new(false),
delta1: RefCell::new(None),
delta2: RefCell::new(None),
delta1_batch_stats: RefCell::new(BatchSizeStats::new()),
delta2_batch_stats: RefCell::new(BatchSizeStats::new()),
output_batch_stats: RefCell::new(BatchSizeStats::new()),
phantom: PhantomData,
}
}
fn try_seek<'a, C, K, V, T, R>(cursor: &'a mut C, key: &K, hash: u64) -> Option<&'a mut C>
where
K: DataTrait + ?Sized,
V: DataTrait + ?Sized,
R: WeightTrait + ?Sized,
C: Cursor<K, V, T, R>,
{
if cursor.seek_key_exact(key, Some(hash)) {
Some(cursor)
} else {
None
}
}
fn compute_affected_times<DC1, DC2, ZC1, C2>(
&self,
delta1: &mut DC1,
delta2: &mut DC2,
delayed_cursor1: &mut Option<&mut ZC1>,
cursor2: &mut C2,
affected_times: &mut DynVec<TS>,
) where
DC1: ZCursor<I1::Key, I1::Val, ()>,
DC2: ZCursor<I2::Key, I2::Val, ()>,
ZC1: ZCursor<I1::Key, I1::Val, ()>,
C2: ZCursor<I2::Key, I2::Val, ()>,
{
affected_times.clear();
while delta1.val_valid() {
affected_times.push_with(&mut |ts| (self.ts_func1)(delta1.val(), ts));
delta1.step_val();
}
debug_assert!(affected_times.is_sorted_by(&|ts1, ts2| ts1.cmp(ts2)));
if let Some(delayed_cursor1) = delayed_cursor1 {
while delta2.val_valid() {
delayed_cursor1
.seek_val_with(&|v| (self.tscmp_func)(v, delta2.val()) != Ordering::Less);
cursor2.seek_val_with(&|v| v > delta2.val());
debug_assert!(!cursor2.val_valid() || **cursor2.weight() != 0);
while delayed_cursor1.val_valid()
&& (!cursor2.val_valid()
|| (self.tscmp_func)(delayed_cursor1.val(), cursor2.val())
== Ordering::Less)
{
affected_times.push_with(&mut |ts| (self.ts_func1)(delayed_cursor1.val(), ts));
delayed_cursor1.step_val();
}
if !cursor2.val_valid() {
break;
}
if !delayed_cursor1.val_valid() {
break;
}
delta2.seek_val(cursor2.val());
}
affected_times.sort();
}
affected_times.dedup();
}
fn eval_val<C1, C2>(
&self,
ts: &TS,
cursor1: &mut Option<&mut C1>,
cursor2: &mut C2,
multiplier: ZWeight,
output_tuples: &mut DynWeightedPairs<DynPair<Z::Key, Z::Val>, DynZWeight>,
) where
C1: ZCursor<I1::Key, I1::Val, ()>,
C2: ZCursor<I2::Key, I2::Val, ()>,
{
let Some(cursor1) = cursor1 else {
return;
};
cursor1.seek_val_with_reverse(&|v| (self.valts_cmp_func)(v, ts) != Ordering::Greater);
while cursor1.val_valid() && (self.valts_cmp_func)(cursor1.val(), ts) == Ordering::Equal {
cursor2
.seek_val_with_reverse(&|v| (self.tscmp_func)(cursor1.val(), v) != Ordering::Less);
debug_assert!(!cursor2.val_valid() || **cursor2.weight() != 0);
let w1 = **cursor1.weight();
let w2 = if cursor2.val_valid() {
**cursor2.weight()
} else {
1
};
let w = w1 * w2 * multiplier;
(self.join_func)(
cursor1.key(),
cursor1.val(),
cursor2.get_val(),
&mut |k, v| {
output_tuples.push_with(&mut move |tup| {
let (kv, neww) = tup.split_mut();
let (newk, newv) = kv.split_mut();
k.move_to(newk);
v.move_to(newv);
*unsafe { neww.downcast_mut() } = w;
});
},
);
cursor1.step_val_reverse();
}
}
#[allow(clippy::too_many_arguments)]
fn eval_key<DC1, DC2, ZC1, ZC2, C1, C2>(
&self,
delta1: &mut DC1,
delta2: &mut DC2,
delayed_cursor1: &mut ZC1,
delayed_cursor2: &mut ZC2,
cursor1: &mut C1,
cursor2: &mut C2,
affected_times: &mut DynVec<TS>,
output_tuples: &mut DynWeightedPairs<DynPair<Z::Key, Z::Val>, DynZWeight>,
) where
DC1: ZCursor<I1::Key, I1::Val, ()>,
DC2: ZCursor<I2::Key, I2::Val, ()>,
ZC1: ZCursor<I1::Key, I1::Val, ()>,
ZC2: ZCursor<I2::Key, I2::Val, ()>,
C1: ZCursor<I1::Key, I1::Val, ()>,
C2: ZCursor<I2::Key, I2::Val, ()>,
{
let key = if delta1.key_valid() {
delta1.key()
} else {
delta2.key()
};
let hash = key.default_hash();
let mut delayed_cursor1 = Self::try_seek(delayed_cursor1, key, hash);
let mut delayed_cursor2 = Self::try_seek(delayed_cursor2, key, hash);
let mut cursor1 = Self::try_seek(cursor1, key, hash);
let mut cursor2 = Self::try_seek(cursor2, key, hash);
let mut empty_cursor = CursorEmpty::new(WithFactory::<ZWeight>::FACTORY);
if let Some(cursor2) = &mut cursor2 {
self.compute_affected_times(
delta1,
delta2,
&mut delayed_cursor1,
*cursor2,
affected_times,
);
} else {
self.compute_affected_times(
delta1,
delta2,
&mut delayed_cursor1,
&mut empty_cursor,
affected_times,
);
}
if let Some(c) = cursor1.as_mut() {
c.fast_forward_vals()
}
if let Some(c) = cursor2.as_mut() {
c.fast_forward_vals()
}
if let Some(c) = delayed_cursor1.as_mut() {
c.fast_forward_vals()
}
if let Some(c) = delayed_cursor2.as_mut() {
c.fast_forward_vals()
}
for i in (0..affected_times.len()).rev() {
let ts = unsafe { affected_times.index_unchecked(i) };
if let Some(delayed_cursor2) = &mut delayed_cursor2 {
self.eval_val(
ts,
&mut delayed_cursor1,
*delayed_cursor2,
-1,
output_tuples,
);
} else {
self.eval_val(
ts,
&mut delayed_cursor1,
&mut empty_cursor,
-1,
output_tuples,
);
}
if let Some(cursor2) = &mut cursor2 {
self.eval_val(ts, &mut cursor1, *cursor2, 1, output_tuples);
} else {
self.eval_val(ts, &mut cursor1, &mut empty_cursor, 1, output_tuples);
}
}
}
}
impl<TS, I1, T1, I2, T2, Z> Operator for AsofJoin<TS, I1, T1, I2, T2, Z>
where
TS: DataTrait + ?Sized,
I1: IndexedZSet,
T1: ZBatchReader,
I2: IndexedZSet,
T2: ZBatchReader,
Z: IndexedZSet,
{
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed("AsofJoin")
}
fn location(&self) -> OperatorLocation {
Some(self.location)
}
fn flush(&mut self) {
self.flush.set(true);
}
fn metadata(&self, meta: &mut OperatorMeta) {
meta.extend(metadata! {
LEFT_INPUT_BATCHES_STATS => self.delta1_batch_stats.borrow().metadata(),
RIGHT_INPUT_BATCHES_STATS => self.delta2_batch_stats.borrow().metadata(),
OUTPUT_BATCHES_STATS => self.output_batch_stats.borrow().metadata(),
});
}
fn fixedpoint(&self, _scope: Scope) -> bool {
true
}
}
impl<TS, I1, T1, I2, T2, Z>
StreamingQuaternaryOperator<Option<Spine<I1>>, T1, Option<Spine<I2>>, T2, Z>
for AsofJoin<TS, I1, T1, I2, T2, Z>
where
TS: DataTrait + ?Sized,
I1: IndexedZSet,
T1: ZBatchReader<Key = I1::Key, Val = I1::Val, Time = ()> + Clone + WithSnapshot<Batch = I1>,
I2: IndexedZSet<Key = I1::Key>,
T2: ZBatchReader<Key = I2::Key, Val = I2::Val, Time = ()> + Clone + WithSnapshot<Batch = I2>,
Z: IndexedZSet,
{
fn eval(
self: Rc<Self>,
delta1: Cow<'_, Option<Spine<I1>>>,
delayed_trace1: Cow<'_, T1>,
delta2: Cow<'_, Option<Spine<I2>>>,
delayed_trace2: Cow<'_, T2>,
) -> impl futures::Stream<Item = (Z, bool, Option<Position>)> + 'static {
if let Some(delta1) = delta1.as_ref() {
*self.delta1.borrow_mut() = Some(delta1.ro_snapshot());
};
if let Some(delta2) = delta2.as_ref() {
*self.delta2.borrow_mut() = Some(delta2.ro_snapshot());
};
let delayed_trace1 = if self.flush.get() {
Some(delayed_trace1.as_ref().ro_snapshot())
} else {
None
};
let delayed_trace2 = if self.flush.get() {
Some(delayed_trace2.ro_snapshot())
} else {
None
};
stream! {
let chunk_size = splitter_output_chunk_size();
if !self.flush.replace(false) {
yield(Z::dyn_empty(&self.factories.output_factories), true, None);
return;
}
let delta1 = self.delta1.take().unwrap();
let delta2 = self.delta2.take().unwrap();
self.delta1_batch_stats.borrow_mut().add_batch(delta1.len());
self.delta2_batch_stats.borrow_mut().add_batch(delta2.len());
let mut delta1_cursor = delta1.cursor();
let mut delta2_cursor = delta2.cursor();
let delayed_trace1 = delayed_trace1.expect("no delayed trace1 provided before flush");
let delayed_trace2 = delayed_trace2.expect("no delayed trace2 provided before flush");
let mut delayed_trace1_cursor = delayed_trace1.cursor();
let mut delayed_trace2_cursor = delayed_trace2.cursor();
let mut trace1_cursor = CursorPair::new(&mut delta1_cursor, &mut delayed_trace1_cursor);
let mut trace2_cursor = CursorPair::new(&mut delta2_cursor, &mut delayed_trace2_cursor);
let mut delta1_cursor = delta1.cursor();
let mut delta2_cursor = delta2.cursor();
let mut delayed_trace1_cursor = delayed_trace1.cursor();
let mut delayed_trace2_cursor = delayed_trace2.cursor();
let weighted_items_factory = self.factories.output_factories.weighted_items_factory();
let mut output_tuples = weighted_items_factory.default_box();
output_tuples.reserve(chunk_size);
let mut affected_times = self.factories.timestamps_factory.default_box();
while delta1_cursor.key_valid() && delta2_cursor.key_valid() {
match delta1_cursor.key().cmp(delta2_cursor.key()) {
Ordering::Less => {
self.eval_key(
&mut delta1_cursor,
&mut CursorEmpty::new(WithFactory::<ZWeight>::FACTORY),
&mut delayed_trace1_cursor,
&mut delayed_trace2_cursor,
&mut trace1_cursor,
&mut trace2_cursor,
affected_times.as_mut(),
output_tuples.as_mut(),
);
delta1_cursor.step_key();
}
Ordering::Equal => {
self.eval_key(
&mut delta1_cursor,
&mut delta2_cursor,
&mut delayed_trace1_cursor,
&mut delayed_trace2_cursor,
&mut trace1_cursor,
&mut trace2_cursor,
affected_times.as_mut(),
output_tuples.as_mut(),
);
delta1_cursor.step_key();
delta2_cursor.step_key();
}
Ordering::Greater => {
self.eval_key(
&mut CursorEmpty::new(WithFactory::<ZWeight>::FACTORY),
&mut delta2_cursor,
&mut delayed_trace1_cursor,
&mut delayed_trace2_cursor,
&mut trace1_cursor,
&mut trace2_cursor,
affected_times.as_mut(),
output_tuples.as_mut(),
);
delta2_cursor.step_key();
}
}
if output_tuples.len() >= chunk_size {
let result = Z::dyn_from_tuples(&self.factories.output_factories, (), &mut output_tuples);
self.output_batch_stats.borrow_mut().add_batch(result.len());
yield (result, false, delta1_cursor.position());
output_tuples = weighted_items_factory.default_box();
output_tuples.reserve(chunk_size);
}
}
while delta1_cursor.key_valid() {
self.eval_key(
&mut delta1_cursor,
&mut CursorEmpty::new(WithFactory::<ZWeight>::FACTORY),
&mut delayed_trace1_cursor,
&mut delayed_trace2_cursor,
&mut trace1_cursor,
&mut trace2_cursor,
affected_times.as_mut(),
output_tuples.as_mut(),
);
delta1_cursor.step_key();
if output_tuples.len() >= chunk_size {
let result = Z::dyn_from_tuples(&self.factories.output_factories, (), &mut output_tuples);
self.output_batch_stats.borrow_mut().add_batch(result.len());
yield (result, false, delta1_cursor.position());
output_tuples = weighted_items_factory.default_box();
output_tuples.reserve(chunk_size);
}
}
while delta2_cursor.key_valid() {
self.eval_key(
&mut CursorEmpty::new(WithFactory::<ZWeight>::FACTORY),
&mut delta2_cursor,
&mut delayed_trace1_cursor,
&mut delayed_trace2_cursor,
&mut trace1_cursor,
&mut trace2_cursor,
affected_times.as_mut(),
output_tuples.as_mut(),
);
delta2_cursor.step_key();
if output_tuples.len() >= chunk_size {
let result = Z::dyn_from_tuples(&self.factories.output_factories, (), &mut output_tuples);
self.output_batch_stats.borrow_mut().add_batch(result.len());
yield (result, false, delta1_cursor.position());
output_tuples = weighted_items_factory.default_box();
output_tuples.reserve(chunk_size);
}
}
let result = Z::dyn_from_tuples(&self.factories.output_factories, (), &mut output_tuples);
self.output_batch_stats.borrow_mut().add_batch(result.len());
yield (result, true, delta1_cursor.position());
}
}
}
#[cfg(test)]
mod test {
use std::cmp::{max, min};
use crate::{
DBData, DBSPHandle, OrdIndexedZSet, OrdZSet, OutputHandle, Runtime, TypedBox, ZSetHandle,
ZWeight,
algebra::F32,
circuit::CircuitConfig,
dynamic::DowncastTrait,
typed_batch::{IndexedZSetReader, SpineSnapshot},
utils::{Tup2, Tup3, Tup4},
zset,
};
use proptest::{collection::vec, prelude::*};
type Time = u64;
type CCNum = u64;
type Amt = F32;
type Transaction = Tup3<Time, CCNum, Amt>;
type User = Tup3<Time, CCNum, String>;
type Output = Tup4<Time, CCNum, Amt, Option<String>>;
fn join(
_key: &CCNum,
transaction: &Transaction,
user: Option<&User>,
) -> Tup4<Time, CCNum, Amt, Option<String>> {
Tup4(
transaction.0,
transaction.1,
transaction.2,
user.map(|u| u.2.clone()),
)
}
fn ts_func1(transaction: &Transaction) -> Time {
transaction.0
}
fn ts_func2(user: &User) -> Time {
user.0
}
fn test_circuit() -> (
DBSPHandle,
(
ZSetHandle<Transaction>,
ZSetHandle<User>,
OutputHandle<SpineSnapshot<OrdZSet<Output>>>,
OutputHandle<SpineSnapshot<OrdZSet<Output>>>,
OutputHandle<SpineSnapshot<OrdIndexedZSet<CCNum, Transaction>>>,
OutputHandle<SpineSnapshot<OrdIndexedZSet<CCNum, User>>>,
),
) {
Runtime::init_circuit(
CircuitConfig::with_workers(2).with_splitter_chunk_size_records(2),
|circuit| {
let (transactions, transactions_handle) = circuit.add_input_zset::<Transaction>();
let (users, users_handle) = circuit.add_input_zset::<User>();
let transactions =
transactions.map_index(|transaction| (transaction.1, *transaction));
let users = users.map_index(|user| (user.1, user.clone()));
let result = transactions.asof_join(&users, join, ts_func1, ts_func2);
let transactions_output_handle = transactions
.shard()
.accumulate_integrate()
.accumulate_output();
let users_output_handle = users.shard().accumulate_integrate().accumulate_output();
let output_handle = result.accumulate_output();
let output_integral_handle = result.accumulate_integrate().accumulate_output();
Ok((
transactions_handle,
users_handle,
output_handle,
output_integral_handle,
transactions_output_handle,
users_output_handle,
))
},
)
.unwrap()
}
fn test_circuit_with_waterline() -> (
DBSPHandle,
(
ZSetHandle<Transaction>,
ZSetHandle<User>,
OutputHandle<SpineSnapshot<OrdZSet<Output>>>,
OutputHandle<SpineSnapshot<OrdZSet<Output>>>,
OutputHandle<SpineSnapshot<OrdIndexedZSet<CCNum, Transaction>>>,
OutputHandle<SpineSnapshot<OrdIndexedZSet<CCNum, User>>>,
),
) {
Runtime::init_circuit(
CircuitConfig::with_workers(2).with_splitter_chunk_size_records(2),
|circuit| {
let (transactions, transactions_handle) = circuit.add_input_zset::<Transaction>();
let (users, users_handle) = circuit.add_input_zset::<User>();
let transactions =
transactions.map_index(|transaction| (transaction.1, *transaction));
let users = users.map_index(|user| (user.1, user.clone()));
let user_waterline = users.waterline(
|| u64::MIN,
|_k, Tup3(ts, _, _)| {
(*ts).saturating_sub(LATENESS)
},
|ts1, ts2| {
max(*ts1, *ts2)
},
);
let transaction_waterline = transactions.waterline(
|| u64::MIN,
|_k, Tup3(ts, _, _)| (*ts).saturating_sub(LATENESS),
|ts1, ts2| max(*ts1, *ts2),
);
let waterline = transaction_waterline.apply2(&user_waterline, |ts1, ts2| {
TypedBox::new(min(unsafe { *ts1.inner().downcast::<u64>() }, unsafe {
*ts2.inner().downcast::<u64>()
}))
});
let result = transactions.asof_join(&users, join, ts_func1, ts_func2);
transactions.accumulate_integrate_trace_retain_values(
&waterline,
|transaction: &Transaction, ts: &u64| transaction.0 >= *ts,
);
users.accumulate_integrate_trace_retain_values_last_n(
&waterline,
|user: &User, ts: &u64| user.0 >= *ts,
1,
);
let transactions_output_handle = transactions
.shard()
.accumulate_integrate()
.accumulate_output();
let users_output_handle = users.shard().accumulate_integrate().accumulate_output();
let output_handle = result.accumulate_output();
let output_integral_handle = result.accumulate_integrate().accumulate_output();
Ok((
transactions_handle,
users_handle,
output_handle,
output_integral_handle,
transactions_output_handle,
users_output_handle,
))
},
)
.unwrap()
}
#[test]
fn asof_join_test() {
let (
mut dbsp,
(
transactions,
users,
result,
_result_integral,
_transactions_output_handle,
_users_output_handle,
),
) = test_circuit();
transactions.append(&mut vec![
Tup2(Tup3(100, 1, F32::new(10.0)), 1),
Tup2(Tup3(200, 1, F32::new(10.0)), 1),
Tup2(Tup3(300, 1, F32::new(10.0)), 1),
Tup2(Tup3(100, 2, F32::new(20.0)), 1),
Tup2(Tup3(100, 3, F32::new(30.0)), 1),
]);
dbsp.transaction().unwrap();
assert_eq!(
result.concat().consolidate(),
zset! {
Tup4(100, 1, F32::new(10.0), None) => 1,
Tup4(200, 1, F32::new(10.0), None) => 1,
Tup4(300, 1, F32::new(10.0), None) => 1,
Tup4(100, 2, F32::new(20.0), None) => 1,
Tup4(100, 3, F32::new(30.0), None) => 1,
}
);
users.append(&mut vec![
Tup2(Tup3(50, 1, "A50".to_string()), 1),
Tup2(Tup3(100, 2, "B100".to_string()), 1),
Tup2(Tup3(110, 3, "C110".to_string()), 1),
]);
dbsp.transaction().unwrap();
assert_eq!(
result.concat().consolidate(),
zset! {
Tup4(100, 1, F32::new(10.0), None) => -1,
Tup4(200, 1, F32::new(10.0), None) => -1,
Tup4(300, 1, F32::new(10.0), None) => -1,
Tup4(100, 2, F32::new(20.0), None) => -1,
Tup4(100, 1, F32::new(10.0), Some("A50".to_string())) => 1,
Tup4(200, 1, F32::new(10.0), Some("A50".to_string())) => 1,
Tup4(300, 1, F32::new(10.0), Some("A50".to_string())) => 1,
Tup4(100, 2, F32::new(20.0), Some("B100".to_string())) => 1,
}
);
users.append(&mut vec![
Tup2(Tup3(60, 1, "A60".to_string()), 1),
Tup2(Tup3(120, 2, "B120".to_string()), 1),
Tup2(Tup3(50, 3, "C50".to_string()), 1),
]);
transactions.append(&mut vec![Tup2(Tup3(200, 3, F32::new(30.0)), 1)]);
dbsp.transaction().unwrap();
assert_eq!(
result.concat().consolidate(),
zset! {
Tup4(100, 1, F32::new(10.0), Some("A50".to_string())) => -1,
Tup4(100, 1, F32::new(10.0), Some("A60".to_string())) => 1,
Tup4(200, 1, F32::new(10.0), Some("A50".to_string())) => -1,
Tup4(200, 1, F32::new(10.0), Some("A60".to_string())) => 1,
Tup4(300, 1, F32::new(10.0), Some("A50".to_string())) => -1,
Tup4(300, 1, F32::new(10.0), Some("A60".to_string())) => 1,
Tup4(100, 3, F32::new(30.0), None) => -1,
Tup4(100, 3, F32::new(30.0), Some("C50".to_string())) => 1,
Tup4(200, 3, F32::new(30.0), Some("C110".to_string())) => 1,
}
);
users.append(&mut vec![
Tup2(Tup3(10, 1, "A10".to_string()), 1),
Tup2(Tup3(10, 2, "B10".to_string()), 1),
Tup2(Tup3(10, 3, "C10".to_string()), 1),
Tup2(Tup3(110, 3, "C105".to_string()), 1),
]);
dbsp.transaction().unwrap();
assert_eq!(result.concat().consolidate(), zset! {});
transactions.append(&mut vec![
Tup2(Tup3(100, 1, F32::new(100.0)), 1),
Tup2(Tup3(200, 1, F32::new(100.0)), 1),
Tup2(Tup3(300, 1, F32::new(100.0)), 1),
Tup2(Tup3(100, 2, F32::new(200.0)), 1),
Tup2(Tup3(100, 3, F32::new(300.0)), 1),
]);
dbsp.transaction().unwrap();
assert_eq!(
result.concat().consolidate(),
zset! {
Tup4(100, 1, F32::new(100.0), Some("A60".to_string())) => 1,
Tup4(200, 1, F32::new(100.0), Some("A60".to_string())) => 1,
Tup4(300, 1, F32::new(100.0), Some("A60".to_string())) => 1,
Tup4(100, 2, F32::new(200.0), Some("B100".to_string())) => 1,
Tup4(100, 3, F32::new(300.0), Some("C50".to_string())) => 1,
}
);
users.append(&mut vec![
Tup2(Tup3(10, 1, "A10".to_string()), -1),
Tup2(Tup3(10, 2, "B10".to_string()), -1),
Tup2(Tup3(10, 3, "C10".to_string()), -1),
]);
dbsp.transaction().unwrap();
assert_eq!(result.concat().consolidate(), zset! {});
users.append(&mut vec![
Tup2(Tup3(60, 1, "A60".to_string()), -1),
Tup2(Tup3(120, 2, "B120".to_string()), -1),
Tup2(Tup3(110, 3, "C110".to_string()), -1),
]);
dbsp.transaction().unwrap();
assert_eq!(
result.concat().consolidate(),
zset! {
Tup4(100, 1, F32::new(100.0), Some("A60".to_string())) => -1,
Tup4(200, 1, F32::new(100.0), Some("A60".to_string())) => -1,
Tup4(300, 1, F32::new(100.0), Some("A60".to_string())) => -1,
Tup4(100, 1, F32::new(10.0), Some("A60".to_string())) => -1,
Tup4(200, 1, F32::new(10.0), Some("A60".to_string())) => -1,
Tup4(300, 1, F32::new(10.0), Some("A60".to_string())) => -1,
Tup4(100, 1, F32::new(100.0), Some("A50".to_string())) => 1,
Tup4(200, 1, F32::new(100.0), Some("A50".to_string())) => 1,
Tup4(300, 1, F32::new(100.0), Some("A50".to_string())) => 1,
Tup4(100, 1, F32::new(10.0), Some("A50".to_string())) => 1,
Tup4(200, 1, F32::new(10.0), Some("A50".to_string())) => 1,
Tup4(300, 1, F32::new(10.0), Some("A50".to_string())) => 1,
Tup4(200, 3, F32::new(30.0), Some("C110".to_string())) => -1,
Tup4(200, 3, F32::new(30.0), Some("C105".to_string())) => 1,
}
);
}
#[test]
fn asof_join_regressions() {
let (
mut dbsp,
(
transactions,
users,
_result,
result_integral,
transactions_output_handle,
users_output_handle,
),
) = test_circuit();
users.append(&mut vec![
Tup2(Tup3(37, 0, "L".to_string()), 1),
Tup2(Tup3(0, 0, "A".to_string()), 1),
]);
dbsp.transaction().unwrap();
assert_eq!(
result_integral.concat().consolidate(),
asof_join_reference2(&transactions_output_handle, &users_output_handle,)
);
transactions.append(&mut vec![Tup2(Tup3(37, 0, F32::new(0.0)), 1)]);
dbsp.transaction().unwrap();
assert_eq!(
result_integral.concat().consolidate(),
asof_join_reference2(&transactions_output_handle, &users_output_handle,)
);
users.append(&mut vec![Tup2(Tup3(37, 0, "L".to_string()), -1)]);
dbsp.transaction().unwrap();
assert_eq!(
result_integral.concat().consolidate(),
asof_join_reference2(&transactions_output_handle, &users_output_handle,)
);
users.append(&mut vec![Tup2(Tup3(0, 0, "A".to_string()), -1)]);
dbsp.transaction().unwrap();
assert_eq!(
result_integral.concat().consolidate(),
asof_join_reference2(&transactions_output_handle, &users_output_handle,)
);
}
fn asof_join_reference2(
transactions_handle: &OutputHandle<SpineSnapshot<OrdIndexedZSet<CCNum, Transaction>>>,
users_handle: &OutputHandle<SpineSnapshot<OrdIndexedZSet<CCNum, User>>>,
) -> OrdZSet<Output> {
asof_join_reference(
&transactions_handle.concat().consolidate(),
&users_handle.concat().consolidate(),
join,
ts_func1,
ts_func2,
)
}
fn asof_join_reference<TS, F, TSF1, TSF2, K, V1, V2, OV>(
left: &OrdIndexedZSet<K, V1>,
right: &OrdIndexedZSet<K, V2>,
join: F,
ts_func1: TSF1,
ts_func2: TSF2,
) -> OrdZSet<OV>
where
TS: DBData,
K: DBData,
V1: DBData,
V2: DBData,
OV: DBData,
F: Fn(&K, &V1, Option<&V2>) -> OV + Clone + 'static,
TSF1: Fn(&V1) -> TS + Clone + 'static,
TSF2: Fn(&V2) -> TS + 'static,
{
let left = left.iter().collect::<Vec<_>>();
let right = right.iter().collect::<Vec<_>>();
let mut result = Vec::new();
for (k, v1, w1) in left.iter() {
let (ov, ow) = right
.iter()
.rev()
.find_map(|(k2, v2, w2)| {
if k2 == k && ts_func2(v2) <= ts_func1(v1) {
Some((join(k, v1, Some(v2)), w1 * w2))
} else {
None
}
})
.unwrap_or_else(|| (join(k, v1, None), *w1));
result.push(Tup2(ov, ow));
}
OrdZSet::from_keys((), result)
}
const LATENESS: u64 = 20;
prop_compose! {
fn transaction(step: usize)
(time in 0..LATENESS,
cc_num in 0..10u64,
amt in 0..100i32,
w in 1..=2 as ZWeight)
-> Tup2<Transaction, ZWeight> {
Tup2(Tup3(step as u64 + time, cc_num, F32::new(amt as f32)), w)
}
}
prop_compose! {
fn user(step: usize)
(time in 0..LATENESS,
cc_num in 0..5u64,
name in "[A-Z][a-z]{5}",
w in 1..=2 as ZWeight)
-> Tup2<User, ZWeight> {
Tup2(Tup3(step as u64 + time, cc_num, name), w)
}
}
prop_compose! {
fn input(step: usize)
(transactions in vec(transaction(step), 0..20),
users in vec(user(step), 0..10))
-> (Vec<Tup2<Transaction, ZWeight>>, Vec<Tup2<User, ZWeight>>) {
(transactions, users)
}
}
fn inputs(
steps: usize,
) -> impl Strategy<Value = Vec<(Vec<Tup2<Transaction, ZWeight>>, Vec<Tup2<User, ZWeight>>)>>
{
(0..steps).map(input).collect::<Vec<_>>().prop_map(|v| v)
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn asof_join_proptest(inputs in inputs(50)) {
let (mut dbsp, (htransactions, husers, _hresult, hresult_integral, htransactions_output_handle, husers_output_handle)) = test_circuit();
let mut deletions = inputs.clone();
for (ts, us) in deletions.iter_mut() {
for Tup2(_t, w) in ts.iter_mut(){
*w = -*w;
}
for Tup2(_u, w) in us.iter_mut() {
*w = -*w;
}
}
for (mut transactions, mut users) in inputs {
htransactions.append(&mut transactions);
husers.append(&mut users);
dbsp.transaction().unwrap();
assert_eq!(
hresult_integral.concat().consolidate(),
asof_join_reference2(&htransactions_output_handle, &husers_output_handle)
);
}
for (mut transactions, mut users) in deletions {
htransactions.append(&mut transactions);
husers.append(&mut users);
dbsp.transaction().unwrap();
assert_eq!(
hresult_integral.concat().consolidate(),
asof_join_reference2(&htransactions_output_handle, &husers_output_handle)
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn asof_join_with_waterline_proptest(inputs in inputs(100)) {
let (mut dbsp, (htransactions, husers, _hresult, hresult_integral, htransactions_output_handle, husers_output_handle)) = test_circuit_with_waterline();
for (mut transactions, mut users) in inputs {
htransactions.append(&mut transactions);
husers.append(&mut users);
dbsp.transaction().unwrap();
assert_eq!(
hresult_integral.concat().consolidate(),
asof_join_reference2(&htransactions_output_handle, &husers_output_handle)
);
}
}
}
}