1use super::variant::Variant;
14use blst::{
15 blst_bendian_from_scalar, blst_expand_message_xmd, blst_fp12, blst_fr, blst_fr_add,
16 blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_inverse, blst_fr_mul, blst_fr_sub,
17 blst_hash_to_g1, blst_hash_to_g2, blst_keygen, blst_p1, blst_p1_add_or_double, blst_p1_affine,
18 blst_p1_compress, blst_p1_from_affine, blst_p1_in_g1, blst_p1_is_inf, blst_p1_mult,
19 blst_p1_to_affine, blst_p1_uncompress, blst_p1s_mult_pippenger,
20 blst_p1s_mult_pippenger_scratch_sizeof, blst_p2, blst_p2_add_or_double, blst_p2_affine,
21 blst_p2_compress, blst_p2_from_affine, blst_p2_in_g2, blst_p2_is_inf, blst_p2_mult,
22 blst_p2_to_affine, blst_p2_uncompress, blst_p2s_mult_pippenger,
23 blst_p2s_mult_pippenger_scratch_sizeof, blst_scalar, blst_scalar_from_be_bytes,
24 blst_scalar_from_bendian, blst_scalar_from_fr, blst_sk_check, BLS12_381_G1, BLS12_381_G2,
25 BLST_ERROR,
26};
27use bytes::{Buf, BufMut};
28use commonware_codec::{
29 varint::UInt,
30 EncodeSize,
31 Error::{self, Invalid},
32 FixedSize, Read, ReadExt, Write,
33};
34use commonware_utils::hex;
35use rand::RngCore;
36use std::{
37 fmt::{Debug, Display},
38 hash::{Hash, Hasher},
39 mem::MaybeUninit,
40 ptr,
41};
42use zeroize::{Zeroize, ZeroizeOnDrop};
43
44pub type DST = &'static [u8];
48
49pub trait Element:
51 Read<Cfg = ()> + Write + FixedSize + Clone + Eq + PartialEq + Ord + PartialOrd + Hash + Send + Sync
52{
53 fn zero() -> Self;
55
56 fn one() -> Self;
58
59 fn add(&mut self, rhs: &Self);
61
62 fn mul(&mut self, rhs: &Scalar);
64}
65
66pub trait Point: Element {
68 fn map(&mut self, dst: DST, message: &[u8]);
70
71 fn msm(points: &[Self], scalars: &[Scalar]) -> Self;
73}
74
75#[derive(Clone, Eq, PartialEq)]
86#[repr(transparent)]
87pub struct Scalar(blst_fr);
88
89pub const SCALAR_LENGTH: usize = 32;
95
96const SCALAR_BITS: usize = 255;
101
102const BLST_FR_ONE: Scalar = Scalar(blst_fr {
115 l: [
116 0x0000_0001_ffff_fffe,
117 0x5884_b7fa_0003_4802,
118 0x998c_4fef_ecbc_4ff5,
119 0x1824_b159_acc5_056f,
120 ],
121});
122
123#[derive(Clone, Copy, Eq, PartialEq)]
125#[repr(transparent)]
126pub struct G1(blst_p1);
127
128pub const G1_ELEMENT_BYTE_LENGTH: usize = 48;
130
131pub const G1_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
133
134pub const G1_MESSAGE: DST = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
141
142#[derive(Clone, Copy, Eq, PartialEq)]
144#[repr(transparent)]
145pub struct G2(blst_p2);
146
147pub const G2_ELEMENT_BYTE_LENGTH: usize = 96;
149
150pub const G2_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
152
153pub const G2_MESSAGE: DST = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
160
161#[derive(Debug, Clone, Eq, PartialEq, Copy)]
166#[repr(transparent)]
167pub struct GT(blst_fp12);
168
169pub const GT_ELEMENT_BYTE_LENGTH: usize = 576;
173
174impl GT {
175 pub(crate) fn from_blst_fp12(fp12: blst_fp12) -> Self {
177 GT(fp12)
178 }
179
180 pub fn as_slice(&self) -> [u8; GT_ELEMENT_BYTE_LENGTH] {
182 let mut slice = [0u8; GT_ELEMENT_BYTE_LENGTH];
183 unsafe {
184 let fp12_ptr = &self.0 as *const blst_fp12 as *const u8;
187 std::ptr::copy_nonoverlapping(fp12_ptr, slice.as_mut_ptr(), GT_ELEMENT_BYTE_LENGTH);
188 }
189 slice
190 }
191}
192
193pub type Private = Scalar;
195
196pub const PRIVATE_KEY_LENGTH: usize = SCALAR_LENGTH;
198
199impl Scalar {
200 pub fn from_rand<R: RngCore>(rng: &mut R) -> Self {
202 let mut ikm = [0u8; 64];
204 rng.fill_bytes(&mut ikm);
205
206 let mut ret = blst_fr::default();
208 unsafe {
209 let mut sc = blst_scalar::default();
210 blst_keygen(&mut sc, ikm.as_ptr(), ikm.len(), ptr::null(), 0);
211 blst_fr_from_scalar(&mut ret, &sc);
212 }
213
214 ikm.zeroize();
216
217 Self(ret)
218 }
219
220 pub fn map(dst: DST, msg: &[u8]) -> Self {
222 const L: usize = 48;
236 let mut uniform_bytes = [0u8; L];
237 unsafe {
238 blst_expand_message_xmd(
239 uniform_bytes.as_mut_ptr(),
240 L,
241 msg.as_ptr(),
242 msg.len(),
243 dst.as_ptr(),
244 dst.len(),
245 );
246 }
247
248 let mut fr = blst_fr::default();
250 unsafe {
251 let mut scalar = blst_scalar::default();
252 blst_scalar_from_be_bytes(&mut scalar, uniform_bytes.as_ptr(), L);
253 blst_fr_from_scalar(&mut fr, &scalar);
254 }
255
256 Self(fr)
257 }
258
259 fn from_u64(i: u64) -> Self {
261 let mut ret = blst_fr::default();
263
264 let buffer = [i, 0, 0, 0];
269 unsafe { blst_fr_from_uint64(&mut ret, buffer.as_ptr()) };
270 Self(ret)
271 }
272
273 pub fn from_index(i: u32) -> Self {
275 Self::from(i as u64 + 1)
276 }
277
278 pub fn inverse(&self) -> Option<Self> {
280 if *self == Self::zero() {
281 return None;
282 }
283 let mut ret = blst_fr::default();
284 unsafe { blst_fr_inverse(&mut ret, &self.0) };
285 Some(Self(ret))
286 }
287
288 pub fn sub(&mut self, rhs: &Self) {
290 unsafe { blst_fr_sub(&mut self.0, &self.0, &rhs.0) }
291 }
292
293 fn as_slice(&self) -> [u8; Self::SIZE] {
295 let mut slice = [0u8; Self::SIZE];
296 unsafe {
297 let mut scalar = blst_scalar::default();
298 blst_scalar_from_fr(&mut scalar, &self.0);
299 blst_bendian_from_scalar(slice.as_mut_ptr(), &scalar);
300 }
301 slice
302 }
303
304 pub(crate) fn as_blst_scalar(&self) -> blst_scalar {
306 let mut scalar = blst_scalar::default();
307 unsafe { blst_scalar_from_fr(&mut scalar, &self.0) };
308 scalar
309 }
310}
311
312impl From<u32> for Scalar {
313 fn from(i: u32) -> Self {
314 Self::from(i as u64)
315 }
316}
317
318impl From<u64> for Scalar {
319 fn from(i: u64) -> Self {
320 Self::from_u64(i)
321 }
322}
323
324impl Element for Scalar {
325 fn zero() -> Self {
326 Self(blst_fr::default())
327 }
328
329 fn one() -> Self {
330 BLST_FR_ONE
331 }
332
333 fn add(&mut self, rhs: &Self) {
334 unsafe {
335 blst_fr_add(&mut self.0, &self.0, &rhs.0);
336 }
337 }
338
339 fn mul(&mut self, rhs: &Self) {
340 unsafe {
341 blst_fr_mul(&mut self.0, &self.0, &rhs.0);
342 }
343 }
344}
345
346impl Write for Scalar {
347 fn write(&self, buf: &mut impl BufMut) {
348 let slice = self.as_slice();
349 buf.put_slice(&slice);
350 }
351}
352
353impl Read for Scalar {
354 type Cfg = ();
355
356 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
357 let bytes = <[u8; Self::SIZE]>::read(buf)?;
358 let mut ret = blst_fr::default();
359 unsafe {
360 let mut scalar = blst_scalar::default();
361 blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
362 if !blst_sk_check(&scalar) {
372 return Err(Invalid("Scalar", "Invalid"));
373 }
374 blst_fr_from_scalar(&mut ret, &scalar);
375 }
376 Ok(Self(ret))
377 }
378}
379
380impl FixedSize for Scalar {
381 const SIZE: usize = SCALAR_LENGTH;
382}
383
384impl Hash for Scalar {
385 fn hash<H: Hasher>(&self, state: &mut H) {
386 let slice = self.as_slice();
387 state.write(&slice);
388 }
389}
390
391impl PartialOrd for Scalar {
392 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
393 Some(self.cmp(other))
394 }
395}
396
397impl Ord for Scalar {
398 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
399 self.as_slice().cmp(&other.as_slice())
400 }
401}
402
403impl Debug for Scalar {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 write!(f, "{}", hex(&self.as_slice()))
406 }
407}
408
409impl Display for Scalar {
410 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411 write!(f, "{}", hex(&self.as_slice()))
412 }
413}
414
415impl Zeroize for Scalar {
416 fn zeroize(&mut self) {
417 self.0.l.zeroize();
418 }
419}
420
421impl Drop for Scalar {
422 fn drop(&mut self) {
423 self.zeroize();
424 }
425}
426
427impl ZeroizeOnDrop for Scalar {}
428
429#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
431pub struct Share {
432 pub index: u32,
434 pub private: Private,
436}
437
438impl Share {
439 pub fn public<V: Variant>(&self) -> V::Public {
443 let mut public = V::Public::one();
444 public.mul(&self.private);
445 public
446 }
447}
448
449impl Write for Share {
450 fn write(&self, buf: &mut impl BufMut) {
451 UInt(self.index).write(buf);
452 self.private.write(buf);
453 }
454}
455
456impl Read for Share {
457 type Cfg = ();
458
459 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
460 let index = UInt::read(buf)?.into();
461 let private = Private::read(buf)?;
462 Ok(Self { index, private })
463 }
464}
465
466impl EncodeSize for Share {
467 fn encode_size(&self) -> usize {
468 UInt(self.index).encode_size() + self.private.encode_size()
469 }
470}
471
472impl Display for Share {
473 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
474 write!(f, "Share(index={}, private={})", self.index, self.private)
475 }
476}
477
478impl Debug for Share {
479 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480 write!(f, "Share(index={}, private={})", self.index, self.private)
481 }
482}
483
484impl G1 {
485 fn as_slice(&self) -> [u8; Self::SIZE] {
487 let mut slice = [0u8; Self::SIZE];
488 unsafe {
489 blst_p1_compress(slice.as_mut_ptr(), &self.0);
490 }
491 slice
492 }
493
494 pub(crate) fn as_blst_p1_affine(&self) -> blst_p1_affine {
496 let mut affine = blst_p1_affine::default();
497 unsafe { blst_p1_to_affine(&mut affine, &self.0) };
498 affine
499 }
500
501 pub(crate) fn from_blst_p1(p: blst_p1) -> Self {
503 Self(p)
504 }
505}
506
507impl Element for G1 {
508 fn zero() -> Self {
509 Self(blst_p1::default())
510 }
511
512 fn one() -> Self {
513 let mut ret = blst_p1::default();
514 unsafe {
515 blst_p1_from_affine(&mut ret, &BLS12_381_G1);
516 }
517 Self(ret)
518 }
519
520 fn add(&mut self, rhs: &Self) {
521 unsafe {
522 blst_p1_add_or_double(&mut self.0, &self.0, &rhs.0);
523 }
524 }
525
526 fn mul(&mut self, rhs: &Scalar) {
527 let mut scalar: blst_scalar = blst_scalar::default();
528 unsafe {
529 blst_scalar_from_fr(&mut scalar, &rhs.0);
530 blst_p1_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
533 }
534 }
535}
536
537impl Write for G1 {
538 fn write(&self, buf: &mut impl BufMut) {
539 let slice = self.as_slice();
540 buf.put_slice(&slice);
541 }
542}
543
544impl Read for G1 {
545 type Cfg = ();
546
547 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
548 let bytes = <[u8; Self::SIZE]>::read(buf)?;
549 let mut ret = blst_p1::default();
550 unsafe {
551 let mut affine = blst_p1_affine::default();
552 match blst_p1_uncompress(&mut affine, bytes.as_ptr()) {
553 BLST_ERROR::BLST_SUCCESS => {}
554 BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G1", "Bad encoding")),
555 BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G1", "Not on curve")),
556 BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G1", "Not in group")),
557 BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G1", "Type mismatch")),
558 BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G1", "Verify fail")),
559 BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G1", "PK is Infinity")),
560 BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G1", "Bad scalar")),
561 }
562 blst_p1_from_affine(&mut ret, &affine);
563
564 if blst_p1_is_inf(&ret) {
566 return Err(Invalid("G1", "Infinity"));
567 }
568
569 if !blst_p1_in_g1(&ret) {
571 return Err(Invalid("G1", "Outside G1"));
572 }
573 }
574 Ok(Self(ret))
575 }
576}
577
578impl FixedSize for G1 {
579 const SIZE: usize = G1_ELEMENT_BYTE_LENGTH;
580}
581
582impl Hash for G1 {
583 fn hash<H: Hasher>(&self, state: &mut H) {
584 let slice = self.as_slice();
585 state.write(&slice);
586 }
587}
588
589impl PartialOrd for G1 {
590 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
591 Some(self.cmp(other))
592 }
593}
594
595impl Ord for G1 {
596 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
597 self.as_slice().cmp(&other.as_slice())
598 }
599}
600
601impl Point for G1 {
602 fn map(&mut self, dst: DST, data: &[u8]) {
603 unsafe {
604 blst_hash_to_g1(
605 &mut self.0,
606 data.as_ptr(),
607 data.len(),
608 dst.as_ptr(),
609 dst.len(),
610 ptr::null(),
611 0,
612 );
613 }
614 }
615
616 fn msm(points: &[Self], scalars: &[Scalar]) -> Self {
622 assert_eq!(points.len(), scalars.len(), "mismatched lengths");
624
625 let mut points_filtered = Vec::with_capacity(points.len());
627 let mut scalars_filtered = Vec::with_capacity(scalars.len());
628 for (point, scalar) in points.iter().zip(scalars.iter()) {
629 if *point == G1::zero() || scalar == &Scalar::zero() {
635 continue;
636 }
637
638 points_filtered.push(point.as_blst_p1_affine());
640 scalars_filtered.push(scalar.as_blst_scalar());
641 }
642
643 if points_filtered.is_empty() {
645 return G1::zero();
646 }
647
648 let points: Vec<*const blst_p1_affine> =
651 points_filtered.iter().map(|p| p as *const _).collect();
652 let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
653
654 let scratch_size = unsafe { blst_p1s_mult_pippenger_scratch_sizeof(points.len()) };
656 let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
657
658 let mut msm_result = blst_p1::default();
660 unsafe {
661 blst_p1s_mult_pippenger(
662 &mut msm_result,
663 points.as_ptr(),
664 points.len(),
665 scalars.as_ptr(),
666 SCALAR_BITS, scratch.as_mut_ptr() as *mut _,
668 );
669 }
670
671 G1::from_blst_p1(msm_result)
672 }
673}
674
675impl Debug for G1 {
676 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
677 write!(f, "{}", hex(&self.as_slice()))
678 }
679}
680
681impl Display for G1 {
682 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
683 write!(f, "{}", hex(&self.as_slice()))
684 }
685}
686
687impl AsRef<G1> for G1 {
688 fn as_ref(&self) -> &Self {
689 self
690 }
691}
692
693impl G2 {
694 fn as_slice(&self) -> [u8; Self::SIZE] {
696 let mut slice = [0u8; Self::SIZE];
697 unsafe {
698 blst_p2_compress(slice.as_mut_ptr(), &self.0);
699 }
700 slice
701 }
702
703 pub(crate) fn as_blst_p2_affine(&self) -> blst_p2_affine {
705 let mut affine = blst_p2_affine::default();
706 unsafe { blst_p2_to_affine(&mut affine, &self.0) };
707 affine
708 }
709
710 pub(crate) fn from_blst_p2(p: blst_p2) -> Self {
712 Self(p)
713 }
714}
715
716impl Element for G2 {
717 fn zero() -> Self {
718 Self(blst_p2::default())
719 }
720
721 fn one() -> Self {
722 let mut ret = blst_p2::default();
723 unsafe {
724 blst_p2_from_affine(&mut ret, &BLS12_381_G2);
725 }
726 Self(ret)
727 }
728
729 fn add(&mut self, rhs: &Self) {
730 unsafe {
731 blst_p2_add_or_double(&mut self.0, &self.0, &rhs.0);
732 }
733 }
734
735 fn mul(&mut self, rhs: &Scalar) {
736 let mut scalar = blst_scalar::default();
737 unsafe {
738 blst_scalar_from_fr(&mut scalar, &rhs.0);
739 blst_p2_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
742 }
743 }
744}
745
746impl Write for G2 {
747 fn write(&self, buf: &mut impl BufMut) {
748 let slice = self.as_slice();
749 buf.put_slice(&slice);
750 }
751}
752
753impl Read for G2 {
754 type Cfg = ();
755
756 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
757 let bytes = <[u8; Self::SIZE]>::read(buf)?;
758 let mut ret = blst_p2::default();
759 unsafe {
760 let mut affine = blst_p2_affine::default();
761 match blst_p2_uncompress(&mut affine, bytes.as_ptr()) {
762 BLST_ERROR::BLST_SUCCESS => {}
763 BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G2", "Bad encoding")),
764 BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G2", "Not on curve")),
765 BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G2", "Not in group")),
766 BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G2", "Type mismatch")),
767 BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G2", "Verify fail")),
768 BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G2", "PK is Infinity")),
769 BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G2", "Bad scalar")),
770 }
771 blst_p2_from_affine(&mut ret, &affine);
772
773 if blst_p2_is_inf(&ret) {
775 return Err(Invalid("G2", "Infinity"));
776 }
777
778 if !blst_p2_in_g2(&ret) {
780 return Err(Invalid("G2", "Outside G2"));
781 }
782 }
783 Ok(Self(ret))
784 }
785}
786
787impl FixedSize for G2 {
788 const SIZE: usize = G2_ELEMENT_BYTE_LENGTH;
789}
790
791impl Hash for G2 {
792 fn hash<H: Hasher>(&self, state: &mut H) {
793 let slice = self.as_slice();
794 state.write(&slice);
795 }
796}
797
798impl PartialOrd for G2 {
799 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
800 Some(self.cmp(other))
801 }
802}
803
804impl Ord for G2 {
805 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
806 self.as_slice().cmp(&other.as_slice())
807 }
808}
809
810impl Point for G2 {
811 fn map(&mut self, dst: DST, data: &[u8]) {
812 unsafe {
813 blst_hash_to_g2(
814 &mut self.0,
815 data.as_ptr(),
816 data.len(),
817 dst.as_ptr(),
818 dst.len(),
819 ptr::null(),
820 0,
821 );
822 }
823 }
824
825 fn msm(points: &[Self], scalars: &[Scalar]) -> Self {
831 assert_eq!(points.len(), scalars.len(), "mismatched lengths");
833
834 let mut points_filtered = Vec::with_capacity(points.len());
836 let mut scalars_filtered = Vec::with_capacity(scalars.len());
837 for (point, scalar) in points.iter().zip(scalars.iter()) {
838 if *point == G2::zero() || scalar == &Scalar::zero() {
844 continue;
845 }
846 points_filtered.push(point.as_blst_p2_affine());
847 scalars_filtered.push(scalar.as_blst_scalar());
848 }
849
850 if points_filtered.is_empty() {
852 return G2::zero();
853 }
854
855 let points: Vec<*const blst_p2_affine> =
857 points_filtered.iter().map(|p| p as *const _).collect();
858 let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
859
860 let scratch_size = unsafe { blst_p2s_mult_pippenger_scratch_sizeof(points.len()) };
862 let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
863
864 let mut msm_result = blst_p2::default();
866 unsafe {
867 blst_p2s_mult_pippenger(
868 &mut msm_result,
869 points.as_ptr(),
870 points.len(),
871 scalars.as_ptr(),
872 SCALAR_BITS, scratch.as_mut_ptr() as *mut _,
874 );
875 }
876
877 G2::from_blst_p2(msm_result)
878 }
879}
880
881impl Debug for G2 {
882 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
883 write!(f, "{}", hex(&self.as_slice()))
884 }
885}
886
887impl Display for G2 {
888 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
889 write!(f, "{}", hex(&self.as_slice()))
890 }
891}
892
893impl AsRef<G2> for G2 {
894 fn as_ref(&self) -> &Self {
895 self
896 }
897}
898
899#[cfg(test)]
900mod tests {
901 use super::*;
902 use commonware_codec::{DecodeExt, Encode};
903 use rand::prelude::*;
904 use std::collections::{BTreeSet, HashMap};
905
906 #[test]
907 fn basic_group() {
908 let s = Scalar::from_rand(&mut thread_rng());
910 let mut e1 = s.clone();
911 let e2 = s.clone();
912 let mut s2 = s.clone();
913 s2.add(&s);
914 s2.mul(&s);
915 e1.add(&e2);
916 e1.mul(&e2);
917
918 let mut p1 = G1::zero();
920 p1.mul(&s2);
921
922 let mut p2 = G1::zero();
924 p2.mul(&s);
925 p2.add(&p2.clone());
926 assert_eq!(p1, p2);
927 }
928
929 #[test]
930 fn test_scalar_codec() {
931 let original = Scalar::from_rand(&mut thread_rng());
932 let mut encoded = original.encode();
933 assert_eq!(encoded.len(), Scalar::SIZE);
934 let decoded = Scalar::decode(&mut encoded).unwrap();
935 assert_eq!(original, decoded);
936 }
937
938 #[test]
939 fn test_g1_codec() {
940 let mut original = G1::one();
941 original.mul(&Scalar::from_rand(&mut thread_rng()));
942 let mut encoded = original.encode();
943 assert_eq!(encoded.len(), G1::SIZE);
944 let decoded = G1::decode(&mut encoded).unwrap();
945 assert_eq!(original, decoded);
946 }
947
948 #[test]
949 fn test_g2_codec() {
950 let mut original = G2::one();
951 original.mul(&Scalar::from_rand(&mut thread_rng()));
952 let mut encoded = original.encode();
953 assert_eq!(encoded.len(), G2::SIZE);
954 let decoded = G2::decode(&mut encoded).unwrap();
955 assert_eq!(original, decoded);
956 }
957
958 fn naive_msm<P: Point>(points: &[P], scalars: &[Scalar]) -> P {
960 assert_eq!(points.len(), scalars.len());
961 let mut total = P::zero();
962 for (point, scalar) in points.iter().zip(scalars.iter()) {
963 if *point == P::zero() || *scalar == Scalar::zero() {
965 continue;
966 }
967 let mut term = point.clone();
968 term.mul(scalar);
969 total.add(&term);
970 }
971 total
972 }
973
974 #[test]
975 fn test_g1_msm() {
976 let mut rng = thread_rng();
977 let n = 10; let points_g1: Vec<G1> = (0..n)
981 .map(|_| {
982 let mut point = G1::one();
983 point.mul(&Scalar::from_rand(&mut rng));
984 point
985 })
986 .collect();
987 let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::from_rand(&mut rng)).collect();
988 let expected_g1 = naive_msm(&points_g1, &scalars);
989 let result_g1 = G1::msm(&points_g1, &scalars);
990 assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
991
992 let mut points_with_zero_g1 = points_g1.clone();
994 points_with_zero_g1[n / 2] = G1::zero();
995 let expected_zero_pt_g1 = naive_msm(&points_with_zero_g1, &scalars);
996 let result_zero_pt_g1 = G1::msm(&points_with_zero_g1, &scalars);
997 assert_eq!(
998 expected_zero_pt_g1, result_zero_pt_g1,
999 "G1 MSM with identity point failed"
1000 );
1001
1002 let mut scalars_with_zero = scalars.clone();
1004 scalars_with_zero[n / 2] = Scalar::zero();
1005 let expected_zero_sc_g1 = naive_msm(&points_g1, &scalars_with_zero);
1006 let result_zero_sc_g1 = G1::msm(&points_g1, &scalars_with_zero);
1007 assert_eq!(
1008 expected_zero_sc_g1, result_zero_sc_g1,
1009 "G1 MSM with zero scalar failed"
1010 );
1011
1012 let zero_points_g1 = vec![G1::zero(); n];
1014 let expected_all_zero_pt_g1 = naive_msm(&zero_points_g1, &scalars);
1015 let result_all_zero_pt_g1 = G1::msm(&zero_points_g1, &scalars);
1016 assert_eq!(
1017 expected_all_zero_pt_g1,
1018 G1::zero(),
1019 "G1 MSM all identity points (naive) failed"
1020 );
1021 assert_eq!(
1022 result_all_zero_pt_g1,
1023 G1::zero(),
1024 "G1 MSM all identity points failed"
1025 );
1026
1027 let zero_scalars = vec![Scalar::zero(); n];
1029 let expected_all_zero_sc_g1 = naive_msm(&points_g1, &zero_scalars);
1030 let result_all_zero_sc_g1 = G1::msm(&points_g1, &zero_scalars);
1031 assert_eq!(
1032 expected_all_zero_sc_g1,
1033 G1::zero(),
1034 "G1 MSM all zero scalars (naive) failed"
1035 );
1036 assert_eq!(
1037 result_all_zero_sc_g1,
1038 G1::zero(),
1039 "G1 MSM all zero scalars failed"
1040 );
1041
1042 let single_point_g1 = [points_g1[0]];
1044 let single_scalar = [scalars[0].clone()];
1045 let expected_single_g1 = naive_msm(&single_point_g1, &single_scalar);
1046 let result_single_g1 = G1::msm(&single_point_g1, &single_scalar);
1047 assert_eq!(
1048 expected_single_g1, result_single_g1,
1049 "G1 MSM single element failed"
1050 );
1051
1052 let empty_points_g1: [G1; 0] = [];
1054 let empty_scalars: [Scalar; 0] = [];
1055 let expected_empty_g1 = naive_msm(&empty_points_g1, &empty_scalars);
1056 let result_empty_g1 = G1::msm(&empty_points_g1, &empty_scalars);
1057 assert_eq!(expected_empty_g1, G1::zero(), "G1 MSM empty (naive) failed");
1058 assert_eq!(result_empty_g1, G1::zero(), "G1 MSM empty failed");
1059
1060 let points_g1: Vec<G1> = (0..50_000)
1062 .map(|_| {
1063 let mut point = G1::one();
1064 point.mul(&Scalar::from_rand(&mut rng));
1065 point
1066 })
1067 .collect();
1068 let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::from_rand(&mut rng)).collect();
1069 let expected_g1 = naive_msm(&points_g1, &scalars);
1070 let result_g1 = G1::msm(&points_g1, &scalars);
1071 assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
1072 }
1073
1074 #[test]
1075 fn test_g2_msm() {
1076 let mut rng = thread_rng();
1077 let n = 10; let points_g2: Vec<G2> = (0..n)
1081 .map(|_| {
1082 let mut point = G2::one();
1083 point.mul(&Scalar::from_rand(&mut rng));
1084 point
1085 })
1086 .collect();
1087 let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::from_rand(&mut rng)).collect();
1088 let expected_g2 = naive_msm(&points_g2, &scalars);
1089 let result_g2 = G2::msm(&points_g2, &scalars);
1090 assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
1091
1092 let mut points_with_zero_g2 = points_g2.clone();
1094 points_with_zero_g2[n / 2] = G2::zero();
1095 let expected_zero_pt_g2 = naive_msm(&points_with_zero_g2, &scalars);
1096 let result_zero_pt_g2 = G2::msm(&points_with_zero_g2, &scalars);
1097 assert_eq!(
1098 expected_zero_pt_g2, result_zero_pt_g2,
1099 "G2 MSM with identity point failed"
1100 );
1101
1102 let mut scalars_with_zero = scalars.clone();
1104 scalars_with_zero[n / 2] = Scalar::zero();
1105 let expected_zero_sc_g2 = naive_msm(&points_g2, &scalars_with_zero);
1106 let result_zero_sc_g2 = G2::msm(&points_g2, &scalars_with_zero);
1107 assert_eq!(
1108 expected_zero_sc_g2, result_zero_sc_g2,
1109 "G2 MSM with zero scalar failed"
1110 );
1111
1112 let zero_points_g2 = vec![G2::zero(); n];
1114 let expected_all_zero_pt_g2 = naive_msm(&zero_points_g2, &scalars);
1115 let result_all_zero_pt_g2 = G2::msm(&zero_points_g2, &scalars);
1116 assert_eq!(
1117 expected_all_zero_pt_g2,
1118 G2::zero(),
1119 "G2 MSM all identity points (naive) failed"
1120 );
1121 assert_eq!(
1122 result_all_zero_pt_g2,
1123 G2::zero(),
1124 "G2 MSM all identity points failed"
1125 );
1126
1127 let zero_scalars = vec![Scalar::zero(); n];
1129 let expected_all_zero_sc_g2 = naive_msm(&points_g2, &zero_scalars);
1130 let result_all_zero_sc_g2 = G2::msm(&points_g2, &zero_scalars);
1131 assert_eq!(
1132 expected_all_zero_sc_g2,
1133 G2::zero(),
1134 "G2 MSM all zero scalars (naive) failed"
1135 );
1136 assert_eq!(
1137 result_all_zero_sc_g2,
1138 G2::zero(),
1139 "G2 MSM all zero scalars failed"
1140 );
1141
1142 let single_point_g2 = [points_g2[0]];
1144 let single_scalar = [scalars[0].clone()];
1145 let expected_single_g2 = naive_msm(&single_point_g2, &single_scalar);
1146 let result_single_g2 = G2::msm(&single_point_g2, &single_scalar);
1147 assert_eq!(
1148 expected_single_g2, result_single_g2,
1149 "G2 MSM single element failed"
1150 );
1151
1152 let empty_points_g2: [G2; 0] = [];
1154 let empty_scalars: [Scalar; 0] = [];
1155 let expected_empty_g2 = naive_msm(&empty_points_g2, &empty_scalars);
1156 let result_empty_g2 = G2::msm(&empty_points_g2, &empty_scalars);
1157 assert_eq!(expected_empty_g2, G2::zero(), "G2 MSM empty (naive) failed");
1158 assert_eq!(result_empty_g2, G2::zero(), "G2 MSM empty failed");
1159
1160 let points_g2: Vec<G2> = (0..50_000)
1162 .map(|_| {
1163 let mut point = G2::one();
1164 point.mul(&Scalar::from_rand(&mut rng));
1165 point
1166 })
1167 .collect();
1168 let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::from_rand(&mut rng)).collect();
1169 let expected_g2 = naive_msm(&points_g2, &scalars);
1170 let result_g2 = G2::msm(&points_g2, &scalars);
1171 assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
1172 }
1173
1174 #[test]
1175 fn test_trait_implementations() {
1176 let mut rng = thread_rng();
1178 const NUM_ITEMS: usize = 10;
1179 let mut scalar_set = BTreeSet::new();
1180 let mut g1_set = BTreeSet::new();
1181 let mut g2_set = BTreeSet::new();
1182 let mut share_set = BTreeSet::new();
1183 while scalar_set.len() < NUM_ITEMS {
1184 let scalar = Scalar::from_rand(&mut rng);
1185 let mut g1 = G1::one();
1186 g1.mul(&scalar);
1187 let mut g2 = G2::one();
1188 g2.mul(&scalar);
1189 let share = Share {
1190 index: scalar_set.len() as u32,
1191 private: scalar.clone(),
1192 };
1193
1194 scalar_set.insert(scalar);
1195 g1_set.insert(g1);
1196 g2_set.insert(g2);
1197 share_set.insert(share);
1198 }
1199
1200 assert_eq!(scalar_set.len(), NUM_ITEMS);
1202 assert_eq!(g1_set.len(), NUM_ITEMS);
1203 assert_eq!(g2_set.len(), NUM_ITEMS);
1204 assert_eq!(share_set.len(), NUM_ITEMS);
1205
1206 let scalars: Vec<_> = scalar_set.iter().collect();
1208 assert!(scalars.windows(2).all(|w| w[0] <= w[1]));
1209 let g1s: Vec<_> = g1_set.iter().collect();
1210 assert!(g1s.windows(2).all(|w| w[0] <= w[1]));
1211 let g2s: Vec<_> = g2_set.iter().collect();
1212 assert!(g2s.windows(2).all(|w| w[0] <= w[1]));
1213 let shares: Vec<_> = share_set.iter().collect();
1214 assert!(shares.windows(2).all(|w| w[0] <= w[1]));
1215
1216 let scalar_map: HashMap<_, _> = scalar_set.iter().cloned().zip(0..).collect();
1218 let g1_map: HashMap<_, _> = g1_set.iter().cloned().zip(0..).collect();
1219 let g2_map: HashMap<_, _> = g2_set.iter().cloned().zip(0..).collect();
1220 let share_map: HashMap<_, _> = share_set.iter().cloned().zip(0..).collect();
1221
1222 assert_eq!(scalar_map.len(), NUM_ITEMS);
1224 assert_eq!(g1_map.len(), NUM_ITEMS);
1225 assert_eq!(g2_map.len(), NUM_ITEMS);
1226 assert_eq!(share_map.len(), NUM_ITEMS);
1227 }
1228
1229 #[test]
1230 fn test_scalar_map() {
1231 let msg = b"test message";
1233 let dst = b"TEST_DST";
1234 let scalar1 = Scalar::map(dst, msg);
1235 let scalar2 = Scalar::map(dst, msg);
1236 assert_eq!(scalar1, scalar2, "Same input should produce same output");
1237
1238 let msg2 = b"different message";
1240 let scalar3 = Scalar::map(dst, msg2);
1241 assert_ne!(
1242 scalar1, scalar3,
1243 "Different messages should produce different scalars"
1244 );
1245
1246 let dst2 = b"DIFFERENT_DST";
1248 let scalar4 = Scalar::map(dst2, msg);
1249 assert_ne!(
1250 scalar1, scalar4,
1251 "Different DSTs should produce different scalars"
1252 );
1253
1254 let empty_msg = b"";
1256 let scalar_empty = Scalar::map(dst, empty_msg);
1257 assert_ne!(
1258 scalar_empty,
1259 Scalar::zero(),
1260 "Empty message should not produce zero"
1261 );
1262
1263 let large_msg = vec![0x42u8; 1000];
1265 let scalar_large = Scalar::map(dst, &large_msg);
1266 assert_ne!(
1267 scalar_large,
1268 Scalar::zero(),
1269 "Large message should not produce zero"
1270 );
1271
1272 assert_ne!(
1274 scalar1,
1275 Scalar::zero(),
1276 "Hash should not produce zero scalar"
1277 );
1278 }
1279}