dbsp/operator/
plus.rs

1//! Binary plus and minus operators.
2
3use crate::{
4    algebra::{AddAssignByRef, AddByRef, NegByRef},
5    circuit::{
6        operator_traits::{BinaryOperator, Operator},
7        Circuit, OwnershipPreference, Scope, Stream,
8    },
9};
10use std::{borrow::Cow, marker::PhantomData, ops::Neg};
11
12impl<C, D> Stream<C, D>
13where
14    C: Circuit,
15    D: AddByRef + AddAssignByRef + Clone + 'static,
16{
17    /// Apply the [`Plus`] operator to `self` and `other`.
18    /// Adding two indexed Z-sets adds the weights of matching key-value pairs.
19    ///
20    /// The stream type's addition operation must be commutative.
21    ///
22    /// # Examples
23    ///
24    /// ```
25    /// # use dbsp::{
26    /// #   operator::Generator,
27    /// #   Circuit, RootCircuit,
28    /// # };
29    /// let circuit = RootCircuit::build(move |circuit| {
30    ///     // Stream of non-negative values: 0, 1, 2, ...
31    ///     let mut n = 0;
32    ///     let source1 = circuit.add_source(Generator::new(move || {
33    ///         let res = n;
34    ///         n += 1;
35    ///         res
36    ///     }));
37    ///     // Stream of non-positive values: 0, -1, -2, ...
38    ///     let mut n = 0;
39    ///     let source2 = circuit.add_source(Generator::new(move || {
40    ///         let res = n;
41    ///         n -= 1;
42    ///         res
43    ///     }));
44    ///     // Compute pairwise sums of values in the stream; the output stream will contain zeros.
45    ///     source1.plus(&source2).inspect(|n| assert_eq!(*n, 0));
46    ///     Ok(())
47    /// })
48    /// .unwrap()
49    /// .0;
50    ///
51    /// # for _ in 0..5 {
52    /// #     circuit.transaction().unwrap();
53    /// # }
54    /// ```
55    #[track_caller]
56    pub fn plus(&self, other: &Stream<C, D>) -> Stream<C, D> {
57        // If both inputs are properly sharded then the sum of those inputs will be
58        // sharded
59        if self.has_sharded_version() && other.has_sharded_version() {
60            self.circuit()
61                .add_binary_operator(
62                    Plus::new(),
63                    &self.try_sharded_version(),
64                    &other.try_sharded_version(),
65                )
66                .mark_sharded()
67        } else {
68            self.circuit().add_binary_operator(Plus::new(), self, other)
69        }
70    }
71}
72
73impl<C, D> Stream<C, D>
74where
75    C: Circuit,
76    D: AddByRef + AddAssignByRef + Neg<Output = D> + NegByRef + Clone + 'static,
77{
78    /// Apply the [`Minus`] operator to `self` and `other`.
79    /// Subtracting two indexed Z-sets subtracts the weights of matching
80    /// key-value pairs.
81    #[track_caller]
82    pub fn minus(&self, other: &Stream<C, D>) -> Stream<C, D> {
83        // If both inputs are properly sharded then the difference of those inputs will
84        // be sharded
85        if self.has_sharded_version() && other.has_sharded_version() {
86            self.circuit()
87                .add_binary_operator(
88                    Minus::new(),
89                    &self.try_sharded_version(),
90                    &other.try_sharded_version(),
91                )
92                .mark_sharded()
93        } else {
94            self.circuit()
95                .add_binary_operator(Minus::new(), self, other)
96        }
97    }
98}
99
100/// Operator that computes the sum of values in its two input streams at each
101/// timestamp.
102///
103/// The stream type's addition operation must be commutative.
104pub struct Plus<D> {
105    phantom: PhantomData<D>,
106}
107
108impl<D> Default for Plus<D> {
109    fn default() -> Self {
110        Self {
111            phantom: PhantomData,
112        }
113    }
114}
115
116impl<D> Plus<D> {
117    pub const fn new() -> Self {
118        Self {
119            phantom: PhantomData,
120        }
121    }
122}
123
124impl<D> Operator for Plus<D>
125where
126    D: 'static,
127{
128    fn name(&self) -> Cow<'static, str> {
129        Cow::from("Plus")
130    }
131
132    fn fixedpoint(&self, _scope: Scope) -> bool {
133        true
134    }
135}
136
137impl<D> BinaryOperator<D, D, D> for Plus<D>
138where
139    D: AddByRef + AddAssignByRef + Clone + 'static,
140{
141    async fn eval(&mut self, i1: &D, i2: &D) -> D {
142        i1.add_by_ref(i2)
143    }
144
145    async fn eval_owned_and_ref(&mut self, mut i1: D, i2: &D) -> D {
146        i1.add_assign_by_ref(i2);
147        i1
148    }
149
150    async fn eval_ref_and_owned(&mut self, i1: &D, mut i2: D) -> D {
151        i2.add_assign_by_ref(i1);
152        i2
153    }
154
155    async fn eval_owned(&mut self, i1: D, i2: D) -> D {
156        i1.add_by_ref(&i2)
157    }
158
159    fn input_preference(&self) -> (OwnershipPreference, OwnershipPreference) {
160        (
161            OwnershipPreference::PREFER_OWNED,
162            OwnershipPreference::PREFER_OWNED,
163        )
164    }
165}
166
167/// Operator that computes the difference of values in its two input streams at
168/// each timestamp.
169pub struct Minus<D> {
170    phantom: PhantomData<D>,
171}
172
173impl<D> Default for Minus<D> {
174    fn default() -> Self {
175        Self {
176            phantom: PhantomData,
177        }
178    }
179}
180
181impl<D> Minus<D> {
182    pub const fn new() -> Self {
183        Self {
184            phantom: PhantomData,
185        }
186    }
187}
188
189impl<D> Operator for Minus<D>
190where
191    D: 'static,
192{
193    fn name(&self) -> Cow<'static, str> {
194        Cow::from("Minus")
195    }
196
197    fn fixedpoint(&self, _scope: Scope) -> bool {
198        true
199    }
200}
201
202// TODO: Add `subtract` operation to `GroupValue`, which
203// can be more efficient than negate followed by plus.
204impl<D> BinaryOperator<D, D, D> for Minus<D>
205where
206    D: AddByRef + AddAssignByRef + Neg<Output = D> + NegByRef + Clone + 'static,
207{
208    async fn eval(&mut self, i1: &D, i2: &D) -> D {
209        let mut i2neg = i2.neg_by_ref();
210        i2neg.add_assign_by_ref(i1);
211        i2neg
212    }
213
214    async fn eval_owned_and_ref(&mut self, i1: D, i2: &D) -> D {
215        i1.add_by_ref(&i2.neg_by_ref())
216    }
217
218    async fn eval_ref_and_owned(&mut self, i1: &D, i2: D) -> D {
219        i2.neg().add_by_ref(i1)
220    }
221
222    async fn eval_owned(&mut self, i1: D, i2: D) -> D {
223        i1.add_by_ref(&i2.neg())
224    }
225
226    fn input_preference(&self) -> (OwnershipPreference, OwnershipPreference) {
227        (
228            OwnershipPreference::PREFER_OWNED,
229            OwnershipPreference::PREFER_OWNED,
230        )
231    }
232}
233
234#[cfg(test)]
235mod test {
236    use crate::{
237        algebra::HasZero,
238        circuit::OwnershipPreference,
239        operator::{Generator, Inspect},
240        typed_batch::OrdZSet,
241        zset, Circuit, RootCircuit,
242    };
243
244    #[test]
245    fn scalar_plus() {
246        let circuit = RootCircuit::build(move |circuit| {
247            let mut n = 0;
248            let source1 = circuit.add_source(Generator::new(move || {
249                let res = n;
250                n += 1;
251                res
252            }));
253            let mut n = 100;
254            let source2 = circuit.add_source(Generator::new(move || {
255                let res = n;
256                n -= 1;
257                res
258            }));
259            source1.plus(&source2).inspect(|n| assert_eq!(*n, 100));
260            Ok(())
261        })
262        .unwrap()
263        .0;
264
265        for _ in 0..100 {
266            circuit.transaction().unwrap();
267        }
268    }
269
270    #[test]
271    #[cfg_attr(miri, ignore)]
272    fn zset_plus() {
273        let build_plus_circuit = |circuit: &RootCircuit| {
274            let mut s = <OrdZSet<_>>::zero();
275            let delta = zset! { 5 => 1};
276            let source1 = circuit.add_source(Generator::new(move || {
277                s = s.merge(&delta);
278                s.clone()
279            }));
280            let mut s = <OrdZSet<_>>::zero();
281            let delta = zset! { 5 => -1};
282            let source2 = circuit.add_source(Generator::new(move || {
283                s = s.merge(&delta);
284                s.clone()
285            }));
286            source1
287                .plus(&source2)
288                .inspect(|s| assert_eq!(s, &<OrdZSet<u64>>::zero()));
289            (source1, source2)
290        };
291
292        let build_minus_circuit = |circuit: &RootCircuit| {
293            let mut s = <OrdZSet<_>>::zero();
294            let delta = zset! { 5 => 1};
295            let source1 = circuit.add_source(Generator::new(move || {
296                s = s.merge(&delta);
297                s.clone()
298            }));
299            let mut s = <OrdZSet<_>>::zero();
300            let delta = zset! { 5 => 1};
301            let source2 = circuit.add_source(Generator::new(move || {
302                s = s.merge(&delta);
303                s.clone()
304            }));
305            source1
306                .minus(&source2)
307                .inspect(|s| assert_eq!(s, &<OrdZSet<_>>::zero()));
308            (source1, source2)
309        };
310        // Allow `Plus` to consume both streams by value.
311        let circuit = RootCircuit::build(move |circuit| {
312            build_plus_circuit(circuit);
313            build_minus_circuit(circuit);
314            Ok(())
315        })
316        .unwrap()
317        .0;
318
319        for _ in 0..100 {
320            circuit.transaction().unwrap();
321        }
322
323        // Only consume source2 by value.
324        let circuit = RootCircuit::build(move |circuit| {
325            let (source1, _source2) = build_plus_circuit(circuit);
326            circuit.add_unary_operator_with_preference(
327                Inspect::new(|_| {}),
328                &source1,
329                OwnershipPreference::STRONGLY_PREFER_OWNED,
330            );
331            let (source3, _source4) = build_minus_circuit(circuit);
332            circuit.add_unary_operator_with_preference(
333                Inspect::new(|_| {}),
334                &source3,
335                OwnershipPreference::STRONGLY_PREFER_OWNED,
336            );
337            Ok(())
338        })
339        .unwrap()
340        .0;
341
342        for _ in 0..100 {
343            circuit.transaction().unwrap();
344        }
345
346        // Only consume source1 by value.
347        let circuit = RootCircuit::build(move |circuit| {
348            let (_source1, source2) = build_plus_circuit(circuit);
349            circuit.add_unary_operator_with_preference(
350                Inspect::new(|_| {}),
351                &source2,
352                OwnershipPreference::STRONGLY_PREFER_OWNED,
353            );
354
355            let (_source3, source4) = build_minus_circuit(circuit);
356            circuit.add_unary_operator_with_preference(
357                Inspect::new(|_| {}),
358                &source4,
359                OwnershipPreference::STRONGLY_PREFER_OWNED,
360            );
361            Ok(())
362        })
363        .unwrap()
364        .0;
365
366        for _ in 0..100 {
367            circuit.transaction().unwrap();
368        }
369
370        // Consume both streams by reference.
371        let circuit = RootCircuit::build(move |circuit| {
372            let (source1, source2) = build_plus_circuit(circuit);
373            circuit.add_unary_operator_with_preference(
374                Inspect::new(|_| {}),
375                &source1,
376                OwnershipPreference::STRONGLY_PREFER_OWNED,
377            );
378            circuit.add_unary_operator_with_preference(
379                Inspect::new(|_| {}),
380                &source2,
381                OwnershipPreference::STRONGLY_PREFER_OWNED,
382            );
383
384            let (source3, source4) = build_minus_circuit(circuit);
385            circuit.add_unary_operator_with_preference(
386                Inspect::new(|_| {}),
387                &source3,
388                OwnershipPreference::STRONGLY_PREFER_OWNED,
389            );
390            circuit.add_unary_operator_with_preference(
391                Inspect::new(|_| {}),
392                &source4,
393                OwnershipPreference::STRONGLY_PREFER_OWNED,
394            );
395            Ok(())
396        })
397        .unwrap()
398        .0;
399
400        for _ in 0..100 {
401            circuit.transaction().unwrap();
402        }
403    }
404}