mpc_stark/algebra/
authenticated_scalar.rs

1//! Defines the authenticated (malicious secure) variant of the MPC scalar type
2
3use std::{
4    fmt::Debug,
5    iter::Sum,
6    ops::{Add, Mul, Neg, Sub},
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use futures::{Future, FutureExt};
12use itertools::{izip, Itertools};
13
14use crate::{
15    commitment::{PedersenCommitment, PedersenCommitmentResult},
16    error::MpcError,
17    fabric::{MpcFabric, ResultId, ResultValue},
18    ResultHandle, PARTY0,
19};
20
21use super::{
22    authenticated_stark_point::AuthenticatedStarkPointResult,
23    macros::{impl_borrow_variants, impl_commutative},
24    mpc_scalar::MpcScalarResult,
25    scalar::{BatchScalarResult, Scalar, ScalarResult},
26    stark_curve::{StarkPoint, StarkPointResult},
27};
28
29/// The number of results wrapped by an `AuthenticatedScalarResult`
30pub const AUTHENTICATED_SCALAR_RESULT_LEN: usize = 3;
31
32/// A maliciously secure wrapper around an `MpcScalarResult`, includes a MAC as per the
33/// SPDZ protocol: https://eprint.iacr.org/2011/535.pdf
34/// that ensures security against a malicious adversary
35#[derive(Clone)]
36pub struct AuthenticatedScalarResult {
37    /// The secret shares of the underlying value
38    pub(crate) share: MpcScalarResult,
39    /// The SPDZ style, unconditionally secure MAC of the value
40    ///
41    /// If the value is `x`, parties hold secret shares of the value
42    /// \delta * x for the global MAC key `\delta`. The parties individually
43    /// hold secret shares of this MAC key [\delta], so we can very naturally
44    /// extend the secret share arithmetic of the underlying `MpcScalarResult` to
45    /// the MAC updates as well
46    pub(crate) mac: MpcScalarResult,
47    /// The public modifier tracks additions and subtractions of public values to the
48    /// underlying value. This is necessary because in the case of a public addition, only the first
49    /// party adds the public value to their share, so the second party must track this up
50    /// until the point that the value is opened and the MAC is checked
51    pub(crate) public_modifier: ScalarResult,
52}
53
54impl Debug for AuthenticatedScalarResult {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("AuthenticatedScalarResult")
57            .field("value", &self.share.id())
58            .field("mac", &self.mac.id())
59            .field("public_modifier", &self.public_modifier.id)
60            .finish()
61    }
62}
63
64impl AuthenticatedScalarResult {
65    /// Create a new result from the given shared value
66    pub fn new_shared(value: ScalarResult) -> Self {
67        // Create an `MpcScalarResult` to represent the fact that this is a shared value
68        let fabric = value.fabric.clone();
69
70        let mpc_value = MpcScalarResult::new_shared(value);
71        let mac = fabric.borrow_mac_key() * mpc_value.clone();
72
73        // Allocate a zero for the public modifier
74        let public_modifier = fabric.zero();
75
76        Self {
77            share: mpc_value,
78            mac,
79            public_modifier,
80        }
81    }
82
83    /// Create a new batch of shared values
84    pub fn new_shared_batch(values: &[ScalarResult]) -> Vec<Self> {
85        if values.is_empty() {
86            return vec![];
87        }
88
89        let n = values.len();
90        let fabric = values[0].fabric();
91        let mpc_values = values
92            .iter()
93            .map(|v| MpcScalarResult::new_shared(v.clone()))
94            .collect_vec();
95
96        let mac_keys = (0..n)
97            .map(|_| fabric.borrow_mac_key().clone())
98            .collect_vec();
99        let values_macs = MpcScalarResult::batch_mul(&mpc_values, &mac_keys);
100
101        mpc_values
102            .into_iter()
103            .zip(values_macs.into_iter())
104            .map(|(value, mac)| Self {
105                share: value,
106                mac,
107                public_modifier: fabric.zero(),
108            })
109            .collect_vec()
110    }
111
112    /// Create a nwe shared batch of values from a batch network result
113    ///
114    /// The batch result combines the batch into one result, so it must be split out
115    /// first before creating the `AuthenticatedScalarResult`s
116    pub fn new_shared_from_batch_result(
117        values: BatchScalarResult,
118        n: usize,
119    ) -> Vec<AuthenticatedScalarResult> {
120        // Convert to a set of scalar results
121        let scalar_results = values
122            .fabric()
123            .new_batch_gate_op(vec![values.id()], n, |mut args| {
124                let scalars: Vec<Scalar> = args.pop().unwrap().into();
125                scalars.into_iter().map(ResultValue::Scalar).collect()
126            });
127
128        Self::new_shared_batch(&scalar_results)
129    }
130
131    /// Get the raw share as an `MpcScalarResult`
132    #[cfg(feature = "test_helpers")]
133    pub fn mpc_share(&self) -> MpcScalarResult {
134        self.share.clone()
135    }
136
137    /// Get the raw share as a `ScalarResult`
138    pub fn share(&self) -> ScalarResult {
139        self.share.to_scalar()
140    }
141
142    /// Get a reference to the underlying MPC fabric
143    pub fn fabric(&self) -> &MpcFabric {
144        self.share.fabric()
145    }
146
147    /// Get the ids of the results that must be awaited
148    /// before the value is ready
149    pub fn ids(&self) -> Vec<ResultId> {
150        vec![self.share.id(), self.mac.id(), self.public_modifier.id]
151    }
152
153    /// Open the value without checking its MAC
154    pub fn open(&self) -> ScalarResult {
155        self.share.open()
156    }
157
158    /// Open a batch of values without checking their MACs
159    pub fn open_batch(values: &[Self]) -> Vec<ScalarResult> {
160        MpcScalarResult::open_batch(&values.iter().map(|val| val.share.clone()).collect_vec())
161    }
162
163    /// Convert a flattened iterator into a batch of `AuthenticatedScalarResult`s
164    ///
165    /// We assume that the iterator has been flattened in the same way order that `Self::id`s returns
166    /// the `AuthenticatedScalar`'s values: `[share, mac, public_modifier]`
167    pub fn from_flattened_iterator<I>(iter: I) -> Vec<Self>
168    where
169        I: Iterator<Item = ResultHandle<Scalar>>,
170    {
171        iter.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
172            .into_iter()
173            .map(|mut chunk| Self {
174                share: chunk.next().unwrap().into(),
175                mac: chunk.next().unwrap().into(),
176                public_modifier: chunk.next().unwrap(),
177            })
178            .collect_vec()
179    }
180
181    /// Check the commitment to a MAC check and that the MAC checks sum to zero
182    pub fn verify_mac_check(
183        my_mac_share: Scalar,
184        peer_mac_share: Scalar,
185        peer_mac_commitment: StarkPoint,
186        peer_commitment_blinder: Scalar,
187    ) -> bool {
188        let their_comm = PedersenCommitment {
189            value: peer_mac_share,
190            blinder: peer_commitment_blinder,
191            commitment: peer_mac_commitment,
192        };
193
194        // Verify that the commitment to the MAC check opens correctly
195        if !their_comm.verify() {
196            return false;
197        }
198
199        // Sum of the commitments should be zero
200        if peer_mac_share + my_mac_share != Scalar::from(0) {
201            return false;
202        }
203
204        true
205    }
206
207    /// Open the value and check its MAC
208    ///
209    /// This follows the protocol detailed in:
210    ///     https://securecomputation.org/docs/pragmaticmpc.pdf
211    /// Section 6.6.2
212    pub fn open_authenticated(&self) -> AuthenticatedScalarOpenResult {
213        // Both parties open the underlying value
214        let recovered_value = self.share.open();
215
216        // Add a gate to compute the MAC check value: `key_share * opened_value - mac_share`
217        let mac_check_value: ScalarResult = self.fabric().new_gate_op(
218            vec![
219                self.fabric().borrow_mac_key().id(),
220                recovered_value.id,
221                self.public_modifier.id,
222                self.mac.id(),
223            ],
224            move |mut args| {
225                let mac_key_share: Scalar = args.remove(0).into();
226                let value: Scalar = args.remove(0).into();
227                let modifier: Scalar = args.remove(0).into();
228                let mac_share: Scalar = args.remove(0).into();
229
230                ResultValue::Scalar(mac_key_share * (value + modifier) - mac_share)
231            },
232        );
233
234        // Compute a commitment to this value and share it with the peer
235        let my_comm = PedersenCommitmentResult::commit(mac_check_value);
236        let peer_commit = self.fabric().exchange_value(my_comm.commitment);
237
238        // Once the parties have exchanged their commitments, they can open them, they have already exchanged
239        // the underlying values and their commitments so all that is left is the blinder
240        let peer_mac_check = self.fabric().exchange_value(my_comm.value.clone());
241
242        let blinder_result: ScalarResult = self.fabric().allocate_scalar(my_comm.blinder);
243        let peer_blinder = self.fabric().exchange_value(blinder_result);
244
245        // Check the commitment and the MAC result
246        let commitment_check: ScalarResult = self.fabric().new_gate_op(
247            vec![
248                my_comm.value.id,
249                peer_mac_check.id,
250                peer_blinder.id,
251                peer_commit.id,
252            ],
253            |mut args| {
254                let my_comm_value: Scalar = args.remove(0).into();
255                let peer_value: Scalar = args.remove(0).into();
256                let blinder: Scalar = args.remove(0).into();
257                let commitment: StarkPoint = args.remove(0).into();
258
259                // Build a commitment from the gate inputs
260                ResultValue::Scalar(Scalar::from(Self::verify_mac_check(
261                    my_comm_value,
262                    peer_value,
263                    commitment,
264                    blinder,
265                )))
266            },
267        );
268
269        AuthenticatedScalarOpenResult {
270            value: recovered_value,
271            mac_check: commitment_check,
272        }
273    }
274
275    /// Open a batch of values and check their MACs
276    pub fn open_authenticated_batch(values: &[Self]) -> Vec<AuthenticatedScalarOpenResult> {
277        if values.is_empty() {
278            return vec![];
279        }
280
281        let n = values.len();
282        let fabric = &values[0].fabric();
283
284        // Both parties open the underlying values
285        let values_open = Self::open_batch(values);
286
287        // --- Mac Checks --- //
288
289        // Compute the shares of the MAC check in batch
290        let mut mac_check_deps = Vec::with_capacity(1 + 3 * n);
291        mac_check_deps.push(fabric.borrow_mac_key().id());
292        for i in 0..n {
293            mac_check_deps.push(values_open[i].id());
294            mac_check_deps.push(values[i].public_modifier.id());
295            mac_check_deps.push(values[i].mac.id());
296        }
297
298        let mac_checks: Vec<ScalarResult> =
299            fabric.new_batch_gate_op(mac_check_deps, n /* output_arity */, move |mut args| {
300                let mac_key_share: Scalar = args.remove(0).into();
301                let mut check_result = Vec::with_capacity(n);
302
303                for _ in 0..n {
304                    let value: Scalar = args.remove(0).into();
305                    let modifier: Scalar = args.remove(0).into();
306                    let mac_share: Scalar = args.remove(0).into();
307
308                    check_result.push(mac_key_share * (value + modifier) - mac_share);
309                }
310
311                check_result.into_iter().map(ResultValue::Scalar).collect()
312            });
313
314        // --- Commit to MAC Checks --- //
315
316        let my_comms = mac_checks
317            .iter()
318            .cloned()
319            .map(PedersenCommitmentResult::commit)
320            .collect_vec();
321        let peer_comms = fabric.exchange_values(
322            &my_comms
323                .iter()
324                .map(|comm| comm.commitment.clone())
325                .collect_vec(),
326        );
327
328        // --- Exchange the MAC Checks and Commitment Blinders --- //
329
330        let peer_mac_checks = fabric.exchange_values(&mac_checks);
331        let peer_blinders = fabric.exchange_values(
332            &my_comms
333                .iter()
334                .map(|comm| fabric.allocate_scalar(comm.blinder))
335                .collect_vec(),
336        );
337
338        // --- Check the MAC Checks --- //
339
340        let mut mac_check_gate_deps = my_comms.iter().map(|comm| comm.value.id).collect_vec();
341        mac_check_gate_deps.push(peer_mac_checks.id);
342        mac_check_gate_deps.push(peer_blinders.id);
343        mac_check_gate_deps.push(peer_comms.id);
344
345        let commitment_checks: Vec<ScalarResult> = fabric.new_batch_gate_op(
346            mac_check_gate_deps,
347            n, /* output_arity */
348            move |mut args| {
349                let my_comms: Vec<Scalar> = args.drain(..n).map(|comm| comm.into()).collect();
350                let peer_mac_checks: Vec<Scalar> = args.remove(0).into();
351                let peer_blinders: Vec<Scalar> = args.remove(0).into();
352                let peer_comms: Vec<StarkPoint> = args.remove(0).into();
353
354                // Build a commitment from the gate inputs
355                let mut mac_checks = Vec::with_capacity(n);
356                for (my_mac_share, peer_mac_share, peer_blinder, peer_commitment) in izip!(
357                    my_comms.into_iter(),
358                    peer_mac_checks.into_iter(),
359                    peer_blinders.into_iter(),
360                    peer_comms.into_iter()
361                ) {
362                    let mac_check = Self::verify_mac_check(
363                        my_mac_share,
364                        peer_mac_share,
365                        peer_commitment,
366                        peer_blinder,
367                    );
368                    mac_checks.push(ResultValue::Scalar(Scalar::from(mac_check)));
369                }
370
371                mac_checks
372            },
373        );
374
375        // --- Return the results --- //
376
377        values_open
378            .into_iter()
379            .zip(commitment_checks.into_iter())
380            .map(|(value, check)| AuthenticatedScalarOpenResult {
381                value,
382                mac_check: check,
383            })
384            .collect_vec()
385    }
386}
387
388/// The value that results from opening an `AuthenticatedScalarResult` and checking its
389/// MAC. This encapsulates both the underlying value and the result of the MAC check
390#[derive(Clone)]
391pub struct AuthenticatedScalarOpenResult {
392    /// The underlying value
393    pub value: ScalarResult,
394    /// The result of the MAC check
395    pub mac_check: ScalarResult,
396}
397
398impl Future for AuthenticatedScalarOpenResult {
399    type Output = Result<Scalar, MpcError>;
400
401    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
402        // Await both of the underlying values
403        let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
404        let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));
405
406        if mac_check == Scalar::from(1) {
407            Poll::Ready(Ok(value))
408        } else {
409            Poll::Ready(Err(MpcError::AuthenticationError))
410        }
411    }
412}
413
414// --------------
415// | Arithmetic |
416// --------------
417
418// === Addition === //
419
420impl Add<&Scalar> for &AuthenticatedScalarResult {
421    type Output = AuthenticatedScalarResult;
422
423    fn add(self, rhs: &Scalar) -> Self::Output {
424        let new_share = if self.fabric().party_id() == PARTY0 {
425            &self.share + rhs
426        } else {
427            &self.share + Scalar::from(0)
428        };
429
430        // Both parties add the public value to their modifier, and the MACs do not change
431        // when adding a public value
432        let new_modifier = &self.public_modifier - rhs;
433        AuthenticatedScalarResult {
434            share: new_share,
435            mac: self.mac.clone(),
436            public_modifier: new_modifier,
437        }
438    }
439}
440impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, Scalar, Output=AuthenticatedScalarResult);
441impl_commutative!(AuthenticatedScalarResult, Add, add, +, Scalar, Output=AuthenticatedScalarResult);
442
443impl Add<&ScalarResult> for &AuthenticatedScalarResult {
444    type Output = AuthenticatedScalarResult;
445
446    fn add(self, rhs: &ScalarResult) -> Self::Output {
447        // As above, only party 0 adds the public value to their share, but both parties
448        // track this with the modifier
449        //
450        // Party 1 adds a zero value to their share to allocate a new ID for the result
451        let new_share = if self.fabric().party_id() == PARTY0 {
452            &self.share + rhs
453        } else {
454            &self.share + Scalar::from(0)
455        };
456
457        let new_modifier = &self.public_modifier - rhs;
458        AuthenticatedScalarResult {
459            share: new_share,
460            mac: self.mac.clone(),
461            public_modifier: new_modifier,
462        }
463    }
464}
465impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, ScalarResult, Output=AuthenticatedScalarResult);
466impl_commutative!(AuthenticatedScalarResult, Add, add, +, ScalarResult, Output=AuthenticatedScalarResult);
467
468impl Add<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
469    type Output = AuthenticatedScalarResult;
470
471    fn add(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
472        AuthenticatedScalarResult {
473            share: &self.share + &rhs.share,
474            mac: &self.mac + &rhs.mac,
475            public_modifier: self.public_modifier.clone() + rhs.public_modifier.clone(),
476        }
477    }
478}
479impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
480
481impl AuthenticatedScalarResult {
482    /// Add two batches of `AuthenticatedScalarResult`s
483    pub fn batch_add(
484        a: &[AuthenticatedScalarResult],
485        b: &[AuthenticatedScalarResult],
486    ) -> Vec<AuthenticatedScalarResult> {
487        assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
488
489        let n = a.len();
490        let fabric = a[0].fabric();
491        let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
492
493        // Add the underlying values
494        let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
495            all_ids,
496            AUTHENTICATED_SCALAR_RESULT_LEN * n, /* output_arity */
497            move |mut args| {
498                let arg_len = args.len();
499                let a_vals = args.drain(..arg_len / 2).collect_vec();
500                let b_vals = args;
501
502                let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
503                for (mut a_vals, mut b_vals) in a_vals
504                    .into_iter()
505                    .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
506                    .into_iter()
507                    .zip(
508                        b_vals
509                            .into_iter()
510                            .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
511                            .into_iter(),
512                    )
513                {
514                    let a_share: Scalar = a_vals.next().unwrap().into();
515                    let a_mac_share: Scalar = a_vals.next().unwrap().into();
516                    let a_modifier: Scalar = a_vals.next().unwrap().into();
517
518                    let b_share: Scalar = b_vals.next().unwrap().into();
519                    let b_mac_share: Scalar = b_vals.next().unwrap().into();
520                    let b_modifier: Scalar = b_vals.next().unwrap().into();
521
522                    result.push(ResultValue::Scalar(a_share + b_share));
523                    result.push(ResultValue::Scalar(a_mac_share + b_mac_share));
524                    result.push(ResultValue::Scalar(a_modifier + b_modifier));
525                }
526
527                result
528            },
529        );
530
531        // Collect the gate results into a series of `AuthenticatedScalarResult`s
532        AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
533    }
534
535    /// Add a batch of `AuthenticatedScalarResult`s to a batch of `ScalarResult`s
536    pub fn batch_add_public(
537        a: &[AuthenticatedScalarResult],
538        b: &[ScalarResult],
539    ) -> Vec<AuthenticatedScalarResult> {
540        assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
541
542        let n = a.len();
543        let results_per_value = 3;
544        let fabric = a[0].fabric();
545        let all_ids = a
546            .iter()
547            .flat_map(|v| v.ids())
548            .chain(b.iter().map(|v| v.id()))
549            .collect_vec();
550
551        // Add the underlying values
552        let party_id = fabric.party_id();
553        let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
554            all_ids,
555            results_per_value * n, /* output_arity */
556            move |mut args| {
557                // Split the args
558                let a_vals = args
559                    .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
560                    .collect_vec();
561                let public_values = args;
562
563                let mut result = Vec::with_capacity(results_per_value * n);
564                for (mut a_vals, public_value) in a_vals
565                    .into_iter()
566                    .chunks(results_per_value)
567                    .into_iter()
568                    .zip(public_values.into_iter())
569                {
570                    let a_share: Scalar = a_vals.next().unwrap().into();
571                    let a_mac_share: Scalar = a_vals.next().unwrap().into();
572                    let a_modifier: Scalar = a_vals.next().unwrap().into();
573
574                    let public_value: Scalar = public_value.into();
575
576                    // Only the first party adds the public value to their share
577                    if party_id == PARTY0 {
578                        result.push(ResultValue::Scalar(a_share + public_value));
579                    } else {
580                        result.push(ResultValue::Scalar(a_share));
581                    }
582
583                    result.push(ResultValue::Scalar(a_mac_share));
584                    result.push(ResultValue::Scalar(a_modifier - public_value));
585                }
586
587                result
588            },
589        );
590
591        // Collect the gate results into a series of `AuthenticatedScalarResult`s
592        AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
593    }
594}
595
596/// TODO: Maybe use a batch gate for this; performance depends on whether materializing the
597/// iterator is burdensome
598impl Sum for AuthenticatedScalarResult {
599    /// Assumes the iterator is non-empty
600    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
601        let seed = iter.next().expect("Cannot sum empty iterator");
602        iter.fold(seed, |acc, val| acc + &val)
603    }
604}
605
606// === Subtraction === //
607
608impl Sub<&Scalar> for &AuthenticatedScalarResult {
609    type Output = AuthenticatedScalarResult;
610
611    /// As in the case for addition, only party 0 subtracts the public value from their share,
612    /// but both parties track this in the public modifier
613    fn sub(self, rhs: &Scalar) -> Self::Output {
614        // Party 1 subtracts a zero value from their share to allocate a new ID for the result
615        // and stay in sync with party 0
616        let new_share = &self.share - rhs;
617
618        // Both parties add the public value to their modifier, and the MACs do not change
619        // when adding a public value
620        let new_modifier = &self.public_modifier + rhs;
621        AuthenticatedScalarResult {
622            share: new_share,
623            mac: self.mac.clone(),
624            public_modifier: new_modifier,
625        }
626    }
627}
628impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, Scalar, Output=AuthenticatedScalarResult);
629
630impl Sub<&AuthenticatedScalarResult> for &Scalar {
631    type Output = AuthenticatedScalarResult;
632
633    fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
634        // Party 1 subtracts a zero value from their share to allocate a new ID for the result
635        // and stay in sync with party 0
636        let new_share = self - &rhs.share;
637
638        // Both parties add the public value to their modifier, and the MACs do not change
639        // when adding a public value
640        let new_modifier = -self - &rhs.public_modifier;
641        AuthenticatedScalarResult {
642            share: new_share,
643            mac: -&rhs.mac,
644            public_modifier: new_modifier,
645        }
646    }
647}
648impl_borrow_variants!(Scalar, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
649
650impl Sub<&ScalarResult> for &AuthenticatedScalarResult {
651    type Output = AuthenticatedScalarResult;
652
653    fn sub(self, rhs: &ScalarResult) -> Self::Output {
654        let new_share = &self.share - rhs;
655
656        // Both parties add the public value to their modifier, and the MACs do not change
657        // when adding a public value
658        let new_modifier = &self.public_modifier + rhs;
659        AuthenticatedScalarResult {
660            share: new_share,
661            mac: self.mac.clone(),
662            public_modifier: new_modifier,
663        }
664    }
665}
666impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, ScalarResult, Output=AuthenticatedScalarResult);
667
668impl Sub<&AuthenticatedScalarResult> for &ScalarResult {
669    type Output = AuthenticatedScalarResult;
670
671    fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
672        // Party 1 subtracts a zero value from their share to allocate a new ID for the result
673        // and stay in sync with party 0
674        let new_share = self - &rhs.share;
675
676        // Both parties add the public value to their modifier, and the MACs do not change
677        // when adding a public value
678        let new_modifier = -self - &rhs.public_modifier;
679        AuthenticatedScalarResult {
680            share: new_share,
681            mac: -&rhs.mac,
682            public_modifier: new_modifier,
683        }
684    }
685}
686impl_borrow_variants!(ScalarResult, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
687
688impl Sub<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
689    type Output = AuthenticatedScalarResult;
690
691    fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
692        AuthenticatedScalarResult {
693            share: &self.share - &rhs.share,
694            mac: &self.mac - &rhs.mac,
695            public_modifier: self.public_modifier.clone() - rhs.public_modifier.clone(),
696        }
697    }
698}
699impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
700
701impl AuthenticatedScalarResult {
702    /// Add two batches of `AuthenticatedScalarResult`s
703    pub fn batch_sub(
704        a: &[AuthenticatedScalarResult],
705        b: &[AuthenticatedScalarResult],
706    ) -> Vec<AuthenticatedScalarResult> {
707        assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
708
709        let n = a.len();
710        let fabric = &a[0].fabric();
711        let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
712
713        // Add the underlying values
714        let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
715            all_ids,
716            AUTHENTICATED_SCALAR_RESULT_LEN * n, /* output_arity */
717            move |mut args| {
718                let arg_len = args.len();
719                let a_vals = args.drain(..arg_len / 2).collect_vec();
720                let b_vals = args;
721
722                let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
723                for (mut a_vals, mut b_vals) in a_vals
724                    .into_iter()
725                    .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
726                    .into_iter()
727                    .zip(
728                        b_vals
729                            .into_iter()
730                            .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
731                            .into_iter(),
732                    )
733                {
734                    let a_share: Scalar = a_vals.next().unwrap().into();
735                    let a_mac_share: Scalar = a_vals.next().unwrap().into();
736                    let a_modifier: Scalar = a_vals.next().unwrap().into();
737
738                    let b_share: Scalar = b_vals.next().unwrap().into();
739                    let b_mac_share: Scalar = b_vals.next().unwrap().into();
740                    let b_modifier: Scalar = b_vals.next().unwrap().into();
741
742                    result.push(ResultValue::Scalar(a_share - b_share));
743                    result.push(ResultValue::Scalar(a_mac_share - b_mac_share));
744                    result.push(ResultValue::Scalar(a_modifier - b_modifier));
745                }
746
747                result
748            },
749        );
750
751        // Collect the gate results into a series of `AuthenticatedScalarResult`s
752        AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
753    }
754
755    /// Subtract a batch of `ScalarResult`s from a batch of `AuthenticatedScalarResult`s
756    pub fn batch_sub_public(
757        a: &[AuthenticatedScalarResult],
758        b: &[ScalarResult],
759    ) -> Vec<AuthenticatedScalarResult> {
760        assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
761
762        let n = a.len();
763        let results_per_value = 3;
764        let fabric = a[0].fabric();
765        let all_ids = a
766            .iter()
767            .flat_map(|v| v.ids())
768            .chain(b.iter().map(|v| v.id()))
769            .collect_vec();
770
771        // Add the underlying values
772        let party_id = fabric.party_id();
773        let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
774            all_ids,
775            results_per_value * n, /* output_arity */
776            move |mut args| {
777                // Split the args
778                let a_vals = args
779                    .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
780                    .collect_vec();
781                let public_values = args;
782
783                let mut result = Vec::with_capacity(results_per_value * n);
784                for (mut a_vals, public_value) in a_vals
785                    .into_iter()
786                    .chunks(results_per_value)
787                    .into_iter()
788                    .zip(public_values.into_iter())
789                {
790                    let a_share: Scalar = a_vals.next().unwrap().into();
791                    let a_mac_share: Scalar = a_vals.next().unwrap().into();
792                    let a_modifier: Scalar = a_vals.next().unwrap().into();
793
794                    let public_value: Scalar = public_value.into();
795
796                    // Only the first party adds the public value to their share
797                    if party_id == PARTY0 {
798                        result.push(ResultValue::Scalar(a_share - public_value));
799                    } else {
800                        result.push(ResultValue::Scalar(a_share));
801                    }
802
803                    result.push(ResultValue::Scalar(a_mac_share));
804                    result.push(ResultValue::Scalar(a_modifier + public_value));
805                }
806
807                result
808            },
809        );
810
811        // Collect the gate results into a series of `AuthenticatedScalarResult`s
812        AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
813    }
814}
815
816// === Negation === //
817
818impl Neg for &AuthenticatedScalarResult {
819    type Output = AuthenticatedScalarResult;
820
821    fn neg(self) -> Self::Output {
822        AuthenticatedScalarResult {
823            share: -&self.share,
824            mac: -&self.mac,
825            public_modifier: -&self.public_modifier,
826        }
827    }
828}
829impl_borrow_variants!(AuthenticatedScalarResult, Neg, neg, -);
830
831impl AuthenticatedScalarResult {
832    /// Negate a batch of `AuthenticatedScalarResult`s
833    pub fn batch_neg(a: &[AuthenticatedScalarResult]) -> Vec<AuthenticatedScalarResult> {
834        if a.is_empty() {
835            return vec![];
836        }
837
838        let n = a.len();
839        let fabric = a[0].fabric();
840        let all_ids = a.iter().flat_map(|v| v.ids()).collect_vec();
841
842        let scalars = fabric.new_batch_gate_op(
843            all_ids,
844            AUTHENTICATED_SCALAR_RESULT_LEN * n, /* output_arity */
845            |args| {
846                args.into_iter()
847                    .map(|arg| ResultValue::Scalar(-Scalar::from(arg)))
848                    .collect()
849            },
850        );
851
852        AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
853    }
854}
855
856// === Multiplication === //
857
858impl Mul<&Scalar> for &AuthenticatedScalarResult {
859    type Output = AuthenticatedScalarResult;
860
861    fn mul(self, rhs: &Scalar) -> Self::Output {
862        AuthenticatedScalarResult {
863            share: &self.share * rhs,
864            mac: &self.mac * rhs,
865            public_modifier: &self.public_modifier * rhs,
866        }
867    }
868}
869impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, Scalar, Output=AuthenticatedScalarResult);
870impl_commutative!(AuthenticatedScalarResult, Mul, mul, *, Scalar, Output=AuthenticatedScalarResult);
871
872impl Mul<&ScalarResult> for &AuthenticatedScalarResult {
873    type Output = AuthenticatedScalarResult;
874
875    fn mul(self, rhs: &ScalarResult) -> Self::Output {
876        AuthenticatedScalarResult {
877            share: &self.share * rhs,
878            mac: &self.mac * rhs,
879            public_modifier: &self.public_modifier * rhs,
880        }
881    }
882}
883impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, ScalarResult, Output=AuthenticatedScalarResult);
884impl_commutative!(AuthenticatedScalarResult, Mul, mul, *, ScalarResult, Output=AuthenticatedScalarResult);
885
886impl Mul<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
887    type Output = AuthenticatedScalarResult;
888
889    // Use the Beaver trick
890    fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
891        // Sample a beaver triplet
892        let (a, b, c) = self.fabric().next_authenticated_triple();
893
894        // Mask the left and right hand sides
895        let masked_lhs = self - &a;
896        let masked_rhs = rhs - &b;
897
898        // Open these values to get d = lhs - a, e = rhs - b
899        let d = masked_lhs.open();
900        let e = masked_rhs.open();
901
902        // Use the same beaver identify as in the `MpcScalarResult` case, but now the public
903        // multiplications are applied to the MACs and the public modifiers as well
904        // Identity: [x * y] = de + d[b] + e[a] + [c]
905        &d * &e + d * b + e * a + c
906    }
907}
908impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
909
910impl AuthenticatedScalarResult {
911    /// Multiply a batch of values using the Beaver trick
912    pub fn batch_mul(
913        a: &[AuthenticatedScalarResult],
914        b: &[AuthenticatedScalarResult],
915    ) -> Vec<AuthenticatedScalarResult> {
916        assert_eq!(
917            a.len(),
918            b.len(),
919            "Cannot multiply batches of different sizes"
920        );
921
922        if a.is_empty() {
923            return vec![];
924        }
925
926        let n = a.len();
927        let fabric = a[0].fabric();
928        let (beaver_a, beaver_b, beaver_c) = fabric.next_authenticated_triple_batch(n);
929
930        // Open the values d = [lhs - a] and e = [rhs - b]
931        let masked_lhs = AuthenticatedScalarResult::batch_sub(a, &beaver_a);
932        let masked_rhs = AuthenticatedScalarResult::batch_sub(b, &beaver_b);
933
934        let all_masks = [masked_lhs, masked_rhs].concat();
935        let opened_values = AuthenticatedScalarResult::open_batch(&all_masks);
936        let (d_open, e_open) = opened_values.split_at(n);
937
938        // Identity: [x * y] = de + d[b] + e[a] + [c]
939        let de = ScalarResult::batch_mul(d_open, e_open);
940        let db = AuthenticatedScalarResult::batch_mul_public(&beaver_b, d_open);
941        let ea = AuthenticatedScalarResult::batch_mul_public(&beaver_a, e_open);
942
943        // Add the terms
944        let de_plus_db = AuthenticatedScalarResult::batch_add_public(&db, &de);
945        let ea_plus_c = AuthenticatedScalarResult::batch_add(&ea, &beaver_c);
946        AuthenticatedScalarResult::batch_add(&de_plus_db, &ea_plus_c)
947    }
948
949    /// Multiply a batch of `AuthenticatedScalarResult`s by a batch of `ScalarResult`s
950    pub fn batch_mul_public(
951        a: &[AuthenticatedScalarResult],
952        b: &[ScalarResult],
953    ) -> Vec<AuthenticatedScalarResult> {
954        assert_eq!(
955            a.len(),
956            b.len(),
957            "Cannot multiply batches of different sizes"
958        );
959        if a.is_empty() {
960            return vec![];
961        }
962
963        let n = a.len();
964        let fabric = a[0].fabric();
965        let all_ids = a
966            .iter()
967            .flat_map(|a| a.ids())
968            .chain(b.iter().map(|b| b.id()))
969            .collect_vec();
970
971        let scalars = fabric.new_batch_gate_op(
972            all_ids,
973            AUTHENTICATED_SCALAR_RESULT_LEN * n, /* output_arity */
974            move |mut args| {
975                let a_vals = args
976                    .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
977                    .collect_vec();
978                let public_values = args;
979
980                let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
981                for (a_vals, public_values) in a_vals
982                    .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
983                    .zip(public_values.into_iter())
984                {
985                    let a_share: Scalar = a_vals[0].to_owned().into();
986                    let a_mac_share: Scalar = a_vals[1].to_owned().into();
987                    let a_modifier: Scalar = a_vals[2].to_owned().into();
988
989                    let public_value: Scalar = public_values.into();
990
991                    result.push(ResultValue::Scalar(a_share * public_value));
992                    result.push(ResultValue::Scalar(a_mac_share * public_value));
993                    result.push(ResultValue::Scalar(a_modifier * public_value));
994                }
995
996                result
997            },
998        );
999
1000        AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
1001    }
1002}
1003
1004// === Curve Scalar Multiplication === //
1005
1006impl Mul<&AuthenticatedScalarResult> for &StarkPoint {
1007    type Output = AuthenticatedStarkPointResult;
1008
1009    fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
1010        AuthenticatedStarkPointResult {
1011            share: self * &rhs.share,
1012            mac: self * &rhs.mac,
1013            public_modifier: self * &rhs.public_modifier,
1014        }
1015    }
1016}
1017impl_commutative!(StarkPoint, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
1018
1019impl Mul<&AuthenticatedScalarResult> for &StarkPointResult {
1020    type Output = AuthenticatedStarkPointResult;
1021
1022    fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
1023        AuthenticatedStarkPointResult {
1024            share: self * &rhs.share,
1025            mac: self * &rhs.mac,
1026            public_modifier: self * &rhs.public_modifier,
1027        }
1028    }
1029}
1030impl_borrow_variants!(StarkPointResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
1031impl_commutative!(StarkPointResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
1032
1033// ----------------
1034// | Test Helpers |
1035// ----------------
1036
1037/// Contains unsafe helpers for modifying values, methods in this module should *only* be used
1038/// for testing
1039#[cfg(feature = "test_helpers")]
1040pub mod test_helpers {
1041    use crate::algebra::scalar::Scalar;
1042
1043    use super::AuthenticatedScalarResult;
1044
1045    /// Modify the MAC of an `AuthenticatedScalarResult`
1046    pub fn modify_mac(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
1047        val.mac = val.fabric().allocate_scalar(new_value).into()
1048    }
1049
1050    /// Modify the underlying secret share of an `AuthenticatedScalarResult`
1051    pub fn modify_share(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
1052        val.share = val.fabric().allocate_scalar(new_value).into()
1053    }
1054
1055    /// Modify the public modifier of an `AuthenticatedScalarResult` by adding an offset
1056    pub fn modify_public_modifier(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
1057        val.public_modifier = val.fabric().allocate_scalar(new_value)
1058    }
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use rand::thread_rng;
1064
1065    use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc, PARTY0};
1066
1067    /// Test subtraction across non-commutative types
1068    #[tokio::test]
1069    async fn test_sub() {
1070        let mut rng = thread_rng();
1071        let value1 = Scalar::random(&mut rng);
1072        let value2 = Scalar::random(&mut rng);
1073
1074        let (res, _) = execute_mock_mpc(|fabric| async move {
1075            // Allocate the first value as a shared scalar and the second as a public scalar
1076            let party0_value = fabric.share_scalar(value1, PARTY0);
1077            let public_value = fabric.allocate_scalar(value2);
1078
1079            // Subtract the public value from the shared value
1080            let res1 = &party0_value - &public_value;
1081            let res_open1 = res1.open_authenticated().await.unwrap();
1082            let expected1 = value1 - value2;
1083
1084            // Subtract the shared value from the public value
1085            let res2 = &public_value - &party0_value;
1086            let res_open2 = res2.open_authenticated().await.unwrap();
1087            let expected2 = value2 - value1;
1088
1089            (res_open1 == expected1, res_open2 == expected2)
1090        })
1091        .await;
1092
1093        assert!(res.0);
1094        assert!(res.1)
1095    }
1096
1097    /// Tests subtraction with a constant value outside of the fabric
1098    #[tokio::test]
1099    async fn test_sub_constant() {
1100        let mut rng = thread_rng();
1101        let value1 = Scalar::random(&mut rng);
1102        let value2 = Scalar::random(&mut rng);
1103
1104        let (res, _) = execute_mock_mpc(|fabric| async move {
1105            // Allocate the first value as a shared scalar and the second as a public scalar
1106            let party0_value = fabric.share_scalar(value1, PARTY0);
1107
1108            // Subtract the public value from the shared value
1109            let res1 = &party0_value - value2;
1110            let res_open1 = res1.open_authenticated().await.unwrap();
1111            let expected1 = value1 - value2;
1112
1113            // Subtract the shared value from the public value
1114            let res2 = value2 - &party0_value;
1115            let res_open2 = res2.open_authenticated().await.unwrap();
1116            let expected2 = value2 - value1;
1117
1118            (res_open1 == expected1, res_open2 == expected2)
1119        })
1120        .await;
1121
1122        assert!(res.0);
1123        assert!(res.1)
1124    }
1125
1126    /// Test a simple `XOR` circuit
1127    #[tokio::test]
1128    async fn test_xor_circuit() {
1129        let (res, _) = execute_mock_mpc(|fabric| async move {
1130            let a = &fabric.zero_authenticated();
1131            let b = &fabric.zero_authenticated();
1132            let res = a + b - Scalar::from(2u64) * a * b;
1133
1134            res.open_authenticated().await
1135        })
1136        .await;
1137
1138        assert_eq!(res.unwrap(), 0.into());
1139    }
1140}