commonware_cryptography/bls12381/primitives/
group.rs

1//! Group operations over the BLS12-381 scalar field.
2//!
3//! This crate implements basic group operations over BLS12-381 elements,
4//! including point addition, scalar multiplication, and pairing operations.
5//!
6//! # Warning
7//!
8//! Ensure that points are checked to belong to the correct subgroup
9//! (G1 or G2) to prevent small subgroup attacks. This is particularly important
10//! when handling deserialized points or points received from untrusted sources. This
11//! is already taken care of for you if you use the provided `deserialize` function.
12
13use super::variant::Variant;
14#[cfg(not(feature = "std"))]
15use alloc::{vec, vec::Vec};
16use blst::{
17    blst_bendian_from_fp12, blst_bendian_from_scalar, blst_expand_message_xmd, blst_fp12, blst_fr,
18    blst_fr_add, blst_fr_cneg, blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_inverse,
19    blst_fr_mul, blst_fr_sub, blst_hash_to_g1, blst_hash_to_g2, blst_keygen, blst_p1,
20    blst_p1_add_or_double, blst_p1_affine, blst_p1_cneg, blst_p1_compress, blst_p1_from_affine,
21    blst_p1_in_g1, blst_p1_is_inf, blst_p1_mult, blst_p1_to_affine, blst_p1_uncompress,
22    blst_p1s_mult_pippenger, blst_p1s_mult_pippenger_scratch_sizeof, blst_p2,
23    blst_p2_add_or_double, blst_p2_affine, blst_p2_cneg, blst_p2_compress, blst_p2_from_affine,
24    blst_p2_in_g2, blst_p2_is_inf, blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress,
25    blst_p2s_mult_pippenger, blst_p2s_mult_pippenger_scratch_sizeof, blst_scalar,
26    blst_scalar_from_be_bytes, blst_scalar_from_bendian, blst_scalar_from_fr, blst_sk_check,
27    BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
28};
29use bytes::{Buf, BufMut};
30use commonware_codec::{
31    varint::UInt,
32    EncodeSize,
33    Error::{self, Invalid},
34    FixedSize, Read, ReadExt, Write,
35};
36use commonware_math::algebra::{
37    Additive, CryptoGroup, Field, HashToGroup, Multiplicative, Object, Random, Ring, Space,
38};
39use commonware_utils::hex;
40use core::{
41    fmt::{Debug, Display, Formatter},
42    hash::{Hash, Hasher},
43    mem::MaybeUninit,
44    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
45    ptr,
46};
47use rand_core::CryptoRngCore;
48use zeroize::{Zeroize, ZeroizeOnDrop};
49
50/// Domain separation tag used when hashing a message to a curve (G1 or G2).
51///
52/// Reference: <https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-bls-signature-05#name-ciphersuites>
53pub type DST = &'static [u8];
54
55/// Wrapper around [blst_fr] that represents an element of the BLS12‑381
56/// scalar field `F_r`.
57///
58/// The new‑type is marked `#[repr(transparent)]`, so it has exactly the same
59/// memory layout as the underlying `blst_fr`, allowing safe passage across
60/// the C FFI boundary without additional transmutation.
61///
62/// All arithmetic is performed modulo the prime
63/// `r = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001`,
64/// the order of the BLS12‑381 G1/G2 groups.
65#[derive(Clone, Eq, PartialEq)]
66#[repr(transparent)]
67pub struct Scalar(blst_fr);
68
69#[cfg(feature = "arbitrary")]
70impl arbitrary::Arbitrary<'_> for Scalar {
71    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
72        // Generate 32 bytes and convert to scalar with automatic modular reduction
73        let bytes = u.arbitrary::<[u8; SCALAR_LENGTH]>()?;
74        let mut fr = blst_fr::default();
75        // SAFETY: bytes is a valid 32-byte array; blst_scalar_from_bendian handles reduction.
76        unsafe {
77            let mut scalar = blst_scalar::default();
78            blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
79            blst_fr_from_scalar(&mut fr, &scalar);
80        }
81        let result = Self(fr);
82        // We avoid generating zero scalars, since this module assumes that scalars
83        // can't be zero, since they're punned to private keys.
84        if result == <Self as Additive>::zero() {
85            Ok(BLST_FR_ONE)
86        } else {
87            Ok(result)
88        }
89    }
90}
91
92/// Number of bytes required to encode a scalar in its canonical
93/// little‑endian form (`32 × 8 = 256 bits`).
94///
95/// Because `r` is only 255 bits wide, the most‑significant byte is always in
96/// the range `0x00‥=0x7f`, leaving the top bit clear.
97pub const SCALAR_LENGTH: usize = 32;
98
99/// Effective bit‑length of the field modulus `r` (`⌈log_2 r⌉ = 255`).
100///
101/// Useful for constant‑time exponentiation loops and for validating that a
102/// decoded integer lies in the range `0 ≤ x < r`.
103const SCALAR_BITS: usize = 255;
104
105/// This constant serves as the multiplicative identity (i.e., "one") in the
106/// BLS12-381 finite field, ensuring that arithmetic is carried out within the
107/// correct modulo.
108///
109/// `R = 2^256 mod q` in little-endian Montgomery form which is equivalent to 1 in little-endian
110/// non-Montgomery form:
111///
112/// ```txt
113/// mod(2^256, 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001) = 0x1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe
114/// ```
115///
116/// Reference: <https://github.com/filecoin-project/blstrs/blob/ffbb41d1495d84e40a712583346439924603b49a/src/scalar.rs#L77-L89>
117const BLST_FR_ONE: Scalar = Scalar(blst_fr {
118    l: [
119        0x0000_0001_ffff_fffe,
120        0x5884_b7fa_0003_4802,
121        0x998c_4fef_ecbc_4ff5,
122        0x1824_b159_acc5_056f,
123    ],
124});
125
126/// A point on the BLS12-381 G1 curve.
127#[derive(Clone, Copy, Eq, PartialEq)]
128#[repr(transparent)]
129pub struct G1(blst_p1);
130
131/// The size in bytes of an encoded G1 element.
132pub const G1_ELEMENT_BYTE_LENGTH: usize = 48;
133
134/// Domain separation tag for hashing a proof of possession (compressed G2) to G1.
135pub const G1_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
136
137/// Domain separation tag for hashing a message to G1.
138///
139/// We use the `POP` scheme for hashing all messages because this crate is expected to be
140/// used in a Byzantine environment (where any player may attempt a rogue key attack) and
141/// any message could be aggregated into a multi-signature (which requires a proof-of-possession
142/// to be safely deployed in this environment).
143pub const G1_MESSAGE: DST = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
144
145#[cfg(feature = "arbitrary")]
146impl arbitrary::Arbitrary<'_> for G1 {
147    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
148        Ok(Self::generator() * &u.arbitrary::<Scalar>()?)
149    }
150}
151
152/// A point on the BLS12-381 G2 curve.
153#[derive(Clone, Copy, Eq, PartialEq)]
154#[repr(transparent)]
155pub struct G2(blst_p2);
156
157/// The size in bytes of an encoded G2 element.
158pub const G2_ELEMENT_BYTE_LENGTH: usize = 96;
159
160/// Domain separation tag for hashing a proof of possession (compressed G1) to G2.
161pub const G2_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
162
163/// Domain separation tag for hashing a message to G2.
164///
165/// We use the `POP` scheme for hashing all messages because this crate is expected to be
166/// used in a Byzantine environment (where any player may attempt a rogue key attack) and
167/// any message could be aggregated into a multi-signature (which requires a proof-of-possession
168/// to be safely deployed in this environment).
169pub const G2_MESSAGE: DST = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
170
171#[cfg(feature = "arbitrary")]
172impl arbitrary::Arbitrary<'_> for G2 {
173    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
174        Ok(Self::generator() * &u.arbitrary::<Scalar>()?)
175    }
176}
177
178/// The target group of the BLS12-381 pairing.
179///
180/// This is an element in the extension field `F_p^12` and is
181/// produced as the result of a pairing operation.
182#[derive(Debug, Clone, Eq, PartialEq, Copy)]
183#[repr(transparent)]
184pub struct GT(blst_fp12);
185
186/// The size in bytes of an encoded GT element.
187///
188/// GT is a 12-tuple of Fp elements, each 48 bytes.
189pub const GT_ELEMENT_BYTE_LENGTH: usize = 576;
190
191impl GT {
192    /// Create GT from blst_fp12.
193    pub(crate) const fn from_blst_fp12(fp12: blst_fp12) -> Self {
194        Self(fp12)
195    }
196
197    /// Converts the GT element to its canonical big-endian byte representation.
198    pub fn as_slice(&self) -> [u8; GT_ELEMENT_BYTE_LENGTH] {
199        let mut slice = [0u8; GT_ELEMENT_BYTE_LENGTH];
200        // SAFETY: blst_bendian_from_fp12 writes exactly 576 bytes to a valid buffer.
201        // Using the proper serialization function ensures portable, canonical encoding.
202        unsafe {
203            blst_bendian_from_fp12(slice.as_mut_ptr(), &self.0);
204        }
205        slice
206    }
207}
208
209/// The private key type.
210pub type Private = Scalar;
211
212/// The private key length.
213pub const PRIVATE_KEY_LENGTH: usize = SCALAR_LENGTH;
214
215impl Scalar {
216    fn from_bytes(mut ikm: [u8; 64]) -> Self {
217        // Generate a scalar from the randomly populated buffer
218        let mut ret = blst_fr::default();
219        // SAFETY: ikm is a valid 64-byte buffer; blst_keygen handles null key_info.
220        unsafe {
221            let mut sc = blst_scalar::default();
222            blst_keygen(&mut sc, ikm.as_ptr(), ikm.len(), ptr::null(), 0);
223            blst_fr_from_scalar(&mut ret, &sc);
224        }
225
226        // Zeroize the ikm buffer
227        ikm.zeroize();
228
229        Self(ret)
230    }
231
232    /// Maps arbitrary bytes to a scalar using RFC9380 hash-to-field.
233    pub fn map(dst: DST, msg: &[u8]) -> Self {
234        // The BLS12-381 scalar field has a modulus of approximately 255 bits.
235        // According to RFC9380, when mapping to a field element, we need to
236        // generate uniform bytes with length L = ceil((ceil(log2(p)) + k) / 8),
237        // where p is the field modulus and k is the security parameter.
238        //
239        // For BLS12-381's scalar field:
240        // - log2(p) ≈ 255 bits
241        // - k = 128 bits (for 128-bit security)
242        // - L = ceil((255 + 128) / 8) = ceil(383 / 8) = 48 bytes
243        //
244        // These 48 bytes provide sufficient entropy to ensure uniform distribution
245        // in the scalar field after modular reduction, maintaining the security
246        // properties required by the hash-to-field construction.
247        const L: usize = 48;
248        let mut uniform_bytes = [0u8; L];
249        // SAFETY: All buffers are valid with correct lengths; blst handles empty inputs.
250        unsafe {
251            blst_expand_message_xmd(
252                uniform_bytes.as_mut_ptr(),
253                L,
254                msg.as_ptr(),
255                msg.len(),
256                dst.as_ptr(),
257                dst.len(),
258            );
259        }
260
261        // Transform expanded bytes with modular reduction
262        let mut fr = blst_fr::default();
263        // SAFETY: uniform_bytes is a valid 48-byte buffer.
264        unsafe {
265            let mut scalar = blst_scalar::default();
266            blst_scalar_from_be_bytes(&mut scalar, uniform_bytes.as_ptr(), L);
267            blst_fr_from_scalar(&mut fr, &scalar);
268        }
269
270        Self(fr)
271    }
272
273    /// Creates a new scalar from the provided integer.
274    pub(crate) fn from_u64(i: u64) -> Self {
275        // Create a new scalar
276        let mut ret = blst_fr::default();
277
278        let buffer = [i, 0, 0, 0];
279
280        // SAFETY: blst_fr_from_uint64 reads exactly 4 u64 values from the buffer.
281        //
282        // Reference: https://github.com/supranational/blst/blob/415d4f0e2347a794091836a3065206edfd9c72f3/bindings/blst.h#L102
283        unsafe { blst_fr_from_uint64(&mut ret, buffer.as_ptr()) };
284        Self(ret)
285    }
286
287    /// Encodes the scalar into a slice.
288    fn as_slice(&self) -> [u8; Self::SIZE] {
289        let mut slice = [0u8; Self::SIZE];
290        // SAFETY: All pointers valid; blst_bendian_from_scalar writes exactly 32 bytes.
291        unsafe {
292            let mut scalar = blst_scalar::default();
293            blst_scalar_from_fr(&mut scalar, &self.0);
294            blst_bendian_from_scalar(slice.as_mut_ptr(), &scalar);
295        }
296        slice
297    }
298
299    /// Converts the scalar to the raw `blst_scalar` type.
300    pub(crate) fn as_blst_scalar(&self) -> blst_scalar {
301        let mut scalar = blst_scalar::default();
302        // SAFETY: Both pointers are valid and properly aligned.
303        unsafe { blst_scalar_from_fr(&mut scalar, &self.0) };
304        scalar
305    }
306}
307
308impl Write for Scalar {
309    fn write(&self, buf: &mut impl BufMut) {
310        let slice = self.as_slice();
311        buf.put_slice(&slice);
312    }
313}
314
315impl Read for Scalar {
316    type Cfg = ();
317
318    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
319        let bytes = <[u8; Self::SIZE]>::read(buf)?;
320        let mut ret = blst_fr::default();
321        // SAFETY: bytes is a valid 32-byte array. blst_sk_check validates non-zero and in-range.
322        // We use blst_sk_check instead of blst_scalar_fr_check because it also checks non-zero
323        // per IETF BLS12-381 spec (Draft 4+).
324        //
325        // References:
326        // * https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-bls-signature-03#section-2.3
327        // * https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-bls-signature-04#section-2.3
328        unsafe {
329            let mut scalar = blst_scalar::default();
330            blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
331            if !blst_sk_check(&scalar) {
332                return Err(Invalid("Scalar", "Invalid"));
333            }
334            blst_fr_from_scalar(&mut ret, &scalar);
335        }
336        Ok(Self(ret))
337    }
338}
339
340impl FixedSize for Scalar {
341    const SIZE: usize = SCALAR_LENGTH;
342}
343
344impl Hash for Scalar {
345    fn hash<H: Hasher>(&self, state: &mut H) {
346        let slice = self.as_slice();
347        state.write(&slice);
348    }
349}
350
351impl PartialOrd for Scalar {
352    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
353        Some(self.cmp(other))
354    }
355}
356
357impl Ord for Scalar {
358    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
359        self.as_slice().cmp(&other.as_slice())
360    }
361}
362
363impl Debug for Scalar {
364    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
365        write!(f, "{}", hex(&self.as_slice()))
366    }
367}
368
369impl Display for Scalar {
370    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
371        write!(f, "{}", hex(&self.as_slice()))
372    }
373}
374
375impl Zeroize for Scalar {
376    fn zeroize(&mut self) {
377        self.0.l.zeroize();
378    }
379}
380
381impl Drop for Scalar {
382    fn drop(&mut self) {
383        self.zeroize();
384    }
385}
386
387impl ZeroizeOnDrop for Scalar {}
388
389impl Object for Scalar {}
390
391impl<'a> AddAssign<&'a Self> for Scalar {
392    fn add_assign(&mut self, rhs: &'a Self) {
393        let ptr = &raw mut self.0;
394        // SAFETY: blst_fr_add supports in-place (ret==a). Raw pointer avoids aliased refs.
395        unsafe {
396            blst_fr_add(ptr, ptr, &rhs.0);
397        }
398    }
399}
400
401impl<'a> Add<&'a Self> for Scalar {
402    type Output = Self;
403
404    fn add(mut self, rhs: &'a Self) -> Self::Output {
405        self += rhs;
406        self
407    }
408}
409
410impl<'a> SubAssign<&'a Self> for Scalar {
411    fn sub_assign(&mut self, rhs: &'a Self) {
412        let ptr = &raw mut self.0;
413        // SAFETY: blst_fr_sub supports in-place (ret==a). Raw pointer avoids aliased refs.
414        unsafe { blst_fr_sub(ptr, ptr, &rhs.0) }
415    }
416}
417
418impl<'a> Sub<&'a Self> for Scalar {
419    type Output = Self;
420
421    fn sub(mut self, rhs: &'a Self) -> Self::Output {
422        self -= rhs;
423        self
424    }
425}
426
427impl Neg for Scalar {
428    type Output = Self;
429
430    fn neg(mut self) -> Self::Output {
431        let ptr = &raw mut self.0;
432        // SAFETY: blst_fr_cneg supports in-place (ret==a). Raw pointer avoids aliased refs.
433        unsafe {
434            blst_fr_cneg(ptr, ptr, true);
435        }
436        self
437    }
438}
439
440impl Additive for Scalar {
441    fn zero() -> Self {
442        Self(blst_fr::default())
443    }
444}
445
446impl<'a> MulAssign<&'a Self> for Scalar {
447    fn mul_assign(&mut self, rhs: &'a Self) {
448        let ptr = &raw mut self.0;
449        // SAFETY: blst_fr_mul supports in-place (ret==a). Raw pointer avoids aliased refs.
450        unsafe {
451            blst_fr_mul(ptr, ptr, &rhs.0);
452        }
453    }
454}
455
456impl<'a> Mul<&'a Self> for Scalar {
457    type Output = Self;
458
459    fn mul(mut self, rhs: &'a Self) -> Self::Output {
460        self *= rhs;
461        self
462    }
463}
464
465impl Multiplicative for Scalar {}
466
467impl Ring for Scalar {
468    fn one() -> Self {
469        BLST_FR_ONE
470    }
471}
472
473impl Field for Scalar {
474    fn inv(&self) -> Self {
475        if *self == Self::zero() {
476            return Self::zero();
477        }
478        let mut ret = blst_fr::default();
479        // SAFETY: Input is non-zero (checked above); blst_fr_inverse is defined for non-zero.
480        unsafe { blst_fr_inverse(&mut ret, &self.0) };
481        Self(ret)
482    }
483}
484
485impl Random for Scalar {
486    fn random(mut rng: impl CryptoRngCore) -> Self {
487        // Generate a random 64 byte buffer
488        let mut ikm = [0u8; 64];
489        rng.fill_bytes(&mut ikm);
490        Self::from_bytes(ikm)
491    }
492}
493
494/// A share of a threshold signing key.
495#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
496#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
497pub struct Share {
498    /// The share's index in the polynomial.
499    pub index: u32,
500    /// The scalar corresponding to the share's secret.
501    pub private: Private,
502}
503
504impl AsRef<Private> for Share {
505    fn as_ref(&self) -> &Private {
506        &self.private
507    }
508}
509
510impl Share {
511    /// Returns the public key corresponding to the share.
512    ///
513    /// This can be verified against the public polynomial.
514    pub fn public<V: Variant>(&self) -> V::Public {
515        V::Public::generator() * &self.private
516    }
517}
518
519impl Write for Share {
520    fn write(&self, buf: &mut impl BufMut) {
521        UInt(self.index).write(buf);
522        self.private.write(buf);
523    }
524}
525
526impl Read for Share {
527    type Cfg = ();
528
529    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
530        let index = UInt::read(buf)?.into();
531        let private = Private::read(buf)?;
532        Ok(Self { index, private })
533    }
534}
535
536impl EncodeSize for Share {
537    fn encode_size(&self) -> usize {
538        UInt(self.index).encode_size() + self.private.encode_size()
539    }
540}
541
542impl Display for Share {
543    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
544        write!(f, "Share(index={}, private={})", self.index, self.private)
545    }
546}
547
548impl Debug for Share {
549    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
550        write!(f, "Share(index={}, private={})", self.index, self.private)
551    }
552}
553
554impl G1 {
555    /// Encodes the G1 element into a slice.
556    fn as_slice(&self) -> [u8; Self::SIZE] {
557        let mut slice = [0u8; Self::SIZE];
558        // SAFETY: blst_p1_compress writes exactly 48 bytes to a valid buffer.
559        unsafe {
560            blst_p1_compress(slice.as_mut_ptr(), &self.0);
561        }
562        slice
563    }
564
565    /// Like [`std::ops::Neg::neg`], except operating in place.
566    ///
567    /// This function exists in order to avoid an extra copy when implement
568    /// subtraction. Basically, the compiler (including LLVM) aren't smart
569    /// enough to eliminate a copy that happens if you implement subtraction
570    /// as `x += &-*rhs`. So, instead, we copy rhs, negate it in place, and then
571    /// add it, to avoid a copy.
572    fn neg_in_place(&mut self) {
573        let ptr = &raw mut self.0;
574        // SAFETY: ptr is valid.
575        unsafe {
576            blst_p1_cneg(ptr, true);
577        }
578    }
579
580    /// Converts the G1 point to its affine representation.
581    pub(crate) fn as_blst_p1_affine(&self) -> blst_p1_affine {
582        let mut affine = blst_p1_affine::default();
583        // SAFETY: Both pointers are valid and properly aligned.
584        unsafe { blst_p1_to_affine(&mut affine, &self.0) };
585        affine
586    }
587
588    /// Creates a G1 point from a raw `blst_p1`.
589    pub(crate) const fn from_blst_p1(p: blst_p1) -> Self {
590        Self(p)
591    }
592}
593
594impl Write for G1 {
595    fn write(&self, buf: &mut impl BufMut) {
596        let slice = self.as_slice();
597        buf.put_slice(&slice);
598    }
599}
600
601impl Read for G1 {
602    type Cfg = ();
603
604    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
605        let bytes = <[u8; Self::SIZE]>::read(buf)?;
606        let mut ret = blst_p1::default();
607        // SAFETY: bytes is a valid 48-byte array. blst_p1_uncompress validates encoding.
608        // Additional checks for infinity and subgroup membership prevent small subgroup attacks.
609        unsafe {
610            let mut affine = blst_p1_affine::default();
611            match blst_p1_uncompress(&mut affine, bytes.as_ptr()) {
612                BLST_ERROR::BLST_SUCCESS => {}
613                BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G1", "Bad encoding")),
614                BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G1", "Not on curve")),
615                BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G1", "Not in group")),
616                BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G1", "Type mismatch")),
617                BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G1", "Verify fail")),
618                BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G1", "PK is Infinity")),
619                BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G1", "Bad scalar")),
620            }
621            blst_p1_from_affine(&mut ret, &affine);
622
623            // Verify that deserialized element isn't infinite
624            if blst_p1_is_inf(&ret) {
625                return Err(Invalid("G1", "Infinity"));
626            }
627
628            // Verify that the deserialized element is in G1
629            if !blst_p1_in_g1(&ret) {
630                return Err(Invalid("G1", "Outside G1"));
631            }
632        }
633        Ok(Self(ret))
634    }
635}
636
637impl FixedSize for G1 {
638    const SIZE: usize = G1_ELEMENT_BYTE_LENGTH;
639}
640
641impl Hash for G1 {
642    fn hash<H: Hasher>(&self, state: &mut H) {
643        let slice = self.as_slice();
644        state.write(&slice);
645    }
646}
647
648impl PartialOrd for G1 {
649    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
650        Some(self.cmp(other))
651    }
652}
653
654impl Ord for G1 {
655    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
656        self.as_slice().cmp(&other.as_slice())
657    }
658}
659
660impl Debug for G1 {
661    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
662        write!(f, "{}", hex(&self.as_slice()))
663    }
664}
665
666impl Display for G1 {
667    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
668        write!(f, "{}", hex(&self.as_slice()))
669    }
670}
671
672impl Object for G1 {}
673
674impl<'a> AddAssign<&'a Self> for G1 {
675    fn add_assign(&mut self, rhs: &'a Self) {
676        let ptr = &raw mut self.0;
677        // SAFETY: blst_p1_add_or_double supports in-place (ret==a). Raw pointer avoids aliased refs.
678        unsafe {
679            blst_p1_add_or_double(ptr, ptr, &rhs.0);
680        }
681    }
682}
683
684impl<'a> Add<&'a Self> for G1 {
685    type Output = Self;
686
687    fn add(mut self, rhs: &'a Self) -> Self::Output {
688        self += rhs;
689        self
690    }
691}
692
693impl Neg for G1 {
694    type Output = Self;
695
696    fn neg(mut self) -> Self::Output {
697        self.neg_in_place();
698        self
699    }
700}
701
702impl<'a> SubAssign<&'a Self> for G1 {
703    fn sub_assign(&mut self, rhs: &'a Self) {
704        let mut rhs_cp = *rhs;
705        rhs_cp.neg_in_place();
706        *self += &rhs_cp;
707    }
708}
709
710impl<'a> Sub<&'a Self> for G1 {
711    type Output = Self;
712
713    fn sub(mut self, rhs: &'a Self) -> Self::Output {
714        self -= rhs;
715        self
716    }
717}
718
719impl Additive for G1 {
720    fn zero() -> Self {
721        Self(blst_p1::default())
722    }
723}
724
725impl<'a> MulAssign<&'a Scalar> for G1 {
726    fn mul_assign(&mut self, rhs: &'a Scalar) {
727        let ptr = &raw mut self.0;
728        let mut scalar: blst_scalar = blst_scalar::default();
729        // SAFETY: blst_p1_mult supports in-place (ret==a). Using SCALAR_BITS (255) ensures
730        // constant-time execution. Raw pointer avoids aliased refs.
731        unsafe {
732            blst_scalar_from_fr(&mut scalar, &rhs.0);
733            blst_p1_mult(ptr, ptr, scalar.b.as_ptr(), SCALAR_BITS);
734        }
735    }
736}
737
738impl<'a> Mul<&'a Scalar> for G1 {
739    type Output = Self;
740
741    fn mul(mut self, rhs: &'a Scalar) -> Self::Output {
742        self *= rhs;
743        self
744    }
745}
746
747impl Space<Scalar> for G1 {
748    /// Performs multi-scalar multiplication (MSM) on G1 points using Pippenger's algorithm.
749    /// Computes `sum(scalars[i] * points[i])`.
750    ///
751    /// Filters out pairs where the point is the identity element (infinity).
752    /// Returns an error if the lengths of the input slices mismatch.
753    fn msm(points: &[Self], scalars: &[Scalar], _concurrency: usize) -> Self {
754        // Assert input validity
755        assert_eq!(points.len(), scalars.len(), "mismatched lengths");
756
757        // Prepare points (affine) and scalars (raw blst_scalar)
758        let mut points_filtered = Vec::with_capacity(points.len());
759        let mut scalars_filtered = Vec::with_capacity(scalars.len());
760        for (point, scalar) in points.iter().zip(scalars.iter()) {
761            // `blst` does not filter out infinity, so we must ensure it is impossible.
762            //
763            // Sources:
764            // * https://github.com/supranational/blst/blob/cbc7e166a10d7286b91a3a7bea341e708962db13/src/multi_scalar.c#L10-L12
765            // * https://github.com/MystenLabs/fastcrypto/blob/0acf0ff1a163c60e0dec1e16e4fbad4a4cf853bd/fastcrypto/src/groups/bls12381.rs#L160-L194
766            if *point == Self::zero() || *scalar == Scalar::zero() {
767                continue;
768            }
769
770            // Add to filtered vectors
771            points_filtered.push(point.as_blst_p1_affine());
772            scalars_filtered.push(scalar.as_blst_scalar());
773        }
774
775        // If all points were filtered, return zero.
776        if points_filtered.is_empty() {
777            return Self::zero();
778        }
779
780        // Create vectors of pointers for the blst API.
781        // These vectors hold pointers *to* the elements in the filtered vectors above.
782        let points: Vec<*const blst_p1_affine> =
783            points_filtered.iter().map(|p| p as *const _).collect();
784        let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
785
786        // Allocate scratch space for Pippenger's algorithm.
787        // SAFETY: blst_p1s_mult_pippenger_scratch_sizeof returns size in bytes for valid input.
788        let scratch_size = unsafe { blst_p1s_mult_pippenger_scratch_sizeof(points.len()) };
789        // Ensure scratch_size is a multiple of 8 to avoid truncation in division.
790        assert_eq!(scratch_size % 8, 0, "scratch_size must be multiple of 8");
791        let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
792
793        // Perform multi-scalar multiplication
794        let mut msm_result = blst_p1::default();
795        // SAFETY: All pointer arrays are valid and point to data that outlives this call.
796        // points_filtered and scalars_filtered remain alive until after this block.
797        unsafe {
798            blst_p1s_mult_pippenger(
799                &mut msm_result,
800                points.as_ptr(),
801                points.len(),
802                scalars.as_ptr(),
803                SCALAR_BITS, // Using SCALAR_BITS (255) ensures full scalar range
804                scratch.as_mut_ptr() as *mut _,
805            );
806        }
807
808        Self::from_blst_p1(msm_result)
809    }
810}
811
812impl CryptoGroup for G1 {
813    type Scalar = Scalar;
814
815    fn generator() -> Self {
816        let mut ret = blst_p1::default();
817        // SAFETY: BLS12_381_G1 is a valid generator point constant.
818        unsafe {
819            blst_p1_from_affine(&mut ret, &BLS12_381_G1);
820        }
821        Self(ret)
822    }
823}
824
825impl HashToGroup for G1 {
826    fn hash_to_group(domain_separator: &[u8], message: &[u8]) -> Self {
827        let mut out = blst_p1::default();
828        // SAFETY: All pointers valid; blst_hash_to_g1 handles empty data. Aug is null/0 (unused).
829        unsafe {
830            blst_hash_to_g1(
831                &mut out,
832                message.as_ptr(),
833                message.len(),
834                domain_separator.as_ptr(),
835                domain_separator.len(),
836                ptr::null(),
837                0,
838            );
839        }
840        Self(out)
841    }
842}
843
844impl G2 {
845    /// Encodes the G2 element into a slice.
846    fn as_slice(&self) -> [u8; Self::SIZE] {
847        let mut slice = [0u8; Self::SIZE];
848        // SAFETY: blst_p2_compress writes exactly 96 bytes to a valid buffer.
849        unsafe {
850            blst_p2_compress(slice.as_mut_ptr(), &self.0);
851        }
852        slice
853    }
854
855    /// c.f. [G1::neg_in_place].
856    fn neg_in_place(&mut self) {
857        let ptr = &raw mut self.0;
858        // SAFETY: ptr is valid.
859        unsafe {
860            blst_p2_cneg(ptr, true);
861        }
862    }
863
864    /// Converts the G2 point to its affine representation.
865    pub(crate) fn as_blst_p2_affine(&self) -> blst_p2_affine {
866        let mut affine = blst_p2_affine::default();
867        // SAFETY: Both pointers are valid and properly aligned.
868        unsafe { blst_p2_to_affine(&mut affine, &self.0) };
869        affine
870    }
871
872    /// Creates a G2 point from a raw `blst_p2`.
873    pub(crate) const fn from_blst_p2(p: blst_p2) -> Self {
874        Self(p)
875    }
876}
877
878impl Write for G2 {
879    fn write(&self, buf: &mut impl BufMut) {
880        let slice = self.as_slice();
881        buf.put_slice(&slice);
882    }
883}
884
885impl Read for G2 {
886    type Cfg = ();
887
888    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
889        let bytes = <[u8; Self::SIZE]>::read(buf)?;
890        let mut ret = blst_p2::default();
891        // SAFETY: bytes is a valid 96-byte array. blst_p2_uncompress validates encoding.
892        // Additional checks for infinity and subgroup membership prevent small subgroup attacks.
893        unsafe {
894            let mut affine = blst_p2_affine::default();
895            match blst_p2_uncompress(&mut affine, bytes.as_ptr()) {
896                BLST_ERROR::BLST_SUCCESS => {}
897                BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G2", "Bad encoding")),
898                BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G2", "Not on curve")),
899                BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G2", "Not in group")),
900                BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G2", "Type mismatch")),
901                BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G2", "Verify fail")),
902                BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G2", "PK is Infinity")),
903                BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G2", "Bad scalar")),
904            }
905            blst_p2_from_affine(&mut ret, &affine);
906
907            // Verify that deserialized element isn't infinite
908            if blst_p2_is_inf(&ret) {
909                return Err(Invalid("G2", "Infinity"));
910            }
911
912            // Verify that the deserialized element is in G2
913            if !blst_p2_in_g2(&ret) {
914                return Err(Invalid("G2", "Outside G2"));
915            }
916        }
917        Ok(Self(ret))
918    }
919}
920
921impl FixedSize for G2 {
922    const SIZE: usize = G2_ELEMENT_BYTE_LENGTH;
923}
924
925impl Hash for G2 {
926    fn hash<H: Hasher>(&self, state: &mut H) {
927        let slice = self.as_slice();
928        state.write(&slice);
929    }
930}
931
932impl PartialOrd for G2 {
933    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
934        Some(self.cmp(other))
935    }
936}
937
938impl Ord for G2 {
939    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
940        self.as_slice().cmp(&other.as_slice())
941    }
942}
943
944impl Debug for G2 {
945    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
946        write!(f, "{}", hex(&self.as_slice()))
947    }
948}
949
950impl Display for G2 {
951    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
952        write!(f, "{}", hex(&self.as_slice()))
953    }
954}
955
956impl Object for G2 {}
957
958impl<'a> AddAssign<&'a Self> for G2 {
959    fn add_assign(&mut self, rhs: &'a Self) {
960        let ptr = &raw mut self.0;
961        // SAFETY: blst_p2_add_or_double supports in-place (ret==a). Raw pointer avoids aliased refs.
962        unsafe {
963            blst_p2_add_or_double(ptr, ptr, &rhs.0);
964        }
965    }
966}
967
968impl<'a> Add<&'a Self> for G2 {
969    type Output = Self;
970
971    fn add(mut self, rhs: &'a Self) -> Self::Output {
972        self += rhs;
973        self
974    }
975}
976
977impl Neg for G2 {
978    type Output = Self;
979
980    fn neg(mut self) -> Self::Output {
981        self.neg_in_place();
982        self
983    }
984}
985
986impl<'a> SubAssign<&'a Self> for G2 {
987    fn sub_assign(&mut self, rhs: &'a Self) {
988        let mut rhs_cp = *rhs;
989        rhs_cp.neg_in_place();
990        *self += &rhs_cp;
991    }
992}
993
994impl<'a> Sub<&'a Self> for G2 {
995    type Output = Self;
996
997    fn sub(mut self, rhs: &'a Self) -> Self::Output {
998        self -= rhs;
999        self
1000    }
1001}
1002
1003impl Additive for G2 {
1004    fn zero() -> Self {
1005        Self(blst_p2::default())
1006    }
1007}
1008
1009impl<'a> MulAssign<&'a Scalar> for G2 {
1010    fn mul_assign(&mut self, rhs: &'a Scalar) {
1011        let mut scalar = blst_scalar::default();
1012        let ptr = &raw mut self.0;
1013        // SAFETY: blst_p2_mult supports in-place (ret==a). Using SCALAR_BITS (255) ensures
1014        // constant-time execution. Raw pointer avoids aliased refs.
1015        unsafe {
1016            blst_scalar_from_fr(&mut scalar, &rhs.0);
1017            blst_p2_mult(ptr, ptr, scalar.b.as_ptr(), SCALAR_BITS);
1018        }
1019    }
1020}
1021
1022impl<'a> Mul<&'a Scalar> for G2 {
1023    type Output = Self;
1024
1025    fn mul(mut self, rhs: &'a Scalar) -> Self::Output {
1026        self *= rhs;
1027        self
1028    }
1029}
1030
1031impl Space<Scalar> for G2 {
1032    /// Performs multi-scalar multiplication (MSM) on G2 points using Pippenger's algorithm.
1033    /// Computes `sum(scalars[i] * points[i])`.
1034    ///
1035    /// Filters out pairs where the point is the identity element (infinity).
1036    /// Returns an error if the lengths of the input slices mismatch.
1037    fn msm(points: &[Self], scalars: &[Scalar], _concurrency: usize) -> Self {
1038        // Assert input validity
1039        assert_eq!(points.len(), scalars.len(), "mismatched lengths");
1040
1041        // Prepare points (affine) and scalars (raw blst_scalar), filtering identity points
1042        let mut points_filtered = Vec::with_capacity(points.len());
1043        let mut scalars_filtered = Vec::with_capacity(scalars.len());
1044        for (point, scalar) in points.iter().zip(scalars.iter()) {
1045            // `blst` does not filter out infinity, so we must ensure it is impossible.
1046            //
1047            // Sources:
1048            // * https://github.com/supranational/blst/blob/cbc7e166a10d7286b91a3a7bea341e708962db13/src/multi_scalar.c#L10-L12
1049            // * https://github.com/MystenLabs/fastcrypto/blob/0acf0ff1a163c60e0dec1e16e4fbad4a4cf853bd/fastcrypto/src/groups/bls12381.rs#L160-L194
1050            if *point == Self::zero() || *scalar == Scalar::zero() {
1051                continue;
1052            }
1053            points_filtered.push(point.as_blst_p2_affine());
1054            scalars_filtered.push(scalar.as_blst_scalar());
1055        }
1056
1057        // If all points were filtered, return zero.
1058        if points_filtered.is_empty() {
1059            return Self::zero();
1060        }
1061
1062        // Create vectors of pointers for the blst API
1063        let points: Vec<*const blst_p2_affine> =
1064            points_filtered.iter().map(|p| p as *const _).collect();
1065        let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
1066
1067        // Allocate scratch space for Pippenger algorithm
1068        // SAFETY: blst_p2s_mult_pippenger_scratch_sizeof returns size in bytes for valid input.
1069        let scratch_size = unsafe { blst_p2s_mult_pippenger_scratch_sizeof(points.len()) };
1070        // Ensure scratch_size is a multiple of 8 to avoid truncation in division.
1071        assert_eq!(scratch_size % 8, 0, "scratch_size must be multiple of 8");
1072        let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
1073
1074        // Perform multi-scalar multiplication
1075        let mut msm_result = blst_p2::default();
1076        // SAFETY: All pointer arrays are valid and point to data that outlives this call.
1077        // points_filtered and scalars_filtered remain alive until after this block.
1078        unsafe {
1079            blst_p2s_mult_pippenger(
1080                &mut msm_result,
1081                points.as_ptr(),
1082                points.len(),
1083                scalars.as_ptr(),
1084                SCALAR_BITS, // Using SCALAR_BITS (255) ensures full scalar range
1085                scratch.as_mut_ptr() as *mut _,
1086            );
1087        }
1088
1089        Self::from_blst_p2(msm_result)
1090    }
1091}
1092
1093impl CryptoGroup for G2 {
1094    type Scalar = Scalar;
1095
1096    fn generator() -> Self {
1097        let mut ret = blst_p2::default();
1098        // SAFETY: BLS12_381_G2 is a valid generator point constant.
1099        unsafe {
1100            blst_p2_from_affine(&mut ret, &BLS12_381_G2);
1101        }
1102        Self(ret)
1103    }
1104}
1105
1106impl HashToGroup for G2 {
1107    fn hash_to_group(domain_separator: &[u8], message: &[u8]) -> Self {
1108        let mut out = blst_p2::default();
1109        // SAFETY: All pointers valid; blst_hash_to_g2 handles empty data. Aug is null/0 (unused).
1110        unsafe {
1111            blst_hash_to_g2(
1112                &mut out,
1113                message.as_ptr(),
1114                message.len(),
1115                domain_separator.as_ptr(),
1116                domain_separator.len(),
1117                ptr::null(),
1118                0,
1119            );
1120        }
1121        Self(out)
1122    }
1123}
1124
1125#[cfg(test)]
1126mod tests {
1127    use super::*;
1128    use commonware_codec::{DecodeExt, Encode};
1129    use commonware_math::algebra::test_suites;
1130    use proptest::prelude::*;
1131    use rand::prelude::*;
1132    use std::collections::{BTreeSet, HashMap};
1133
1134    impl Arbitrary for Scalar {
1135        type Parameters = ();
1136        type Strategy = BoxedStrategy<Self>;
1137
1138        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
1139            any::<[u8; 64]>().prop_map(Self::from_bytes).boxed()
1140        }
1141    }
1142
1143    impl Arbitrary for G1 {
1144        type Parameters = ();
1145        type Strategy = BoxedStrategy<Self>;
1146
1147        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
1148            prop_oneof![
1149                Just(Self::zero()),
1150                Just(Self::generator()),
1151                any::<Scalar>().prop_map(|s| Self::generator() * &s)
1152            ]
1153            .boxed()
1154        }
1155    }
1156
1157    impl Arbitrary for G2 {
1158        type Parameters = ();
1159        type Strategy = BoxedStrategy<Self>;
1160
1161        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
1162            prop_oneof![
1163                Just(Self::zero()),
1164                Just(Self::generator()),
1165                any::<Scalar>().prop_map(|s| Self::generator() * &s)
1166            ]
1167            .boxed()
1168        }
1169    }
1170
1171    #[test]
1172    fn test_scalar_as_field() {
1173        test_suites::test_field(file!(), &any::<Scalar>());
1174    }
1175
1176    #[test]
1177    fn test_g1_as_space() {
1178        test_suites::test_space_ring(file!(), &any::<Scalar>(), &any::<G1>());
1179    }
1180
1181    #[test]
1182    fn test_g2_as_space() {
1183        test_suites::test_space_ring(file!(), &any::<Scalar>(), &any::<G2>());
1184    }
1185
1186    #[test]
1187    fn test_hash_to_g1() {
1188        test_suites::test_hash_to_group::<G1>(file!());
1189    }
1190
1191    #[test]
1192    fn test_hash_to_g2() {
1193        test_suites::test_hash_to_group::<G2>(file!());
1194    }
1195
1196    #[test]
1197    fn basic_group() {
1198        // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/b0ef82ff79769d085a5a7d3f4fe690b1c8fe6dc9/crates/threshold-bls/src/curve/bls12381.rs#L200-L220
1199        let s = Scalar::random(&mut thread_rng());
1200        let mut s2 = s.clone();
1201        s2.double();
1202
1203        // p1 = s2 * G = (s+s)G
1204        let p1 = G1::generator() * &s2;
1205
1206        // p2 = sG + sG = s2 * G
1207        let mut p2 = G1::generator() * &s;
1208        p2.double();
1209        assert_eq!(p1, p2);
1210    }
1211
1212    #[test]
1213    fn test_scalar_codec() {
1214        let original = Scalar::random(&mut thread_rng());
1215        let mut encoded = original.encode();
1216        assert_eq!(encoded.len(), Scalar::SIZE);
1217        let decoded = Scalar::decode(&mut encoded).unwrap();
1218        assert_eq!(original, decoded);
1219    }
1220
1221    #[test]
1222    fn test_g1_codec() {
1223        let original = G1::generator() * &Scalar::random(&mut thread_rng());
1224        let mut encoded = original.encode();
1225        assert_eq!(encoded.len(), G1::SIZE);
1226        let decoded = G1::decode(&mut encoded).unwrap();
1227        assert_eq!(original, decoded);
1228    }
1229
1230    #[test]
1231    fn test_g2_codec() {
1232        let original = G2::generator() * &Scalar::random(&mut thread_rng());
1233        let mut encoded = original.encode();
1234        assert_eq!(encoded.len(), G2::SIZE);
1235        let decoded = G2::decode(&mut encoded).unwrap();
1236        assert_eq!(original, decoded);
1237    }
1238
1239    /// Naive calculation of Multi-Scalar Multiplication: sum(scalar * point)
1240    fn naive_msm<P: Space<Scalar>>(points: &[P], scalars: &[Scalar]) -> P {
1241        assert_eq!(points.len(), scalars.len());
1242        let mut total = P::zero();
1243        for (point, scalar) in points.iter().zip(scalars.iter()) {
1244            // Skip identity points or zero scalars, similar to the optimized MSM
1245            if *point == P::zero() || *scalar == Scalar::zero() {
1246                continue;
1247            }
1248            let term = point.clone() * scalar;
1249            total += &term;
1250        }
1251        total
1252    }
1253
1254    #[test]
1255    fn test_g1_msm() {
1256        let mut rng = thread_rng();
1257        let n = 10; // Number of points/scalars
1258
1259        // Case 1: Random points and scalars
1260        let points_g1: Vec<G1> = (0..n)
1261            .map(|_| G1::generator() * &Scalar::random(&mut rng))
1262            .collect();
1263        let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::random(&mut rng)).collect();
1264        let expected_g1 = naive_msm(&points_g1, &scalars);
1265        let result_g1 = G1::msm(&points_g1, &scalars, 1);
1266        assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
1267
1268        // Case 2: Include identity point
1269        let mut points_with_zero_g1 = points_g1.clone();
1270        points_with_zero_g1[n / 2] = G1::zero();
1271        let expected_zero_pt_g1 = naive_msm(&points_with_zero_g1, &scalars);
1272        let result_zero_pt_g1 = G1::msm(&points_with_zero_g1, &scalars, 1);
1273        assert_eq!(
1274            expected_zero_pt_g1, result_zero_pt_g1,
1275            "G1 MSM with identity point failed"
1276        );
1277
1278        // Case 3: Include zero scalar
1279        let mut scalars_with_zero = scalars.clone();
1280        scalars_with_zero[n / 2] = Scalar::zero();
1281        let expected_zero_sc_g1 = naive_msm(&points_g1, &scalars_with_zero);
1282        let result_zero_sc_g1 = G1::msm(&points_g1, &scalars_with_zero, 1);
1283        assert_eq!(
1284            expected_zero_sc_g1, result_zero_sc_g1,
1285            "G1 MSM with zero scalar failed"
1286        );
1287
1288        // Case 4: All points identity
1289        let zero_points_g1 = vec![G1::zero(); n];
1290        let expected_all_zero_pt_g1 = naive_msm(&zero_points_g1, &scalars);
1291        let result_all_zero_pt_g1 = G1::msm(&zero_points_g1, &scalars, 1);
1292        assert_eq!(
1293            expected_all_zero_pt_g1,
1294            G1::zero(),
1295            "G1 MSM all identity points (naive) failed"
1296        );
1297        assert_eq!(
1298            result_all_zero_pt_g1,
1299            G1::zero(),
1300            "G1 MSM all identity points failed"
1301        );
1302
1303        // Case 5: All scalars zero
1304        let zero_scalars = vec![Scalar::zero(); n];
1305        let expected_all_zero_sc_g1 = naive_msm(&points_g1, &zero_scalars);
1306        let result_all_zero_sc_g1 = G1::msm(&points_g1, &zero_scalars, 1);
1307        assert_eq!(
1308            expected_all_zero_sc_g1,
1309            G1::zero(),
1310            "G1 MSM all zero scalars (naive) failed"
1311        );
1312        assert_eq!(
1313            result_all_zero_sc_g1,
1314            G1::zero(),
1315            "G1 MSM all zero scalars failed"
1316        );
1317
1318        // Case 6: Single element
1319        let single_point_g1 = [points_g1[0]];
1320        let single_scalar = [scalars[0].clone()];
1321        let expected_single_g1 = naive_msm(&single_point_g1, &single_scalar);
1322        let result_single_g1 = G1::msm(&single_point_g1, &single_scalar, 1);
1323        assert_eq!(
1324            expected_single_g1, result_single_g1,
1325            "G1 MSM single element failed"
1326        );
1327
1328        // Case 7: Empty input
1329        let empty_points_g1: [G1; 0] = [];
1330        let empty_scalars: [Scalar; 0] = [];
1331        let expected_empty_g1 = naive_msm(&empty_points_g1, &empty_scalars);
1332        let result_empty_g1 = G1::msm(&empty_points_g1, &empty_scalars, 1);
1333        assert_eq!(expected_empty_g1, G1::zero(), "G1 MSM empty (naive) failed");
1334        assert_eq!(result_empty_g1, G1::zero(), "G1 MSM empty failed");
1335
1336        // Case 8: Random points and scalars (big)
1337        let points_g1: Vec<G1> = (0..50_000)
1338            .map(|_| G1::generator() * &Scalar::random(&mut rng))
1339            .collect();
1340        let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::random(&mut rng)).collect();
1341        let expected_g1 = naive_msm(&points_g1, &scalars);
1342        let result_g1 = G1::msm(&points_g1, &scalars, 1);
1343        assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
1344    }
1345
1346    #[test]
1347    fn test_g2_msm() {
1348        let mut rng = thread_rng();
1349        let n = 10; // Number of points/scalars
1350
1351        // Case 1: Random points and scalars
1352        let points_g2: Vec<G2> = (0..n)
1353            .map(|_| G2::generator() * &Scalar::random(&mut rng))
1354            .collect();
1355        let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::random(&mut rng)).collect();
1356        let expected_g2 = naive_msm(&points_g2, &scalars);
1357        let result_g2 = G2::msm(&points_g2, &scalars, 1);
1358        assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
1359
1360        // Case 2: Include identity point
1361        let mut points_with_zero_g2 = points_g2.clone();
1362        points_with_zero_g2[n / 2] = G2::zero();
1363        let expected_zero_pt_g2 = naive_msm(&points_with_zero_g2, &scalars);
1364        let result_zero_pt_g2 = G2::msm(&points_with_zero_g2, &scalars, 1);
1365        assert_eq!(
1366            expected_zero_pt_g2, result_zero_pt_g2,
1367            "G2 MSM with identity point failed"
1368        );
1369
1370        // Case 3: Include zero scalar
1371        let mut scalars_with_zero = scalars.clone();
1372        scalars_with_zero[n / 2] = Scalar::zero();
1373        let expected_zero_sc_g2 = naive_msm(&points_g2, &scalars_with_zero);
1374        let result_zero_sc_g2 = G2::msm(&points_g2, &scalars_with_zero, 1);
1375        assert_eq!(
1376            expected_zero_sc_g2, result_zero_sc_g2,
1377            "G2 MSM with zero scalar failed"
1378        );
1379
1380        // Case 4: All points identity
1381        let zero_points_g2 = vec![G2::zero(); n];
1382        let expected_all_zero_pt_g2 = naive_msm(&zero_points_g2, &scalars);
1383        let result_all_zero_pt_g2 = G2::msm(&zero_points_g2, &scalars, 1);
1384        assert_eq!(
1385            expected_all_zero_pt_g2,
1386            G2::zero(),
1387            "G2 MSM all identity points (naive) failed"
1388        );
1389        assert_eq!(
1390            result_all_zero_pt_g2,
1391            G2::zero(),
1392            "G2 MSM all identity points failed"
1393        );
1394
1395        // Case 5: All scalars zero
1396        let zero_scalars = vec![Scalar::zero(); n];
1397        let expected_all_zero_sc_g2 = naive_msm(&points_g2, &zero_scalars);
1398        let result_all_zero_sc_g2 = G2::msm(&points_g2, &zero_scalars, 1);
1399        assert_eq!(
1400            expected_all_zero_sc_g2,
1401            G2::zero(),
1402            "G2 MSM all zero scalars (naive) failed"
1403        );
1404        assert_eq!(
1405            result_all_zero_sc_g2,
1406            G2::zero(),
1407            "G2 MSM all zero scalars failed"
1408        );
1409
1410        // Case 6: Single element
1411        let single_point_g2 = [points_g2[0]];
1412        let single_scalar = [scalars[0].clone()];
1413        let expected_single_g2 = naive_msm(&single_point_g2, &single_scalar);
1414        let result_single_g2 = G2::msm(&single_point_g2, &single_scalar, 1);
1415        assert_eq!(
1416            expected_single_g2, result_single_g2,
1417            "G2 MSM single element failed"
1418        );
1419
1420        // Case 7: Empty input
1421        let empty_points_g2: [G2; 0] = [];
1422        let empty_scalars: [Scalar; 0] = [];
1423        let expected_empty_g2 = naive_msm(&empty_points_g2, &empty_scalars);
1424        let result_empty_g2 = G2::msm(&empty_points_g2, &empty_scalars, 1);
1425        assert_eq!(expected_empty_g2, G2::zero(), "G2 MSM empty (naive) failed");
1426        assert_eq!(result_empty_g2, G2::zero(), "G2 MSM empty failed");
1427
1428        // Case 8: Random points and scalars (big)
1429        let points_g2: Vec<G2> = (0..50_000)
1430            .map(|_| G2::generator() * &Scalar::random(&mut rng))
1431            .collect();
1432        let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::random(&mut rng)).collect();
1433        let expected_g2 = naive_msm(&points_g2, &scalars);
1434        let result_g2 = G2::msm(&points_g2, &scalars, 1);
1435        assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
1436    }
1437
1438    #[test]
1439    fn test_trait_implementations() {
1440        // Generate a set of unique items to test.
1441        let mut rng = thread_rng();
1442        const NUM_ITEMS: usize = 10;
1443        let mut scalar_set = BTreeSet::new();
1444        let mut g1_set = BTreeSet::new();
1445        let mut g2_set = BTreeSet::new();
1446        let mut share_set = BTreeSet::new();
1447        while scalar_set.len() < NUM_ITEMS {
1448            let scalar = Scalar::random(&mut rng);
1449            let g1 = G1::generator() * &scalar;
1450            let g2 = G2::generator() * &scalar;
1451            let share = Share {
1452                index: scalar_set.len() as u32,
1453                private: scalar.clone(),
1454            };
1455
1456            scalar_set.insert(scalar);
1457            g1_set.insert(g1);
1458            g2_set.insert(g2);
1459            share_set.insert(share);
1460        }
1461
1462        // Verify that the sets contain the expected number of unique items.
1463        assert_eq!(scalar_set.len(), NUM_ITEMS);
1464        assert_eq!(g1_set.len(), NUM_ITEMS);
1465        assert_eq!(g2_set.len(), NUM_ITEMS);
1466        assert_eq!(share_set.len(), NUM_ITEMS);
1467
1468        // Verify that `BTreeSet` iteration is sorted, which relies on `Ord`.
1469        let scalars: Vec<_> = scalar_set.iter().collect();
1470        assert!(scalars.windows(2).all(|w| w[0] <= w[1]));
1471        let g1s: Vec<_> = g1_set.iter().collect();
1472        assert!(g1s.windows(2).all(|w| w[0] <= w[1]));
1473        let g2s: Vec<_> = g2_set.iter().collect();
1474        assert!(g2s.windows(2).all(|w| w[0] <= w[1]));
1475        let shares: Vec<_> = share_set.iter().collect();
1476        assert!(shares.windows(2).all(|w| w[0] <= w[1]));
1477
1478        // Test that we can use these types as keys in hash maps, which relies on `Hash` and `Eq`.
1479        let scalar_map: HashMap<_, _> = scalar_set.iter().cloned().zip(0..).collect();
1480        let g1_map: HashMap<_, _> = g1_set.iter().cloned().zip(0..).collect();
1481        let g2_map: HashMap<_, _> = g2_set.iter().cloned().zip(0..).collect();
1482        let share_map: HashMap<_, _> = share_set.iter().cloned().zip(0..).collect();
1483
1484        // Verify that the maps contain the expected number of unique items.
1485        assert_eq!(scalar_map.len(), NUM_ITEMS);
1486        assert_eq!(g1_map.len(), NUM_ITEMS);
1487        assert_eq!(g2_map.len(), NUM_ITEMS);
1488        assert_eq!(share_map.len(), NUM_ITEMS);
1489    }
1490
1491    #[test]
1492    fn test_scalar_map() {
1493        // Test 1: Basic functionality
1494        let msg = b"test message";
1495        let dst = b"TEST_DST";
1496        let scalar1 = Scalar::map(dst, msg);
1497        let scalar2 = Scalar::map(dst, msg);
1498        assert_eq!(scalar1, scalar2, "Same input should produce same output");
1499
1500        // Test 2: Different messages produce different scalars
1501        let msg2 = b"different message";
1502        let scalar3 = Scalar::map(dst, msg2);
1503        assert_ne!(
1504            scalar1, scalar3,
1505            "Different messages should produce different scalars"
1506        );
1507
1508        // Test 3: Different DSTs produce different scalars
1509        let dst2 = b"DIFFERENT_DST";
1510        let scalar4 = Scalar::map(dst2, msg);
1511        assert_ne!(
1512            scalar1, scalar4,
1513            "Different DSTs should produce different scalars"
1514        );
1515
1516        // Test 4: Empty message
1517        let empty_msg = b"";
1518        let scalar_empty = Scalar::map(dst, empty_msg);
1519        assert_ne!(
1520            scalar_empty,
1521            Scalar::zero(),
1522            "Empty message should not produce zero"
1523        );
1524
1525        // Test 5: Large message
1526        let large_msg = vec![0x42u8; 1000];
1527        let scalar_large = Scalar::map(dst, &large_msg);
1528        assert_ne!(
1529            scalar_large,
1530            Scalar::zero(),
1531            "Large message should not produce zero"
1532        );
1533
1534        // Test 6: Verify the scalar is valid (not zero)
1535        assert_ne!(
1536            scalar1,
1537            Scalar::zero(),
1538            "Hash should not produce zero scalar"
1539        );
1540    }
1541
1542    #[cfg(feature = "arbitrary")]
1543    mod conformance {
1544        use super::*;
1545        use commonware_codec::conformance::CodecConformance;
1546
1547        commonware_conformance::conformance_tests! {
1548            CodecConformance<G1>,
1549            CodecConformance<G2>,
1550            CodecConformance<Scalar>,
1551            CodecConformance<Share>
1552        }
1553    }
1554}