1use std::{
5 fmt::{Debug, Formatter, Result as FmtResult},
6 iter::Sum,
7 ops::{Add, Mul, Neg, Sub},
8 pin::Pin,
9 task::{Context, Poll},
10};
11
12use ark_ec::CurveGroup;
13use futures::{Future, FutureExt};
14use itertools::{izip, Itertools};
15
16use crate::{
17 algebra::macros::*,
18 algebra::scalar::*,
19 commitment::{HashCommitment, HashCommitmentResult},
20 error::MpcError,
21 fabric::{MpcFabric, ResultValue},
22 ResultId, PARTY0,
23};
24
25use super::{
26 curve::{BatchCurvePointResult, CurvePoint, CurvePointResult},
27 mpc_curve::MpcPointResult,
28};
29
30pub(crate) const AUTHENTICATED_POINT_RESULT_LEN: usize = 3;
32
33#[derive(Clone)]
36pub struct AuthenticatedPointResult<C: CurveGroup> {
37 pub(crate) share: MpcPointResult<C>,
39 pub(crate) mac: MpcPointResult<C>,
44 pub(crate) public_modifier: CurvePointResult<C>,
51}
52
53impl<C: CurveGroup> Debug for AuthenticatedPointResult<C> {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("AuthenticatedPointResult")
56 .field("value", &self.share.id())
57 .field("mac", &self.mac.id())
58 .field("public_modifier", &self.public_modifier.id)
59 .finish()
60 }
61}
62
63impl<C: CurveGroup> AuthenticatedPointResult<C> {
64 pub fn new_shared(value: CurvePointResult<C>) -> AuthenticatedPointResult<C> {
66 let fabric_clone = value.fabric.clone();
68
69 let mpc_value = MpcPointResult::new_shared(value);
70 let mac = fabric_clone.borrow_mac_key() * &mpc_value;
71
72 let public_modifier = fabric_clone.allocate_point(CurvePoint::identity());
74
75 Self {
76 share: mpc_value,
77 mac,
78 public_modifier,
79 }
80 }
81
82 pub fn new_shared_batch(values: &[CurvePointResult<C>]) -> Vec<AuthenticatedPointResult<C>> {
85 if values.is_empty() {
86 return vec![];
87 }
88
89 let n = values.len();
91 let fabric = values[0].fabric();
92 let mpc_values = values
93 .iter()
94 .map(|p| MpcPointResult::new_shared(p.clone()))
95 .collect_vec();
96
97 let mac_keys = (0..n)
98 .map(|_| fabric.borrow_mac_key().clone())
99 .collect_vec();
100 let macs = MpcPointResult::batch_mul(&mac_keys, &mpc_values);
101
102 mpc_values
103 .into_iter()
104 .zip(macs)
105 .map(|(share, mac)| Self {
106 share,
107 mac,
108 public_modifier: fabric.curve_identity(),
109 })
110 .collect_vec()
111 }
112
113 pub fn new_shared_from_batch_result(
118 values: BatchCurvePointResult<C>,
119 n: usize,
120 ) -> Vec<AuthenticatedPointResult<C>> {
121 let scalar_results: Vec<CurvePointResult<C>> =
123 values
124 .fabric()
125 .new_batch_gate_op(vec![values.id()], n, |mut args| {
126 let args: Vec<CurvePoint<C>> = args.pop().unwrap().into();
127 args.into_iter().map(ResultValue::Point).collect_vec()
128 });
129
130 Self::new_shared_batch(&scalar_results)
131 }
132
133 pub fn id(&self) -> ResultId {
135 self.share.id()
136 }
137
138 pub(crate) fn ids(&self) -> Vec<ResultId> {
141 vec![self.share.id(), self.mac.id(), self.public_modifier.id]
142 }
143
144 pub fn fabric(&self) -> &MpcFabric<C> {
146 self.share.fabric()
147 }
148
149 #[cfg(feature = "test_helpers")]
151 pub fn mpc_share(&self) -> MpcPointResult<C> {
152 self.share.clone()
153 }
154
155 pub fn open(&self) -> CurvePointResult<C> {
157 self.share.open()
158 }
159
160 pub fn open_batch(values: &[Self]) -> Vec<CurvePointResult<C>> {
162 MpcPointResult::open_batch(&values.iter().map(|v| v.share.clone()).collect_vec())
163 }
164
165 pub fn from_flattened_iterator<I>(iter: I) -> Vec<Self>
171 where
172 I: Iterator<Item = CurvePointResult<C>>,
173 {
174 iter.chunks(AUTHENTICATED_POINT_RESULT_LEN)
175 .into_iter()
176 .map(|mut chunk| Self {
177 share: chunk.next().unwrap().into(),
178 mac: chunk.next().unwrap().into(),
179 public_modifier: chunk.next().unwrap(),
180 })
181 .collect_vec()
182 }
183
184 fn verify_mac_check(
186 my_mac_share: CurvePoint<C>,
187 peer_mac_share: CurvePoint<C>,
188 peer_mac_commitment: Scalar<C>,
189 peer_blinder: Scalar<C>,
190 ) -> bool {
191 let peer_comm = HashCommitment {
194 value: peer_mac_share,
195 blinder: peer_blinder,
196 commitment: peer_mac_commitment,
197 };
198 if !peer_comm.verify() {
199 return false;
200 }
201
202 if my_mac_share + peer_mac_share != CurvePoint::identity() {
205 return false;
206 }
207
208 true
209 }
210
211 pub fn open_authenticated(&self) -> AuthenticatedPointOpenResult<C> {
216 let recovered_value = self.share.open();
218
219 let mac_check: CurvePointResult<C> = self.fabric().new_gate_op(
222 vec![
223 self.fabric().borrow_mac_key().id(),
224 recovered_value.id(),
225 self.public_modifier.id(),
226 self.mac.id(),
227 ],
228 |mut args| {
229 let mac_key_share: Scalar<C> = args.remove(0).into();
230 let value: CurvePoint<C> = args.remove(0).into();
231 let modifier: CurvePoint<C> = args.remove(0).into();
232 let mac_share: CurvePoint<C> = args.remove(0).into();
233
234 ResultValue::Point((value + modifier) * mac_key_share - mac_share)
235 },
236 );
237
238 let my_comm = HashCommitmentResult::commit(mac_check.clone());
240 let peer_commit = self.fabric().exchange_value(my_comm.commitment);
241
242 let peer_mac_check = self.fabric().exchange_value(my_comm.value.clone());
245 let blinder_result: ScalarResult<C> = self.fabric().allocate_scalar(my_comm.blinder);
246 let peer_blinder = self.fabric().exchange_value(blinder_result);
247
248 let commitment_check: ScalarResult<C> = self.fabric().new_gate_op(
250 vec![
251 mac_check.id,
252 peer_mac_check.id,
253 peer_blinder.id,
254 peer_commit.id,
255 ],
256 move |mut args| {
257 let my_mac_check: CurvePoint<C> = args.remove(0).into();
258 let peer_mac_check: CurvePoint<C> = args.remove(0).into();
259 let peer_blinder: Scalar<C> = args.remove(0).into();
260 let peer_commitment: Scalar<C> = args.remove(0).into();
261
262 ResultValue::Scalar(Scalar::from(Self::verify_mac_check(
263 my_mac_check,
264 peer_mac_check,
265 peer_commitment,
266 peer_blinder,
267 )))
268 },
269 );
270
271 AuthenticatedPointOpenResult {
272 value: recovered_value,
273 mac_check: commitment_check,
274 }
275 }
276
277 pub fn open_authenticated_batch(values: &[Self]) -> Vec<AuthenticatedPointOpenResult<C>> {
279 if values.is_empty() {
280 return Vec::new();
281 }
282
283 let n = values.len();
284 let fabric = values[0].fabric();
285
286 let opened_values = Self::open_batch(values);
288
289 let mut mac_check_deps = Vec::with_capacity(1 + AUTHENTICATED_POINT_RESULT_LEN * n);
293 mac_check_deps.push(fabric.borrow_mac_key().id());
294 for i in 0..n {
295 mac_check_deps.push(opened_values[i].id());
296 mac_check_deps.push(values[i].public_modifier.id());
297 mac_check_deps.push(values[i].mac.id());
298 }
299
300 let mac_checks: Vec<CurvePointResult<C>> =
301 fabric.new_batch_gate_op(mac_check_deps, n , move |mut args| {
302 let mac_key_share: Scalar<C> = args.remove(0).into();
303 let mut check_result = Vec::with_capacity(n);
304
305 for _ in 0..n {
306 let value: CurvePoint<C> = args.remove(0).into();
307 let modifier: CurvePoint<C> = args.remove(0).into();
308 let mac_share: CurvePoint<C> = args.remove(0).into();
309
310 check_result.push(mac_key_share * (value + modifier) - mac_share);
311 }
312
313 check_result.into_iter().map(ResultValue::Point).collect()
314 });
315
316 let my_comms = mac_checks
319 .iter()
320 .cloned()
321 .map(HashCommitmentResult::commit)
322 .collect_vec();
323 let peer_comms = fabric.exchange_values(
324 &my_comms
325 .iter()
326 .map(|comm| comm.commitment.clone())
327 .collect_vec(),
328 );
329
330 let peer_mac_checks = fabric.exchange_values(&mac_checks);
333 let peer_blinders = fabric.exchange_values(
334 &my_comms
335 .iter()
336 .map(|comm| fabric.allocate_scalar(comm.blinder))
337 .collect_vec(),
338 );
339
340 let mut mac_check_gate_deps = my_comms.iter().map(|comm| comm.value.id).collect_vec();
343 mac_check_gate_deps.push(peer_mac_checks.id);
344 mac_check_gate_deps.push(peer_blinders.id);
345 mac_check_gate_deps.push(peer_comms.id);
346
347 let commitment_checks: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
348 mac_check_gate_deps,
349 n, move |mut args| {
351 let my_comms: Vec<CurvePoint<C>> =
352 args.drain(..n).map(|comm| comm.into()).collect();
353 let peer_mac_checks: Vec<CurvePoint<C>> = args.remove(0).into();
354 let peer_blinders: Vec<Scalar<C>> = args.remove(0).into();
355 let peer_comms: Vec<Scalar<C>> = args.remove(0).into();
356
357 let mut mac_checks = Vec::with_capacity(n);
359 for (my_mac_share, peer_mac_share, peer_blinder, peer_commitment) in izip!(
360 my_comms.into_iter(),
361 peer_mac_checks.into_iter(),
362 peer_blinders.into_iter(),
363 peer_comms.into_iter()
364 ) {
365 let mac_check = Self::verify_mac_check(
366 my_mac_share,
367 peer_mac_share,
368 peer_commitment,
369 peer_blinder,
370 );
371 mac_checks.push(ResultValue::Scalar(Scalar::from(mac_check)));
372 }
373
374 mac_checks
375 },
376 );
377
378 opened_values
381 .into_iter()
382 .zip(commitment_checks)
383 .map(|(value, check)| AuthenticatedPointOpenResult {
384 value,
385 mac_check: check,
386 })
387 .collect_vec()
388 }
389}
390
391#[derive(Clone)]
395pub struct AuthenticatedPointOpenResult<C: CurveGroup> {
396 pub value: CurvePointResult<C>,
398 pub mac_check: ScalarResult<C>,
400}
401
402impl<C: CurveGroup> Debug for AuthenticatedPointOpenResult<C> {
403 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
404 f.debug_struct("AuthenticatedPointOpenResult")
405 .field("value", &self.value.id)
406 .field("mac_check", &self.mac_check.id)
407 .finish()
408 }
409}
410
411impl<C: CurveGroup> Future for AuthenticatedPointOpenResult<C>
412where
413 C::ScalarField: Unpin,
414{
415 type Output = Result<CurvePoint<C>, MpcError>;
416
417 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
418 let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
420 let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));
421
422 if mac_check == Scalar::from(1u8) {
423 Poll::Ready(Ok(value))
424 } else {
425 Poll::Ready(Err(MpcError::AuthenticationError))
426 }
427 }
428}
429
430impl<C: CurveGroup> Sum for AuthenticatedPointResult<C> {
431 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
433 let first = iter
434 .next()
435 .expect("AuthenticatedPointResult<C>::sum requires a non-empty iterator");
436 iter.fold(first, |acc, x| acc + x)
437 }
438}
439
440impl<C: CurveGroup> Add<&CurvePoint<C>> for &AuthenticatedPointResult<C> {
447 type Output = AuthenticatedPointResult<C>;
448
449 fn add(self, other: &CurvePoint<C>) -> AuthenticatedPointResult<C> {
450 let new_share = if self.fabric().party_id() == PARTY0 {
451 &self.share + other
453 } else {
454 &self.share + CurvePoint::identity()
457 };
458
459 let new_modifier = &self.public_modifier - other;
461 AuthenticatedPointResult {
462 share: new_share,
463 mac: self.mac.clone(),
464 public_modifier: new_modifier,
465 }
466 }
467}
468impl_borrow_variants!(AuthenticatedPointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
469impl_commutative!(AuthenticatedPointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
470
471impl<C: CurveGroup> Add<&CurvePointResult<C>> for &AuthenticatedPointResult<C> {
472 type Output = AuthenticatedPointResult<C>;
473
474 fn add(self, other: &CurvePointResult<C>) -> AuthenticatedPointResult<C> {
475 let new_share = if self.fabric().party_id() == PARTY0 {
476 &self.share + other
478 } else {
479 &self.share + CurvePoint::identity()
482 };
483
484 let new_modifier = &self.public_modifier - other;
486 AuthenticatedPointResult {
487 share: new_share,
488 mac: self.mac.clone(),
489 public_modifier: new_modifier,
490 }
491 }
492}
493impl_borrow_variants!(AuthenticatedPointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
494impl_commutative!(AuthenticatedPointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
495
496impl<C: CurveGroup> Add<&AuthenticatedPointResult<C>> for &AuthenticatedPointResult<C> {
497 type Output = AuthenticatedPointResult<C>;
498
499 fn add(self, other: &AuthenticatedPointResult<C>) -> AuthenticatedPointResult<C> {
500 let new_share = &self.share + &other.share;
501
502 let new_mac = &self.mac + &other.mac;
504 AuthenticatedPointResult {
505 share: new_share,
506 mac: new_mac,
507 public_modifier: self.public_modifier.clone() + other.public_modifier.clone(),
508 }
509 }
510}
511impl_borrow_variants!(AuthenticatedPointResult<C>, Add, add, +, AuthenticatedPointResult<C>, C: CurveGroup);
512
513impl<C: CurveGroup> AuthenticatedPointResult<C> {
514 pub fn batch_add(
516 a: &[AuthenticatedPointResult<C>],
517 b: &[AuthenticatedPointResult<C>],
518 ) -> Vec<AuthenticatedPointResult<C>> {
519 assert_eq!(a.len(), b.len(), "batch_add requires equal length vectors");
520 if a.is_empty() {
521 return Vec::new();
522 }
523
524 let n = a.len();
525 let fabric = a[0].fabric();
526 let all_ids = a.iter().chain(b.iter()).flat_map(|p| p.ids()).collect_vec();
527
528 let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
529 all_ids,
530 AUTHENTICATED_POINT_RESULT_LEN * n,
531 move |mut args| {
532 let len = args.len();
533 let a_vals = args.drain(..len / 2).collect_vec();
534 let b_vals = args;
535
536 let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
537 for (a_chunk, b_chunk) in a_vals
538 .chunks(AUTHENTICATED_POINT_RESULT_LEN)
539 .zip(b_vals.chunks(AUTHENTICATED_POINT_RESULT_LEN))
540 {
541 let a_share: CurvePoint<C> = a_chunk[0].clone().into();
542 let a_mac: CurvePoint<C> = a_chunk[1].clone().into();
543 let a_modifier: CurvePoint<C> = a_chunk[2].clone().into();
544
545 let b_share: CurvePoint<C> = b_chunk[0].clone().into();
546 let b_mac: CurvePoint<C> = b_chunk[1].clone().into();
547 let b_modifier: CurvePoint<C> = b_chunk[2].clone().into();
548
549 result.push(ResultValue::Point(a_share + b_share));
550 result.push(ResultValue::Point(a_mac + b_mac));
551 result.push(ResultValue::Point(a_modifier + b_modifier));
552 }
553
554 result
555 },
556 );
557
558 Self::from_flattened_iterator(res.into_iter())
559 }
560
561 pub fn batch_add_public(
564 a: &[AuthenticatedPointResult<C>],
565 b: &[CurvePointResult<C>],
566 ) -> Vec<AuthenticatedPointResult<C>> {
567 assert_eq!(
568 a.len(),
569 b.len(),
570 "batch_add_public requires equal length vectors"
571 );
572 if a.is_empty() {
573 return Vec::new();
574 }
575
576 let n = a.len();
577 let fabric = a[0].fabric();
578 let all_ids = a
579 .iter()
580 .flat_map(|a| a.ids())
581 .chain(b.iter().map(|b| b.id()))
582 .collect_vec();
583
584 let party_id = fabric.party_id();
585 let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
586 all_ids,
587 AUTHENTICATED_POINT_RESULT_LEN * n,
588 move |mut args| {
589 let a_vals = args
590 .drain(..AUTHENTICATED_POINT_RESULT_LEN * n)
591 .collect_vec();
592 let b_vals = args;
593
594 let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
595 for (a_chunk, b_val) in a_vals
596 .chunks(AUTHENTICATED_POINT_RESULT_LEN)
597 .zip(b_vals.into_iter())
598 {
599 let a_share: CurvePoint<C> = a_chunk[0].clone().into();
600 let a_mac: CurvePoint<C> = a_chunk[1].clone().into();
601 let a_modifier: CurvePoint<C> = a_chunk[2].clone().into();
602
603 let public_value: CurvePoint<C> = b_val.into();
604
605 if party_id == PARTY0 {
607 result.push(ResultValue::Point(a_share + public_value));
608 } else {
609 result.push(ResultValue::Point(a_share))
610 }
611
612 result.push(ResultValue::Point(a_mac));
613 result.push(ResultValue::Point(a_modifier - public_value));
614 }
615
616 result
617 },
618 );
619
620 Self::from_flattened_iterator(res.into_iter())
621 }
622}
623
624impl<C: CurveGroup> Sub<&CurvePoint<C>> for &AuthenticatedPointResult<C> {
627 type Output = AuthenticatedPointResult<C>;
628
629 fn sub(self, other: &CurvePoint<C>) -> AuthenticatedPointResult<C> {
630 let new_share = if self.fabric().party_id() == PARTY0 {
631 &self.share - other
633 } else {
634 &self.share - CurvePoint::identity()
637 };
638
639 let new_modifier = &self.public_modifier + other;
641 AuthenticatedPointResult {
642 share: new_share,
643 mac: self.mac.clone(),
644 public_modifier: new_modifier,
645 }
646 }
647}
648impl_borrow_variants!(AuthenticatedPointResult<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
649impl_commutative!(AuthenticatedPointResult<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
650
651impl<C: CurveGroup> Sub<&CurvePointResult<C>> for &AuthenticatedPointResult<C> {
652 type Output = AuthenticatedPointResult<C>;
653
654 fn sub(self, other: &CurvePointResult<C>) -> AuthenticatedPointResult<C> {
655 let new_share = if self.fabric().party_id() == PARTY0 {
656 &self.share - other
658 } else {
659 &self.share - CurvePoint::identity()
662 };
663
664 let new_modifier = &self.public_modifier + other;
666 AuthenticatedPointResult {
667 share: new_share,
668 mac: self.mac.clone(),
669 public_modifier: new_modifier,
670 }
671 }
672}
673impl_borrow_variants!(AuthenticatedPointResult<C>, Sub, sub, -, CurvePointResult<C>, C: CurveGroup);
674impl_commutative!(AuthenticatedPointResult<C>, Sub, sub, -, CurvePointResult<C>, C: CurveGroup);
675
676impl<C: CurveGroup> Sub<&AuthenticatedPointResult<C>> for &AuthenticatedPointResult<C> {
677 type Output = AuthenticatedPointResult<C>;
678
679 fn sub(self, other: &AuthenticatedPointResult<C>) -> AuthenticatedPointResult<C> {
680 let new_share = &self.share - &other.share;
681
682 let new_mac = &self.mac - &other.mac;
684 AuthenticatedPointResult {
685 share: new_share,
686 mac: new_mac,
687 public_modifier: self.public_modifier.clone(),
688 }
689 }
690}
691impl_borrow_variants!(AuthenticatedPointResult<C>, Sub, sub, -, AuthenticatedPointResult<C>, C: CurveGroup);
692
693impl<C: CurveGroup> AuthenticatedPointResult<C> {
694 pub fn batch_sub(
696 a: &[AuthenticatedPointResult<C>],
697 b: &[AuthenticatedPointResult<C>],
698 ) -> Vec<AuthenticatedPointResult<C>> {
699 assert_eq!(a.len(), b.len(), "batch_add requires equal length vectors");
700 if a.is_empty() {
701 return Vec::new();
702 }
703
704 let n = a.len();
705 let fabric = a[0].fabric();
706 let all_ids = a.iter().chain(b.iter()).flat_map(|p| p.ids()).collect_vec();
707
708 let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
709 all_ids,
710 AUTHENTICATED_POINT_RESULT_LEN * n,
711 move |mut args| {
712 let len = args.len();
713 let a_vals = args.drain(..len / 2).collect_vec();
714 let b_vals = args;
715
716 let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
717 for (a_chunk, b_chunk) in a_vals
718 .chunks(AUTHENTICATED_POINT_RESULT_LEN)
719 .zip(b_vals.chunks(AUTHENTICATED_POINT_RESULT_LEN))
720 {
721 let a_share: CurvePoint<C> = a_chunk[0].clone().into();
722 let a_mac: CurvePoint<C> = a_chunk[1].clone().into();
723 let a_modifier: CurvePoint<C> = a_chunk[2].clone().into();
724
725 let b_share: CurvePoint<C> = b_chunk[0].clone().into();
726 let b_mac: CurvePoint<C> = b_chunk[1].clone().into();
727 let b_modifier: CurvePoint<C> = b_chunk[2].clone().into();
728
729 result.push(ResultValue::Point(a_share - b_share));
730 result.push(ResultValue::Point(a_mac - b_mac));
731 result.push(ResultValue::Point(a_modifier - b_modifier));
732 }
733
734 result
735 },
736 );
737
738 Self::from_flattened_iterator(res.into_iter())
739 }
740
741 pub fn batch_sub_public(
744 a: &[AuthenticatedPointResult<C>],
745 b: &[CurvePointResult<C>],
746 ) -> Vec<AuthenticatedPointResult<C>> {
747 assert_eq!(
748 a.len(),
749 b.len(),
750 "batch_add_public requires equal length vectors"
751 );
752 if a.is_empty() {
753 return Vec::new();
754 }
755
756 let n = a.len();
757 let fabric = a[0].fabric();
758 let all_ids = a
759 .iter()
760 .flat_map(|a| a.ids())
761 .chain(b.iter().map(|b| b.id()))
762 .collect_vec();
763
764 let party_id = fabric.party_id();
765 let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
766 all_ids,
767 AUTHENTICATED_POINT_RESULT_LEN * n,
768 move |mut args| {
769 let a_vals = args
770 .drain(..AUTHENTICATED_POINT_RESULT_LEN * n)
771 .collect_vec();
772 let b_vals = args;
773
774 let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
775 for (a_chunk, b_val) in a_vals
776 .chunks(AUTHENTICATED_POINT_RESULT_LEN)
777 .zip(b_vals.into_iter())
778 {
779 let a_share: CurvePoint<C> = a_chunk[0].clone().into();
780 let a_mac: CurvePoint<C> = a_chunk[1].clone().into();
781 let a_modifier: CurvePoint<C> = a_chunk[2].clone().into();
782
783 let b_share: CurvePoint<C> = b_val.into();
784
785 if party_id == PARTY0 {
787 result.push(ResultValue::Point(a_share - b_share));
788 } else {
789 result.push(ResultValue::Point(a_share))
790 }
791
792 result.push(ResultValue::Point(a_mac));
793 result.push(ResultValue::Point(a_modifier + b_share));
794 }
795
796 result
797 },
798 );
799
800 Self::from_flattened_iterator(res.into_iter())
801 }
802}
803
804impl<C: CurveGroup> Neg for &AuthenticatedPointResult<C> {
807 type Output = AuthenticatedPointResult<C>;
808
809 fn neg(self) -> AuthenticatedPointResult<C> {
810 let new_share = -&self.share;
811
812 let new_mac = -&self.mac;
814 AuthenticatedPointResult {
815 share: new_share,
816 mac: new_mac,
817 public_modifier: self.public_modifier.clone(),
818 }
819 }
820}
821impl_borrow_variants!(AuthenticatedPointResult<C>, Neg, neg, -, C: CurveGroup);
822
823impl<C: CurveGroup> AuthenticatedPointResult<C> {
824 pub fn batch_neg(a: &[AuthenticatedPointResult<C>]) -> Vec<AuthenticatedPointResult<C>> {
826 if a.is_empty() {
827 return Vec::new();
828 }
829
830 let n = a.len();
831 let fabric = a[0].fabric();
832 let all_ids = a.iter().flat_map(|p| p.ids()).collect_vec();
833
834 let res: Vec<CurvePointResult<C>> =
835 fabric.new_batch_gate_op(all_ids, AUTHENTICATED_POINT_RESULT_LEN * n, move |args| {
836 args.into_iter()
837 .map(CurvePoint::from)
838 .map(CurvePoint::neg)
839 .map(ResultValue::Point)
840 .collect_vec()
841 });
842
843 Self::from_flattened_iterator(res.into_iter())
844 }
845}
846
847impl<C: CurveGroup> Mul<&Scalar<C>> for &AuthenticatedPointResult<C> {
850 type Output = AuthenticatedPointResult<C>;
851
852 fn mul(self, other: &Scalar<C>) -> AuthenticatedPointResult<C> {
853 let new_share = &self.share * other;
854
855 let new_mac = &self.mac * other;
857 let new_modifier = &self.public_modifier * other;
858 AuthenticatedPointResult {
859 share: new_share,
860 mac: new_mac,
861 public_modifier: new_modifier,
862 }
863 }
864}
865impl_borrow_variants!(AuthenticatedPointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
866impl_commutative!(AuthenticatedPointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
867
868impl<C: CurveGroup> Mul<&ScalarResult<C>> for &AuthenticatedPointResult<C> {
869 type Output = AuthenticatedPointResult<C>;
870
871 fn mul(self, other: &ScalarResult<C>) -> AuthenticatedPointResult<C> {
872 let new_share = &self.share * other;
873
874 let new_mac = &self.mac * other;
876 let new_modifier = &self.public_modifier * other;
877 AuthenticatedPointResult {
878 share: new_share,
879 mac: new_mac,
880 public_modifier: new_modifier,
881 }
882 }
883}
884impl_borrow_variants!(AuthenticatedPointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
885impl_commutative!(AuthenticatedPointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
886
887impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &AuthenticatedPointResult<C> {
888 type Output = AuthenticatedPointResult<C>;
889
890 fn mul(self, rhs: &AuthenticatedScalarResult<C>) -> AuthenticatedPointResult<C> {
892 let generator = CurvePoint::generator();
894 let (a, b, c) = self.fabric().next_authenticated_triple();
895
896 let masked_rhs = rhs - &a;
898 let masked_lhs = self - (&generator * &b);
899
900 #[allow(non_snake_case)]
901 let eG_open = masked_lhs.open();
902 let d_open = masked_rhs.open();
903
904 &d_open * &eG_open + &d_open * &(&generator * &b) + &a * eG_open + &c * generator
906 }
907}
908impl_borrow_variants!(AuthenticatedPointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, C: CurveGroup);
909impl_commutative!(AuthenticatedPointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, C: CurveGroup);
910
911impl<C: CurveGroup> AuthenticatedPointResult<C> {
912 #[allow(non_snake_case)]
915 pub fn batch_mul(
916 a: &[AuthenticatedScalarResult<C>],
917 b: &[AuthenticatedPointResult<C>],
918 ) -> Vec<AuthenticatedPointResult<C>> {
919 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
920 if a.is_empty() {
921 return Vec::new();
922 }
923
924 let n = a.len();
925 let fabric = a[0].fabric();
926
927 let (beaver_a, beaver_b, beaver_c) = fabric.next_authenticated_triple_batch(n);
929 let beaver_b_gen = AuthenticatedPointResult::batch_mul_generator(&beaver_b);
930
931 let masked_rhs = AuthenticatedScalarResult::batch_sub(a, &beaver_a);
932 let masked_lhs = AuthenticatedPointResult::batch_sub(b, &beaver_b_gen);
933
934 let eG_open = AuthenticatedPointResult::open_batch(&masked_lhs);
935 let d_open = AuthenticatedScalarResult::open_batch(&masked_rhs);
936
937 let deG = CurvePointResult::batch_mul(&d_open, &eG_open);
939 let dbG = AuthenticatedPointResult::batch_mul_public(&d_open, &beaver_b_gen);
940 let aeG = CurvePointResult::batch_mul_authenticated(&beaver_a, &eG_open);
941 let cG = AuthenticatedPointResult::batch_mul_generator(&beaver_c);
942
943 let de_db_G = AuthenticatedPointResult::batch_add_public(&dbG, &deG);
944 let ae_c_G = AuthenticatedPointResult::batch_add(&aeG, &cG);
945
946 AuthenticatedPointResult::batch_add(&de_db_G, &ae_c_G)
947 }
948
949 pub fn batch_mul_public(
952 a: &[ScalarResult<C>],
953 b: &[AuthenticatedPointResult<C>],
954 ) -> Vec<AuthenticatedPointResult<C>> {
955 assert_eq!(
956 a.len(),
957 b.len(),
958 "batch_mul_public requires equal length vectors"
959 );
960 if a.is_empty() {
961 return Vec::new();
962 }
963
964 let n = a.len();
965 let fabric = a[0].fabric();
966 let all_ids = a
967 .iter()
968 .map(|a| a.id())
969 .chain(b.iter().flat_map(|p| p.ids()))
970 .collect_vec();
971
972 let results: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
973 all_ids,
974 AUTHENTICATED_POINT_RESULT_LEN * n, move |mut args| {
976 let scalars: Vec<Scalar<C>> = args.drain(..n).map(Scalar::from).collect_vec();
977 let points: Vec<CurvePoint<C>> =
978 args.into_iter().map(CurvePoint::from).collect_vec();
979
980 let mut result = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
981 for (scalar, points) in scalars
982 .into_iter()
983 .zip(points.chunks(AUTHENTICATED_POINT_RESULT_LEN))
984 {
985 let share: CurvePoint<C> = points[0];
986 let mac: CurvePoint<C> = points[1];
987 let modifier: CurvePoint<C> = points[2];
988
989 result.push(ResultValue::Point(share * scalar));
990 result.push(ResultValue::Point(mac * scalar));
991 result.push(ResultValue::Point(modifier * scalar));
992 }
993
994 result
995 },
996 );
997
998 Self::from_flattened_iterator(results.into_iter())
999 }
1000
1001 pub fn batch_mul_generator(
1003 a: &[AuthenticatedScalarResult<C>],
1004 ) -> Vec<AuthenticatedPointResult<C>> {
1005 if a.is_empty() {
1006 return Vec::new();
1007 }
1008
1009 let n = a.len();
1010 let fabric = a[0].fabric();
1011 let all_ids = a.iter().flat_map(|v| v.ids()).collect_vec();
1012
1013 let results = fabric.new_batch_gate_op(
1015 all_ids,
1016 AUTHENTICATED_POINT_RESULT_LEN * n, move |args| {
1018 let scalars = args.into_iter().map(Scalar::from).collect_vec();
1019 let generator = CurvePoint::generator();
1020
1021 scalars
1022 .into_iter()
1023 .map(|x| x * generator)
1024 .map(ResultValue::Point)
1025 .collect_vec()
1026 },
1027 );
1028
1029 Self::from_flattened_iterator(results.into_iter())
1030 }
1031}
1032
1033impl<C: CurveGroup> AuthenticatedPointResult<C> {
1036 pub fn msm(
1041 scalars: &[AuthenticatedScalarResult<C>],
1042 points: &[AuthenticatedPointResult<C>],
1043 ) -> AuthenticatedPointResult<C> {
1044 assert_eq!(
1045 scalars.len(),
1046 points.len(),
1047 "multiscalar_mul requires equal length vectors"
1048 );
1049 assert!(
1050 !scalars.is_empty(),
1051 "multiscalar_mul requires non-empty vectors"
1052 );
1053
1054 let mul_out = AuthenticatedPointResult::batch_mul(scalars, points);
1055
1056 let fabric = scalars[0].fabric();
1058 let all_ids = mul_out.iter().flat_map(|p| p.ids()).collect_vec();
1059
1060 let results = fabric.new_batch_gate_op(
1061 all_ids,
1062 AUTHENTICATED_POINT_RESULT_LEN, move |args| {
1064 let mut share = CurvePoint::identity();
1066 let mut mac = CurvePoint::identity();
1067 let mut modifier = CurvePoint::identity();
1068
1069 for mut chunk in args
1070 .into_iter()
1071 .map(CurvePoint::from)
1072 .chunks(AUTHENTICATED_POINT_RESULT_LEN)
1073 .into_iter()
1074 {
1075 share += chunk.next().unwrap();
1076 mac += chunk.next().unwrap();
1077 modifier += chunk.next().unwrap();
1078 }
1079
1080 vec![
1081 ResultValue::Point(share),
1082 ResultValue::Point(mac),
1083 ResultValue::Point(modifier),
1084 ]
1085 },
1086 );
1087
1088 AuthenticatedPointResult {
1089 share: results[0].clone().into(),
1090 mac: results[1].clone().into(),
1091 public_modifier: results[2].clone(),
1092 }
1093 }
1094
1095 pub fn msm_iter<S, P>(scalars: S, points: P) -> AuthenticatedPointResult<C>
1097 where
1098 S: IntoIterator<Item = AuthenticatedScalarResult<C>>,
1099 P: IntoIterator<Item = AuthenticatedPointResult<C>>,
1100 {
1101 let scalars = scalars.into_iter().collect::<Vec<_>>();
1102 let points = points.into_iter().collect::<Vec<_>>();
1103
1104 Self::msm(&scalars, &points)
1105 }
1106}
1107
1108#[cfg(feature = "test_helpers")]
1115pub mod test_helpers {
1116 use ark_ec::CurveGroup;
1117
1118 use crate::algebra::curve::CurvePoint;
1119
1120 use super::AuthenticatedPointResult;
1121
1122 pub fn modify_mac<C: CurveGroup>(
1124 point: &mut AuthenticatedPointResult<C>,
1125 new_mac: CurvePoint<C>,
1126 ) {
1127 point.mac = point.fabric().allocate_point(new_mac).into()
1128 }
1129
1130 pub fn modify_share<C: CurveGroup>(
1132 point: &mut AuthenticatedPointResult<C>,
1133 new_share: CurvePoint<C>,
1134 ) {
1135 point.share = point.fabric().allocate_point(new_share).into()
1136 }
1137
1138 pub fn modify_public_modifier<C: CurveGroup>(
1140 point: &mut AuthenticatedPointResult<C>,
1141 new_modifier: CurvePoint<C>,
1142 ) {
1143 point.public_modifier = point.fabric().allocate_point(new_modifier)
1144 }
1145}