dbsp/operator/dynamic/
input_upsert.rs

1use crate::{
2    algebra::{HasOne, HasZero, IndexedZSet, OrdZSet, ZTrace},
3    circuit::{
4        checkpointer::Checkpoint,
5        circuit_builder::{register_replay_stream, CircuitBase, RefStreamValue},
6        metadata::{BatchSizeStats, OperatorMeta, INPUT_BATCHES_LABEL, OUTPUT_BATCHES_LABEL},
7        operator_traits::{BinaryOperator, Operator, TernaryOperator},
8        OwnershipPreference, Scope,
9    },
10    declare_trait_object,
11    dynamic::{ClonableTrait, Data, DataTrait, DynOpt, DynPairs, Erase, Factory, WithFactory},
12    operator::{
13        dynamic::{
14            time_series::LeastUpperBoundFunc,
15            trace::{DelayedTraceId, TraceBounds, TraceId, UntimedTraceAppend, Z1Trace},
16        },
17        Z1,
18    },
19    trace::{
20        cursor::Cursor, Batch, BatchFactories, BatchReader, BatchReaderFactories, Builder, Rkyv,
21        Spine,
22    },
23    Circuit, DBData, NumEntries, RootCircuit, Stream, ZWeight,
24};
25use minitrace::trace;
26use rkyv::{Archive, Deserialize, Serialize};
27use size_of::SizeOf;
28use std::{
29    borrow::Cow,
30    marker::PhantomData,
31    mem::take,
32    ops::{Deref, Neg},
33};
34
35use super::trace::BoundsId;
36
37#[derive(
38    Clone,
39    Debug,
40    Default,
41    SizeOf,
42    PartialEq,
43    Eq,
44    Hash,
45    PartialOrd,
46    Ord,
47    Archive,
48    Serialize,
49    Deserialize,
50)]
51#[archive_attr(derive(Ord, Eq, PartialEq, PartialOrd))]
52#[archive(compare(PartialEq, PartialOrd))]
53pub enum Update<V: DBData, U: DBData> {
54    Insert(V),
55    #[default]
56    Delete,
57    Update(U),
58}
59
60pub enum UpdateRef<'a, V: DataTrait + ?Sized, U: DataTrait + ?Sized> {
61    Insert(&'a V),
62    Delete,
63    Update(&'a U),
64}
65
66impl<V: DBData, U: DBData> NumEntries for Update<V, U> {
67    const CONST_NUM_ENTRIES: Option<usize> = Some(1);
68
69    fn num_entries_shallow(&self) -> usize {
70        1
71    }
72
73    fn num_entries_deep(&self) -> usize {
74        1
75    }
76}
77
78pub trait UpdateTrait<V: DataTrait + ?Sized, U: DataTrait + ?Sized>: Data {
79    fn get(&self) -> UpdateRef<'_, V, U>;
80    fn insert_ref(&mut self, val: &V);
81    fn insert_val(&mut self, val: &mut V);
82    fn delete(&mut self);
83    fn update_ref(&mut self, upd: &U);
84    fn update_val(&mut self, upd: &mut U);
85}
86
87impl<V, U, VType, UType> UpdateTrait<V, U> for Update<VType, UType>
88where
89    V: DataTrait + ?Sized,
90    U: DataTrait + ?Sized,
91    VType: DBData + Erase<V>,
92    UType: DBData + Erase<U>,
93{
94    fn get(&self) -> UpdateRef<'_, V, U> {
95        match self {
96            Update::Insert(v) => UpdateRef::Insert(v.erase()),
97            Update::Delete => UpdateRef::Delete,
98            Update::Update(u) => UpdateRef::Update(u.erase()),
99        }
100    }
101
102    fn insert_ref(&mut self, val: &V) {
103        *self = Update::Insert(unsafe { val.downcast::<VType>().clone() })
104    }
105
106    fn insert_val(&mut self, val: &mut V) {
107        *self = Update::Insert(take(unsafe { val.downcast_mut::<VType>() }))
108    }
109
110    fn delete(&mut self) {
111        *self = Update::Delete;
112    }
113
114    fn update_ref(&mut self, upd: &U) {
115        *self = Update::Update(unsafe { upd.downcast::<UType>().clone() })
116    }
117
118    fn update_val(&mut self, upd: &mut U) {
119        *self = Update::Update(take(unsafe { upd.downcast_mut::<UType>() }))
120    }
121}
122
123declare_trait_object!(DynUpdate<VTrait, UTrait> = dyn UpdateTrait<VTrait, UTrait>
124where
125    VTrait: DataTrait + ?Sized,
126    UTrait: DataTrait + ?Sized,
127);
128
129pub type PatchFunc<V, U> = Box<dyn Fn(&mut V, &U)>;
130
131pub struct InputUpsertFactories<B: IndexedZSet> {
132    pub batch_factories: B::Factories,
133    pub opt_key_factory: &'static dyn Factory<DynOpt<B::Key>>,
134    pub opt_val_factory: &'static dyn Factory<DynOpt<B::Val>>,
135}
136
137impl<B: IndexedZSet> Clone for InputUpsertFactories<B> {
138    fn clone(&self) -> Self {
139        Self {
140            batch_factories: self.batch_factories.clone(),
141            opt_key_factory: self.opt_key_factory,
142            opt_val_factory: self.opt_val_factory,
143        }
144    }
145}
146
147impl<B> InputUpsertFactories<B>
148where
149    B: Batch + IndexedZSet,
150{
151    pub fn new<KType, VType>() -> Self
152    where
153        KType: DBData + Erase<B::Key>,
154        VType: DBData + Erase<B::Val>,
155    {
156        Self {
157            batch_factories: BatchReaderFactories::new::<KType, VType, ZWeight>(),
158            opt_key_factory: WithFactory::<Option<KType>>::FACTORY,
159            opt_val_factory: WithFactory::<Option<VType>>::FACTORY,
160        }
161    }
162}
163
164pub struct InputUpsertWithWaterlineFactories<B: IndexedZSet, E: DataTrait + ?Sized> {
165    pub batch_factories: B::Factories,
166    pub opt_key_factory: &'static dyn Factory<DynOpt<B::Key>>,
167    pub opt_val_factory: &'static dyn Factory<DynOpt<B::Val>>,
168    pub val_factory: &'static dyn Factory<B::Val>,
169    errors_factory: <OrdZSet<E> as BatchReader>::Factories,
170}
171
172impl<B: IndexedZSet, E: DataTrait + ?Sized> Clone for InputUpsertWithWaterlineFactories<B, E> {
173    fn clone(&self) -> Self {
174        Self {
175            batch_factories: self.batch_factories.clone(),
176            opt_key_factory: self.opt_key_factory,
177            opt_val_factory: self.opt_val_factory,
178            val_factory: self.val_factory,
179            errors_factory: self.errors_factory.clone(),
180        }
181    }
182}
183
184impl<B, E> InputUpsertWithWaterlineFactories<B, E>
185where
186    B: Batch + IndexedZSet,
187    E: DataTrait + ?Sized,
188{
189    pub fn new<KType, VType, EType>() -> Self
190    where
191        KType: DBData + Erase<B::Key>,
192        VType: DBData + Erase<B::Val>,
193        EType: DBData + Erase<E>,
194    {
195        Self {
196            batch_factories: BatchReaderFactories::new::<KType, VType, ZWeight>(),
197            opt_key_factory: WithFactory::<Option<KType>>::FACTORY,
198            opt_val_factory: WithFactory::<Option<VType>>::FACTORY,
199            val_factory: WithFactory::<VType>::FACTORY,
200            errors_factory: BatchReaderFactories::new::<EType, (), ZWeight>(),
201        }
202    }
203}
204
205impl<K, V, U> Stream<RootCircuit, Vec<Box<DynPairs<K, DynUpdate<V, U>>>>>
206where
207    K: DataTrait + ?Sized,
208    V: DataTrait + ?Sized,
209    U: DataTrait + ?Sized,
210{
211    /// Convert an input stream of upserts into a stream of updates to a
212    /// relation.
213    ///
214    /// The input stream carries changes to a key/value map in the form of
215    /// _upserts_.  An upsert assigns a new value to a key (or deletes the key
216    /// from the map) without explicitly removing the old value, if any.  The
217    /// operator converts upserts into batches of updates, which is the input
218    /// format of most DBSP operators.
219    ///
220    /// The operator assumes that the input vector is sorted by key; however,
221    /// unlike the [`Stream::upsert`] operator it allows the vector to
222    /// contain multiple updates per key.  Updates are applied one by one in
223    /// order, and the output of the operator reflects cumulative effect of
224    /// the updates.  Additionally, unlike the [`Stream::upsert`] operator,
225    /// which only supports inserts, which overwrite the entire value with a
226    /// new value, and deletions, this operator also supports updates that
227    /// modify the contents of a value, e.g., overwriting some of its
228    /// fields.  Type argument `U` defines the format of modifications,
229    /// and the `patch_func` function applies update of type `U` to a value of
230    /// type `V`.
231    ///
232    /// This is a stateful operator that internally maintains the trace of the
233    /// collection.
234    pub fn input_upsert<B>(
235        &self,
236        persistent_id: Option<&str>,
237        factories: &InputUpsertFactories<B>,
238        patch_func: PatchFunc<V, U>,
239    ) -> Stream<RootCircuit, B>
240    where
241        B: IndexedZSet<Key = K, Val = V>,
242    {
243        let circuit = self.circuit();
244
245        assert!(
246            self.is_sharded(),
247            "input_upsert operator applied to a non-sharded collection"
248        );
249
250        // We build the following circuit to implement the upsert semantics.
251        // The collection is accumulated into a trace using integrator
252        // (UntimedTraceAppend + Z1Trace = integrator).  The `InputUpsert`
253        // operator evaluates each upsert command in the input stream against
254        // the trace and computes a batch of updates to be added to the trace.
255        //
256        // ```text
257        //                               ┌────────────────────────────►
258        //                               │
259        //                               │
260        //  self        ┌───────────┐    │        ┌──────────────────┐  trace
261        // ────────────►│InputUpsert├────┴───────►│UntimedTraceAppend├────┐
262        //              └───────────┘   delta     └──────────────────┘    │
263        //                      ▲                  ▲                      │
264        //                      │                  │                      │
265        //                      │                  │   ┌───────┐          │
266        //                      └──────────────────┴───┤Z1Trace│◄─────────┘
267        //                         z1trace             └───────┘
268        // ```
269        circuit.region("input_upsert", || {
270            let bounds = <TraceBounds<K, V>>::unbounded();
271
272            let z1 = Z1Trace::new(
273                &factories.batch_factories,
274                &factories.batch_factories,
275                false,
276                circuit.root_scope(),
277                bounds.clone(),
278            );
279
280            let (delayed_trace, z1feedback) = circuit.add_feedback_persistent(
281                persistent_id
282                    .map(|name| format!("{name}.integral"))
283                    .as_deref(),
284                z1,
285            );
286
287            delayed_trace.mark_sharded();
288
289            let delta = circuit
290                .add_binary_operator(
291                    <InputUpsert<Spine<B>, U, B>>::new(
292                        factories.batch_factories.clone(),
293                        factories.opt_key_factory,
294                        factories.opt_val_factory,
295                        patch_func,
296                    ),
297                    &delayed_trace,
298                    self,
299                )
300                .mark_distinct();
301            delta.mark_sharded();
302            let replay_stream = z1feedback.operator_mut().prepare_replay_stream(&delta);
303
304            let trace = circuit.add_binary_operator_with_preference(
305                UntimedTraceAppend::<Spine<B>>::new(),
306                (&delayed_trace, OwnershipPreference::STRONGLY_PREFER_OWNED),
307                (&delta, OwnershipPreference::PREFER_OWNED),
308            );
309            trace.mark_sharded();
310
311            z1feedback.connect_with_preference(&trace, OwnershipPreference::STRONGLY_PREFER_OWNED);
312
313            register_replay_stream(circuit, &delta, &replay_stream);
314
315            circuit.cache_insert(DelayedTraceId::new(trace.stream_id()), delayed_trace);
316            circuit.cache_insert(TraceId::new(delta.stream_id()), trace);
317            circuit.cache_insert(BoundsId::<B>::new(delta.stream_id()), bounds);
318            delta
319        })
320    }
321
322    // Like `input_upsert`, but additionally tracks a waterline of the input collection and
323    // rejects inputs that are below the waterline.  An input is rejected if the input record
324    // itself is below the waterline or if the existing record it replaces is below the waterline.
325    #[allow(clippy::too_many_arguments)]
326    pub fn input_upsert_with_waterline<B, W, E>(
327        &self,
328        persistent_id: Option<&str>,
329        factories: &InputUpsertWithWaterlineFactories<B, E>,
330        patch_func: PatchFunc<V, U>,
331        init_waterline: Box<dyn Fn() -> Box<W>>,
332        extract_ts: Box<dyn Fn(&B::Key, &B::Val, &mut W)>,
333        least_upper_bound: LeastUpperBoundFunc<W>,
334        filter_func: Box<dyn Fn(&W, &B::Key, &B::Val) -> bool>,
335        report_func: Box<dyn Fn(&W, &B::Key, &B::Val, ZWeight, &mut E)>,
336    ) -> (
337        Stream<RootCircuit, B>,
338        Stream<RootCircuit, OrdZSet<E>>,
339        Stream<RootCircuit, Box<W>>,
340    )
341    where
342        B: IndexedZSet<Key = K, Val = V>,
343        W: DataTrait + Checkpoint + ?Sized,
344        E: DataTrait + ?Sized,
345        Box<W>: Checkpoint + Clone + NumEntries + Rkyv,
346    {
347        let circuit = self.circuit();
348
349        assert!(
350            self.is_sharded(),
351            "input_upsert_with_waterline operator applied to a non-sharded collection"
352        );
353
354        // ```text
355        //                   ┌─────────────────────────────────────────────►
356        //                   │ waterline
357        //             ┌─────┴─────┐
358        //             │ waterline │◄─────────┬────────────────────────────►
359        //             └──────┬────┘          │
360        //                    │ waterline     │
361        //                   Z1               │
362        //  delayed_waterline │               │
363        //                    ▼               │
364        //         ┌─────────────────────┐    │        ┌──────────────────┐  trace
365        // ───────►│InputUpsertWaterline ├────┴───────►│UntimedTraceAppend├────┐
366        //         └──────┬──────────────┘   delta     └──────────────────┘    │
367        //                │          ▲                  ▲                      │
368        //                │          │                  │                      │
369        //                │          │                  │   ┌───────┐          │
370        //                │          └──────────────────┴───┤Z1Trace│◄─────────┘
371        //                │             delayed_trace       └───────┘
372        //                │
373        //                │error stream
374        //                └────────────────────────────────────────────────►
375        // ```
376
377        circuit.region("input_upsert_waterline", || {
378            let bounds = <TraceBounds<K, V>>::unbounded();
379
380            let z1 = Z1Trace::new(
381                &factories.batch_factories,
382                &factories.batch_factories,
383                false,
384                circuit.root_scope(),
385                bounds.clone(),
386            );
387
388            let (delayed_trace, z1feedback) = circuit.add_feedback_persistent(
389                persistent_id
390                    .map(|name| format!("{name}.integral"))
391                    .as_deref(),
392                z1,
393            );
394
395            delayed_trace.mark_sharded();
396
397            let waterline_z1 = Z1::new((init_waterline)());
398
399            let (delayed_waterline, waterline_feedback) = circuit.add_feedback_persistent(
400                persistent_id
401                    .map(|name| format!("{name}.delayed_waterline"))
402                    .as_deref(),
403                waterline_z1,
404            );
405
406            let error_stream_val = RefStreamValue::empty();
407
408            let delta = circuit
409                .add_ternary_operator(
410                    <InputUpsertWithWaterline<Spine<B>, U, B, W, E>>::new(
411                        factories.clone(),
412                        patch_func,
413                        filter_func,
414                        report_func,
415                        error_stream_val.clone(),
416                    ),
417                    &delayed_trace,
418                    self,
419                    &delayed_waterline,
420                )
421                .mark_distinct();
422            delta.mark_sharded();
423            let replay_stream = z1feedback.operator_mut().prepare_replay_stream(&delta);
424
425            let waterline_id = persistent_id.map(|name| format!("{name}.input_waterline"));
426
427            let waterline = delta.dyn_waterline(
428                waterline_id.as_deref(),
429                init_waterline,
430                extract_ts,
431                least_upper_bound,
432            );
433
434            waterline_feedback.connect(&waterline);
435
436            let trace = circuit.add_binary_operator_with_preference(
437                UntimedTraceAppend::<Spine<B>>::new(),
438                (&delayed_trace, OwnershipPreference::STRONGLY_PREFER_OWNED),
439                (&delta, OwnershipPreference::PREFER_OWNED),
440            );
441            trace.mark_sharded();
442
443            z1feedback.connect_with_preference(&trace, OwnershipPreference::STRONGLY_PREFER_OWNED);
444
445            register_replay_stream(circuit, &delta, &replay_stream);
446
447            let error_stream = Stream::with_value(
448                self.circuit().clone(),
449                delta.local_node_id(),
450                error_stream_val,
451            );
452
453            circuit.cache_insert(DelayedTraceId::new(trace.stream_id()), delayed_trace);
454            circuit.cache_insert(TraceId::new(delta.stream_id()), trace);
455            circuit.cache_insert(BoundsId::<B>::new(delta.stream_id()), bounds);
456
457            (delta, error_stream, waterline)
458        })
459    }
460}
461
462pub struct InputUpsert<T, U, B>
463where
464    T: BatchReader,
465    B: Batch,
466    U: DataTrait + ?Sized,
467{
468    batch_factories: B::Factories,
469    opt_key_factory: &'static dyn Factory<DynOpt<B::Key>>,
470    opt_val_factory: &'static dyn Factory<DynOpt<B::Val>>,
471    patch_func: PatchFunc<T::Val, U>,
472
473    // Input batch sizes.
474    input_batch_stats: BatchSizeStats,
475
476    // Output batch sizes.
477    output_batch_stats: BatchSizeStats,
478
479    phantom: PhantomData<B>,
480}
481
482impl<T, U, B> InputUpsert<T, U, B>
483where
484    T: BatchReader,
485    B: Batch,
486    U: DataTrait + ?Sized,
487{
488    pub fn new(
489        batch_factories: B::Factories,
490        opt_key_factory: &'static dyn Factory<DynOpt<B::Key>>,
491        opt_val_factory: &'static dyn Factory<DynOpt<B::Val>>,
492        patch_func: PatchFunc<T::Val, U>,
493    ) -> Self {
494        Self {
495            batch_factories,
496            opt_key_factory,
497            opt_val_factory,
498            patch_func,
499            input_batch_stats: BatchSizeStats::new(),
500            output_batch_stats: BatchSizeStats::new(),
501            phantom: PhantomData,
502        }
503    }
504}
505
506impl<T, U, B> Operator for InputUpsert<T, U, B>
507where
508    T: BatchReader,
509    U: DataTrait + ?Sized,
510    B: Batch,
511{
512    fn name(&self) -> Cow<'static, str> {
513        Cow::from("InputUpsert")
514    }
515
516    fn metadata(&self, meta: &mut OperatorMeta) {
517        meta.extend(metadata! {
518            INPUT_BATCHES_LABEL => self.input_batch_stats.metadata(),
519            OUTPUT_BATCHES_LABEL => self.output_batch_stats.metadata(),
520        });
521    }
522
523    fn fixedpoint(&self, _scope: Scope) -> bool {
524        true
525    }
526}
527
528impl<T, U, B> BinaryOperator<T, Vec<Box<DynPairs<T::Key, DynUpdate<T::Val, U>>>>, B>
529    for InputUpsert<T, U, B>
530where
531    T: ZTrace<Time = ()>,
532    U: DataTrait + ?Sized,
533    B: IndexedZSet<Key = T::Key, Val = T::Val>,
534{
535    #[trace]
536    async fn eval(
537        &mut self,
538        trace: &T,
539        updates: &Vec<Box<DynPairs<T::Key, DynUpdate<T::Val, U>>>>,
540    ) -> B {
541        // Inputs must be sorted by key
542        let mut updates = updates
543            .iter()
544            .filter_map(|updates| {
545                if !updates.is_empty() {
546                    Some((&**updates, 0))
547                } else {
548                    None
549                }
550            })
551            .collect::<Vec<_>>();
552        let n_updates = updates.iter().map(|updates| updates.0.len()).sum();
553        debug_assert!(updates
554            .iter()
555            .all(|updates| updates.0.is_sorted_by(&|u1, u2| u1.fst().cmp(u2.fst()))));
556
557        self.input_batch_stats.add_batch(n_updates);
558
559        let mut key_updates = self.batch_factories.weighted_vals_factory().default_box();
560
561        let mut trace_cursor = trace.cursor();
562
563        let mut builder =
564            B::Builder::with_capacity(&self.batch_factories, n_updates * 2, n_updates * 2);
565
566        // Current key for which we are processing updates.
567        let mut cur_key: Box<DynOpt<T::Key>> = self.opt_key_factory.default_box();
568
569        // Current value associated with the key after applying all processed updates
570        // to it.
571        let mut cur_val: Box<DynOpt<T::Val>> = self.opt_val_factory.default_box();
572
573        while !updates.is_empty() {
574            let (index, key_upd) = updates
575                .iter()
576                .map(|(updates, index)| updates.index(*index))
577                .enumerate()
578                .min_by(|(_a_index, a), (_b_index, b)| a.cmp(b))
579                .unwrap();
580            updates[index].1 += 1;
581            if updates[index].1 >= updates[index].0.len() {
582                updates.remove(index);
583            }
584
585            let (key, upd) = key_upd.split();
586
587            // We finished processing updates for the previous key. Push them to the
588            // builder and generate a retraction for the new key.
589            if cur_key.get() != Some(key) {
590                // Push updates for the previous key to the builder.
591                if let Some(cur_key) = cur_key.get_mut() {
592                    if let Some(val) = cur_val.get_mut() {
593                        key_updates.push_with(&mut |item| {
594                            let (v, w) = item.split_mut();
595
596                            val.move_to(v);
597                            **w = HasOne::one();
598                        });
599                    }
600                    key_updates.consolidate();
601                    if !key_updates.is_empty() {
602                        for pair in key_updates.dyn_iter_mut() {
603                            let (v, d) = pair.split_mut();
604                            builder.push_val_diff_mut(v, d);
605                        }
606                        builder.push_key(cur_key);
607                    }
608                    key_updates.clear();
609                }
610
611                cur_key.from_ref(key);
612                cur_val.set_none();
613
614                // Generate retraction if `key` is present in the trace.
615                if trace_cursor.seek_key_exact(key, None) {
616                    // println!("{}: found key in trace_cursor", Runtime::worker_index());
617                    while trace_cursor.val_valid() {
618                        let weight = **trace_cursor.weight();
619
620                        if !weight.is_zero() {
621                            let val = trace_cursor.val();
622
623                            key_updates.push_with(&mut |item| {
624                                let (v, w) = item.split_mut();
625
626                                val.clone_to(v);
627                                **w = weight.neg()
628                            });
629                            cur_val.from_ref(val);
630                        }
631
632                        trace_cursor.step_val();
633                    }
634                }
635            }
636
637            match upd.get() {
638                UpdateRef::Delete => {
639                    // TODO: if cur_val.is_none(), report missing key.
640                    cur_val.set_none();
641                }
642                UpdateRef::Insert(val) => {
643                    cur_val.from_ref(val);
644                }
645                UpdateRef::Update(upd) => {
646                    if let Some(val) = cur_val.get_mut() {
647                        (self.patch_func)(val, upd);
648                    } else {
649                        // TODO: report missing key.
650                    }
651                }
652            }
653        }
654
655        // Push updates for the last key.
656        if let Some(cur_key) = cur_key.get_mut() {
657            if let Some(val) = cur_val.get_mut() {
658                key_updates.push_with(&mut |item| {
659                    let (v, w) = item.split_mut();
660
661                    val.move_to(v);
662                    **w = HasOne::one();
663                });
664            }
665
666            key_updates.consolidate();
667            if !key_updates.is_empty() {
668                for pair in key_updates.dyn_iter_mut() {
669                    let (v, d) = pair.split_mut();
670                    builder.push_val_diff_mut(v, d);
671                }
672                builder.push_key(cur_key);
673            }
674            key_updates.clear();
675        }
676
677        builder.done()
678    }
679
680    fn input_preference(&self) -> (OwnershipPreference, OwnershipPreference) {
681        (
682            OwnershipPreference::PREFER_OWNED,
683            OwnershipPreference::PREFER_OWNED,
684        )
685    }
686}
687
688pub struct InputUpsertWithWaterline<T, U, B, W, E>
689where
690    T: BatchReader,
691    B: IndexedZSet,
692    U: DataTrait + ?Sized,
693    W: DataTrait + ?Sized,
694    E: DataTrait + ?Sized,
695{
696    factories: InputUpsertWithWaterlineFactories<B, E>,
697    patch_func: PatchFunc<T::Val, U>,
698    filter_func: Box<dyn Fn(&W, &B::Key, &B::Val) -> bool>,
699    report_func: Box<dyn Fn(&W, &B::Key, &B::Val, ZWeight, &mut E)>,
700    error_stream_val: RefStreamValue<OrdZSet<E>>,
701
702    // Input batch sizes.
703    input_batch_stats: BatchSizeStats,
704
705    // Output batch sizes.
706    output_batch_stats: BatchSizeStats,
707
708    phantom: PhantomData<B>,
709}
710
711impl<T, U, B, W, E> InputUpsertWithWaterline<T, U, B, W, E>
712where
713    T: BatchReader,
714    B: IndexedZSet,
715    U: DataTrait + ?Sized,
716    W: DataTrait + ?Sized,
717    E: DataTrait + ?Sized,
718{
719    pub fn new(
720        factories: InputUpsertWithWaterlineFactories<B, E>,
721        patch_func: PatchFunc<T::Val, U>,
722        filter_func: Box<dyn Fn(&W, &B::Key, &B::Val) -> bool>,
723        report_func: Box<dyn Fn(&W, &B::Key, &B::Val, ZWeight, &mut E)>,
724        error_stream_val: RefStreamValue<OrdZSet<E>>,
725    ) -> Self {
726        Self {
727            factories,
728            patch_func,
729            filter_func,
730            report_func,
731            error_stream_val,
732            input_batch_stats: BatchSizeStats::new(),
733            output_batch_stats: BatchSizeStats::new(),
734            phantom: PhantomData,
735        }
736    }
737
738    fn passes_filter(&self, waterline: &W, key: &B::Key, val: &B::Val) -> bool {
739        (self.filter_func)(waterline, key, val)
740    }
741}
742
743impl<T, U, B, W, E> Operator for InputUpsertWithWaterline<T, U, B, W, E>
744where
745    T: BatchReader,
746    U: DataTrait + ?Sized,
747    B: IndexedZSet,
748    W: DataTrait + ?Sized,
749    E: DataTrait + ?Sized,
750{
751    fn name(&self) -> Cow<'static, str> {
752        Cow::from("InputUpsertWithWaterline")
753    }
754
755    fn metadata(&self, meta: &mut OperatorMeta) {
756        meta.extend(metadata! {
757            INPUT_BATCHES_LABEL => self.input_batch_stats.metadata(),
758            OUTPUT_BATCHES_LABEL => self.output_batch_stats.metadata(),
759        });
760    }
761
762    fn fixedpoint(&self, _scope: Scope) -> bool {
763        true
764    }
765}
766
767impl<T, U, B, W, E> TernaryOperator<T, Vec<Box<DynPairs<T::Key, DynUpdate<T::Val, U>>>>, Box<W>, B>
768    for InputUpsertWithWaterline<T, U, B, W, E>
769where
770    T: ZTrace<Time = ()> + Clone,
771    U: DataTrait + ?Sized,
772    B: IndexedZSet<Key = T::Key, Val = T::Val>,
773    W: DataTrait + ?Sized,
774    Box<W>: Clone,
775    E: DataTrait + ?Sized,
776{
777    #[trace]
778    async fn eval(
779        &mut self,
780        trace: Cow<'_, T>,
781        updates: Cow<'_, Vec<Box<DynPairs<T::Key, DynUpdate<T::Val, U>>>>>,
782        waterline: Cow<'_, Box<W>>,
783    ) -> B {
784        // Inputs must be sorted by key
785        let mut updates = updates
786            .iter()
787            .filter_map(|updates| {
788                if !updates.is_empty() {
789                    Some((updates, 0))
790                } else {
791                    None
792                }
793            })
794            .collect::<Vec<_>>();
795        let n_updates = updates.iter().map(|updates| updates.0.len()).sum();
796        debug_assert!(updates
797            .iter()
798            .all(|updates| updates.0.is_sorted_by(&|u1, u2| u1.fst().cmp(u2.fst()))));
799
800        self.input_batch_stats.add_batch(n_updates);
801
802        let mut errors = self
803            .factories
804            .errors_factory
805            .weighted_items_factory()
806            .default_box();
807
808        let waterline = waterline.deref();
809        let mut key_updates = self
810            .factories
811            .batch_factories
812            .weighted_vals_factory()
813            .default_box();
814
815        let mut trace_cursor = trace.deref().cursor();
816
817        let mut builder = B::Builder::with_capacity(
818            &self.factories.batch_factories,
819            n_updates * 2,
820            n_updates * 2,
821        );
822
823        // Current key for which we are processing updates.
824        let mut cur_key: Box<DynOpt<T::Key>> = self.factories.opt_key_factory.default_box();
825
826        // Current value associated with the key after applying all processed updates
827        // to it.
828        let mut cur_val: Box<DynOpt<T::Val>> = self.factories.opt_val_factory.default_box();
829        let mut tmp_val: Box<T::Val> = self.factories.val_factory.default_box();
830
831        // Set to true when the value associated with the current key doesn't
832        // satisfy `val_filter`, hence refuse to remove this value and process
833        // all updates for this key.
834        let mut skip_key = false;
835
836        while !updates.is_empty() {
837            let (index, key_upd) = updates
838                .iter()
839                .map(|(updates, index)| updates.index(*index))
840                .enumerate()
841                .min_by(|(_a_index, a), (_b_index, b)| a.cmp(b))
842                .unwrap();
843            updates[index].1 += 1;
844            if updates[index].1 >= updates[index].0.len() {
845                updates.remove(index);
846            }
847
848            let (key, upd) = key_upd.split();
849
850            // We finished processing updates for the previous key. Push them to the
851            // builder and generate a retraction for the new key.
852            if cur_key.get() != Some(key) {
853                // Push updates for the previous key to the builder.
854                if let Some(cur_key) = cur_key.get_mut() {
855                    if let Some(val) = cur_val.get_mut() {
856                        key_updates.push_with(&mut |item| {
857                            let (v, w) = item.split_mut();
858
859                            val.move_to(v);
860                            **w = HasOne::one();
861                        });
862                    }
863                    key_updates.consolidate();
864                    if !key_updates.is_empty() {
865                        for pair in key_updates.dyn_iter_mut() {
866                            let (v, d) = pair.split_mut();
867                            builder.push_val_diff_mut(v, d);
868                        }
869                        builder.push_key(cur_key);
870                    }
871                    key_updates.clear();
872                }
873
874                skip_key = false;
875                cur_key.from_ref(key);
876                cur_val.set_none();
877
878                // Generate retraction if `key` is present in the trace.
879                if trace_cursor.seek_key_exact(key, None) {
880                    // println!("{}: found key in trace_cursor", Runtime::worker_index());
881                    while trace_cursor.val_valid() {
882                        let weight = **trace_cursor.weight();
883
884                        if !weight.is_zero() {
885                            let val = trace_cursor.val();
886
887                            if self.passes_filter(waterline, key, val) {
888                                key_updates.push_with(&mut |item| {
889                                    let (v, w) = item.split_mut();
890
891                                    val.clone_to(v);
892                                    **w = weight.neg()
893                                });
894                                cur_val.from_ref(val);
895                            } else {
896                                skip_key = true;
897                                errors.push_with(&mut |item| {
898                                    let (kv, err_weight) = item.split_mut();
899                                    **err_weight = HasOne::one();
900                                    (self.report_func)(
901                                        waterline,
902                                        key,
903                                        val,
904                                        weight.neg(),
905                                        kv.fst_mut(),
906                                    );
907                                });
908                            }
909                        }
910
911                        trace_cursor.step_val();
912                    }
913                }
914            }
915
916            if !skip_key {
917                match upd.get() {
918                    UpdateRef::Delete => {
919                        // TODO: if cur_val.is_none(), report missing key.
920                        cur_val.set_none();
921                    }
922                    UpdateRef::Insert(val) => {
923                        if self.passes_filter(waterline, key, val) {
924                            cur_val.from_ref(val);
925                        } else {
926                            errors.push_with(&mut |item| {
927                                let (kv, err_weight) = item.split_mut();
928                                **err_weight = HasOne::one();
929                                (self.report_func)(waterline, key, val, 1, kv.fst_mut());
930                            });
931                        }
932                    }
933                    UpdateRef::Update(upd) => {
934                        if let Some(val) = cur_val.get_mut() {
935                            val.clone_to(&mut tmp_val);
936                            (self.patch_func)(&mut tmp_val, upd);
937                            if !self.passes_filter(waterline, key, &tmp_val) {
938                                errors.push_with(&mut |item| {
939                                    let (kv, err_weight) = item.split_mut();
940                                    **err_weight = HasOne::one();
941                                    (self.report_func)(waterline, key, &tmp_val, 1, kv.fst_mut());
942                                });
943                            } else {
944                                tmp_val.clone_to(val);
945                            }
946                        } else {
947                            // TODO: report missing key.
948                        }
949                    }
950                }
951            }
952        }
953
954        // Push updates for the last key.
955        if let Some(cur_key) = cur_key.get_mut() {
956            if let Some(val) = cur_val.get_mut() {
957                key_updates.push_with(&mut |item| {
958                    let (v, w) = item.split_mut();
959
960                    val.move_to(v);
961                    **w = HasOne::one();
962                });
963            }
964
965            key_updates.consolidate();
966            if !key_updates.is_empty() {
967                for pair in key_updates.dyn_iter_mut() {
968                    let (v, d) = pair.split_mut();
969                    builder.push_val_diff_mut(v, d);
970                }
971                builder.push_key(cur_key);
972            }
973            key_updates.clear();
974        }
975
976        let errors = <OrdZSet<E>>::dyn_from_tuples(&self.factories.errors_factory, (), &mut errors);
977        self.error_stream_val.put(errors);
978
979        let result = builder.done();
980        self.output_batch_stats.add_batch(result.len());
981        result
982    }
983
984    fn input_preference(
985        &self,
986    ) -> (
987        OwnershipPreference,
988        OwnershipPreference,
989        OwnershipPreference,
990    ) {
991        (
992            OwnershipPreference::PREFER_OWNED,
993            OwnershipPreference::PREFER_OWNED,
994            OwnershipPreference::INDIFFERENT,
995        )
996    }
997}