1use std::ops::{Add, Mul, Neg, Sub};
5
6use itertools::Itertools;
7
8use crate::{
9 algebra::scalar::BatchScalarResult,
10 fabric::{MpcFabric, ResultHandle, ResultValue},
11 network::NetworkPayload,
12 PARTY0,
13};
14
15use super::{
16 macros::{impl_borrow_variants, impl_commutative},
17 mpc_stark_point::MpcStarkPointResult,
18 scalar::{Scalar, ScalarResult},
19 stark_curve::{StarkPoint, StarkPointResult},
20};
21
22#[derive(Clone, Debug)]
24pub struct MpcScalarResult {
25 pub(crate) share: ScalarResult,
27}
28
29impl From<ScalarResult> for MpcScalarResult {
30 fn from(share: ScalarResult) -> Self {
31 Self { share }
32 }
33}
34
35impl MpcScalarResult {
37 pub fn new_shared(value: ScalarResult) -> MpcScalarResult {
39 value.into()
40 }
41
42 pub fn id(&self) -> usize {
44 self.share.id
45 }
46
47 pub fn fabric(&self) -> &MpcFabric {
49 self.share.fabric()
50 }
51
52 pub fn open(&self) -> ResultHandle<Scalar> {
54 let (val0, val1) = if self.fabric().party_id() == PARTY0 {
56 let party0_value: ResultHandle<Scalar> =
57 self.fabric().new_network_op(vec![self.id()], |args| {
58 let share: Scalar = args[0].to_owned().into();
59 NetworkPayload::Scalar(share)
60 });
61 let party1_value: ResultHandle<Scalar> = self.fabric().receive_value();
62
63 (party0_value, party1_value)
64 } else {
65 let party0_value: ResultHandle<Scalar> = self.fabric().receive_value();
66 let party1_value: ResultHandle<Scalar> =
67 self.fabric().new_network_op(vec![self.id()], |args| {
68 let share = args[0].to_owned().into();
69 NetworkPayload::Scalar(share)
70 });
71
72 (party0_value, party1_value)
73 };
74
75 &val0 + &val1
77 }
78
79 pub fn open_batch(values: &[MpcScalarResult]) -> Vec<ScalarResult> {
81 if values.is_empty() {
82 return vec![];
83 }
84
85 let n = values.len();
86 let fabric = &values[0].fabric();
87 let my_results = values.iter().map(|v| v.id()).collect_vec();
88 let send_shares_fn = |args: Vec<ResultValue>| {
89 let shares: Vec<Scalar> = args.into_iter().map(Scalar::from).collect();
90 NetworkPayload::ScalarBatch(shares)
91 };
92
93 let (party0_vals, party1_vals) = if values[0].fabric().party_id() == PARTY0 {
95 let party0_vals: BatchScalarResult = fabric.new_network_op(my_results, send_shares_fn);
97 let party1_vals: BatchScalarResult = fabric.receive_value();
98
99 (party0_vals, party1_vals)
100 } else {
101 let party0_vals: BatchScalarResult = fabric.receive_value();
102 let party1_vals: BatchScalarResult = fabric.new_network_op(my_results, send_shares_fn);
103
104 (party0_vals, party1_vals)
105 };
106
107 fabric.new_batch_gate_op(vec![party0_vals.id, party1_vals.id], n, move |args| {
109 let party0_vals: Vec<Scalar> = args[0].to_owned().into();
110 let party1_vals: Vec<Scalar> = args[1].to_owned().into();
111
112 let mut results = Vec::with_capacity(n);
113 for i in 0..n {
114 results.push(ResultValue::Scalar(party0_vals[i] + party1_vals[i]));
115 }
116
117 results
118 })
119 }
120
121 pub fn to_scalar(&self) -> ScalarResult {
123 self.share.clone()
124 }
125}
126
127impl Add<&Scalar> for &MpcScalarResult {
134 type Output = MpcScalarResult;
135
136 fn add(self, rhs: &Scalar) -> Self::Output {
138 let rhs = *rhs;
139 let party_id = self.fabric().party_id();
140
141 self.fabric()
142 .new_gate_op(vec![self.id()], move |args| {
143 let lhs_share: Scalar = args[0].to_owned().into();
145 if party_id == PARTY0 {
146 ResultValue::Scalar(lhs_share + rhs)
147 } else {
148 ResultValue::Scalar(lhs_share)
149 }
150 })
151 .into()
152 }
153}
154impl_borrow_variants!(MpcScalarResult, Add, add, +, Scalar);
155impl_commutative!(MpcScalarResult, Add, add, +, Scalar);
156
157impl Add<&ScalarResult> for &MpcScalarResult {
158 type Output = MpcScalarResult;
159
160 fn add(self, rhs: &ScalarResult) -> Self::Output {
162 let party_id = self.fabric().party_id();
163 self.fabric()
164 .new_gate_op(vec![self.id(), rhs.id], move |mut args| {
165 let lhs: Scalar = args.remove(0).into();
167 let rhs: Scalar = args.remove(0).into();
168
169 if party_id == PARTY0 {
170 ResultValue::Scalar(lhs + rhs)
171 } else {
172 ResultValue::Scalar(lhs)
173 }
174 })
175 .into()
176 }
177}
178impl_borrow_variants!(MpcScalarResult, Add, add, +, ScalarResult);
179impl_commutative!(MpcScalarResult, Add, add, +, ScalarResult);
180
181impl Add<&MpcScalarResult> for &MpcScalarResult {
182 type Output = MpcScalarResult;
183
184 fn add(self, rhs: &MpcScalarResult) -> Self::Output {
185 self.fabric()
186 .new_gate_op(vec![self.id(), rhs.id()], |args| {
187 let lhs: Scalar = args[0].to_owned().into();
189 let rhs: Scalar = args[1].to_owned().into();
190
191 ResultValue::Scalar(lhs + rhs)
192 })
193 .into()
194 }
195}
196impl_borrow_variants!(MpcScalarResult, Add, add, +, MpcScalarResult);
197
198impl MpcScalarResult {
199 pub fn batch_add(a: &[MpcScalarResult], b: &[MpcScalarResult]) -> Vec<MpcScalarResult> {
201 assert_eq!(
202 a.len(),
203 b.len(),
204 "batch_add: a and b must be the same length"
205 );
206
207 let n = a.len();
208 let fabric = a[0].fabric();
209 let ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
210
211 let scalars = fabric.new_batch_gate_op(ids, n , move |args| {
212 let scalars = args.into_iter().map(Scalar::from).collect_vec();
214 let (a_res, b_res) = scalars.split_at(n);
215
216 a_res
218 .iter()
219 .zip(b_res.iter())
220 .map(|(a, b)| ResultValue::Scalar(a + b))
221 .collect_vec()
222 });
223
224 scalars.into_iter().map(|s| s.into()).collect_vec()
225 }
226
227 pub fn batch_add_public(a: &[MpcScalarResult], b: &[ScalarResult]) -> Vec<MpcScalarResult> {
229 assert_eq!(
230 a.len(),
231 b.len(),
232 "batch_add_public: a and b must be the same length"
233 );
234
235 let n = a.len();
236 let fabric = a[0].fabric();
237 let ids = a
238 .iter()
239 .map(|v| v.id())
240 .chain(b.iter().map(|v| v.id()))
241 .collect_vec();
242
243 let party_id = fabric.party_id();
244 let scalars: Vec<ScalarResult> =
245 fabric.new_batch_gate_op(ids, n , move |args| {
246 if party_id == PARTY0 {
247 let mut res: Vec<ResultValue> = Vec::with_capacity(n);
248
249 for i in 0..n {
250 let lhs: Scalar = args[i].to_owned().into();
251 let rhs: Scalar = args[i + n].to_owned().into();
252
253 res.push(ResultValue::Scalar(lhs + rhs));
254 }
255
256 res
257 } else {
258 args[..n].to_vec()
259 }
260 });
261
262 scalars.into_iter().map(|s| s.into()).collect_vec()
263 }
264}
265
266impl Sub<&Scalar> for &MpcScalarResult {
269 type Output = MpcScalarResult;
270
271 fn sub(self, rhs: &Scalar) -> Self::Output {
273 let rhs = *rhs;
274 let party_id = self.fabric().party_id();
275
276 if party_id == PARTY0 {
277 &self.share - rhs
278 } else {
279 &self.share - Scalar::zero()
281 }
282 .into()
283 }
284}
285impl_borrow_variants!(MpcScalarResult, Sub, sub, -, Scalar);
286
287impl Sub<&MpcScalarResult> for &Scalar {
288 type Output = MpcScalarResult;
289
290 fn sub(self, rhs: &MpcScalarResult) -> Self::Output {
292 let party_id = rhs.fabric().party_id();
293
294 if party_id == PARTY0 {
295 self - &rhs.share
296 } else {
297 Scalar::zero() - &rhs.share
299 }
300 .into()
301 }
302}
303
304impl Sub<&ScalarResult> for &MpcScalarResult {
305 type Output = MpcScalarResult;
306
307 fn sub(self, rhs: &ScalarResult) -> Self::Output {
309 let party_id = self.fabric().party_id();
310
311 if party_id == PARTY0 {
312 &self.share - rhs
313 } else {
314 self.share.clone() + Scalar::zero()
316 }
317 .into()
318 }
319}
320impl_borrow_variants!(MpcScalarResult, Sub, sub, -, ScalarResult);
321
322impl Sub<&MpcScalarResult> for &ScalarResult {
323 type Output = MpcScalarResult;
324
325 fn sub(self, rhs: &MpcScalarResult) -> Self::Output {
327 let party_id = rhs.fabric().party_id();
328
329 if party_id == PARTY0 {
330 self - &rhs.share
331 } else {
332 Scalar::zero() - rhs.share.clone()
334 }
335 .into()
336 }
337}
338impl_borrow_variants!(ScalarResult, Sub, sub, -, MpcScalarResult, Output=MpcScalarResult);
339
340impl Sub<&MpcScalarResult> for &MpcScalarResult {
341 type Output = MpcScalarResult;
342
343 fn sub(self, rhs: &MpcScalarResult) -> Self::Output {
344 self.fabric()
345 .new_gate_op(vec![self.id(), rhs.id()], |args| {
346 let lhs: Scalar = args[0].to_owned().into();
348 let rhs: Scalar = args[1].to_owned().into();
349
350 ResultValue::Scalar(lhs - rhs)
351 })
352 .into()
353 }
354}
355impl_borrow_variants!(MpcScalarResult, Sub, sub, -, MpcScalarResult);
356
357impl MpcScalarResult {
358 pub fn batch_sub(a: &[MpcScalarResult], b: &[MpcScalarResult]) -> Vec<MpcScalarResult> {
360 assert_eq!(
361 a.len(),
362 b.len(),
363 "batch_sub: a and b must be the same length"
364 );
365
366 let n = a.len();
367 let fabric = a[0].fabric();
368 let ids = a
369 .iter()
370 .map(|v| v.id())
371 .chain(b.iter().map(|v| v.id()))
372 .collect_vec();
373
374 let scalars: Vec<ScalarResult> =
375 fabric.new_batch_gate_op(ids, n , move |args| {
376 let scalars = args.into_iter().map(Scalar::from).collect_vec();
378 let (a_res, b_res) = scalars.split_at(n);
379
380 a_res
382 .iter()
383 .zip(b_res.iter())
384 .map(|(a, b)| ResultValue::Scalar(a - b))
385 .collect_vec()
386 });
387
388 scalars.into_iter().map(|s| s.into()).collect_vec()
389 }
390
391 pub fn batch_sub_public(a: &[MpcScalarResult], b: &[ScalarResult]) -> Vec<MpcScalarResult> {
393 assert_eq!(
394 a.len(),
395 b.len(),
396 "batch_sub_public: a and b must be the same length"
397 );
398
399 let n = a.len();
400 let fabric = a[0].fabric();
401 let ids = a
402 .iter()
403 .map(|v| v.id())
404 .chain(b.iter().map(|v| v.id()))
405 .collect_vec();
406
407 let party_id = fabric.party_id();
408 let scalars = fabric.new_batch_gate_op(ids, n , move |args| {
409 if party_id == PARTY0 {
410 let mut res: Vec<ResultValue> = Vec::with_capacity(n);
411
412 for i in 0..n {
413 let lhs: Scalar = args[i].to_owned().into();
414 let rhs: Scalar = args[i + n].to_owned().into();
415
416 res.push(ResultValue::Scalar(lhs - rhs));
417 }
418
419 res
420 } else {
421 args[..n].to_vec()
422 }
423 });
424
425 scalars.into_iter().map(|s| s.into()).collect_vec()
426 }
427}
428
429impl Neg for &MpcScalarResult {
432 type Output = MpcScalarResult;
433
434 fn neg(self) -> Self::Output {
435 self.fabric()
436 .new_gate_op(vec![self.id()], |args| {
437 let lhs: Scalar = args[0].to_owned().into();
439 ResultValue::Scalar(-lhs)
440 })
441 .into()
442 }
443}
444impl_borrow_variants!(MpcScalarResult, Neg, neg, -);
445
446impl MpcScalarResult {
447 pub fn batch_neg(values: &[MpcScalarResult]) -> Vec<MpcScalarResult> {
449 if values.is_empty() {
450 return vec![];
451 }
452
453 let n = values.len();
454 let fabric = values[0].fabric();
455 let ids = values.iter().map(|v| v.id()).collect_vec();
456
457 let scalars = fabric.new_batch_gate_op(ids, n , move |args| {
458 let scalars = args.into_iter().map(Scalar::from).collect_vec();
460
461 scalars
463 .iter()
464 .map(|a| ResultValue::Scalar(-a))
465 .collect_vec()
466 });
467
468 scalars.into_iter().map(|s| s.into()).collect_vec()
469 }
470}
471
472impl Mul<&Scalar> for &MpcScalarResult {
475 type Output = MpcScalarResult;
476
477 fn mul(self, rhs: &Scalar) -> Self::Output {
478 let rhs = *rhs;
479 self.fabric()
480 .new_gate_op(vec![self.id()], move |args| {
481 let lhs: Scalar = args[0].to_owned().into();
483 ResultValue::Scalar(lhs * rhs)
484 })
485 .into()
486 }
487}
488impl_borrow_variants!(MpcScalarResult, Mul, mul, *, Scalar);
489impl_commutative!(MpcScalarResult, Mul, mul, *, Scalar);
490
491impl Mul<&ScalarResult> for &MpcScalarResult {
492 type Output = MpcScalarResult;
493
494 fn mul(self, rhs: &ScalarResult) -> Self::Output {
495 self.fabric()
496 .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
497 let lhs: Scalar = args.remove(0).into();
499 let rhs: Scalar = args.remove(0).into();
500
501 ResultValue::Scalar(lhs * rhs)
502 })
503 .into()
504 }
505}
506impl_borrow_variants!(MpcScalarResult, Mul, mul, *, ScalarResult);
507impl_commutative!(MpcScalarResult, Mul, mul, *, ScalarResult);
508
509impl Mul<&MpcScalarResult> for &MpcScalarResult {
511 type Output = MpcScalarResult;
512
513 fn mul(self, rhs: &MpcScalarResult) -> Self::Output {
514 let (a, b, c) = self.fabric().next_beaver_triple();
516
517 let masked_lhs = self - &a;
519 let masked_rhs = rhs - &b;
520
521 let d_open = masked_lhs.open();
522 let e_open = masked_rhs.open();
523
524 &d_open * &b + &e_open * &a + c + &d_open * &e_open
526 }
527}
528impl_borrow_variants!(MpcScalarResult, Mul, mul, *, MpcScalarResult);
529
530impl MpcScalarResult {
531 pub fn batch_mul(a: &[MpcScalarResult], b: &[MpcScalarResult]) -> Vec<MpcScalarResult> {
533 let n = a.len();
534 assert_eq!(
535 a.len(),
536 b.len(),
537 "batch_mul: a and b must be the same length"
538 );
539
540 let fabric = &a[0].fabric();
542 let (beaver_a, beaver_b, beaver_c) = fabric.next_beaver_triple_batch(n);
543
544 let masked_lhs = MpcScalarResult::batch_sub(a, &beaver_a);
546 let masked_rhs = MpcScalarResult::batch_sub(b, &beaver_b);
547
548 let all_masks = [masked_lhs, masked_rhs].concat();
549 let opened_values = MpcScalarResult::open_batch(&all_masks);
550 let (d_open, e_open) = opened_values.split_at(n);
551
552 let de = ScalarResult::batch_mul(d_open, e_open);
554 let db = MpcScalarResult::batch_mul_public(&beaver_b, d_open);
555 let ea = MpcScalarResult::batch_mul_public(&beaver_a, e_open);
556
557 let de_plus_db = MpcScalarResult::batch_add_public(&db, &de);
559 let ea_plus_c = MpcScalarResult::batch_add(&ea, &beaver_c);
560 MpcScalarResult::batch_add(&de_plus_db, &ea_plus_c)
561 }
562
563 pub fn batch_mul_public(a: &[MpcScalarResult], b: &[ScalarResult]) -> Vec<MpcScalarResult> {
565 assert_eq!(
566 a.len(),
567 b.len(),
568 "batch_mul_public: a and b must be the same length"
569 );
570
571 let n = a.len();
572 let fabric = a[0].fabric();
573 let ids = a
574 .iter()
575 .map(|v| v.id())
576 .chain(b.iter().map(|v| v.id))
577 .collect_vec();
578
579 let scalars: Vec<ScalarResult> =
580 fabric.new_batch_gate_op(ids, n , move |args| {
581 let mut res: Vec<ResultValue> = Vec::with_capacity(n);
582 for i in 0..n {
583 let lhs: Scalar = args[i].to_owned().into();
584 let rhs: Scalar = args[i + n].to_owned().into();
585
586 res.push(ResultValue::Scalar(lhs * rhs));
587 }
588
589 res
590 });
591
592 scalars.into_iter().map(|s| s.into()).collect_vec()
593 }
594}
595
596impl Mul<&MpcScalarResult> for &StarkPoint {
599 type Output = MpcStarkPointResult;
600
601 fn mul(self, rhs: &MpcScalarResult) -> Self::Output {
602 let self_owned = *self;
603 rhs.fabric()
604 .new_gate_op(vec![rhs.id()], move |mut args| {
605 let rhs: Scalar = args.remove(0).into();
606
607 ResultValue::Point(self_owned * rhs)
608 })
609 .into()
610 }
611}
612impl_commutative!(StarkPoint, Mul, mul, *, MpcScalarResult, Output=MpcStarkPointResult);
613
614impl Mul<&MpcScalarResult> for &StarkPointResult {
615 type Output = MpcStarkPointResult;
616
617 fn mul(self, rhs: &MpcScalarResult) -> Self::Output {
618 self.fabric
619 .new_gate_op(vec![self.id(), rhs.id()], |mut args| {
620 let lhs: StarkPoint = args.remove(0).into();
621 let rhs: Scalar = args.remove(0).into();
622
623 ResultValue::Point(lhs * rhs)
624 })
625 .into()
626 }
627}
628impl_borrow_variants!(StarkPointResult, Mul, mul, *, MpcScalarResult, Output=MpcStarkPointResult);
629impl_commutative!(StarkPointResult, Mul, mul, *, MpcScalarResult, Output=MpcStarkPointResult);
630
631#[cfg(test)]
632mod test {
633 use rand::thread_rng;
634
635 use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc, PARTY0};
636
637 #[tokio::test]
639 async fn test_sub() {
640 let mut rng = thread_rng();
641 let value1 = Scalar::random(&mut rng);
642 let value2 = Scalar::random(&mut rng);
643
644 let (res, _) = execute_mock_mpc(|fabric| async move {
645 let party0_value = fabric.share_scalar(value1, PARTY0).mpc_share();
647 let public_value = fabric.allocate_scalar(value2);
648
649 let res1 = &party0_value - &public_value;
651 let res_open1 = res1.open().await;
652 let expected1 = value1 - value2;
653
654 let res2 = &public_value - &party0_value;
656 let res_open2 = res2.open().await;
657 let expected2 = value2 - value1;
658
659 (res_open1 == expected1, res_open2 == expected2)
660 })
661 .await;
662
663 assert!(res.0);
664 assert!(res.1)
665 }
666}