1use std::{
4 fmt::Debug,
5 iter::Sum,
6 ops::{Add, Mul, Neg, Sub},
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11use futures::{Future, FutureExt};
12use itertools::{izip, Itertools};
13
14use crate::{
15 commitment::{PedersenCommitment, PedersenCommitmentResult},
16 error::MpcError,
17 fabric::{MpcFabric, ResultId, ResultValue},
18 ResultHandle, PARTY0,
19};
20
21use super::{
22 authenticated_stark_point::AuthenticatedStarkPointResult,
23 macros::{impl_borrow_variants, impl_commutative},
24 mpc_scalar::MpcScalarResult,
25 scalar::{BatchScalarResult, Scalar, ScalarResult},
26 stark_curve::{StarkPoint, StarkPointResult},
27};
28
29pub const AUTHENTICATED_SCALAR_RESULT_LEN: usize = 3;
31
32#[derive(Clone)]
36pub struct AuthenticatedScalarResult {
37 pub(crate) share: MpcScalarResult,
39 pub(crate) mac: MpcScalarResult,
47 pub(crate) public_modifier: ScalarResult,
52}
53
54impl Debug for AuthenticatedScalarResult {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("AuthenticatedScalarResult")
57 .field("value", &self.share.id())
58 .field("mac", &self.mac.id())
59 .field("public_modifier", &self.public_modifier.id)
60 .finish()
61 }
62}
63
64impl AuthenticatedScalarResult {
65 pub fn new_shared(value: ScalarResult) -> Self {
67 let fabric = value.fabric.clone();
69
70 let mpc_value = MpcScalarResult::new_shared(value);
71 let mac = fabric.borrow_mac_key() * mpc_value.clone();
72
73 let public_modifier = fabric.zero();
75
76 Self {
77 share: mpc_value,
78 mac,
79 public_modifier,
80 }
81 }
82
83 pub fn new_shared_batch(values: &[ScalarResult]) -> Vec<Self> {
85 if values.is_empty() {
86 return vec![];
87 }
88
89 let n = values.len();
90 let fabric = values[0].fabric();
91 let mpc_values = values
92 .iter()
93 .map(|v| MpcScalarResult::new_shared(v.clone()))
94 .collect_vec();
95
96 let mac_keys = (0..n)
97 .map(|_| fabric.borrow_mac_key().clone())
98 .collect_vec();
99 let values_macs = MpcScalarResult::batch_mul(&mpc_values, &mac_keys);
100
101 mpc_values
102 .into_iter()
103 .zip(values_macs.into_iter())
104 .map(|(value, mac)| Self {
105 share: value,
106 mac,
107 public_modifier: fabric.zero(),
108 })
109 .collect_vec()
110 }
111
112 pub fn new_shared_from_batch_result(
117 values: BatchScalarResult,
118 n: usize,
119 ) -> Vec<AuthenticatedScalarResult> {
120 let scalar_results = values
122 .fabric()
123 .new_batch_gate_op(vec![values.id()], n, |mut args| {
124 let scalars: Vec<Scalar> = args.pop().unwrap().into();
125 scalars.into_iter().map(ResultValue::Scalar).collect()
126 });
127
128 Self::new_shared_batch(&scalar_results)
129 }
130
131 #[cfg(feature = "test_helpers")]
133 pub fn mpc_share(&self) -> MpcScalarResult {
134 self.share.clone()
135 }
136
137 pub fn share(&self) -> ScalarResult {
139 self.share.to_scalar()
140 }
141
142 pub fn fabric(&self) -> &MpcFabric {
144 self.share.fabric()
145 }
146
147 pub fn ids(&self) -> Vec<ResultId> {
150 vec![self.share.id(), self.mac.id(), self.public_modifier.id]
151 }
152
153 pub fn open(&self) -> ScalarResult {
155 self.share.open()
156 }
157
158 pub fn open_batch(values: &[Self]) -> Vec<ScalarResult> {
160 MpcScalarResult::open_batch(&values.iter().map(|val| val.share.clone()).collect_vec())
161 }
162
163 pub fn from_flattened_iterator<I>(iter: I) -> Vec<Self>
168 where
169 I: Iterator<Item = ResultHandle<Scalar>>,
170 {
171 iter.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
172 .into_iter()
173 .map(|mut chunk| Self {
174 share: chunk.next().unwrap().into(),
175 mac: chunk.next().unwrap().into(),
176 public_modifier: chunk.next().unwrap(),
177 })
178 .collect_vec()
179 }
180
181 pub fn verify_mac_check(
183 my_mac_share: Scalar,
184 peer_mac_share: Scalar,
185 peer_mac_commitment: StarkPoint,
186 peer_commitment_blinder: Scalar,
187 ) -> bool {
188 let their_comm = PedersenCommitment {
189 value: peer_mac_share,
190 blinder: peer_commitment_blinder,
191 commitment: peer_mac_commitment,
192 };
193
194 if !their_comm.verify() {
196 return false;
197 }
198
199 if peer_mac_share + my_mac_share != Scalar::from(0) {
201 return false;
202 }
203
204 true
205 }
206
207 pub fn open_authenticated(&self) -> AuthenticatedScalarOpenResult {
213 let recovered_value = self.share.open();
215
216 let mac_check_value: ScalarResult = self.fabric().new_gate_op(
218 vec![
219 self.fabric().borrow_mac_key().id(),
220 recovered_value.id,
221 self.public_modifier.id,
222 self.mac.id(),
223 ],
224 move |mut args| {
225 let mac_key_share: Scalar = args.remove(0).into();
226 let value: Scalar = args.remove(0).into();
227 let modifier: Scalar = args.remove(0).into();
228 let mac_share: Scalar = args.remove(0).into();
229
230 ResultValue::Scalar(mac_key_share * (value + modifier) - mac_share)
231 },
232 );
233
234 let my_comm = PedersenCommitmentResult::commit(mac_check_value);
236 let peer_commit = self.fabric().exchange_value(my_comm.commitment);
237
238 let peer_mac_check = self.fabric().exchange_value(my_comm.value.clone());
241
242 let blinder_result: ScalarResult = self.fabric().allocate_scalar(my_comm.blinder);
243 let peer_blinder = self.fabric().exchange_value(blinder_result);
244
245 let commitment_check: ScalarResult = self.fabric().new_gate_op(
247 vec![
248 my_comm.value.id,
249 peer_mac_check.id,
250 peer_blinder.id,
251 peer_commit.id,
252 ],
253 |mut args| {
254 let my_comm_value: Scalar = args.remove(0).into();
255 let peer_value: Scalar = args.remove(0).into();
256 let blinder: Scalar = args.remove(0).into();
257 let commitment: StarkPoint = args.remove(0).into();
258
259 ResultValue::Scalar(Scalar::from(Self::verify_mac_check(
261 my_comm_value,
262 peer_value,
263 commitment,
264 blinder,
265 )))
266 },
267 );
268
269 AuthenticatedScalarOpenResult {
270 value: recovered_value,
271 mac_check: commitment_check,
272 }
273 }
274
275 pub fn open_authenticated_batch(values: &[Self]) -> Vec<AuthenticatedScalarOpenResult> {
277 if values.is_empty() {
278 return vec![];
279 }
280
281 let n = values.len();
282 let fabric = &values[0].fabric();
283
284 let values_open = Self::open_batch(values);
286
287 let mut mac_check_deps = Vec::with_capacity(1 + 3 * n);
291 mac_check_deps.push(fabric.borrow_mac_key().id());
292 for i in 0..n {
293 mac_check_deps.push(values_open[i].id());
294 mac_check_deps.push(values[i].public_modifier.id());
295 mac_check_deps.push(values[i].mac.id());
296 }
297
298 let mac_checks: Vec<ScalarResult> =
299 fabric.new_batch_gate_op(mac_check_deps, n , move |mut args| {
300 let mac_key_share: Scalar = args.remove(0).into();
301 let mut check_result = Vec::with_capacity(n);
302
303 for _ in 0..n {
304 let value: Scalar = args.remove(0).into();
305 let modifier: Scalar = args.remove(0).into();
306 let mac_share: Scalar = args.remove(0).into();
307
308 check_result.push(mac_key_share * (value + modifier) - mac_share);
309 }
310
311 check_result.into_iter().map(ResultValue::Scalar).collect()
312 });
313
314 let my_comms = mac_checks
317 .iter()
318 .cloned()
319 .map(PedersenCommitmentResult::commit)
320 .collect_vec();
321 let peer_comms = fabric.exchange_values(
322 &my_comms
323 .iter()
324 .map(|comm| comm.commitment.clone())
325 .collect_vec(),
326 );
327
328 let peer_mac_checks = fabric.exchange_values(&mac_checks);
331 let peer_blinders = fabric.exchange_values(
332 &my_comms
333 .iter()
334 .map(|comm| fabric.allocate_scalar(comm.blinder))
335 .collect_vec(),
336 );
337
338 let mut mac_check_gate_deps = my_comms.iter().map(|comm| comm.value.id).collect_vec();
341 mac_check_gate_deps.push(peer_mac_checks.id);
342 mac_check_gate_deps.push(peer_blinders.id);
343 mac_check_gate_deps.push(peer_comms.id);
344
345 let commitment_checks: Vec<ScalarResult> = fabric.new_batch_gate_op(
346 mac_check_gate_deps,
347 n, move |mut args| {
349 let my_comms: Vec<Scalar> = args.drain(..n).map(|comm| comm.into()).collect();
350 let peer_mac_checks: Vec<Scalar> = args.remove(0).into();
351 let peer_blinders: Vec<Scalar> = args.remove(0).into();
352 let peer_comms: Vec<StarkPoint> = args.remove(0).into();
353
354 let mut mac_checks = Vec::with_capacity(n);
356 for (my_mac_share, peer_mac_share, peer_blinder, peer_commitment) in izip!(
357 my_comms.into_iter(),
358 peer_mac_checks.into_iter(),
359 peer_blinders.into_iter(),
360 peer_comms.into_iter()
361 ) {
362 let mac_check = Self::verify_mac_check(
363 my_mac_share,
364 peer_mac_share,
365 peer_commitment,
366 peer_blinder,
367 );
368 mac_checks.push(ResultValue::Scalar(Scalar::from(mac_check)));
369 }
370
371 mac_checks
372 },
373 );
374
375 values_open
378 .into_iter()
379 .zip(commitment_checks.into_iter())
380 .map(|(value, check)| AuthenticatedScalarOpenResult {
381 value,
382 mac_check: check,
383 })
384 .collect_vec()
385 }
386}
387
388#[derive(Clone)]
391pub struct AuthenticatedScalarOpenResult {
392 pub value: ScalarResult,
394 pub mac_check: ScalarResult,
396}
397
398impl Future for AuthenticatedScalarOpenResult {
399 type Output = Result<Scalar, MpcError>;
400
401 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
402 let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
404 let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));
405
406 if mac_check == Scalar::from(1) {
407 Poll::Ready(Ok(value))
408 } else {
409 Poll::Ready(Err(MpcError::AuthenticationError))
410 }
411 }
412}
413
414impl Add<&Scalar> for &AuthenticatedScalarResult {
421 type Output = AuthenticatedScalarResult;
422
423 fn add(self, rhs: &Scalar) -> Self::Output {
424 let new_share = if self.fabric().party_id() == PARTY0 {
425 &self.share + rhs
426 } else {
427 &self.share + Scalar::from(0)
428 };
429
430 let new_modifier = &self.public_modifier - rhs;
433 AuthenticatedScalarResult {
434 share: new_share,
435 mac: self.mac.clone(),
436 public_modifier: new_modifier,
437 }
438 }
439}
440impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, Scalar, Output=AuthenticatedScalarResult);
441impl_commutative!(AuthenticatedScalarResult, Add, add, +, Scalar, Output=AuthenticatedScalarResult);
442
443impl Add<&ScalarResult> for &AuthenticatedScalarResult {
444 type Output = AuthenticatedScalarResult;
445
446 fn add(self, rhs: &ScalarResult) -> Self::Output {
447 let new_share = if self.fabric().party_id() == PARTY0 {
452 &self.share + rhs
453 } else {
454 &self.share + Scalar::from(0)
455 };
456
457 let new_modifier = &self.public_modifier - rhs;
458 AuthenticatedScalarResult {
459 share: new_share,
460 mac: self.mac.clone(),
461 public_modifier: new_modifier,
462 }
463 }
464}
465impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, ScalarResult, Output=AuthenticatedScalarResult);
466impl_commutative!(AuthenticatedScalarResult, Add, add, +, ScalarResult, Output=AuthenticatedScalarResult);
467
468impl Add<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
469 type Output = AuthenticatedScalarResult;
470
471 fn add(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
472 AuthenticatedScalarResult {
473 share: &self.share + &rhs.share,
474 mac: &self.mac + &rhs.mac,
475 public_modifier: self.public_modifier.clone() + rhs.public_modifier.clone(),
476 }
477 }
478}
479impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
480
481impl AuthenticatedScalarResult {
482 pub fn batch_add(
484 a: &[AuthenticatedScalarResult],
485 b: &[AuthenticatedScalarResult],
486 ) -> Vec<AuthenticatedScalarResult> {
487 assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
488
489 let n = a.len();
490 let fabric = a[0].fabric();
491 let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
492
493 let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
495 all_ids,
496 AUTHENTICATED_SCALAR_RESULT_LEN * n, move |mut args| {
498 let arg_len = args.len();
499 let a_vals = args.drain(..arg_len / 2).collect_vec();
500 let b_vals = args;
501
502 let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
503 for (mut a_vals, mut b_vals) in a_vals
504 .into_iter()
505 .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
506 .into_iter()
507 .zip(
508 b_vals
509 .into_iter()
510 .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
511 .into_iter(),
512 )
513 {
514 let a_share: Scalar = a_vals.next().unwrap().into();
515 let a_mac_share: Scalar = a_vals.next().unwrap().into();
516 let a_modifier: Scalar = a_vals.next().unwrap().into();
517
518 let b_share: Scalar = b_vals.next().unwrap().into();
519 let b_mac_share: Scalar = b_vals.next().unwrap().into();
520 let b_modifier: Scalar = b_vals.next().unwrap().into();
521
522 result.push(ResultValue::Scalar(a_share + b_share));
523 result.push(ResultValue::Scalar(a_mac_share + b_mac_share));
524 result.push(ResultValue::Scalar(a_modifier + b_modifier));
525 }
526
527 result
528 },
529 );
530
531 AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
533 }
534
535 pub fn batch_add_public(
537 a: &[AuthenticatedScalarResult],
538 b: &[ScalarResult],
539 ) -> Vec<AuthenticatedScalarResult> {
540 assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
541
542 let n = a.len();
543 let results_per_value = 3;
544 let fabric = a[0].fabric();
545 let all_ids = a
546 .iter()
547 .flat_map(|v| v.ids())
548 .chain(b.iter().map(|v| v.id()))
549 .collect_vec();
550
551 let party_id = fabric.party_id();
553 let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
554 all_ids,
555 results_per_value * n, move |mut args| {
557 let a_vals = args
559 .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
560 .collect_vec();
561 let public_values = args;
562
563 let mut result = Vec::with_capacity(results_per_value * n);
564 for (mut a_vals, public_value) in a_vals
565 .into_iter()
566 .chunks(results_per_value)
567 .into_iter()
568 .zip(public_values.into_iter())
569 {
570 let a_share: Scalar = a_vals.next().unwrap().into();
571 let a_mac_share: Scalar = a_vals.next().unwrap().into();
572 let a_modifier: Scalar = a_vals.next().unwrap().into();
573
574 let public_value: Scalar = public_value.into();
575
576 if party_id == PARTY0 {
578 result.push(ResultValue::Scalar(a_share + public_value));
579 } else {
580 result.push(ResultValue::Scalar(a_share));
581 }
582
583 result.push(ResultValue::Scalar(a_mac_share));
584 result.push(ResultValue::Scalar(a_modifier - public_value));
585 }
586
587 result
588 },
589 );
590
591 AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
593 }
594}
595
596impl Sum for AuthenticatedScalarResult {
599 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
601 let seed = iter.next().expect("Cannot sum empty iterator");
602 iter.fold(seed, |acc, val| acc + &val)
603 }
604}
605
606impl Sub<&Scalar> for &AuthenticatedScalarResult {
609 type Output = AuthenticatedScalarResult;
610
611 fn sub(self, rhs: &Scalar) -> Self::Output {
614 let new_share = &self.share - rhs;
617
618 let new_modifier = &self.public_modifier + rhs;
621 AuthenticatedScalarResult {
622 share: new_share,
623 mac: self.mac.clone(),
624 public_modifier: new_modifier,
625 }
626 }
627}
628impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, Scalar, Output=AuthenticatedScalarResult);
629
630impl Sub<&AuthenticatedScalarResult> for &Scalar {
631 type Output = AuthenticatedScalarResult;
632
633 fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
634 let new_share = self - &rhs.share;
637
638 let new_modifier = -self - &rhs.public_modifier;
641 AuthenticatedScalarResult {
642 share: new_share,
643 mac: -&rhs.mac,
644 public_modifier: new_modifier,
645 }
646 }
647}
648impl_borrow_variants!(Scalar, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
649
650impl Sub<&ScalarResult> for &AuthenticatedScalarResult {
651 type Output = AuthenticatedScalarResult;
652
653 fn sub(self, rhs: &ScalarResult) -> Self::Output {
654 let new_share = &self.share - rhs;
655
656 let new_modifier = &self.public_modifier + rhs;
659 AuthenticatedScalarResult {
660 share: new_share,
661 mac: self.mac.clone(),
662 public_modifier: new_modifier,
663 }
664 }
665}
666impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, ScalarResult, Output=AuthenticatedScalarResult);
667
668impl Sub<&AuthenticatedScalarResult> for &ScalarResult {
669 type Output = AuthenticatedScalarResult;
670
671 fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
672 let new_share = self - &rhs.share;
675
676 let new_modifier = -self - &rhs.public_modifier;
679 AuthenticatedScalarResult {
680 share: new_share,
681 mac: -&rhs.mac,
682 public_modifier: new_modifier,
683 }
684 }
685}
686impl_borrow_variants!(ScalarResult, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
687
688impl Sub<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
689 type Output = AuthenticatedScalarResult;
690
691 fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
692 AuthenticatedScalarResult {
693 share: &self.share - &rhs.share,
694 mac: &self.mac - &rhs.mac,
695 public_modifier: self.public_modifier.clone() - rhs.public_modifier.clone(),
696 }
697 }
698}
699impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
700
701impl AuthenticatedScalarResult {
702 pub fn batch_sub(
704 a: &[AuthenticatedScalarResult],
705 b: &[AuthenticatedScalarResult],
706 ) -> Vec<AuthenticatedScalarResult> {
707 assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
708
709 let n = a.len();
710 let fabric = &a[0].fabric();
711 let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
712
713 let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
715 all_ids,
716 AUTHENTICATED_SCALAR_RESULT_LEN * n, move |mut args| {
718 let arg_len = args.len();
719 let a_vals = args.drain(..arg_len / 2).collect_vec();
720 let b_vals = args;
721
722 let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
723 for (mut a_vals, mut b_vals) in a_vals
724 .into_iter()
725 .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
726 .into_iter()
727 .zip(
728 b_vals
729 .into_iter()
730 .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
731 .into_iter(),
732 )
733 {
734 let a_share: Scalar = a_vals.next().unwrap().into();
735 let a_mac_share: Scalar = a_vals.next().unwrap().into();
736 let a_modifier: Scalar = a_vals.next().unwrap().into();
737
738 let b_share: Scalar = b_vals.next().unwrap().into();
739 let b_mac_share: Scalar = b_vals.next().unwrap().into();
740 let b_modifier: Scalar = b_vals.next().unwrap().into();
741
742 result.push(ResultValue::Scalar(a_share - b_share));
743 result.push(ResultValue::Scalar(a_mac_share - b_mac_share));
744 result.push(ResultValue::Scalar(a_modifier - b_modifier));
745 }
746
747 result
748 },
749 );
750
751 AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
753 }
754
755 pub fn batch_sub_public(
757 a: &[AuthenticatedScalarResult],
758 b: &[ScalarResult],
759 ) -> Vec<AuthenticatedScalarResult> {
760 assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
761
762 let n = a.len();
763 let results_per_value = 3;
764 let fabric = a[0].fabric();
765 let all_ids = a
766 .iter()
767 .flat_map(|v| v.ids())
768 .chain(b.iter().map(|v| v.id()))
769 .collect_vec();
770
771 let party_id = fabric.party_id();
773 let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
774 all_ids,
775 results_per_value * n, move |mut args| {
777 let a_vals = args
779 .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
780 .collect_vec();
781 let public_values = args;
782
783 let mut result = Vec::with_capacity(results_per_value * n);
784 for (mut a_vals, public_value) in a_vals
785 .into_iter()
786 .chunks(results_per_value)
787 .into_iter()
788 .zip(public_values.into_iter())
789 {
790 let a_share: Scalar = a_vals.next().unwrap().into();
791 let a_mac_share: Scalar = a_vals.next().unwrap().into();
792 let a_modifier: Scalar = a_vals.next().unwrap().into();
793
794 let public_value: Scalar = public_value.into();
795
796 if party_id == PARTY0 {
798 result.push(ResultValue::Scalar(a_share - public_value));
799 } else {
800 result.push(ResultValue::Scalar(a_share));
801 }
802
803 result.push(ResultValue::Scalar(a_mac_share));
804 result.push(ResultValue::Scalar(a_modifier + public_value));
805 }
806
807 result
808 },
809 );
810
811 AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
813 }
814}
815
816impl Neg for &AuthenticatedScalarResult {
819 type Output = AuthenticatedScalarResult;
820
821 fn neg(self) -> Self::Output {
822 AuthenticatedScalarResult {
823 share: -&self.share,
824 mac: -&self.mac,
825 public_modifier: -&self.public_modifier,
826 }
827 }
828}
829impl_borrow_variants!(AuthenticatedScalarResult, Neg, neg, -);
830
831impl AuthenticatedScalarResult {
832 pub fn batch_neg(a: &[AuthenticatedScalarResult]) -> Vec<AuthenticatedScalarResult> {
834 if a.is_empty() {
835 return vec![];
836 }
837
838 let n = a.len();
839 let fabric = a[0].fabric();
840 let all_ids = a.iter().flat_map(|v| v.ids()).collect_vec();
841
842 let scalars = fabric.new_batch_gate_op(
843 all_ids,
844 AUTHENTICATED_SCALAR_RESULT_LEN * n, |args| {
846 args.into_iter()
847 .map(|arg| ResultValue::Scalar(-Scalar::from(arg)))
848 .collect()
849 },
850 );
851
852 AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
853 }
854}
855
856impl Mul<&Scalar> for &AuthenticatedScalarResult {
859 type Output = AuthenticatedScalarResult;
860
861 fn mul(self, rhs: &Scalar) -> Self::Output {
862 AuthenticatedScalarResult {
863 share: &self.share * rhs,
864 mac: &self.mac * rhs,
865 public_modifier: &self.public_modifier * rhs,
866 }
867 }
868}
869impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, Scalar, Output=AuthenticatedScalarResult);
870impl_commutative!(AuthenticatedScalarResult, Mul, mul, *, Scalar, Output=AuthenticatedScalarResult);
871
872impl Mul<&ScalarResult> for &AuthenticatedScalarResult {
873 type Output = AuthenticatedScalarResult;
874
875 fn mul(self, rhs: &ScalarResult) -> Self::Output {
876 AuthenticatedScalarResult {
877 share: &self.share * rhs,
878 mac: &self.mac * rhs,
879 public_modifier: &self.public_modifier * rhs,
880 }
881 }
882}
883impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, ScalarResult, Output=AuthenticatedScalarResult);
884impl_commutative!(AuthenticatedScalarResult, Mul, mul, *, ScalarResult, Output=AuthenticatedScalarResult);
885
886impl Mul<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
887 type Output = AuthenticatedScalarResult;
888
889 fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
891 let (a, b, c) = self.fabric().next_authenticated_triple();
893
894 let masked_lhs = self - &a;
896 let masked_rhs = rhs - &b;
897
898 let d = masked_lhs.open();
900 let e = masked_rhs.open();
901
902 &d * &e + d * b + e * a + c
906 }
907}
908impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
909
910impl AuthenticatedScalarResult {
911 pub fn batch_mul(
913 a: &[AuthenticatedScalarResult],
914 b: &[AuthenticatedScalarResult],
915 ) -> Vec<AuthenticatedScalarResult> {
916 assert_eq!(
917 a.len(),
918 b.len(),
919 "Cannot multiply batches of different sizes"
920 );
921
922 if a.is_empty() {
923 return vec![];
924 }
925
926 let n = a.len();
927 let fabric = a[0].fabric();
928 let (beaver_a, beaver_b, beaver_c) = fabric.next_authenticated_triple_batch(n);
929
930 let masked_lhs = AuthenticatedScalarResult::batch_sub(a, &beaver_a);
932 let masked_rhs = AuthenticatedScalarResult::batch_sub(b, &beaver_b);
933
934 let all_masks = [masked_lhs, masked_rhs].concat();
935 let opened_values = AuthenticatedScalarResult::open_batch(&all_masks);
936 let (d_open, e_open) = opened_values.split_at(n);
937
938 let de = ScalarResult::batch_mul(d_open, e_open);
940 let db = AuthenticatedScalarResult::batch_mul_public(&beaver_b, d_open);
941 let ea = AuthenticatedScalarResult::batch_mul_public(&beaver_a, e_open);
942
943 let de_plus_db = AuthenticatedScalarResult::batch_add_public(&db, &de);
945 let ea_plus_c = AuthenticatedScalarResult::batch_add(&ea, &beaver_c);
946 AuthenticatedScalarResult::batch_add(&de_plus_db, &ea_plus_c)
947 }
948
949 pub fn batch_mul_public(
951 a: &[AuthenticatedScalarResult],
952 b: &[ScalarResult],
953 ) -> Vec<AuthenticatedScalarResult> {
954 assert_eq!(
955 a.len(),
956 b.len(),
957 "Cannot multiply batches of different sizes"
958 );
959 if a.is_empty() {
960 return vec![];
961 }
962
963 let n = a.len();
964 let fabric = a[0].fabric();
965 let all_ids = a
966 .iter()
967 .flat_map(|a| a.ids())
968 .chain(b.iter().map(|b| b.id()))
969 .collect_vec();
970
971 let scalars = fabric.new_batch_gate_op(
972 all_ids,
973 AUTHENTICATED_SCALAR_RESULT_LEN * n, move |mut args| {
975 let a_vals = args
976 .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
977 .collect_vec();
978 let public_values = args;
979
980 let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
981 for (a_vals, public_values) in a_vals
982 .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
983 .zip(public_values.into_iter())
984 {
985 let a_share: Scalar = a_vals[0].to_owned().into();
986 let a_mac_share: Scalar = a_vals[1].to_owned().into();
987 let a_modifier: Scalar = a_vals[2].to_owned().into();
988
989 let public_value: Scalar = public_values.into();
990
991 result.push(ResultValue::Scalar(a_share * public_value));
992 result.push(ResultValue::Scalar(a_mac_share * public_value));
993 result.push(ResultValue::Scalar(a_modifier * public_value));
994 }
995
996 result
997 },
998 );
999
1000 AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
1001 }
1002}
1003
1004impl Mul<&AuthenticatedScalarResult> for &StarkPoint {
1007 type Output = AuthenticatedStarkPointResult;
1008
1009 fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
1010 AuthenticatedStarkPointResult {
1011 share: self * &rhs.share,
1012 mac: self * &rhs.mac,
1013 public_modifier: self * &rhs.public_modifier,
1014 }
1015 }
1016}
1017impl_commutative!(StarkPoint, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
1018
1019impl Mul<&AuthenticatedScalarResult> for &StarkPointResult {
1020 type Output = AuthenticatedStarkPointResult;
1021
1022 fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
1023 AuthenticatedStarkPointResult {
1024 share: self * &rhs.share,
1025 mac: self * &rhs.mac,
1026 public_modifier: self * &rhs.public_modifier,
1027 }
1028 }
1029}
1030impl_borrow_variants!(StarkPointResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
1031impl_commutative!(StarkPointResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
1032
1033#[cfg(feature = "test_helpers")]
1040pub mod test_helpers {
1041 use crate::algebra::scalar::Scalar;
1042
1043 use super::AuthenticatedScalarResult;
1044
1045 pub fn modify_mac(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
1047 val.mac = val.fabric().allocate_scalar(new_value).into()
1048 }
1049
1050 pub fn modify_share(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
1052 val.share = val.fabric().allocate_scalar(new_value).into()
1053 }
1054
1055 pub fn modify_public_modifier(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
1057 val.public_modifier = val.fabric().allocate_scalar(new_value)
1058 }
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063 use rand::thread_rng;
1064
1065 use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc, PARTY0};
1066
1067 #[tokio::test]
1069 async fn test_sub() {
1070 let mut rng = thread_rng();
1071 let value1 = Scalar::random(&mut rng);
1072 let value2 = Scalar::random(&mut rng);
1073
1074 let (res, _) = execute_mock_mpc(|fabric| async move {
1075 let party0_value = fabric.share_scalar(value1, PARTY0);
1077 let public_value = fabric.allocate_scalar(value2);
1078
1079 let res1 = &party0_value - &public_value;
1081 let res_open1 = res1.open_authenticated().await.unwrap();
1082 let expected1 = value1 - value2;
1083
1084 let res2 = &public_value - &party0_value;
1086 let res_open2 = res2.open_authenticated().await.unwrap();
1087 let expected2 = value2 - value1;
1088
1089 (res_open1 == expected1, res_open2 == expected2)
1090 })
1091 .await;
1092
1093 assert!(res.0);
1094 assert!(res.1)
1095 }
1096
1097 #[tokio::test]
1099 async fn test_sub_constant() {
1100 let mut rng = thread_rng();
1101 let value1 = Scalar::random(&mut rng);
1102 let value2 = Scalar::random(&mut rng);
1103
1104 let (res, _) = execute_mock_mpc(|fabric| async move {
1105 let party0_value = fabric.share_scalar(value1, PARTY0);
1107
1108 let res1 = &party0_value - value2;
1110 let res_open1 = res1.open_authenticated().await.unwrap();
1111 let expected1 = value1 - value2;
1112
1113 let res2 = value2 - &party0_value;
1115 let res_open2 = res2.open_authenticated().await.unwrap();
1116 let expected2 = value2 - value1;
1117
1118 (res_open1 == expected1, res_open2 == expected2)
1119 })
1120 .await;
1121
1122 assert!(res.0);
1123 assert!(res.1)
1124 }
1125
1126 #[tokio::test]
1128 async fn test_xor_circuit() {
1129 let (res, _) = execute_mock_mpc(|fabric| async move {
1130 let a = &fabric.zero_authenticated();
1131 let b = &fabric.zero_authenticated();
1132 let res = a + b - Scalar::from(2u64) * a * b;
1133
1134 res.open_authenticated().await
1135 })
1136 .await;
1137
1138 assert_eq!(res.unwrap(), 0.into());
1139 }
1140}