mpc_stark/algebra/
stark_curve.rs

1//! Defines the `Scalar` type of the Starknet field
2
3use std::{
4    iter::Sum,
5    mem::size_of,
6    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
7};
8
9use ark_ec::{
10    hashing::{
11        curve_maps::swu::{SWUConfig, SWUMap},
12        map_to_curve_hasher::MapToCurve,
13        HashToCurveError,
14    },
15    short_weierstrass::{Affine, Projective, SWCurveConfig},
16    CurveConfig, CurveGroup, Group, VariableBaseMSM,
17};
18use ark_ff::{BigInt, MontFp, PrimeField, Zero};
19
20use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError};
21use itertools::Itertools;
22use num_bigint::BigUint;
23use serde::{de::Error as DeError, Deserialize, Serialize};
24
25use crate::{
26    algebra::{
27        authenticated_scalar::AUTHENTICATED_SCALAR_RESULT_LEN,
28        authenticated_stark_point::AUTHENTICATED_STARK_POINT_RESULT_LEN,
29    },
30    fabric::{ResultHandle, ResultValue},
31};
32
33use super::{
34    authenticated_scalar::AuthenticatedScalarResult,
35    authenticated_stark_point::AuthenticatedStarkPointResult,
36    macros::{impl_borrow_variants, impl_commutative},
37    mpc_scalar::MpcScalarResult,
38    mpc_stark_point::MpcStarkPointResult,
39    scalar::{Scalar, ScalarInner, ScalarResult, StarknetBaseFelt, BASE_FIELD_BYTES},
40};
41
42/// The number of points and scalars to pull from an iterated MSM when
43/// performing a multiscalar multiplication
44const MSM_CHUNK_SIZE: usize = 1 << 16;
45/// The threshold at which we call out to the Arkworks MSM implementation
46///
47/// MSM sizes below this threshold are computed serially as the parallelism overhead is
48/// too significant
49const MSM_SIZE_THRESHOLD: usize = 10;
50
51/// The security level used in the hash-to-curve implementation, in bytes
52pub const HASH_TO_CURVE_SECURITY: usize = 16; // 128 bit security
53/// The number of bytes needed to serialize a `StarkPoint`
54pub const STARK_POINT_BYTES: usize = 32;
55/// The number of uniformly distributed bytes needed to construct a uniformly
56/// distributed Stark point
57pub const STARK_UNIFORM_BYTES: usize = 2 * (BASE_FIELD_BYTES + HASH_TO_CURVE_SECURITY);
58
59/// The Stark curve in the arkworks short Weierstrass curve representation
60pub struct StarknetCurveConfig;
61impl CurveConfig for StarknetCurveConfig {
62    type BaseField = StarknetBaseFelt;
63    type ScalarField = ScalarInner;
64
65    const COFACTOR: &'static [u64] = &[1];
66    const COFACTOR_INV: Self::ScalarField = MontFp!("1");
67}
68
69/// See https://docs.starkware.co/starkex/crypto/stark-curve.html
70/// for curve parameters
71impl SWCurveConfig for StarknetCurveConfig {
72    const COEFF_A: Self::BaseField = MontFp!("1");
73    const COEFF_B: Self::BaseField =
74        MontFp!("3141592653589793238462643383279502884197169399375105820974944592307816406665");
75
76    const GENERATOR: Affine<Self> = Affine {
77        x: MontFp!("874739451078007766457464989774322083649278607533249481151382481072868806602"),
78        y: MontFp!("152666792071518830868575557812948353041420400780739481342941381225525861407"),
79        infinity: false,
80    };
81}
82
83/// Defines the \zeta constant for the SWU map to curve implementation
84impl SWUConfig for StarknetCurveConfig {
85    const ZETA: Self::BaseField = MontFp!("3");
86}
87
88/// A type alias for a projective curve point on the Stark curve
89pub(crate) type StarkPointInner = Projective<StarknetCurveConfig>;
90/// A wrapper around the inner point that allows us to define foreign traits on the point
91#[derive(Copy, Clone, Debug, PartialEq, Eq)]
92pub struct StarkPoint(pub(crate) StarkPointInner);
93
94impl Serialize for StarkPoint {
95    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
96        let bytes = self.to_bytes();
97        bytes.serialize(serializer)
98    }
99}
100
101impl<'de> Deserialize<'de> for StarkPoint {
102    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
103        let bytes = <Vec<u8>>::deserialize(deserializer)?;
104        StarkPoint::from_bytes(&bytes)
105            .map_err(|err| DeError::custom(format!("Failed to deserialize point: {err:?}")))
106    }
107}
108
109// ------------------------
110// | Misc Implementations |
111// ------------------------
112
113impl StarkPoint {
114    /// The additive identity in the curve group
115    pub fn identity() -> StarkPoint {
116        StarkPoint(StarkPointInner::zero())
117    }
118
119    /// Check whether the given point is the identity point in the group
120    pub fn is_identity(&self) -> bool {
121        self == &StarkPoint::identity()
122    }
123
124    /// Convert the point to affine
125    pub fn to_affine(&self) -> Affine<StarknetCurveConfig> {
126        self.0.into_affine()
127    }
128
129    /// Construct a `StarkPoint` from its affine coordinates
130    pub fn from_affine_coords(x: BigUint, y: BigUint) -> Self {
131        let x_bigint = BigInt::try_from(x).unwrap();
132        let y_bigint = BigInt::try_from(y).unwrap();
133        let x = StarknetBaseFelt::from(x_bigint);
134        let y = StarknetBaseFelt::from(y_bigint);
135
136        let aff = Affine {
137            x,
138            y,
139            infinity: false,
140        };
141
142        Self(aff.into())
143    }
144
145    /// The group generator
146    pub fn generator() -> StarkPoint {
147        StarkPoint(StarkPointInner::generator())
148    }
149
150    /// Serialize this point to a byte buffer
151    pub fn to_bytes(&self) -> Vec<u8> {
152        let mut out: Vec<u8> = Vec::with_capacity(size_of::<StarkPoint>());
153        self.0
154            .serialize_compressed(&mut out)
155            .expect("Failed to serialize point");
156
157        out
158    }
159
160    /// Deserialize a point from a byte buffer
161    pub fn from_bytes(bytes: &[u8]) -> Result<StarkPoint, SerializationError> {
162        let point = StarkPointInner::deserialize_compressed(bytes)?;
163        Ok(StarkPoint(point))
164    }
165
166    /// Convert a uniform byte buffer to a `StarkPoint` via the SWU map-to-curve approach:
167    ///
168    /// See https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-09#simple-swu
169    /// for a description of the setup. Essentially, we assume that the buffer provided is the
170    /// result of an `extend_message` implementation that gives us its uniform digest. From here
171    /// we construct two field elements, map to curve, and add the points to give a uniformly
172    /// distributed curve point
173    pub fn from_uniform_bytes(
174        buf: [u8; STARK_UNIFORM_BYTES],
175    ) -> Result<StarkPoint, HashToCurveError> {
176        // Sample two base field elements from the buffer
177        let f1 = Self::hash_to_field(&buf[..STARK_UNIFORM_BYTES / 2]);
178        let f2 = Self::hash_to_field(&buf[STARK_UNIFORM_BYTES / 2..]);
179
180        // Map to curve
181        let mapper = SWUMap::<StarknetCurveConfig>::new()?;
182        let p1 = mapper.map_to_curve(f1)?;
183        let p2 = mapper.map_to_curve(f2)?;
184
185        // The IETF spec above requires that we clear the cofactor. However, the STARK curve has cofactor
186        // h = 1, so no works needs to be done
187        Ok(StarkPoint(p1 + p2))
188    }
189
190    /// A helper that converts an arbitrarily long byte buffer to a field element
191    fn hash_to_field(buf: &[u8]) -> StarknetBaseFelt {
192        StarknetBaseFelt::from_be_bytes_mod_order(buf)
193    }
194}
195
196impl From<StarkPointInner> for StarkPoint {
197    fn from(p: StarkPointInner) -> Self {
198        StarkPoint(p)
199    }
200}
201
202// ------------------------------------
203// | Curve Arithmetic Implementations |
204// ------------------------------------
205
206// === Addition === //
207
208impl Add<&StarkPointInner> for &StarkPoint {
209    type Output = StarkPoint;
210
211    fn add(self, rhs: &StarkPointInner) -> Self::Output {
212        StarkPoint(self.0 + rhs)
213    }
214}
215impl_borrow_variants!(StarkPoint, Add, add, +, StarkPointInner);
216impl_commutative!(StarkPoint, Add, add, +, StarkPointInner);
217
218impl Add<&StarkPoint> for &StarkPoint {
219    type Output = StarkPoint;
220
221    fn add(self, rhs: &StarkPoint) -> Self::Output {
222        StarkPoint(self.0 + rhs.0)
223    }
224}
225impl_borrow_variants!(StarkPoint, Add, add, +, StarkPoint);
226
227/// A type alias for a result that resolves to a `StarkPoint`
228pub type StarkPointResult = ResultHandle<StarkPoint>;
229/// A type alias for a result that resolves to a batch of `StarkPoint`s
230pub type BatchStarkPointResult = ResultHandle<Vec<StarkPoint>>;
231
232impl Add<&StarkPointResult> for &StarkPointResult {
233    type Output = StarkPointResult;
234
235    fn add(self, rhs: &StarkPointResult) -> Self::Output {
236        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
237            let lhs: StarkPoint = args[0].to_owned().into();
238            let rhs: StarkPoint = args[1].to_owned().into();
239            ResultValue::Point(StarkPoint(lhs.0 + rhs.0))
240        })
241    }
242}
243impl_borrow_variants!(StarkPointResult, Add, add, +, StarkPointResult);
244
245impl Add<&StarkPoint> for &StarkPointResult {
246    type Output = StarkPointResult;
247
248    fn add(self, rhs: &StarkPoint) -> Self::Output {
249        let rhs = *rhs;
250        self.fabric.new_gate_op(vec![self.id], move |args| {
251            let lhs: StarkPoint = args[0].to_owned().into();
252            ResultValue::Point(StarkPoint(lhs.0 + rhs.0))
253        })
254    }
255}
256impl_borrow_variants!(StarkPointResult, Add, add, +, StarkPoint);
257impl_commutative!(StarkPointResult, Add, add, +, StarkPoint);
258
259impl StarkPointResult {
260    /// Add two batches of `StarkPoint`s together
261    pub fn batch_add(a: &[StarkPointResult], b: &[StarkPointResult]) -> Vec<StarkPointResult> {
262        assert_eq!(
263            a.len(),
264            b.len(),
265            "batch_add cannot compute on vectors of unequal length"
266        );
267
268        let n = a.len();
269        let fabric = a[0].fabric();
270        let all_ids = a.iter().chain(b.iter()).map(|r| r.id).collect_vec();
271
272        fabric.new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
273            let a = args.drain(..n).map(StarkPoint::from).collect_vec();
274            let b = args.into_iter().map(StarkPoint::from).collect_vec();
275
276            a.into_iter()
277                .zip(b.into_iter())
278                .map(|(a, b)| a + b)
279                .map(ResultValue::Point)
280                .collect_vec()
281        })
282    }
283}
284
285// === AddAssign === //
286
287impl AddAssign for StarkPoint {
288    fn add_assign(&mut self, rhs: Self) {
289        self.0 += rhs.0;
290    }
291}
292
293// === Subtraction === //
294
295impl Sub<&StarkPoint> for &StarkPoint {
296    type Output = StarkPoint;
297
298    fn sub(self, rhs: &StarkPoint) -> Self::Output {
299        StarkPoint(self.0 - rhs.0)
300    }
301}
302impl_borrow_variants!(StarkPoint, Sub, sub, -, StarkPoint);
303
304impl Sub<&StarkPointResult> for &StarkPointResult {
305    type Output = StarkPointResult;
306
307    fn sub(self, rhs: &StarkPointResult) -> Self::Output {
308        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
309            let lhs: StarkPoint = args[0].to_owned().into();
310            let rhs: StarkPoint = args[1].to_owned().into();
311            ResultValue::Point(StarkPoint(lhs.0 - rhs.0))
312        })
313    }
314}
315impl_borrow_variants!(StarkPointResult, Sub, sub, -, StarkPointResult);
316
317impl Sub<&StarkPoint> for &StarkPointResult {
318    type Output = StarkPointResult;
319
320    fn sub(self, rhs: &StarkPoint) -> Self::Output {
321        let rhs = *rhs;
322        self.fabric.new_gate_op(vec![self.id], move |args| {
323            let lhs: StarkPoint = args[0].to_owned().into();
324            ResultValue::Point(StarkPoint(lhs.0 - rhs.0))
325        })
326    }
327}
328impl_borrow_variants!(StarkPointResult, Sub, sub, -, StarkPoint);
329
330impl Sub<&StarkPointResult> for &StarkPoint {
331    type Output = StarkPointResult;
332
333    fn sub(self, rhs: &StarkPointResult) -> Self::Output {
334        let self_owned = *self;
335        rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
336            let rhs: StarkPoint = args[0].to_owned().into();
337            ResultValue::Point(StarkPoint(self_owned.0 - rhs.0))
338        })
339    }
340}
341
342impl StarkPointResult {
343    /// Subtract two batches of `StarkPoint`s
344    pub fn batch_sub(a: &[StarkPointResult], b: &[StarkPointResult]) -> Vec<StarkPointResult> {
345        assert_eq!(
346            a.len(),
347            b.len(),
348            "batch_sub cannot compute on vectors of unequal length"
349        );
350
351        let n = a.len();
352        let fabric = a[0].fabric();
353        let all_ids = a.iter().chain(b.iter()).map(|r| r.id).collect_vec();
354
355        fabric.new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
356            let a = args.drain(..n).map(StarkPoint::from).collect_vec();
357            let b = args.into_iter().map(StarkPoint::from).collect_vec();
358
359            a.into_iter()
360                .zip(b.into_iter())
361                .map(|(a, b)| a - b)
362                .map(ResultValue::Point)
363                .collect_vec()
364        })
365    }
366}
367
368// === SubAssign === //
369
370impl SubAssign for StarkPoint {
371    fn sub_assign(&mut self, rhs: Self) {
372        self.0 -= rhs.0;
373    }
374}
375
376// === Negation === //
377
378impl Neg for &StarkPoint {
379    type Output = StarkPoint;
380
381    fn neg(self) -> Self::Output {
382        StarkPoint(-self.0)
383    }
384}
385impl_borrow_variants!(StarkPoint, Neg, neg, -);
386
387impl Neg for &StarkPointResult {
388    type Output = StarkPointResult;
389
390    fn neg(self) -> Self::Output {
391        self.fabric.new_gate_op(vec![self.id], |args| {
392            let lhs: StarkPoint = args[0].to_owned().into();
393            ResultValue::Point(StarkPoint(-lhs.0))
394        })
395    }
396}
397impl_borrow_variants!(StarkPointResult, Neg, neg, -);
398
399impl StarkPointResult {
400    /// Negate a batch of `StarkPoint`s
401    pub fn batch_neg(a: &[StarkPointResult]) -> Vec<StarkPointResult> {
402        let n = a.len();
403        let fabric = a[0].fabric();
404        let all_ids = a.iter().map(|r| r.id).collect_vec();
405
406        fabric.new_batch_gate_op(all_ids, n /* output_arity */, |args| {
407            args.into_iter()
408                .map(StarkPoint::from)
409                .map(StarkPoint::neg)
410                .map(ResultValue::Point)
411                .collect_vec()
412        })
413    }
414}
415
416// === Scalar Multiplication === //
417
418impl Mul<&Scalar> for &StarkPoint {
419    type Output = StarkPoint;
420
421    fn mul(self, rhs: &Scalar) -> Self::Output {
422        StarkPoint(self.0 * rhs.0)
423    }
424}
425impl_borrow_variants!(StarkPoint, Mul, mul, *, Scalar);
426impl_commutative!(StarkPoint, Mul, mul, *, Scalar);
427
428impl Mul<&Scalar> for &StarkPointResult {
429    type Output = StarkPointResult;
430
431    fn mul(self, rhs: &Scalar) -> Self::Output {
432        let rhs = *rhs;
433        self.fabric.new_gate_op(vec![self.id], move |args| {
434            let lhs: StarkPoint = args[0].to_owned().into();
435            ResultValue::Point(StarkPoint(lhs.0 * rhs.0))
436        })
437    }
438}
439impl_borrow_variants!(StarkPointResult, Mul, mul, *, Scalar);
440impl_commutative!(StarkPointResult, Mul, mul, *, Scalar);
441
442impl Mul<&ScalarResult> for &StarkPoint {
443    type Output = StarkPointResult;
444
445    fn mul(self, rhs: &ScalarResult) -> Self::Output {
446        let self_owned = *self;
447        rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
448            let rhs: Scalar = args[0].to_owned().into();
449            ResultValue::Point(StarkPoint(self_owned.0 * rhs.0))
450        })
451    }
452}
453impl_borrow_variants!(StarkPoint, Mul, mul, *, ScalarResult, Output=StarkPointResult);
454impl_commutative!(StarkPoint, Mul, mul, *, ScalarResult, Output=StarkPointResult);
455
456impl Mul<&ScalarResult> for &StarkPointResult {
457    type Output = StarkPointResult;
458
459    fn mul(self, rhs: &ScalarResult) -> Self::Output {
460        self.fabric.new_gate_op(vec![self.id, rhs.id], |mut args| {
461            let lhs: StarkPoint = args.remove(0).into();
462            let rhs: Scalar = args.remove(0).into();
463
464            ResultValue::Point(StarkPoint(lhs.0 * rhs.0))
465        })
466    }
467}
468impl_borrow_variants!(StarkPointResult, Mul, mul, *, ScalarResult);
469impl_commutative!(StarkPointResult, Mul, mul, *, ScalarResult);
470
471impl StarkPointResult {
472    /// Multiply a batch of `StarkPointResult`s with a batch of `ScalarResult`s
473    pub fn batch_mul(a: &[ScalarResult], b: &[StarkPointResult]) -> Vec<StarkPointResult> {
474        assert_eq!(
475            a.len(),
476            b.len(),
477            "batch_mul cannot compute on vectors of unequal length"
478        );
479
480        let n = a.len();
481        let fabric = a[0].fabric();
482        let all_ids = a
483            .iter()
484            .map(|a| a.id())
485            .chain(b.iter().map(|b| b.id()))
486            .collect_vec();
487
488        fabric.new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
489            let a = args.drain(..n).map(Scalar::from).collect_vec();
490            let b = args.into_iter().map(StarkPoint::from).collect_vec();
491
492            a.into_iter()
493                .zip(b.into_iter())
494                .map(|(a, b)| a * b)
495                .map(ResultValue::Point)
496                .collect_vec()
497        })
498    }
499
500    /// Multiply a batch of `MpcScalarResult`s with a batch of `StarkPointResult`s
501    pub fn batch_mul_shared(
502        a: &[MpcScalarResult],
503        b: &[StarkPointResult],
504    ) -> Vec<MpcStarkPointResult> {
505        assert_eq!(
506            a.len(),
507            b.len(),
508            "batch_mul_shared cannot compute on vectors of unequal length"
509        );
510
511        let n = a.len();
512        let fabric = a[0].fabric();
513        let all_ids = a
514            .iter()
515            .map(|a| a.id())
516            .chain(b.iter().map(|b| b.id()))
517            .collect_vec();
518
519        fabric
520            .new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
521                let a = args.drain(..n).map(Scalar::from).collect_vec();
522                let b = args.into_iter().map(StarkPoint::from).collect_vec();
523
524                a.into_iter()
525                    .zip(b.into_iter())
526                    .map(|(a, b)| a * b)
527                    .map(ResultValue::Point)
528                    .collect_vec()
529            })
530            .into_iter()
531            .map(MpcStarkPointResult::from)
532            .collect_vec()
533    }
534
535    /// Multiply a batch of `AuthenticatedScalarResult`s with a batch of `StarkPointResult`s
536    pub fn batch_mul_authenticated(
537        a: &[AuthenticatedScalarResult],
538        b: &[StarkPointResult],
539    ) -> Vec<AuthenticatedStarkPointResult> {
540        assert_eq!(
541            a.len(),
542            b.len(),
543            "batch_mul_authenticated cannot compute on vectors of unequal length"
544        );
545
546        let n = a.len();
547        let fabric = a[0].fabric();
548        let all_ids = b
549            .iter()
550            .map(|b| b.id())
551            .chain(a.iter().flat_map(|a| a.ids()))
552            .collect_vec();
553
554        let results = fabric.new_batch_gate_op(
555            all_ids,
556            AUTHENTICATED_STARK_POINT_RESULT_LEN * n, /* output_arity */
557            move |mut args| {
558                let points: Vec<StarkPoint> = args.drain(..n).map(StarkPoint::from).collect_vec();
559
560                let mut results = Vec::with_capacity(AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
561
562                for (scalars, point) in args
563                    .chunks_exact(AUTHENTICATED_SCALAR_RESULT_LEN)
564                    .zip(points.into_iter())
565                {
566                    let share = Scalar::from(&scalars[0]);
567                    let mac = Scalar::from(&scalars[1]);
568                    let public_modifier = Scalar::from(&scalars[2]);
569
570                    results.push(ResultValue::Point(point * share));
571                    results.push(ResultValue::Point(point * mac));
572                    results.push(ResultValue::Point(point * public_modifier));
573                }
574
575                results
576            },
577        );
578
579        AuthenticatedStarkPointResult::from_flattened_iterator(results.into_iter())
580    }
581}
582
583// === MulAssign === //
584
585impl MulAssign<&Scalar> for StarkPoint {
586    fn mul_assign(&mut self, rhs: &Scalar) {
587        self.0 *= rhs.0;
588    }
589}
590
591// -------------------
592// | Iterator Traits |
593// -------------------
594
595impl Sum for StarkPoint {
596    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
597        iter.fold(StarkPoint::identity(), |acc, x| acc + x)
598    }
599}
600
601impl Sum for StarkPointResult {
602    /// Assumes the iterator is non-empty
603    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
604        let first = iter.next().expect("empty iterator");
605        iter.fold(first, |acc, x| acc + x)
606    }
607}
608
609/// MSM Implementation
610impl StarkPoint {
611    /// Compute the multiscalar multiplication of the given scalars and points
612    pub fn msm(scalars: &[Scalar], points: &[StarkPoint]) -> StarkPoint {
613        assert_eq!(
614            scalars.len(),
615            points.len(),
616            "msm cannot compute on vectors of unequal length"
617        );
618
619        let n = scalars.len();
620        if n < MSM_SIZE_THRESHOLD {
621            return scalars.iter().zip(points.iter()).map(|(s, p)| s * p).sum();
622        }
623
624        let affine_points = points.iter().map(|p| p.0.into_affine()).collect_vec();
625        let stripped_scalars = scalars.iter().map(|s| s.0).collect_vec();
626        StarkPointInner::msm(&affine_points, &stripped_scalars)
627            .map(StarkPoint)
628            .unwrap()
629    }
630
631    /// Compute the multiscalar multiplication of the given scalars and points
632    /// represented as streaming iterators
633    pub fn msm_iter<I, J>(scalars: I, points: J) -> StarkPoint
634    where
635        I: IntoIterator<Item = Scalar>,
636        J: IntoIterator<Item = StarkPoint>,
637    {
638        let mut res = StarkPoint::identity();
639        for (scalar_chunk, point_chunk) in scalars
640            .into_iter()
641            .chunks(MSM_CHUNK_SIZE)
642            .into_iter()
643            .zip(points.into_iter().chunks(MSM_CHUNK_SIZE).into_iter())
644        {
645            let scalars: Vec<Scalar> = scalar_chunk.collect();
646            let points: Vec<StarkPoint> = point_chunk.collect();
647            let chunk_res = StarkPoint::msm(&scalars, &points);
648
649            res += chunk_res;
650        }
651
652        res
653    }
654
655    /// Compute the multiscalar multiplication of the given points with `ScalarResult`s
656    pub fn msm_results(scalars: &[ScalarResult], points: &[StarkPoint]) -> StarkPointResult {
657        assert_eq!(
658            scalars.len(),
659            points.len(),
660            "msm cannot compute on vectors of unequal length"
661        );
662
663        let fabric = scalars[0].fabric();
664        let scalar_ids = scalars.iter().map(|s| s.id()).collect_vec();
665
666        // Clone `points` so that the gate closure may capture it
667        let points = points.to_vec();
668        fabric.new_gate_op(scalar_ids, move |args| {
669            let scalars = args.into_iter().map(Scalar::from).collect_vec();
670
671            ResultValue::Point(StarkPoint::msm(&scalars, &points))
672        })
673    }
674
675    /// Compute the multiscalar multiplication of the given points with `ScalarResult`s
676    /// as iterators. Assumes the iterators are non-empty
677    pub fn msm_results_iter<I, J>(scalars: I, points: J) -> StarkPointResult
678    where
679        I: IntoIterator<Item = ScalarResult>,
680        J: IntoIterator<Item = StarkPoint>,
681    {
682        Self::msm_results(
683            &scalars.into_iter().collect_vec(),
684            &points.into_iter().collect_vec(),
685        )
686    }
687
688    /// Compute the multiscalar multiplication of the given authenticated scalars and plaintext points
689    pub fn msm_authenticated(
690        scalars: &[AuthenticatedScalarResult],
691        points: &[StarkPoint],
692    ) -> AuthenticatedStarkPointResult {
693        assert_eq!(
694            scalars.len(),
695            points.len(),
696            "msm cannot compute on vectors of unequal length"
697        );
698
699        let n = scalars.len();
700        let fabric = scalars[0].fabric();
701        let scalar_ids = scalars.iter().flat_map(|s| s.ids()).collect_vec();
702
703        // Clone points to let the gate closure take ownership
704        let points = points.to_vec();
705        let res: Vec<StarkPointResult> = fabric.new_batch_gate_op(
706            scalar_ids,
707            AUTHENTICATED_SCALAR_RESULT_LEN, /* output_arity */
708            move |args| {
709                let mut shares = Vec::with_capacity(n);
710                let mut macs = Vec::with_capacity(n);
711                let mut modifiers = Vec::with_capacity(n);
712
713                for chunk in args.chunks_exact(AUTHENTICATED_SCALAR_RESULT_LEN) {
714                    shares.push(Scalar::from(chunk[0].to_owned()));
715                    macs.push(Scalar::from(chunk[1].to_owned()));
716                    modifiers.push(Scalar::from(chunk[2].to_owned()));
717                }
718
719                // Compute the MSM of the point
720                vec![
721                    StarkPoint::msm(&shares, &points),
722                    StarkPoint::msm(&macs, &points),
723                    StarkPoint::msm(&modifiers, &points),
724                ]
725                .into_iter()
726                .map(ResultValue::Point)
727                .collect_vec()
728            },
729        );
730
731        AuthenticatedStarkPointResult {
732            share: res[0].to_owned().into(),
733            mac: res[1].to_owned().into(),
734            public_modifier: res[2].to_owned(),
735        }
736    }
737
738    /// Compute the multiscalar multiplication of the given authenticated scalars and plaintext points
739    /// as iterators
740    /// This method assumes that the iterators are of the same length
741    pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedStarkPointResult
742    where
743        I: IntoIterator<Item = AuthenticatedScalarResult>,
744        J: IntoIterator<Item = StarkPoint>,
745    {
746        let scalars: Vec<AuthenticatedScalarResult> = scalars.into_iter().collect();
747        let points: Vec<StarkPoint> = points.into_iter().collect();
748
749        Self::msm_authenticated(&scalars, &points)
750    }
751}
752
753impl StarkPointResult {
754    /// Compute the multiscalar multiplication of the given scalars and points
755    pub fn msm_results(scalars: &[ScalarResult], points: &[StarkPointResult]) -> StarkPointResult {
756        assert!(!scalars.is_empty(), "msm cannot compute on an empty vector");
757        assert_eq!(
758            scalars.len(),
759            points.len(),
760            "msm cannot compute on vectors of unequal length"
761        );
762
763        let n = scalars.len();
764        let fabric = scalars[0].fabric();
765        let all_ids = scalars
766            .iter()
767            .map(|s| s.id())
768            .chain(points.iter().map(|p| p.id()))
769            .collect_vec();
770
771        fabric.new_gate_op(all_ids, move |mut args| {
772            let scalars = args.drain(..n).map(Scalar::from).collect_vec();
773            let points = args.into_iter().map(StarkPoint::from).collect_vec();
774
775            let res = StarkPoint::msm(&scalars, &points);
776            ResultValue::Point(res)
777        })
778    }
779
780    /// Compute the multiscalar multiplication of the given scalars and points
781    /// represented as streaming iterators
782    ///
783    /// Assumes the iterator is non-empty
784    pub fn msm_results_iter<I, J>(scalars: I, points: J) -> StarkPointResult
785    where
786        I: IntoIterator<Item = ScalarResult>,
787        J: IntoIterator<Item = StarkPointResult>,
788    {
789        Self::msm_results(
790            &scalars.into_iter().collect_vec(),
791            &points.into_iter().collect_vec(),
792        )
793    }
794
795    /// Compute the multiscalar multiplication of the given `AuthenticatedScalar`s and points
796    pub fn msm_authenticated(
797        scalars: &[AuthenticatedScalarResult],
798        points: &[StarkPointResult],
799    ) -> AuthenticatedStarkPointResult {
800        assert_eq!(
801            scalars.len(),
802            points.len(),
803            "msm cannot compute on vectors of unequal length"
804        );
805
806        let n = scalars.len();
807        let fabric = scalars[0].fabric();
808        let all_ids = scalars
809            .iter()
810            .flat_map(|s| s.ids())
811            .chain(points.iter().map(|p| p.id()))
812            .collect_vec();
813
814        let res = fabric.new_batch_gate_op(
815            all_ids,
816            AUTHENTICATED_STARK_POINT_RESULT_LEN, /* output_arity */
817            move |mut args| {
818                let mut shares = Vec::with_capacity(n);
819                let mut macs = Vec::with_capacity(n);
820                let mut modifiers = Vec::with_capacity(n);
821
822                for mut chunk in args
823                    .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
824                    .map(Scalar::from)
825                    .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
826                    .into_iter()
827                {
828                    shares.push(chunk.next().unwrap());
829                    macs.push(chunk.next().unwrap());
830                    modifiers.push(chunk.next().unwrap());
831                }
832
833                let points = args.into_iter().map(StarkPoint::from).collect_vec();
834
835                vec![
836                    StarkPoint::msm(&shares, &points),
837                    StarkPoint::msm(&macs, &points),
838                    StarkPoint::msm(&modifiers, &points),
839                ]
840                .into_iter()
841                .map(ResultValue::Point)
842                .collect_vec()
843            },
844        );
845
846        AuthenticatedStarkPointResult {
847            share: res[0].to_owned().into(),
848            mac: res[1].to_owned().into(),
849            public_modifier: res[2].to_owned(),
850        }
851    }
852
853    /// Compute the multiscalar multiplication of the given `AuthenticatedScalar`s and points
854    /// represented as streaming iterators
855    pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedStarkPointResult
856    where
857        I: IntoIterator<Item = AuthenticatedScalarResult>,
858        J: IntoIterator<Item = StarkPointResult>,
859    {
860        let scalars: Vec<AuthenticatedScalarResult> = scalars.into_iter().collect();
861        let points: Vec<StarkPointResult> = points.into_iter().collect();
862
863        Self::msm_authenticated(&scalars, &points)
864    }
865}
866
867// ---------
868// | Tests |
869// ---------
870
871/// We test our config against a known implementation of the Stark curve:
872///     https://github.com/xJonathanLEI/starknet-rs
873#[cfg(test)]
874mod test {
875    use rand::{thread_rng, RngCore};
876    use starknet_curve::{curve_params::GENERATOR, ProjectivePoint};
877
878    use crate::algebra::test_helper::{
879        arkworks_point_to_starknet, compare_points, prime_field_to_starknet_felt, random_point,
880        starknet_rs_scalar_mul,
881    };
882
883    use super::*;
884    /// Test that the generators are the same between the two curve representations
885    #[test]
886    fn test_generators() {
887        let generator_1 = StarkPoint::generator();
888        let generator_2 = ProjectivePoint::from_affine_point(&GENERATOR);
889
890        assert!(compare_points(&generator_1, &generator_2));
891    }
892
893    /// Tests point addition
894    #[test]
895    fn test_point_addition() {
896        let p1 = random_point();
897        let q1 = random_point();
898
899        let p2 = arkworks_point_to_starknet(&p1);
900        let q2 = arkworks_point_to_starknet(&q1);
901
902        let r1 = p1 + q1;
903
904        // Only `AddAssign` is implemented on `ProjectivePoint`
905        let mut r2 = p2;
906        r2 += &q2;
907
908        assert!(compare_points(&r1, &r2));
909    }
910
911    /// Tests scalar multiplication
912    #[test]
913    fn test_scalar_mul() {
914        let mut rng = thread_rng();
915        let s1 = Scalar::random(&mut rng);
916        let p1 = random_point();
917
918        let s2 = prime_field_to_starknet_felt(&s1.0);
919        let p2 = arkworks_point_to_starknet(&p1);
920
921        let r1 = p1 * s1;
922        let r2 = starknet_rs_scalar_mul(&s2, &p2);
923
924        assert!(compare_points(&r1, &r2));
925    }
926
927    /// Tests addition with the additive identity
928    #[test]
929    fn test_additive_identity() {
930        let p1 = random_point();
931        let res = p1 + StarkPoint::identity();
932
933        assert_eq!(p1, res);
934    }
935
936    /// Tests the size of the curve point serialization
937    #[test]
938    fn test_point_serialized() {
939        // Sample a random point and serialize it to bytes
940        let point = random_point();
941        let res = point.to_bytes();
942
943        assert_eq!(res.len(), STARK_POINT_BYTES);
944
945        // Deserialize and verify the points are equal
946        let deserialized = StarkPoint::from_bytes(&res).unwrap();
947        assert_eq!(point, deserialized);
948    }
949
950    /// Tests the hash-to-curve implementation `StarkPoint::from_uniform_bytes`
951    #[test]
952    fn test_hash_to_curve() {
953        // Sample random bytes into a buffer
954        let mut rng = thread_rng();
955        let mut buf = [0u8; STARK_UNIFORM_BYTES];
956        rng.fill_bytes(&mut buf);
957
958        // As long as the method does not error, the test is successful
959        let res = StarkPoint::from_uniform_bytes(buf);
960        assert!(res.is_ok())
961    }
962
963    /// Tests converting to and from affine coordinates
964    #[test]
965    fn test_to_from_affine_coords() {
966        let projective = random_point();
967        let affine = projective.to_affine();
968
969        let x = BigUint::from(affine.x);
970        let y = BigUint::from(affine.y);
971        let recovered = StarkPoint::from_affine_coords(x, y);
972
973        assert_eq!(projective, recovered);
974    }
975}