ark_mpc/algebra/curve/
authenticated_curve.rs

1//! Defines an malicious secure wrapper around an `MpcCurvePoint<C>` type that
2//! includes a MAC for ensuring computational integrity of an opened point
3
4use std::{
5    fmt::{Debug, Formatter, Result as FmtResult},
6    iter::Sum,
7    ops::{Add, Mul, Neg, Sub},
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12use ark_ec::CurveGroup;
13use futures::{Future, FutureExt};
14use itertools::{izip, Itertools};
15
16use crate::{
17    algebra::macros::*,
18    algebra::scalar::*,
19    commitment::{HashCommitment, HashCommitmentResult},
20    error::MpcError,
21    fabric::{MpcFabric, ResultValue},
22    ResultId, PARTY0,
23};
24
25use super::{
26    curve::{BatchCurvePointResult, CurvePoint, CurvePointResult},
27    mpc_curve::MpcPointResult,
28};
29
30/// The number of underlying results in an `AuthenticatedPointResult`
31pub(crate) const AUTHENTICATED_POINT_RESULT_LEN: usize = 3;
32
33/// A maliciously secure wrapper around `MpcPointResult` that includes a MAC as
34/// per the SPDZ protocol: https://eprint.iacr.org/2011/535.pdf
35#[derive(Clone)]
36pub struct AuthenticatedPointResult<C: CurveGroup> {
37    /// The local secret share of the underlying authenticated point
38    pub(crate) share: MpcPointResult<C>,
39    /// A SPDZ style, unconditionally secure MAC of the value
40    /// This is used to ensure computational integrity of the opened value
41    ///
42    /// See the doc comment in `AuthenticatedScalar` for more details
43    pub(crate) mac: MpcPointResult<C>,
44    /// The public modifier tracks additions and subtractions of public values
45    /// to the shares
46    ///
47    /// Only the first party adds/subtracts public values to their share, but
48    /// the other parties must track this to validate the MAC when it is
49    /// opened
50    pub(crate) public_modifier: CurvePointResult<C>,
51}
52
53impl<C: CurveGroup> Debug for AuthenticatedPointResult<C> {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("AuthenticatedPointResult")
56            .field("value", &self.share.id())
57            .field("mac", &self.mac.id())
58            .field("public_modifier", &self.public_modifier.id)
59            .finish()
60    }
61}
62
63impl<C: CurveGroup> AuthenticatedPointResult<C> {
64    /// Creates a new authenticated point from a given underlying point
65    pub fn new_shared(value: CurvePointResult<C>) -> AuthenticatedPointResult<C> {
66        // Create an mpc point from the value
67        let fabric_clone = value.fabric.clone();
68
69        let mpc_value = MpcPointResult::new_shared(value);
70        let mac = fabric_clone.borrow_mac_key() * &mpc_value;
71
72        // Allocate a zero point for the public modifier
73        let public_modifier = fabric_clone.allocate_point(CurvePoint::identity());
74
75        Self {
76            share: mpc_value,
77            mac,
78            public_modifier,
79        }
80    }
81
82    /// Creates a batch of `AuthenticatedPointResult`s from a given batch of
83    /// underlying points
84    pub fn new_shared_batch(values: &[CurvePointResult<C>]) -> Vec<AuthenticatedPointResult<C>> {
85        if values.is_empty() {
86            return vec![];
87        }
88
89        // Create mpc points from the value
90        let n = values.len();
91        let fabric = values[0].fabric();
92        let mpc_values = values
93            .iter()
94            .map(|p| MpcPointResult::new_shared(p.clone()))
95            .collect_vec();
96
97        let mac_keys = (0..n)
98            .map(|_| fabric.borrow_mac_key().clone())
99            .collect_vec();
100        let macs = MpcPointResult::batch_mul(&mac_keys, &mpc_values);
101
102        mpc_values
103            .into_iter()
104            .zip(macs)
105            .map(|(share, mac)| Self {
106                share,
107                mac,
108                public_modifier: fabric.curve_identity(),
109            })
110            .collect_vec()
111    }
112
113    /// Creates a batch of `AuthenticatedPointResult`s from a batch result
114    ///
115    /// The batch result combines the batch into one result, so it must be split
116    /// out first before creating the `AuthenticatedPointResult`s
117    pub fn new_shared_from_batch_result(
118        values: BatchCurvePointResult<C>,
119        n: usize,
120    ) -> Vec<AuthenticatedPointResult<C>> {
121        // Convert to a set of scalar results
122        let scalar_results: Vec<CurvePointResult<C>> =
123            values
124                .fabric()
125                .new_batch_gate_op(vec![values.id()], n, |mut args| {
126                    let args: Vec<CurvePoint<C>> = args.pop().unwrap().into();
127                    args.into_iter().map(ResultValue::Point).collect_vec()
128                });
129
130        Self::new_shared_batch(&scalar_results)
131    }
132
133    /// Get the ID of the underlying share's result
134    pub fn id(&self) -> ResultId {
135        self.share.id()
136    }
137
138    /// Get the IDs of the results that make up the `AuthenticatedPointResult`
139    /// representation
140    pub(crate) fn ids(&self) -> Vec<ResultId> {
141        vec![self.share.id(), self.mac.id(), self.public_modifier.id]
142    }
143
144    /// Borrow the fabric that this result is allocated in
145    pub fn fabric(&self) -> &MpcFabric<C> {
146        self.share.fabric()
147    }
148
149    /// Get the underlying share as an `MpcPointResult`
150    #[cfg(feature = "test_helpers")]
151    pub fn mpc_share(&self) -> MpcPointResult<C> {
152        self.share.clone()
153    }
154
155    /// Open the value without checking the MAC
156    pub fn open(&self) -> CurvePointResult<C> {
157        self.share.open()
158    }
159
160    /// Open a batch of values without checking the MAC
161    pub fn open_batch(values: &[Self]) -> Vec<CurvePointResult<C>> {
162        MpcPointResult::open_batch(&values.iter().map(|v| v.share.clone()).collect_vec())
163    }
164
165    /// Convert a flattened iterator into a batch of `AuthenticatedPointResult`s
166    ///
167    /// We assume that the iterator has been flattened in the same way order
168    /// that `Self::id`s returns the `AuthenticatedScalar`'s values:
169    /// `[share, mac, public_modifier]`
170    pub fn from_flattened_iterator<I>(iter: I) -> Vec<Self>
171    where
172        I: Iterator<Item = CurvePointResult<C>>,
173    {
174        iter.chunks(AUTHENTICATED_POINT_RESULT_LEN)
175            .into_iter()
176            .map(|mut chunk| Self {
177                share: chunk.next().unwrap().into(),
178                mac: chunk.next().unwrap().into(),
179                public_modifier: chunk.next().unwrap(),
180            })
181            .collect_vec()
182    }
183
184    /// Verify the MAC check on an authenticated opening
185    fn verify_mac_check(
186        my_mac_share: CurvePoint<C>,
187        peer_mac_share: CurvePoint<C>,
188        peer_mac_commitment: Scalar<C>,
189        peer_blinder: Scalar<C>,
190    ) -> bool {
191        // Check that the MAC check value is the correct opening of the
192        // given commitment
193        let peer_comm = HashCommitment {
194            value: peer_mac_share,
195            blinder: peer_blinder,
196            commitment: peer_mac_commitment,
197        };
198        if !peer_comm.verify() {
199            return false;
200        }
201
202        // Check that the MAC check shares add up to the additive identity in
203        // the curve group
204        if my_mac_share + peer_mac_share != CurvePoint::identity() {
205            return false;
206        }
207
208        true
209    }
210
211    /// Open the value and check the MAC
212    ///
213    /// This follows the protocol detailed in
214    ///     https://securecomputation.org/docs/pragmaticmpc.pdf
215    pub fn open_authenticated(&self) -> AuthenticatedPointOpenResult<C> {
216        // Both parties open the underlying value
217        let recovered_value = self.share.open();
218
219        // Add a gate to compute hte MAC check value: `key_share * opened_value -
220        // mac_share`
221        let mac_check: CurvePointResult<C> = self.fabric().new_gate_op(
222            vec![
223                self.fabric().borrow_mac_key().id(),
224                recovered_value.id(),
225                self.public_modifier.id(),
226                self.mac.id(),
227            ],
228            |mut args| {
229                let mac_key_share: Scalar<C> = args.remove(0).into();
230                let value: CurvePoint<C> = args.remove(0).into();
231                let modifier: CurvePoint<C> = args.remove(0).into();
232                let mac_share: CurvePoint<C> = args.remove(0).into();
233
234                ResultValue::Point((value + modifier) * mac_key_share - mac_share)
235            },
236        );
237
238        // Compute a commitment to this value and share it with the peer
239        let my_comm = HashCommitmentResult::commit(mac_check.clone());
240        let peer_commit = self.fabric().exchange_value(my_comm.commitment);
241
242        // Once the parties have exchanged their commitments, they can open the
243        // underlying MAC check value as they are bound by the commitment
244        let peer_mac_check = self.fabric().exchange_value(my_comm.value.clone());
245        let blinder_result: ScalarResult<C> = self.fabric().allocate_scalar(my_comm.blinder);
246        let peer_blinder = self.fabric().exchange_value(blinder_result);
247
248        // Check the peer's commitment and the sum of the MAC checks
249        let commitment_check: ScalarResult<C> = self.fabric().new_gate_op(
250            vec![
251                mac_check.id,
252                peer_mac_check.id,
253                peer_blinder.id,
254                peer_commit.id,
255            ],
256            move |mut args| {
257                let my_mac_check: CurvePoint<C> = args.remove(0).into();
258                let peer_mac_check: CurvePoint<C> = args.remove(0).into();
259                let peer_blinder: Scalar<C> = args.remove(0).into();
260                let peer_commitment: Scalar<C> = args.remove(0).into();
261
262                ResultValue::Scalar(Scalar::from(Self::verify_mac_check(
263                    my_mac_check,
264                    peer_mac_check,
265                    peer_commitment,
266                    peer_blinder,
267                )))
268            },
269        );
270
271        AuthenticatedPointOpenResult {
272            value: recovered_value,
273            mac_check: commitment_check,
274        }
275    }
276
277    /// Open a batch of values and check the MACs
278    pub fn open_authenticated_batch(values: &[Self]) -> Vec<AuthenticatedPointOpenResult<C>> {
279        if values.is_empty() {
280            return Vec::new();
281        }
282
283        let n = values.len();
284        let fabric = values[0].fabric();
285
286        // Open the values
287        let opened_values = Self::open_batch(values);
288
289        // --- MAC Check --- //
290
291        // Compute the shares of the MAC check in batch
292        let mut mac_check_deps = Vec::with_capacity(1 + AUTHENTICATED_POINT_RESULT_LEN * n);
293        mac_check_deps.push(fabric.borrow_mac_key().id());
294        for i in 0..n {
295            mac_check_deps.push(opened_values[i].id());
296            mac_check_deps.push(values[i].public_modifier.id());
297            mac_check_deps.push(values[i].mac.id());
298        }
299
300        let mac_checks: Vec<CurvePointResult<C>> =
301            fabric.new_batch_gate_op(mac_check_deps, n /* output_arity */, move |mut args| {
302                let mac_key_share: Scalar<C> = args.remove(0).into();
303                let mut check_result = Vec::with_capacity(n);
304
305                for _ in 0..n {
306                    let value: CurvePoint<C> = args.remove(0).into();
307                    let modifier: CurvePoint<C> = args.remove(0).into();
308                    let mac_share: CurvePoint<C> = args.remove(0).into();
309
310                    check_result.push(mac_key_share * (value + modifier) - mac_share);
311                }
312
313                check_result.into_iter().map(ResultValue::Point).collect()
314            });
315
316        // --- Commit to the MAC checks --- //
317
318        let my_comms = mac_checks
319            .iter()
320            .cloned()
321            .map(HashCommitmentResult::commit)
322            .collect_vec();
323        let peer_comms = fabric.exchange_values(
324            &my_comms
325                .iter()
326                .map(|comm| comm.commitment.clone())
327                .collect_vec(),
328        );
329
330        // --- Exchange the MAC Checks and Commitment Blinders --- //
331
332        let peer_mac_checks = fabric.exchange_values(&mac_checks);
333        let peer_blinders = fabric.exchange_values(
334            &my_comms
335                .iter()
336                .map(|comm| fabric.allocate_scalar(comm.blinder))
337                .collect_vec(),
338        );
339
340        // --- Check the MAC Checks --- //
341
342        let mut mac_check_gate_deps = my_comms.iter().map(|comm| comm.value.id).collect_vec();
343        mac_check_gate_deps.push(peer_mac_checks.id);
344        mac_check_gate_deps.push(peer_blinders.id);
345        mac_check_gate_deps.push(peer_comms.id);
346
347        let commitment_checks: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
348            mac_check_gate_deps,
349            n, // output_arity
350            move |mut args| {
351                let my_comms: Vec<CurvePoint<C>> =
352                    args.drain(..n).map(|comm| comm.into()).collect();
353                let peer_mac_checks: Vec<CurvePoint<C>> = args.remove(0).into();
354                let peer_blinders: Vec<Scalar<C>> = args.remove(0).into();
355                let peer_comms: Vec<Scalar<C>> = args.remove(0).into();
356
357                // Build a commitment from the gate inputs
358                let mut mac_checks = Vec::with_capacity(n);
359                for (my_mac_share, peer_mac_share, peer_blinder, peer_commitment) in izip!(
360                    my_comms.into_iter(),
361                    peer_mac_checks.into_iter(),
362                    peer_blinders.into_iter(),
363                    peer_comms.into_iter()
364                ) {
365                    let mac_check = Self::verify_mac_check(
366                        my_mac_share,
367                        peer_mac_share,
368                        peer_commitment,
369                        peer_blinder,
370                    );
371                    mac_checks.push(ResultValue::Scalar(Scalar::from(mac_check)));
372                }
373
374                mac_checks
375            },
376        );
377
378        // --- Return the results --- //
379
380        opened_values
381            .into_iter()
382            .zip(commitment_checks)
383            .map(|(value, check)| AuthenticatedPointOpenResult {
384                value,
385                mac_check: check,
386            })
387            .collect_vec()
388    }
389}
390
391/// The value that results from opening an `AuthenticatedPointResult` and
392/// checking its MAC. This encapsulates both the underlying value and the result
393/// of the MAC check
394#[derive(Clone)]
395pub struct AuthenticatedPointOpenResult<C: CurveGroup> {
396    /// The underlying value
397    pub value: CurvePointResult<C>,
398    /// The result of the MAC check
399    pub mac_check: ScalarResult<C>,
400}
401
402impl<C: CurveGroup> Debug for AuthenticatedPointOpenResult<C> {
403    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
404        f.debug_struct("AuthenticatedPointOpenResult")
405            .field("value", &self.value.id)
406            .field("mac_check", &self.mac_check.id)
407            .finish()
408    }
409}
410
411impl<C: CurveGroup> Future for AuthenticatedPointOpenResult<C>
412where
413    C::ScalarField: Unpin,
414{
415    type Output = Result<CurvePoint<C>, MpcError>;
416
417    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
418        // Await both of the underlying values
419        let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
420        let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));
421
422        if mac_check == Scalar::from(1u8) {
423            Poll::Ready(Ok(value))
424        } else {
425            Poll::Ready(Err(MpcError::AuthenticationError))
426        }
427    }
428}
429
430impl<C: CurveGroup> Sum for AuthenticatedPointResult<C> {
431    // Assumes the iterator is non-empty
432    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
433        let first = iter
434            .next()
435            .expect("AuthenticatedPointResult<C>::sum requires a non-empty iterator");
436        iter.fold(first, |acc, x| acc + x)
437    }
438}
439
440// --------------
441// | Arithmetic |
442// --------------
443
444// === Addition === //
445
446impl<C: CurveGroup> Add<&CurvePoint<C>> for &AuthenticatedPointResult<C> {
447    type Output = AuthenticatedPointResult<C>;
448
449    fn add(self, other: &CurvePoint<C>) -> AuthenticatedPointResult<C> {
450        let new_share = if self.fabric().party_id() == PARTY0 {
451            // Party zero adds the public value to their share
452            &self.share + other
453        } else {
454            // Other parties just add the identity to the value to allocate a new op and
455            // keep in sync with party 0
456            &self.share + CurvePoint::identity()
457        };
458
459        // Add the public value to the MAC
460        let new_modifier = &self.public_modifier - other;
461        AuthenticatedPointResult {
462            share: new_share,
463            mac: self.mac.clone(),
464            public_modifier: new_modifier,
465        }
466    }
467}
468impl_borrow_variants!(AuthenticatedPointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
469impl_commutative!(AuthenticatedPointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
470
471impl<C: CurveGroup> Add<&CurvePointResult<C>> for &AuthenticatedPointResult<C> {
472    type Output = AuthenticatedPointResult<C>;
473
474    fn add(self, other: &CurvePointResult<C>) -> AuthenticatedPointResult<C> {
475        let new_share = if self.fabric().party_id() == PARTY0 {
476            // Party zero adds the public value to their share
477            &self.share + other
478        } else {
479            // Other parties just add the identity to the value to allocate a new op and
480            // keep in sync with party 0
481            &self.share + CurvePoint::identity()
482        };
483
484        // Add the public value to the MAC
485        let new_modifier = &self.public_modifier - other;
486        AuthenticatedPointResult {
487            share: new_share,
488            mac: self.mac.clone(),
489            public_modifier: new_modifier,
490        }
491    }
492}
493impl_borrow_variants!(AuthenticatedPointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
494impl_commutative!(AuthenticatedPointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
495
496impl<C: CurveGroup> Add<&AuthenticatedPointResult<C>> for &AuthenticatedPointResult<C> {
497    type Output = AuthenticatedPointResult<C>;
498
499    fn add(self, other: &AuthenticatedPointResult<C>) -> AuthenticatedPointResult<C> {
500        let new_share = &self.share + &other.share;
501
502        // Add the public value to the MAC
503        let new_mac = &self.mac + &other.mac;
504        AuthenticatedPointResult {
505            share: new_share,
506            mac: new_mac,
507            public_modifier: self.public_modifier.clone() + other.public_modifier.clone(),
508        }
509    }
510}
511impl_borrow_variants!(AuthenticatedPointResult<C>, Add, add, +, AuthenticatedPointResult<C>, C: CurveGroup);
512
513impl<C: CurveGroup> AuthenticatedPointResult<C> {
514    /// Add two batches of `AuthenticatedPointResult`s
515    pub fn batch_add(
516        a: &[AuthenticatedPointResult<C>],
517        b: &[AuthenticatedPointResult<C>],
518    ) -> Vec<AuthenticatedPointResult<C>> {
519        assert_eq!(a.len(), b.len(), "batch_add requires equal length vectors");
520        if a.is_empty() {
521            return Vec::new();
522        }
523
524        let n = a.len();
525        let fabric = a[0].fabric();
526        let all_ids = a.iter().chain(b.iter()).flat_map(|p| p.ids()).collect_vec();
527
528        let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
529            all_ids,
530            AUTHENTICATED_POINT_RESULT_LEN * n,
531            move |mut args| {
532                let len = args.len();
533                let a_vals = args.drain(..len / 2).collect_vec();
534                let b_vals = args;
535
536                let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
537                for (a_chunk, b_chunk) in a_vals
538                    .chunks(AUTHENTICATED_POINT_RESULT_LEN)
539                    .zip(b_vals.chunks(AUTHENTICATED_POINT_RESULT_LEN))
540                {
541                    let a_share: CurvePoint<C> = a_chunk[0].clone().into();
542                    let a_mac: CurvePoint<C> = a_chunk[1].clone().into();
543                    let a_modifier: CurvePoint<C> = a_chunk[2].clone().into();
544
545                    let b_share: CurvePoint<C> = b_chunk[0].clone().into();
546                    let b_mac: CurvePoint<C> = b_chunk[1].clone().into();
547                    let b_modifier: CurvePoint<C> = b_chunk[2].clone().into();
548
549                    result.push(ResultValue::Point(a_share + b_share));
550                    result.push(ResultValue::Point(a_mac + b_mac));
551                    result.push(ResultValue::Point(a_modifier + b_modifier));
552                }
553
554                result
555            },
556        );
557
558        Self::from_flattened_iterator(res.into_iter())
559    }
560
561    /// Add a batch of `AuthenticatedPointResult`s to a batch of
562    /// `CurvePointResult`s
563    pub fn batch_add_public(
564        a: &[AuthenticatedPointResult<C>],
565        b: &[CurvePointResult<C>],
566    ) -> Vec<AuthenticatedPointResult<C>> {
567        assert_eq!(
568            a.len(),
569            b.len(),
570            "batch_add_public requires equal length vectors"
571        );
572        if a.is_empty() {
573            return Vec::new();
574        }
575
576        let n = a.len();
577        let fabric = a[0].fabric();
578        let all_ids = a
579            .iter()
580            .flat_map(|a| a.ids())
581            .chain(b.iter().map(|b| b.id()))
582            .collect_vec();
583
584        let party_id = fabric.party_id();
585        let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
586            all_ids,
587            AUTHENTICATED_POINT_RESULT_LEN * n,
588            move |mut args| {
589                let a_vals = args
590                    .drain(..AUTHENTICATED_POINT_RESULT_LEN * n)
591                    .collect_vec();
592                let b_vals = args;
593
594                let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
595                for (a_chunk, b_val) in a_vals
596                    .chunks(AUTHENTICATED_POINT_RESULT_LEN)
597                    .zip(b_vals.into_iter())
598                {
599                    let a_share: CurvePoint<C> = a_chunk[0].clone().into();
600                    let a_mac: CurvePoint<C> = a_chunk[1].clone().into();
601                    let a_modifier: CurvePoint<C> = a_chunk[2].clone().into();
602
603                    let public_value: CurvePoint<C> = b_val.into();
604
605                    // Only the first party adds the public value to their share
606                    if party_id == PARTY0 {
607                        result.push(ResultValue::Point(a_share + public_value));
608                    } else {
609                        result.push(ResultValue::Point(a_share))
610                    }
611
612                    result.push(ResultValue::Point(a_mac));
613                    result.push(ResultValue::Point(a_modifier - public_value));
614                }
615
616                result
617            },
618        );
619
620        Self::from_flattened_iterator(res.into_iter())
621    }
622}
623
624// === Subtraction === //
625
626impl<C: CurveGroup> Sub<&CurvePoint<C>> for &AuthenticatedPointResult<C> {
627    type Output = AuthenticatedPointResult<C>;
628
629    fn sub(self, other: &CurvePoint<C>) -> AuthenticatedPointResult<C> {
630        let new_share = if self.fabric().party_id() == PARTY0 {
631            // Party zero subtracts the public value from their share
632            &self.share - other
633        } else {
634            // Other parties just subtract the identity from the value to allocate a new op
635            // and keep in sync with party 0
636            &self.share - CurvePoint::identity()
637        };
638
639        // Subtract the public value from the MAC
640        let new_modifier = &self.public_modifier + other;
641        AuthenticatedPointResult {
642            share: new_share,
643            mac: self.mac.clone(),
644            public_modifier: new_modifier,
645        }
646    }
647}
648impl_borrow_variants!(AuthenticatedPointResult<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
649impl_commutative!(AuthenticatedPointResult<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
650
651impl<C: CurveGroup> Sub<&CurvePointResult<C>> for &AuthenticatedPointResult<C> {
652    type Output = AuthenticatedPointResult<C>;
653
654    fn sub(self, other: &CurvePointResult<C>) -> AuthenticatedPointResult<C> {
655        let new_share = if self.fabric().party_id() == PARTY0 {
656            // Party zero subtracts the public value from their share
657            &self.share - other
658        } else {
659            // Other parties just subtract the identity from the value to allocate a new op
660            // and keep in sync with party 0
661            &self.share - CurvePoint::identity()
662        };
663
664        // Subtract the public value from the MAC
665        let new_modifier = &self.public_modifier + other;
666        AuthenticatedPointResult {
667            share: new_share,
668            mac: self.mac.clone(),
669            public_modifier: new_modifier,
670        }
671    }
672}
673impl_borrow_variants!(AuthenticatedPointResult<C>, Sub, sub, -, CurvePointResult<C>, C: CurveGroup);
674impl_commutative!(AuthenticatedPointResult<C>, Sub, sub, -, CurvePointResult<C>, C: CurveGroup);
675
676impl<C: CurveGroup> Sub<&AuthenticatedPointResult<C>> for &AuthenticatedPointResult<C> {
677    type Output = AuthenticatedPointResult<C>;
678
679    fn sub(self, other: &AuthenticatedPointResult<C>) -> AuthenticatedPointResult<C> {
680        let new_share = &self.share - &other.share;
681
682        // Subtract the public value from the MAC
683        let new_mac = &self.mac - &other.mac;
684        AuthenticatedPointResult {
685            share: new_share,
686            mac: new_mac,
687            public_modifier: self.public_modifier.clone(),
688        }
689    }
690}
691impl_borrow_variants!(AuthenticatedPointResult<C>, Sub, sub, -, AuthenticatedPointResult<C>, C: CurveGroup);
692
693impl<C: CurveGroup> AuthenticatedPointResult<C> {
694    /// Add two batches of `AuthenticatedPointResult`s
695    pub fn batch_sub(
696        a: &[AuthenticatedPointResult<C>],
697        b: &[AuthenticatedPointResult<C>],
698    ) -> Vec<AuthenticatedPointResult<C>> {
699        assert_eq!(a.len(), b.len(), "batch_add requires equal length vectors");
700        if a.is_empty() {
701            return Vec::new();
702        }
703
704        let n = a.len();
705        let fabric = a[0].fabric();
706        let all_ids = a.iter().chain(b.iter()).flat_map(|p| p.ids()).collect_vec();
707
708        let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
709            all_ids,
710            AUTHENTICATED_POINT_RESULT_LEN * n,
711            move |mut args| {
712                let len = args.len();
713                let a_vals = args.drain(..len / 2).collect_vec();
714                let b_vals = args;
715
716                let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
717                for (a_chunk, b_chunk) in a_vals
718                    .chunks(AUTHENTICATED_POINT_RESULT_LEN)
719                    .zip(b_vals.chunks(AUTHENTICATED_POINT_RESULT_LEN))
720                {
721                    let a_share: CurvePoint<C> = a_chunk[0].clone().into();
722                    let a_mac: CurvePoint<C> = a_chunk[1].clone().into();
723                    let a_modifier: CurvePoint<C> = a_chunk[2].clone().into();
724
725                    let b_share: CurvePoint<C> = b_chunk[0].clone().into();
726                    let b_mac: CurvePoint<C> = b_chunk[1].clone().into();
727                    let b_modifier: CurvePoint<C> = b_chunk[2].clone().into();
728
729                    result.push(ResultValue::Point(a_share - b_share));
730                    result.push(ResultValue::Point(a_mac - b_mac));
731                    result.push(ResultValue::Point(a_modifier - b_modifier));
732                }
733
734                result
735            },
736        );
737
738        Self::from_flattened_iterator(res.into_iter())
739    }
740
741    /// Subtract a batch of `AuthenticatedPointResult`s to a batch of
742    /// `CurvePointResult`s
743    pub fn batch_sub_public(
744        a: &[AuthenticatedPointResult<C>],
745        b: &[CurvePointResult<C>],
746    ) -> Vec<AuthenticatedPointResult<C>> {
747        assert_eq!(
748            a.len(),
749            b.len(),
750            "batch_add_public requires equal length vectors"
751        );
752        if a.is_empty() {
753            return Vec::new();
754        }
755
756        let n = a.len();
757        let fabric = a[0].fabric();
758        let all_ids = a
759            .iter()
760            .flat_map(|a| a.ids())
761            .chain(b.iter().map(|b| b.id()))
762            .collect_vec();
763
764        let party_id = fabric.party_id();
765        let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
766            all_ids,
767            AUTHENTICATED_POINT_RESULT_LEN * n,
768            move |mut args| {
769                let a_vals = args
770                    .drain(..AUTHENTICATED_POINT_RESULT_LEN * n)
771                    .collect_vec();
772                let b_vals = args;
773
774                let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
775                for (a_chunk, b_val) in a_vals
776                    .chunks(AUTHENTICATED_POINT_RESULT_LEN)
777                    .zip(b_vals.into_iter())
778                {
779                    let a_share: CurvePoint<C> = a_chunk[0].clone().into();
780                    let a_mac: CurvePoint<C> = a_chunk[1].clone().into();
781                    let a_modifier: CurvePoint<C> = a_chunk[2].clone().into();
782
783                    let b_share: CurvePoint<C> = b_val.into();
784
785                    // Only the first party adds the public value to their share
786                    if party_id == PARTY0 {
787                        result.push(ResultValue::Point(a_share - b_share));
788                    } else {
789                        result.push(ResultValue::Point(a_share))
790                    }
791
792                    result.push(ResultValue::Point(a_mac));
793                    result.push(ResultValue::Point(a_modifier + b_share));
794                }
795
796                result
797            },
798        );
799
800        Self::from_flattened_iterator(res.into_iter())
801    }
802}
803
804// === Negation == //
805
806impl<C: CurveGroup> Neg for &AuthenticatedPointResult<C> {
807    type Output = AuthenticatedPointResult<C>;
808
809    fn neg(self) -> AuthenticatedPointResult<C> {
810        let new_share = -&self.share;
811
812        // Negate the public value in the MAC
813        let new_mac = -&self.mac;
814        AuthenticatedPointResult {
815            share: new_share,
816            mac: new_mac,
817            public_modifier: self.public_modifier.clone(),
818        }
819    }
820}
821impl_borrow_variants!(AuthenticatedPointResult<C>, Neg, neg, -, C: CurveGroup);
822
823impl<C: CurveGroup> AuthenticatedPointResult<C> {
824    /// Negate a batch of `AuthenticatedPointResult`s
825    pub fn batch_neg(a: &[AuthenticatedPointResult<C>]) -> Vec<AuthenticatedPointResult<C>> {
826        if a.is_empty() {
827            return Vec::new();
828        }
829
830        let n = a.len();
831        let fabric = a[0].fabric();
832        let all_ids = a.iter().flat_map(|p| p.ids()).collect_vec();
833
834        let res: Vec<CurvePointResult<C>> =
835            fabric.new_batch_gate_op(all_ids, AUTHENTICATED_POINT_RESULT_LEN * n, move |args| {
836                args.into_iter()
837                    .map(CurvePoint::from)
838                    .map(CurvePoint::neg)
839                    .map(ResultValue::Point)
840                    .collect_vec()
841            });
842
843        Self::from_flattened_iterator(res.into_iter())
844    }
845}
846
847// === Scalar<C> Multiplication === //
848
849impl<C: CurveGroup> Mul<&Scalar<C>> for &AuthenticatedPointResult<C> {
850    type Output = AuthenticatedPointResult<C>;
851
852    fn mul(self, other: &Scalar<C>) -> AuthenticatedPointResult<C> {
853        let new_share = &self.share * other;
854
855        // Multiply the public value in the MAC
856        let new_mac = &self.mac * other;
857        let new_modifier = &self.public_modifier * other;
858        AuthenticatedPointResult {
859            share: new_share,
860            mac: new_mac,
861            public_modifier: new_modifier,
862        }
863    }
864}
865impl_borrow_variants!(AuthenticatedPointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
866impl_commutative!(AuthenticatedPointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
867
868impl<C: CurveGroup> Mul<&ScalarResult<C>> for &AuthenticatedPointResult<C> {
869    type Output = AuthenticatedPointResult<C>;
870
871    fn mul(self, other: &ScalarResult<C>) -> AuthenticatedPointResult<C> {
872        let new_share = &self.share * other;
873
874        // Multiply the public value in the MAC
875        let new_mac = &self.mac * other;
876        let new_modifier = &self.public_modifier * other;
877        AuthenticatedPointResult {
878            share: new_share,
879            mac: new_mac,
880            public_modifier: new_modifier,
881        }
882    }
883}
884impl_borrow_variants!(AuthenticatedPointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
885impl_commutative!(AuthenticatedPointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
886
887impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &AuthenticatedPointResult<C> {
888    type Output = AuthenticatedPointResult<C>;
889
890    // Beaver trick
891    fn mul(self, rhs: &AuthenticatedScalarResult<C>) -> AuthenticatedPointResult<C> {
892        // Sample a beaver triple
893        let generator = CurvePoint::generator();
894        let (a, b, c) = self.fabric().next_authenticated_triple();
895
896        // Open the values d = [rhs - a] and e = [lhs - bG] for curve group generator G
897        let masked_rhs = rhs - &a;
898        let masked_lhs = self - (&generator * &b);
899
900        #[allow(non_snake_case)]
901        let eG_open = masked_lhs.open();
902        let d_open = masked_rhs.open();
903
904        // Identity [x * yG] = deG + d[bG] + [a]eG + [c]G
905        &d_open * &eG_open + &d_open * &(&generator * &b) + &a * eG_open + &c * generator
906    }
907}
908impl_borrow_variants!(AuthenticatedPointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, C: CurveGroup);
909impl_commutative!(AuthenticatedPointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, C: CurveGroup);
910
911impl<C: CurveGroup> AuthenticatedPointResult<C> {
912    /// Multiply a batch of `AuthenticatedPointResult`s by a batch of
913    /// `AuthenticatedScalarResult`s
914    #[allow(non_snake_case)]
915    pub fn batch_mul(
916        a: &[AuthenticatedScalarResult<C>],
917        b: &[AuthenticatedPointResult<C>],
918    ) -> Vec<AuthenticatedPointResult<C>> {
919        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
920        if a.is_empty() {
921            return Vec::new();
922        }
923
924        let n = a.len();
925        let fabric = a[0].fabric();
926
927        // Sample a set of beaver triples for the multiplications
928        let (beaver_a, beaver_b, beaver_c) = fabric.next_authenticated_triple_batch(n);
929        let beaver_b_gen = AuthenticatedPointResult::batch_mul_generator(&beaver_b);
930
931        let masked_rhs = AuthenticatedScalarResult::batch_sub(a, &beaver_a);
932        let masked_lhs = AuthenticatedPointResult::batch_sub(b, &beaver_b_gen);
933
934        let eG_open = AuthenticatedPointResult::open_batch(&masked_lhs);
935        let d_open = AuthenticatedScalarResult::open_batch(&masked_rhs);
936
937        // Identity [x * yG] = deG + d[bG] + [a]eG + [c]G
938        let deG = CurvePointResult::batch_mul(&d_open, &eG_open);
939        let dbG = AuthenticatedPointResult::batch_mul_public(&d_open, &beaver_b_gen);
940        let aeG = CurvePointResult::batch_mul_authenticated(&beaver_a, &eG_open);
941        let cG = AuthenticatedPointResult::batch_mul_generator(&beaver_c);
942
943        let de_db_G = AuthenticatedPointResult::batch_add_public(&dbG, &deG);
944        let ae_c_G = AuthenticatedPointResult::batch_add(&aeG, &cG);
945
946        AuthenticatedPointResult::batch_add(&de_db_G, &ae_c_G)
947    }
948
949    /// Multiply a batch of `AuthenticatedPointResult`s by a batch of
950    /// `ScalarResult`s
951    pub fn batch_mul_public(
952        a: &[ScalarResult<C>],
953        b: &[AuthenticatedPointResult<C>],
954    ) -> Vec<AuthenticatedPointResult<C>> {
955        assert_eq!(
956            a.len(),
957            b.len(),
958            "batch_mul_public requires equal length vectors"
959        );
960        if a.is_empty() {
961            return Vec::new();
962        }
963
964        let n = a.len();
965        let fabric = a[0].fabric();
966        let all_ids = a
967            .iter()
968            .map(|a| a.id())
969            .chain(b.iter().flat_map(|p| p.ids()))
970            .collect_vec();
971
972        let results: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
973            all_ids,
974            AUTHENTICATED_POINT_RESULT_LEN * n, // output_arity
975            move |mut args| {
976                let scalars: Vec<Scalar<C>> = args.drain(..n).map(Scalar::from).collect_vec();
977                let points: Vec<CurvePoint<C>> =
978                    args.into_iter().map(CurvePoint::from).collect_vec();
979
980                let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
981                for (scalar, points) in scalars
982                    .into_iter()
983                    .zip(points.chunks(AUTHENTICATED_POINT_RESULT_LEN))
984                {
985                    let share: CurvePoint<C> = points[0];
986                    let mac: CurvePoint<C> = points[1];
987                    let modifier: CurvePoint<C> = points[2];
988
989                    result.push(ResultValue::Point(share * scalar));
990                    result.push(ResultValue::Point(mac * scalar));
991                    result.push(ResultValue::Point(modifier * scalar));
992                }
993
994                result
995            },
996        );
997
998        Self::from_flattened_iterator(results.into_iter())
999    }
1000
1001    /// Multiply a batch of scalars by the generator
1002    pub fn batch_mul_generator(
1003        a: &[AuthenticatedScalarResult<C>],
1004    ) -> Vec<AuthenticatedPointResult<C>> {
1005        if a.is_empty() {
1006            return Vec::new();
1007        }
1008
1009        let n = a.len();
1010        let fabric = a[0].fabric();
1011        let all_ids = a.iter().flat_map(|v| v.ids()).collect_vec();
1012
1013        // Multiply the shares in a batch gate
1014        let results = fabric.new_batch_gate_op(
1015            all_ids,
1016            AUTHENTICATED_POINT_RESULT_LEN * n, // output_arity
1017            move |args| {
1018                let scalars = args.into_iter().map(Scalar::from).collect_vec();
1019                let generator = CurvePoint::generator();
1020
1021                scalars
1022                    .into_iter()
1023                    .map(|x| x * generator)
1024                    .map(ResultValue::Point)
1025                    .collect_vec()
1026            },
1027        );
1028
1029        Self::from_flattened_iterator(results.into_iter())
1030    }
1031}
1032
1033// === Multiscalar Multiplication === //
1034
1035impl<C: CurveGroup> AuthenticatedPointResult<C> {
1036    /// Multiscalar multiplication
1037    ///
1038    /// TODO: Maybe make use of a fast MSM operation under the hood once the
1039    /// blinded points are revealed
1040    pub fn msm(
1041        scalars: &[AuthenticatedScalarResult<C>],
1042        points: &[AuthenticatedPointResult<C>],
1043    ) -> AuthenticatedPointResult<C> {
1044        assert_eq!(
1045            scalars.len(),
1046            points.len(),
1047            "multiscalar_mul requires equal length vectors"
1048        );
1049        assert!(
1050            !scalars.is_empty(),
1051            "multiscalar_mul requires non-empty vectors"
1052        );
1053
1054        let mul_out = AuthenticatedPointResult::batch_mul(scalars, points);
1055
1056        // Create a gate to sum the points
1057        let fabric = scalars[0].fabric();
1058        let all_ids = mul_out.iter().flat_map(|p| p.ids()).collect_vec();
1059
1060        let results = fabric.new_batch_gate_op(
1061            all_ids,
1062            AUTHENTICATED_POINT_RESULT_LEN, // output_arity
1063            move |args| {
1064                // Accumulators
1065                let mut share = CurvePoint::identity();
1066                let mut mac = CurvePoint::identity();
1067                let mut modifier = CurvePoint::identity();
1068
1069                for mut chunk in args
1070                    .into_iter()
1071                    .map(CurvePoint::from)
1072                    .chunks(AUTHENTICATED_POINT_RESULT_LEN)
1073                    .into_iter()
1074                {
1075                    share += chunk.next().unwrap();
1076                    mac += chunk.next().unwrap();
1077                    modifier += chunk.next().unwrap();
1078                }
1079
1080                vec![
1081                    ResultValue::Point(share),
1082                    ResultValue::Point(mac),
1083                    ResultValue::Point(modifier),
1084                ]
1085            },
1086        );
1087
1088        AuthenticatedPointResult {
1089            share: results[0].clone().into(),
1090            mac: results[1].clone().into(),
1091            public_modifier: results[2].clone(),
1092        }
1093    }
1094
1095    /// Multiscalar multiplication on iterator types
1096    pub fn msm_iter<S, P>(scalars: S, points: P) -> AuthenticatedPointResult<C>
1097    where
1098        S: IntoIterator<Item = AuthenticatedScalarResult<C>>,
1099        P: IntoIterator<Item = AuthenticatedPointResult<C>>,
1100    {
1101        let scalars = scalars.into_iter().collect::<Vec<_>>();
1102        let points = points.into_iter().collect::<Vec<_>>();
1103
1104        Self::msm(&scalars, &points)
1105    }
1106}
1107
1108// ----------------
1109// | Test Helpers |
1110// ----------------
1111
1112/// Defines testing helpers for testing secure opening, these methods are not
1113/// safe to use outside of tests
1114#[cfg(feature = "test_helpers")]
1115pub mod test_helpers {
1116    use ark_ec::CurveGroup;
1117
1118    use crate::algebra::curve::CurvePoint;
1119
1120    use super::AuthenticatedPointResult;
1121
1122    /// Corrupt the MAC of a given authenticated point
1123    pub fn modify_mac<C: CurveGroup>(
1124        point: &mut AuthenticatedPointResult<C>,
1125        new_mac: CurvePoint<C>,
1126    ) {
1127        point.mac = point.fabric().allocate_point(new_mac).into()
1128    }
1129
1130    /// Corrupt the underlying secret share of a given authenticated point
1131    pub fn modify_share<C: CurveGroup>(
1132        point: &mut AuthenticatedPointResult<C>,
1133        new_share: CurvePoint<C>,
1134    ) {
1135        point.share = point.fabric().allocate_point(new_share).into()
1136    }
1137
1138    /// Corrupt the public modifier of a given authenticated point
1139    pub fn modify_public_modifier<C: CurveGroup>(
1140        point: &mut AuthenticatedPointResult<C>,
1141        new_modifier: CurvePoint<C>,
1142    ) {
1143        point.public_modifier = point.fabric().allocate_point(new_modifier)
1144    }
1145}