1use std::{
5 iter::Sum,
6 mem::size_of,
7 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
8};
9
10use ark_ec::{
11 hashing::{
12 curve_maps::swu::{SWUConfig, SWUMap},
13 map_to_curve_hasher::MapToCurve,
14 HashToCurveError,
15 },
16 short_weierstrass::Projective,
17 AffineRepr, CurveGroup,
18};
19use ark_ff::PrimeField;
20
21use ark_serialize::SerializationError;
22use itertools::Itertools;
23use serde::{de::Error as DeError, Deserialize, Serialize};
24
25use crate::{
26 algebra::{
27 macros::*, n_bytes_field, scalar::*, AUTHENTICATED_POINT_RESULT_LEN,
28 AUTHENTICATED_SCALAR_RESULT_LEN,
29 },
30 fabric::{ResultHandle, ResultValue},
31};
32
33use super::{authenticated_curve::AuthenticatedPointResult, mpc_curve::MpcPointResult};
34
35const MSM_CHUNK_SIZE: usize = 1 << 16;
38const MSM_SIZE_THRESHOLD: usize = 10;
43
44pub const HASH_TO_CURVE_SECURITY: usize = 16; #[derive(Copy, Clone, Debug, PartialEq, Eq)]
50pub struct CurvePoint<C: CurveGroup>(pub(crate) C);
51impl<C: CurveGroup> Unpin for CurvePoint<C> {}
52
53impl<C: CurveGroup> Serialize for CurvePoint<C> {
54 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
55 let bytes = self.to_bytes();
56 bytes.serialize(serializer)
57 }
58}
59
60impl<'de, C: CurveGroup> Deserialize<'de> for CurvePoint<C> {
61 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
62 let bytes = <Vec<u8>>::deserialize(deserializer)?;
63 CurvePoint::from_bytes(&bytes)
64 .map_err(|err| DeError::custom(format!("Failed to deserialize point: {err:?}")))
65 }
66}
67
68impl<C: CurveGroup> CurvePoint<C> {
73 pub type BaseField = C::BaseField;
76 pub type ScalarField = C::ScalarField;
79
80 pub fn identity() -> CurvePoint<C> {
82 CurvePoint(C::zero())
83 }
84
85 pub fn is_identity(&self) -> bool {
87 self == &CurvePoint::identity()
88 }
89
90 pub fn inner(&self) -> C {
92 self.0
93 }
94
95 pub fn to_affine(&self) -> C::Affine {
97 self.0.into_affine()
98 }
99
100 pub fn generator() -> CurvePoint<C> {
102 CurvePoint(C::generator())
103 }
104
105 pub fn to_bytes(&self) -> Vec<u8> {
107 let mut out: Vec<u8> = Vec::with_capacity(size_of::<CurvePoint<C>>());
108 self.0
109 .serialize_compressed(&mut out)
110 .expect("Failed to serialize point");
111
112 out
113 }
114
115 pub fn from_bytes(bytes: &[u8]) -> Result<CurvePoint<C>, SerializationError> {
117 let point = C::deserialize_compressed(bytes)?;
118 Ok(CurvePoint(point))
119 }
120}
121
122impl<C: CurveGroup> CurvePoint<C>
123where
124 C::BaseField: PrimeField,
125{
126 pub fn n_bytes() -> usize {
130 n_bytes_field::<C::BaseField>()
131 }
132}
133
134impl<C: CurveGroup> CurvePoint<C>
135where
136 C::Config: SWUConfig,
137 C::BaseField: PrimeField,
138{
139 pub fn from_uniform_bytes(
149 buf: Vec<u8>,
150 ) -> Result<CurvePoint<Projective<C::Config>>, HashToCurveError> {
151 let n_bytes = Self::n_bytes();
152 assert_eq!(
153 buf.len(),
154 2 * n_bytes,
155 "Invalid buffer length, must represent two curve points"
156 );
157
158 let f1 = Self::hash_to_field(&buf[..n_bytes / 2]);
160 let f2 = Self::hash_to_field(&buf[n_bytes / 2..]);
161
162 let mapper = SWUMap::<C::Config>::new()?;
164 let p1 = mapper.map_to_curve(f1)?;
165 let p2 = mapper.map_to_curve(f2)?;
166
167 let p1_clear = p1.clear_cofactor();
169 let p2_clear = p2.clear_cofactor();
170
171 Ok(CurvePoint(p1_clear + p2_clear))
172 }
173
174 fn hash_to_field(buf: &[u8]) -> C::BaseField {
177 Self::BaseField::from_be_bytes_mod_order(buf)
178 }
179}
180
181impl<C: CurveGroup> From<C> for CurvePoint<C> {
182 fn from(p: C) -> Self {
183 CurvePoint(p)
184 }
185}
186
187impl<C: CurveGroup> Add<&C> for &CurvePoint<C> {
194 type Output = CurvePoint<C>;
195
196 fn add(self, rhs: &C) -> Self::Output {
197 CurvePoint(self.0 + rhs)
198 }
199}
200impl_borrow_variants!(CurvePoint<C>, Add, add, +, C, C: CurveGroup);
201
202impl<C: CurveGroup> Add<&CurvePoint<C>> for &CurvePoint<C> {
203 type Output = CurvePoint<C>;
204
205 fn add(self, rhs: &CurvePoint<C>) -> Self::Output {
206 CurvePoint(self.0 + rhs.0)
207 }
208}
209impl_borrow_variants!(CurvePoint<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
210
211pub type CurvePointResult<C> = ResultHandle<C, CurvePoint<C>>;
213pub type BatchCurvePointResult<C> = ResultHandle<C, Vec<CurvePoint<C>>>;
215
216impl<C: CurveGroup> Add<&CurvePointResult<C>> for &CurvePointResult<C> {
217 type Output = CurvePointResult<C>;
218
219 fn add(self, rhs: &CurvePointResult<C>) -> Self::Output {
220 self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
221 let lhs: CurvePoint<C> = args[0].to_owned().into();
222 let rhs: CurvePoint<C> = args[1].to_owned().into();
223 ResultValue::Point(CurvePoint(lhs.0 + rhs.0))
224 })
225 }
226}
227impl_borrow_variants!(CurvePointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
228
229impl<C: CurveGroup> Add<&CurvePoint<C>> for &CurvePointResult<C> {
230 type Output = CurvePointResult<C>;
231
232 fn add(self, rhs: &CurvePoint<C>) -> Self::Output {
233 let rhs = *rhs;
234 self.fabric.new_gate_op(vec![self.id], move |args| {
235 let lhs: CurvePoint<C> = args[0].to_owned().into();
236 ResultValue::Point(CurvePoint(lhs.0 + rhs.0))
237 })
238 }
239}
240impl_borrow_variants!(CurvePointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
241impl_commutative!(CurvePointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
242
243impl<C: CurveGroup> CurvePointResult<C> {
244 pub fn batch_add(
246 a: &[CurvePointResult<C>],
247 b: &[CurvePointResult<C>],
248 ) -> Vec<CurvePointResult<C>> {
249 assert_eq!(
250 a.len(),
251 b.len(),
252 "batch_add cannot compute on vectors of unequal length"
253 );
254
255 let n = a.len();
256 let fabric = a[0].fabric();
257 let all_ids = a.iter().chain(b.iter()).map(|r| r.id).collect_vec();
258
259 fabric.new_batch_gate_op(all_ids, n , move |mut args| {
260 let a = args.drain(..n).map(CurvePoint::from).collect_vec();
261 let b = args.into_iter().map(CurvePoint::from).collect_vec();
262
263 a.into_iter()
264 .zip(b)
265 .map(|(a, b)| a + b)
266 .map(ResultValue::Point)
267 .collect_vec()
268 })
269 }
270}
271
272impl<C: CurveGroup> AddAssign for CurvePoint<C> {
275 fn add_assign(&mut self, rhs: Self) {
276 self.0 += rhs.0;
277 }
278}
279
280impl<C: CurveGroup> Sub<&CurvePoint<C>> for &CurvePoint<C> {
283 type Output = CurvePoint<C>;
284
285 fn sub(self, rhs: &CurvePoint<C>) -> Self::Output {
286 CurvePoint(self.0 - rhs.0)
287 }
288}
289impl_borrow_variants!(CurvePoint<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
290
291impl<C: CurveGroup> Sub<&CurvePointResult<C>> for &CurvePointResult<C> {
292 type Output = CurvePointResult<C>;
293
294 fn sub(self, rhs: &CurvePointResult<C>) -> Self::Output {
295 self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
296 let lhs: CurvePoint<C> = args[0].to_owned().into();
297 let rhs: CurvePoint<C> = args[1].to_owned().into();
298 ResultValue::Point(CurvePoint(lhs.0 - rhs.0))
299 })
300 }
301}
302impl_borrow_variants!(CurvePointResult<C>, Sub, sub, -, CurvePointResult<C>, C: CurveGroup);
303
304impl<C: CurveGroup> Sub<&CurvePoint<C>> for &CurvePointResult<C> {
305 type Output = CurvePointResult<C>;
306
307 fn sub(self, rhs: &CurvePoint<C>) -> Self::Output {
308 let rhs = *rhs;
309 self.fabric.new_gate_op(vec![self.id], move |args| {
310 let lhs: CurvePoint<C> = args[0].to_owned().into();
311 ResultValue::Point(CurvePoint(lhs.0 - rhs.0))
312 })
313 }
314}
315impl_borrow_variants!(CurvePointResult<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
316
317impl<C: CurveGroup> Sub<&CurvePointResult<C>> for &CurvePoint<C> {
318 type Output = CurvePointResult<C>;
319
320 fn sub(self, rhs: &CurvePointResult<C>) -> Self::Output {
321 let self_owned = *self;
322 rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
323 let rhs: CurvePoint<C> = args[0].to_owned().into();
324 ResultValue::Point(CurvePoint(self_owned.0 - rhs.0))
325 })
326 }
327}
328
329impl<C: CurveGroup> CurvePointResult<C> {
330 pub fn batch_sub(
332 a: &[CurvePointResult<C>],
333 b: &[CurvePointResult<C>],
334 ) -> Vec<CurvePointResult<C>> {
335 assert_eq!(
336 a.len(),
337 b.len(),
338 "batch_sub cannot compute on vectors of unequal length"
339 );
340
341 let n = a.len();
342 let fabric = a[0].fabric();
343 let all_ids = a.iter().chain(b.iter()).map(|r| r.id).collect_vec();
344
345 fabric.new_batch_gate_op(all_ids, n , move |mut args| {
346 let a = args.drain(..n).map(CurvePoint::from).collect_vec();
347 let b = args.into_iter().map(CurvePoint::from).collect_vec();
348
349 a.into_iter()
350 .zip(b)
351 .map(|(a, b)| a - b)
352 .map(ResultValue::Point)
353 .collect_vec()
354 })
355 }
356}
357
358impl<C: CurveGroup> SubAssign for CurvePoint<C> {
361 fn sub_assign(&mut self, rhs: Self) {
362 self.0 -= rhs.0;
363 }
364}
365
366impl<C: CurveGroup> Neg for &CurvePoint<C> {
369 type Output = CurvePoint<C>;
370
371 fn neg(self) -> Self::Output {
372 CurvePoint(-self.0)
373 }
374}
375impl_borrow_variants!(CurvePoint<C>, Neg, neg, -, C: CurveGroup);
376
377impl<C: CurveGroup> Neg for &CurvePointResult<C> {
378 type Output = CurvePointResult<C>;
379
380 fn neg(self) -> Self::Output {
381 self.fabric.new_gate_op(vec![self.id], |args| {
382 let lhs: CurvePoint<C> = args[0].to_owned().into();
383 ResultValue::Point(CurvePoint(-lhs.0))
384 })
385 }
386}
387impl_borrow_variants!(CurvePointResult<C>, Neg, neg, -, C:CurveGroup);
388
389impl<C: CurveGroup> CurvePointResult<C> {
390 pub fn batch_neg(a: &[CurvePointResult<C>]) -> Vec<CurvePointResult<C>> {
392 let n = a.len();
393 let fabric = a[0].fabric();
394 let all_ids = a.iter().map(|r| r.id).collect_vec();
395
396 fabric.new_batch_gate_op(all_ids, n , |args| {
397 args.into_iter()
398 .map(CurvePoint::from)
399 .map(CurvePoint::neg)
400 .map(ResultValue::Point)
401 .collect_vec()
402 })
403 }
404}
405
406impl<C: CurveGroup> Mul<&Scalar<C>> for &CurvePoint<C> {
409 type Output = CurvePoint<C>;
410
411 fn mul(self, rhs: &Scalar<C>) -> Self::Output {
412 CurvePoint(self.0 * rhs.0)
413 }
414}
415impl_borrow_variants!(CurvePoint<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
416impl_commutative!(CurvePoint<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
417
418impl<C: CurveGroup> Mul<&Scalar<C>> for &CurvePointResult<C> {
419 type Output = CurvePointResult<C>;
420
421 fn mul(self, rhs: &Scalar<C>) -> Self::Output {
422 let rhs = *rhs;
423 self.fabric.new_gate_op(vec![self.id], move |args| {
424 let lhs: CurvePoint<C> = args[0].to_owned().into();
425 ResultValue::Point(CurvePoint(lhs.0 * rhs.0))
426 })
427 }
428}
429impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
430impl_commutative!(CurvePointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
431
432impl<C: CurveGroup> Mul<&ScalarResult<C>> for &CurvePoint<C> {
433 type Output = CurvePointResult<C>;
434
435 fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
436 let self_owned = *self;
437 rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
438 let rhs: Scalar<C> = args[0].to_owned().into();
439 ResultValue::Point(CurvePoint(self_owned.0 * rhs.0))
440 })
441 }
442}
443impl_borrow_variants!(CurvePoint<C>, Mul, mul, *, ScalarResult<C>, Output=CurvePointResult<C>, C: CurveGroup);
444impl_commutative!(CurvePoint<C>, Mul, mul, *, ScalarResult<C>, Output=CurvePointResult<C>, C: CurveGroup);
445
446impl<C: CurveGroup> Mul<&ScalarResult<C>> for &CurvePointResult<C> {
447 type Output = CurvePointResult<C>;
448
449 fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
450 self.fabric.new_gate_op(vec![self.id, rhs.id], |mut args| {
451 let lhs: CurvePoint<C> = args.remove(0).into();
452 let rhs: Scalar<C> = args.remove(0).into();
453
454 ResultValue::Point(CurvePoint(lhs.0 * rhs.0))
455 })
456 }
457}
458impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
459impl_commutative!(CurvePointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
460
461impl<C: CurveGroup> CurvePointResult<C> {
462 pub fn batch_mul(a: &[ScalarResult<C>], b: &[CurvePointResult<C>]) -> Vec<CurvePointResult<C>> {
465 assert_eq!(
466 a.len(),
467 b.len(),
468 "batch_mul cannot compute on vectors of unequal length"
469 );
470
471 let n = a.len();
472 let fabric = a[0].fabric();
473 let all_ids = a
474 .iter()
475 .map(|a| a.id())
476 .chain(b.iter().map(|b| b.id()))
477 .collect_vec();
478
479 fabric.new_batch_gate_op(all_ids, n , move |mut args| {
480 let a = args.drain(..n).map(Scalar::from).collect_vec();
481 let b = args.into_iter().map(CurvePoint::from).collect_vec();
482
483 a.into_iter()
484 .zip(b)
485 .map(|(a, b)| a * b)
486 .map(ResultValue::Point)
487 .collect_vec()
488 })
489 }
490
491 pub fn batch_mul_shared(
494 a: &[MpcScalarResult<C>],
495 b: &[CurvePointResult<C>],
496 ) -> Vec<MpcPointResult<C>> {
497 assert_eq!(
498 a.len(),
499 b.len(),
500 "batch_mul_shared cannot compute on vectors of unequal length"
501 );
502
503 let n = a.len();
504 let fabric = a[0].fabric();
505 let all_ids = a
506 .iter()
507 .map(|a| a.id())
508 .chain(b.iter().map(|b| b.id()))
509 .collect_vec();
510
511 fabric
512 .new_batch_gate_op(all_ids, n , move |mut args| {
513 let a = args.drain(..n).map(Scalar::from).collect_vec();
514 let b = args.into_iter().map(CurvePoint::from).collect_vec();
515
516 a.into_iter()
517 .zip(b)
518 .map(|(a, b)| a * b)
519 .map(ResultValue::Point)
520 .collect_vec()
521 })
522 .into_iter()
523 .map(MpcPointResult::from)
524 .collect_vec()
525 }
526
527 pub fn batch_mul_authenticated(
530 a: &[AuthenticatedScalarResult<C>],
531 b: &[CurvePointResult<C>],
532 ) -> Vec<AuthenticatedPointResult<C>> {
533 assert_eq!(
534 a.len(),
535 b.len(),
536 "batch_mul_authenticated cannot compute on vectors of unequal length"
537 );
538
539 let n = a.len();
540 let fabric = a[0].fabric();
541 let all_ids = b
542 .iter()
543 .map(|b| b.id())
544 .chain(a.iter().flat_map(|a| a.ids()))
545 .collect_vec();
546
547 let results = fabric.new_batch_gate_op(
548 all_ids,
549 AUTHENTICATED_POINT_RESULT_LEN * n, move |mut args| {
551 let points: Vec<CurvePoint<C>> =
552 args.drain(..n).map(CurvePoint::from).collect_vec();
553
554 let mut results = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
555
556 for (scalars, point) in args
557 .chunks_exact(AUTHENTICATED_SCALAR_RESULT_LEN)
558 .zip(points.into_iter())
559 {
560 let share = Scalar::from(&scalars[0]);
561 let mac = Scalar::from(&scalars[1]);
562 let public_modifier = Scalar::from(&scalars[2]);
563
564 results.push(ResultValue::Point(point * share));
565 results.push(ResultValue::Point(point * mac));
566 results.push(ResultValue::Point(point * public_modifier));
567 }
568
569 results
570 },
571 );
572
573 AuthenticatedPointResult::from_flattened_iterator(results.into_iter())
574 }
575}
576
577impl<C: CurveGroup> MulAssign<&Scalar<C>> for CurvePoint<C> {
580 fn mul_assign(&mut self, rhs: &Scalar<C>) {
581 self.0 *= rhs.0;
582 }
583}
584
585impl<C: CurveGroup> Sum for CurvePoint<C> {
590 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
591 iter.fold(CurvePoint::identity(), |acc, x| acc + x)
592 }
593}
594
595impl<C: CurveGroup> Sum for CurvePointResult<C> {
596 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
598 let first = iter.next().expect("empty iterator");
599 iter.fold(first, |acc, x| acc + x)
600 }
601}
602
603impl<C: CurveGroup> CurvePoint<C> {
605 pub fn msm(scalars: &[Scalar<C>], points: &[CurvePoint<C>]) -> CurvePoint<C> {
607 assert_eq!(
608 scalars.len(),
609 points.len(),
610 "msm cannot compute on vectors of unequal length"
611 );
612
613 let n = scalars.len();
614 if n < MSM_SIZE_THRESHOLD {
615 return scalars.iter().zip(points.iter()).map(|(s, p)| s * p).sum();
616 }
617
618 let affine_points = points.iter().map(|p| p.0.into_affine()).collect_vec();
619 let stripped_scalars = scalars.iter().map(|s| s.0).collect_vec();
620 C::msm(&affine_points, &stripped_scalars)
621 .map(CurvePoint)
622 .unwrap()
623 }
624
625 pub fn msm_iter<I, J>(scalars: I, points: J) -> CurvePoint<C>
628 where
629 I: IntoIterator<Item = Scalar<C>>,
630 J: IntoIterator<Item = CurvePoint<C>>,
631 {
632 let mut res = CurvePoint::identity();
633 for (scalar_chunk, point_chunk) in scalars
634 .into_iter()
635 .chunks(MSM_CHUNK_SIZE)
636 .into_iter()
637 .zip(points.into_iter().chunks(MSM_CHUNK_SIZE).into_iter())
638 {
639 let scalars: Vec<Scalar<C>> = scalar_chunk.collect();
640 let points: Vec<CurvePoint<C>> = point_chunk.collect();
641 let chunk_res = CurvePoint::msm(&scalars, &points);
642
643 res += chunk_res;
644 }
645
646 res
647 }
648
649 pub fn msm_results(
652 scalars: &[ScalarResult<C>],
653 points: &[CurvePoint<C>],
654 ) -> CurvePointResult<C> {
655 assert_eq!(
656 scalars.len(),
657 points.len(),
658 "msm cannot compute on vectors of unequal length"
659 );
660
661 let fabric = scalars[0].fabric();
662 let scalar_ids = scalars.iter().map(|s| s.id()).collect_vec();
663
664 let points = points.to_vec();
666 fabric.new_gate_op(scalar_ids, move |args| {
667 let scalars = args.into_iter().map(Scalar::from).collect_vec();
668
669 ResultValue::Point(CurvePoint::msm(&scalars, &points))
670 })
671 }
672
673 pub fn msm_results_iter<I, J>(scalars: I, points: J) -> CurvePointResult<C>
676 where
677 I: IntoIterator<Item = ScalarResult<C>>,
678 J: IntoIterator<Item = CurvePoint<C>>,
679 {
680 Self::msm_results(
681 &scalars.into_iter().collect_vec(),
682 &points.into_iter().collect_vec(),
683 )
684 }
685
686 pub fn msm_authenticated(
689 scalars: &[AuthenticatedScalarResult<C>],
690 points: &[CurvePoint<C>],
691 ) -> AuthenticatedPointResult<C> {
692 assert_eq!(
693 scalars.len(),
694 points.len(),
695 "msm cannot compute on vectors of unequal length"
696 );
697
698 let n = scalars.len();
699 let fabric = scalars[0].fabric();
700 let scalar_ids = scalars.iter().flat_map(|s| s.ids()).collect_vec();
701
702 let points = points.to_vec();
704 let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
705 scalar_ids,
706 AUTHENTICATED_SCALAR_RESULT_LEN, move |args| {
708 let mut shares = Vec::with_capacity(n);
709 let mut macs = Vec::with_capacity(n);
710 let mut modifiers = Vec::with_capacity(n);
711
712 for chunk in args.chunks_exact(AUTHENTICATED_SCALAR_RESULT_LEN) {
713 shares.push(Scalar::from(chunk[0].to_owned()));
714 macs.push(Scalar::from(chunk[1].to_owned()));
715 modifiers.push(Scalar::from(chunk[2].to_owned()));
716 }
717
718 vec![
720 CurvePoint::msm(&shares, &points),
721 CurvePoint::msm(&macs, &points),
722 CurvePoint::msm(&modifiers, &points),
723 ]
724 .into_iter()
725 .map(ResultValue::Point)
726 .collect_vec()
727 },
728 );
729
730 AuthenticatedPointResult {
731 share: res[0].to_owned().into(),
732 mac: res[1].to_owned().into(),
733 public_modifier: res[2].to_owned(),
734 }
735 }
736
737 pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedPointResult<C>
741 where
742 I: IntoIterator<Item = AuthenticatedScalarResult<C>>,
743 J: IntoIterator<Item = CurvePoint<C>>,
744 {
745 let scalars: Vec<AuthenticatedScalarResult<C>> = scalars.into_iter().collect();
746 let points: Vec<CurvePoint<C>> = points.into_iter().collect();
747
748 Self::msm_authenticated(&scalars, &points)
749 }
750}
751
752impl<C: CurveGroup> CurvePointResult<C> {
753 pub fn msm_results(
755 scalars: &[ScalarResult<C>],
756 points: &[CurvePointResult<C>],
757 ) -> CurvePointResult<C> {
758 assert!(!scalars.is_empty(), "msm cannot compute on an empty vector");
759 assert_eq!(
760 scalars.len(),
761 points.len(),
762 "msm cannot compute on vectors of unequal length"
763 );
764
765 let n = scalars.len();
766 let fabric = scalars[0].fabric();
767 let all_ids = scalars
768 .iter()
769 .map(|s| s.id())
770 .chain(points.iter().map(|p| p.id()))
771 .collect_vec();
772
773 fabric.new_gate_op(all_ids, move |mut args| {
774 let scalars = args.drain(..n).map(Scalar::from).collect_vec();
775 let points = args.into_iter().map(CurvePoint::from).collect_vec();
776
777 let res = CurvePoint::msm(&scalars, &points);
778 ResultValue::Point(res)
779 })
780 }
781
782 pub fn msm_results_iter<I, J>(scalars: I, points: J) -> CurvePointResult<C>
787 where
788 I: IntoIterator<Item = ScalarResult<C>>,
789 J: IntoIterator<Item = CurvePointResult<C>>,
790 {
791 Self::msm_results(
792 &scalars.into_iter().collect_vec(),
793 &points.into_iter().collect_vec(),
794 )
795 }
796
797 pub fn msm_authenticated(
800 scalars: &[AuthenticatedScalarResult<C>],
801 points: &[CurvePointResult<C>],
802 ) -> AuthenticatedPointResult<C> {
803 assert_eq!(
804 scalars.len(),
805 points.len(),
806 "msm cannot compute on vectors of unequal length"
807 );
808
809 let n = scalars.len();
810 let fabric = scalars[0].fabric();
811 let all_ids = scalars
812 .iter()
813 .flat_map(|s| s.ids())
814 .chain(points.iter().map(|p| p.id()))
815 .collect_vec();
816
817 let res = fabric.new_batch_gate_op(
818 all_ids,
819 AUTHENTICATED_POINT_RESULT_LEN, move |mut args| {
821 let mut shares = Vec::with_capacity(n);
822 let mut macs = Vec::with_capacity(n);
823 let mut modifiers = Vec::with_capacity(n);
824
825 for mut chunk in args
826 .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
827 .map(Scalar::from)
828 .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
829 .into_iter()
830 {
831 shares.push(chunk.next().unwrap());
832 macs.push(chunk.next().unwrap());
833 modifiers.push(chunk.next().unwrap());
834 }
835
836 let points = args.into_iter().map(CurvePoint::from).collect_vec();
837
838 vec![
839 CurvePoint::msm(&shares, &points),
840 CurvePoint::msm(&macs, &points),
841 CurvePoint::msm(&modifiers, &points),
842 ]
843 .into_iter()
844 .map(ResultValue::Point)
845 .collect_vec()
846 },
847 );
848
849 AuthenticatedPointResult {
850 share: res[0].to_owned().into(),
851 mac: res[1].to_owned().into(),
852 public_modifier: res[2].to_owned(),
853 }
854 }
855
856 pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedPointResult<C>
860 where
861 I: IntoIterator<Item = AuthenticatedScalarResult<C>>,
862 J: IntoIterator<Item = CurvePointResult<C>>,
863 {
864 let scalars: Vec<AuthenticatedScalarResult<C>> = scalars.into_iter().collect();
865 let points: Vec<CurvePointResult<C>> = points.into_iter().collect();
866
867 Self::msm_authenticated(&scalars, &points)
868 }
869}
870
871#[cfg(test)]
876mod test {
877 use rand::thread_rng;
878
879 use crate::{test_helpers::mock_fabric, test_helpers::TestCurve};
880
881 use super::*;
882
883 pub type TestCurvePoint = CurvePoint<TestCurve>;
885
886 pub fn random_point() -> TestCurvePoint {
889 let mut rng = thread_rng();
890 let scalar = Scalar::random(&mut rng);
891 let point = TestCurvePoint::generator() * scalar;
892 point * scalar
893 }
894
895 #[tokio::test]
897 async fn test_point_addition() {
898 let fabric = mock_fabric();
899
900 let p1 = random_point();
901 let p2 = random_point();
902
903 let p1_res = fabric.allocate_point(p1);
904 let p2_res = fabric.allocate_point(p2);
905
906 let res = (p1_res + p2_res).await;
907 let expected_res = p1 + p2;
908
909 assert_eq!(res, expected_res);
910 fabric.shutdown();
911 }
912
913 #[tokio::test]
915 async fn test_scalar_mul() {
916 let fabric = mock_fabric();
917
918 let mut rng = thread_rng();
919 let s1 = Scalar::<TestCurve>::random(&mut rng);
920 let p1 = random_point();
921
922 let s1_res = fabric.allocate_scalar(s1);
923 let p1_res = fabric.allocate_point(p1);
924
925 let res = (s1_res * p1_res).await;
926 let expected_res = s1 * p1;
927
928 assert_eq!(res, expected_res);
929 fabric.shutdown();
930 }
931
932 #[tokio::test]
934 async fn test_additive_identity() {
935 let fabric = mock_fabric();
936
937 let p1 = random_point();
938
939 let p1_res = fabric.allocate_point(p1);
940 let identity_res = fabric.curve_identity();
941
942 let res = (p1_res + identity_res).await;
943 let expected_res = p1;
944
945 assert_eq!(res, expected_res);
946 fabric.shutdown();
947 }
948}