1use super::variant::Variant;
14use blst::{
15 blst_bendian_from_scalar, blst_fp12, blst_fr, blst_fr_add, blst_fr_from_scalar,
16 blst_fr_from_uint64, blst_fr_inverse, blst_fr_mul, blst_fr_sub, blst_hash_to_g1,
17 blst_hash_to_g2, blst_keygen, blst_p1, blst_p1_add_or_double, blst_p1_affine, blst_p1_compress,
18 blst_p1_from_affine, blst_p1_in_g1, blst_p1_is_inf, blst_p1_mult, blst_p1_to_affine,
19 blst_p1_uncompress, blst_p1s_mult_pippenger, blst_p1s_mult_pippenger_scratch_sizeof, blst_p2,
20 blst_p2_add_or_double, blst_p2_affine, blst_p2_compress, blst_p2_from_affine, blst_p2_in_g2,
21 blst_p2_is_inf, blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress, blst_p2s_mult_pippenger,
22 blst_p2s_mult_pippenger_scratch_sizeof, blst_scalar, blst_scalar_from_bendian,
23 blst_scalar_from_fr, blst_sk_check, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
24};
25use bytes::{Buf, BufMut};
26use commonware_codec::{
27 varint::UInt,
28 EncodeSize,
29 Error::{self, Invalid},
30 FixedSize, Read, ReadExt, Write,
31};
32use commonware_utils::hex;
33use rand::RngCore;
34use std::{
35 fmt::{Debug, Display},
36 hash::{Hash, Hasher},
37 mem::MaybeUninit,
38 ptr,
39};
40use zeroize::{Zeroize, ZeroizeOnDrop};
41
42pub type DST = &'static [u8];
46
47pub trait Element:
49 Read<Cfg = ()> + Write + FixedSize + Clone + Eq + PartialEq + Send + Sync
50{
51 fn zero() -> Self;
53
54 fn one() -> Self;
56
57 fn add(&mut self, rhs: &Self);
59
60 fn mul(&mut self, rhs: &Scalar);
62}
63
64pub trait Point: Element {
66 fn map(&mut self, dst: DST, message: &[u8]);
68
69 fn msm(points: &[Self], scalars: &[Scalar]) -> Self;
71}
72
73#[derive(Clone, Eq, PartialEq)]
84#[repr(transparent)]
85pub struct Scalar(blst_fr);
86
87const SCALAR_LENGTH: usize = 32;
93
94const SCALAR_BITS: usize = 255;
99
100const BLST_FR_ONE: Scalar = Scalar(blst_fr {
113 l: [
114 0x0000_0001_ffff_fffe,
115 0x5884_b7fa_0003_4802,
116 0x998c_4fef_ecbc_4ff5,
117 0x1824_b159_acc5_056f,
118 ],
119});
120
121#[derive(Clone, Copy, Eq, PartialEq)]
123#[repr(transparent)]
124pub struct G1(blst_p1);
125
126pub const G1_ELEMENT_BYTE_LENGTH: usize = 48;
128
129pub const G1_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
131
132pub const G1_MESSAGE: DST = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
139
140#[derive(Clone, Copy, Eq, PartialEq)]
142#[repr(transparent)]
143pub struct G2(blst_p2);
144
145pub const G2_ELEMENT_BYTE_LENGTH: usize = 96;
147
148pub const G2_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
150
151pub const G2_MESSAGE: DST = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
158
159#[derive(Debug, Clone, Copy, Eq, PartialEq)]
164pub struct GT(blst_fp12);
165
166pub type Private = Scalar;
168
169pub const PRIVATE_KEY_LENGTH: usize = SCALAR_LENGTH;
171
172impl Scalar {
173 pub fn rand<R: RngCore>(rng: &mut R) -> Self {
175 let mut ikm = [0u8; 64];
177 rng.fill_bytes(&mut ikm);
178
179 let mut ret = blst_fr::default();
181 unsafe {
182 let mut sc = blst_scalar::default();
183 blst_keygen(&mut sc, ikm.as_ptr(), ikm.len(), ptr::null(), 0);
184 blst_fr_from_scalar(&mut ret, &sc);
185 }
186
187 ikm.zeroize();
189 Self(ret)
190 }
191
192 pub fn set_int(&mut self, i: u32) {
194 let buffer = [i as u64, 0, 0, 0];
199 unsafe { blst_fr_from_uint64(&mut self.0, buffer.as_ptr()) };
200 }
201
202 pub fn inverse(&self) -> Option<Self> {
204 if *self == Self::zero() {
205 return None;
206 }
207 let mut ret = blst_fr::default();
208 unsafe { blst_fr_inverse(&mut ret, &self.0) };
209 Some(Self(ret))
210 }
211
212 pub fn sub(&mut self, rhs: &Self) {
214 unsafe { blst_fr_sub(&mut self.0, &self.0, &rhs.0) }
215 }
216
217 fn as_slice(&self) -> [u8; Self::SIZE] {
219 let mut slice = [0u8; Self::SIZE];
220 unsafe {
221 let mut scalar = blst_scalar::default();
222 blst_scalar_from_fr(&mut scalar, &self.0);
223 blst_bendian_from_scalar(slice.as_mut_ptr(), &scalar);
224 }
225 slice
226 }
227
228 pub(crate) fn as_blst_scalar(&self) -> blst_scalar {
230 let mut scalar = blst_scalar::default();
231 unsafe { blst_scalar_from_fr(&mut scalar, &self.0) };
232 scalar
233 }
234}
235
236impl Element for Scalar {
237 fn zero() -> Self {
238 Self(blst_fr::default())
239 }
240
241 fn one() -> Self {
242 BLST_FR_ONE
243 }
244
245 fn add(&mut self, rhs: &Self) {
246 unsafe {
247 blst_fr_add(&mut self.0, &self.0, &rhs.0);
248 }
249 }
250
251 fn mul(&mut self, rhs: &Self) {
252 unsafe {
253 blst_fr_mul(&mut self.0, &self.0, &rhs.0);
254 }
255 }
256}
257
258impl Write for Scalar {
259 fn write(&self, buf: &mut impl BufMut) {
260 let slice = self.as_slice();
261 buf.put_slice(&slice);
262 }
263}
264
265impl Read for Scalar {
266 type Cfg = ();
267
268 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
269 let bytes = <[u8; Self::SIZE]>::read(buf)?;
270 let mut ret = blst_fr::default();
271 unsafe {
272 let mut scalar = blst_scalar::default();
273 blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
274 if !blst_sk_check(&scalar) {
284 return Err(Invalid("Scalar", "Invalid"));
285 }
286 blst_fr_from_scalar(&mut ret, &scalar);
287 }
288 Ok(Self(ret))
289 }
290}
291
292impl FixedSize for Scalar {
293 const SIZE: usize = SCALAR_LENGTH;
294}
295
296impl Hash for Scalar {
297 fn hash<H: Hasher>(&self, state: &mut H) {
298 let slice = self.as_slice();
299 state.write(&slice);
300 }
301}
302
303impl Debug for Scalar {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 write!(f, "{}", hex(&self.as_slice()))
306 }
307}
308
309impl Display for Scalar {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 write!(f, "{}", hex(&self.as_slice()))
312 }
313}
314
315impl Zeroize for Scalar {
316 fn zeroize(&mut self) {
317 self.0.l.zeroize();
318 }
319}
320
321impl Drop for Scalar {
322 fn drop(&mut self) {
323 self.zeroize();
324 }
325}
326
327impl ZeroizeOnDrop for Scalar {}
328
329#[derive(Clone, PartialEq, Hash)]
331pub struct Share {
332 pub index: u32,
334 pub private: Private,
336}
337
338impl Share {
339 pub fn public<V: Variant>(&self) -> V::Public {
343 let mut public = V::Public::one();
344 public.mul(&self.private);
345 public
346 }
347}
348
349impl Write for Share {
350 fn write(&self, buf: &mut impl BufMut) {
351 UInt(self.index).write(buf);
352 self.private.write(buf);
353 }
354}
355
356impl Read for Share {
357 type Cfg = ();
358
359 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
360 let index = UInt::read(buf)?.into();
361 let private = Private::read(buf)?;
362 Ok(Self { index, private })
363 }
364}
365
366impl EncodeSize for Share {
367 fn encode_size(&self) -> usize {
368 UInt(self.index).encode_size() + self.private.encode_size()
369 }
370}
371
372impl Display for Share {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 write!(f, "Share(index={}, private={})", self.index, self.private)
375 }
376}
377
378impl Debug for Share {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 write!(f, "Share(index={}, private={})", self.index, self.private)
381 }
382}
383
384impl G1 {
385 fn as_slice(&self) -> [u8; Self::SIZE] {
387 let mut slice = [0u8; Self::SIZE];
388 unsafe {
389 blst_p1_compress(slice.as_mut_ptr(), &self.0);
390 }
391 slice
392 }
393
394 pub(crate) fn as_blst_p1_affine(&self) -> blst_p1_affine {
396 let mut affine = blst_p1_affine::default();
397 unsafe { blst_p1_to_affine(&mut affine, &self.0) };
398 affine
399 }
400
401 pub(crate) fn from_blst_p1(p: blst_p1) -> Self {
403 Self(p)
404 }
405}
406
407impl Element for G1 {
408 fn zero() -> Self {
409 Self(blst_p1::default())
410 }
411
412 fn one() -> Self {
413 let mut ret = blst_p1::default();
414 unsafe {
415 blst_p1_from_affine(&mut ret, &BLS12_381_G1);
416 }
417 Self(ret)
418 }
419
420 fn add(&mut self, rhs: &Self) {
421 unsafe {
422 blst_p1_add_or_double(&mut self.0, &self.0, &rhs.0);
423 }
424 }
425
426 fn mul(&mut self, rhs: &Scalar) {
427 let mut scalar: blst_scalar = blst_scalar::default();
428 unsafe {
429 blst_scalar_from_fr(&mut scalar, &rhs.0);
430 blst_p1_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
433 }
434 }
435}
436
437impl Write for G1 {
438 fn write(&self, buf: &mut impl BufMut) {
439 let slice = self.as_slice();
440 buf.put_slice(&slice);
441 }
442}
443
444impl Read for G1 {
445 type Cfg = ();
446
447 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
448 let bytes = <[u8; Self::SIZE]>::read(buf)?;
449 let mut ret = blst_p1::default();
450 unsafe {
451 let mut affine = blst_p1_affine::default();
452 match blst_p1_uncompress(&mut affine, bytes.as_ptr()) {
453 BLST_ERROR::BLST_SUCCESS => {}
454 BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G1", "Bad encoding")),
455 BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G1", "Not on curve")),
456 BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G1", "Not in group")),
457 BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G1", "Type mismatch")),
458 BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G1", "Verify fail")),
459 BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G1", "PK is Infinity")),
460 BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G1", "Bad scalar")),
461 }
462 blst_p1_from_affine(&mut ret, &affine);
463
464 if blst_p1_is_inf(&ret) {
466 return Err(Invalid("G1", "Infinity"));
467 }
468
469 if !blst_p1_in_g1(&ret) {
471 return Err(Invalid("G1", "Outside G1"));
472 }
473 }
474 Ok(Self(ret))
475 }
476}
477
478impl FixedSize for G1 {
479 const SIZE: usize = G1_ELEMENT_BYTE_LENGTH;
480}
481
482impl Hash for G1 {
483 fn hash<H: Hasher>(&self, state: &mut H) {
484 let slice = self.as_slice();
485 state.write(&slice);
486 }
487}
488
489impl Point for G1 {
490 fn map(&mut self, dst: DST, data: &[u8]) {
491 unsafe {
492 blst_hash_to_g1(
493 &mut self.0,
494 data.as_ptr(),
495 data.len(),
496 dst.as_ptr(),
497 dst.len(),
498 ptr::null(),
499 0,
500 );
501 }
502 }
503
504 fn msm(points: &[Self], scalars: &[Scalar]) -> Self {
510 assert_eq!(points.len(), scalars.len(), "mismatched lengths");
512
513 let mut points_filtered = Vec::with_capacity(points.len());
515 let mut scalars_filtered = Vec::with_capacity(scalars.len());
516 for (point, scalar) in points.iter().zip(scalars.iter()) {
517 if *point == G1::zero() || scalar == &Scalar::zero() {
523 continue;
524 }
525
526 points_filtered.push(point.as_blst_p1_affine());
528 scalars_filtered.push(scalar.as_blst_scalar());
529 }
530
531 if points_filtered.is_empty() {
533 return G1::zero();
534 }
535
536 let points: Vec<*const blst_p1_affine> =
539 points_filtered.iter().map(|p| p as *const _).collect();
540 let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
541
542 let scratch_size = unsafe { blst_p1s_mult_pippenger_scratch_sizeof(points.len()) };
544 let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
545
546 let mut msm_result = blst_p1::default();
548 unsafe {
549 blst_p1s_mult_pippenger(
550 &mut msm_result,
551 points.as_ptr(),
552 points.len(),
553 scalars.as_ptr(),
554 SCALAR_BITS, scratch.as_mut_ptr() as *mut _,
556 );
557 }
558
559 G1::from_blst_p1(msm_result)
560 }
561}
562
563impl Debug for G1 {
564 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565 write!(f, "{}", hex(&self.as_slice()))
566 }
567}
568
569impl Display for G1 {
570 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
571 write!(f, "{}", hex(&self.as_slice()))
572 }
573}
574
575impl AsRef<G1> for G1 {
576 fn as_ref(&self) -> &Self {
577 self
578 }
579}
580
581impl G2 {
582 fn as_slice(&self) -> [u8; Self::SIZE] {
584 let mut slice = [0u8; Self::SIZE];
585 unsafe {
586 blst_p2_compress(slice.as_mut_ptr(), &self.0);
587 }
588 slice
589 }
590
591 pub(crate) fn as_blst_p2_affine(&self) -> blst_p2_affine {
593 let mut affine = blst_p2_affine::default();
594 unsafe { blst_p2_to_affine(&mut affine, &self.0) };
595 affine
596 }
597
598 pub(crate) fn from_blst_p2(p: blst_p2) -> Self {
600 Self(p)
601 }
602}
603
604impl Element for G2 {
605 fn zero() -> Self {
606 Self(blst_p2::default())
607 }
608
609 fn one() -> Self {
610 let mut ret = blst_p2::default();
611 unsafe {
612 blst_p2_from_affine(&mut ret, &BLS12_381_G2);
613 }
614 Self(ret)
615 }
616
617 fn add(&mut self, rhs: &Self) {
618 unsafe {
619 blst_p2_add_or_double(&mut self.0, &self.0, &rhs.0);
620 }
621 }
622
623 fn mul(&mut self, rhs: &Scalar) {
624 let mut scalar = blst_scalar::default();
625 unsafe {
626 blst_scalar_from_fr(&mut scalar, &rhs.0);
627 blst_p2_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
630 }
631 }
632}
633
634impl Write for G2 {
635 fn write(&self, buf: &mut impl BufMut) {
636 let slice = self.as_slice();
637 buf.put_slice(&slice);
638 }
639}
640
641impl Read for G2 {
642 type Cfg = ();
643
644 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
645 let bytes = <[u8; Self::SIZE]>::read(buf)?;
646 let mut ret = blst_p2::default();
647 unsafe {
648 let mut affine = blst_p2_affine::default();
649 match blst_p2_uncompress(&mut affine, bytes.as_ptr()) {
650 BLST_ERROR::BLST_SUCCESS => {}
651 BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G2", "Bad encoding")),
652 BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G2", "Not on curve")),
653 BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G2", "Not in group")),
654 BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G2", "Type mismatch")),
655 BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G2", "Verify fail")),
656 BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G2", "PK is Infinity")),
657 BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G2", "Bad scalar")),
658 }
659 blst_p2_from_affine(&mut ret, &affine);
660
661 if blst_p2_is_inf(&ret) {
663 return Err(Invalid("G2", "Infinity"));
664 }
665
666 if !blst_p2_in_g2(&ret) {
668 return Err(Invalid("G2", "Outside G2"));
669 }
670 }
671 Ok(Self(ret))
672 }
673}
674
675impl FixedSize for G2 {
676 const SIZE: usize = G2_ELEMENT_BYTE_LENGTH;
677}
678
679impl Hash for G2 {
680 fn hash<H: Hasher>(&self, state: &mut H) {
681 let slice = self.as_slice();
682 state.write(&slice);
683 }
684}
685
686impl Point for G2 {
687 fn map(&mut self, dst: DST, data: &[u8]) {
688 unsafe {
689 blst_hash_to_g2(
690 &mut self.0,
691 data.as_ptr(),
692 data.len(),
693 dst.as_ptr(),
694 dst.len(),
695 ptr::null(),
696 0,
697 );
698 }
699 }
700
701 fn msm(points: &[Self], scalars: &[Scalar]) -> Self {
707 assert_eq!(points.len(), scalars.len(), "mismatched lengths");
709
710 let mut points_filtered = Vec::with_capacity(points.len());
712 let mut scalars_filtered = Vec::with_capacity(scalars.len());
713 for (point, scalar) in points.iter().zip(scalars.iter()) {
714 if *point == G2::zero() || scalar == &Scalar::zero() {
720 continue;
721 }
722 points_filtered.push(point.as_blst_p2_affine());
723 scalars_filtered.push(scalar.as_blst_scalar());
724 }
725
726 if points_filtered.is_empty() {
728 return G2::zero();
729 }
730
731 let points: Vec<*const blst_p2_affine> =
733 points_filtered.iter().map(|p| p as *const _).collect();
734 let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
735
736 let scratch_size = unsafe { blst_p2s_mult_pippenger_scratch_sizeof(points.len()) };
738 let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
739
740 let mut msm_result = blst_p2::default();
742 unsafe {
743 blst_p2s_mult_pippenger(
744 &mut msm_result,
745 points.as_ptr(),
746 points.len(),
747 scalars.as_ptr(),
748 SCALAR_BITS, scratch.as_mut_ptr() as *mut _,
750 );
751 }
752
753 G2::from_blst_p2(msm_result)
754 }
755}
756
757impl Debug for G2 {
758 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
759 write!(f, "{}", hex(&self.as_slice()))
760 }
761}
762
763impl Display for G2 {
764 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
765 write!(f, "{}", hex(&self.as_slice()))
766 }
767}
768
769impl AsRef<G2> for G2 {
770 fn as_ref(&self) -> &Self {
771 self
772 }
773}
774
775#[cfg(test)]
776mod tests {
777 use super::*;
778 use commonware_codec::{DecodeExt, Encode};
779 use rand::prelude::*;
780
781 #[test]
782 fn basic_group() {
783 let s = Scalar::rand(&mut thread_rng());
785 let mut e1 = s.clone();
786 let e2 = s.clone();
787 let mut s2 = s.clone();
788 s2.add(&s);
789 s2.mul(&s);
790 e1.add(&e2);
791 e1.mul(&e2);
792
793 let mut p1 = G1::zero();
795 p1.mul(&s2);
796
797 let mut p2 = G1::zero();
799 p2.mul(&s);
800 p2.add(&p2.clone());
801 assert_eq!(p1, p2);
802 }
803
804 #[test]
805 fn test_scalar_codec() {
806 let original = Scalar::rand(&mut thread_rng());
807 let mut encoded = original.encode();
808 assert_eq!(encoded.len(), Scalar::SIZE);
809 let decoded = Scalar::decode(&mut encoded).unwrap();
810 assert_eq!(original, decoded);
811 }
812
813 #[test]
814 fn test_g1_codec() {
815 let mut original = G1::one();
816 original.mul(&Scalar::rand(&mut thread_rng()));
817 let mut encoded = original.encode();
818 assert_eq!(encoded.len(), G1::SIZE);
819 let decoded = G1::decode(&mut encoded).unwrap();
820 assert_eq!(original, decoded);
821 }
822
823 #[test]
824 fn test_g2_codec() {
825 let mut original = G2::one();
826 original.mul(&Scalar::rand(&mut thread_rng()));
827 let mut encoded = original.encode();
828 assert_eq!(encoded.len(), G2::SIZE);
829 let decoded = G2::decode(&mut encoded).unwrap();
830 assert_eq!(original, decoded);
831 }
832
833 fn naive_msm<P: Point>(points: &[P], scalars: &[Scalar]) -> P {
835 assert_eq!(points.len(), scalars.len());
836 let mut total = P::zero();
837 for (point, scalar) in points.iter().zip(scalars.iter()) {
838 if *point == P::zero() || *scalar == Scalar::zero() {
840 continue;
841 }
842 let mut term = point.clone();
843 term.mul(scalar);
844 total.add(&term);
845 }
846 total
847 }
848
849 #[test]
850 fn test_g1_msm() {
851 let mut rng = thread_rng();
852 let n = 10; let points_g1: Vec<G1> = (0..n)
856 .map(|_| {
857 let mut point = G1::one();
858 point.mul(&Scalar::rand(&mut rng));
859 point
860 })
861 .collect();
862 let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::rand(&mut rng)).collect();
863 let expected_g1 = naive_msm(&points_g1, &scalars);
864 let result_g1 = G1::msm(&points_g1, &scalars);
865 assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
866
867 let mut points_with_zero_g1 = points_g1.clone();
869 points_with_zero_g1[n / 2] = G1::zero();
870 let expected_zero_pt_g1 = naive_msm(&points_with_zero_g1, &scalars);
871 let result_zero_pt_g1 = G1::msm(&points_with_zero_g1, &scalars);
872 assert_eq!(
873 expected_zero_pt_g1, result_zero_pt_g1,
874 "G1 MSM with identity point failed"
875 );
876
877 let mut scalars_with_zero = scalars.clone();
879 scalars_with_zero[n / 2] = Scalar::zero();
880 let expected_zero_sc_g1 = naive_msm(&points_g1, &scalars_with_zero);
881 let result_zero_sc_g1 = G1::msm(&points_g1, &scalars_with_zero);
882 assert_eq!(
883 expected_zero_sc_g1, result_zero_sc_g1,
884 "G1 MSM with zero scalar failed"
885 );
886
887 let zero_points_g1 = vec![G1::zero(); n];
889 let expected_all_zero_pt_g1 = naive_msm(&zero_points_g1, &scalars);
890 let result_all_zero_pt_g1 = G1::msm(&zero_points_g1, &scalars);
891 assert_eq!(
892 expected_all_zero_pt_g1,
893 G1::zero(),
894 "G1 MSM all identity points (naive) failed"
895 );
896 assert_eq!(
897 result_all_zero_pt_g1,
898 G1::zero(),
899 "G1 MSM all identity points failed"
900 );
901
902 let zero_scalars = vec![Scalar::zero(); n];
904 let expected_all_zero_sc_g1 = naive_msm(&points_g1, &zero_scalars);
905 let result_all_zero_sc_g1 = G1::msm(&points_g1, &zero_scalars);
906 assert_eq!(
907 expected_all_zero_sc_g1,
908 G1::zero(),
909 "G1 MSM all zero scalars (naive) failed"
910 );
911 assert_eq!(
912 result_all_zero_sc_g1,
913 G1::zero(),
914 "G1 MSM all zero scalars failed"
915 );
916
917 let single_point_g1 = [points_g1[0]];
919 let single_scalar = [scalars[0].clone()];
920 let expected_single_g1 = naive_msm(&single_point_g1, &single_scalar);
921 let result_single_g1 = G1::msm(&single_point_g1, &single_scalar);
922 assert_eq!(
923 expected_single_g1, result_single_g1,
924 "G1 MSM single element failed"
925 );
926
927 let empty_points_g1: [G1; 0] = [];
929 let empty_scalars: [Scalar; 0] = [];
930 let expected_empty_g1 = naive_msm(&empty_points_g1, &empty_scalars);
931 let result_empty_g1 = G1::msm(&empty_points_g1, &empty_scalars);
932 assert_eq!(expected_empty_g1, G1::zero(), "G1 MSM empty (naive) failed");
933 assert_eq!(result_empty_g1, G1::zero(), "G1 MSM empty failed");
934
935 let points_g1: Vec<G1> = (0..50_000)
937 .map(|_| {
938 let mut point = G1::one();
939 point.mul(&Scalar::rand(&mut rng));
940 point
941 })
942 .collect();
943 let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::rand(&mut rng)).collect();
944 let expected_g1 = naive_msm(&points_g1, &scalars);
945 let result_g1 = G1::msm(&points_g1, &scalars);
946 assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
947 }
948
949 #[test]
950 fn test_g2_msm() {
951 let mut rng = thread_rng();
952 let n = 10; let points_g2: Vec<G2> = (0..n)
956 .map(|_| {
957 let mut point = G2::one();
958 point.mul(&Scalar::rand(&mut rng));
959 point
960 })
961 .collect();
962 let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::rand(&mut rng)).collect();
963 let expected_g2 = naive_msm(&points_g2, &scalars);
964 let result_g2 = G2::msm(&points_g2, &scalars);
965 assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
966
967 let mut points_with_zero_g2 = points_g2.clone();
969 points_with_zero_g2[n / 2] = G2::zero();
970 let expected_zero_pt_g2 = naive_msm(&points_with_zero_g2, &scalars);
971 let result_zero_pt_g2 = G2::msm(&points_with_zero_g2, &scalars);
972 assert_eq!(
973 expected_zero_pt_g2, result_zero_pt_g2,
974 "G2 MSM with identity point failed"
975 );
976
977 let mut scalars_with_zero = scalars.clone();
979 scalars_with_zero[n / 2] = Scalar::zero();
980 let expected_zero_sc_g2 = naive_msm(&points_g2, &scalars_with_zero);
981 let result_zero_sc_g2 = G2::msm(&points_g2, &scalars_with_zero);
982 assert_eq!(
983 expected_zero_sc_g2, result_zero_sc_g2,
984 "G2 MSM with zero scalar failed"
985 );
986
987 let zero_points_g2 = vec![G2::zero(); n];
989 let expected_all_zero_pt_g2 = naive_msm(&zero_points_g2, &scalars);
990 let result_all_zero_pt_g2 = G2::msm(&zero_points_g2, &scalars);
991 assert_eq!(
992 expected_all_zero_pt_g2,
993 G2::zero(),
994 "G2 MSM all identity points (naive) failed"
995 );
996 assert_eq!(
997 result_all_zero_pt_g2,
998 G2::zero(),
999 "G2 MSM all identity points failed"
1000 );
1001
1002 let zero_scalars = vec![Scalar::zero(); n];
1004 let expected_all_zero_sc_g2 = naive_msm(&points_g2, &zero_scalars);
1005 let result_all_zero_sc_g2 = G2::msm(&points_g2, &zero_scalars);
1006 assert_eq!(
1007 expected_all_zero_sc_g2,
1008 G2::zero(),
1009 "G2 MSM all zero scalars (naive) failed"
1010 );
1011 assert_eq!(
1012 result_all_zero_sc_g2,
1013 G2::zero(),
1014 "G2 MSM all zero scalars failed"
1015 );
1016
1017 let single_point_g2 = [points_g2[0]];
1019 let single_scalar = [scalars[0].clone()];
1020 let expected_single_g2 = naive_msm(&single_point_g2, &single_scalar);
1021 let result_single_g2 = G2::msm(&single_point_g2, &single_scalar);
1022 assert_eq!(
1023 expected_single_g2, result_single_g2,
1024 "G2 MSM single element failed"
1025 );
1026
1027 let empty_points_g2: [G2; 0] = [];
1029 let empty_scalars: [Scalar; 0] = [];
1030 let expected_empty_g2 = naive_msm(&empty_points_g2, &empty_scalars);
1031 let result_empty_g2 = G2::msm(&empty_points_g2, &empty_scalars);
1032 assert_eq!(expected_empty_g2, G2::zero(), "G2 MSM empty (naive) failed");
1033 assert_eq!(result_empty_g2, G2::zero(), "G2 MSM empty failed");
1034
1035 let points_g2: Vec<G2> = (0..50_000)
1037 .map(|_| {
1038 let mut point = G2::one();
1039 point.mul(&Scalar::rand(&mut rng));
1040 point
1041 })
1042 .collect();
1043 let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::rand(&mut rng)).collect();
1044 let expected_g2 = naive_msm(&points_g2, &scalars);
1045 let result_g2 = G2::msm(&points_g2, &scalars);
1046 assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
1047 }
1048}