mpc_stark/algebra/
authenticated_stark_point.rs

1//! Defines an malicious secure wrapper around an `MpcStarkPoint` type that includes a MAC
2//! for ensuring computational integrity of an opened point
3
4use std::{
5    fmt::{Debug, Formatter, Result as FmtResult},
6    iter::Sum,
7    ops::{Add, Mul, Neg, Sub},
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12use futures::{Future, FutureExt};
13use itertools::{izip, Itertools};
14
15use crate::{
16    algebra::stark_curve::StarkPoint,
17    commitment::{HashCommitment, HashCommitmentResult},
18    error::MpcError,
19    fabric::{MpcFabric, ResultValue},
20    ResultId, PARTY0,
21};
22
23use super::{
24    authenticated_scalar::AuthenticatedScalarResult,
25    macros::{impl_borrow_variants, impl_commutative},
26    mpc_stark_point::MpcStarkPointResult,
27    scalar::{Scalar, ScalarResult},
28    stark_curve::{BatchStarkPointResult, StarkPointResult},
29};
30
31/// The number of underlying results in an `AuthenticatedStarkPointResult`
32pub(crate) const AUTHENTICATED_STARK_POINT_RESULT_LEN: usize = 3;
33
34/// A maliciously secure wrapper around `MpcStarkPoint` that includes a MAC as per
35/// the SPDZ protocol: https://eprint.iacr.org/2011/535.pdf
36#[derive(Clone)]
37pub struct AuthenticatedStarkPointResult {
38    /// The local secret share of the underlying authenticated point
39    pub(crate) share: MpcStarkPointResult,
40    /// A SPDZ style, unconditionally secure MAC of the value
41    /// This is used to ensure computational integrity of the opened value
42    /// See the doc comment in `AuthenticatedScalar` for more details
43    pub(crate) mac: MpcStarkPointResult,
44    /// The public modifier tracks additions and subtractions of public values to the shares
45    ///
46    /// Only the first party adds/subtracts public values to their share, but the other parties
47    /// must track this to validate the MAC when it is opened
48    pub(crate) public_modifier: StarkPointResult,
49}
50
51impl Debug for AuthenticatedStarkPointResult {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("AuthenticatedStarkPointResult")
54            .field("value", &self.share.id())
55            .field("mac", &self.mac.id())
56            .field("public_modifier", &self.public_modifier.id)
57            .finish()
58    }
59}
60
61impl AuthenticatedStarkPointResult {
62    /// Creates a new `AuthenticatedStarkPoint` from a given underlying point
63    pub fn new_shared(value: StarkPointResult) -> AuthenticatedStarkPointResult {
64        // Create an `MpcStarkPoint` from the value
65        let fabric_clone = value.fabric.clone();
66
67        let mpc_value = MpcStarkPointResult::new_shared(value);
68        let mac = fabric_clone.borrow_mac_key() * &mpc_value;
69
70        // Allocate a zero point for the public modifier
71        let public_modifier = fabric_clone.allocate_point(StarkPoint::identity());
72
73        Self {
74            share: mpc_value,
75            mac,
76            public_modifier,
77        }
78    }
79
80    /// Creates a batch of `AuthenticatedStarkPoint`s from a given batch of underlying points
81    pub fn new_shared_batch(values: &[StarkPointResult]) -> Vec<AuthenticatedStarkPointResult> {
82        if values.is_empty() {
83            return vec![];
84        }
85
86        // Create an `MpcStarkPoint` from the value
87        let n = values.len();
88        let fabric = values[0].fabric();
89        let mpc_values = values
90            .iter()
91            .map(|p| MpcStarkPointResult::new_shared(p.clone()))
92            .collect_vec();
93
94        let mac_keys = (0..n)
95            .map(|_| fabric.borrow_mac_key().clone())
96            .collect_vec();
97        let macs = MpcStarkPointResult::batch_mul(&mac_keys, &mpc_values);
98
99        mpc_values
100            .into_iter()
101            .zip(macs.into_iter())
102            .map(|(share, mac)| Self {
103                share,
104                mac,
105                public_modifier: fabric.curve_identity(),
106            })
107            .collect_vec()
108    }
109
110    /// Creates a batch of `AuthenticatedStarkPoint`s from a batch result
111    ///
112    /// The batch result combines the batch into one result, so it must be split out
113    /// first before creating the `AuthenticatedStarkPointResult`s
114    pub fn new_shared_from_batch_result(
115        values: BatchStarkPointResult,
116        n: usize,
117    ) -> Vec<AuthenticatedStarkPointResult> {
118        // Convert to a set of scalar results
119        let scalar_results = values
120            .fabric()
121            .new_batch_gate_op(vec![values.id()], n, |mut args| {
122                let args: Vec<StarkPoint> = args.pop().unwrap().into();
123                args.into_iter().map(ResultValue::Point).collect_vec()
124            });
125
126        Self::new_shared_batch(&scalar_results)
127    }
128
129    /// Get the ID of the underlying share's result
130    pub fn id(&self) -> ResultId {
131        self.share.id()
132    }
133
134    /// Get the IDs of the results that make up the `AuthenticatedStarkPointResult` representation
135    pub(crate) fn ids(&self) -> Vec<ResultId> {
136        vec![self.share.id(), self.mac.id(), self.public_modifier.id]
137    }
138
139    /// Borrow the fabric that this result is allocated in
140    pub fn fabric(&self) -> &MpcFabric {
141        self.share.fabric()
142    }
143
144    /// Get the underlying share as an `MpcStarkPoint`
145    #[cfg(feature = "test_helpers")]
146    pub fn mpc_share(&self) -> MpcStarkPointResult {
147        self.share.clone()
148    }
149
150    /// Open the value without checking the MAC
151    pub fn open(&self) -> StarkPointResult {
152        self.share.open()
153    }
154
155    /// Open a batch of values without checking the MAC
156    pub fn open_batch(values: &[Self]) -> Vec<StarkPointResult> {
157        MpcStarkPointResult::open_batch(&values.iter().map(|v| v.share.clone()).collect_vec())
158    }
159
160    /// Convert a flattened iterator into a batch of `AuthenticatedStarkPointResult`s
161    ///
162    /// We assume that the iterator has been flattened in the same way order that `Self::id`s returns
163    /// the `AuthenticatedScalar`'s values: `[share, mac, public_modifier]`
164    pub fn from_flattened_iterator<I>(iter: I) -> Vec<Self>
165    where
166        I: Iterator<Item = StarkPointResult>,
167    {
168        iter.chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN)
169            .into_iter()
170            .map(|mut chunk| Self {
171                share: chunk.next().unwrap().into(),
172                mac: chunk.next().unwrap().into(),
173                public_modifier: chunk.next().unwrap(),
174            })
175            .collect_vec()
176    }
177
178    /// Verify the MAC check on an authenticated opening
179    fn verify_mac_check(
180        my_mac_share: StarkPoint,
181        peer_mac_share: StarkPoint,
182        peer_mac_commitment: Scalar,
183        peer_blinder: Scalar,
184    ) -> bool {
185        // Check that the MAC check value is the correct opening of the
186        // given commitment
187        let peer_comm = HashCommitment {
188            value: peer_mac_share,
189            blinder: peer_blinder,
190            commitment: peer_mac_commitment,
191        };
192        if !peer_comm.verify() {
193            return false;
194        }
195
196        // Check that the MAC check shares add up to the additive identity in
197        // the Starknet curve group
198        if my_mac_share + peer_mac_share != StarkPoint::identity() {
199            return false;
200        }
201
202        true
203    }
204
205    /// Open the value and check the MAC
206    ///
207    /// This follows the protocol detailed in
208    ///     https://securecomputation.org/docs/pragmaticmpc.pdf
209    pub fn open_authenticated(&self) -> AuthenticatedStarkPointOpenResult {
210        // Both parties open the underlying value
211        let recovered_value = self.share.open();
212
213        // Add a gate to compute hte MAC check value: `key_share * opened_value - mac_share`
214        let mac_check: StarkPointResult = self.fabric().new_gate_op(
215            vec![
216                self.fabric().borrow_mac_key().id(),
217                recovered_value.id(),
218                self.public_modifier.id(),
219                self.mac.id(),
220            ],
221            |mut args| {
222                let mac_key_share: Scalar = args.remove(0).into();
223                let value: StarkPoint = args.remove(0).into();
224                let modifier: StarkPoint = args.remove(0).into();
225                let mac_share: StarkPoint = args.remove(0).into();
226
227                ResultValue::Point((value + modifier) * mac_key_share - mac_share)
228            },
229        );
230
231        // Compute a commitment to this value and share it with the peer
232        let my_comm = HashCommitmentResult::commit(mac_check.clone());
233        let peer_commit = self.fabric().exchange_value(my_comm.commitment);
234
235        // Once the parties have exchanged their commitments, they can open the underlying MAC check value
236        // as they are bound by the commitment
237        let peer_mac_check = self.fabric().exchange_value(my_comm.value.clone());
238        let blinder_result: ScalarResult = self.fabric().allocate_scalar(my_comm.blinder);
239        let peer_blinder = self.fabric().exchange_value(blinder_result);
240
241        // Check the peer's commitment and the sum of the MAC checks
242        let commitment_check: ScalarResult = self.fabric().new_gate_op(
243            vec![
244                mac_check.id,
245                peer_mac_check.id,
246                peer_blinder.id,
247                peer_commit.id,
248            ],
249            move |mut args| {
250                let my_mac_check: StarkPoint = args.remove(0).into();
251                let peer_mac_check: StarkPoint = args.remove(0).into();
252                let peer_blinder: Scalar = args.remove(0).into();
253                let peer_commitment: Scalar = args.remove(0).into();
254
255                ResultValue::Scalar(Scalar::from(Self::verify_mac_check(
256                    my_mac_check,
257                    peer_mac_check,
258                    peer_commitment,
259                    peer_blinder,
260                )))
261            },
262        );
263
264        AuthenticatedStarkPointOpenResult {
265            value: recovered_value,
266            mac_check: commitment_check,
267        }
268    }
269
270    /// Open a batch of values and check the MACs
271    pub fn open_authenticated_batch(values: &[Self]) -> Vec<AuthenticatedStarkPointOpenResult> {
272        if values.is_empty() {
273            return Vec::new();
274        }
275
276        let n = values.len();
277        let fabric = values[0].fabric();
278
279        // Open the values
280        let opened_values = Self::open_batch(values);
281
282        // --- MAC Check --- //
283
284        // Compute the shares of the MAC check in batch
285        let mut mac_check_deps = Vec::with_capacity(1 + AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
286        mac_check_deps.push(fabric.borrow_mac_key().id());
287        for i in 0..n {
288            mac_check_deps.push(opened_values[i].id());
289            mac_check_deps.push(values[i].public_modifier.id());
290            mac_check_deps.push(values[i].mac.id());
291        }
292
293        let mac_checks: Vec<StarkPointResult> =
294            fabric.new_batch_gate_op(mac_check_deps, n /* output_arity */, move |mut args| {
295                let mac_key_share: Scalar = args.remove(0).into();
296                let mut check_result = Vec::with_capacity(n);
297
298                for _ in 0..n {
299                    let value: StarkPoint = args.remove(0).into();
300                    let modifier: StarkPoint = args.remove(0).into();
301                    let mac_share: StarkPoint = args.remove(0).into();
302
303                    check_result.push(mac_key_share * (value + modifier) - mac_share);
304                }
305
306                check_result.into_iter().map(ResultValue::Point).collect()
307            });
308
309        // --- Commit to the MAC checks --- //
310
311        let my_comms = mac_checks
312            .iter()
313            .cloned()
314            .map(HashCommitmentResult::commit)
315            .collect_vec();
316        let peer_comms = fabric.exchange_values(
317            &my_comms
318                .iter()
319                .map(|comm| comm.commitment.clone())
320                .collect_vec(),
321        );
322
323        // --- Exchange the MAC Checks and Commitment Blinders --- //
324
325        let peer_mac_checks = fabric.exchange_values(&mac_checks);
326        let peer_blinders = fabric.exchange_values(
327            &my_comms
328                .iter()
329                .map(|comm| fabric.allocate_scalar(comm.blinder))
330                .collect_vec(),
331        );
332
333        // --- Check the MAC Checks --- //
334
335        let mut mac_check_gate_deps = my_comms.iter().map(|comm| comm.value.id).collect_vec();
336        mac_check_gate_deps.push(peer_mac_checks.id);
337        mac_check_gate_deps.push(peer_blinders.id);
338        mac_check_gate_deps.push(peer_comms.id);
339
340        let commitment_checks: Vec<ScalarResult> = fabric.new_batch_gate_op(
341            mac_check_gate_deps,
342            n, /* output_arity */
343            move |mut args| {
344                let my_comms: Vec<StarkPoint> = args.drain(..n).map(|comm| comm.into()).collect();
345                let peer_mac_checks: Vec<StarkPoint> = args.remove(0).into();
346                let peer_blinders: Vec<Scalar> = args.remove(0).into();
347                let peer_comms: Vec<Scalar> = args.remove(0).into();
348
349                // Build a commitment from the gate inputs
350                let mut mac_checks = Vec::with_capacity(n);
351                for (my_mac_share, peer_mac_share, peer_blinder, peer_commitment) in izip!(
352                    my_comms.into_iter(),
353                    peer_mac_checks.into_iter(),
354                    peer_blinders.into_iter(),
355                    peer_comms.into_iter()
356                ) {
357                    let mac_check = Self::verify_mac_check(
358                        my_mac_share,
359                        peer_mac_share,
360                        peer_commitment,
361                        peer_blinder,
362                    );
363                    mac_checks.push(ResultValue::Scalar(Scalar::from(mac_check)));
364                }
365
366                mac_checks
367            },
368        );
369
370        // --- Return the results --- //
371
372        opened_values
373            .into_iter()
374            .zip(commitment_checks.into_iter())
375            .map(|(value, check)| AuthenticatedStarkPointOpenResult {
376                value,
377                mac_check: check,
378            })
379            .collect_vec()
380    }
381}
382
383/// The value that results from opening an `AuthenticatedStarkPointResult` and checking its MAC. This encapsulates
384/// both the underlying value and the result of the MAC check
385#[derive(Clone)]
386pub struct AuthenticatedStarkPointOpenResult {
387    /// The underlying value
388    pub value: StarkPointResult,
389    /// The result of the MAC check
390    pub mac_check: ScalarResult,
391}
392
393impl Debug for AuthenticatedStarkPointOpenResult {
394    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
395        f.debug_struct("AuthenticatedStarkPointOpenResult")
396            .field("value", &self.value.id)
397            .field("mac_check", &self.mac_check.id)
398            .finish()
399    }
400}
401
402impl Future for AuthenticatedStarkPointOpenResult {
403    type Output = Result<StarkPoint, MpcError>;
404
405    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
406        // Await both of the underlying values
407        let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
408        let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));
409
410        if mac_check == Scalar::from(1) {
411            Poll::Ready(Ok(value))
412        } else {
413            Poll::Ready(Err(MpcError::AuthenticationError))
414        }
415    }
416}
417
418impl Sum for AuthenticatedStarkPointResult {
419    // Assumes the iterator is non-empty
420    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
421        let first = iter
422            .next()
423            .expect("AuthenticatedStarkPointResult::sum requires a non-empty iterator");
424        iter.fold(first, |acc, x| acc + x)
425    }
426}
427
428// --------------
429// | Arithmetic |
430// --------------
431
432// === Addition === //
433
434impl Add<&StarkPoint> for &AuthenticatedStarkPointResult {
435    type Output = AuthenticatedStarkPointResult;
436
437    fn add(self, other: &StarkPoint) -> AuthenticatedStarkPointResult {
438        let new_share = if self.fabric().party_id() == PARTY0 {
439            // Party zero adds the public value to their share
440            &self.share + other
441        } else {
442            // Other parties just add the identity to the value to allocate a new op and keep
443            // in sync with party 0
444            &self.share + StarkPoint::identity()
445        };
446
447        // Add the public value to the MAC
448        let new_modifier = &self.public_modifier - other;
449        AuthenticatedStarkPointResult {
450            share: new_share,
451            mac: self.mac.clone(),
452            public_modifier: new_modifier,
453        }
454    }
455}
456impl_borrow_variants!(AuthenticatedStarkPointResult, Add, add, +, StarkPoint);
457impl_commutative!(AuthenticatedStarkPointResult, Add, add, +, StarkPoint);
458
459impl Add<&StarkPointResult> for &AuthenticatedStarkPointResult {
460    type Output = AuthenticatedStarkPointResult;
461
462    fn add(self, other: &StarkPointResult) -> AuthenticatedStarkPointResult {
463        let new_share = if self.fabric().party_id() == PARTY0 {
464            // Party zero adds the public value to their share
465            &self.share + other
466        } else {
467            // Other parties just add the identity to the value to allocate a new op and keep
468            // in sync with party 0
469            &self.share + StarkPoint::identity()
470        };
471
472        // Add the public value to the MAC
473        let new_modifier = &self.public_modifier - other;
474        AuthenticatedStarkPointResult {
475            share: new_share,
476            mac: self.mac.clone(),
477            public_modifier: new_modifier,
478        }
479    }
480}
481impl_borrow_variants!(AuthenticatedStarkPointResult, Add, add, +, StarkPointResult);
482impl_commutative!(AuthenticatedStarkPointResult, Add, add, +, StarkPointResult);
483
484impl Add<&AuthenticatedStarkPointResult> for &AuthenticatedStarkPointResult {
485    type Output = AuthenticatedStarkPointResult;
486
487    fn add(self, other: &AuthenticatedStarkPointResult) -> AuthenticatedStarkPointResult {
488        let new_share = &self.share + &other.share;
489
490        // Add the public value to the MAC
491        let new_mac = &self.mac + &other.mac;
492        AuthenticatedStarkPointResult {
493            share: new_share,
494            mac: new_mac,
495            public_modifier: self.public_modifier.clone() + other.public_modifier.clone(),
496        }
497    }
498}
499impl_borrow_variants!(AuthenticatedStarkPointResult, Add, add, +, AuthenticatedStarkPointResult);
500
501impl AuthenticatedStarkPointResult {
502    /// Add two batches of `AuthenticatedStarkPointResult`s
503    pub fn batch_add(
504        a: &[AuthenticatedStarkPointResult],
505        b: &[AuthenticatedStarkPointResult],
506    ) -> Vec<AuthenticatedStarkPointResult> {
507        assert_eq!(a.len(), b.len(), "batch_add requires equal length vectors");
508        if a.is_empty() {
509            return Vec::new();
510        }
511
512        let n = a.len();
513        let fabric = a[0].fabric();
514        let all_ids = a.iter().chain(b.iter()).flat_map(|p| p.ids()).collect_vec();
515
516        let res: Vec<StarkPointResult> = fabric.new_batch_gate_op(
517            all_ids,
518            AUTHENTICATED_STARK_POINT_RESULT_LEN * n,
519            move |mut args| {
520                let len = args.len();
521                let a_vals = args.drain(..len / 2).collect_vec();
522                let b_vals = args;
523
524                let mut result = Vec::with_capacity(AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
525                for (a_chunk, b_chunk) in a_vals
526                    .chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN)
527                    .zip(b_vals.chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN))
528                {
529                    let a_share: StarkPoint = a_chunk[0].clone().into();
530                    let a_mac: StarkPoint = a_chunk[1].clone().into();
531                    let a_modifier: StarkPoint = a_chunk[2].clone().into();
532
533                    let b_share: StarkPoint = b_chunk[0].clone().into();
534                    let b_mac: StarkPoint = b_chunk[1].clone().into();
535                    let b_modifier: StarkPoint = b_chunk[2].clone().into();
536
537                    result.push(ResultValue::Point(a_share + b_share));
538                    result.push(ResultValue::Point(a_mac + b_mac));
539                    result.push(ResultValue::Point(a_modifier + b_modifier));
540                }
541
542                result
543            },
544        );
545
546        Self::from_flattened_iterator(res.into_iter())
547    }
548
549    /// Add a batch of `AuthenticatedStarkPointResult`s to a batch of `StarkPointResult`s
550    pub fn batch_add_public(
551        a: &[AuthenticatedStarkPointResult],
552        b: &[StarkPointResult],
553    ) -> Vec<AuthenticatedStarkPointResult> {
554        assert_eq!(
555            a.len(),
556            b.len(),
557            "batch_add_public requires equal length vectors"
558        );
559        if a.is_empty() {
560            return Vec::new();
561        }
562
563        let n = a.len();
564        let fabric = a[0].fabric();
565        let all_ids = a
566            .iter()
567            .flat_map(|a| a.ids())
568            .chain(b.iter().map(|b| b.id()))
569            .collect_vec();
570
571        let party_id = fabric.party_id();
572        let res: Vec<StarkPointResult> = fabric.new_batch_gate_op(
573            all_ids,
574            AUTHENTICATED_STARK_POINT_RESULT_LEN * n,
575            move |mut args| {
576                let a_vals = args
577                    .drain(..AUTHENTICATED_STARK_POINT_RESULT_LEN * n)
578                    .collect_vec();
579                let b_vals = args;
580
581                let mut result = Vec::with_capacity(AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
582                for (a_chunk, b_val) in a_vals
583                    .chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN)
584                    .zip(b_vals.into_iter())
585                {
586                    let a_share: StarkPoint = a_chunk[0].clone().into();
587                    let a_mac: StarkPoint = a_chunk[1].clone().into();
588                    let a_modifier: StarkPoint = a_chunk[2].clone().into();
589
590                    let public_value: StarkPoint = b_val.into();
591
592                    // Only the first party adds the public value to their share
593                    if party_id == PARTY0 {
594                        result.push(ResultValue::Point(a_share + public_value));
595                    } else {
596                        result.push(ResultValue::Point(a_share))
597                    }
598
599                    result.push(ResultValue::Point(a_mac));
600                    result.push(ResultValue::Point(a_modifier - public_value));
601                }
602
603                result
604            },
605        );
606
607        Self::from_flattened_iterator(res.into_iter())
608    }
609}
610
611// === Subtraction === //
612
613impl Sub<&StarkPoint> for &AuthenticatedStarkPointResult {
614    type Output = AuthenticatedStarkPointResult;
615
616    fn sub(self, other: &StarkPoint) -> AuthenticatedStarkPointResult {
617        let new_share = if self.fabric().party_id() == PARTY0 {
618            // Party zero subtracts the public value from their share
619            &self.share - other
620        } else {
621            // Other parties just subtract the identity from the value to allocate a new op and keep
622            // in sync with party 0
623            &self.share - StarkPoint::identity()
624        };
625
626        // Subtract the public value from the MAC
627        let new_modifier = &self.public_modifier + other;
628        AuthenticatedStarkPointResult {
629            share: new_share,
630            mac: self.mac.clone(),
631            public_modifier: new_modifier,
632        }
633    }
634}
635impl_borrow_variants!(AuthenticatedStarkPointResult, Sub, sub, -, StarkPoint);
636impl_commutative!(AuthenticatedStarkPointResult, Sub, sub, -, StarkPoint);
637
638impl Sub<&StarkPointResult> for &AuthenticatedStarkPointResult {
639    type Output = AuthenticatedStarkPointResult;
640
641    fn sub(self, other: &StarkPointResult) -> AuthenticatedStarkPointResult {
642        let new_share = if self.fabric().party_id() == PARTY0 {
643            // Party zero subtracts the public value from their share
644            &self.share - other
645        } else {
646            // Other parties just subtract the identity from the value to allocate a new op and keep
647            // in sync with party 0
648            &self.share - StarkPoint::identity()
649        };
650
651        // Subtract the public value from the MAC
652        let new_modifier = &self.public_modifier + other;
653        AuthenticatedStarkPointResult {
654            share: new_share,
655            mac: self.mac.clone(),
656            public_modifier: new_modifier,
657        }
658    }
659}
660impl_borrow_variants!(AuthenticatedStarkPointResult, Sub, sub, -, StarkPointResult);
661impl_commutative!(AuthenticatedStarkPointResult, Sub, sub, -, StarkPointResult);
662
663impl Sub<&AuthenticatedStarkPointResult> for &AuthenticatedStarkPointResult {
664    type Output = AuthenticatedStarkPointResult;
665
666    fn sub(self, other: &AuthenticatedStarkPointResult) -> AuthenticatedStarkPointResult {
667        let new_share = &self.share - &other.share;
668
669        // Subtract the public value from the MAC
670        let new_mac = &self.mac - &other.mac;
671        AuthenticatedStarkPointResult {
672            share: new_share,
673            mac: new_mac,
674            public_modifier: self.public_modifier.clone(),
675        }
676    }
677}
678impl_borrow_variants!(AuthenticatedStarkPointResult, Sub, sub, -, AuthenticatedStarkPointResult);
679
680impl AuthenticatedStarkPointResult {
681    /// Add two batches of `AuthenticatedStarkPointResult`s
682    pub fn batch_sub(
683        a: &[AuthenticatedStarkPointResult],
684        b: &[AuthenticatedStarkPointResult],
685    ) -> Vec<AuthenticatedStarkPointResult> {
686        assert_eq!(a.len(), b.len(), "batch_add requires equal length vectors");
687        if a.is_empty() {
688            return Vec::new();
689        }
690
691        let n = a.len();
692        let fabric = a[0].fabric();
693        let all_ids = a.iter().chain(b.iter()).flat_map(|p| p.ids()).collect_vec();
694
695        let res: Vec<StarkPointResult> = fabric.new_batch_gate_op(
696            all_ids,
697            AUTHENTICATED_STARK_POINT_RESULT_LEN * n,
698            move |mut args| {
699                let len = args.len();
700                let a_vals = args.drain(..len / 2).collect_vec();
701                let b_vals = args;
702
703                let mut result = Vec::with_capacity(AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
704                for (a_chunk, b_chunk) in a_vals
705                    .chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN)
706                    .zip(b_vals.chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN))
707                {
708                    let a_share: StarkPoint = a_chunk[0].clone().into();
709                    let a_mac: StarkPoint = a_chunk[1].clone().into();
710                    let a_modifier: StarkPoint = a_chunk[2].clone().into();
711
712                    let b_share: StarkPoint = b_chunk[0].clone().into();
713                    let b_mac: StarkPoint = b_chunk[1].clone().into();
714                    let b_modifier: StarkPoint = b_chunk[2].clone().into();
715
716                    result.push(ResultValue::Point(a_share - b_share));
717                    result.push(ResultValue::Point(a_mac - b_mac));
718                    result.push(ResultValue::Point(a_modifier - b_modifier));
719                }
720
721                result
722            },
723        );
724
725        Self::from_flattened_iterator(res.into_iter())
726    }
727
728    /// Subtract a batch of `AuthenticatedStarkPointResult`s to a batch of `StarkPointResult`s
729    pub fn batch_sub_public(
730        a: &[AuthenticatedStarkPointResult],
731        b: &[StarkPointResult],
732    ) -> Vec<AuthenticatedStarkPointResult> {
733        assert_eq!(
734            a.len(),
735            b.len(),
736            "batch_add_public requires equal length vectors"
737        );
738        if a.is_empty() {
739            return Vec::new();
740        }
741
742        let n = a.len();
743        let fabric = a[0].fabric();
744        let all_ids = a
745            .iter()
746            .flat_map(|a| a.ids())
747            .chain(b.iter().map(|b| b.id()))
748            .collect_vec();
749
750        let party_id = fabric.party_id();
751        let res: Vec<StarkPointResult> = fabric.new_batch_gate_op(
752            all_ids,
753            AUTHENTICATED_STARK_POINT_RESULT_LEN * n,
754            move |mut args| {
755                let a_vals = args
756                    .drain(..AUTHENTICATED_STARK_POINT_RESULT_LEN * n)
757                    .collect_vec();
758                let b_vals = args;
759
760                let mut result = Vec::with_capacity(AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
761                for (a_chunk, b_val) in a_vals
762                    .chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN)
763                    .zip(b_vals.into_iter())
764                {
765                    let a_share: StarkPoint = a_chunk[0].clone().into();
766                    let a_mac: StarkPoint = a_chunk[1].clone().into();
767                    let a_modifier: StarkPoint = a_chunk[2].clone().into();
768
769                    let b_share: StarkPoint = b_val.into();
770
771                    // Only the first party adds the public value to their share
772                    if party_id == PARTY0 {
773                        result.push(ResultValue::Point(a_share - b_share));
774                    } else {
775                        result.push(ResultValue::Point(a_share))
776                    }
777
778                    result.push(ResultValue::Point(a_mac));
779                    result.push(ResultValue::Point(a_modifier + b_share));
780                }
781
782                result
783            },
784        );
785
786        Self::from_flattened_iterator(res.into_iter())
787    }
788}
789
790// === Negation == //
791
792impl Neg for &AuthenticatedStarkPointResult {
793    type Output = AuthenticatedStarkPointResult;
794
795    fn neg(self) -> AuthenticatedStarkPointResult {
796        let new_share = -&self.share;
797
798        // Negate the public value in the MAC
799        let new_mac = -&self.mac;
800        AuthenticatedStarkPointResult {
801            share: new_share,
802            mac: new_mac,
803            public_modifier: self.public_modifier.clone(),
804        }
805    }
806}
807impl_borrow_variants!(AuthenticatedStarkPointResult, Neg, neg, -);
808
809impl AuthenticatedStarkPointResult {
810    /// Negate a batch of `AuthenticatedStarkPointResult`s
811    pub fn batch_neg(a: &[AuthenticatedStarkPointResult]) -> Vec<AuthenticatedStarkPointResult> {
812        if a.is_empty() {
813            return Vec::new();
814        }
815
816        let n = a.len();
817        let fabric = a[0].fabric();
818        let all_ids = a.iter().flat_map(|p| p.ids()).collect_vec();
819
820        let res: Vec<StarkPointResult> = fabric.new_batch_gate_op(
821            all_ids,
822            AUTHENTICATED_STARK_POINT_RESULT_LEN * n,
823            move |args| {
824                args.into_iter()
825                    .map(StarkPoint::from)
826                    .map(StarkPoint::neg)
827                    .map(ResultValue::Point)
828                    .collect_vec()
829            },
830        );
831
832        Self::from_flattened_iterator(res.into_iter())
833    }
834}
835
836// === Scalar Multiplication === //
837
838impl Mul<&Scalar> for &AuthenticatedStarkPointResult {
839    type Output = AuthenticatedStarkPointResult;
840
841    fn mul(self, other: &Scalar) -> AuthenticatedStarkPointResult {
842        let new_share = &self.share * other;
843
844        // Multiply the public value in the MAC
845        let new_mac = &self.mac * other;
846        let new_modifier = &self.public_modifier * other;
847        AuthenticatedStarkPointResult {
848            share: new_share,
849            mac: new_mac,
850            public_modifier: new_modifier,
851        }
852    }
853}
854impl_borrow_variants!(AuthenticatedStarkPointResult, Mul, mul, *, Scalar);
855impl_commutative!(AuthenticatedStarkPointResult, Mul, mul, *, Scalar);
856
857impl Mul<&ScalarResult> for &AuthenticatedStarkPointResult {
858    type Output = AuthenticatedStarkPointResult;
859
860    fn mul(self, other: &ScalarResult) -> AuthenticatedStarkPointResult {
861        let new_share = &self.share * other;
862
863        // Multiply the public value in the MAC
864        let new_mac = &self.mac * other;
865        let new_modifier = &self.public_modifier * other;
866        AuthenticatedStarkPointResult {
867            share: new_share,
868            mac: new_mac,
869            public_modifier: new_modifier,
870        }
871    }
872}
873impl_borrow_variants!(AuthenticatedStarkPointResult, Mul, mul, *, ScalarResult);
874impl_commutative!(AuthenticatedStarkPointResult, Mul, mul, *, ScalarResult);
875
876impl Mul<&AuthenticatedScalarResult> for &AuthenticatedStarkPointResult {
877    type Output = AuthenticatedStarkPointResult;
878
879    // Beaver trick
880    fn mul(self, rhs: &AuthenticatedScalarResult) -> AuthenticatedStarkPointResult {
881        // Sample a beaver triple
882        let generator = StarkPoint::generator();
883        let (a, b, c) = self.fabric().next_authenticated_triple();
884
885        // Open the values d = [rhs - a] and e = [lhs - bG] for curve group generator G
886        let masked_rhs = rhs - &a;
887        let masked_lhs = self - (&generator * &b);
888
889        #[allow(non_snake_case)]
890        let eG_open = masked_lhs.open();
891        let d_open = masked_rhs.open();
892
893        // Identity [x * yG] = deG + d[bG] + [a]eG + [c]G
894        &d_open * &eG_open + &d_open * &(&generator * &b) + &a * eG_open + &c * generator
895    }
896}
897impl_borrow_variants!(AuthenticatedStarkPointResult, Mul, mul, *, AuthenticatedScalarResult);
898impl_commutative!(AuthenticatedStarkPointResult, Mul, mul, *, AuthenticatedScalarResult);
899
900impl AuthenticatedStarkPointResult {
901    /// Multiply a batch of `AuthenticatedStarkPointResult`s by a batch of `AuthenticatedScalarResult`s
902    #[allow(non_snake_case)]
903    pub fn batch_mul(
904        a: &[AuthenticatedScalarResult],
905        b: &[AuthenticatedStarkPointResult],
906    ) -> Vec<AuthenticatedStarkPointResult> {
907        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
908        if a.is_empty() {
909            return Vec::new();
910        }
911
912        let n = a.len();
913        let fabric = a[0].fabric();
914
915        // Sample a set of beaver triples for the multiplications
916        let (beaver_a, beaver_b, beaver_c) = fabric.next_authenticated_triple_batch(n);
917        let beaver_b_gen = AuthenticatedStarkPointResult::batch_mul_generator(&beaver_b);
918
919        let masked_rhs = AuthenticatedScalarResult::batch_sub(a, &beaver_a);
920        let masked_lhs = AuthenticatedStarkPointResult::batch_sub(b, &beaver_b_gen);
921
922        let eG_open = AuthenticatedStarkPointResult::open_batch(&masked_lhs);
923        let d_open = AuthenticatedScalarResult::open_batch(&masked_rhs);
924
925        // Identity [x * yG] = deG + d[bG] + [a]eG + [c]G
926        let deG = StarkPointResult::batch_mul(&d_open, &eG_open);
927        let dbG = AuthenticatedStarkPointResult::batch_mul_public(&d_open, &beaver_b_gen);
928        let aeG = StarkPointResult::batch_mul_authenticated(&beaver_a, &eG_open);
929        let cG = AuthenticatedStarkPointResult::batch_mul_generator(&beaver_c);
930
931        let de_db_G = AuthenticatedStarkPointResult::batch_add_public(&dbG, &deG);
932        let ae_c_G = AuthenticatedStarkPointResult::batch_add(&aeG, &cG);
933
934        AuthenticatedStarkPointResult::batch_add(&de_db_G, &ae_c_G)
935    }
936
937    /// Multiply a batch of `AuthenticatedStarkPointResult`s by a batch of `ScalarResult`s
938    pub fn batch_mul_public(
939        a: &[ScalarResult],
940        b: &[AuthenticatedStarkPointResult],
941    ) -> Vec<AuthenticatedStarkPointResult> {
942        assert_eq!(
943            a.len(),
944            b.len(),
945            "batch_mul_public requires equal length vectors"
946        );
947        if a.is_empty() {
948            return Vec::new();
949        }
950
951        let n = a.len();
952        let fabric = a[0].fabric();
953        let all_ids = a
954            .iter()
955            .map(|a| a.id())
956            .chain(b.iter().flat_map(|p| p.ids()))
957            .collect_vec();
958
959        let results: Vec<StarkPointResult> = fabric.new_batch_gate_op(
960            all_ids,
961            AUTHENTICATED_STARK_POINT_RESULT_LEN * n, /* output_arity */
962            move |mut args| {
963                let scalars: Vec<Scalar> = args.drain(..n).map(Scalar::from).collect_vec();
964                let points: Vec<StarkPoint> = args.into_iter().map(StarkPoint::from).collect_vec();
965
966                let mut result = Vec::with_capacity(AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
967                for (scalar, points) in scalars
968                    .into_iter()
969                    .zip(points.chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN))
970                {
971                    let share: StarkPoint = points[0];
972                    let mac: StarkPoint = points[1];
973                    let modifier: StarkPoint = points[2];
974
975                    result.push(ResultValue::Point(share * scalar));
976                    result.push(ResultValue::Point(mac * scalar));
977                    result.push(ResultValue::Point(modifier * scalar));
978                }
979
980                result
981            },
982        );
983
984        Self::from_flattened_iterator(results.into_iter())
985    }
986
987    /// Multiply a batch of scalars by the generator
988    pub fn batch_mul_generator(
989        a: &[AuthenticatedScalarResult],
990    ) -> Vec<AuthenticatedStarkPointResult> {
991        if a.is_empty() {
992            return Vec::new();
993        }
994
995        let n = a.len();
996        let fabric = a[0].fabric();
997        let all_ids = a.iter().flat_map(|v| v.ids()).collect_vec();
998
999        // Multiply the shares in a batch gate
1000        let results = fabric.new_batch_gate_op(
1001            all_ids,
1002            AUTHENTICATED_STARK_POINT_RESULT_LEN * n, /* output_arity */
1003            move |args| {
1004                let scalars = args.into_iter().map(Scalar::from).collect_vec();
1005                let generator = StarkPoint::generator();
1006
1007                scalars
1008                    .into_iter()
1009                    .map(|x| x * generator)
1010                    .map(ResultValue::Point)
1011                    .collect_vec()
1012            },
1013        );
1014
1015        Self::from_flattened_iterator(results.into_iter())
1016    }
1017}
1018
1019// === Multiscalar Multiplication === //
1020
1021impl AuthenticatedStarkPointResult {
1022    /// Multiscalar multiplication
1023    ///
1024    /// TODO: Maybe make use of a fast MSM operation under the hood once the blinded points are revealed
1025    pub fn msm(
1026        scalars: &[AuthenticatedScalarResult],
1027        points: &[AuthenticatedStarkPointResult],
1028    ) -> AuthenticatedStarkPointResult {
1029        assert_eq!(
1030            scalars.len(),
1031            points.len(),
1032            "multiscalar_mul requires equal length vectors"
1033        );
1034        assert!(
1035            !scalars.is_empty(),
1036            "multiscalar_mul requires non-empty vectors"
1037        );
1038
1039        let mul_out = AuthenticatedStarkPointResult::batch_mul(scalars, points);
1040
1041        // Create a gate to sum the points
1042        let fabric = scalars[0].fabric();
1043        let all_ids = mul_out.iter().flat_map(|p| p.ids()).collect_vec();
1044
1045        let results = fabric.new_batch_gate_op(
1046            all_ids,
1047            AUTHENTICATED_STARK_POINT_RESULT_LEN, /* output_arity */
1048            move |args| {
1049                // Accumulators
1050                let mut share = StarkPoint::identity();
1051                let mut mac = StarkPoint::identity();
1052                let mut modifier = StarkPoint::identity();
1053
1054                for mut chunk in args
1055                    .into_iter()
1056                    .map(StarkPoint::from)
1057                    .chunks(AUTHENTICATED_STARK_POINT_RESULT_LEN)
1058                    .into_iter()
1059                {
1060                    share += chunk.next().unwrap();
1061                    mac += chunk.next().unwrap();
1062                    modifier += chunk.next().unwrap();
1063                }
1064
1065                vec![
1066                    ResultValue::Point(share),
1067                    ResultValue::Point(mac),
1068                    ResultValue::Point(modifier),
1069                ]
1070            },
1071        );
1072
1073        AuthenticatedStarkPointResult {
1074            share: results[0].clone().into(),
1075            mac: results[1].clone().into(),
1076            public_modifier: results[2].clone(),
1077        }
1078    }
1079
1080    /// Multiscalar multiplication on iterator types
1081    pub fn msm_iter<S, P>(scalars: S, points: P) -> AuthenticatedStarkPointResult
1082    where
1083        S: IntoIterator<Item = AuthenticatedScalarResult>,
1084        P: IntoIterator<Item = AuthenticatedStarkPointResult>,
1085    {
1086        let scalars = scalars.into_iter().collect::<Vec<_>>();
1087        let points = points.into_iter().collect::<Vec<_>>();
1088
1089        Self::msm(&scalars, &points)
1090    }
1091}
1092
1093// ----------------
1094// | Test Helpers |
1095// ----------------
1096
1097/// Defines testing helpers for testing secure opening, these methods are not safe to use
1098/// outside of tests
1099#[cfg(feature = "test_helpers")]
1100pub mod test_helpers {
1101    use crate::algebra::stark_curve::StarkPoint;
1102
1103    use super::AuthenticatedStarkPointResult;
1104
1105    /// Corrupt the MAC of a given authenticated point
1106    pub fn modify_mac(point: &mut AuthenticatedStarkPointResult, new_mac: StarkPoint) {
1107        point.mac = point.fabric().allocate_point(new_mac).into()
1108    }
1109
1110    /// Corrupt the underlying secret share of a given authenticated point
1111    pub fn modify_share(point: &mut AuthenticatedStarkPointResult, new_share: StarkPoint) {
1112        point.share = point.fabric().allocate_point(new_share).into()
1113    }
1114
1115    /// Corrupt the public modifier of a given authenticated point
1116    pub fn modify_public_modifier(
1117        point: &mut AuthenticatedStarkPointResult,
1118        new_modifier: StarkPoint,
1119    ) {
1120        point.public_modifier = point.fabric().allocate_point(new_modifier)
1121    }
1122}