ark_mpc/algebra/curve/
curve.rs

1//! Defines the `CurvePoint` type, a wrapper around a generic curve that allows
2//! us to bring curve arithmetic into the execution graph
3
4use std::{
5    iter::Sum,
6    mem::size_of,
7    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
8};
9
10use ark_ec::{
11    hashing::{
12        curve_maps::swu::{SWUConfig, SWUMap},
13        map_to_curve_hasher::MapToCurve,
14        HashToCurveError,
15    },
16    short_weierstrass::Projective,
17    AffineRepr, CurveGroup,
18};
19use ark_ff::PrimeField;
20
21use ark_serialize::SerializationError;
22use itertools::Itertools;
23use serde::{de::Error as DeError, Deserialize, Serialize};
24
25use crate::{
26    algebra::{
27        macros::*, n_bytes_field, scalar::*, AUTHENTICATED_POINT_RESULT_LEN,
28        AUTHENTICATED_SCALAR_RESULT_LEN,
29    },
30    fabric::{ResultHandle, ResultValue},
31};
32
33use super::{authenticated_curve::AuthenticatedPointResult, mpc_curve::MpcPointResult};
34
35/// The number of points and scalars to pull from an iterated MSM when
36/// performing a multiscalar multiplication
37const MSM_CHUNK_SIZE: usize = 1 << 16;
38/// The threshold at which we call out to the Arkworks MSM implementation
39///
40/// MSM sizes below this threshold are computed serially as the parallelism
41/// overhead is too significant
42const MSM_SIZE_THRESHOLD: usize = 10;
43
44/// The security level used in the hash-to-curve implementation, in bytes
45pub const HASH_TO_CURVE_SECURITY: usize = 16; // 128 bit security
46
47/// A wrapper around the inner point that allows us to define foreign traits on
48/// the point
49#[derive(Copy, Clone, Debug, PartialEq, Eq)]
50pub struct CurvePoint<C: CurveGroup>(pub(crate) C);
51impl<C: CurveGroup> Unpin for CurvePoint<C> {}
52
53impl<C: CurveGroup> Serialize for CurvePoint<C> {
54    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
55        let bytes = self.to_bytes();
56        bytes.serialize(serializer)
57    }
58}
59
60impl<'de, C: CurveGroup> Deserialize<'de> for CurvePoint<C> {
61    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
62        let bytes = <Vec<u8>>::deserialize(deserializer)?;
63        CurvePoint::from_bytes(&bytes)
64            .map_err(|err| DeError::custom(format!("Failed to deserialize point: {err:?}")))
65    }
66}
67
68// ------------------------
69// | Misc Implementations |
70// ------------------------
71
72impl<C: CurveGroup> CurvePoint<C> {
73    /// The base field that the curve is defined over, i.e. the field in which
74    /// the curve equation's coefficients lie
75    pub type BaseField = C::BaseField;
76    /// The scalar field of the curve, i.e. Z/rZ where r is the curve group's
77    /// order
78    pub type ScalarField = C::ScalarField;
79
80    /// The additive identity in the curve group
81    pub fn identity() -> CurvePoint<C> {
82        CurvePoint(C::zero())
83    }
84
85    /// Check whether the given point is the identity point in the group
86    pub fn is_identity(&self) -> bool {
87        self == &CurvePoint::identity()
88    }
89
90    /// Return the wrapped type
91    pub fn inner(&self) -> C {
92        self.0
93    }
94
95    /// Convert the point to affine
96    pub fn to_affine(&self) -> C::Affine {
97        self.0.into_affine()
98    }
99
100    /// The group generator
101    pub fn generator() -> CurvePoint<C> {
102        CurvePoint(C::generator())
103    }
104
105    /// Serialize this point to a byte buffer
106    pub fn to_bytes(&self) -> Vec<u8> {
107        let mut out: Vec<u8> = Vec::with_capacity(size_of::<CurvePoint<C>>());
108        self.0
109            .serialize_compressed(&mut out)
110            .expect("Failed to serialize point");
111
112        out
113    }
114
115    /// Deserialize a point from a byte buffer
116    pub fn from_bytes(bytes: &[u8]) -> Result<CurvePoint<C>, SerializationError> {
117        let point = C::deserialize_compressed(bytes)?;
118        Ok(CurvePoint(point))
119    }
120}
121
122impl<C: CurveGroup> CurvePoint<C>
123where
124    C::BaseField: PrimeField,
125{
126    /// Get the number of bytes needed to represent a point, this is exactly the
127    /// number of bytes for one base field element, as we can simply use the
128    /// x-coordinate and set a high bit for the `y`
129    pub fn n_bytes() -> usize {
130        n_bytes_field::<C::BaseField>()
131    }
132}
133
134impl<C: CurveGroup> CurvePoint<C>
135where
136    C::Config: SWUConfig,
137    C::BaseField: PrimeField,
138{
139    /// Convert a uniform byte buffer to a `CurvePoint<C>` via the SWU
140    /// map-to-curve approach:
141    ///
142    /// See https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-09#simple-swu
143    /// for a description of the setup. Essentially, we assume that the buffer
144    /// provided is the result of an `extend_message` implementation that
145    /// gives us its uniform digest. From here we construct two field
146    /// elements, map to curve, and add the points to give a uniformly
147    /// distributed curve point
148    pub fn from_uniform_bytes(
149        buf: Vec<u8>,
150    ) -> Result<CurvePoint<Projective<C::Config>>, HashToCurveError> {
151        let n_bytes = Self::n_bytes();
152        assert_eq!(
153            buf.len(),
154            2 * n_bytes,
155            "Invalid buffer length, must represent two curve points"
156        );
157
158        // Sample two base field elements from the buffer
159        let f1 = Self::hash_to_field(&buf[..n_bytes / 2]);
160        let f2 = Self::hash_to_field(&buf[n_bytes / 2..]);
161
162        // Map to curve
163        let mapper = SWUMap::<C::Config>::new()?;
164        let p1 = mapper.map_to_curve(f1)?;
165        let p2 = mapper.map_to_curve(f2)?;
166
167        // Clear the cofactor
168        let p1_clear = p1.clear_cofactor();
169        let p2_clear = p2.clear_cofactor();
170
171        Ok(CurvePoint(p1_clear + p2_clear))
172    }
173
174    /// A helper that converts an arbitrarily long byte buffer to a field
175    /// element
176    fn hash_to_field(buf: &[u8]) -> C::BaseField {
177        Self::BaseField::from_be_bytes_mod_order(buf)
178    }
179}
180
181impl<C: CurveGroup> From<C> for CurvePoint<C> {
182    fn from(p: C) -> Self {
183        CurvePoint(p)
184    }
185}
186
187// ------------------------------------
188// | Curve Arithmetic Implementations |
189// ------------------------------------
190
191// === Addition === //
192
193impl<C: CurveGroup> Add<&C> for &CurvePoint<C> {
194    type Output = CurvePoint<C>;
195
196    fn add(self, rhs: &C) -> Self::Output {
197        CurvePoint(self.0 + rhs)
198    }
199}
200impl_borrow_variants!(CurvePoint<C>, Add, add, +, C, C: CurveGroup);
201
202impl<C: CurveGroup> Add<&CurvePoint<C>> for &CurvePoint<C> {
203    type Output = CurvePoint<C>;
204
205    fn add(self, rhs: &CurvePoint<C>) -> Self::Output {
206        CurvePoint(self.0 + rhs.0)
207    }
208}
209impl_borrow_variants!(CurvePoint<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
210
211/// A type alias for a result that resolves to a `CurvePoint<C>`
212pub type CurvePointResult<C> = ResultHandle<C, CurvePoint<C>>;
213/// A type alias for a result that resolves to a batch of `CurvePoint<C>`s
214pub type BatchCurvePointResult<C> = ResultHandle<C, Vec<CurvePoint<C>>>;
215
216impl<C: CurveGroup> Add<&CurvePointResult<C>> for &CurvePointResult<C> {
217    type Output = CurvePointResult<C>;
218
219    fn add(self, rhs: &CurvePointResult<C>) -> Self::Output {
220        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
221            let lhs: CurvePoint<C> = args[0].to_owned().into();
222            let rhs: CurvePoint<C> = args[1].to_owned().into();
223            ResultValue::Point(CurvePoint(lhs.0 + rhs.0))
224        })
225    }
226}
227impl_borrow_variants!(CurvePointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
228
229impl<C: CurveGroup> Add<&CurvePoint<C>> for &CurvePointResult<C> {
230    type Output = CurvePointResult<C>;
231
232    fn add(self, rhs: &CurvePoint<C>) -> Self::Output {
233        let rhs = *rhs;
234        self.fabric.new_gate_op(vec![self.id], move |args| {
235            let lhs: CurvePoint<C> = args[0].to_owned().into();
236            ResultValue::Point(CurvePoint(lhs.0 + rhs.0))
237        })
238    }
239}
240impl_borrow_variants!(CurvePointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
241impl_commutative!(CurvePointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
242
243impl<C: CurveGroup> CurvePointResult<C> {
244    /// Add two batches of `CurvePoint<C>`s together
245    pub fn batch_add(
246        a: &[CurvePointResult<C>],
247        b: &[CurvePointResult<C>],
248    ) -> Vec<CurvePointResult<C>> {
249        assert_eq!(
250            a.len(),
251            b.len(),
252            "batch_add cannot compute on vectors of unequal length"
253        );
254
255        let n = a.len();
256        let fabric = a[0].fabric();
257        let all_ids = a.iter().chain(b.iter()).map(|r| r.id).collect_vec();
258
259        fabric.new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
260            let a = args.drain(..n).map(CurvePoint::from).collect_vec();
261            let b = args.into_iter().map(CurvePoint::from).collect_vec();
262
263            a.into_iter()
264                .zip(b)
265                .map(|(a, b)| a + b)
266                .map(ResultValue::Point)
267                .collect_vec()
268        })
269    }
270}
271
272// === AddAssign === //
273
274impl<C: CurveGroup> AddAssign for CurvePoint<C> {
275    fn add_assign(&mut self, rhs: Self) {
276        self.0 += rhs.0;
277    }
278}
279
280// === Subtraction === //
281
282impl<C: CurveGroup> Sub<&CurvePoint<C>> for &CurvePoint<C> {
283    type Output = CurvePoint<C>;
284
285    fn sub(self, rhs: &CurvePoint<C>) -> Self::Output {
286        CurvePoint(self.0 - rhs.0)
287    }
288}
289impl_borrow_variants!(CurvePoint<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
290
291impl<C: CurveGroup> Sub<&CurvePointResult<C>> for &CurvePointResult<C> {
292    type Output = CurvePointResult<C>;
293
294    fn sub(self, rhs: &CurvePointResult<C>) -> Self::Output {
295        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
296            let lhs: CurvePoint<C> = args[0].to_owned().into();
297            let rhs: CurvePoint<C> = args[1].to_owned().into();
298            ResultValue::Point(CurvePoint(lhs.0 - rhs.0))
299        })
300    }
301}
302impl_borrow_variants!(CurvePointResult<C>, Sub, sub, -, CurvePointResult<C>, C: CurveGroup);
303
304impl<C: CurveGroup> Sub<&CurvePoint<C>> for &CurvePointResult<C> {
305    type Output = CurvePointResult<C>;
306
307    fn sub(self, rhs: &CurvePoint<C>) -> Self::Output {
308        let rhs = *rhs;
309        self.fabric.new_gate_op(vec![self.id], move |args| {
310            let lhs: CurvePoint<C> = args[0].to_owned().into();
311            ResultValue::Point(CurvePoint(lhs.0 - rhs.0))
312        })
313    }
314}
315impl_borrow_variants!(CurvePointResult<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
316
317impl<C: CurveGroup> Sub<&CurvePointResult<C>> for &CurvePoint<C> {
318    type Output = CurvePointResult<C>;
319
320    fn sub(self, rhs: &CurvePointResult<C>) -> Self::Output {
321        let self_owned = *self;
322        rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
323            let rhs: CurvePoint<C> = args[0].to_owned().into();
324            ResultValue::Point(CurvePoint(self_owned.0 - rhs.0))
325        })
326    }
327}
328
329impl<C: CurveGroup> CurvePointResult<C> {
330    /// Subtract two batches of `CurvePoint<C>`s
331    pub fn batch_sub(
332        a: &[CurvePointResult<C>],
333        b: &[CurvePointResult<C>],
334    ) -> Vec<CurvePointResult<C>> {
335        assert_eq!(
336            a.len(),
337            b.len(),
338            "batch_sub cannot compute on vectors of unequal length"
339        );
340
341        let n = a.len();
342        let fabric = a[0].fabric();
343        let all_ids = a.iter().chain(b.iter()).map(|r| r.id).collect_vec();
344
345        fabric.new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
346            let a = args.drain(..n).map(CurvePoint::from).collect_vec();
347            let b = args.into_iter().map(CurvePoint::from).collect_vec();
348
349            a.into_iter()
350                .zip(b)
351                .map(|(a, b)| a - b)
352                .map(ResultValue::Point)
353                .collect_vec()
354        })
355    }
356}
357
358// === SubAssign === //
359
360impl<C: CurveGroup> SubAssign for CurvePoint<C> {
361    fn sub_assign(&mut self, rhs: Self) {
362        self.0 -= rhs.0;
363    }
364}
365
366// === Negation === //
367
368impl<C: CurveGroup> Neg for &CurvePoint<C> {
369    type Output = CurvePoint<C>;
370
371    fn neg(self) -> Self::Output {
372        CurvePoint(-self.0)
373    }
374}
375impl_borrow_variants!(CurvePoint<C>, Neg, neg, -, C: CurveGroup);
376
377impl<C: CurveGroup> Neg for &CurvePointResult<C> {
378    type Output = CurvePointResult<C>;
379
380    fn neg(self) -> Self::Output {
381        self.fabric.new_gate_op(vec![self.id], |args| {
382            let lhs: CurvePoint<C> = args[0].to_owned().into();
383            ResultValue::Point(CurvePoint(-lhs.0))
384        })
385    }
386}
387impl_borrow_variants!(CurvePointResult<C>, Neg, neg, -, C:CurveGroup);
388
389impl<C: CurveGroup> CurvePointResult<C> {
390    /// Negate a batch of `CurvePoint<C>`s
391    pub fn batch_neg(a: &[CurvePointResult<C>]) -> Vec<CurvePointResult<C>> {
392        let n = a.len();
393        let fabric = a[0].fabric();
394        let all_ids = a.iter().map(|r| r.id).collect_vec();
395
396        fabric.new_batch_gate_op(all_ids, n /* output_arity */, |args| {
397            args.into_iter()
398                .map(CurvePoint::from)
399                .map(CurvePoint::neg)
400                .map(ResultValue::Point)
401                .collect_vec()
402        })
403    }
404}
405
406// === Scalar Multiplication === //
407
408impl<C: CurveGroup> Mul<&Scalar<C>> for &CurvePoint<C> {
409    type Output = CurvePoint<C>;
410
411    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
412        CurvePoint(self.0 * rhs.0)
413    }
414}
415impl_borrow_variants!(CurvePoint<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
416impl_commutative!(CurvePoint<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
417
418impl<C: CurveGroup> Mul<&Scalar<C>> for &CurvePointResult<C> {
419    type Output = CurvePointResult<C>;
420
421    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
422        let rhs = *rhs;
423        self.fabric.new_gate_op(vec![self.id], move |args| {
424            let lhs: CurvePoint<C> = args[0].to_owned().into();
425            ResultValue::Point(CurvePoint(lhs.0 * rhs.0))
426        })
427    }
428}
429impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
430impl_commutative!(CurvePointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
431
432impl<C: CurveGroup> Mul<&ScalarResult<C>> for &CurvePoint<C> {
433    type Output = CurvePointResult<C>;
434
435    fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
436        let self_owned = *self;
437        rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
438            let rhs: Scalar<C> = args[0].to_owned().into();
439            ResultValue::Point(CurvePoint(self_owned.0 * rhs.0))
440        })
441    }
442}
443impl_borrow_variants!(CurvePoint<C>, Mul, mul, *, ScalarResult<C>, Output=CurvePointResult<C>, C: CurveGroup);
444impl_commutative!(CurvePoint<C>, Mul, mul, *, ScalarResult<C>, Output=CurvePointResult<C>, C: CurveGroup);
445
446impl<C: CurveGroup> Mul<&ScalarResult<C>> for &CurvePointResult<C> {
447    type Output = CurvePointResult<C>;
448
449    fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
450        self.fabric.new_gate_op(vec![self.id, rhs.id], |mut args| {
451            let lhs: CurvePoint<C> = args.remove(0).into();
452            let rhs: Scalar<C> = args.remove(0).into();
453
454            ResultValue::Point(CurvePoint(lhs.0 * rhs.0))
455        })
456    }
457}
458impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
459impl_commutative!(CurvePointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
460
461impl<C: CurveGroup> CurvePointResult<C> {
462    /// Multiply a batch of `CurvePointResult<C>`s with a batch of
463    /// `ScalarResult`s
464    pub fn batch_mul(a: &[ScalarResult<C>], b: &[CurvePointResult<C>]) -> Vec<CurvePointResult<C>> {
465        assert_eq!(
466            a.len(),
467            b.len(),
468            "batch_mul cannot compute on vectors of unequal length"
469        );
470
471        let n = a.len();
472        let fabric = a[0].fabric();
473        let all_ids = a
474            .iter()
475            .map(|a| a.id())
476            .chain(b.iter().map(|b| b.id()))
477            .collect_vec();
478
479        fabric.new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
480            let a = args.drain(..n).map(Scalar::from).collect_vec();
481            let b = args.into_iter().map(CurvePoint::from).collect_vec();
482
483            a.into_iter()
484                .zip(b)
485                .map(|(a, b)| a * b)
486                .map(ResultValue::Point)
487                .collect_vec()
488        })
489    }
490
491    /// Multiply a batch of `MpcScalarResult`s with a batch of
492    /// `CurvePointResult<C>`s
493    pub fn batch_mul_shared(
494        a: &[MpcScalarResult<C>],
495        b: &[CurvePointResult<C>],
496    ) -> Vec<MpcPointResult<C>> {
497        assert_eq!(
498            a.len(),
499            b.len(),
500            "batch_mul_shared cannot compute on vectors of unequal length"
501        );
502
503        let n = a.len();
504        let fabric = a[0].fabric();
505        let all_ids = a
506            .iter()
507            .map(|a| a.id())
508            .chain(b.iter().map(|b| b.id()))
509            .collect_vec();
510
511        fabric
512            .new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
513                let a = args.drain(..n).map(Scalar::from).collect_vec();
514                let b = args.into_iter().map(CurvePoint::from).collect_vec();
515
516                a.into_iter()
517                    .zip(b)
518                    .map(|(a, b)| a * b)
519                    .map(ResultValue::Point)
520                    .collect_vec()
521            })
522            .into_iter()
523            .map(MpcPointResult::from)
524            .collect_vec()
525    }
526
527    /// Multiply a batch of `AuthenticatedScalarResult`s with a batch of
528    /// `CurvePointResult<C>`s
529    pub fn batch_mul_authenticated(
530        a: &[AuthenticatedScalarResult<C>],
531        b: &[CurvePointResult<C>],
532    ) -> Vec<AuthenticatedPointResult<C>> {
533        assert_eq!(
534            a.len(),
535            b.len(),
536            "batch_mul_authenticated cannot compute on vectors of unequal length"
537        );
538
539        let n = a.len();
540        let fabric = a[0].fabric();
541        let all_ids = b
542            .iter()
543            .map(|b| b.id())
544            .chain(a.iter().flat_map(|a| a.ids()))
545            .collect_vec();
546
547        let results = fabric.new_batch_gate_op(
548            all_ids,
549            AUTHENTICATED_POINT_RESULT_LEN * n, // output_arity
550            move |mut args| {
551                let points: Vec<CurvePoint<C>> =
552                    args.drain(..n).map(CurvePoint::from).collect_vec();
553
554                let mut results = Vec::with_capacity(AUTHENTICATED_POINT_RESULT_LEN * n);
555
556                for (scalars, point) in args
557                    .chunks_exact(AUTHENTICATED_SCALAR_RESULT_LEN)
558                    .zip(points.into_iter())
559                {
560                    let share = Scalar::from(&scalars[0]);
561                    let mac = Scalar::from(&scalars[1]);
562                    let public_modifier = Scalar::from(&scalars[2]);
563
564                    results.push(ResultValue::Point(point * share));
565                    results.push(ResultValue::Point(point * mac));
566                    results.push(ResultValue::Point(point * public_modifier));
567                }
568
569                results
570            },
571        );
572
573        AuthenticatedPointResult::from_flattened_iterator(results.into_iter())
574    }
575}
576
577// === MulAssign === //
578
579impl<C: CurveGroup> MulAssign<&Scalar<C>> for CurvePoint<C> {
580    fn mul_assign(&mut self, rhs: &Scalar<C>) {
581        self.0 *= rhs.0;
582    }
583}
584
585// -------------------
586// | Iterator Traits |
587// -------------------
588
589impl<C: CurveGroup> Sum for CurvePoint<C> {
590    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
591        iter.fold(CurvePoint::identity(), |acc, x| acc + x)
592    }
593}
594
595impl<C: CurveGroup> Sum for CurvePointResult<C> {
596    /// Assumes the iterator is non-empty
597    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
598        let first = iter.next().expect("empty iterator");
599        iter.fold(first, |acc, x| acc + x)
600    }
601}
602
603/// MSM Implementation
604impl<C: CurveGroup> CurvePoint<C> {
605    /// Compute the multiscalar multiplication of the given scalars and points
606    pub fn msm(scalars: &[Scalar<C>], points: &[CurvePoint<C>]) -> CurvePoint<C> {
607        assert_eq!(
608            scalars.len(),
609            points.len(),
610            "msm cannot compute on vectors of unequal length"
611        );
612
613        let n = scalars.len();
614        if n < MSM_SIZE_THRESHOLD {
615            return scalars.iter().zip(points.iter()).map(|(s, p)| s * p).sum();
616        }
617
618        let affine_points = points.iter().map(|p| p.0.into_affine()).collect_vec();
619        let stripped_scalars = scalars.iter().map(|s| s.0).collect_vec();
620        C::msm(&affine_points, &stripped_scalars)
621            .map(CurvePoint)
622            .unwrap()
623    }
624
625    /// Compute the multiscalar multiplication of the given scalars and points
626    /// represented as streaming iterators
627    pub fn msm_iter<I, J>(scalars: I, points: J) -> CurvePoint<C>
628    where
629        I: IntoIterator<Item = Scalar<C>>,
630        J: IntoIterator<Item = CurvePoint<C>>,
631    {
632        let mut res = CurvePoint::identity();
633        for (scalar_chunk, point_chunk) in scalars
634            .into_iter()
635            .chunks(MSM_CHUNK_SIZE)
636            .into_iter()
637            .zip(points.into_iter().chunks(MSM_CHUNK_SIZE).into_iter())
638        {
639            let scalars: Vec<Scalar<C>> = scalar_chunk.collect();
640            let points: Vec<CurvePoint<C>> = point_chunk.collect();
641            let chunk_res = CurvePoint::msm(&scalars, &points);
642
643            res += chunk_res;
644        }
645
646        res
647    }
648
649    /// Compute the multiscalar multiplication of the given points with
650    /// `ScalarResult`s
651    pub fn msm_results(
652        scalars: &[ScalarResult<C>],
653        points: &[CurvePoint<C>],
654    ) -> CurvePointResult<C> {
655        assert_eq!(
656            scalars.len(),
657            points.len(),
658            "msm cannot compute on vectors of unequal length"
659        );
660
661        let fabric = scalars[0].fabric();
662        let scalar_ids = scalars.iter().map(|s| s.id()).collect_vec();
663
664        // Clone `points` so that the gate closure may capture it
665        let points = points.to_vec();
666        fabric.new_gate_op(scalar_ids, move |args| {
667            let scalars = args.into_iter().map(Scalar::from).collect_vec();
668
669            ResultValue::Point(CurvePoint::msm(&scalars, &points))
670        })
671    }
672
673    /// Compute the multiscalar multiplication of the given points with
674    /// `ScalarResult`s as iterators. Assumes the iterators are non-empty
675    pub fn msm_results_iter<I, J>(scalars: I, points: J) -> CurvePointResult<C>
676    where
677        I: IntoIterator<Item = ScalarResult<C>>,
678        J: IntoIterator<Item = CurvePoint<C>>,
679    {
680        Self::msm_results(
681            &scalars.into_iter().collect_vec(),
682            &points.into_iter().collect_vec(),
683        )
684    }
685
686    /// Compute the multiscalar multiplication of the given authenticated
687    /// scalars and plaintext points
688    pub fn msm_authenticated(
689        scalars: &[AuthenticatedScalarResult<C>],
690        points: &[CurvePoint<C>],
691    ) -> AuthenticatedPointResult<C> {
692        assert_eq!(
693            scalars.len(),
694            points.len(),
695            "msm cannot compute on vectors of unequal length"
696        );
697
698        let n = scalars.len();
699        let fabric = scalars[0].fabric();
700        let scalar_ids = scalars.iter().flat_map(|s| s.ids()).collect_vec();
701
702        // Clone points to let the gate closure take ownership
703        let points = points.to_vec();
704        let res: Vec<CurvePointResult<C>> = fabric.new_batch_gate_op(
705            scalar_ids,
706            AUTHENTICATED_SCALAR_RESULT_LEN, // output_arity
707            move |args| {
708                let mut shares = Vec::with_capacity(n);
709                let mut macs = Vec::with_capacity(n);
710                let mut modifiers = Vec::with_capacity(n);
711
712                for chunk in args.chunks_exact(AUTHENTICATED_SCALAR_RESULT_LEN) {
713                    shares.push(Scalar::from(chunk[0].to_owned()));
714                    macs.push(Scalar::from(chunk[1].to_owned()));
715                    modifiers.push(Scalar::from(chunk[2].to_owned()));
716                }
717
718                // Compute the MSM of the point
719                vec![
720                    CurvePoint::msm(&shares, &points),
721                    CurvePoint::msm(&macs, &points),
722                    CurvePoint::msm(&modifiers, &points),
723                ]
724                .into_iter()
725                .map(ResultValue::Point)
726                .collect_vec()
727            },
728        );
729
730        AuthenticatedPointResult {
731            share: res[0].to_owned().into(),
732            mac: res[1].to_owned().into(),
733            public_modifier: res[2].to_owned(),
734        }
735    }
736
737    /// Compute the multiscalar multiplication of the given authenticated
738    /// scalars and plaintext points as iterators
739    /// This method assumes that the iterators are of the same length
740    pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedPointResult<C>
741    where
742        I: IntoIterator<Item = AuthenticatedScalarResult<C>>,
743        J: IntoIterator<Item = CurvePoint<C>>,
744    {
745        let scalars: Vec<AuthenticatedScalarResult<C>> = scalars.into_iter().collect();
746        let points: Vec<CurvePoint<C>> = points.into_iter().collect();
747
748        Self::msm_authenticated(&scalars, &points)
749    }
750}
751
752impl<C: CurveGroup> CurvePointResult<C> {
753    /// Compute the multiscalar multiplication of the given scalars and points
754    pub fn msm_results(
755        scalars: &[ScalarResult<C>],
756        points: &[CurvePointResult<C>],
757    ) -> CurvePointResult<C> {
758        assert!(!scalars.is_empty(), "msm cannot compute on an empty vector");
759        assert_eq!(
760            scalars.len(),
761            points.len(),
762            "msm cannot compute on vectors of unequal length"
763        );
764
765        let n = scalars.len();
766        let fabric = scalars[0].fabric();
767        let all_ids = scalars
768            .iter()
769            .map(|s| s.id())
770            .chain(points.iter().map(|p| p.id()))
771            .collect_vec();
772
773        fabric.new_gate_op(all_ids, move |mut args| {
774            let scalars = args.drain(..n).map(Scalar::from).collect_vec();
775            let points = args.into_iter().map(CurvePoint::from).collect_vec();
776
777            let res = CurvePoint::msm(&scalars, &points);
778            ResultValue::Point(res)
779        })
780    }
781
782    /// Compute the multiscalar multiplication of the given scalars and points
783    /// represented as streaming iterators
784    ///
785    /// Assumes the iterator is non-empty
786    pub fn msm_results_iter<I, J>(scalars: I, points: J) -> CurvePointResult<C>
787    where
788        I: IntoIterator<Item = ScalarResult<C>>,
789        J: IntoIterator<Item = CurvePointResult<C>>,
790    {
791        Self::msm_results(
792            &scalars.into_iter().collect_vec(),
793            &points.into_iter().collect_vec(),
794        )
795    }
796
797    /// Compute the multiscalar multiplication of the given
798    /// `AuthenticatedScalarResult`s and points
799    pub fn msm_authenticated(
800        scalars: &[AuthenticatedScalarResult<C>],
801        points: &[CurvePointResult<C>],
802    ) -> AuthenticatedPointResult<C> {
803        assert_eq!(
804            scalars.len(),
805            points.len(),
806            "msm cannot compute on vectors of unequal length"
807        );
808
809        let n = scalars.len();
810        let fabric = scalars[0].fabric();
811        let all_ids = scalars
812            .iter()
813            .flat_map(|s| s.ids())
814            .chain(points.iter().map(|p| p.id()))
815            .collect_vec();
816
817        let res = fabric.new_batch_gate_op(
818            all_ids,
819            AUTHENTICATED_POINT_RESULT_LEN, // output_arity
820            move |mut args| {
821                let mut shares = Vec::with_capacity(n);
822                let mut macs = Vec::with_capacity(n);
823                let mut modifiers = Vec::with_capacity(n);
824
825                for mut chunk in args
826                    .drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
827                    .map(Scalar::from)
828                    .chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
829                    .into_iter()
830                {
831                    shares.push(chunk.next().unwrap());
832                    macs.push(chunk.next().unwrap());
833                    modifiers.push(chunk.next().unwrap());
834                }
835
836                let points = args.into_iter().map(CurvePoint::from).collect_vec();
837
838                vec![
839                    CurvePoint::msm(&shares, &points),
840                    CurvePoint::msm(&macs, &points),
841                    CurvePoint::msm(&modifiers, &points),
842                ]
843                .into_iter()
844                .map(ResultValue::Point)
845                .collect_vec()
846            },
847        );
848
849        AuthenticatedPointResult {
850            share: res[0].to_owned().into(),
851            mac: res[1].to_owned().into(),
852            public_modifier: res[2].to_owned(),
853        }
854    }
855
856    /// Compute the multiscalar multiplication of the given
857    /// `AuthenticatedScalarResult`s and points represented as streaming
858    /// iterators
859    pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedPointResult<C>
860    where
861        I: IntoIterator<Item = AuthenticatedScalarResult<C>>,
862        J: IntoIterator<Item = CurvePointResult<C>>,
863    {
864        let scalars: Vec<AuthenticatedScalarResult<C>> = scalars.into_iter().collect();
865        let points: Vec<CurvePointResult<C>> = points.into_iter().collect();
866
867        Self::msm_authenticated(&scalars, &points)
868    }
869}
870
871// ---------
872// | Tests |
873// ---------
874
875#[cfg(test)]
876mod test {
877    use rand::thread_rng;
878
879    use crate::{test_helpers::mock_fabric, test_helpers::TestCurve};
880
881    use super::*;
882
883    /// A curve point on the test curve
884    pub type TestCurvePoint = CurvePoint<TestCurve>;
885
886    /// Generate a random point, by multiplying the basepoint with a random
887    /// scalar
888    pub fn random_point() -> TestCurvePoint {
889        let mut rng = thread_rng();
890        let scalar = Scalar::random(&mut rng);
891        let point = TestCurvePoint::generator() * scalar;
892        point * scalar
893    }
894
895    /// Tests point addition
896    #[tokio::test]
897    async fn test_point_addition() {
898        let fabric = mock_fabric();
899
900        let p1 = random_point();
901        let p2 = random_point();
902
903        let p1_res = fabric.allocate_point(p1);
904        let p2_res = fabric.allocate_point(p2);
905
906        let res = (p1_res + p2_res).await;
907        let expected_res = p1 + p2;
908
909        assert_eq!(res, expected_res);
910        fabric.shutdown();
911    }
912
913    /// Tests scalar multiplication
914    #[tokio::test]
915    async fn test_scalar_mul() {
916        let fabric = mock_fabric();
917
918        let mut rng = thread_rng();
919        let s1 = Scalar::<TestCurve>::random(&mut rng);
920        let p1 = random_point();
921
922        let s1_res = fabric.allocate_scalar(s1);
923        let p1_res = fabric.allocate_point(p1);
924
925        let res = (s1_res * p1_res).await;
926        let expected_res = s1 * p1;
927
928        assert_eq!(res, expected_res);
929        fabric.shutdown();
930    }
931
932    /// Tests addition with the additive identity
933    #[tokio::test]
934    async fn test_additive_identity() {
935        let fabric = mock_fabric();
936
937        let p1 = random_point();
938
939        let p1_res = fabric.allocate_point(p1);
940        let identity_res = fabric.curve_identity();
941
942        let res = (p1_res + identity_res).await;
943        let expected_res = p1;
944
945        assert_eq!(res, expected_res);
946        fabric.shutdown();
947    }
948}