1use std::ops::{Add, Mul, Neg, Sub};
5
6use ark_ec::CurveGroup;
7use itertools::Itertools;
8
9use crate::{
10 algebra::macros::*,
11 algebra::BatchScalarResult,
12 algebra::{CurvePoint, CurvePointResult, MpcPointResult},
13 fabric::{MpcFabric, ResultValue},
14 network::NetworkPayload,
15 PARTY0,
16};
17
18use super::scalar::{Scalar, ScalarResult};
19
20#[derive(Clone, Debug)]
22pub struct MpcScalarResult<C: CurveGroup> {
23 pub(crate) share: ScalarResult<C>,
25}
26
27impl<C: CurveGroup> From<ScalarResult<C>> for MpcScalarResult<C> {
28 fn from(share: ScalarResult<C>) -> Self {
29 Self { share }
30 }
31}
32
33impl<C: CurveGroup> MpcScalarResult<C> {
36 pub fn new_shared(value: ScalarResult<C>) -> MpcScalarResult<C> {
39 value.into()
40 }
41
42 pub fn id(&self) -> usize {
44 self.share.id
45 }
46
47 pub fn fabric(&self) -> &MpcFabric<C> {
49 self.share.fabric()
50 }
51
52 pub fn open(&self) -> ScalarResult<C> {
54 let (val0, val1) = if self.fabric().party_id() == PARTY0 {
56 let party0_value: ScalarResult<C> =
57 self.fabric().new_network_op(vec![self.id()], |args| {
58 let share: Scalar<C> = args[0].to_owned().into();
59 NetworkPayload::Scalar(share)
60 });
61 let party1_value: ScalarResult<C> = self.fabric().receive_value();
62
63 (party0_value, party1_value)
64 } else {
65 let party0_value: ScalarResult<C> = self.fabric().receive_value();
66 let party1_value: ScalarResult<C> =
67 self.fabric().new_network_op(vec![self.id()], |args| {
68 let share = args[0].to_owned().into();
69 NetworkPayload::Scalar(share)
70 });
71
72 (party0_value, party1_value)
73 };
74
75 &val0 + &val1
77 }
78
79 pub fn open_batch(values: &[MpcScalarResult<C>]) -> Vec<ScalarResult<C>> {
81 if values.is_empty() {
82 return vec![];
83 }
84
85 let n = values.len();
86 let fabric = &values[0].fabric();
87 let my_results = values.iter().map(|v| v.id()).collect_vec();
88 let send_shares_fn = |args: Vec<ResultValue<C>>| {
89 let shares: Vec<Scalar<C>> = args.into_iter().map(Scalar::from).collect();
90 NetworkPayload::ScalarBatch(shares)
91 };
92
93 let (party0_vals, party1_vals) = if values[0].fabric().party_id() == PARTY0 {
95 let party0_vals: BatchScalarResult<C> =
97 fabric.new_network_op(my_results, send_shares_fn);
98 let party1_vals: BatchScalarResult<C> = fabric.receive_value();
99
100 (party0_vals, party1_vals)
101 } else {
102 let party0_vals: BatchScalarResult<C> = fabric.receive_value();
103 let party1_vals: BatchScalarResult<C> =
104 fabric.new_network_op(my_results, send_shares_fn);
105
106 (party0_vals, party1_vals)
107 };
108
109 fabric.new_batch_gate_op(vec![party0_vals.id, party1_vals.id], n, move |args| {
111 let party0_vals: Vec<Scalar<C>> = args[0].to_owned().into();
112 let party1_vals: Vec<Scalar<C>> = args[1].to_owned().into();
113
114 let mut results = Vec::with_capacity(n);
115 for i in 0..n {
116 results.push(ResultValue::Scalar(party0_vals[i] + party1_vals[i]));
117 }
118
119 results
120 })
121 }
122
123 pub fn to_scalar(&self) -> ScalarResult<C> {
125 self.share.clone()
126 }
127}
128
129impl<C: CurveGroup> Add<&Scalar<C>> for &MpcScalarResult<C> {
136 type Output = MpcScalarResult<C>;
137
138 fn add(self, rhs: &Scalar<C>) -> Self::Output {
140 let rhs = *rhs;
141 let party_id = self.fabric().party_id();
142
143 self.fabric()
144 .new_gate_op(vec![self.id()], move |args| {
145 let lhs_share: Scalar<C> = args[0].to_owned().into();
147 if party_id == PARTY0 {
148 ResultValue::Scalar(lhs_share + rhs)
149 } else {
150 ResultValue::Scalar(lhs_share)
151 }
152 })
153 .into()
154 }
155}
156impl_borrow_variants!(MpcScalarResult<C>, Add, add, +, Scalar<C>, C: CurveGroup);
157impl_commutative!(MpcScalarResult<C>, Add, add, +, Scalar<C>, C: CurveGroup);
158
159impl<C: CurveGroup> Add<&ScalarResult<C>> for &MpcScalarResult<C> {
160 type Output = MpcScalarResult<C>;
161
162 fn add(self, rhs: &ScalarResult<C>) -> Self::Output {
164 let party_id = self.fabric().party_id();
165 self.fabric()
166 .new_gate_op(vec![self.id(), rhs.id], move |mut args| {
167 let lhs: Scalar<C> = args.remove(0).into();
169 let rhs: Scalar<C> = args.remove(0).into();
170
171 if party_id == PARTY0 {
172 ResultValue::Scalar(lhs + rhs)
173 } else {
174 ResultValue::Scalar(lhs)
175 }
176 })
177 .into()
178 }
179}
180impl_borrow_variants!(MpcScalarResult<C>, Add, add, +, ScalarResult<C>, C: CurveGroup);
181impl_commutative!(MpcScalarResult<C>, Add, add, +, ScalarResult<C>, C: CurveGroup);
182
183impl<C: CurveGroup> Add<&MpcScalarResult<C>> for &MpcScalarResult<C> {
184 type Output = MpcScalarResult<C>;
185
186 fn add(self, rhs: &MpcScalarResult<C>) -> Self::Output {
187 self.fabric()
188 .new_gate_op(vec![self.id(), rhs.id()], |args| {
189 let lhs: Scalar<C> = args[0].to_owned().into();
191 let rhs: Scalar<C> = args[1].to_owned().into();
192
193 ResultValue::Scalar(lhs + rhs)
194 })
195 .into()
196 }
197}
198impl_borrow_variants!(MpcScalarResult<C>, Add, add, +, MpcScalarResult<C>, C: CurveGroup);
199
200impl<C: CurveGroup> MpcScalarResult<C> {
201 pub fn batch_add(
203 a: &[MpcScalarResult<C>],
204 b: &[MpcScalarResult<C>],
205 ) -> Vec<MpcScalarResult<C>> {
206 assert_eq!(
207 a.len(),
208 b.len(),
209 "batch_add: a and b must be the same length"
210 );
211
212 let n = a.len();
213 let fabric = a[0].fabric();
214 let ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
215
216 let scalars = fabric.new_batch_gate_op(ids, n , move |args| {
217 let scalars = args.into_iter().map(Scalar::from).collect_vec();
219 let (a_res, b_res) = scalars.split_at(n);
220
221 a_res
223 .iter()
224 .zip(b_res.iter())
225 .map(|(a, b)| ResultValue::Scalar(a + b))
226 .collect_vec()
227 });
228
229 scalars.into_iter().map(|s| s.into()).collect_vec()
230 }
231
232 pub fn batch_add_public(
234 a: &[MpcScalarResult<C>],
235 b: &[ScalarResult<C>],
236 ) -> Vec<MpcScalarResult<C>> {
237 assert_eq!(
238 a.len(),
239 b.len(),
240 "batch_add_public: a and b must be the same length"
241 );
242
243 let n = a.len();
244 let fabric = a[0].fabric();
245 let ids = a
246 .iter()
247 .map(|v| v.id())
248 .chain(b.iter().map(|v| v.id()))
249 .collect_vec();
250
251 let party_id = fabric.party_id();
252 let scalars: Vec<ScalarResult<C>> =
253 fabric.new_batch_gate_op(ids, n , move |args| {
254 if party_id == PARTY0 {
255 let mut res: Vec<ResultValue<C>> = Vec::with_capacity(n);
256
257 for i in 0..n {
258 let lhs: Scalar<C> = args[i].to_owned().into();
259 let rhs: Scalar<C> = args[i + n].to_owned().into();
260
261 res.push(ResultValue::Scalar(lhs + rhs));
262 }
263
264 res
265 } else {
266 args[..n].to_vec()
267 }
268 });
269
270 scalars.into_iter().map(|s| s.into()).collect_vec()
271 }
272}
273
274impl<C: CurveGroup> Sub<&Scalar<C>> for &MpcScalarResult<C> {
277 type Output = MpcScalarResult<C>;
278
279 fn sub(self, rhs: &Scalar<C>) -> Self::Output {
281 let rhs = *rhs;
282 let party_id = self.fabric().party_id();
283
284 if party_id == PARTY0 {
285 &self.share - rhs
286 } else {
287 &self.share - Scalar::zero()
289 }
290 .into()
291 }
292}
293impl_borrow_variants!(MpcScalarResult<C>, Sub, sub, -, Scalar<C>, C: CurveGroup);
294
295impl<C: CurveGroup> Sub<&MpcScalarResult<C>> for &Scalar<C> {
296 type Output = MpcScalarResult<C>;
297
298 fn sub(self, rhs: &MpcScalarResult<C>) -> Self::Output {
300 let party_id = rhs.fabric().party_id();
301
302 if party_id == PARTY0 {
303 self - &rhs.share
304 } else {
305 Scalar::zero() - &rhs.share
307 }
308 .into()
309 }
310}
311
312impl<C: CurveGroup> Sub<&ScalarResult<C>> for &MpcScalarResult<C> {
313 type Output = MpcScalarResult<C>;
314
315 fn sub(self, rhs: &ScalarResult<C>) -> Self::Output {
317 let party_id = self.fabric().party_id();
318
319 if party_id == PARTY0 {
320 &self.share - rhs
321 } else {
322 self.share.clone() + Scalar::zero()
324 }
325 .into()
326 }
327}
328impl_borrow_variants!(MpcScalarResult<C>, Sub, sub, -, ScalarResult<C>, C: CurveGroup);
329
330impl<C: CurveGroup> Sub<&MpcScalarResult<C>> for &ScalarResult<C> {
331 type Output = MpcScalarResult<C>;
332
333 fn sub(self, rhs: &MpcScalarResult<C>) -> Self::Output {
335 let party_id = rhs.fabric().party_id();
336
337 if party_id == PARTY0 {
338 self - &rhs.share
339 } else {
340 Scalar::zero() - rhs.share.clone()
342 }
343 .into()
344 }
345}
346impl_borrow_variants!(ScalarResult<C>, Sub, sub, -, MpcScalarResult<C>, Output=MpcScalarResult<C>, C: CurveGroup);
347
348impl<C: CurveGroup> Sub<&MpcScalarResult<C>> for &MpcScalarResult<C> {
349 type Output = MpcScalarResult<C>;
350
351 fn sub(self, rhs: &MpcScalarResult<C>) -> Self::Output {
352 self.fabric()
353 .new_gate_op(vec![self.id(), rhs.id()], |args| {
354 let lhs: Scalar<C> = args[0].to_owned().into();
356 let rhs: Scalar<C> = args[1].to_owned().into();
357
358 ResultValue::Scalar(lhs - rhs)
359 })
360 .into()
361 }
362}
363impl_borrow_variants!(MpcScalarResult<C>, Sub, sub, -, MpcScalarResult<C>, C: CurveGroup);
364
365impl<C: CurveGroup> MpcScalarResult<C> {
366 pub fn batch_sub(
368 a: &[MpcScalarResult<C>],
369 b: &[MpcScalarResult<C>],
370 ) -> Vec<MpcScalarResult<C>> {
371 assert_eq!(
372 a.len(),
373 b.len(),
374 "batch_sub: a and b must be the same length"
375 );
376
377 let n = a.len();
378 let fabric = a[0].fabric();
379 let ids = a
380 .iter()
381 .map(|v| v.id())
382 .chain(b.iter().map(|v| v.id()))
383 .collect_vec();
384
385 let scalars: Vec<ScalarResult<C>> =
386 fabric.new_batch_gate_op(ids, n , move |args| {
387 let scalars = args.into_iter().map(Scalar::from).collect_vec();
389 let (a_res, b_res) = scalars.split_at(n);
390
391 a_res
393 .iter()
394 .zip(b_res.iter())
395 .map(|(a, b)| ResultValue::Scalar(a - b))
396 .collect_vec()
397 });
398
399 scalars.into_iter().map(|s| s.into()).collect_vec()
400 }
401
402 pub fn batch_sub_public(
405 a: &[MpcScalarResult<C>],
406 b: &[ScalarResult<C>],
407 ) -> Vec<MpcScalarResult<C>> {
408 assert_eq!(
409 a.len(),
410 b.len(),
411 "batch_sub_public: a and b must be the same length"
412 );
413
414 let n = a.len();
415 let fabric = a[0].fabric();
416 let ids = a
417 .iter()
418 .map(|v| v.id())
419 .chain(b.iter().map(|v| v.id()))
420 .collect_vec();
421
422 let party_id = fabric.party_id();
423 let scalars = fabric.new_batch_gate_op(ids, n , move |args| {
424 if party_id == PARTY0 {
425 let mut res: Vec<ResultValue<C>> = Vec::with_capacity(n);
426
427 for i in 0..n {
428 let lhs: Scalar<C> = args[i].to_owned().into();
429 let rhs: Scalar<C> = args[i + n].to_owned().into();
430
431 res.push(ResultValue::Scalar(lhs - rhs));
432 }
433
434 res
435 } else {
436 args[..n].to_vec()
437 }
438 });
439
440 scalars.into_iter().map(|s| s.into()).collect_vec()
441 }
442}
443
444impl<C: CurveGroup> Neg for &MpcScalarResult<C> {
447 type Output = MpcScalarResult<C>;
448
449 fn neg(self) -> Self::Output {
450 self.fabric()
451 .new_gate_op(vec![self.id()], |args| {
452 let lhs: Scalar<C> = args[0].to_owned().into();
454 ResultValue::Scalar(-lhs)
455 })
456 .into()
457 }
458}
459impl_borrow_variants!(MpcScalarResult<C>, Neg, neg, -, C: CurveGroup);
460
461impl<C: CurveGroup> MpcScalarResult<C> {
462 pub fn batch_neg(values: &[MpcScalarResult<C>]) -> Vec<MpcScalarResult<C>> {
464 if values.is_empty() {
465 return vec![];
466 }
467
468 let n = values.len();
469 let fabric = values[0].fabric();
470 let ids = values.iter().map(|v| v.id()).collect_vec();
471
472 let scalars = fabric.new_batch_gate_op(ids, n , move |args| {
473 let scalars = args.into_iter().map(Scalar::from).collect_vec();
475
476 scalars
478 .iter()
479 .map(|a| ResultValue::Scalar(-a))
480 .collect_vec()
481 });
482
483 scalars.into_iter().map(|s| s.into()).collect_vec()
484 }
485}
486
487impl<C: CurveGroup> Mul<&Scalar<C>> for &MpcScalarResult<C> {
490 type Output = MpcScalarResult<C>;
491
492 fn mul(self, rhs: &Scalar<C>) -> Self::Output {
493 let rhs = *rhs;
494 self.fabric()
495 .new_gate_op(vec![self.id()], move |args| {
496 let lhs: Scalar<C> = args[0].to_owned().into();
498 ResultValue::Scalar(lhs * rhs)
499 })
500 .into()
501 }
502}
503impl_borrow_variants!(MpcScalarResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
504impl_commutative!(MpcScalarResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
505
506impl<C: CurveGroup> Mul<&ScalarResult<C>> for &MpcScalarResult<C> {
507 type Output = MpcScalarResult<C>;
508
509 fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
510 self.fabric()
511 .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
512 let lhs: Scalar<C> = args.remove(0).into();
514 let rhs: Scalar<C> = args.remove(0).into();
515
516 ResultValue::Scalar(lhs * rhs)
517 })
518 .into()
519 }
520}
521impl_borrow_variants!(MpcScalarResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
522impl_commutative!(MpcScalarResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
523
524impl<C: CurveGroup> Mul<&MpcScalarResult<C>> for &MpcScalarResult<C> {
526 type Output = MpcScalarResult<C>;
527
528 fn mul(self, rhs: &MpcScalarResult<C>) -> Self::Output {
529 let (a, b, c) = self.fabric().next_beaver_triple();
531
532 let masked_lhs = self - &a;
534 let masked_rhs = rhs - &b;
535
536 let d_open = masked_lhs.open();
537 let e_open = masked_rhs.open();
538
539 &d_open * &b + &e_open * &a + c + &d_open * &e_open
541 }
542}
543impl_borrow_variants!(MpcScalarResult<C>, Mul, mul, *, MpcScalarResult<C>, C: CurveGroup);
544
545impl<C: CurveGroup> MpcScalarResult<C> {
546 pub fn batch_mul(
548 a: &[MpcScalarResult<C>],
549 b: &[MpcScalarResult<C>],
550 ) -> Vec<MpcScalarResult<C>> {
551 let n = a.len();
552 assert_eq!(
553 a.len(),
554 b.len(),
555 "batch_mul: a and b must be the same length"
556 );
557
558 let fabric = &a[0].fabric();
560 let (beaver_a, beaver_b, beaver_c) = fabric.next_beaver_triple_batch(n);
561
562 let masked_lhs = MpcScalarResult::batch_sub(a, &beaver_a);
564 let masked_rhs = MpcScalarResult::batch_sub(b, &beaver_b);
565
566 let all_masks = [masked_lhs, masked_rhs].concat();
567 let opened_values = MpcScalarResult::open_batch(&all_masks);
568 let (d_open, e_open) = opened_values.split_at(n);
569
570 let de = ScalarResult::batch_mul(d_open, e_open);
572 let db = MpcScalarResult::batch_mul_public(&beaver_b, d_open);
573 let ea = MpcScalarResult::batch_mul_public(&beaver_a, e_open);
574
575 let de_plus_db = MpcScalarResult::batch_add_public(&db, &de);
577 let ea_plus_c = MpcScalarResult::batch_add(&ea, &beaver_c);
578 MpcScalarResult::batch_add(&de_plus_db, &ea_plus_c)
579 }
580
581 pub fn batch_mul_public(
584 a: &[MpcScalarResult<C>],
585 b: &[ScalarResult<C>],
586 ) -> Vec<MpcScalarResult<C>> {
587 assert_eq!(
588 a.len(),
589 b.len(),
590 "batch_mul_public: a and b must be the same length"
591 );
592
593 let n = a.len();
594 let fabric = a[0].fabric();
595 let ids = a
596 .iter()
597 .map(|v| v.id())
598 .chain(b.iter().map(|v| v.id))
599 .collect_vec();
600
601 let scalars: Vec<ScalarResult<C>> =
602 fabric.new_batch_gate_op(ids, n , move |args| {
603 let mut res: Vec<ResultValue<C>> = Vec::with_capacity(n);
604 for i in 0..n {
605 let lhs: Scalar<C> = args[i].to_owned().into();
606 let rhs: Scalar<C> = args[i + n].to_owned().into();
607
608 res.push(ResultValue::Scalar(lhs * rhs));
609 }
610
611 res
612 });
613
614 scalars.into_iter().map(|s| s.into()).collect_vec()
615 }
616}
617
618impl<C: CurveGroup> Mul<&MpcScalarResult<C>> for &CurvePoint<C> {
621 type Output = MpcPointResult<C>;
622
623 fn mul(self, rhs: &MpcScalarResult<C>) -> Self::Output {
624 let self_owned = *self;
625 rhs.fabric()
626 .new_gate_op(vec![rhs.id()], move |mut args| {
627 let rhs: Scalar<C> = args.remove(0).into();
628
629 ResultValue::Point(self_owned * rhs)
630 })
631 .into()
632 }
633}
634impl_commutative!(CurvePoint<C>, Mul, mul, *, MpcScalarResult<C>, Output=MpcPointResult<C>, C: CurveGroup);
635
636impl<C: CurveGroup> Mul<&MpcScalarResult<C>> for &CurvePointResult<C> {
637 type Output = MpcPointResult<C>;
638
639 fn mul(self, rhs: &MpcScalarResult<C>) -> Self::Output {
640 self.fabric
641 .new_gate_op(vec![self.id(), rhs.id()], |mut args| {
642 let lhs: CurvePoint<C> = args.remove(0).into();
643 let rhs: Scalar<C> = args.remove(0).into();
644
645 ResultValue::Point(lhs * rhs)
646 })
647 .into()
648 }
649}
650impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, MpcScalarResult<C>, Output=MpcPointResult<C>, C: CurveGroup);
651impl_commutative!(CurvePointResult<C>, Mul, mul, *, MpcScalarResult<C>, Output=MpcPointResult<C>, C: CurveGroup);
652
653#[cfg(test)]
654mod test {
655 use rand::thread_rng;
656
657 use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc, PARTY0};
658
659 #[tokio::test]
661 async fn test_sub() {
662 let mut rng = thread_rng();
663 let value1 = Scalar::random(&mut rng);
664 let value2 = Scalar::random(&mut rng);
665
666 let (res, _) = execute_mock_mpc(|fabric| async move {
667 let party0_value = fabric.share_scalar(value1, PARTY0).mpc_share();
669 let public_value = fabric.allocate_scalar(value2);
670
671 let res1 = &party0_value - &public_value;
673 let res_open1 = res1.open().await;
674 let expected1 = value1 - value2;
675
676 let res2 = &public_value - &party0_value;
678 let res_open2 = res2.open().await;
679 let expected2 = value2 - value1;
680
681 (res_open1 == expected1, res_open2 == expected2)
682 })
683 .await;
684
685 assert!(res.0);
686 assert!(res.1)
687 }
688}