commonware_cryptography/bls12381/primitives/
group.rs

1//! Group operations over the BLS12-381 scalar field.
2//!
3//! This crate implements basic group operations over BLS12-381 elements,
4//! including point addition, scalar multiplication, and pairing operations.
5//!
6//! # Warning
7//!
8//! Ensure that points are checked to belong to the correct subgroup
9//! (G1 or G2) to prevent small subgroup attacks. This is particularly important
10//! when handling deserialized points or points received from untrusted sources. This
11//! is already taken care of for you if you use the provided `deserialize` function.
12
13use super::variant::Variant;
14use blst::{
15    blst_bendian_from_scalar, blst_fp12, blst_fr, blst_fr_add, blst_fr_from_scalar,
16    blst_fr_from_uint64, blst_fr_inverse, blst_fr_mul, blst_fr_sub, blst_hash_to_g1,
17    blst_hash_to_g2, blst_keygen, blst_p1, blst_p1_add_or_double, blst_p1_affine, blst_p1_compress,
18    blst_p1_from_affine, blst_p1_in_g1, blst_p1_is_inf, blst_p1_mult, blst_p1_to_affine,
19    blst_p1_uncompress, blst_p1s_mult_pippenger, blst_p1s_mult_pippenger_scratch_sizeof, blst_p2,
20    blst_p2_add_or_double, blst_p2_affine, blst_p2_compress, blst_p2_from_affine, blst_p2_in_g2,
21    blst_p2_is_inf, blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress, blst_p2s_mult_pippenger,
22    blst_p2s_mult_pippenger_scratch_sizeof, blst_scalar, blst_scalar_from_bendian,
23    blst_scalar_from_fr, blst_sk_check, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
24};
25use bytes::{Buf, BufMut};
26use commonware_codec::{
27    varint::UInt,
28    EncodeSize,
29    Error::{self, Invalid},
30    FixedSize, Read, ReadExt, Write,
31};
32use commonware_utils::hex;
33use rand::RngCore;
34use std::{
35    fmt::{Debug, Display},
36    hash::{Hash, Hasher},
37    mem::MaybeUninit,
38    ptr,
39};
40use zeroize::{Zeroize, ZeroizeOnDrop};
41
42/// Domain separation tag used when hashing a message to a curve (G1 or G2).
43///
44/// Reference: <https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-bls-signature-05#name-ciphersuites>
45pub type DST = &'static [u8];
46
47/// An element of a group.
48pub trait Element:
49    Read<Cfg = ()> + Write + FixedSize + Clone + Eq + PartialEq + Send + Sync
50{
51    /// Returns the additive identity.
52    fn zero() -> Self;
53
54    /// Returns the multiplicative identity.
55    fn one() -> Self;
56
57    /// Adds to self in-place.
58    fn add(&mut self, rhs: &Self);
59
60    /// Multiplies self in-place.
61    fn mul(&mut self, rhs: &Scalar);
62}
63
64/// A point on a curve.
65pub trait Point: Element {
66    /// Maps the provided data to a group element.
67    fn map(&mut self, dst: DST, message: &[u8]);
68
69    /// Performs a multi‑scalar multiplication of the provided points and scalars.
70    fn msm(points: &[Self], scalars: &[Scalar]) -> Self;
71}
72
73/// Wrapper around [`blst_fr`] that represents an element of the BLS12‑381
74/// scalar field `F_r`.
75///
76/// The new‑type is marked `#[repr(transparent)]`, so it has exactly the same
77/// memory layout as the underlying `blst_fr`, allowing safe passage across
78/// the C FFI boundary without additional transmutation.
79///
80/// All arithmetic is performed modulo the prime
81/// `r = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001`,
82/// the order of the BLS12‑381 G1/G2 groups.
83#[derive(Clone, Eq, PartialEq)]
84#[repr(transparent)]
85pub struct Scalar(blst_fr);
86
87/// Number of bytes required to encode a scalar in its canonical
88/// little‑endian form (`32 × 8 = 256 bits`).
89///
90/// Because `r` is only 255 bits wide, the most‑significant byte is always in
91/// the range `0x00‥=0x7f`, leaving the top bit clear.
92const SCALAR_LENGTH: usize = 32;
93
94/// Effective bit‑length of the field modulus `r` (`⌈log_2 r⌉ = 255`).
95///
96/// Useful for constant‑time exponentiation loops and for validating that a
97/// decoded integer lies in the range `0 ≤ x < r`.
98const SCALAR_BITS: usize = 255;
99
100/// This constant serves as the multiplicative identity (i.e., "one") in the
101/// BLS12-381 finite field, ensuring that arithmetic is carried out within the
102/// correct modulo.
103///
104/// `R = 2^256 mod q` in little-endian Montgomery form which is equivalent to 1 in little-endian
105/// non-Montgomery form:
106///
107/// ```txt
108/// mod(2^256, 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001) = 0x1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe
109/// ```
110///
111/// Reference: <https://github.com/filecoin-project/blstrs/blob/ffbb41d1495d84e40a712583346439924603b49a/src/scalar.rs#L77-L89>
112const BLST_FR_ONE: Scalar = Scalar(blst_fr {
113    l: [
114        0x0000_0001_ffff_fffe,
115        0x5884_b7fa_0003_4802,
116        0x998c_4fef_ecbc_4ff5,
117        0x1824_b159_acc5_056f,
118    ],
119});
120
121/// A point on the BLS12-381 G1 curve.
122#[derive(Clone, Copy, Eq, PartialEq)]
123#[repr(transparent)]
124pub struct G1(blst_p1);
125
126/// The size in bytes of an encoded G1 element.
127pub const G1_ELEMENT_BYTE_LENGTH: usize = 48;
128
129/// Domain separation tag for hashing a proof of possession (compressed G2) to G1.
130pub const G1_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
131
132/// Domain separation tag for hashing a message to G1.
133///
134/// We use the `POP` scheme for hashing all messages because this crate is expected to be
135/// used in a Byzantine environment (where any player may attempt a rogue key attack) and
136/// any message could be aggregated into a multi-signature (which requires a proof-of-possession
137/// to be safely deployed in this environment).
138pub const G1_MESSAGE: DST = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
139
140/// A point on the BLS12-381 G2 curve.
141#[derive(Clone, Copy, Eq, PartialEq)]
142#[repr(transparent)]
143pub struct G2(blst_p2);
144
145/// The size in bytes of an encoded G2 element.
146pub const G2_ELEMENT_BYTE_LENGTH: usize = 96;
147
148/// Domain separation tag for hashing a proof of possession (compressed G1) to G2.
149pub const G2_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
150
151/// Domain separation tag for hashing a message to G2.
152///
153/// We use the `POP` scheme for hashing all messages because this crate is expected to be
154/// used in a Byzantine environment (where any player may attempt a rogue key attack) and
155/// any message could be aggregated into a multi-signature (which requires a proof-of-possession
156/// to be safely deployed in this environment).
157pub const G2_MESSAGE: DST = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
158
159/// The target group of the BLS12-381 pairing.
160///
161/// This is an element in the extension field `F_p^12` and is
162/// produced as the result of a pairing operation.
163#[derive(Debug, Clone, Copy, Eq, PartialEq)]
164pub struct GT(blst_fp12);
165
166/// The private key type.
167pub type Private = Scalar;
168
169/// The private key length.
170pub const PRIVATE_KEY_LENGTH: usize = SCALAR_LENGTH;
171
172impl Scalar {
173    /// Generates a random scalar using the provided RNG.
174    pub fn rand<R: RngCore>(rng: &mut R) -> Self {
175        // Generate a random 64 byte buffer
176        let mut ikm = [0u8; 64];
177        rng.fill_bytes(&mut ikm);
178
179        // Generate a scalar from the randomly populated buffer
180        let mut ret = blst_fr::default();
181        unsafe {
182            let mut sc = blst_scalar::default();
183            blst_keygen(&mut sc, ikm.as_ptr(), ikm.len(), ptr::null(), 0);
184            blst_fr_from_scalar(&mut ret, &sc);
185        }
186
187        // Zeroize the ikm buffer
188        ikm.zeroize();
189        Self(ret)
190    }
191
192    /// Sets the scalar to be the provided integer.
193    pub fn set_int(&mut self, i: u32) {
194        // blst requires a buffer of 4 uint64 values. Failure to provide one will
195        // result in unexpected behavior (will read past the provided buffer).
196        //
197        // Reference: https://github.com/supranational/blst/blob/415d4f0e2347a794091836a3065206edfd9c72f3/bindings/blst.h#L102
198        let buffer = [i as u64, 0, 0, 0];
199        unsafe { blst_fr_from_uint64(&mut self.0, buffer.as_ptr()) };
200    }
201
202    /// Computes the inverse of the scalar.
203    pub fn inverse(&self) -> Option<Self> {
204        if *self == Self::zero() {
205            return None;
206        }
207        let mut ret = blst_fr::default();
208        unsafe { blst_fr_inverse(&mut ret, &self.0) };
209        Some(Self(ret))
210    }
211
212    /// Subtracts the provided scalar from self in-place.
213    pub fn sub(&mut self, rhs: &Self) {
214        unsafe { blst_fr_sub(&mut self.0, &self.0, &rhs.0) }
215    }
216
217    /// Encodes the scalar into a slice.
218    fn as_slice(&self) -> [u8; Self::SIZE] {
219        let mut slice = [0u8; Self::SIZE];
220        unsafe {
221            let mut scalar = blst_scalar::default();
222            blst_scalar_from_fr(&mut scalar, &self.0);
223            blst_bendian_from_scalar(slice.as_mut_ptr(), &scalar);
224        }
225        slice
226    }
227
228    /// Converts the scalar to the raw `blst_scalar` type.
229    pub(crate) fn as_blst_scalar(&self) -> blst_scalar {
230        let mut scalar = blst_scalar::default();
231        unsafe { blst_scalar_from_fr(&mut scalar, &self.0) };
232        scalar
233    }
234}
235
236impl Element for Scalar {
237    fn zero() -> Self {
238        Self(blst_fr::default())
239    }
240
241    fn one() -> Self {
242        BLST_FR_ONE
243    }
244
245    fn add(&mut self, rhs: &Self) {
246        unsafe {
247            blst_fr_add(&mut self.0, &self.0, &rhs.0);
248        }
249    }
250
251    fn mul(&mut self, rhs: &Self) {
252        unsafe {
253            blst_fr_mul(&mut self.0, &self.0, &rhs.0);
254        }
255    }
256}
257
258impl Write for Scalar {
259    fn write(&self, buf: &mut impl BufMut) {
260        let slice = self.as_slice();
261        buf.put_slice(&slice);
262    }
263}
264
265impl Read for Scalar {
266    type Cfg = ();
267
268    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
269        let bytes = <[u8; Self::SIZE]>::read(buf)?;
270        let mut ret = blst_fr::default();
271        unsafe {
272            let mut scalar = blst_scalar::default();
273            blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
274            // We use `blst_sk_check` instead of `blst_scalar_fr_check` because the former
275            // performs a non-zero check.
276            //
277            // The IETF BLS12-381 specification allows for zero scalars up to (inclusive) Draft 3
278            // but disallows them after.
279            //
280            // References:
281            // * https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-bls-signature-03#section-2.3
282            // * https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-bls-signature-04#section-2.3
283            if !blst_sk_check(&scalar) {
284                return Err(Invalid("Scalar", "Invalid"));
285            }
286            blst_fr_from_scalar(&mut ret, &scalar);
287        }
288        Ok(Self(ret))
289    }
290}
291
292impl FixedSize for Scalar {
293    const SIZE: usize = SCALAR_LENGTH;
294}
295
296impl Hash for Scalar {
297    fn hash<H: Hasher>(&self, state: &mut H) {
298        let slice = self.as_slice();
299        state.write(&slice);
300    }
301}
302
303impl Debug for Scalar {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        write!(f, "{}", hex(&self.as_slice()))
306    }
307}
308
309impl Display for Scalar {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        write!(f, "{}", hex(&self.as_slice()))
312    }
313}
314
315impl Zeroize for Scalar {
316    fn zeroize(&mut self) {
317        self.0.l.zeroize();
318    }
319}
320
321impl Drop for Scalar {
322    fn drop(&mut self) {
323        self.zeroize();
324    }
325}
326
327impl ZeroizeOnDrop for Scalar {}
328
329/// A share of a threshold signing key.
330#[derive(Clone, PartialEq, Hash)]
331pub struct Share {
332    /// The share's index in the polynomial.
333    pub index: u32,
334    /// The scalar corresponding to the share's secret.
335    pub private: Private,
336}
337
338impl Share {
339    /// Returns the public key corresponding to the share.
340    ///
341    /// This can be verified against the public polynomial.
342    pub fn public<V: Variant>(&self) -> V::Public {
343        let mut public = V::Public::one();
344        public.mul(&self.private);
345        public
346    }
347}
348
349impl Write for Share {
350    fn write(&self, buf: &mut impl BufMut) {
351        UInt(self.index).write(buf);
352        self.private.write(buf);
353    }
354}
355
356impl Read for Share {
357    type Cfg = ();
358
359    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
360        let index = UInt::read(buf)?.into();
361        let private = Private::read(buf)?;
362        Ok(Self { index, private })
363    }
364}
365
366impl EncodeSize for Share {
367    fn encode_size(&self) -> usize {
368        UInt(self.index).encode_size() + self.private.encode_size()
369    }
370}
371
372impl Display for Share {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        write!(f, "Share(index={}, private={})", self.index, self.private)
375    }
376}
377
378impl Debug for Share {
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        write!(f, "Share(index={}, private={})", self.index, self.private)
381    }
382}
383
384impl G1 {
385    /// Encodes the G1 element into a slice.
386    fn as_slice(&self) -> [u8; Self::SIZE] {
387        let mut slice = [0u8; Self::SIZE];
388        unsafe {
389            blst_p1_compress(slice.as_mut_ptr(), &self.0);
390        }
391        slice
392    }
393
394    /// Converts the G1 point to its affine representation.
395    pub(crate) fn as_blst_p1_affine(&self) -> blst_p1_affine {
396        let mut affine = blst_p1_affine::default();
397        unsafe { blst_p1_to_affine(&mut affine, &self.0) };
398        affine
399    }
400
401    /// Creates a G1 point from a raw `blst_p1`.
402    pub(crate) fn from_blst_p1(p: blst_p1) -> Self {
403        Self(p)
404    }
405}
406
407impl Element for G1 {
408    fn zero() -> Self {
409        Self(blst_p1::default())
410    }
411
412    fn one() -> Self {
413        let mut ret = blst_p1::default();
414        unsafe {
415            blst_p1_from_affine(&mut ret, &BLS12_381_G1);
416        }
417        Self(ret)
418    }
419
420    fn add(&mut self, rhs: &Self) {
421        unsafe {
422            blst_p1_add_or_double(&mut self.0, &self.0, &rhs.0);
423        }
424    }
425
426    fn mul(&mut self, rhs: &Scalar) {
427        let mut scalar: blst_scalar = blst_scalar::default();
428        unsafe {
429            blst_scalar_from_fr(&mut scalar, &rhs.0);
430            // To avoid a timing attack during signing, we always perform the same
431            // number of iterations during scalar multiplication.
432            blst_p1_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
433        }
434    }
435}
436
437impl Write for G1 {
438    fn write(&self, buf: &mut impl BufMut) {
439        let slice = self.as_slice();
440        buf.put_slice(&slice);
441    }
442}
443
444impl Read for G1 {
445    type Cfg = ();
446
447    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
448        let bytes = <[u8; Self::SIZE]>::read(buf)?;
449        let mut ret = blst_p1::default();
450        unsafe {
451            let mut affine = blst_p1_affine::default();
452            match blst_p1_uncompress(&mut affine, bytes.as_ptr()) {
453                BLST_ERROR::BLST_SUCCESS => {}
454                BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G1", "Bad encoding")),
455                BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G1", "Not on curve")),
456                BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G1", "Not in group")),
457                BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G1", "Type mismatch")),
458                BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G1", "Verify fail")),
459                BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G1", "PK is Infinity")),
460                BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G1", "Bad scalar")),
461            }
462            blst_p1_from_affine(&mut ret, &affine);
463
464            // Verify that deserialized element isn't infinite
465            if blst_p1_is_inf(&ret) {
466                return Err(Invalid("G1", "Infinity"));
467            }
468
469            // Verify that the deserialized element is in G1
470            if !blst_p1_in_g1(&ret) {
471                return Err(Invalid("G1", "Outside G1"));
472            }
473        }
474        Ok(Self(ret))
475    }
476}
477
478impl FixedSize for G1 {
479    const SIZE: usize = G1_ELEMENT_BYTE_LENGTH;
480}
481
482impl Hash for G1 {
483    fn hash<H: Hasher>(&self, state: &mut H) {
484        let slice = self.as_slice();
485        state.write(&slice);
486    }
487}
488
489impl Point for G1 {
490    fn map(&mut self, dst: DST, data: &[u8]) {
491        unsafe {
492            blst_hash_to_g1(
493                &mut self.0,
494                data.as_ptr(),
495                data.len(),
496                dst.as_ptr(),
497                dst.len(),
498                ptr::null(),
499                0,
500            );
501        }
502    }
503
504    /// Performs multi-scalar multiplication (MSM) on G1 points using Pippenger's algorithm.
505    /// Computes `sum(scalars[i] * points[i])`.
506    ///
507    /// Filters out pairs where the point is the identity element (infinity).
508    /// Returns an error if the lengths of the input slices mismatch.
509    fn msm(points: &[Self], scalars: &[Scalar]) -> Self {
510        // Assert input validity
511        assert_eq!(points.len(), scalars.len(), "mismatched lengths");
512
513        // Prepare points (affine) and scalars (raw blst_scalar)
514        let mut points_filtered = Vec::with_capacity(points.len());
515        let mut scalars_filtered = Vec::with_capacity(scalars.len());
516        for (point, scalar) in points.iter().zip(scalars.iter()) {
517            // `blst` does not filter out infinity, so we must ensure it is impossible.
518            //
519            // Sources:
520            // * https://github.com/supranational/blst/blob/cbc7e166a10d7286b91a3a7bea341e708962db13/src/multi_scalar.c#L10-L12
521            // * https://github.com/MystenLabs/fastcrypto/blob/0acf0ff1a163c60e0dec1e16e4fbad4a4cf853bd/fastcrypto/src/groups/bls12381.rs#L160-L194
522            if *point == G1::zero() || scalar == &Scalar::zero() {
523                continue;
524            }
525
526            // Add to filtered vectors
527            points_filtered.push(point.as_blst_p1_affine());
528            scalars_filtered.push(scalar.as_blst_scalar());
529        }
530
531        // If all points were filtered, return zero.
532        if points_filtered.is_empty() {
533            return G1::zero();
534        }
535
536        // Create vectors of pointers for the blst API.
537        // These vectors hold pointers *to* the elements in the filtered vectors above.
538        let points: Vec<*const blst_p1_affine> =
539            points_filtered.iter().map(|p| p as *const _).collect();
540        let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
541
542        // Allocate scratch space for Pippenger's algorithm.
543        let scratch_size = unsafe { blst_p1s_mult_pippenger_scratch_sizeof(points.len()) };
544        let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
545
546        // Perform multi-scalar multiplication
547        let mut msm_result = blst_p1::default();
548        unsafe {
549            blst_p1s_mult_pippenger(
550                &mut msm_result,
551                points.as_ptr(),
552                points.len(),
553                scalars.as_ptr(),
554                SCALAR_BITS, // Using SCALAR_BITS (255) ensures full scalar range
555                scratch.as_mut_ptr() as *mut _,
556            );
557        }
558
559        G1::from_blst_p1(msm_result)
560    }
561}
562
563impl Debug for G1 {
564    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565        write!(f, "{}", hex(&self.as_slice()))
566    }
567}
568
569impl Display for G1 {
570    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
571        write!(f, "{}", hex(&self.as_slice()))
572    }
573}
574
575impl AsRef<G1> for G1 {
576    fn as_ref(&self) -> &Self {
577        self
578    }
579}
580
581impl G2 {
582    /// Encodes the G2 element into a slice.
583    fn as_slice(&self) -> [u8; Self::SIZE] {
584        let mut slice = [0u8; Self::SIZE];
585        unsafe {
586            blst_p2_compress(slice.as_mut_ptr(), &self.0);
587        }
588        slice
589    }
590
591    /// Converts the G2 point to its affine representation.
592    pub(crate) fn as_blst_p2_affine(&self) -> blst_p2_affine {
593        let mut affine = blst_p2_affine::default();
594        unsafe { blst_p2_to_affine(&mut affine, &self.0) };
595        affine
596    }
597
598    /// Creates a G2 point from a raw `blst_p2`.
599    pub(crate) fn from_blst_p2(p: blst_p2) -> Self {
600        Self(p)
601    }
602}
603
604impl Element for G2 {
605    fn zero() -> Self {
606        Self(blst_p2::default())
607    }
608
609    fn one() -> Self {
610        let mut ret = blst_p2::default();
611        unsafe {
612            blst_p2_from_affine(&mut ret, &BLS12_381_G2);
613        }
614        Self(ret)
615    }
616
617    fn add(&mut self, rhs: &Self) {
618        unsafe {
619            blst_p2_add_or_double(&mut self.0, &self.0, &rhs.0);
620        }
621    }
622
623    fn mul(&mut self, rhs: &Scalar) {
624        let mut scalar = blst_scalar::default();
625        unsafe {
626            blst_scalar_from_fr(&mut scalar, &rhs.0);
627            // To avoid a timing attack during signing, we always perform the same
628            // number of iterations during scalar multiplication.
629            blst_p2_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
630        }
631    }
632}
633
634impl Write for G2 {
635    fn write(&self, buf: &mut impl BufMut) {
636        let slice = self.as_slice();
637        buf.put_slice(&slice);
638    }
639}
640
641impl Read for G2 {
642    type Cfg = ();
643
644    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
645        let bytes = <[u8; Self::SIZE]>::read(buf)?;
646        let mut ret = blst_p2::default();
647        unsafe {
648            let mut affine = blst_p2_affine::default();
649            match blst_p2_uncompress(&mut affine, bytes.as_ptr()) {
650                BLST_ERROR::BLST_SUCCESS => {}
651                BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G2", "Bad encoding")),
652                BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G2", "Not on curve")),
653                BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G2", "Not in group")),
654                BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G2", "Type mismatch")),
655                BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G2", "Verify fail")),
656                BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G2", "PK is Infinity")),
657                BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G2", "Bad scalar")),
658            }
659            blst_p2_from_affine(&mut ret, &affine);
660
661            // Verify that deserialized element isn't infinite
662            if blst_p2_is_inf(&ret) {
663                return Err(Invalid("G2", "Infinity"));
664            }
665
666            // Verify that the deserialized element is in G2
667            if !blst_p2_in_g2(&ret) {
668                return Err(Invalid("G2", "Outside G2"));
669            }
670        }
671        Ok(Self(ret))
672    }
673}
674
675impl FixedSize for G2 {
676    const SIZE: usize = G2_ELEMENT_BYTE_LENGTH;
677}
678
679impl Hash for G2 {
680    fn hash<H: Hasher>(&self, state: &mut H) {
681        let slice = self.as_slice();
682        state.write(&slice);
683    }
684}
685
686impl Point for G2 {
687    fn map(&mut self, dst: DST, data: &[u8]) {
688        unsafe {
689            blst_hash_to_g2(
690                &mut self.0,
691                data.as_ptr(),
692                data.len(),
693                dst.as_ptr(),
694                dst.len(),
695                ptr::null(),
696                0,
697            );
698        }
699    }
700
701    /// Performs multi-scalar multiplication (MSM) on G2 points using Pippenger's algorithm.
702    /// Computes `sum(scalars[i] * points[i])`.
703    ///
704    /// Filters out pairs where the point is the identity element (infinity).
705    /// Returns an error if the lengths of the input slices mismatch.
706    fn msm(points: &[Self], scalars: &[Scalar]) -> Self {
707        // Assert input validity
708        assert_eq!(points.len(), scalars.len(), "mismatched lengths");
709
710        // Prepare points (affine) and scalars (raw blst_scalar), filtering identity points
711        let mut points_filtered = Vec::with_capacity(points.len());
712        let mut scalars_filtered = Vec::with_capacity(scalars.len());
713        for (point, scalar) in points.iter().zip(scalars.iter()) {
714            // `blst` does not filter out infinity, so we must ensure it is impossible.
715            //
716            // Sources:
717            // * https://github.com/supranational/blst/blob/cbc7e166a10d7286b91a3a7bea341e708962db13/src/multi_scalar.c#L10-L12
718            // * https://github.com/MystenLabs/fastcrypto/blob/0acf0ff1a163c60e0dec1e16e4fbad4a4cf853bd/fastcrypto/src/groups/bls12381.rs#L160-L194
719            if *point == G2::zero() || scalar == &Scalar::zero() {
720                continue;
721            }
722            points_filtered.push(point.as_blst_p2_affine());
723            scalars_filtered.push(scalar.as_blst_scalar());
724        }
725
726        // If all points were filtered, return zero.
727        if points_filtered.is_empty() {
728            return G2::zero();
729        }
730
731        // Create vectors of pointers for the blst API
732        let points: Vec<*const blst_p2_affine> =
733            points_filtered.iter().map(|p| p as *const _).collect();
734        let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
735
736        // Allocate scratch space for Pippenger algorithm
737        let scratch_size = unsafe { blst_p2s_mult_pippenger_scratch_sizeof(points.len()) };
738        let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
739
740        // Perform multi-scalar multiplication
741        let mut msm_result = blst_p2::default();
742        unsafe {
743            blst_p2s_mult_pippenger(
744                &mut msm_result,
745                points.as_ptr(),
746                points.len(),
747                scalars.as_ptr(),
748                SCALAR_BITS, // Using SCALAR_BITS (255) ensures full scalar range
749                scratch.as_mut_ptr() as *mut _,
750            );
751        }
752
753        G2::from_blst_p2(msm_result)
754    }
755}
756
757impl Debug for G2 {
758    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
759        write!(f, "{}", hex(&self.as_slice()))
760    }
761}
762
763impl Display for G2 {
764    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
765        write!(f, "{}", hex(&self.as_slice()))
766    }
767}
768
769impl AsRef<G2> for G2 {
770    fn as_ref(&self) -> &Self {
771        self
772    }
773}
774
775#[cfg(test)]
776mod tests {
777    use super::*;
778    use commonware_codec::{DecodeExt, Encode};
779    use rand::prelude::*;
780
781    #[test]
782    fn basic_group() {
783        // Reference: https://github.com/celo-org/celo-threshold-bls-rs/blob/b0ef82ff79769d085a5a7d3f4fe690b1c8fe6dc9/crates/threshold-bls/src/curve/bls12381.rs#L200-L220
784        let s = Scalar::rand(&mut thread_rng());
785        let mut e1 = s.clone();
786        let e2 = s.clone();
787        let mut s2 = s.clone();
788        s2.add(&s);
789        s2.mul(&s);
790        e1.add(&e2);
791        e1.mul(&e2);
792
793        // p1 = s2 * G = (s+s)G
794        let mut p1 = G1::zero();
795        p1.mul(&s2);
796
797        // p2 = sG + sG = s2 * G
798        let mut p2 = G1::zero();
799        p2.mul(&s);
800        p2.add(&p2.clone());
801        assert_eq!(p1, p2);
802    }
803
804    #[test]
805    fn test_scalar_codec() {
806        let original = Scalar::rand(&mut thread_rng());
807        let mut encoded = original.encode();
808        assert_eq!(encoded.len(), Scalar::SIZE);
809        let decoded = Scalar::decode(&mut encoded).unwrap();
810        assert_eq!(original, decoded);
811    }
812
813    #[test]
814    fn test_g1_codec() {
815        let mut original = G1::one();
816        original.mul(&Scalar::rand(&mut thread_rng()));
817        let mut encoded = original.encode();
818        assert_eq!(encoded.len(), G1::SIZE);
819        let decoded = G1::decode(&mut encoded).unwrap();
820        assert_eq!(original, decoded);
821    }
822
823    #[test]
824    fn test_g2_codec() {
825        let mut original = G2::one();
826        original.mul(&Scalar::rand(&mut thread_rng()));
827        let mut encoded = original.encode();
828        assert_eq!(encoded.len(), G2::SIZE);
829        let decoded = G2::decode(&mut encoded).unwrap();
830        assert_eq!(original, decoded);
831    }
832
833    /// Naive calculation of Multi-Scalar Multiplication: sum(scalar * point)
834    fn naive_msm<P: Point>(points: &[P], scalars: &[Scalar]) -> P {
835        assert_eq!(points.len(), scalars.len());
836        let mut total = P::zero();
837        for (point, scalar) in points.iter().zip(scalars.iter()) {
838            // Skip identity points or zero scalars, similar to the optimized MSM
839            if *point == P::zero() || *scalar == Scalar::zero() {
840                continue;
841            }
842            let mut term = point.clone();
843            term.mul(scalar);
844            total.add(&term);
845        }
846        total
847    }
848
849    #[test]
850    fn test_g1_msm() {
851        let mut rng = thread_rng();
852        let n = 10; // Number of points/scalars
853
854        // Case 1: Random points and scalars
855        let points_g1: Vec<G1> = (0..n)
856            .map(|_| {
857                let mut point = G1::one();
858                point.mul(&Scalar::rand(&mut rng));
859                point
860            })
861            .collect();
862        let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::rand(&mut rng)).collect();
863        let expected_g1 = naive_msm(&points_g1, &scalars);
864        let result_g1 = G1::msm(&points_g1, &scalars);
865        assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
866
867        // Case 2: Include identity point
868        let mut points_with_zero_g1 = points_g1.clone();
869        points_with_zero_g1[n / 2] = G1::zero();
870        let expected_zero_pt_g1 = naive_msm(&points_with_zero_g1, &scalars);
871        let result_zero_pt_g1 = G1::msm(&points_with_zero_g1, &scalars);
872        assert_eq!(
873            expected_zero_pt_g1, result_zero_pt_g1,
874            "G1 MSM with identity point failed"
875        );
876
877        // Case 3: Include zero scalar
878        let mut scalars_with_zero = scalars.clone();
879        scalars_with_zero[n / 2] = Scalar::zero();
880        let expected_zero_sc_g1 = naive_msm(&points_g1, &scalars_with_zero);
881        let result_zero_sc_g1 = G1::msm(&points_g1, &scalars_with_zero);
882        assert_eq!(
883            expected_zero_sc_g1, result_zero_sc_g1,
884            "G1 MSM with zero scalar failed"
885        );
886
887        // Case 4: All points identity
888        let zero_points_g1 = vec![G1::zero(); n];
889        let expected_all_zero_pt_g1 = naive_msm(&zero_points_g1, &scalars);
890        let result_all_zero_pt_g1 = G1::msm(&zero_points_g1, &scalars);
891        assert_eq!(
892            expected_all_zero_pt_g1,
893            G1::zero(),
894            "G1 MSM all identity points (naive) failed"
895        );
896        assert_eq!(
897            result_all_zero_pt_g1,
898            G1::zero(),
899            "G1 MSM all identity points failed"
900        );
901
902        // Case 5: All scalars zero
903        let zero_scalars = vec![Scalar::zero(); n];
904        let expected_all_zero_sc_g1 = naive_msm(&points_g1, &zero_scalars);
905        let result_all_zero_sc_g1 = G1::msm(&points_g1, &zero_scalars);
906        assert_eq!(
907            expected_all_zero_sc_g1,
908            G1::zero(),
909            "G1 MSM all zero scalars (naive) failed"
910        );
911        assert_eq!(
912            result_all_zero_sc_g1,
913            G1::zero(),
914            "G1 MSM all zero scalars failed"
915        );
916
917        // Case 6: Single element
918        let single_point_g1 = [points_g1[0]];
919        let single_scalar = [scalars[0].clone()];
920        let expected_single_g1 = naive_msm(&single_point_g1, &single_scalar);
921        let result_single_g1 = G1::msm(&single_point_g1, &single_scalar);
922        assert_eq!(
923            expected_single_g1, result_single_g1,
924            "G1 MSM single element failed"
925        );
926
927        // Case 7: Empty input
928        let empty_points_g1: [G1; 0] = [];
929        let empty_scalars: [Scalar; 0] = [];
930        let expected_empty_g1 = naive_msm(&empty_points_g1, &empty_scalars);
931        let result_empty_g1 = G1::msm(&empty_points_g1, &empty_scalars);
932        assert_eq!(expected_empty_g1, G1::zero(), "G1 MSM empty (naive) failed");
933        assert_eq!(result_empty_g1, G1::zero(), "G1 MSM empty failed");
934
935        // Case 8: Random points and scalars (big)
936        let points_g1: Vec<G1> = (0..50_000)
937            .map(|_| {
938                let mut point = G1::one();
939                point.mul(&Scalar::rand(&mut rng));
940                point
941            })
942            .collect();
943        let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::rand(&mut rng)).collect();
944        let expected_g1 = naive_msm(&points_g1, &scalars);
945        let result_g1 = G1::msm(&points_g1, &scalars);
946        assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
947    }
948
949    #[test]
950    fn test_g2_msm() {
951        let mut rng = thread_rng();
952        let n = 10; // Number of points/scalars
953
954        // Case 1: Random points and scalars
955        let points_g2: Vec<G2> = (0..n)
956            .map(|_| {
957                let mut point = G2::one();
958                point.mul(&Scalar::rand(&mut rng));
959                point
960            })
961            .collect();
962        let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::rand(&mut rng)).collect();
963        let expected_g2 = naive_msm(&points_g2, &scalars);
964        let result_g2 = G2::msm(&points_g2, &scalars);
965        assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
966
967        // Case 2: Include identity point
968        let mut points_with_zero_g2 = points_g2.clone();
969        points_with_zero_g2[n / 2] = G2::zero();
970        let expected_zero_pt_g2 = naive_msm(&points_with_zero_g2, &scalars);
971        let result_zero_pt_g2 = G2::msm(&points_with_zero_g2, &scalars);
972        assert_eq!(
973            expected_zero_pt_g2, result_zero_pt_g2,
974            "G2 MSM with identity point failed"
975        );
976
977        // Case 3: Include zero scalar
978        let mut scalars_with_zero = scalars.clone();
979        scalars_with_zero[n / 2] = Scalar::zero();
980        let expected_zero_sc_g2 = naive_msm(&points_g2, &scalars_with_zero);
981        let result_zero_sc_g2 = G2::msm(&points_g2, &scalars_with_zero);
982        assert_eq!(
983            expected_zero_sc_g2, result_zero_sc_g2,
984            "G2 MSM with zero scalar failed"
985        );
986
987        // Case 4: All points identity
988        let zero_points_g2 = vec![G2::zero(); n];
989        let expected_all_zero_pt_g2 = naive_msm(&zero_points_g2, &scalars);
990        let result_all_zero_pt_g2 = G2::msm(&zero_points_g2, &scalars);
991        assert_eq!(
992            expected_all_zero_pt_g2,
993            G2::zero(),
994            "G2 MSM all identity points (naive) failed"
995        );
996        assert_eq!(
997            result_all_zero_pt_g2,
998            G2::zero(),
999            "G2 MSM all identity points failed"
1000        );
1001
1002        // Case 5: All scalars zero
1003        let zero_scalars = vec![Scalar::zero(); n];
1004        let expected_all_zero_sc_g2 = naive_msm(&points_g2, &zero_scalars);
1005        let result_all_zero_sc_g2 = G2::msm(&points_g2, &zero_scalars);
1006        assert_eq!(
1007            expected_all_zero_sc_g2,
1008            G2::zero(),
1009            "G2 MSM all zero scalars (naive) failed"
1010        );
1011        assert_eq!(
1012            result_all_zero_sc_g2,
1013            G2::zero(),
1014            "G2 MSM all zero scalars failed"
1015        );
1016
1017        // Case 6: Single element
1018        let single_point_g2 = [points_g2[0]];
1019        let single_scalar = [scalars[0].clone()];
1020        let expected_single_g2 = naive_msm(&single_point_g2, &single_scalar);
1021        let result_single_g2 = G2::msm(&single_point_g2, &single_scalar);
1022        assert_eq!(
1023            expected_single_g2, result_single_g2,
1024            "G2 MSM single element failed"
1025        );
1026
1027        // Case 7: Empty input
1028        let empty_points_g2: [G2; 0] = [];
1029        let empty_scalars: [Scalar; 0] = [];
1030        let expected_empty_g2 = naive_msm(&empty_points_g2, &empty_scalars);
1031        let result_empty_g2 = G2::msm(&empty_points_g2, &empty_scalars);
1032        assert_eq!(expected_empty_g2, G2::zero(), "G2 MSM empty (naive) failed");
1033        assert_eq!(result_empty_g2, G2::zero(), "G2 MSM empty failed");
1034
1035        // Case 8: Random points and scalars (big)
1036        let points_g2: Vec<G2> = (0..50_000)
1037            .map(|_| {
1038                let mut point = G2::one();
1039                point.mul(&Scalar::rand(&mut rng));
1040                point
1041            })
1042            .collect();
1043        let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::rand(&mut rng)).collect();
1044        let expected_g2 = naive_msm(&points_g2, &scalars);
1045        let result_g2 = G2::msm(&points_g2, &scalars);
1046        assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
1047    }
1048}