1use crate::{
4 algebra::{AddAssignByRef, AddByRef, NegByRef},
5 circuit::{
6 operator_traits::{BinaryOperator, Operator},
7 Circuit, OwnershipPreference, Scope, Stream,
8 },
9};
10use std::{borrow::Cow, marker::PhantomData, ops::Neg};
11
12impl<C, D> Stream<C, D>
13where
14 C: Circuit,
15 D: AddByRef + AddAssignByRef + Clone + 'static,
16{
17 #[track_caller]
56 pub fn plus(&self, other: &Stream<C, D>) -> Stream<C, D> {
57 if self.has_sharded_version() && other.has_sharded_version() {
60 self.circuit()
61 .add_binary_operator(
62 Plus::new(),
63 &self.try_sharded_version(),
64 &other.try_sharded_version(),
65 )
66 .mark_sharded()
67 } else {
68 self.circuit().add_binary_operator(Plus::new(), self, other)
69 }
70 }
71}
72
73impl<C, D> Stream<C, D>
74where
75 C: Circuit,
76 D: AddByRef + AddAssignByRef + Neg<Output = D> + NegByRef + Clone + 'static,
77{
78 #[track_caller]
82 pub fn minus(&self, other: &Stream<C, D>) -> Stream<C, D> {
83 if self.has_sharded_version() && other.has_sharded_version() {
86 self.circuit()
87 .add_binary_operator(
88 Minus::new(),
89 &self.try_sharded_version(),
90 &other.try_sharded_version(),
91 )
92 .mark_sharded()
93 } else {
94 self.circuit()
95 .add_binary_operator(Minus::new(), self, other)
96 }
97 }
98}
99
100pub struct Plus<D> {
105 phantom: PhantomData<D>,
106}
107
108impl<D> Default for Plus<D> {
109 fn default() -> Self {
110 Self {
111 phantom: PhantomData,
112 }
113 }
114}
115
116impl<D> Plus<D> {
117 pub const fn new() -> Self {
118 Self {
119 phantom: PhantomData,
120 }
121 }
122}
123
124impl<D> Operator for Plus<D>
125where
126 D: 'static,
127{
128 fn name(&self) -> Cow<'static, str> {
129 Cow::from("Plus")
130 }
131
132 fn fixedpoint(&self, _scope: Scope) -> bool {
133 true
134 }
135}
136
137impl<D> BinaryOperator<D, D, D> for Plus<D>
138where
139 D: AddByRef + AddAssignByRef + Clone + 'static,
140{
141 async fn eval(&mut self, i1: &D, i2: &D) -> D {
142 i1.add_by_ref(i2)
143 }
144
145 async fn eval_owned_and_ref(&mut self, mut i1: D, i2: &D) -> D {
146 i1.add_assign_by_ref(i2);
147 i1
148 }
149
150 async fn eval_ref_and_owned(&mut self, i1: &D, mut i2: D) -> D {
151 i2.add_assign_by_ref(i1);
152 i2
153 }
154
155 async fn eval_owned(&mut self, i1: D, i2: D) -> D {
156 i1.add_by_ref(&i2)
157 }
158
159 fn input_preference(&self) -> (OwnershipPreference, OwnershipPreference) {
160 (
161 OwnershipPreference::PREFER_OWNED,
162 OwnershipPreference::PREFER_OWNED,
163 )
164 }
165}
166
167pub struct Minus<D> {
170 phantom: PhantomData<D>,
171}
172
173impl<D> Default for Minus<D> {
174 fn default() -> Self {
175 Self {
176 phantom: PhantomData,
177 }
178 }
179}
180
181impl<D> Minus<D> {
182 pub const fn new() -> Self {
183 Self {
184 phantom: PhantomData,
185 }
186 }
187}
188
189impl<D> Operator for Minus<D>
190where
191 D: 'static,
192{
193 fn name(&self) -> Cow<'static, str> {
194 Cow::from("Minus")
195 }
196
197 fn fixedpoint(&self, _scope: Scope) -> bool {
198 true
199 }
200}
201
202impl<D> BinaryOperator<D, D, D> for Minus<D>
205where
206 D: AddByRef + AddAssignByRef + Neg<Output = D> + NegByRef + Clone + 'static,
207{
208 async fn eval(&mut self, i1: &D, i2: &D) -> D {
209 let mut i2neg = i2.neg_by_ref();
210 i2neg.add_assign_by_ref(i1);
211 i2neg
212 }
213
214 async fn eval_owned_and_ref(&mut self, i1: D, i2: &D) -> D {
215 i1.add_by_ref(&i2.neg_by_ref())
216 }
217
218 async fn eval_ref_and_owned(&mut self, i1: &D, i2: D) -> D {
219 i2.neg().add_by_ref(i1)
220 }
221
222 async fn eval_owned(&mut self, i1: D, i2: D) -> D {
223 i1.add_by_ref(&i2.neg())
224 }
225
226 fn input_preference(&self) -> (OwnershipPreference, OwnershipPreference) {
227 (
228 OwnershipPreference::PREFER_OWNED,
229 OwnershipPreference::PREFER_OWNED,
230 )
231 }
232}
233
234#[cfg(test)]
235mod test {
236 use crate::{
237 algebra::HasZero,
238 circuit::OwnershipPreference,
239 operator::{Generator, Inspect},
240 typed_batch::OrdZSet,
241 zset, Circuit, RootCircuit,
242 };
243
244 #[test]
245 fn scalar_plus() {
246 let circuit = RootCircuit::build(move |circuit| {
247 let mut n = 0;
248 let source1 = circuit.add_source(Generator::new(move || {
249 let res = n;
250 n += 1;
251 res
252 }));
253 let mut n = 100;
254 let source2 = circuit.add_source(Generator::new(move || {
255 let res = n;
256 n -= 1;
257 res
258 }));
259 source1.plus(&source2).inspect(|n| assert_eq!(*n, 100));
260 Ok(())
261 })
262 .unwrap()
263 .0;
264
265 for _ in 0..100 {
266 circuit.transaction().unwrap();
267 }
268 }
269
270 #[test]
271 #[cfg_attr(miri, ignore)]
272 fn zset_plus() {
273 let build_plus_circuit = |circuit: &RootCircuit| {
274 let mut s = <OrdZSet<_>>::zero();
275 let delta = zset! { 5 => 1};
276 let source1 = circuit.add_source(Generator::new(move || {
277 s = s.merge(&delta);
278 s.clone()
279 }));
280 let mut s = <OrdZSet<_>>::zero();
281 let delta = zset! { 5 => -1};
282 let source2 = circuit.add_source(Generator::new(move || {
283 s = s.merge(&delta);
284 s.clone()
285 }));
286 source1
287 .plus(&source2)
288 .inspect(|s| assert_eq!(s, &<OrdZSet<u64>>::zero()));
289 (source1, source2)
290 };
291
292 let build_minus_circuit = |circuit: &RootCircuit| {
293 let mut s = <OrdZSet<_>>::zero();
294 let delta = zset! { 5 => 1};
295 let source1 = circuit.add_source(Generator::new(move || {
296 s = s.merge(&delta);
297 s.clone()
298 }));
299 let mut s = <OrdZSet<_>>::zero();
300 let delta = zset! { 5 => 1};
301 let source2 = circuit.add_source(Generator::new(move || {
302 s = s.merge(&delta);
303 s.clone()
304 }));
305 source1
306 .minus(&source2)
307 .inspect(|s| assert_eq!(s, &<OrdZSet<_>>::zero()));
308 (source1, source2)
309 };
310 let circuit = RootCircuit::build(move |circuit| {
312 build_plus_circuit(circuit);
313 build_minus_circuit(circuit);
314 Ok(())
315 })
316 .unwrap()
317 .0;
318
319 for _ in 0..100 {
320 circuit.transaction().unwrap();
321 }
322
323 let circuit = RootCircuit::build(move |circuit| {
325 let (source1, _source2) = build_plus_circuit(circuit);
326 circuit.add_unary_operator_with_preference(
327 Inspect::new(|_| {}),
328 &source1,
329 OwnershipPreference::STRONGLY_PREFER_OWNED,
330 );
331 let (source3, _source4) = build_minus_circuit(circuit);
332 circuit.add_unary_operator_with_preference(
333 Inspect::new(|_| {}),
334 &source3,
335 OwnershipPreference::STRONGLY_PREFER_OWNED,
336 );
337 Ok(())
338 })
339 .unwrap()
340 .0;
341
342 for _ in 0..100 {
343 circuit.transaction().unwrap();
344 }
345
346 let circuit = RootCircuit::build(move |circuit| {
348 let (_source1, source2) = build_plus_circuit(circuit);
349 circuit.add_unary_operator_with_preference(
350 Inspect::new(|_| {}),
351 &source2,
352 OwnershipPreference::STRONGLY_PREFER_OWNED,
353 );
354
355 let (_source3, source4) = build_minus_circuit(circuit);
356 circuit.add_unary_operator_with_preference(
357 Inspect::new(|_| {}),
358 &source4,
359 OwnershipPreference::STRONGLY_PREFER_OWNED,
360 );
361 Ok(())
362 })
363 .unwrap()
364 .0;
365
366 for _ in 0..100 {
367 circuit.transaction().unwrap();
368 }
369
370 let circuit = RootCircuit::build(move |circuit| {
372 let (source1, source2) = build_plus_circuit(circuit);
373 circuit.add_unary_operator_with_preference(
374 Inspect::new(|_| {}),
375 &source1,
376 OwnershipPreference::STRONGLY_PREFER_OWNED,
377 );
378 circuit.add_unary_operator_with_preference(
379 Inspect::new(|_| {}),
380 &source2,
381 OwnershipPreference::STRONGLY_PREFER_OWNED,
382 );
383
384 let (source3, source4) = build_minus_circuit(circuit);
385 circuit.add_unary_operator_with_preference(
386 Inspect::new(|_| {}),
387 &source3,
388 OwnershipPreference::STRONGLY_PREFER_OWNED,
389 );
390 circuit.add_unary_operator_with_preference(
391 Inspect::new(|_| {}),
392 &source4,
393 OwnershipPreference::STRONGLY_PREFER_OWNED,
394 );
395 Ok(())
396 })
397 .unwrap()
398 .0;
399
400 for _ in 0..100 {
401 circuit.transaction().unwrap();
402 }
403 }
404}