ark_mpc/algebra/scalar/
authenticated_scalar.rs

1//! Defines the authenticated (malicious secure) variant of the MPC scalar type
2
3use std::{
4    fmt::Debug,
5    iter::Sum,
6    ops::{Add, Div, Mul, Neg, Sub},
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use ark_ec::CurveGroup;
12use ark_ff::FftField;
13use ark_poly::EvaluationDomain;
14use futures::{Future, FutureExt};
15use itertools::{izip, Itertools};
16
17use crate::{
18    algebra::{macros::*, AuthenticatedPointResult, CurvePoint, CurvePointResult},
19    commitment::{PedersenCommitment, PedersenCommitmentResult},
20    error::MpcError,
21    fabric::{MpcFabric, ResultId, ResultValue},
22    PARTY0,
23};
24
25use super::{
26    mpc_scalar::MpcScalarResult,
27    scalar::{BatchScalarResult, Scalar, ScalarResult},
28};
29
30/// The number of results wrapped by an `AuthenticatedScalarResult<C>`
31pub const AUTHENTICATED_SCALAR_RESULT_LEN: usize = 3;
32
33/// A maliciously secure wrapper around an `MpcScalarResult`, includes a MAC as
34/// per the SPDZ protocol: https://eprint.iacr.org/2011/535.pdf
35/// that ensures security against a malicious adversary
36#[derive(Clone)]
37pub struct AuthenticatedScalarResult<C: CurveGroup> {
38    /// The secret shares of the underlying value
39    pub(crate) share: MpcScalarResult<C>,
40    /// The SPDZ style, unconditionally secure MAC of the value
41    ///
42    /// If the value is `x`, parties hold secret shares of the value
43    /// \delta * x for the global MAC key `\delta`. The parties individually
44    /// hold secret shares of this MAC key [\delta], so we can very naturally
45    /// extend the secret share arithmetic of the underlying `MpcScalarResult`
46    /// to the MAC updates as well
47    pub(crate) mac: MpcScalarResult<C>,
48    /// The public modifier tracks additions and subtractions of public values
49    /// to the underlying value. This is necessary because in the case of a
50    /// public addition, only the first party adds the public value to their
51    /// share, so the second party must track this up until the point that
52    /// the value is opened and the MAC is checked
53    pub(crate) public_modifier: ScalarResult<C>,
54}
55
56impl<C: CurveGroup> Debug for AuthenticatedScalarResult<C> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("AuthenticatedScalarResult<C>")
59            .field("value", &self.share.id())
60            .field("mac", &self.mac.id())
61            .field("public_modifier", &self.public_modifier.id)
62            .finish()
63    }
64}
65
66impl<C: CurveGroup> AuthenticatedScalarResult<C> {
67    /// Create a new result from the given shared value
68    pub fn new_shared(value: ScalarResult<C>) -> Self {
69        // Create an `MpcScalarResult` to represent the fact that this is a shared value
70        let fabric = value.fabric.clone();
71
72        let mpc_value = MpcScalarResult::new_shared(value);
73        let mac = fabric.borrow_mac_key() * mpc_value.clone();
74
75        // Allocate a zero for the public modifier
76        let public_modifier = fabric.zero();
77
78        Self {
79            share: mpc_value,
80            mac,
81            public_modifier,
82        }
83    }
84
85    /// Create a new batch of shared values
86    pub fn new_shared_batch(values: &[ScalarResult<C>]) -> Vec<Self> {
87        if values.is_empty() {
88            return vec![];
89        }
90
91        let n = values.len();
92        let fabric = values[0].fabric();
93        let mpc_values = values
94            .iter()
95            .map(|v| MpcScalarResult::new_shared(v.clone()))
96            .collect_vec();
97
98        let mac_keys = (0..n)
99            .map(|_| fabric.borrow_mac_key().clone())
100            .collect_vec();
101        let values_macs = MpcScalarResult::batch_mul(&mpc_values, &mac_keys);
102
103        mpc_values
104            .into_iter()
105            .zip(values_macs)
106            .map(|(value, mac)| Self {
107                share: value,
108                mac,
109                public_modifier: fabric.zero(),
110            })
111            .collect_vec()
112    }
113
114    /// Create a nwe shared batch of values from a batch network result
115    ///
116    /// The batch result combines the batch into one result, so it must be split
117    /// out first before creating the `AuthenticatedScalarResult`s
118    pub fn new_shared_from_batch_result(
119        values: BatchScalarResult<C>,
120        n: usize,
121    ) -> Vec<AuthenticatedScalarResult<C>> {
122        // Convert to a set of scalar results
123        let scalar_results: Vec<ScalarResult<C>> =
124            values
125                .fabric()
126                .new_batch_gate_op(vec![values.id()], n, |mut args| {
127                    let scalars: Vec<Scalar<C>> = args.pop().unwrap().into();
128                    scalars.into_iter().map(ResultValue::Scalar).collect()
129                });
130
131        Self::new_shared_batch(&scalar_results)
132    }
133
134    /// Get the raw share as an `MpcScalarResult`
135    #[cfg(feature = "test_helpers")]
136    pub fn mpc_share(&self) -> MpcScalarResult<C> {
137        self.share.clone()
138    }
139
140    /// Get the raw share as a `ScalarResult`
141    pub fn share(&self) -> ScalarResult<C> {
142        self.share.to_scalar()
143    }
144
145    /// Get the raw share of the MAC as a `ScalarResult`
146    pub fn mac_share(&self) -> ScalarResult<C> {
147        self.mac.to_scalar()
148    }
149
150    /// Get a reference to the underlying MPC fabric
151    pub fn fabric(&self) -> &MpcFabric<C> {
152        self.share.fabric()
153    }
154
155    /// Get the ids of the results that must be awaited
156    /// before the value is ready
157    pub fn ids(&self) -> Vec<ResultId> {
158        vec![self.share.id(), self.mac.id(), self.public_modifier.id]
159    }
160
161    /// Compute the inverse of a
162    pub fn inverse(&self) -> AuthenticatedScalarResult<C> {
163        let mut res = Self::batch_inverse(&[self.clone()]);
164        res.remove(0)
165    }
166
167    /// Compute a batch of inverses of an `AuthenticatedScalarResult`s
168    ///
169    /// This follows the protocol detailed in:
170    ///     https://dl.acm.org/doi/pdf/10.1145/72981.72995
171    /// Which gives a two round implementation
172    pub fn batch_inverse(
173        values: &[AuthenticatedScalarResult<C>],
174    ) -> Vec<AuthenticatedScalarResult<C>> {
175        let n = values.len();
176        assert!(n > 0, "cannot invert empty batch of scalars");
177
178        let fabric = values[0].fabric();
179
180        // For the following steps, let the input values be x_i for i=1..n
181
182        // 1. Sample a random shared group element from the shared value source
183        // call these values r_i for i=1..n
184        let shared_scalars = fabric.random_shared_scalars_authenticated(n);
185
186        // 2. Mask the values by multiplying them with the random scalars, i.e. compute
187        //    m_i = (r_i * x_i)
188        // Open the masked values to both parties
189        let masked_values = AuthenticatedScalarResult::batch_mul(values, &shared_scalars);
190        let masked_values_open = Self::open_authenticated_batch(&masked_values);
191
192        // 3. Compute the inverse of the masked values: m_i^-1 = (x_i^-1 * r_i^-1)
193        let inverted_openings = masked_values_open
194            .into_iter()
195            .map(|val| val.value.inverse())
196            .collect_vec();
197
198        // 4. Multiply these inverted openings with the original shared scalars r_i:
199        //    m_i^-1 * r_i = (x_i^-1 * r_i^-1) * r_i = x_i^-1
200        AuthenticatedScalarResult::batch_mul_public(&shared_scalars, &inverted_openings)
201    }
202
203    /// Compute the exponentiation of the given value
204    /// via recursive squaring
205    pub fn pow(&self, exp: u64) -> Self {
206        if exp == 0 {
207            return self.fabric().zero_authenticated();
208        } else if exp == 1 {
209            return self.clone();
210        }
211
212        let recursive = self.pow(exp / 2);
213        let mut res = &recursive * &recursive;
214
215        if exp % 2 == 1 {
216            res = res * self.clone();
217        }
218        res
219    }
220}
221
222/// Opening implementations
223impl<C: CurveGroup> AuthenticatedScalarResult<C> {
224    /// Open the value without checking its MAC
225    pub fn open(&self) -> ScalarResult<C> {
226        self.share.open()
227    }
228
229    /// Open a batch of values without checking their MACs
230    pub fn open_batch(values: &[Self]) -> Vec<ScalarResult<C>> {
231        MpcScalarResult::open_batch(&values.iter().map(|val| val.share.clone()).collect_vec())
232    }
233
234    /// Convert a flattened iterator into a batch of
235    /// `AuthenticatedScalarResult`s
236    ///
237    /// We assume that the iterator has been flattened in the same way order
238    /// that `Self::id`s returns the `AuthenticatedScalar<C>`'s values:
239    /// `[share, mac, public_modifier]`
240    pub fn from_flattened_iterator<I>(iter: I) -> Vec<Self>
241    where
242        I: Iterator<Item = ScalarResult<C>>,
243    {
244        iter.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
245            .into_iter()
246            .map(|mut chunk| Self {
247                share: chunk.next().unwrap().into(),
248                mac: chunk.next().unwrap().into(),
249                public_modifier: chunk.next().unwrap(),
250            })
251            .collect_vec()
252    }
253
254    /// Check the commitment to a MAC check and that the MAC checks sum to zero
255    pub fn verify_mac_check(
256        my_mac_share: Scalar<C>,
257        peer_mac_share: Scalar<C>,
258        peer_mac_commitment: CurvePoint<C>,
259        peer_commitment_blinder: Scalar<C>,
260    ) -> bool {
261        let their_comm = PedersenCommitment {
262            value: peer_mac_share,
263            blinder: peer_commitment_blinder,
264            commitment: peer_mac_commitment,
265        };
266
267        // Verify that the commitment to the MAC check opens correctly
268        if !their_comm.verify() {
269            return false;
270        }
271
272        // Sum of the commitments should be zero
273        if peer_mac_share + my_mac_share != Scalar::zero() {
274            return false;
275        }
276
277        true
278    }
279
280    /// Open the value and check its MAC
281    ///
282    /// This follows the protocol detailed in:
283    ///     https://securecomputation.org/docs/pragmaticmpc.pdf
284    /// Section 6.6.2
285    pub fn open_authenticated(&self) -> AuthenticatedScalarOpenResult<C> {
286        // Both parties open the underlying value
287        let recovered_value = self.share.open();
288
289        // Add a gate to compute the MAC check value: `key_share * opened_value -
290        // mac_share`
291        let mac_check_value: ScalarResult<C> = self.fabric().new_gate_op(
292            vec![
293                self.fabric().borrow_mac_key().id(),
294                recovered_value.id,
295                self.public_modifier.id,
296                self.mac.id(),
297            ],
298            move |mut args| {
299                let mac_key_share: Scalar<C> = args.remove(0).into();
300                let value: Scalar<C> = args.remove(0).into();
301                let modifier: Scalar<C> = args.remove(0).into();
302                let mac_share: Scalar<C> = args.remove(0).into();
303
304                ResultValue::Scalar(mac_key_share * (value + modifier) - mac_share)
305            },
306        );
307
308        // Compute a commitment to this value and share it with the peer
309        let my_comm = PedersenCommitmentResult::commit(mac_check_value);
310        let peer_commit = self.fabric().exchange_value(my_comm.commitment);
311
312        // Once the parties have exchanged their commitments, they can open them, they
313        // have already exchanged the underlying values and their commitments so
314        // all that is left is the blinder
315        let peer_mac_check = self.fabric().exchange_value(my_comm.value.clone());
316
317        let blinder_result: ScalarResult<C> = self.fabric().allocate_scalar(my_comm.blinder);
318        let peer_blinder = self.fabric().exchange_value(blinder_result);
319
320        // Check the commitment and the MAC result
321        let commitment_check: ScalarResult<C> = self.fabric().new_gate_op(
322            vec![
323                my_comm.value.id,
324                peer_mac_check.id,
325                peer_blinder.id,
326                peer_commit.id,
327            ],
328            |mut args| {
329                let my_comm_value: Scalar<C> = args.remove(0).into();
330                let peer_value: Scalar<C> = args.remove(0).into();
331                let blinder: Scalar<C> = args.remove(0).into();
332                let commitment: CurvePoint<C> = args.remove(0).into();
333
334                // Build a commitment from the gate inputs
335                ResultValue::Scalar(Scalar::from(Self::verify_mac_check(
336                    my_comm_value,
337                    peer_value,
338                    commitment,
339                    blinder,
340                )))
341            },
342        );
343
344        AuthenticatedScalarOpenResult {
345            value: recovered_value,
346            mac_check: commitment_check,
347        }
348    }
349
350    /// Open a batch of values and check their MACs
351    pub fn open_authenticated_batch(values: &[Self]) -> Vec<AuthenticatedScalarOpenResult<C>> {
352        if values.is_empty() {
353            return vec![];
354        }
355
356        let n = values.len();
357        let fabric = &values[0].fabric();
358
359        // Both parties open the underlying values
360        let values_open = Self::open_batch(values);
361
362        // --- Mac Checks --- //
363
364        // Compute the shares of the MAC check in batch
365        let mut mac_check_deps = Vec::with_capacity(1 + 3 * n);
366        mac_check_deps.push(fabric.borrow_mac_key().id());
367        for i in 0..n {
368            mac_check_deps.push(values_open[i].id());
369            mac_check_deps.push(values[i].public_modifier.id());
370            mac_check_deps.push(values[i].mac.id());
371        }
372
373        let mac_checks: Vec<ScalarResult<C>> =
374            fabric.new_batch_gate_op(mac_check_deps, n /* output_arity */, move |mut args| {
375                let mac_key_share: Scalar<C> = args.remove(0).into();
376                let mut check_result = Vec::with_capacity(n);
377
378                for _ in 0..n {
379                    let value: Scalar<C> = args.remove(0).into();
380                    let modifier: Scalar<C> = args.remove(0).into();
381                    let mac_share: Scalar<C> = args.remove(0).into();
382
383                    check_result.push(mac_key_share * (value + modifier) - mac_share);
384                }
385
386                check_result.into_iter().map(ResultValue::Scalar).collect()
387            });
388
389        // --- Commit to MAC Checks --- //
390
391        let my_comms = mac_checks
392            .iter()
393            .cloned()
394            .map(PedersenCommitmentResult::commit)
395            .collect_vec();
396        let peer_comms = fabric.exchange_values(
397            &my_comms
398                .iter()
399                .map(|comm| comm.commitment.clone())
400                .collect_vec(),
401        );
402
403        // --- Exchange the MAC Checks and Commitment Blinders --- //
404
405        let peer_mac_checks = fabric.exchange_values(&mac_checks);
406        let peer_blinders = fabric.exchange_values(
407            &my_comms
408                .iter()
409                .map(|comm| fabric.allocate_scalar(comm.blinder))
410                .collect_vec(),
411        );
412
413        // --- Check the MAC Checks --- //
414
415        let mut mac_check_gate_deps = my_comms.iter().map(|comm| comm.value.id).collect_vec();
416        mac_check_gate_deps.push(peer_mac_checks.id);
417        mac_check_gate_deps.push(peer_blinders.id);
418        mac_check_gate_deps.push(peer_comms.id);
419
420        let commitment_checks: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
421            mac_check_gate_deps,
422            n, // output_arity
423            move |mut args| {
424                let my_comms: Vec<Scalar<C>> = args.drain(..n).map(|comm| comm.into()).collect();
425                let peer_mac_checks: Vec<Scalar<C>> = args.remove(0).into();
426                let peer_blinders: Vec<Scalar<C>> = args.remove(0).into();
427                let peer_comms: Vec<CurvePoint<C>> = args.remove(0).into();
428
429                // Build a commitment from the gate inputs
430                let mut mac_checks = Vec::with_capacity(n);
431                for (my_mac_share, peer_mac_share, peer_blinder, peer_commitment) in izip!(
432                    my_comms.into_iter(),
433                    peer_mac_checks.into_iter(),
434                    peer_blinders.into_iter(),
435                    peer_comms.into_iter()
436                ) {
437                    let mac_check = Self::verify_mac_check(
438                        my_mac_share,
439                        peer_mac_share,
440                        peer_commitment,
441                        peer_blinder,
442                    );
443                    mac_checks.push(ResultValue::Scalar(Scalar::from(mac_check)));
444                }
445
446                mac_checks
447            },
448        );
449
450        // --- Return the results --- //
451
452        values_open
453            .into_iter()
454            .zip(commitment_checks)
455            .map(|(value, check)| AuthenticatedScalarOpenResult {
456                value,
457                mac_check: check,
458            })
459            .collect_vec()
460    }
461}
462
463/// The value that results from opening an `AuthenticatedScalarResult` and
464/// checking its MAC. This encapsulates both the underlying value and the result
465/// of the MAC check
466#[derive(Clone)]
467pub struct AuthenticatedScalarOpenResult<C: CurveGroup> {
468    /// The underlying value
469    pub value: ScalarResult<C>,
470    /// The result of the MAC check
471    pub mac_check: ScalarResult<C>,
472}
473
474impl<C: CurveGroup> Future for AuthenticatedScalarOpenResult<C>
475where
476    C::ScalarField: Unpin,
477{
478    type Output = Result<Scalar<C>, MpcError>;
479
480    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
481        // Await both of the underlying values
482        let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
483        let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));
484
485        if mac_check == Scalar::from(1u8) {
486            Poll::Ready(Ok(value))
487        } else {
488            Poll::Ready(Err(MpcError::AuthenticationError))
489        }
490    }
491}
492
493// --------------
494// | Arithmetic |
495// --------------
496
497// === Addition === //
498
499impl<C: CurveGroup> Add<&Scalar<C>> for &AuthenticatedScalarResult<C> {
500    type Output = AuthenticatedScalarResult<C>;
501
502    fn add(self, rhs: &Scalar<C>) -> Self::Output {
503        let new_share = if self.fabric().party_id() == PARTY0 {
504            &self.share + rhs
505        } else {
506            &self.share + Scalar::zero()
507        };
508
509        // Both parties add the public value to their modifier, and the MACs do not
510        // change when adding a public value
511        let new_modifier = &self.public_modifier - rhs;
512        AuthenticatedScalarResult {
513            share: new_share,
514            mac: self.mac.clone(),
515            public_modifier: new_modifier,
516        }
517    }
518}
519impl_borrow_variants!(AuthenticatedScalarResult<C>, Add, add, +, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
520impl_commutative!(AuthenticatedScalarResult<C>, Add, add, +, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
521
522impl<C: CurveGroup> Add<&ScalarResult<C>> for &AuthenticatedScalarResult<C> {
523    type Output = AuthenticatedScalarResult<C>;
524
525    fn add(self, rhs: &ScalarResult<C>) -> Self::Output {
526        // As above, only party 0 adds the public value to their share, but both parties
527        // track this with the modifier
528        //
529        // Party 1 adds a zero value to their share to allocate a new ID for the result
530        let new_share = if self.fabric().party_id() == PARTY0 {
531            &self.share + rhs
532        } else {
533            &self.share + Scalar::zero()
534        };
535
536        let new_modifier = &self.public_modifier - rhs;
537        AuthenticatedScalarResult {
538            share: new_share,
539            mac: self.mac.clone(),
540            public_modifier: new_modifier,
541        }
542    }
543}
544impl_borrow_variants!(AuthenticatedScalarResult<C>, Add, add, +, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
545impl_commutative!(AuthenticatedScalarResult<C>, Add, add, +, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
546
547impl<C: CurveGroup> Add<&AuthenticatedScalarResult<C>> for &AuthenticatedScalarResult<C> {
548    type Output = AuthenticatedScalarResult<C>;
549
550    fn add(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
551        AuthenticatedScalarResult {
552            share: &self.share + &rhs.share,
553            mac: &self.mac + &rhs.mac,
554            public_modifier: self.public_modifier.clone() + rhs.public_modifier.clone(),
555        }
556    }
557}
558impl_borrow_variants!(AuthenticatedScalarResult<C>, Add, add, +, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
559
560impl<C: CurveGroup> AuthenticatedScalarResult<C> {
561    /// Add two batches of `AuthenticatedScalarResult`s
562    pub fn batch_add(
563        a: &[AuthenticatedScalarResult<C>],
564        b: &[AuthenticatedScalarResult<C>],
565    ) -> Vec<AuthenticatedScalarResult<C>> {
566        assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
567
568        let n = a.len();
569        let fabric = a[0].fabric();
570        let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
571
572        // Add the underlying values
573        let gate_results: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
574            all_ids,
575            AUTHENTICATED_SCALAR_RESULT_LEN * n, // output_arity
576            move |mut args| {
577                let arg_len = args.len();
578                let a_vals = args.drain(..arg_len / 2).collect_vec();
579                let b_vals = args;
580
581                let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
582                for (mut a_vals, mut b_vals) in a_vals
583                    .into_iter()
584                    .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
585                    .into_iter()
586                    .zip(
587                        b_vals
588                            .into_iter()
589                            .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
590                            .into_iter(),
591                    )
592                {
593                    let a_share: Scalar<C> = a_vals.next().unwrap().into();
594                    let a_mac_share: Scalar<C> = a_vals.next().unwrap().into();
595                    let a_modifier: Scalar<C> = a_vals.next().unwrap().into();
596
597                    let b_share: Scalar<C> = b_vals.next().unwrap().into();
598                    let b_mac_share: Scalar<C> = b_vals.next().unwrap().into();
599                    let b_modifier: Scalar<C> = b_vals.next().unwrap().into();
600
601                    result.push(ResultValue::Scalar(a_share + b_share));
602                    result.push(ResultValue::Scalar(a_mac_share + b_mac_share));
603                    result.push(ResultValue::Scalar(a_modifier + b_modifier));
604                }
605
606                result
607            },
608        );
609
610        // Collect the gate results into a series of `AuthenticatedScalarResult`s
611        AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
612    }
613
614    /// Add a batch of `AuthenticatedScalarResult`s to a batch of
615    /// `ScalarResult`s
616    pub fn batch_add_public(
617        a: &[AuthenticatedScalarResult<C>],
618        b: &[ScalarResult<C>],
619    ) -> Vec<AuthenticatedScalarResult<C>> {
620        assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
621
622        let n = a.len();
623        let results_per_value = 3;
624        let fabric = a[0].fabric();
625        let all_ids = a
626            .iter()
627            .flat_map(|v| v.ids())
628            .chain(b.iter().map(|v| v.id()))
629            .collect_vec();
630
631        // Add the underlying values
632        let party_id = fabric.party_id();
633        let gate_results: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
634            all_ids,
635            results_per_value * n, // output_arity
636            move |mut args| {
637                // Split the args
638                let a_vals = args
639                    .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
640                    .collect_vec();
641                let public_values = args;
642
643                let mut result = Vec::with_capacity(results_per_value * n);
644                for (mut a_vals, public_value) in a_vals
645                    .into_iter()
646                    .chunks(results_per_value)
647                    .into_iter()
648                    .zip(public_values.into_iter())
649                {
650                    let a_share: Scalar<C> = a_vals.next().unwrap().into();
651                    let a_mac_share: Scalar<C> = a_vals.next().unwrap().into();
652                    let a_modifier: Scalar<C> = a_vals.next().unwrap().into();
653
654                    let public_value: Scalar<C> = public_value.into();
655
656                    // Only the first party adds the public value to their share
657                    if party_id == PARTY0 {
658                        result.push(ResultValue::Scalar(a_share + public_value));
659                    } else {
660                        result.push(ResultValue::Scalar(a_share));
661                    }
662
663                    result.push(ResultValue::Scalar(a_mac_share));
664                    result.push(ResultValue::Scalar(a_modifier - public_value));
665                }
666
667                result
668            },
669        );
670
671        // Collect the gate results into a series of `AuthenticatedScalarResult<C>`s
672        AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
673    }
674}
675
676/// TODO: Maybe use a batch gate for this; performance depends on whether
677/// materializing the iterator is burdensome
678impl<C: CurveGroup> Sum for AuthenticatedScalarResult<C> {
679    /// Assumes the iterator is non-empty
680    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
681        let seed = iter.next().expect("Cannot sum empty iterator");
682        iter.fold(seed, |acc, val| acc + &val)
683    }
684}
685
686// === Subtraction === //
687
688impl<C: CurveGroup> Sub<&Scalar<C>> for &AuthenticatedScalarResult<C> {
689    type Output = AuthenticatedScalarResult<C>;
690
691    /// As in the case for addition, only party 0 subtracts the public value
692    /// from their share, but both parties track this in the public modifier
693    fn sub(self, rhs: &Scalar<C>) -> Self::Output {
694        // Party 1 subtracts a zero value from their share to allocate a new ID for the
695        // result and stay in sync with party 0
696        let new_share = &self.share - rhs;
697
698        // Both parties add the public value to their modifier, and the MACs do not
699        // change when adding a public value
700        let new_modifier = &self.public_modifier + rhs;
701        AuthenticatedScalarResult {
702            share: new_share,
703            mac: self.mac.clone(),
704            public_modifier: new_modifier,
705        }
706    }
707}
708impl_borrow_variants!(AuthenticatedScalarResult<C>, Sub, sub, -, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
709
710impl<C: CurveGroup> Sub<&AuthenticatedScalarResult<C>> for &Scalar<C> {
711    type Output = AuthenticatedScalarResult<C>;
712
713    fn sub(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
714        // Party 1 subtracts a zero value from their share to allocate a new ID for the
715        // result and stay in sync with party 0
716        let new_share = self - &rhs.share;
717
718        // Both parties add the public value to their modifier, and the MACs do not
719        // change when adding a public value
720        let new_modifier = -self - &rhs.public_modifier;
721        AuthenticatedScalarResult {
722            share: new_share,
723            mac: -&rhs.mac,
724            public_modifier: new_modifier,
725        }
726    }
727}
728impl_borrow_variants!(Scalar<C>, Sub, sub, -, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
729
730impl<C: CurveGroup> Sub<&ScalarResult<C>> for &AuthenticatedScalarResult<C> {
731    type Output = AuthenticatedScalarResult<C>;
732
733    fn sub(self, rhs: &ScalarResult<C>) -> Self::Output {
734        let new_share = &self.share - rhs;
735
736        // Both parties add the public value to their modifier, and the MACs do not
737        // change when adding a public value
738        let new_modifier = &self.public_modifier + rhs;
739        AuthenticatedScalarResult {
740            share: new_share,
741            mac: self.mac.clone(),
742            public_modifier: new_modifier,
743        }
744    }
745}
746impl_borrow_variants!(AuthenticatedScalarResult<C>, Sub, sub, -, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
747
748impl<C: CurveGroup> Sub<&AuthenticatedScalarResult<C>> for &ScalarResult<C> {
749    type Output = AuthenticatedScalarResult<C>;
750
751    fn sub(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
752        // Party 1 subtracts a zero value from their share to allocate a new ID for the
753        // result and stay in sync with party 0
754        let new_share = self - &rhs.share;
755
756        // Both parties add the public value to their modifier, and the MACs do not
757        // change when adding a public value
758        let new_modifier = -self - &rhs.public_modifier;
759        AuthenticatedScalarResult {
760            share: new_share,
761            mac: -&rhs.mac,
762            public_modifier: new_modifier,
763        }
764    }
765}
766impl_borrow_variants!(ScalarResult<C>, Sub, sub, -, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
767
768impl<C: CurveGroup> Sub<&AuthenticatedScalarResult<C>> for &AuthenticatedScalarResult<C> {
769    type Output = AuthenticatedScalarResult<C>;
770
771    fn sub(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
772        AuthenticatedScalarResult {
773            share: &self.share - &rhs.share,
774            mac: &self.mac - &rhs.mac,
775            public_modifier: self.public_modifier.clone() - rhs.public_modifier.clone(),
776        }
777    }
778}
779impl_borrow_variants!(AuthenticatedScalarResult<C>, Sub, sub, -, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
780
781impl<C: CurveGroup> AuthenticatedScalarResult<C> {
782    /// Add two batches of `AuthenticatedScalarResult`s
783    pub fn batch_sub(
784        a: &[AuthenticatedScalarResult<C>],
785        b: &[AuthenticatedScalarResult<C>],
786    ) -> Vec<AuthenticatedScalarResult<C>> {
787        assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
788
789        let n = a.len();
790        let fabric = &a[0].fabric();
791        let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
792
793        // Add the underlying values
794        let gate_results: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
795            all_ids,
796            AUTHENTICATED_SCALAR_RESULT_LEN * n, // output_arity
797            move |mut args| {
798                let arg_len = args.len();
799                let a_vals = args.drain(..arg_len / 2).collect_vec();
800                let b_vals = args;
801
802                let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
803                for (mut a_vals, mut b_vals) in a_vals
804                    .into_iter()
805                    .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
806                    .into_iter()
807                    .zip(
808                        b_vals
809                            .into_iter()
810                            .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
811                            .into_iter(),
812                    )
813                {
814                    let a_share: Scalar<C> = a_vals.next().unwrap().into();
815                    let a_mac_share: Scalar<C> = a_vals.next().unwrap().into();
816                    let a_modifier: Scalar<C> = a_vals.next().unwrap().into();
817
818                    let b_share: Scalar<C> = b_vals.next().unwrap().into();
819                    let b_mac_share: Scalar<C> = b_vals.next().unwrap().into();
820                    let b_modifier: Scalar<C> = b_vals.next().unwrap().into();
821
822                    result.push(ResultValue::Scalar(a_share - b_share));
823                    result.push(ResultValue::Scalar(a_mac_share - b_mac_share));
824                    result.push(ResultValue::Scalar(a_modifier - b_modifier));
825                }
826
827                result
828            },
829        );
830
831        // Collect the gate results into a series of `AuthenticatedScalarResult`s
832        AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
833    }
834
835    /// Subtract a batch of `ScalarResult`s from a batch of
836    /// `AuthenticatedScalarResult`s
837    pub fn batch_sub_public(
838        a: &[AuthenticatedScalarResult<C>],
839        b: &[ScalarResult<C>],
840    ) -> Vec<AuthenticatedScalarResult<C>> {
841        assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
842
843        let n = a.len();
844        let results_per_value = 3;
845        let fabric = a[0].fabric();
846        let all_ids = a
847            .iter()
848            .flat_map(|v| v.ids())
849            .chain(b.iter().map(|v| v.id()))
850            .collect_vec();
851
852        // Add the underlying values
853        let party_id = fabric.party_id();
854        let gate_results: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
855            all_ids,
856            results_per_value * n, // output_arity
857            move |mut args| {
858                // Split the args
859                let a_vals = args
860                    .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
861                    .collect_vec();
862                let public_values = args;
863
864                let mut result = Vec::with_capacity(results_per_value * n);
865                for (mut a_vals, public_value) in a_vals
866                    .into_iter()
867                    .chunks(results_per_value)
868                    .into_iter()
869                    .zip(public_values.into_iter())
870                {
871                    let a_share: Scalar<C> = a_vals.next().unwrap().into();
872                    let a_mac_share: Scalar<C> = a_vals.next().unwrap().into();
873                    let a_modifier: Scalar<C> = a_vals.next().unwrap().into();
874
875                    let public_value: Scalar<C> = public_value.into();
876
877                    // Only the first party adds the public value to their share
878                    if party_id == PARTY0 {
879                        result.push(ResultValue::Scalar(a_share - public_value));
880                    } else {
881                        result.push(ResultValue::Scalar(a_share));
882                    }
883
884                    result.push(ResultValue::Scalar(a_mac_share));
885                    result.push(ResultValue::Scalar(a_modifier + public_value));
886                }
887
888                result
889            },
890        );
891
892        // Collect the gate results into a series of `AuthenticatedScalarResult`s
893        AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
894    }
895}
896
897// === Negation === //
898
899impl<C: CurveGroup> Neg for &AuthenticatedScalarResult<C> {
900    type Output = AuthenticatedScalarResult<C>;
901
902    fn neg(self) -> Self::Output {
903        AuthenticatedScalarResult {
904            share: -&self.share,
905            mac: -&self.mac,
906            public_modifier: -&self.public_modifier,
907        }
908    }
909}
910impl_borrow_variants!(AuthenticatedScalarResult<C>, Neg, neg, -, C: CurveGroup);
911
912impl<C: CurveGroup> AuthenticatedScalarResult<C> {
913    /// Negate a batch of `AuthenticatedScalarResult`s
914    pub fn batch_neg(a: &[AuthenticatedScalarResult<C>]) -> Vec<AuthenticatedScalarResult<C>> {
915        if a.is_empty() {
916            return vec![];
917        }
918
919        let n = a.len();
920        let fabric = a[0].fabric();
921        let all_ids = a.iter().flat_map(|v| v.ids()).collect_vec();
922
923        let scalars = fabric.new_batch_gate_op(
924            all_ids,
925            AUTHENTICATED_SCALAR_RESULT_LEN * n, // output_arity
926            |args| {
927                args.into_iter()
928                    .map(|arg| ResultValue::Scalar(-Scalar::from(arg)))
929                    .collect()
930            },
931        );
932
933        AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
934    }
935}
936
937// === Multiplication === //
938
939impl<C: CurveGroup> Mul<&Scalar<C>> for &AuthenticatedScalarResult<C> {
940    type Output = AuthenticatedScalarResult<C>;
941
942    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
943        AuthenticatedScalarResult {
944            share: &self.share * rhs,
945            mac: &self.mac * rhs,
946            public_modifier: &self.public_modifier * rhs,
947        }
948    }
949}
950impl_borrow_variants!(AuthenticatedScalarResult<C>, Mul, mul, *, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
951impl_commutative!(AuthenticatedScalarResult<C>, Mul, mul, *, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
952
953impl<C: CurveGroup> Mul<&ScalarResult<C>> for &AuthenticatedScalarResult<C> {
954    type Output = AuthenticatedScalarResult<C>;
955
956    fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
957        AuthenticatedScalarResult {
958            share: &self.share * rhs,
959            mac: &self.mac * rhs,
960            public_modifier: &self.public_modifier * rhs,
961        }
962    }
963}
964impl_borrow_variants!(AuthenticatedScalarResult<C>, Mul, mul, *, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
965impl_commutative!(AuthenticatedScalarResult<C>, Mul, mul, *, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
966
967impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &AuthenticatedScalarResult<C> {
968    type Output = AuthenticatedScalarResult<C>;
969
970    // Use the Beaver trick
971    fn mul(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
972        // Sample a beaver triplet
973        let (a, b, c) = self.fabric().next_authenticated_triple();
974
975        // Mask the left and right hand sides
976        let masked_lhs = self - &a;
977        let masked_rhs = rhs - &b;
978
979        // Open these values to get d = lhs - a, e = rhs - b
980        let d = masked_lhs.open();
981        let e = masked_rhs.open();
982
983        // Use the same beaver identify as in the `MpcScalarResult<C>` case, but now the
984        // public multiplications are applied to the MACs and the public
985        // modifiers as well Identity: [x * y] = de + d[b] + e[a] + [c]
986        &d * &e + d * b + e * a + c
987    }
988}
989impl_borrow_variants!(AuthenticatedScalarResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
990
991impl<C: CurveGroup> AuthenticatedScalarResult<C> {
992    /// Multiply a batch of values using the Beaver trick
993    pub fn batch_mul(
994        a: &[AuthenticatedScalarResult<C>],
995        b: &[AuthenticatedScalarResult<C>],
996    ) -> Vec<AuthenticatedScalarResult<C>> {
997        assert_eq!(
998            a.len(),
999            b.len(),
1000            "Cannot multiply batches of different sizes"
1001        );
1002
1003        if a.is_empty() {
1004            return vec![];
1005        }
1006
1007        let n = a.len();
1008        let fabric = a[0].fabric();
1009        let (beaver_a, beaver_b, beaver_c) = fabric.next_authenticated_triple_batch(n);
1010
1011        // Open the values d = [lhs - a] and e = [rhs - b]
1012        let masked_lhs = AuthenticatedScalarResult::batch_sub(a, &beaver_a);
1013        let masked_rhs = AuthenticatedScalarResult::batch_sub(b, &beaver_b);
1014
1015        let all_masks = [masked_lhs, masked_rhs].concat();
1016        let opened_values = AuthenticatedScalarResult::open_batch(&all_masks);
1017        let (d_open, e_open) = opened_values.split_at(n);
1018
1019        // Identity: [x * y] = de + d[b] + e[a] + [c]
1020        let de = ScalarResult::batch_mul(d_open, e_open);
1021        let db = AuthenticatedScalarResult::batch_mul_public(&beaver_b, d_open);
1022        let ea = AuthenticatedScalarResult::batch_mul_public(&beaver_a, e_open);
1023
1024        // Add the terms
1025        let de_plus_db = AuthenticatedScalarResult::batch_add_public(&db, &de);
1026        let ea_plus_c = AuthenticatedScalarResult::batch_add(&ea, &beaver_c);
1027        AuthenticatedScalarResult::batch_add(&de_plus_db, &ea_plus_c)
1028    }
1029
1030    /// Multiply a batch of `AuthenticatedScalarResult`s by a batch of
1031    /// `ScalarResult`s
1032    pub fn batch_mul_public(
1033        a: &[AuthenticatedScalarResult<C>],
1034        b: &[ScalarResult<C>],
1035    ) -> Vec<AuthenticatedScalarResult<C>> {
1036        assert_eq!(
1037            a.len(),
1038            b.len(),
1039            "Cannot multiply batches of different sizes"
1040        );
1041        if a.is_empty() {
1042            return vec![];
1043        }
1044
1045        let n = a.len();
1046        let fabric = a[0].fabric();
1047        let all_ids = a
1048            .iter()
1049            .flat_map(|a| a.ids())
1050            .chain(b.iter().map(|b| b.id()))
1051            .collect_vec();
1052
1053        let scalars = fabric.new_batch_gate_op(
1054            all_ids,
1055            AUTHENTICATED_SCALAR_RESULT_LEN * n, // output_arity
1056            move |mut args| {
1057                let a_vals = args
1058                    .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
1059                    .collect_vec();
1060                let public_values = args;
1061
1062                let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
1063                for (a_vals, public_values) in a_vals
1064                    .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
1065                    .zip(public_values.into_iter())
1066                {
1067                    let a_share: Scalar<C> = a_vals[0].to_owned().into();
1068                    let a_mac_share: Scalar<C> = a_vals[1].to_owned().into();
1069                    let a_modifier: Scalar<C> = a_vals[2].to_owned().into();
1070
1071                    let public_value: Scalar<C> = public_values.into();
1072
1073                    result.push(ResultValue::Scalar(a_share * public_value));
1074                    result.push(ResultValue::Scalar(a_mac_share * public_value));
1075                    result.push(ResultValue::Scalar(a_modifier * public_value));
1076                }
1077
1078                result
1079            },
1080        );
1081
1082        AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
1083    }
1084}
1085
1086// === Division === //
1087#[allow(clippy::suspicious_arithmetic_impl)]
1088impl<C: CurveGroup> Div<&ScalarResult<C>> for &AuthenticatedScalarResult<C> {
1089    type Output = AuthenticatedScalarResult<C>;
1090    fn div(self, rhs: &ScalarResult<C>) -> Self::Output {
1091        let rhs_inv = rhs.inverse();
1092        self * rhs_inv
1093    }
1094}
1095impl_borrow_variants!(AuthenticatedScalarResult<C>, Div, div, /, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
1096
1097#[allow(clippy::suspicious_arithmetic_impl)]
1098impl<C: CurveGroup> Div<&AuthenticatedScalarResult<C>> for &AuthenticatedScalarResult<C> {
1099    type Output = AuthenticatedScalarResult<C>;
1100    fn div(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
1101        let rhs_inv = rhs.inverse();
1102        self * rhs_inv
1103    }
1104}
1105impl_borrow_variants!(AuthenticatedScalarResult<C>, Div, div, /, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
1106
1107impl<C: CurveGroup> AuthenticatedScalarResult<C> {
1108    /// Divide two batches of values
1109    pub fn batch_div(a: &[Self], b: &[Self]) -> Vec<Self> {
1110        let b_inv = Self::batch_inverse(b);
1111        Self::batch_mul(a, &b_inv)
1112    }
1113}
1114
1115// === Curve Scalar Multiplication === //
1116
1117impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &CurvePoint<C> {
1118    type Output = AuthenticatedPointResult<C>;
1119
1120    fn mul(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
1121        AuthenticatedPointResult {
1122            share: self * &rhs.share,
1123            mac: self * &rhs.mac,
1124            public_modifier: self * &rhs.public_modifier,
1125        }
1126    }
1127}
1128impl_commutative!(CurvePoint<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedPointResult<C>, C: CurveGroup);
1129
1130impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &CurvePointResult<C> {
1131    type Output = AuthenticatedPointResult<C>;
1132
1133    fn mul(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
1134        AuthenticatedPointResult {
1135            share: self * &rhs.share,
1136            mac: self * &rhs.mac,
1137            public_modifier: self * &rhs.public_modifier,
1138        }
1139    }
1140}
1141impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedPointResult<C>, C: CurveGroup);
1142impl_commutative!(CurvePointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedPointResult<C>, C: CurveGroup);
1143
1144// === FFT and IFFT === //
1145impl<C: CurveGroup> AuthenticatedScalarResult<C>
1146where
1147    C::ScalarField: FftField,
1148{
1149    /// Compute the FFT of a vector of `AuthenticatedScalarResult`s
1150    pub fn fft<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
1151        x: &[AuthenticatedScalarResult<C>],
1152    ) -> Vec<AuthenticatedScalarResult<C>> {
1153        Self::fft_with_domain::<D>(x, D::new(x.len()).unwrap())
1154    }
1155
1156    /// Compute the FFT of a vector of `AuthenticatedScalarResult`s with a given
1157    /// domain
1158    pub fn fft_with_domain<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
1159        x: &[AuthenticatedScalarResult<C>],
1160        domain: D,
1161    ) -> Vec<AuthenticatedScalarResult<C>> {
1162        Self::fft_helper::<D>(x, true /* is_forward */, domain)
1163    }
1164
1165    /// Compute the inverse FFT of a vector of `AuthenticatedScalarResult`s
1166    pub fn ifft<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
1167        x: &[AuthenticatedScalarResult<C>],
1168    ) -> Vec<AuthenticatedScalarResult<C>> {
1169        Self::fft_helper::<D>(x, false /* is_forward */, D::new(x.len()).unwrap())
1170    }
1171
1172    /// Compute the inverse FFT of a vector of `AuthenticatedScalarResult`s with
1173    /// a given domain
1174    pub fn ifft_with_domain<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
1175        x: &[AuthenticatedScalarResult<C>],
1176        domain: D,
1177    ) -> Vec<AuthenticatedScalarResult<C>> {
1178        Self::fft_helper::<D>(x, false /* is_forward */, domain)
1179    }
1180
1181    /// An FFT/IFFT helper that encapsulates the setup and restructuring of an
1182    /// FFT regardless of direction
1183    ///
1184    /// If `is_forward` is set, an FFT is performed. Otherwise, an IFFT is
1185    /// performed
1186    fn fft_helper<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
1187        x: &[AuthenticatedScalarResult<C>],
1188        is_forward: bool,
1189        domain: D,
1190    ) -> Vec<AuthenticatedScalarResult<C>> {
1191        assert!(!x.is_empty(), "Cannot compute FFT of empty vector");
1192
1193        // Take the FFT of the shares and the macs separately
1194        let shares = x.iter().map(|v| v.share()).collect_vec();
1195        let macs = x.iter().map(|v| v.mac_share()).collect_vec();
1196        let modifiers = x.iter().map(|v| v.public_modifier.clone()).collect_vec();
1197
1198        let (share_fft, mac_fft, modifier_fft) = if is_forward {
1199            (
1200                ScalarResult::fft_with_domain::<D>(&shares, domain),
1201                ScalarResult::fft_with_domain::<D>(&macs, domain),
1202                ScalarResult::fft_with_domain::<D>(&modifiers, domain),
1203            )
1204        } else {
1205            (
1206                ScalarResult::ifft_with_domain::<D>(&shares, domain),
1207                ScalarResult::ifft_with_domain::<D>(&macs, domain),
1208                ScalarResult::ifft_with_domain::<D>(&modifiers, domain),
1209            )
1210        };
1211
1212        let mut res = Vec::with_capacity(domain.size());
1213        for (share, mac, modifier) in izip!(share_fft, mac_fft, modifier_fft) {
1214            res.push(AuthenticatedScalarResult {
1215                share: MpcScalarResult::new_shared(share),
1216                mac: MpcScalarResult::new_shared(mac),
1217                public_modifier: modifier,
1218            })
1219        }
1220
1221        res
1222    }
1223}
1224
1225// ----------------
1226// | Test Helpers |
1227// ----------------
1228
1229/// Contains unsafe helpers for modifying values, methods in this module should
1230/// *only* be used for testing
1231#[cfg(feature = "test_helpers")]
1232pub mod test_helpers {
1233    use ark_ec::CurveGroup;
1234
1235    use crate::algebra::scalar::Scalar;
1236
1237    use super::AuthenticatedScalarResult;
1238
1239    /// Modify the MAC of an `AuthenticatedScalarResult`
1240    pub fn modify_mac<C: CurveGroup>(val: &mut AuthenticatedScalarResult<C>, new_value: Scalar<C>) {
1241        val.mac = val.fabric().allocate_scalar(new_value).into()
1242    }
1243
1244    /// Modify the underlying secret share of an `AuthenticatedScalarResult`
1245    pub fn modify_share<C: CurveGroup>(
1246        val: &mut AuthenticatedScalarResult<C>,
1247        new_value: Scalar<C>,
1248    ) {
1249        val.share = val.fabric().allocate_scalar(new_value).into()
1250    }
1251
1252    /// Modify the public modifier of an `AuthenticatedScalarResult` by adding
1253    /// an offset
1254    pub fn modify_public_modifier<C: CurveGroup>(
1255        val: &mut AuthenticatedScalarResult<C>,
1256        new_value: Scalar<C>,
1257    ) {
1258        val.public_modifier = val.fabric().allocate_scalar(new_value)
1259    }
1260}
1261
1262#[cfg(test)]
1263mod tests {
1264    use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
1265    use futures::future;
1266    use itertools::Itertools;
1267    use rand::{thread_rng, Rng, RngCore};
1268
1269    use crate::{
1270        algebra::{poly_test_helpers::TestPolyField, scalar::Scalar, AuthenticatedScalarResult},
1271        test_helpers::{execute_mock_mpc, TestCurve},
1272        PARTY0, PARTY1,
1273    };
1274
1275    /// Test subtraction across non-commutative types
1276    #[tokio::test]
1277    async fn test_sub() {
1278        let mut rng = thread_rng();
1279        let value1 = Scalar::random(&mut rng);
1280        let value2 = Scalar::random(&mut rng);
1281
1282        let (res, _) = execute_mock_mpc(|fabric| async move {
1283            // Allocate the first value as a shared scalar and the second as a public scalar
1284            let party0_value = fabric.share_scalar(value1, PARTY0);
1285            let public_value = fabric.allocate_scalar(value2);
1286
1287            // Subtract the public value from the shared value
1288            let res1 = &party0_value - &public_value;
1289            let res_open1 = res1.open_authenticated().await.unwrap();
1290            let expected1 = value1 - value2;
1291
1292            // Subtract the shared value from the public value
1293            let res2 = &public_value - &party0_value;
1294            let res_open2 = res2.open_authenticated().await.unwrap();
1295            let expected2 = value2 - value1;
1296
1297            (res_open1 == expected1, res_open2 == expected2)
1298        })
1299        .await;
1300
1301        assert!(res.0);
1302        assert!(res.1)
1303    }
1304
1305    /// Tests subtraction with a constant value outside of the fabric
1306    #[tokio::test]
1307    async fn test_sub_constant() {
1308        let mut rng = thread_rng();
1309        let value1 = Scalar::random(&mut rng);
1310        let value2 = Scalar::random(&mut rng);
1311
1312        let (res, _) = execute_mock_mpc(|fabric| async move {
1313            // Allocate the first value as a shared scalar and the second as a public scalar
1314            let party0_value = fabric.share_scalar(value1, PARTY0);
1315
1316            // Subtract the public value from the shared value
1317            let res1 = &party0_value - value2;
1318            let res_open1 = res1.open_authenticated().await.unwrap();
1319            let expected1 = value1 - value2;
1320
1321            // Subtract the shared value from the public value
1322            let res2 = value2 - &party0_value;
1323            let res_open2 = res2.open_authenticated().await.unwrap();
1324            let expected2 = value2 - value1;
1325
1326            (res_open1 == expected1, res_open2 == expected2)
1327        })
1328        .await;
1329
1330        assert!(res.0);
1331        assert!(res.1)
1332    }
1333
1334    /// Tests division between a shared and public scalar
1335    #[tokio::test]
1336    async fn test_public_division() {
1337        let mut rng = thread_rng();
1338        let value1 = Scalar::random(&mut rng);
1339        let value2 = Scalar::random(&mut rng);
1340
1341        let expected_res = value1 * value2.inverse();
1342
1343        let (res, _) = execute_mock_mpc(|fabric| async move {
1344            let shared_value = fabric.share_scalar(value1, PARTY0);
1345            let public_value = fabric.allocate_scalar(value2);
1346
1347            (shared_value / public_value).open().await
1348        })
1349        .await;
1350
1351        assert_eq!(res, expected_res)
1352    }
1353
1354    /// Tests division between two authenticated values
1355    #[tokio::test]
1356    async fn test_division() {
1357        let mut rng = thread_rng();
1358        let value1 = Scalar::random(&mut rng);
1359        let value2 = Scalar::random(&mut rng);
1360
1361        let expected_res = value1 / value2;
1362
1363        let (res, _) = execute_mock_mpc(|fabric| async move {
1364            let shared_value1 = fabric.share_scalar(value1, PARTY0 /* sender */);
1365            let shared_value2 = fabric.share_scalar(value2, PARTY1 /* sender */);
1366
1367            (shared_value1 / shared_value2).open_authenticated().await
1368        })
1369        .await;
1370
1371        assert_eq!(res.unwrap(), expected_res)
1372    }
1373
1374    /// Tests batch division between authenticated values
1375    #[tokio::test]
1376    async fn test_batch_div() {
1377        const N: usize = 100;
1378        let mut rng = thread_rng();
1379
1380        let a_values = (0..N)
1381            .map(|_| Scalar::<TestCurve>::random(&mut rng))
1382            .collect_vec();
1383        let b_values = (0..N)
1384            .map(|_| Scalar::<TestCurve>::random(&mut rng))
1385            .collect_vec();
1386
1387        let expected_res = a_values
1388            .iter()
1389            .zip(b_values.iter())
1390            .map(|(a, b)| a / b)
1391            .collect_vec();
1392
1393        let (res, _) = execute_mock_mpc(|fabric| {
1394            let a = a_values.clone();
1395            let b = b_values.clone();
1396            async move {
1397                let shared_a = fabric.batch_share_scalar(a, PARTY0 /* sender */);
1398                let shared_b = fabric.batch_share_scalar(b, PARTY1 /* sender */);
1399
1400                let res = AuthenticatedScalarResult::batch_div(&shared_a, &shared_b);
1401                let opening = AuthenticatedScalarResult::open_authenticated_batch(&res);
1402                future::join_all(opening.into_iter())
1403                    .await
1404                    .into_iter()
1405                    .collect::<Result<Vec<_>, _>>()
1406            }
1407        })
1408        .await;
1409
1410        assert_eq!(res.unwrap(), expected_res)
1411    }
1412
1413    /// Test a simple `XOR` circuit
1414    #[tokio::test]
1415    async fn test_xor_circuit() {
1416        let (res, _) = execute_mock_mpc(|fabric| async move {
1417            let a = &fabric.zero_authenticated();
1418            let b = &fabric.zero_authenticated();
1419            let res = a + b - Scalar::from(2u64) * a * b;
1420
1421            res.open_authenticated().await
1422        })
1423        .await;
1424
1425        assert_eq!(res.unwrap(), 0u8.into());
1426    }
1427
1428    /// Tests computing the inverse of a scalar
1429    #[tokio::test]
1430    async fn test_batch_inverse() {
1431        const N: usize = 10;
1432
1433        let mut rng = thread_rng();
1434        let values = (0..N)
1435            .map(|_| Scalar::<TestCurve>::random(&mut rng))
1436            .collect_vec();
1437        let expected_res = values.iter().map(|v| v.inverse()).collect_vec();
1438
1439        let (res, _) = execute_mock_mpc(|fabric| {
1440            let values = values.clone();
1441            async move {
1442                let shared_values = fabric.batch_share_scalar(values, PARTY0 /* sender */);
1443                let inverses = AuthenticatedScalarResult::batch_inverse(&shared_values);
1444
1445                let opening = AuthenticatedScalarResult::open_authenticated_batch(&inverses);
1446                future::join_all(opening.into_iter())
1447                    .await
1448                    .into_iter()
1449                    .collect::<Result<Vec<_>, _>>()
1450            }
1451        })
1452        .await;
1453
1454        assert_eq!(res.unwrap(), expected_res)
1455    }
1456
1457    /// Tests exponentiation
1458    #[tokio::test]
1459    async fn test_pow() {
1460        let mut rng = thread_rng();
1461        let exp = rng.next_u64();
1462        let value = Scalar::<TestCurve>::random(&mut rng);
1463
1464        let expected_res = value.pow(exp);
1465
1466        let (res, _) = execute_mock_mpc(|fabric| async move {
1467            let shared_value = fabric.share_scalar(value, PARTY0 /* sender */);
1468            let res = shared_value.pow(exp);
1469
1470            res.open().await
1471        })
1472        .await;
1473
1474        assert_eq!(res, expected_res)
1475    }
1476
1477    #[tokio::test]
1478    async fn test_fft() {
1479        let mut rng = thread_rng();
1480        let n: usize = rng.gen_range(0..100);
1481        let domain_size = rng.gen_range(n..10 * n);
1482
1483        let values = (0..n)
1484            .map(|_| Scalar::<TestCurve>::random(&mut rng))
1485            .collect_vec();
1486
1487        let domain = Radix2EvaluationDomain::<TestPolyField>::new(domain_size).unwrap();
1488        let fft_res = domain.fft(
1489            &values
1490                .iter()
1491                // Add one to test public modifiers
1492                .map(|v| (v + Scalar::one()).inner())
1493                .collect_vec(),
1494        );
1495        let expected_res = fft_res.into_iter().map(Scalar::new).collect_vec();
1496
1497        let (res, _) = execute_mock_mpc(|fabric| {
1498            let values = values.clone();
1499            async move {
1500                let shared_values = fabric
1501                    .batch_share_scalar(values, PARTY0 /* sender */)
1502                    .into_iter()
1503                    .map(|v| v + Scalar::one())
1504                    .collect_vec();
1505                let fft = AuthenticatedScalarResult::fft_with_domain::<
1506                    Radix2EvaluationDomain<TestPolyField>,
1507                >(&shared_values, domain);
1508
1509                let opening = AuthenticatedScalarResult::open_authenticated_batch(&fft);
1510                future::join_all(opening.into_iter())
1511                    .await
1512                    .into_iter()
1513                    .collect::<Result<Vec<_>, _>>()
1514            }
1515        })
1516        .await;
1517
1518        assert_eq!(res.unwrap(), expected_res)
1519    }
1520
1521    #[tokio::test]
1522    async fn test_ifft() {
1523        let mut rng = thread_rng();
1524        let n: usize = rng.gen_range(0..100);
1525        let domain_size = rng.gen_range(n..10 * n);
1526
1527        let values = (0..n)
1528            .map(|_| Scalar::<TestCurve>::random(&mut rng))
1529            .collect_vec();
1530
1531        let domain = Radix2EvaluationDomain::<TestPolyField>::new(domain_size).unwrap();
1532        let ifft_res = domain.ifft(
1533            &values
1534                .iter()
1535                // Add one to test public modifiers
1536                .map(|v| (v + Scalar::one()).inner())
1537                .collect_vec(),
1538        );
1539        let expected_res = ifft_res.into_iter().map(Scalar::new).collect_vec();
1540
1541        let (res, _) = execute_mock_mpc(|fabric| {
1542            let values = values.clone();
1543            async move {
1544                let shared_values = fabric.batch_share_scalar(values, PARTY0 /* sender */);
1545                let shared_values = shared_values
1546                    .into_iter()
1547                    .map(|v| v + Scalar::one())
1548                    .collect_vec();
1549
1550                let ifft = AuthenticatedScalarResult::ifft_with_domain::<
1551                    Radix2EvaluationDomain<TestPolyField>,
1552                >(&shared_values, domain);
1553
1554                let opening = AuthenticatedScalarResult::open_authenticated_batch(&ifft);
1555                future::join_all(opening.into_iter())
1556                    .await
1557                    .into_iter()
1558                    .collect::<Result<Vec<_>, _>>()
1559            }
1560        })
1561        .await;
1562
1563        assert_eq!(res.unwrap(), expected_res)
1564    }
1565}