mpc_ristretto/
mpc_scalar.rs

1//! Groups the definitions and trait implementations for a scalar value within an MPC network
2#![allow(unused_doc_comments)]
3use std::{
4    borrow::Borrow,
5    convert::TryInto,
6    iter::{Product, Sum},
7    ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign},
8};
9
10use clear_on_drop::clear::Clear;
11use curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar};
12use rand_core::{CryptoRng, OsRng, RngCore};
13use subtle::ConstantTimeEq;
14use tokio::runtime::Handle;
15use zeroize::Zeroize;
16
17use crate::{
18    beaver::SharedValueSource,
19    commitment::PedersenCommitment,
20    error::{MpcError, MpcNetworkError},
21    macros::{self},
22    network::MpcNetwork,
23    BeaverSource, SharedNetwork, Visibility, Visible,
24};
25
26/// Represents a scalar value allocated in an MPC network
27#[derive(Debug)]
28pub struct MpcScalar<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> {
29    /// the underlying value of the scalar allocated in the network
30    pub value: Scalar,
31    /// The visibility flag; what amount of information parties have
32    pub(crate) visibility: Visibility,
33    /// The underlying network that the MPC operates on
34    pub(crate) network: SharedNetwork<N>,
35    /// The source for shared values; MAC keys, beaver triples, etc
36    pub(crate) beaver_source: BeaverSource<S>,
37}
38
39impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Clone for MpcScalar<N, S> {
40    fn clone(&self) -> Self {
41        Self {
42            value: self.value,
43            visibility: self.visibility,
44            network: self.network.clone(),
45            beaver_source: self.beaver_source.clone(),
46        }
47    }
48}
49
50/**
51 * Static and helper methods
52 */
53
54/// Converts a scalar to u64
55pub fn scalar_to_u64(a: &Scalar) -> u64 {
56    u64::from_le_bytes(a.to_bytes()[..8].try_into().unwrap())
57}
58
59/**
60 * Wrapper type implementations
61 */
62impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> MpcScalar<N, S> {
63    /**
64     * Helper methods
65     */
66    #[inline]
67    pub(crate) fn is_private(&self) -> bool {
68        self.visibility == Visibility::Private
69    }
70
71    #[inline]
72    pub(crate) fn is_shared(&self) -> bool {
73        self.visibility == Visibility::Shared
74    }
75
76    #[inline]
77    pub(crate) fn is_public(&self) -> bool {
78        self.visibility == Visibility::Public
79    }
80
81    #[inline]
82    pub fn value(&self) -> Scalar {
83        self.value
84    }
85
86    #[inline]
87    pub fn to_scalar(&self) -> Scalar {
88        self.value()
89    }
90
91    #[inline]
92    pub(crate) fn network(&self) -> SharedNetwork<N> {
93        self.network.clone()
94    }
95
96    #[inline]
97    pub(crate) fn beaver_source(&self) -> BeaverSource<S> {
98        self.beaver_source.clone()
99    }
100
101    /**
102     * Casting methods
103     */
104
105    /// Create a public network scalar from a u64
106    pub fn from_public_u64(
107        a: u64,
108        network: SharedNetwork<N>,
109        beaver_source: BeaverSource<S>,
110    ) -> Self {
111        Self::from_u64_with_visibility(a, Visibility::Public, network, beaver_source)
112    }
113
114    /// Create a private network scalar from a given u64
115    pub fn from_private_u64(
116        a: u64,
117        network: SharedNetwork<N>,
118        beaver_source: BeaverSource<S>,
119    ) -> Self {
120        Self::from_u64_with_visibility(a, Visibility::Private, network, beaver_source)
121    }
122
123    /// Create a scalar from a given u64 and visibility
124    pub(crate) fn from_u64_with_visibility(
125        a: u64,
126        visibility: Visibility,
127        network: SharedNetwork<N>,
128        beaver_source: BeaverSource<S>,
129    ) -> Self {
130        Self {
131            network,
132            visibility,
133            beaver_source,
134            value: Scalar::from(a),
135        }
136    }
137
138    /// Allocate a public network value from an underlying scalar
139    pub fn from_public_scalar(
140        value: Scalar,
141        network: SharedNetwork<N>,
142        beaver_source: BeaverSource<S>,
143    ) -> Self {
144        Self::from_scalar_with_visibility(value, Visibility::Public, network, beaver_source)
145    }
146
147    /// Allocate a private network value from an underlying scalar
148    pub fn from_private_scalar(
149        value: Scalar,
150        network: SharedNetwork<N>,
151        beaver_source: BeaverSource<S>,
152    ) -> Self {
153        Self::from_scalar_with_visibility(value, Visibility::Private, network, beaver_source)
154    }
155
156    /// Allocate an existing scalar in the network with given visibility
157    pub(crate) fn from_scalar_with_visibility(
158        value: Scalar,
159        visibility: Visibility,
160        network: SharedNetwork<N>,
161        beaver_source: BeaverSource<S>,
162    ) -> Self {
163        Self {
164            network,
165            visibility,
166            value,
167            beaver_source,
168        }
169    }
170
171    /// Generate a random scalar
172    /// Random will always be SharedWithOwner(self); two parties cannot reliably generate the same random value
173    pub fn random<R: RngCore + CryptoRng>(
174        rng: &mut R,
175        network: SharedNetwork<N>,
176        beaver_source: BeaverSource<S>,
177    ) -> Self {
178        Self {
179            network,
180            visibility: Visibility::Private,
181            beaver_source,
182            value: Scalar::random(rng),
183        }
184    }
185
186    /// Default-esque implementation
187    pub fn default(network: SharedNetwork<N>, beaver_source: BeaverSource<S>) -> Self {
188        Self::zero(network, beaver_source)
189    }
190
191    // Build a scalar from bytes
192    macros::impl_delegated_wrapper!(
193        Scalar,
194        from_bytes_mod_order,
195        from_bytes_mod_order_with_visibility,
196        bytes,
197        [u8; 32]
198    );
199    macros::impl_delegated_wrapper!(
200        Scalar,
201        from_bytes_mod_order_wide,
202        from_bytes_mod_order_wide_with_visibility,
203        input,
204        &[u8; 64]
205    );
206
207    pub fn from_canonical_bytes(
208        bytes: [u8; 32],
209        network: SharedNetwork<N>,
210        beaver_source: BeaverSource<S>,
211    ) -> Option<MpcScalar<N, S>> {
212        Self::from_canonical_bytes_with_visibility(
213            bytes,
214            Visibility::Public,
215            network,
216            beaver_source,
217        )
218    }
219
220    pub fn from_canonical_bytes_with_visibility(
221        bytes: [u8; 32],
222        visibility: Visibility,
223        network: SharedNetwork<N>,
224        beaver_source: BeaverSource<S>,
225    ) -> Option<MpcScalar<N, S>> {
226        Some(MpcScalar {
227            visibility,
228            network,
229            beaver_source,
230            value: Scalar::from_canonical_bytes(bytes)?,
231        })
232    }
233
234    macros::impl_delegated_wrapper!(
235        Scalar,
236        from_bits,
237        from_bits_with_visibility,
238        bytes,
239        [u8; 32]
240    );
241
242    // Convert a scalar to bytes
243    macros::impl_delegated!(to_bytes, self, [u8; 32]);
244    macros::impl_delegated!(as_bytes, self, &[u8; 32]);
245    // Check whether the scalar is canonically represented mod l
246    macros::impl_delegated!(is_canonical, self, bool);
247    // Generate the additive identity
248    macros::impl_delegated_wrapper!(Scalar, zero);
249    // Generate the multiplicative identity
250    macros::impl_delegated_wrapper!(Scalar, one);
251}
252
253/**
254 * Secret sharing implementation
255 */
256impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> MpcScalar<N, S> {
257    /// From a privately held value, construct an additive secret share and distribute this
258    /// to the counterparty. The local party samples a random value R which is given to the peer
259    /// The local party then holds a - R where a is the underlying value.
260    /// This method is called by both parties, only one of which transmits, the peer will simply
261    /// await the sent share
262    pub fn share_secret(&self, party_id: u64) -> Result<MpcScalar<N, S>, MpcNetworkError> {
263        let my_party_id = self.network.as_ref().borrow().party_id();
264
265        if my_party_id == party_id {
266            // Sender party
267            // Sample a random additive complement
268            let mut rng = OsRng {};
269            let random_share = Scalar::random(&mut rng);
270
271            // Broadcast the counterparty's share
272            Handle::current().block_on(
273                self.network
274                    .as_ref()
275                    .borrow_mut()
276                    .send_single_scalar(random_share),
277            )?;
278
279            // Do not subtract directly as the random scalar is not directly allocated in the network
280            // subtracting directly ties it to the subtraction implementation in a fragile way
281            Ok(MpcScalar {
282                value: self.value - random_share,
283                visibility: Visibility::Shared,
284                network: self.network.clone(),
285                beaver_source: self.beaver_source.clone(),
286            })
287        } else {
288            Self::receive_value(self.network.clone(), self.beaver_source.clone())
289        }
290    }
291
292    /// Share a batch of privately held secrets by constructing additive shares
293    pub fn batch_share_secrets(
294        party_id: u64,
295        secrets: &[MpcScalar<N, S>],
296    ) -> Result<Vec<MpcScalar<N, S>>, MpcNetworkError> {
297        assert!(
298            secrets.iter().all(|secret| secret.is_private()),
299            "Values to be shared must be in private state"
300        );
301
302        if secrets.is_empty() {
303            return Ok(Vec::new());
304        }
305
306        let network = secrets[0].network();
307        let beaver_source = secrets[0].beaver_source();
308        let my_party_id = network.as_ref().borrow().party_id();
309
310        if my_party_id == party_id {
311            // Sender party
312            let mut rng = OsRng {};
313            let random_shares: Vec<Scalar> = (0..secrets.len())
314                .map(|_| Scalar::random(&mut rng))
315                .collect();
316
317            // Broadcast the random shares to the peer
318            Handle::current()
319                .block_on(network.as_ref().borrow_mut().send_scalars(&random_shares))?;
320
321            Ok(secrets
322                .iter()
323                .zip(random_shares.iter())
324                .map(|(secret, blinding)| MpcScalar {
325                    value: secret.value() - blinding,
326                    visibility: Visibility::Shared,
327                    network: network.clone(),
328                    beaver_source: beaver_source.clone(),
329                })
330                .collect())
331        } else {
332            Self::batch_receive_values(secrets.len(), network, beaver_source)
333        }
334    }
335
336    /// Local party receives a secret share of a value; as opposed to using share_secret, no existing value is needed
337    pub fn receive_value(
338        network: SharedNetwork<N>,
339        beaver_source: BeaverSource<S>,
340    ) -> Result<MpcScalar<N, S>, MpcNetworkError> {
341        let value =
342            Handle::current().block_on(network.as_ref().borrow_mut().receive_single_scalar())?;
343
344        Ok(MpcScalar {
345            value,
346            visibility: Visibility::Shared,
347            network,
348            beaver_source,
349        })
350    }
351
352    /// Local party receives a batch of shared values
353    pub fn batch_receive_values(
354        num_expected: usize,
355        network: SharedNetwork<N>,
356        beaver_source: BeaverSource<S>,
357    ) -> Result<Vec<MpcScalar<N, S>>, MpcNetworkError> {
358        let values = Handle::current()
359            .block_on(network.as_ref().borrow_mut().receive_scalars(num_expected))?;
360
361        Ok(values
362            .iter()
363            .map(|value| MpcScalar {
364                value: *value,
365                visibility: Visibility::Shared,
366                network: network.clone(),
367                beaver_source: beaver_source.clone(),
368            })
369            .collect())
370    }
371
372    /// From a shared value, both parties open their shares and construct the plaintext value.
373    /// Note that the parties no longer hold valid additive secret shares of the value, this is used
374    /// at the end of a computation
375    pub fn open(&self) -> Result<MpcScalar<N, S>, MpcNetworkError> {
376        assert!(!self.is_private(), "Private values may not be opened...");
377        if self.is_public() {
378            return Ok(self.clone());
379        }
380
381        // Send my scalar and expect one back
382        let received_scalar = Handle::current().block_on(
383            self.network
384                .as_ref()
385                .borrow_mut()
386                .broadcast_single_scalar(self.value),
387        )?;
388
389        // Reconstruct the plaintext from the peer's share
390        Ok(MpcScalar::from_public_scalar(
391            self.value + received_scalar,
392            self.network.clone(),
393            self.beaver_source.clone(),
394        ))
395    }
396
397    /// Open a batch of shared values
398    pub fn batch_open(values: &[MpcScalar<N, S>]) -> Result<Vec<MpcScalar<N, S>>, MpcNetworkError> {
399        assert!(
400            values.iter().all(|value| !value.is_private()),
401            "Private values may not be opened..."
402        );
403
404        if values.is_empty() {
405            return Ok(Vec::new());
406        }
407
408        let network = values[0].network();
409        let beaver_source = values[0].beaver_source();
410
411        // Both parties share their values
412        let received_scalars = Handle::current().block_on(
413            network.as_ref().borrow_mut().broadcast_scalars(
414                &values
415                    .iter()
416                    .map(|value| value.value())
417                    .collect::<Vec<Scalar>>(),
418            ),
419        )?;
420
421        Ok(values
422            .iter()
423            .zip(received_scalars.iter())
424            .map(|(my_share, peer_share)| {
425                if my_share.is_public() {
426                    return my_share.clone();
427                }
428
429                MpcScalar::from_public_scalar(
430                    my_share.value() + peer_share,
431                    network.clone(),
432                    beaver_source.clone(),
433                )
434            })
435            .collect())
436    }
437
438    /// From a shared value:
439    ///     1. Commit to the value and exchange commitments
440    ///     2. Open those commitments to the underlying value
441    ///     3. Verify that the peer's opening matches their commitment
442    pub fn commit_and_open(&self) -> Result<MpcScalar<N, S>, MpcError> {
443        assert!(!self.is_private(), "Private values may not be opened...");
444        if self.is_public() {
445            return Ok(self.clone());
446        }
447
448        // Compute a Pedersen commitment to the value
449        let commitment = PedersenCommitment::commit(self.to_scalar());
450        let peer_commitment = Handle::current()
451            .block_on(
452                self.network()
453                    .as_ref()
454                    .borrow_mut()
455                    .broadcast_single_point(commitment.get_commitment()),
456            )
457            .map_err(MpcError::NetworkError)?;
458
459        // Open the commitment to the underlying value
460        let received_scalars = Handle::current()
461            .block_on(
462                self.network()
463                    .as_ref()
464                    .borrow_mut()
465                    .broadcast_scalars(&[commitment.get_blinding(), commitment.get_value()]),
466            )
467            .map_err(MpcError::NetworkError)?;
468
469        let (peer_blinding, peer_value) = (received_scalars[0], received_scalars[1]);
470
471        // Verify the commitment and return the opened value
472        if !PedersenCommitment::verify_from_values(peer_commitment, peer_blinding, peer_value) {
473            return Err(MpcError::AuthenticationError);
474        }
475
476        Ok(Self {
477            value: self.value() + peer_value,
478            visibility: Visibility::Public,
479            network: self.network(),
480            beaver_source: self.beaver_source(),
481        })
482    }
483
484    /// Commit to and open a batch of secret shared values
485    pub fn batch_commit_and_open(
486        values: &[MpcScalar<N, S>],
487    ) -> Result<Vec<MpcScalar<N, S>>, MpcError> {
488        assert!(
489            values.iter().all(|value| !value.is_private()),
490            "Private values may not be opened...",
491        );
492
493        if values.is_empty() {
494            return Ok(Vec::new());
495        }
496
497        let network = values[0].network();
498        let beaver_source = values[0].beaver_source();
499
500        // Generate commitments to the values and share them with the peer
501        let commitments: Vec<PedersenCommitment> = values
502            .iter()
503            .map(|value| PedersenCommitment::commit(value.to_scalar()))
504            .collect();
505        let peer_commitments = Handle::current()
506            .block_on(
507                network.as_ref().borrow_mut().broadcast_points(
508                    &commitments
509                        .iter()
510                        .map(|comm| comm.get_commitment())
511                        .collect::<Vec<RistrettoPoint>>(),
512                ),
513            )
514            .map_err(MpcError::NetworkError)?;
515
516        // Open both the underlying values and the blinding factors
517        let mut commitment_data: Vec<Scalar> = Vec::new();
518        commitments.iter().for_each(|comm| {
519            commitment_data.push(comm.get_blinding());
520            commitment_data.push(comm.get_value());
521        });
522
523        let received_values = Handle::current()
524            .block_on(
525                network
526                    .as_ref()
527                    .borrow_mut()
528                    .broadcast_scalars(&commitment_data),
529            )
530            .map_err(MpcError::NetworkError)?;
531
532        // Verify the peer's commitments
533        let mut peer_values: Vec<Scalar> = Vec::new();
534        received_values
535            .chunks(2 /* chunk_size */) // Fetch each pair of blinding, value
536            .zip(peer_commitments.into_iter())
537            .try_for_each(|(revealed_values, comm)| {
538                // Destructure the received payload and append to the peer values vector
539                let (blinding, value) = (revealed_values[0], revealed_values[1]);
540                peer_values.push(value);
541
542                // Verify the Pedersen commitment, report an authentication error if opening fails
543                if !PedersenCommitment::verify_from_values(comm, blinding, value) {
544                    return Err(MpcError::AuthenticationError);
545                }
546
547                Ok(())
548            })?;
549
550        // If the commitments open properly then add shares together to recover cleartext
551        Ok(values
552            .iter()
553            .zip(peer_values)
554            .map(|(my_value, peer_value)| {
555                if my_value.is_public() {
556                    return my_value.clone();
557                }
558
559                MpcScalar {
560                    value: my_value.value() + peer_value,
561                    visibility: Visibility::Public,
562                    network: network.clone(),
563                    beaver_source: beaver_source.clone(),
564                }
565            })
566            .collect())
567    }
568
569    /// Retrieves the next Beaver triplet from the Beaver source and allocates the values within the network
570    fn next_beaver_triplet(&self) -> (MpcScalar<N, S>, MpcScalar<N, S>, MpcScalar<N, S>) {
571        let (a, b, c) = self.beaver_source.as_ref().borrow_mut().next_triplet();
572
573        (
574            MpcScalar::from_scalar_with_visibility(
575                a,
576                Visibility::Shared,
577                self.network.clone(),
578                self.beaver_source.clone(),
579            ),
580            MpcScalar::from_scalar_with_visibility(
581                b,
582                Visibility::Shared,
583                self.network.clone(),
584                self.beaver_source.clone(),
585            ),
586            MpcScalar::from_scalar_with_visibility(
587                c,
588                Visibility::Shared,
589                self.network.clone(),
590                self.beaver_source.clone(),
591            ),
592        )
593    }
594
595    /// Retrieves the next Beaver triplet batch from the Beaver source and allocates the value in the network
596    #[allow(clippy::type_complexity)]
597    fn next_beaver_triplet_batch(
598        &self,
599        num_triplets: usize,
600    ) -> Vec<(MpcScalar<N, S>, MpcScalar<N, S>, MpcScalar<N, S>)> {
601        let triplet_batch = self
602            .beaver_source
603            .as_ref()
604            .borrow_mut()
605            .next_triplet_batch(num_triplets);
606
607        // Allocate values as shared in the network
608        triplet_batch
609            .iter()
610            .map(|(a, b, c)| {
611                (
612                    MpcScalar::from_scalar_with_visibility(
613                        *a,
614                        Visibility::Shared,
615                        self.network.clone(),
616                        self.beaver_source.clone(),
617                    ),
618                    MpcScalar::from_scalar_with_visibility(
619                        *b,
620                        Visibility::Shared,
621                        self.network.clone(),
622                        self.beaver_source.clone(),
623                    ),
624                    MpcScalar::from_scalar_with_visibility(
625                        *c,
626                        Visibility::Shared,
627                        self.network.clone(),
628                        self.beaver_source.clone(),
629                    ),
630                )
631            })
632            .collect::<Vec<_>>()
633    }
634}
635
636/**
637 * Generic trait implementations
638 */
639impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Visible for MpcScalar<N, S> {
640    fn visibility(&self) -> Visibility {
641        self.visibility
642    }
643}
644
645impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> PartialEq for MpcScalar<N, S> {
646    fn eq(&self, other: &Self) -> bool {
647        self.value.eq(&other.value)
648    }
649}
650
651impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> ConstantTimeEq for MpcScalar<N, S> {
652    fn ct_eq(&self, other: &Self) -> subtle::Choice {
653        self.value.ct_eq(&other.value)
654    }
655}
656
657impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Index<usize> for MpcScalar<N, S> {
658    type Output = u8;
659
660    fn index(&self, index: usize) -> &Self::Output {
661        self.value.index(index)
662    }
663}
664
665impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Clear for &mut MpcScalar<N, S> {
666    #[allow(clippy::needless_borrow)]
667    fn clear(&mut self) {
668        (&mut self.value).clear();
669    }
670}
671
672/**
673 * Mul and variants for: borrowed, non-borrowed, and Scalar types
674 */
675
676/// Implementation of mul with the beaver trick
677/// This implementation panics in the case of a network error.
678/// Ideally this is done in a thread where the panic can be handled by the parent.
679impl<'a, N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Mul<&'a MpcScalar<N, S>>
680    for &'a MpcScalar<N, S>
681{
682    type Output = MpcScalar<N, S>;
683
684    /// Multiplies two (possibly shared) values. The only case in which we need a Beaver trick
685    /// is when both lhs and rhs are Shared. If only one is shared, multiplying by a public value
686    /// directly leads to an additive sharing. If both are public, we do not need an additive share.
687    /// TODO(@joey): What is the correct behavior when one or both of lhs and rhs are private
688    ///
689    /// See https://securecomputation.org/docs/pragmaticmpc.pdf (Section 3.4) for the identities this
690    /// implementation makes use of.
691    fn mul(self, rhs: &'a MpcScalar<N, S>) -> Self::Output {
692        if self.is_shared() && rhs.is_shared() {
693            let (a, b, c) = self.next_beaver_triplet();
694
695            // Open the values d = [lhs - a] and e = [rhs - b]
696            let opened_values = MpcScalar::batch_open(&[(self - &a), (rhs - &b)]).unwrap();
697            let lhs_minus_a = &opened_values[0];
698            let rhs_minus_b = &opened_values[1];
699
700            // Identity: [a * b] = de + d[b] + e[a] + [c]
701            // All multiplications here are between a public and shared value or
702            // two public values, so the recursion will not hit this case
703            let mut res = lhs_minus_a * &b + rhs_minus_b * &a + c;
704
705            // Split into additive shares, the king holds de + res
706            if self.network.as_ref().borrow().am_king() {
707                res += lhs_minus_a * rhs_minus_b;
708            }
709
710            res
711        } else {
712            // Directly multiply
713            MpcScalar {
714                visibility: Visibility::min_visibility_two(self, rhs),
715                network: self.network.clone(),
716                beaver_source: self.beaver_source.clone(),
717                value: self.value * rhs.value,
718            }
719        }
720    }
721}
722
723// Multiplication with a scalar value is equivalent to a public multiplication, no Beaver
724// trick needed
725macros::impl_operator_variants!(MpcScalar<N, S>, Mul, mul, *, MpcScalar<N, S>);
726macros::impl_wrapper_type!(MpcScalar<N, S>, Scalar, MpcScalar::from_public_scalar, Mul, mul, *, authenticated=false);
727macros::impl_arithmetic_assign!(MpcScalar<N, S>, MulAssign, mul_assign, *, MpcScalar<N, S>);
728macros::impl_arithmetic_assign!(MpcScalar<N, S>, MulAssign, mul_assign, *, Scalar);
729
730/**
731 * Batch multiply allowing for batches of communication
732 */
733
734impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> MpcScalar<N, S> {
735    /// Returns the result [a_1 * b_1, ..., a_n * b_n]
736    ///
737    /// This method is not meant to be used directly, instead, it should be called
738    /// through the MPC fabric which will inject `am_king` and `beaver_source`
739    pub fn batch_mul(
740        a: &[MpcScalar<N, S>],
741        b: &[MpcScalar<N, S>],
742    ) -> Result<Vec<MpcScalar<N, S>>, MpcNetworkError> {
743        assert_eq!(
744            a.len(),
745            b.len(),
746            "input arrays to batch_mul must be of equal length"
747        );
748
749        if a.is_empty() {
750            return Ok(Vec::new());
751        }
752
753        let n = a.len();
754        let mut res = Vec::with_capacity(n);
755
756        // If one (or both) of a and b is public, it can be multiplied locally
757        // so we first separate out these values to avoid unnecessary computation/communication
758        let mut beaver_mul_pairs = Vec::new();
759        for i in 0..a.len() {
760            if !a[i].is_public() && !b[i].is_public() {
761                beaver_mul_pairs.push((&a[i], &b[i]))
762            }
763        }
764
765        // For each of the multiplications that requires a beaver-style mul; sample a multiplication triplet
766        let num_beaver_muls = beaver_mul_pairs.len();
767        let mut beaver_triplets = a[0].next_beaver_triplet_batch(num_beaver_muls);
768
769        // Tile a payload buffer with the beaver openings then share
770        let mut beaver_subs = Vec::with_capacity(2 * n);
771        beaver_mul_pairs
772            .iter()
773            .zip(beaver_triplets.iter())
774            .for_each(|((a_val, b_val), (beaver_a, beaver_b, _))| {
775                beaver_subs.push(*a_val - beaver_a);
776                beaver_subs.push(*b_val - beaver_b);
777            });
778
779        // Open the tiled beaver subtractions
780        let mut opened_beaver_subs = if num_beaver_muls == 0 {
781            Vec::new()
782        } else {
783            MpcScalar::batch_open(&beaver_subs)?
784        };
785        for i in 0..n {
786            if a[i].is_public() || b[i].is_public() {
787                res.push(&a[i] * &b[i])
788            } else {
789                // Fetch the next opening of a beaver sub
790                let (lhs_minus_a, rhs_minus_b) =
791                    (opened_beaver_subs.remove(0), opened_beaver_subs.remove(0));
792
793                let (beaver_a, beaver_b, beaver_c) = beaver_triplets.remove(0);
794
795                // Perform the multiplication and place it in the result
796                // Identity: [a * b] = de + d[b] + e[a] + [c]
797                // All multiplications here are between a public and shared value or
798                // two public values, so the recursion will not hit this case
799                let result = &lhs_minus_a * &beaver_b
800                    + &rhs_minus_b * &beaver_a
801                    + lhs_minus_a * rhs_minus_b
802                    + &beaver_c;
803
804                res.push(result);
805            }
806        }
807
808        Ok(res)
809    }
810}
811
812/**
813 * Add and variants for: borrowed, non-borrowed, and scalar types
814 */
815impl<'a, N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Add<&'a MpcScalar<N, S>>
816    for &'a MpcScalar<N, S>
817{
818    type Output = MpcScalar<N, S>;
819
820    fn add(self, rhs: &'a MpcScalar<N, S>) -> Self::Output {
821        // If public + shared swap the arguments for simplicity
822        if self.is_public() && rhs.is_shared() {
823            return rhs + self;
824        }
825
826        // If both values are public; both parties add the values together to obtain
827        // a public result.
828        // If both values are shared; both parties add the shared values together to
829        // obtain a shared result.
830        // If only one value is public, the king adds the public valid to her share
831        // I.e. if the parties hold an additive sharing of a = a_1 + a_2 and with to
832        // add public b; the king now holds a_1 + b and the peer holds a_2. Effectively
833        // they construct an implicit secret sharing of b where b_1 = b and b_2 = 0
834        let am_king = self.network.as_ref().borrow().am_king();
835
836        let res = {
837            if self.is_public() && rhs.is_public() ||  // Both public
838                self.is_shared() && rhs.is_shared() ||  // Both shared
839                am_king
840            // One public, but local peer is king
841            {
842                self.value() + rhs.value()
843            } else {
844                self.value()
845            }
846        };
847
848        MpcScalar {
849            value: res,
850            visibility: Visibility::min_visibility_two(self, rhs),
851            network: self.network.clone(),
852            beaver_source: self.beaver_source.clone(),
853        }
854    }
855}
856
857macros::impl_operator_variants!(MpcScalar<N, S>, Add, add, +, MpcScalar<N, S>);
858macros::impl_wrapper_type!(MpcScalar<N, S>, Scalar, MpcScalar::from_public_scalar, Add, add, +, authenticated=false);
859macros::impl_arithmetic_assign!(MpcScalar<N, S>, AddAssign, add_assign, +, MpcScalar<N, S>);
860macros::impl_arithmetic_assign!(MpcScalar<N, S>, AddAssign, add_assign, +, Scalar);
861
862/**
863 * Sub and variants for: borrowed, non-borrowed, and scalar types
864 */
865impl<'a, N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Sub<&'a MpcScalar<N, S>>
866    for &'a MpcScalar<N, S>
867{
868    type Output = MpcScalar<N, S>;
869
870    #[allow(clippy::suspicious_arithmetic_impl)]
871    fn sub(self, rhs: &'a MpcScalar<N, S>) -> Self::Output {
872        self + rhs.neg()
873    }
874}
875
876macros::impl_operator_variants!(MpcScalar<N, S>, Sub, sub, -, MpcScalar<N, S>);
877macros::impl_wrapper_type!(MpcScalar<N, S>, Scalar, MpcScalar::from_public_scalar, Sub, sub, -, authenticated=false);
878macros::impl_arithmetic_assign!(MpcScalar<N, S>, SubAssign, sub_assign, -, MpcScalar<N, S>);
879macros::impl_arithmetic_assign!(MpcScalar<N, S>, SubAssign, sub_assign, -, Scalar);
880
881impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Neg for MpcScalar<N, S> {
882    type Output = MpcScalar<N, S>;
883
884    fn neg(self) -> Self::Output {
885        (&self).neg()
886    }
887}
888
889impl<'a, N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Neg for &'a MpcScalar<N, S> {
890    type Output = MpcScalar<N, S>;
891
892    fn neg(self) -> Self::Output {
893        MpcScalar {
894            visibility: self.visibility,
895            network: self.network.clone(),
896            beaver_source: self.beaver_source.clone(),
897            value: self.value.neg(),
898        }
899    }
900}
901
902/**
903 * Iterator traits
904 */
905
906/// TODO: Optimize this to use tree-structured round parallelism
907impl<N, S, T> Product<T> for MpcScalar<N, S>
908where
909    N: MpcNetwork + Send,
910    S: SharedValueSource<Scalar>,
911    T: Borrow<MpcScalar<N, S>>,
912{
913    fn product<I: Iterator<Item = T>>(iter: I) -> Self {
914        let mut peekable = iter.peekable();
915        let first_elem = peekable.peek().unwrap();
916        let network: SharedNetwork<N> = first_elem.borrow().network.clone();
917        let beaver_source: BeaverSource<S> = first_elem.borrow().beaver_source.clone();
918
919        peekable.fold(MpcScalar::one(network, beaver_source), |acc, item| {
920            acc * item.borrow()
921        })
922    }
923}
924
925impl<N, S, T> Sum<T> for MpcScalar<N, S>
926where
927    N: MpcNetwork + Send,
928    S: SharedValueSource<Scalar>,
929    T: Borrow<MpcScalar<N, S>>,
930{
931    fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
932        // This operation is invalid on an empty iterator, unwrap is expected
933        let mut peekable = iter.peekable();
934        let first_elem = peekable.peek().unwrap();
935        let network = first_elem.borrow().network.clone();
936        let beaver_source: BeaverSource<S> = first_elem.borrow().beaver_source.clone();
937
938        peekable.fold(
939            MpcScalar::from_u64_with_visibility(0, Visibility::Shared, network, beaver_source),
940            |acc, item| acc + item.borrow(),
941        )
942    }
943}
944
945impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> MpcScalar<N, S> {
946    /// Takes a linear combination of the input scalars
947    pub fn linear_combination(
948        scalars: &[MpcScalar<N, S>],
949        coeffs: &[MpcScalar<N, S>],
950    ) -> Result<MpcScalar<N, S>, MpcNetworkError> {
951        Ok(MpcScalar::batch_mul(scalars, coeffs)?.iter().sum())
952    }
953}
954
955impl<N: MpcNetwork + Send, S: SharedValueSource<Scalar>> Zeroize for MpcScalar<N, S> {
956    fn zeroize(&mut self) {
957        self.value.zeroize()
958    }
959}
960
961/// For these tests, we explicitly build a tokio runtime and spawn tests as blocking
962/// tasks within the runtime. This allows us to block in the execution of a test without
963/// blocking the tokio async driver
964#[cfg(test)]
965mod test {
966    use std::{cell::RefCell, rc::Rc};
967
968    use clear_on_drop::clear::Clear;
969    use curve25519_dalek::scalar::Scalar;
970    use rand_core::OsRng;
971    use tokio::runtime::{Builder as RuntimeBuilder, Runtime};
972
973    use crate::{beaver::DummySharedScalarSource, network::dummy_network::DummyMpcNetwork};
974
975    use super::{MpcScalar, Visibility};
976
977    /// Helper to create a tokio runtime that allows the implementation to block on async network
978    /// results
979    fn create_blockable_runtime() -> Runtime {
980        RuntimeBuilder::new_multi_thread()
981            .enable_all()
982            .worker_threads(1)
983            .max_blocking_threads(1)
984            .build()
985            .unwrap()
986    }
987
988    #[test]
989    fn test_zero() {
990        let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
991        let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
992
993        let expected =
994            MpcScalar::from_public_scalar(Scalar::zero(), network.clone(), beaver_source.clone());
995        let zero = MpcScalar::zero(network, beaver_source);
996
997        assert_eq!(zero, expected);
998    }
999
1000    #[test]
1001    fn test_open() {
1002        let rt = create_blockable_runtime();
1003        let handle = rt.spawn_blocking(|| {
1004            let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1005            network
1006                .borrow_mut()
1007                .add_mock_scalars(vec![Scalar::from(1u8)]);
1008
1009            let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1010            let expected = MpcScalar::from_public_scalar(
1011                Scalar::from(2u8),
1012                network.clone(),
1013                beaver_source.clone(),
1014            );
1015
1016            // Dummy network opens to the value we send it, so the mock parties each hold Scalar(1) for a
1017            // shared value of Scalar(2)
1018            let my_share = MpcScalar::from_u64_with_visibility(
1019                1u64,
1020                Visibility::Shared,
1021                network,
1022                beaver_source,
1023            );
1024            assert_eq!(my_share.open().unwrap(), expected);
1025        });
1026
1027        rt.block_on(handle).unwrap();
1028    }
1029
1030    #[test]
1031    fn test_add() {
1032        let rt = create_blockable_runtime();
1033        let handle = rt.spawn_blocking(|| {
1034            let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1035            network
1036                .borrow_mut()
1037                .add_mock_scalars(vec![Scalar::from(2u8)]);
1038
1039            let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1040
1041            // Assume that parties hold a secret share of [4] as individual shares of 2 each
1042            let shared_value1 = MpcScalar::from_u64_with_visibility(
1043                2u64,
1044                Visibility::Shared,
1045                network.clone(),
1046                beaver_source.clone(),
1047            );
1048
1049            // Test adding a scalar value first
1050            let res = &shared_value1 + Scalar::from(3u64); // [4] + 3
1051            assert_eq!(res.visibility, Visibility::Shared);
1052            assert_eq!(
1053                res.open().unwrap(),
1054                MpcScalar::from_public_u64(7u64, network.clone(), beaver_source.clone())
1055            );
1056
1057            // Test adding another shared value
1058            // Assume now that parties have additive shares of [5]
1059            // The peer holds 1, the local party holds 4
1060            let shared_value2 = MpcScalar::from_u64_with_visibility(
1061                4u64,
1062                Visibility::Shared,
1063                network.clone(),
1064                beaver_source.clone(),
1065            );
1066
1067            network
1068                .borrow_mut()
1069                .add_mock_scalars(vec![Scalar::from(3u8)]); // The peer's share of [4] + [5]
1070
1071            let res = shared_value1 + shared_value2;
1072            assert_eq!(res.visibility, Visibility::Shared);
1073            assert_eq!(
1074                res.open().unwrap(),
1075                MpcScalar::from_public_u64(9, network, beaver_source)
1076            )
1077        });
1078
1079        rt.block_on(handle).unwrap();
1080    }
1081
1082    #[test]
1083    fn test_add_associative() {
1084        let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1085        let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1086
1087        // Add two random values, ensure associativity
1088        let mut rng = OsRng {};
1089        let v1 = MpcScalar::random(&mut rng, network, beaver_source);
1090        let v2 = Scalar::random(&mut rng);
1091
1092        let res1 = &v1 + v2;
1093        let res2 = v2 + &v1;
1094
1095        assert_eq!(res1, res2);
1096    }
1097
1098    #[test]
1099    fn test_sub() {
1100        let rt = create_blockable_runtime();
1101        let handle = rt.spawn_blocking(|| {
1102            let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1103            let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1104
1105            // Subtract a raw scalar from a shared value
1106            // Assume parties hold secret shares 2 and 1 of [3]
1107            let shared_value1 = MpcScalar::from_u64_with_visibility(
1108                2u64,
1109                Visibility::Shared,
1110                network.clone(),
1111                beaver_source.clone(),
1112            );
1113            network
1114                .borrow_mut()
1115                .add_mock_scalars(vec![Scalar::from(1u8)]);
1116
1117            let res = &shared_value1 - Scalar::from(2u8);
1118            assert_eq!(res.visibility, Visibility::Shared);
1119            assert_eq!(
1120                res.open().unwrap(),
1121                MpcScalar::from_public_u64(1u64, network.clone(), beaver_source.clone())
1122            );
1123
1124            // Subtract two shared values
1125            let shared_value2 = MpcScalar::from_u64_with_visibility(
1126                5,
1127                Visibility::Shared,
1128                network.clone(),
1129                beaver_source.clone(),
1130            );
1131            network
1132                .borrow_mut()
1133                .add_mock_scalars(vec![Scalar::from(2u8)]);
1134
1135            let res = shared_value2 - shared_value1;
1136            assert_eq!(res.visibility, Visibility::Shared);
1137            assert_eq!(
1138                res.open().unwrap(),
1139                MpcScalar::from_public_u64(5, network, beaver_source)
1140            )
1141        });
1142
1143        rt.block_on(handle).unwrap();
1144    }
1145
1146    #[test]
1147    fn test_mul() {
1148        let rt = create_blockable_runtime();
1149        let handle = rt.spawn_blocking(|| {
1150            let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1151            let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1152
1153            // Multiply a scalar with a shared value
1154            // Assume both parties have a sharing of [11], local party holds 6
1155            let shared_value1 = MpcScalar::from_u64_with_visibility(
1156                6u64,
1157                Visibility::Shared,
1158                network.clone(),
1159                beaver_source.clone(),
1160            );
1161
1162            // Populate the network mock after the multiplication; this implicitly asserts that
1163            // no network call was used for multiplying by a scalar (assumed public)
1164            let res = &shared_value1 * Scalar::from(2u8);
1165            assert_eq!(res.visibility, Visibility::Shared);
1166
1167            network
1168                .borrow_mut()
1169                .add_mock_scalars(vec![Scalar::from(10u8)]);
1170
1171            assert_eq!(
1172                res.open().unwrap(),
1173                MpcScalar::from_public_u64(22, network.clone(), beaver_source.clone())
1174            );
1175
1176            // Multiply a shared value with a public value
1177            let public_value =
1178                MpcScalar::from_public_u64(3u64, network.clone(), beaver_source.clone());
1179
1180            // As above, populate the network mock after the multiplication
1181            let res = public_value * &shared_value1;
1182            assert_eq!(res.visibility, Visibility::Shared);
1183
1184            network
1185                .borrow_mut()
1186                .add_mock_scalars(vec![Scalar::from(15u8)]);
1187            assert_eq!(
1188                res.open().unwrap(),
1189                MpcScalar::from_public_u64(33u64, network.clone(), beaver_source.clone())
1190            );
1191
1192            // Multiply two shared values, a beaver triplet (a, b, c) will be needed
1193            // Populate the network mock with two openings:
1194            //      1. [shared1 - a]
1195            //      2. [shared2 - b]
1196            // Assume that the parties hold [shared2] = [12] where the peer holds 7 and the local holds 5
1197            let shared_value2 = MpcScalar::from_u64_with_visibility(
1198                5u64,
1199                Visibility::Shared,
1200                network.clone(),
1201                beaver_source.clone(),
1202            );
1203            network
1204                .borrow_mut()
1205                .add_mock_scalars(vec![Scalar::from(5u8), Scalar::from(7u8)]);
1206
1207            // Populate the network with the peer's res share after the computation
1208            let res = shared_value1 * shared_value2;
1209            assert_eq!(res.visibility, Visibility::Shared);
1210
1211            network
1212                .borrow_mut()
1213                .add_mock_scalars(vec![Scalar::from(0u64)]);
1214
1215            assert_eq!(
1216                res.open().unwrap(),
1217                MpcScalar::from_public_u64(12 * 11, network, beaver_source)
1218            );
1219        });
1220
1221        rt.block_on(handle).unwrap();
1222    }
1223
1224    #[tokio::test]
1225    async fn test_clear() {
1226        let network = Rc::new(RefCell::new(DummyMpcNetwork::new()));
1227        let beaver_source = Rc::new(RefCell::new(DummySharedScalarSource::new()));
1228        let mut value = MpcScalar::from_public_u64(2, network, beaver_source);
1229
1230        (&mut value).clear();
1231        assert_eq!(value.value(), Scalar::zero());
1232    }
1233}