1#![allow(unused_doc_comments)]
3use std::{
4 borrow::Borrow,
5 convert::TryInto,
6 iter::{Product, Sum},
7 ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign},
8};
9
10use clear_on_drop::clear::Clear;
11use curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar};
12use rand_core::{CryptoRng, OsRng, RngCore};
13use subtle::ConstantTimeEq;
14use tokio::runtime::Handle;
15use zeroize::Zeroize;
16
17use crate::{
18 beaver::SharedValueSource,
19 commitment::PedersenCommitment,
20 error::{MpcError, MpcNetworkError},
21 macros::{self},
22 network::MpcNetwork,
23 BeaverSource, SharedNetwork, Visibility, Visible,
24};
25
26#[derive(Debug)]
28pub struct MpcScalar<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> {
29 pub value: Scalar,
31 pub(crate) visibility: Visibility,
33 pub(crate) network: SharedNetwork<N>,
35 pub(crate) beaver_source: BeaverSource<S>,
37}
38
39impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Clone for MpcScalar<N, S> {
40 fn clone(&self) -> Self {
41 Self {
42 value: self.value,
43 visibility: self.visibility,
44 network: self.network.clone(),
45 beaver_source: self.beaver_source.clone(),
46 }
47 }
48}
49
50pub fn scalar_to_u64(a: &Scalar) -> u64 {
56 u64::from_le_bytes(a.to_bytes()[..8].try_into().unwrap())
57}
58
59impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> MpcScalar<N, S> {
63 #[inline]
67 pub(crate) fn is_private(&self) -> bool {
68 self.visibility == Visibility::Private
69 }
70
71 #[inline]
72 pub(crate) fn is_shared(&self) -> bool {
73 self.visibility == Visibility::Shared
74 }
75
76 #[inline]
77 pub(crate) fn is_public(&self) -> bool {
78 self.visibility == Visibility::Public
79 }
80
81 #[inline]
82 pub fn value(&self) -> Scalar {
83 self.value
84 }
85
86 #[inline]
87 pub fn to_scalar(&self) -> Scalar {
88 self.value()
89 }
90
91 #[inline]
92 pub(crate) fn network(&self) -> SharedNetwork<N> {
93 self.network.clone()
94 }
95
96 #[inline]
97 pub(crate) fn beaver_source(&self) -> BeaverSource<S> {
98 self.beaver_source.clone()
99 }
100
101 pub fn from_public_u64(
107 a: u64,
108 network: SharedNetwork<N>,
109 beaver_source: BeaverSource<S>,
110 ) -> Self {
111 Self::from_u64_with_visibility(a, Visibility::Public, network, beaver_source)
112 }
113
114 pub fn from_private_u64(
116 a: u64,
117 network: SharedNetwork<N>,
118 beaver_source: BeaverSource<S>,
119 ) -> Self {
120 Self::from_u64_with_visibility(a, Visibility::Private, network, beaver_source)
121 }
122
123 pub(crate) fn from_u64_with_visibility(
125 a: u64,
126 visibility: Visibility,
127 network: SharedNetwork<N>,
128 beaver_source: BeaverSource<S>,
129 ) -> Self {
130 Self {
131 network,
132 visibility,
133 beaver_source,
134 value: Scalar::from(a),
135 }
136 }
137
138 pub fn from_public_scalar(
140 value: Scalar,
141 network: SharedNetwork<N>,
142 beaver_source: BeaverSource<S>,
143 ) -> Self {
144 Self::from_scalar_with_visibility(value, Visibility::Public, network, beaver_source)
145 }
146
147 pub fn from_private_scalar(
149 value: Scalar,
150 network: SharedNetwork<N>,
151 beaver_source: BeaverSource<S>,
152 ) -> Self {
153 Self::from_scalar_with_visibility(value, Visibility::Private, network, beaver_source)
154 }
155
156 pub(crate) fn from_scalar_with_visibility(
158 value: Scalar,
159 visibility: Visibility,
160 network: SharedNetwork<N>,
161 beaver_source: BeaverSource<S>,
162 ) -> Self {
163 Self {
164 network,
165 visibility,
166 value,
167 beaver_source,
168 }
169 }
170
171 pub fn random<R: RngCore + CryptoRng>(
174 rng: &mut R,
175 network: SharedNetwork<N>,
176 beaver_source: BeaverSource<S>,
177 ) -> Self {
178 Self {
179 network,
180 visibility: Visibility::Private,
181 beaver_source,
182 value: Scalar::random(rng),
183 }
184 }
185
186 pub fn default(network: SharedNetwork<N>, beaver_source: BeaverSource<S>) -> Self {
188 Self::zero(network, beaver_source)
189 }
190
191 macros::impl_delegated_wrapper!(
193 Scalar,
194 from_bytes_mod_order,
195 from_bytes_mod_order_with_visibility,
196 bytes,
197 [u8; 32]
198 );
199 macros::impl_delegated_wrapper!(
200 Scalar,
201 from_bytes_mod_order_wide,
202 from_bytes_mod_order_wide_with_visibility,
203 input,
204 &[u8; 64]
205 );
206
207 pub fn from_canonical_bytes(
208 bytes: [u8; 32],
209 network: SharedNetwork<N>,
210 beaver_source: BeaverSource<S>,
211 ) -> Option<MpcScalar<N, S>> {
212 Self::from_canonical_bytes_with_visibility(
213 bytes,
214 Visibility::Public,
215 network,
216 beaver_source,
217 )
218 }
219
220 pub fn from_canonical_bytes_with_visibility(
221 bytes: [u8; 32],
222 visibility: Visibility,
223 network: SharedNetwork<N>,
224 beaver_source: BeaverSource<S>,
225 ) -> Option<MpcScalar<N, S>> {
226 Some(MpcScalar {
227 visibility,
228 network,
229 beaver_source,
230 value: Scalar::from_canonical_bytes(bytes)?,
231 })
232 }
233
234 macros::impl_delegated_wrapper!(
235 Scalar,
236 from_bits,
237 from_bits_with_visibility,
238 bytes,
239 [u8; 32]
240 );
241
242 macros::impl_delegated!(to_bytes, self, [u8; 32]);
244 macros::impl_delegated!(as_bytes, self, &[u8; 32]);
245 macros::impl_delegated!(is_canonical, self, bool);
247 macros::impl_delegated_wrapper!(Scalar, zero);
249 macros::impl_delegated_wrapper!(Scalar, one);
251}
252
253impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> MpcScalar<N, S> {
257 pub fn share_secret(&self, party_id: u64) -> Result<MpcScalar<N, S>, MpcNetworkError> {
263 let my_party_id = self.network.as_ref().borrow().party_id();
264
265 if my_party_id == party_id {
266 let mut rng = OsRng {};
269 let random_share = Scalar::random(&mut rng);
270
271 Handle::current().block_on(
273 self.network
274 .as_ref()
275 .borrow_mut()
276 .send_single_scalar(random_share),
277 )?;
278
279 Ok(MpcScalar {
282 value: self.value - random_share,
283 visibility: Visibility::Shared,
284 network: self.network.clone(),
285 beaver_source: self.beaver_source.clone(),
286 })
287 } else {
288 Self::receive_value(self.network.clone(), self.beaver_source.clone())
289 }
290 }
291
292 pub fn batch_share_secrets(
294 party_id: u64,
295 secrets: &[MpcScalar<N, S>],
296 ) -> Result<Vec<MpcScalar<N, S>>, MpcNetworkError> {
297 assert!(
298 secrets.iter().all(|secret| secret.is_private()),
299 "Values to be shared must be in private state"
300 );
301
302 if secrets.is_empty() {
303 return Ok(Vec::new());
304 }
305
306 let network = secrets[0].network();
307 let beaver_source = secrets[0].beaver_source();
308 let my_party_id = network.as_ref().borrow().party_id();
309
310 if my_party_id == party_id {
311 let mut rng = OsRng {};
313 let random_shares: Vec<Scalar> = (0..secrets.len())
314 .map(|_| Scalar::random(&mut rng))
315 .collect();
316
317 Handle::current()
319 .block_on(network.as_ref().borrow_mut().send_scalars(&random_shares))?;
320
321 Ok(secrets
322 .iter()
323 .zip(random_shares.iter())
324 .map(|(secret, blinding)| MpcScalar {
325 value: secret.value() - blinding,
326 visibility: Visibility::Shared,
327 network: network.clone(),
328 beaver_source: beaver_source.clone(),
329 })
330 .collect())
331 } else {
332 Self::batch_receive_values(secrets.len(), network, beaver_source)
333 }
334 }
335
336 pub fn receive_value(
338 network: SharedNetwork<N>,
339 beaver_source: BeaverSource<S>,
340 ) -> Result<MpcScalar<N, S>, MpcNetworkError> {
341 let value =
342 Handle::current().block_on(network.as_ref().borrow_mut().receive_single_scalar())?;
343
344 Ok(MpcScalar {
345 value,
346 visibility: Visibility::Shared,
347 network,
348 beaver_source,
349 })
350 }
351
352 pub fn batch_receive_values(
354 num_expected: usize,
355 network: SharedNetwork<N>,
356 beaver_source: BeaverSource<S>,
357 ) -> Result<Vec<MpcScalar<N, S>>, MpcNetworkError> {
358 let values = Handle::current()
359 .block_on(network.as_ref().borrow_mut().receive_scalars(num_expected))?;
360
361 Ok(values
362 .iter()
363 .map(|value| MpcScalar {
364 value: *value,
365 visibility: Visibility::Shared,
366 network: network.clone(),
367 beaver_source: beaver_source.clone(),
368 })
369 .collect())
370 }
371
372 pub fn open(&self) -> Result<MpcScalar<N, S>, MpcNetworkError> {
376 assert!(!self.is_private(), "Private values may not be opened...");
377 if self.is_public() {
378 return Ok(self.clone());
379 }
380
381 let received_scalar = Handle::current().block_on(
383 self.network
384 .as_ref()
385 .borrow_mut()
386 .broadcast_single_scalar(self.value),
387 )?;
388
389 Ok(MpcScalar::from_public_scalar(
391 self.value + received_scalar,
392 self.network.clone(),
393 self.beaver_source.clone(),
394 ))
395 }
396
397 pub fn batch_open(values: &[MpcScalar<N, S>]) -> Result<Vec<MpcScalar<N, S>>, MpcNetworkError> {
399 assert!(
400 values.iter().all(|value| !value.is_private()),
401 "Private values may not be opened..."
402 );
403
404 if values.is_empty() {
405 return Ok(Vec::new());
406 }
407
408 let network = values[0].network();
409 let beaver_source = values[0].beaver_source();
410
411 let received_scalars = Handle::current().block_on(
413 network.as_ref().borrow_mut().broadcast_scalars(
414 &values
415 .iter()
416 .map(|value| value.value())
417 .collect::<Vec<Scalar>>(),
418 ),
419 )?;
420
421 Ok(values
422 .iter()
423 .zip(received_scalars.iter())
424 .map(|(my_share, peer_share)| {
425 if my_share.is_public() {
426 return my_share.clone();
427 }
428
429 MpcScalar::from_public_scalar(
430 my_share.value() + peer_share,
431 network.clone(),
432 beaver_source.clone(),
433 )
434 })
435 .collect())
436 }
437
438 pub fn commit_and_open(&self) -> Result<MpcScalar<N, S>, MpcError> {
443 assert!(!self.is_private(), "Private values may not be opened...");
444 if self.is_public() {
445 return Ok(self.clone());
446 }
447
448 let commitment = PedersenCommitment::commit(self.to_scalar());
450 let peer_commitment = Handle::current()
451 .block_on(
452 self.network()
453 .as_ref()
454 .borrow_mut()
455 .broadcast_single_point(commitment.get_commitment()),
456 )
457 .map_err(MpcError::NetworkError)?;
458
459 let received_scalars = Handle::current()
461 .block_on(
462 self.network()
463 .as_ref()
464 .borrow_mut()
465 .broadcast_scalars(&[commitment.get_blinding(), commitment.get_value()]),
466 )
467 .map_err(MpcError::NetworkError)?;
468
469 let (peer_blinding, peer_value) = (received_scalars[0], received_scalars[1]);
470
471 if !PedersenCommitment::verify_from_values(peer_commitment, peer_blinding, peer_value) {
473 return Err(MpcError::AuthenticationError);
474 }
475
476 Ok(Self {
477 value: self.value() + peer_value,
478 visibility: Visibility::Public,
479 network: self.network(),
480 beaver_source: self.beaver_source(),
481 })
482 }
483
484 pub fn batch_commit_and_open(
486 values: &[MpcScalar<N, S>],
487 ) -> Result<Vec<MpcScalar<N, S>>, MpcError> {
488 assert!(
489 values.iter().all(|value| !value.is_private()),
490 "Private values may not be opened...",
491 );
492
493 if values.is_empty() {
494 return Ok(Vec::new());
495 }
496
497 let network = values[0].network();
498 let beaver_source = values[0].beaver_source();
499
500 let commitments: Vec<PedersenCommitment> = values
502 .iter()
503 .map(|value| PedersenCommitment::commit(value.to_scalar()))
504 .collect();
505 let peer_commitments = Handle::current()
506 .block_on(
507 network.as_ref().borrow_mut().broadcast_points(
508 &commitments
509 .iter()
510 .map(|comm| comm.get_commitment())
511 .collect::<Vec<RistrettoPoint>>(),
512 ),
513 )
514 .map_err(MpcError::NetworkError)?;
515
516 let mut commitment_data: Vec<Scalar> = Vec::new();
518 commitments.iter().for_each(|comm| {
519 commitment_data.push(comm.get_blinding());
520 commitment_data.push(comm.get_value());
521 });
522
523 let received_values = Handle::current()
524 .block_on(
525 network
526 .as_ref()
527 .borrow_mut()
528 .broadcast_scalars(&commitment_data),
529 )
530 .map_err(MpcError::NetworkError)?;
531
532 let mut peer_values: Vec<Scalar> = Vec::new();
534 received_values
535 .chunks(2 ) .zip(peer_commitments.into_iter())
537 .try_for_each(|(revealed_values, comm)| {
538 let (blinding, value) = (revealed_values[0], revealed_values[1]);
540 peer_values.push(value);
541
542 if !PedersenCommitment::verify_from_values(comm, blinding, value) {
544 return Err(MpcError::AuthenticationError);
545 }
546
547 Ok(())
548 })?;
549
550 Ok(values
552 .iter()
553 .zip(peer_values)
554 .map(|(my_value, peer_value)| {
555 if my_value.is_public() {
556 return my_value.clone();
557 }
558
559 MpcScalar {
560 value: my_value.value() + peer_value,
561 visibility: Visibility::Public,
562 network: network.clone(),
563 beaver_source: beaver_source.clone(),
564 }
565 })
566 .collect())
567 }
568
569 fn next_beaver_triplet(&self) -> (MpcScalar<N, S>, MpcScalar<N, S>, MpcScalar<N, S>) {
571 let (a, b, c) = self.beaver_source.as_ref().borrow_mut().next_triplet();
572
573 (
574 MpcScalar::from_scalar_with_visibility(
575 a,
576 Visibility::Shared,
577 self.network.clone(),
578 self.beaver_source.clone(),
579 ),
580 MpcScalar::from_scalar_with_visibility(
581 b,
582 Visibility::Shared,
583 self.network.clone(),
584 self.beaver_source.clone(),
585 ),
586 MpcScalar::from_scalar_with_visibility(
587 c,
588 Visibility::Shared,
589 self.network.clone(),
590 self.beaver_source.clone(),
591 ),
592 )
593 }
594
595 #[allow(clippy::type_complexity)]
597 fn next_beaver_triplet_batch(
598 &self,
599 num_triplets: usize,
600 ) -> Vec<(MpcScalar<N, S>, MpcScalar<N, S>, MpcScalar<N, S>)> {
601 let triplet_batch = self
602 .beaver_source
603 .as_ref()
604 .borrow_mut()
605 .next_triplet_batch(num_triplets);
606
607 triplet_batch
609 .iter()
610 .map(|(a, b, c)| {
611 (
612 MpcScalar::from_scalar_with_visibility(
613 *a,
614 Visibility::Shared,
615 self.network.clone(),
616 self.beaver_source.clone(),
617 ),
618 MpcScalar::from_scalar_with_visibility(
619 *b,
620 Visibility::Shared,
621 self.network.clone(),
622 self.beaver_source.clone(),
623 ),
624 MpcScalar::from_scalar_with_visibility(
625 *c,
626 Visibility::Shared,
627 self.network.clone(),
628 self.beaver_source.clone(),
629 ),
630 )
631 })
632 .collect::<Vec<_>>()
633 }
634}
635
636impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Visible for MpcScalar<N, S> {
640 fn visibility(&self) -> Visibility {
641 self.visibility
642 }
643}
644
645impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> PartialEq for MpcScalar<N, S> {
646 fn eq(&self, other: &Self) -> bool {
647 self.value.eq(&other.value)
648 }
649}
650
651impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> ConstantTimeEq for MpcScalar<N, S> {
652 fn ct_eq(&self, other: &Self) -> subtle::Choice {
653 self.value.ct_eq(&other.value)
654 }
655}
656
657impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Index<usize> for MpcScalar<N, S> {
658 type Output = u8;
659
660 fn index(&self, index: usize) -> &Self::Output {
661 self.value.index(index)
662 }
663}
664
665impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Clear for &mut MpcScalar<N, S> {
666 #[allow(clippy::needless_borrow)]
667 fn clear(&mut self) {
668 (&mut self.value).clear();
669 }
670}
671
672impl<'a, N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Mul<&'a MpcScalar<N, S>>
680 for &'a MpcScalar<N, S>
681{
682 type Output = MpcScalar<N, S>;
683
684 fn mul(self, rhs: &'a MpcScalar<N, S>) -> Self::Output {
692 if self.is_shared() && rhs.is_shared() {
693 let (a, b, c) = self.next_beaver_triplet();
694
695 let opened_values = MpcScalar::batch_open(&[(self - &a), (rhs - &b)]).unwrap();
697 let lhs_minus_a = &opened_values[0];
698 let rhs_minus_b = &opened_values[1];
699
700 let mut res = lhs_minus_a * &b + rhs_minus_b * &a + c;
704
705 if self.network.as_ref().borrow().am_king() {
707 res += lhs_minus_a * rhs_minus_b;
708 }
709
710 res
711 } else {
712 MpcScalar {
714 visibility: Visibility::min_visibility_two(self, rhs),
715 network: self.network.clone(),
716 beaver_source: self.beaver_source.clone(),
717 value: self.value * rhs.value,
718 }
719 }
720 }
721}
722
723macros::impl_operator_variants!(MpcScalar<N, S>, Mul, mul, *, MpcScalar<N, S>);
726macros::impl_wrapper_type!(MpcScalar<N, S>, Scalar, MpcScalar::from_public_scalar, Mul, mul, *, authenticated=false);
727macros::impl_arithmetic_assign!(MpcScalar<N, S>, MulAssign, mul_assign, *, MpcScalar<N, S>);
728macros::impl_arithmetic_assign!(MpcScalar<N, S>, MulAssign, mul_assign, *, Scalar);
729
730impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> MpcScalar<N, S> {
735 pub fn batch_mul(
740 a: &[MpcScalar<N, S>],
741 b: &[MpcScalar<N, S>],
742 ) -> Result<Vec<MpcScalar<N, S>>, MpcNetworkError> {
743 assert_eq!(
744 a.len(),
745 b.len(),
746 "input arrays to batch_mul must be of equal length"
747 );
748
749 if a.is_empty() {
750 return Ok(Vec::new());
751 }
752
753 let n = a.len();
754 let mut res = Vec::with_capacity(n);
755
756 let mut beaver_mul_pairs = Vec::new();
759 for i in 0..a.len() {
760 if !a[i].is_public() && !b[i].is_public() {
761 beaver_mul_pairs.push((&a[i], &b[i]))
762 }
763 }
764
765 let num_beaver_muls = beaver_mul_pairs.len();
767 let mut beaver_triplets = a[0].next_beaver_triplet_batch(num_beaver_muls);
768
769 let mut beaver_subs = Vec::with_capacity(2 * n);
771 beaver_mul_pairs
772 .iter()
773 .zip(beaver_triplets.iter())
774 .for_each(|((a_val, b_val), (beaver_a, beaver_b, _))| {
775 beaver_subs.push(*a_val - beaver_a);
776 beaver_subs.push(*b_val - beaver_b);
777 });
778
779 let mut opened_beaver_subs = if num_beaver_muls == 0 {
781 Vec::new()
782 } else {
783 MpcScalar::batch_open(&beaver_subs)?
784 };
785 for i in 0..n {
786 if a[i].is_public() || b[i].is_public() {
787 res.push(&a[i] * &b[i])
788 } else {
789 let (lhs_minus_a, rhs_minus_b) =
791 (opened_beaver_subs.remove(0), opened_beaver_subs.remove(0));
792
793 let (beaver_a, beaver_b, beaver_c) = beaver_triplets.remove(0);
794
795 let result = &lhs_minus_a * &beaver_b
800 + &rhs_minus_b * &beaver_a
801 + lhs_minus_a * rhs_minus_b
802 + &beaver_c;
803
804 res.push(result);
805 }
806 }
807
808 Ok(res)
809 }
810}
811
812impl<'a, N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Add<&'a MpcScalar<N, S>>
816 for &'a MpcScalar<N, S>
817{
818 type Output = MpcScalar<N, S>;
819
820 fn add(self, rhs: &'a MpcScalar<N, S>) -> Self::Output {
821 if self.is_public() && rhs.is_shared() {
823 return rhs + self;
824 }
825
826 let am_king = self.network.as_ref().borrow().am_king();
835
836 let res = {
837 if self.is_public() && rhs.is_public() || self.is_shared() && rhs.is_shared() || am_king
840 {
842 self.value() + rhs.value()
843 } else {
844 self.value()
845 }
846 };
847
848 MpcScalar {
849 value: res,
850 visibility: Visibility::min_visibility_two(self, rhs),
851 network: self.network.clone(),
852 beaver_source: self.beaver_source.clone(),
853 }
854 }
855}
856
857macros::impl_operator_variants!(MpcScalar<N, S>, Add, add, +, MpcScalar<N, S>);
858macros::impl_wrapper_type!(MpcScalar<N, S>, Scalar, MpcScalar::from_public_scalar, Add, add, +, authenticated=false);
859macros::impl_arithmetic_assign!(MpcScalar<N, S>, AddAssign, add_assign, +, MpcScalar<N, S>);
860macros::impl_arithmetic_assign!(MpcScalar<N, S>, AddAssign, add_assign, +, Scalar);
861
862impl<'a, N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Sub<&'a MpcScalar<N, S>>
866 for &'a MpcScalar<N, S>
867{
868 type Output = MpcScalar<N, S>;
869
870 #[allow(clippy::suspicious_arithmetic_impl)]
871 fn sub(self, rhs: &'a MpcScalar<N, S>) -> Self::Output {
872 self + rhs.neg()
873 }
874}
875
876macros::impl_operator_variants!(MpcScalar<N, S>, Sub, sub, -, MpcScalar<N, S>);
877macros::impl_wrapper_type!(MpcScalar<N, S>, Scalar, MpcScalar::from_public_scalar, Sub, sub, -, authenticated=false);
878macros::impl_arithmetic_assign!(MpcScalar<N, S>, SubAssign, sub_assign, -, MpcScalar<N, S>);
879macros::impl_arithmetic_assign!(MpcScalar<N, S>, SubAssign, sub_assign, -, Scalar);
880
881impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Neg for MpcScalar<N, S> {
882 type Output = MpcScalar<N, S>;
883
884 fn neg(self) -> Self::Output {
885 (&self).neg()
886 }
887}
888
889impl<'a, N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Neg for &'a MpcScalar<N, S> {
890 type Output = MpcScalar<N, S>;
891
892 fn neg(self) -> Self::Output {
893 MpcScalar {
894 visibility: self.visibility,
895 network: self.network.clone(),
896 beaver_source: self.beaver_source.clone(),
897 value: self.value.neg(),
898 }
899 }
900}
901
902impl<N, S, T> Product<T> for MpcScalar<N, S>
908where
909 N: MpcNetwork + Send,
910 S: SharedValueSource<Scalar>,
911 T: Borrow<MpcScalar<N, S>>,
912{
913 fn product<I: Iterator<Item = T>>(iter: I) -> Self {
914 let mut peekable = iter.peekable();
915 let first_elem = peekable.peek().unwrap();
916 let network: SharedNetwork<N> = first_elem.borrow().network.clone();
917 let beaver_source: BeaverSource<S> = first_elem.borrow().beaver_source.clone();
918
919 peekable.fold(MpcScalar::one(network, beaver_source), |acc, item| {
920 acc * item.borrow()
921 })
922 }
923}
924
925impl<N, S, T> Sum<T> for MpcScalar<N, S>
926where
927 N: MpcNetwork + Send,
928 S: SharedValueSource<Scalar>,
929 T: Borrow<MpcScalar<N, S>>,
930{
931 fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
932 let mut peekable = iter.peekable();
934 let first_elem = peekable.peek().unwrap();
935 let network = first_elem.borrow().network.clone();
936 let beaver_source: BeaverSource<S> = first_elem.borrow().beaver_source.clone();
937
938 peekable.fold(
939 MpcScalar::from_u64_with_visibility(0, Visibility::Shared, network, beaver_source),
940 |acc, item| acc + item.borrow(),
941 )
942 }
943}
944
945impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> MpcScalar<N, S> {
946 pub fn linear_combination(
948 scalars: &[MpcScalar<N, S>],
949 coeffs: &[MpcScalar<N, S>],
950 ) -> Result<MpcScalar<N, S>, MpcNetworkError> {
951 Ok(MpcScalar::batch_mul(scalars, coeffs)?.iter().sum())
952 }
953}
954
955impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Zeroize for MpcScalar<N, S> {
956 fn zeroize(&mut self) {
957 self.value.zeroize()
958 }
959}
960
961#[cfg(test)]
965mod test {
966 use std::{cell::RefCell, rc::Rc};
967
968 use clear_on_drop::clear::Clear;
969 use curve25519_dalek::scalar::Scalar;
970 use rand_core::OsRng;
971 use tokio::runtime::{Builder as RuntimeBuilder, Runtime};
972
973 use crate::{beaver::DummySharedScalarSource, network::dummy_network::DummyMpcNetwork};
974
975 use super::{MpcScalar, Visibility};
976
977 fn create_blockable_runtime() -> Runtime {
980 RuntimeBuilder::new_multi_thread()
981 .enable_all()
982 .worker_threads(1)
983 .max_blocking_threads(1)
984 .build()
985 .unwrap()
986 }
987
988 #[test]
989 fn test_zero() {
990 let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
991 let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
992
993 let expected =
994 MpcScalar::from_public_scalar(Scalar::zero(), network.clone(), beaver_source.clone());
995 let zero = MpcScalar::zero(network, beaver_source);
996
997 assert_eq!(zero, expected);
998 }
999
1000 #[test]
1001 fn test_open() {
1002 let rt = create_blockable_runtime();
1003 let handle = rt.spawn_blocking(|| {
1004 let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1005 network
1006 .borrow_mut()
1007 .add_mock_scalars(vec![Scalar::from(1u8)]);
1008
1009 let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1010 let expected = MpcScalar::from_public_scalar(
1011 Scalar::from(2u8),
1012 network.clone(),
1013 beaver_source.clone(),
1014 );
1015
1016 let my_share = MpcScalar::from_u64_with_visibility(
1019 1u64,
1020 Visibility::Shared,
1021 network,
1022 beaver_source,
1023 );
1024 assert_eq!(my_share.open().unwrap(), expected);
1025 });
1026
1027 rt.block_on(handle).unwrap();
1028 }
1029
1030 #[test]
1031 fn test_add() {
1032 let rt = create_blockable_runtime();
1033 let handle = rt.spawn_blocking(|| {
1034 let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1035 network
1036 .borrow_mut()
1037 .add_mock_scalars(vec![Scalar::from(2u8)]);
1038
1039 let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1040
1041 let shared_value1 = MpcScalar::from_u64_with_visibility(
1043 2u64,
1044 Visibility::Shared,
1045 network.clone(),
1046 beaver_source.clone(),
1047 );
1048
1049 let res = &shared_value1 + Scalar::from(3u64); assert_eq!(res.visibility, Visibility::Shared);
1052 assert_eq!(
1053 res.open().unwrap(),
1054 MpcScalar::from_public_u64(7u64, network.clone(), beaver_source.clone())
1055 );
1056
1057 let shared_value2 = MpcScalar::from_u64_with_visibility(
1061 4u64,
1062 Visibility::Shared,
1063 network.clone(),
1064 beaver_source.clone(),
1065 );
1066
1067 network
1068 .borrow_mut()
1069 .add_mock_scalars(vec![Scalar::from(3u8)]); let res = shared_value1 + shared_value2;
1072 assert_eq!(res.visibility, Visibility::Shared);
1073 assert_eq!(
1074 res.open().unwrap(),
1075 MpcScalar::from_public_u64(9, network, beaver_source)
1076 )
1077 });
1078
1079 rt.block_on(handle).unwrap();
1080 }
1081
1082 #[test]
1083 fn test_add_associative() {
1084 let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1085 let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1086
1087 let mut rng = OsRng {};
1089 let v1 = MpcScalar::random(&mut rng, network, beaver_source);
1090 let v2 = Scalar::random(&mut rng);
1091
1092 let res1 = &v1 + v2;
1093 let res2 = v2 + &v1;
1094
1095 assert_eq!(res1, res2);
1096 }
1097
1098 #[test]
1099 fn test_sub() {
1100 let rt = create_blockable_runtime();
1101 let handle = rt.spawn_blocking(|| {
1102 let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1103 let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1104
1105 let shared_value1 = MpcScalar::from_u64_with_visibility(
1108 2u64,
1109 Visibility::Shared,
1110 network.clone(),
1111 beaver_source.clone(),
1112 );
1113 network
1114 .borrow_mut()
1115 .add_mock_scalars(vec![Scalar::from(1u8)]);
1116
1117 let res = &shared_value1 - Scalar::from(2u8);
1118 assert_eq!(res.visibility, Visibility::Shared);
1119 assert_eq!(
1120 res.open().unwrap(),
1121 MpcScalar::from_public_u64(1u64, network.clone(), beaver_source.clone())
1122 );
1123
1124 let shared_value2 = MpcScalar::from_u64_with_visibility(
1126 5,
1127 Visibility::Shared,
1128 network.clone(),
1129 beaver_source.clone(),
1130 );
1131 network
1132 .borrow_mut()
1133 .add_mock_scalars(vec![Scalar::from(2u8)]);
1134
1135 let res = shared_value2 - shared_value1;
1136 assert_eq!(res.visibility, Visibility::Shared);
1137 assert_eq!(
1138 res.open().unwrap(),
1139 MpcScalar::from_public_u64(5, network, beaver_source)
1140 )
1141 });
1142
1143 rt.block_on(handle).unwrap();
1144 }
1145
1146 #[test]
1147 fn test_mul() {
1148 let rt = create_blockable_runtime();
1149 let handle = rt.spawn_blocking(|| {
1150 let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1151 let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1152
1153 let shared_value1 = MpcScalar::from_u64_with_visibility(
1156 6u64,
1157 Visibility::Shared,
1158 network.clone(),
1159 beaver_source.clone(),
1160 );
1161
1162 let res = &shared_value1 * Scalar::from(2u8);
1165 assert_eq!(res.visibility, Visibility::Shared);
1166
1167 network
1168 .borrow_mut()
1169 .add_mock_scalars(vec![Scalar::from(10u8)]);
1170
1171 assert_eq!(
1172 res.open().unwrap(),
1173 MpcScalar::from_public_u64(22, network.clone(), beaver_source.clone())
1174 );
1175
1176 let public_value =
1178 MpcScalar::from_public_u64(3u64, network.clone(), beaver_source.clone());
1179
1180 let res = public_value * &shared_value1;
1182 assert_eq!(res.visibility, Visibility::Shared);
1183
1184 network
1185 .borrow_mut()
1186 .add_mock_scalars(vec![Scalar::from(15u8)]);
1187 assert_eq!(
1188 res.open().unwrap(),
1189 MpcScalar::from_public_u64(33u64, network.clone(), beaver_source.clone())
1190 );
1191
1192 let shared_value2 = MpcScalar::from_u64_with_visibility(
1198 5u64,
1199 Visibility::Shared,
1200 network.clone(),
1201 beaver_source.clone(),
1202 );
1203 network
1204 .borrow_mut()
1205 .add_mock_scalars(vec![Scalar::from(5u8), Scalar::from(7u8)]);
1206
1207 let res = shared_value1 * shared_value2;
1209 assert_eq!(res.visibility, Visibility::Shared);
1210
1211 network
1212 .borrow_mut()
1213 .add_mock_scalars(vec![Scalar::from(0u64)]);
1214
1215 assert_eq!(
1216 res.open().unwrap(),
1217 MpcScalar::from_public_u64(12 * 11, network, beaver_source)
1218 );
1219 });
1220
1221 rt.block_on(handle).unwrap();
1222 }
1223
1224 #[tokio::test]
1225 async fn test_clear() {
1226 let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1227 let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1228 let mut value = MpcScalar::from_public_u64(2, network, beaver_source);
1229
1230 (&mut value).clear();
1231 assert_eq!(value.value(), Scalar::zero());
1232 }
1233}