1use std::{
4 iter::Sum,
5 mem::size_of,
6 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
7};
8
9use ark_ec::{
10 hashing::{
11 curve_maps::swu::{SWUConfig, SWUMap},
12 map_to_curve_hasher::MapToCurve,
13 HashToCurveError,
14 },
15 short_weierstrass::{Affine, Projective, SWCurveConfig},
16 CurveConfig, CurveGroup, Group, VariableBaseMSM,
17};
18use ark_ff::{BigInt, MontFp, PrimeField, Zero};
19
20use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError};
21use itertools::Itertools;
22use num_bigint::BigUint;
23use serde::{de::Error as DeError, Deserialize, Serialize};
24
25use crate::{
26 algebra::{
27 authenticated_scalar::AUTHENTICATED_SCALAR_RESULT_LEN,
28 authenticated_stark_point::AUTHENTICATED_STARK_POINT_RESULT_LEN,
29 },
30 fabric::{ResultHandle, ResultValue},
31};
32
33use super::{
34 authenticated_scalar::AuthenticatedScalarResult,
35 authenticated_stark_point::AuthenticatedStarkPointResult,
36 macros::{impl_borrow_variants, impl_commutative},
37 mpc_scalar::MpcScalarResult,
38 mpc_stark_point::MpcStarkPointResult,
39 scalar::{Scalar, ScalarInner, ScalarResult, StarknetBaseFelt, BASE_FIELD_BYTES},
40};
41
42const MSM_CHUNK_SIZE: usize = 1 << 16;
45const MSM_SIZE_THRESHOLD: usize = 10;
50
51pub const HASH_TO_CURVE_SECURITY: usize = 16; pub const STARK_POINT_BYTES: usize = 32;
55pub const STARK_UNIFORM_BYTES: usize = 2 * (BASE_FIELD_BYTES + HASH_TO_CURVE_SECURITY);
58
59pub struct StarknetCurveConfig;
61impl CurveConfig for StarknetCurveConfig {
62 type BaseField = StarknetBaseFelt;
63 type ScalarField = ScalarInner;
64
65 const COFACTOR: &'static [u64] = &[1];
66 const COFACTOR_INV: Self::ScalarField = MontFp!("1");
67}
68
69impl SWCurveConfig for StarknetCurveConfig {
72 const COEFF_A: Self::BaseField = MontFp!("1");
73 const COEFF_B: Self::BaseField =
74 MontFp!("3141592653589793238462643383279502884197169399375105820974944592307816406665");
75
76 const GENERATOR: Affine<Self> = Affine {
77 x: MontFp!("874739451078007766457464989774322083649278607533249481151382481072868806602"),
78 y: MontFp!("152666792071518830868575557812948353041420400780739481342941381225525861407"),
79 infinity: false,
80 };
81}
82
83impl SWUConfig for StarknetCurveConfig {
85 const ZETA: Self::BaseField = MontFp!("3");
86}
87
88pub(crate) type StarkPointInner = Projective<StarknetCurveConfig>;
90#[derive(Copy, Clone, Debug, PartialEq, Eq)]
92pub struct StarkPoint(pub(crate) StarkPointInner);
93
94impl Serialize for StarkPoint {
95 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
96 let bytes = self.to_bytes();
97 bytes.serialize(serializer)
98 }
99}
100
101impl<'de> Deserialize<'de> for StarkPoint {
102 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
103 let bytes = <Vec<u8>>::deserialize(deserializer)?;
104 StarkPoint::from_bytes(&bytes)
105 .map_err(|err| DeError::custom(format!("Failed to deserialize point: {err:?}")))
106 }
107}
108
109impl StarkPoint {
114 pub fn identity() -> StarkPoint {
116 StarkPoint(StarkPointInner::zero())
117 }
118
119 pub fn is_identity(&self) -> bool {
121 self == &StarkPoint::identity()
122 }
123
124 pub fn to_affine(&self) -> Affine<StarknetCurveConfig> {
126 self.0.into_affine()
127 }
128
129 pub fn from_affine_coords(x: BigUint, y: BigUint) -> Self {
131 let x_bigint = BigInt::try_from(x).unwrap();
132 let y_bigint = BigInt::try_from(y).unwrap();
133 let x = StarknetBaseFelt::from(x_bigint);
134 let y = StarknetBaseFelt::from(y_bigint);
135
136 let aff = Affine {
137 x,
138 y,
139 infinity: false,
140 };
141
142 Self(aff.into())
143 }
144
145 pub fn generator() -> StarkPoint {
147 StarkPoint(StarkPointInner::generator())
148 }
149
150 pub fn to_bytes(&self) -> Vec<u8> {
152 let mut out: Vec<u8> = Vec::with_capacity(size_of::<StarkPoint>());
153 self.0
154 .serialize_compressed(&mut out)
155 .expect("Failed to serialize point");
156
157 out
158 }
159
160 pub fn from_bytes(bytes: &[u8]) -> Result<StarkPoint, SerializationError> {
162 let point = StarkPointInner::deserialize_compressed(bytes)?;
163 Ok(StarkPoint(point))
164 }
165
166 pub fn from_uniform_bytes(
174 buf: [u8; STARK_UNIFORM_BYTES],
175 ) -> Result<StarkPoint, HashToCurveError> {
176 let f1 = Self::hash_to_field(&buf[..STARK_UNIFORM_BYTES / 2]);
178 let f2 = Self::hash_to_field(&buf[STARK_UNIFORM_BYTES / 2..]);
179
180 let mapper = SWUMap::<StarknetCurveConfig>::new()?;
182 let p1 = mapper.map_to_curve(f1)?;
183 let p2 = mapper.map_to_curve(f2)?;
184
185 Ok(StarkPoint(p1 + p2))
188 }
189
190 fn hash_to_field(buf: &[u8]) -> StarknetBaseFelt {
192 StarknetBaseFelt::from_be_bytes_mod_order(buf)
193 }
194}
195
196impl From<StarkPointInner> for StarkPoint {
197 fn from(p: StarkPointInner) -> Self {
198 StarkPoint(p)
199 }
200}
201
202impl Add<&StarkPointInner> for &StarkPoint {
209 type Output = StarkPoint;
210
211 fn add(self, rhs: &StarkPointInner) -> Self::Output {
212 StarkPoint(self.0 + rhs)
213 }
214}
215impl_borrow_variants!(StarkPoint, Add, add, +, StarkPointInner);
216impl_commutative!(StarkPoint, Add, add, +, StarkPointInner);
217
218impl Add<&StarkPoint> for &StarkPoint {
219 type Output = StarkPoint;
220
221 fn add(self, rhs: &StarkPoint) -> Self::Output {
222 StarkPoint(self.0 + rhs.0)
223 }
224}
225impl_borrow_variants!(StarkPoint, Add, add, +, StarkPoint);
226
227pub type StarkPointResult = ResultHandle<StarkPoint>;
229pub type BatchStarkPointResult = ResultHandle<Vec<StarkPoint>>;
231
232impl Add<&StarkPointResult> for &StarkPointResult {
233 type Output = StarkPointResult;
234
235 fn add(self, rhs: &StarkPointResult) -> Self::Output {
236 self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
237 let lhs: StarkPoint = args[0].to_owned().into();
238 let rhs: StarkPoint = args[1].to_owned().into();
239 ResultValue::Point(StarkPoint(lhs.0 + rhs.0))
240 })
241 }
242}
243impl_borrow_variants!(StarkPointResult, Add, add, +, StarkPointResult);
244
245impl Add<&StarkPoint> for &StarkPointResult {
246 type Output = StarkPointResult;
247
248 fn add(self, rhs: &StarkPoint) -> Self::Output {
249 let rhs = *rhs;
250 self.fabric.new_gate_op(vec![self.id], move |args| {
251 let lhs: StarkPoint = args[0].to_owned().into();
252 ResultValue::Point(StarkPoint(lhs.0 + rhs.0))
253 })
254 }
255}
256impl_borrow_variants!(StarkPointResult, Add, add, +, StarkPoint);
257impl_commutative!(StarkPointResult, Add, add, +, StarkPoint);
258
259impl StarkPointResult {
260 pub fn batch_add(a: &[StarkPointResult], b: &[StarkPointResult]) -> Vec<StarkPointResult> {
262 assert_eq!(
263 a.len(),
264 b.len(),
265 "batch_add cannot compute on vectors of unequal length"
266 );
267
268 let n = a.len();
269 let fabric = a[0].fabric();
270 let all_ids = a.iter().chain(b.iter()).map(|r| r.id).collect_vec();
271
272 fabric.new_batch_gate_op(all_ids, n , move |mut args| {
273 let a = args.drain(..n).map(StarkPoint::from).collect_vec();
274 let b = args.into_iter().map(StarkPoint::from).collect_vec();
275
276 a.into_iter()
277 .zip(b.into_iter())
278 .map(|(a, b)| a + b)
279 .map(ResultValue::Point)
280 .collect_vec()
281 })
282 }
283}
284
285impl AddAssign for StarkPoint {
288 fn add_assign(&mut self, rhs: Self) {
289 self.0 += rhs.0;
290 }
291}
292
293impl Sub<&StarkPoint> for &StarkPoint {
296 type Output = StarkPoint;
297
298 fn sub(self, rhs: &StarkPoint) -> Self::Output {
299 StarkPoint(self.0 - rhs.0)
300 }
301}
302impl_borrow_variants!(StarkPoint, Sub, sub, -, StarkPoint);
303
304impl Sub<&StarkPointResult> for &StarkPointResult {
305 type Output = StarkPointResult;
306
307 fn sub(self, rhs: &StarkPointResult) -> Self::Output {
308 self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
309 let lhs: StarkPoint = args[0].to_owned().into();
310 let rhs: StarkPoint = args[1].to_owned().into();
311 ResultValue::Point(StarkPoint(lhs.0 - rhs.0))
312 })
313 }
314}
315impl_borrow_variants!(StarkPointResult, Sub, sub, -, StarkPointResult);
316
317impl Sub<&StarkPoint> for &StarkPointResult {
318 type Output = StarkPointResult;
319
320 fn sub(self, rhs: &StarkPoint) -> Self::Output {
321 let rhs = *rhs;
322 self.fabric.new_gate_op(vec![self.id], move |args| {
323 let lhs: StarkPoint = args[0].to_owned().into();
324 ResultValue::Point(StarkPoint(lhs.0 - rhs.0))
325 })
326 }
327}
328impl_borrow_variants!(StarkPointResult, Sub, sub, -, StarkPoint);
329
330impl Sub<&StarkPointResult> for &StarkPoint {
331 type Output = StarkPointResult;
332
333 fn sub(self, rhs: &StarkPointResult) -> Self::Output {
334 let self_owned = *self;
335 rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
336 let rhs: StarkPoint = args[0].to_owned().into();
337 ResultValue::Point(StarkPoint(self_owned.0 - rhs.0))
338 })
339 }
340}
341
342impl StarkPointResult {
343 pub fn batch_sub(a: &[StarkPointResult], b: &[StarkPointResult]) -> Vec<StarkPointResult> {
345 assert_eq!(
346 a.len(),
347 b.len(),
348 "batch_sub cannot compute on vectors of unequal length"
349 );
350
351 let n = a.len();
352 let fabric = a[0].fabric();
353 let all_ids = a.iter().chain(b.iter()).map(|r| r.id).collect_vec();
354
355 fabric.new_batch_gate_op(all_ids, n , move |mut args| {
356 let a = args.drain(..n).map(StarkPoint::from).collect_vec();
357 let b = args.into_iter().map(StarkPoint::from).collect_vec();
358
359 a.into_iter()
360 .zip(b.into_iter())
361 .map(|(a, b)| a - b)
362 .map(ResultValue::Point)
363 .collect_vec()
364 })
365 }
366}
367
368impl SubAssign for StarkPoint {
371 fn sub_assign(&mut self, rhs: Self) {
372 self.0 -= rhs.0;
373 }
374}
375
376impl Neg for &StarkPoint {
379 type Output = StarkPoint;
380
381 fn neg(self) -> Self::Output {
382 StarkPoint(-self.0)
383 }
384}
385impl_borrow_variants!(StarkPoint, Neg, neg, -);
386
387impl Neg for &StarkPointResult {
388 type Output = StarkPointResult;
389
390 fn neg(self) -> Self::Output {
391 self.fabric.new_gate_op(vec![self.id], |args| {
392 let lhs: StarkPoint = args[0].to_owned().into();
393 ResultValue::Point(StarkPoint(-lhs.0))
394 })
395 }
396}
397impl_borrow_variants!(StarkPointResult, Neg, neg, -);
398
399impl StarkPointResult {
400 pub fn batch_neg(a: &[StarkPointResult]) -> Vec<StarkPointResult> {
402 let n = a.len();
403 let fabric = a[0].fabric();
404 let all_ids = a.iter().map(|r| r.id).collect_vec();
405
406 fabric.new_batch_gate_op(all_ids, n , |args| {
407 args.into_iter()
408 .map(StarkPoint::from)
409 .map(StarkPoint::neg)
410 .map(ResultValue::Point)
411 .collect_vec()
412 })
413 }
414}
415
416impl Mul<&Scalar> for &StarkPoint {
419 type Output = StarkPoint;
420
421 fn mul(self, rhs: &Scalar) -> Self::Output {
422 StarkPoint(self.0 * rhs.0)
423 }
424}
425impl_borrow_variants!(StarkPoint, Mul, mul, *, Scalar);
426impl_commutative!(StarkPoint, Mul, mul, *, Scalar);
427
428impl Mul<&Scalar> for &StarkPointResult {
429 type Output = StarkPointResult;
430
431 fn mul(self, rhs: &Scalar) -> Self::Output {
432 let rhs = *rhs;
433 self.fabric.new_gate_op(vec![self.id], move |args| {
434 let lhs: StarkPoint = args[0].to_owned().into();
435 ResultValue::Point(StarkPoint(lhs.0 * rhs.0))
436 })
437 }
438}
439impl_borrow_variants!(StarkPointResult, Mul, mul, *, Scalar);
440impl_commutative!(StarkPointResult, Mul, mul, *, Scalar);
441
442impl Mul<&ScalarResult> for &StarkPoint {
443 type Output = StarkPointResult;
444
445 fn mul(self, rhs: &ScalarResult) -> Self::Output {
446 let self_owned = *self;
447 rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
448 let rhs: Scalar = args[0].to_owned().into();
449 ResultValue::Point(StarkPoint(self_owned.0 * rhs.0))
450 })
451 }
452}
453impl_borrow_variants!(StarkPoint, Mul, mul, *, ScalarResult, Output=StarkPointResult);
454impl_commutative!(StarkPoint, Mul, mul, *, ScalarResult, Output=StarkPointResult);
455
456impl Mul<&ScalarResult> for &StarkPointResult {
457 type Output = StarkPointResult;
458
459 fn mul(self, rhs: &ScalarResult) -> Self::Output {
460 self.fabric.new_gate_op(vec![self.id, rhs.id], |mut args| {
461 let lhs: StarkPoint = args.remove(0).into();
462 let rhs: Scalar = args.remove(0).into();
463
464 ResultValue::Point(StarkPoint(lhs.0 * rhs.0))
465 })
466 }
467}
468impl_borrow_variants!(StarkPointResult, Mul, mul, *, ScalarResult);
469impl_commutative!(StarkPointResult, Mul, mul, *, ScalarResult);
470
471impl StarkPointResult {
472 pub fn batch_mul(a: &[ScalarResult], b: &[StarkPointResult]) -> Vec<StarkPointResult> {
474 assert_eq!(
475 a.len(),
476 b.len(),
477 "batch_mul cannot compute on vectors of unequal length"
478 );
479
480 let n = a.len();
481 let fabric = a[0].fabric();
482 let all_ids = a
483 .iter()
484 .map(|a| a.id())
485 .chain(b.iter().map(|b| b.id()))
486 .collect_vec();
487
488 fabric.new_batch_gate_op(all_ids, n , move |mut args| {
489 let a = args.drain(..n).map(Scalar::from).collect_vec();
490 let b = args.into_iter().map(StarkPoint::from).collect_vec();
491
492 a.into_iter()
493 .zip(b.into_iter())
494 .map(|(a, b)| a * b)
495 .map(ResultValue::Point)
496 .collect_vec()
497 })
498 }
499
500 pub fn batch_mul_shared(
502 a: &[MpcScalarResult],
503 b: &[StarkPointResult],
504 ) -> Vec<MpcStarkPointResult> {
505 assert_eq!(
506 a.len(),
507 b.len(),
508 "batch_mul_shared cannot compute on vectors of unequal length"
509 );
510
511 let n = a.len();
512 let fabric = a[0].fabric();
513 let all_ids = a
514 .iter()
515 .map(|a| a.id())
516 .chain(b.iter().map(|b| b.id()))
517 .collect_vec();
518
519 fabric
520 .new_batch_gate_op(all_ids, n , move |mut args| {
521 let a = args.drain(..n).map(Scalar::from).collect_vec();
522 let b = args.into_iter().map(StarkPoint::from).collect_vec();
523
524 a.into_iter()
525 .zip(b.into_iter())
526 .map(|(a, b)| a * b)
527 .map(ResultValue::Point)
528 .collect_vec()
529 })
530 .into_iter()
531 .map(MpcStarkPointResult::from)
532 .collect_vec()
533 }
534
535 pub fn batch_mul_authenticated(
537 a: &[AuthenticatedScalarResult],
538 b: &[StarkPointResult],
539 ) -> Vec<AuthenticatedStarkPointResult> {
540 assert_eq!(
541 a.len(),
542 b.len(),
543 "batch_mul_authenticated cannot compute on vectors of unequal length"
544 );
545
546 let n = a.len();
547 let fabric = a[0].fabric();
548 let all_ids = b
549 .iter()
550 .map(|b| b.id())
551 .chain(a.iter().flat_map(|a| a.ids()))
552 .collect_vec();
553
554 let results = fabric.new_batch_gate_op(
555 all_ids,
556 AUTHENTICATED_STARK_POINT_RESULT_LEN * n, move |mut args| {
558 let points: Vec<StarkPoint> = args.drain(..n).map(StarkPoint::from).collect_vec();
559
560 let mut results = Vec::with_capacity(AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
561
562 for (scalars, point) in args
563 .chunks_exact(AUTHENTICATED_SCALAR_RESULT_LEN)
564 .zip(points.into_iter())
565 {
566 let share = Scalar::from(&scalars[0]);
567 let mac = Scalar::from(&scalars[1]);
568 let public_modifier = Scalar::from(&scalars[2]);
569
570 results.push(ResultValue::Point(point * share));
571 results.push(ResultValue::Point(point * mac));
572 results.push(ResultValue::Point(point * public_modifier));
573 }
574
575 results
576 },
577 );
578
579 AuthenticatedStarkPointResult::from_flattened_iterator(results.into_iter())
580 }
581}
582
583impl MulAssign<&Scalar> for StarkPoint {
586 fn mul_assign(&mut self, rhs: &Scalar) {
587 self.0 *= rhs.0;
588 }
589}
590
591impl Sum for StarkPoint {
596 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
597 iter.fold(StarkPoint::identity(), |acc, x| acc + x)
598 }
599}
600
601impl Sum for StarkPointResult {
602 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
604 let first = iter.next().expect("empty iterator");
605 iter.fold(first, |acc, x| acc + x)
606 }
607}
608
609impl StarkPoint {
611 pub fn msm(scalars: &[Scalar], points: &[StarkPoint]) -> StarkPoint {
613 assert_eq!(
614 scalars.len(),
615 points.len(),
616 "msm cannot compute on vectors of unequal length"
617 );
618
619 let n = scalars.len();
620 if n < MSM_SIZE_THRESHOLD {
621 return scalars.iter().zip(points.iter()).map(|(s, p)| s * p).sum();
622 }
623
624 let affine_points = points.iter().map(|p| p.0.into_affine()).collect_vec();
625 let stripped_scalars = scalars.iter().map(|s| s.0).collect_vec();
626 StarkPointInner::msm(&affine_points, &stripped_scalars)
627 .map(StarkPoint)
628 .unwrap()
629 }
630
631 pub fn msm_iter<I, J>(scalars: I, points: J) -> StarkPoint
634 where
635 I: IntoIterator<Item = Scalar>,
636 J: IntoIterator<Item = StarkPoint>,
637 {
638 let mut res = StarkPoint::identity();
639 for (scalar_chunk, point_chunk) in scalars
640 .into_iter()
641 .chunks(MSM_CHUNK_SIZE)
642 .into_iter()
643 .zip(points.into_iter().chunks(MSM_CHUNK_SIZE).into_iter())
644 {
645 let scalars: Vec<Scalar> = scalar_chunk.collect();
646 let points: Vec<StarkPoint> = point_chunk.collect();
647 let chunk_res = StarkPoint::msm(&scalars, &points);
648
649 res += chunk_res;
650 }
651
652 res
653 }
654
655 pub fn msm_results(scalars: &[ScalarResult], points: &[StarkPoint]) -> StarkPointResult {
657 assert_eq!(
658 scalars.len(),
659 points.len(),
660 "msm cannot compute on vectors of unequal length"
661 );
662
663 let fabric = scalars[0].fabric();
664 let scalar_ids = scalars.iter().map(|s| s.id()).collect_vec();
665
666 let points = points.to_vec();
668 fabric.new_gate_op(scalar_ids, move |args| {
669 let scalars = args.into_iter().map(Scalar::from).collect_vec();
670
671 ResultValue::Point(StarkPoint::msm(&scalars, &points))
672 })
673 }
674
675 pub fn msm_results_iter<I, J>(scalars: I, points: J) -> StarkPointResult
678 where
679 I: IntoIterator<Item = ScalarResult>,
680 J: IntoIterator<Item = StarkPoint>,
681 {
682 Self::msm_results(
683 &scalars.into_iter().collect_vec(),
684 &points.into_iter().collect_vec(),
685 )
686 }
687
688 pub fn msm_authenticated(
690 scalars: &[AuthenticatedScalarResult],
691 points: &[StarkPoint],
692 ) -> AuthenticatedStarkPointResult {
693 assert_eq!(
694 scalars.len(),
695 points.len(),
696 "msm cannot compute on vectors of unequal length"
697 );
698
699 let n = scalars.len();
700 let fabric = scalars[0].fabric();
701 let scalar_ids = scalars.iter().flat_map(|s| s.ids()).collect_vec();
702
703 let points = points.to_vec();
705 let res: Vec<StarkPointResult> = fabric.new_batch_gate_op(
706 scalar_ids,
707 AUTHENTICATED_SCALAR_RESULT_LEN, move |args| {
709 let mut shares = Vec::with_capacity(n);
710 let mut macs = Vec::with_capacity(n);
711 let mut modifiers = Vec::with_capacity(n);
712
713 for chunk in args.chunks_exact(AUTHENTICATED_SCALAR_RESULT_LEN) {
714 shares.push(Scalar::from(chunk[0].to_owned()));
715 macs.push(Scalar::from(chunk[1].to_owned()));
716 modifiers.push(Scalar::from(chunk[2].to_owned()));
717 }
718
719 vec![
721 StarkPoint::msm(&shares, &points),
722 StarkPoint::msm(&macs, &points),
723 StarkPoint::msm(&modifiers, &points),
724 ]
725 .into_iter()
726 .map(ResultValue::Point)
727 .collect_vec()
728 },
729 );
730
731 AuthenticatedStarkPointResult {
732 share: res[0].to_owned().into(),
733 mac: res[1].to_owned().into(),
734 public_modifier: res[2].to_owned(),
735 }
736 }
737
738 pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedStarkPointResult
742 where
743 I: IntoIterator<Item = AuthenticatedScalarResult>,
744 J: IntoIterator<Item = StarkPoint>,
745 {
746 let scalars: Vec<AuthenticatedScalarResult> = scalars.into_iter().collect();
747 let points: Vec<StarkPoint> = points.into_iter().collect();
748
749 Self::msm_authenticated(&scalars, &points)
750 }
751}
752
753impl StarkPointResult {
754 pub fn msm_results(scalars: &[ScalarResult], points: &[StarkPointResult]) -> StarkPointResult {
756 assert!(!scalars.is_empty(), "msm cannot compute on an empty vector");
757 assert_eq!(
758 scalars.len(),
759 points.len(),
760 "msm cannot compute on vectors of unequal length"
761 );
762
763 let n = scalars.len();
764 let fabric = scalars[0].fabric();
765 let all_ids = scalars
766 .iter()
767 .map(|s| s.id())
768 .chain(points.iter().map(|p| p.id()))
769 .collect_vec();
770
771 fabric.new_gate_op(all_ids, move |mut args| {
772 let scalars = args.drain(..n).map(Scalar::from).collect_vec();
773 let points = args.into_iter().map(StarkPoint::from).collect_vec();
774
775 let res = StarkPoint::msm(&scalars, &points);
776 ResultValue::Point(res)
777 })
778 }
779
780 pub fn msm_results_iter<I, J>(scalars: I, points: J) -> StarkPointResult
785 where
786 I: IntoIterator<Item = ScalarResult>,
787 J: IntoIterator<Item = StarkPointResult>,
788 {
789 Self::msm_results(
790 &scalars.into_iter().collect_vec(),
791 &points.into_iter().collect_vec(),
792 )
793 }
794
795 pub fn msm_authenticated(
797 scalars: &[AuthenticatedScalarResult],
798 points: &[StarkPointResult],
799 ) -> AuthenticatedStarkPointResult {
800 assert_eq!(
801 scalars.len(),
802 points.len(),
803 "msm cannot compute on vectors of unequal length"
804 );
805
806 let n = scalars.len();
807 let fabric = scalars[0].fabric();
808 let all_ids = scalars
809 .iter()
810 .flat_map(|s| s.ids())
811 .chain(points.iter().map(|p| p.id()))
812 .collect_vec();
813
814 let res = fabric.new_batch_gate_op(
815 all_ids,
816 AUTHENTICATED_STARK_POINT_RESULT_LEN, move |mut args| {
818 let mut shares = Vec::with_capacity(n);
819 let mut macs = Vec::with_capacity(n);
820 let mut modifiers = Vec::with_capacity(n);
821
822 for mut chunk in args
823 .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
824 .map(Scalar::from)
825 .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
826 .into_iter()
827 {
828 shares.push(chunk.next().unwrap());
829 macs.push(chunk.next().unwrap());
830 modifiers.push(chunk.next().unwrap());
831 }
832
833 let points = args.into_iter().map(StarkPoint::from).collect_vec();
834
835 vec![
836 StarkPoint::msm(&shares, &points),
837 StarkPoint::msm(&macs, &points),
838 StarkPoint::msm(&modifiers, &points),
839 ]
840 .into_iter()
841 .map(ResultValue::Point)
842 .collect_vec()
843 },
844 );
845
846 AuthenticatedStarkPointResult {
847 share: res[0].to_owned().into(),
848 mac: res[1].to_owned().into(),
849 public_modifier: res[2].to_owned(),
850 }
851 }
852
853 pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedStarkPointResult
856 where
857 I: IntoIterator<Item = AuthenticatedScalarResult>,
858 J: IntoIterator<Item = StarkPointResult>,
859 {
860 let scalars: Vec<AuthenticatedScalarResult> = scalars.into_iter().collect();
861 let points: Vec<StarkPointResult> = points.into_iter().collect();
862
863 Self::msm_authenticated(&scalars, &points)
864 }
865}
866
867#[cfg(test)]
874mod test {
875 use rand::{thread_rng, RngCore};
876 use starknet_curve::{curve_params::GENERATOR, ProjectivePoint};
877
878 use crate::algebra::test_helper::{
879 arkworks_point_to_starknet, compare_points, prime_field_to_starknet_felt, random_point,
880 starknet_rs_scalar_mul,
881 };
882
883 use super::*;
884 #[test]
886 fn test_generators() {
887 let generator_1 = StarkPoint::generator();
888 let generator_2 = ProjectivePoint::from_affine_point(&GENERATOR);
889
890 assert!(compare_points(&generator_1, &generator_2));
891 }
892
893 #[test]
895 fn test_point_addition() {
896 let p1 = random_point();
897 let q1 = random_point();
898
899 let p2 = arkworks_point_to_starknet(&p1);
900 let q2 = arkworks_point_to_starknet(&q1);
901
902 let r1 = p1 + q1;
903
904 let mut r2 = p2;
906 r2 += &q2;
907
908 assert!(compare_points(&r1, &r2));
909 }
910
911 #[test]
913 fn test_scalar_mul() {
914 let mut rng = thread_rng();
915 let s1 = Scalar::random(&mut rng);
916 let p1 = random_point();
917
918 let s2 = prime_field_to_starknet_felt(&s1.0);
919 let p2 = arkworks_point_to_starknet(&p1);
920
921 let r1 = p1 * s1;
922 let r2 = starknet_rs_scalar_mul(&s2, &p2);
923
924 assert!(compare_points(&r1, &r2));
925 }
926
927 #[test]
929 fn test_additive_identity() {
930 let p1 = random_point();
931 let res = p1 + StarkPoint::identity();
932
933 assert_eq!(p1, res);
934 }
935
936 #[test]
938 fn test_point_serialized() {
939 let point = random_point();
941 let res = point.to_bytes();
942
943 assert_eq!(res.len(), STARK_POINT_BYTES);
944
945 let deserialized = StarkPoint::from_bytes(&res).unwrap();
947 assert_eq!(point, deserialized);
948 }
949
950 #[test]
952 fn test_hash_to_curve() {
953 let mut rng = thread_rng();
955 let mut buf = [0u8; STARK_UNIFORM_BYTES];
956 rng.fill_bytes(&mut buf);
957
958 let res = StarkPoint::from_uniform_bytes(buf);
960 assert!(res.is_ok())
961 }
962
963 #[test]
965 fn test_to_from_affine_coords() {
966 let projective = random_point();
967 let affine = projective.to_affine();
968
969 let x = BigUint::from(affine.x);
970 let y = BigUint::from(affine.y);
971 let recovered = StarkPoint::from_affine_coords(x, y);
972
973 assert_eq!(projective, recovered);
974 }
975}