use super::variant::Variant;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use blst::{
blst_bendian_from_scalar, blst_expand_message_xmd, blst_fp12, blst_fr, blst_fr_add,
blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_inverse, blst_fr_mul, blst_fr_sub,
blst_hash_to_g1, blst_hash_to_g2, blst_keygen, blst_p1, blst_p1_add_or_double, blst_p1_affine,
blst_p1_compress, blst_p1_from_affine, blst_p1_in_g1, blst_p1_is_inf, blst_p1_mult,
blst_p1_to_affine, blst_p1_uncompress, blst_p1s_mult_pippenger,
blst_p1s_mult_pippenger_scratch_sizeof, blst_p2, blst_p2_add_or_double, blst_p2_affine,
blst_p2_compress, blst_p2_from_affine, blst_p2_in_g2, blst_p2_is_inf, blst_p2_mult,
blst_p2_to_affine, blst_p2_uncompress, blst_p2s_mult_pippenger,
blst_p2s_mult_pippenger_scratch_sizeof, blst_scalar, blst_scalar_from_be_bytes,
blst_scalar_from_bendian, blst_scalar_from_fr, blst_sk_check, BLS12_381_G1, BLS12_381_G2,
BLST_ERROR,
};
use bytes::{Buf, BufMut};
use commonware_codec::{
varint::UInt,
EncodeSize,
Error::{self, Invalid},
FixedSize, Read, ReadExt, Write,
};
use commonware_utils::hex;
use core::{
fmt::{Debug, Display, Formatter},
hash::{Hash, Hasher},
mem::MaybeUninit,
ptr,
};
use rand_core::CryptoRngCore;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub type DST = &'static [u8];
pub trait Element:
Read<Cfg = ()> + Write + FixedSize + Clone + Eq + PartialEq + Ord + PartialOrd + Hash + Send + Sync
{
fn zero() -> Self;
fn one() -> Self;
fn add(&mut self, rhs: &Self);
fn mul(&mut self, rhs: &Scalar);
}
pub trait Point: Element {
fn map(&mut self, dst: DST, message: &[u8]);
fn msm(points: &[Self], scalars: &[Scalar]) -> Self;
}
#[derive(Clone, Eq, PartialEq)]
#[repr(transparent)]
pub struct Scalar(blst_fr);
pub const SCALAR_LENGTH: usize = 32;
const SCALAR_BITS: usize = 255;
const BLST_FR_ONE: Scalar = Scalar(blst_fr {
l: [
0x0000_0001_ffff_fffe,
0x5884_b7fa_0003_4802,
0x998c_4fef_ecbc_4ff5,
0x1824_b159_acc5_056f,
],
});
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(transparent)]
pub struct G1(blst_p1);
pub const G1_ELEMENT_BYTE_LENGTH: usize = 48;
pub const G1_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
pub const G1_MESSAGE: DST = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(transparent)]
pub struct G2(blst_p2);
pub const G2_ELEMENT_BYTE_LENGTH: usize = 96;
pub const G2_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
pub const G2_MESSAGE: DST = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
#[derive(Debug, Clone, Eq, PartialEq, Copy)]
#[repr(transparent)]
pub struct GT(blst_fp12);
pub const GT_ELEMENT_BYTE_LENGTH: usize = 576;
impl GT {
pub(crate) fn from_blst_fp12(fp12: blst_fp12) -> Self {
GT(fp12)
}
pub fn as_slice(&self) -> [u8; GT_ELEMENT_BYTE_LENGTH] {
let mut slice = [0u8; GT_ELEMENT_BYTE_LENGTH];
unsafe {
let fp12_ptr = &self.0 as *const blst_fp12 as *const u8;
core::ptr::copy_nonoverlapping(fp12_ptr, slice.as_mut_ptr(), GT_ELEMENT_BYTE_LENGTH);
}
slice
}
}
pub type Private = Scalar;
pub const PRIVATE_KEY_LENGTH: usize = SCALAR_LENGTH;
impl Scalar {
pub fn from_rand<R: CryptoRngCore>(rng: &mut R) -> Self {
let mut ikm = [0u8; 64];
rng.fill_bytes(&mut ikm);
let mut ret = blst_fr::default();
unsafe {
let mut sc = blst_scalar::default();
blst_keygen(&mut sc, ikm.as_ptr(), ikm.len(), ptr::null(), 0);
blst_fr_from_scalar(&mut ret, &sc);
}
ikm.zeroize();
Self(ret)
}
pub fn map(dst: DST, msg: &[u8]) -> Self {
const L: usize = 48;
let mut uniform_bytes = [0u8; L];
unsafe {
blst_expand_message_xmd(
uniform_bytes.as_mut_ptr(),
L,
msg.as_ptr(),
msg.len(),
dst.as_ptr(),
dst.len(),
);
}
let mut fr = blst_fr::default();
unsafe {
let mut scalar = blst_scalar::default();
blst_scalar_from_be_bytes(&mut scalar, uniform_bytes.as_ptr(), L);
blst_fr_from_scalar(&mut fr, &scalar);
}
Self(fr)
}
fn from_u64(i: u64) -> Self {
let mut ret = blst_fr::default();
let buffer = [i, 0, 0, 0];
unsafe { blst_fr_from_uint64(&mut ret, buffer.as_ptr()) };
Self(ret)
}
pub fn from_index(i: u32) -> Self {
Self::from(i as u64 + 1)
}
pub fn inverse(&self) -> Option<Self> {
if *self == Self::zero() {
return None;
}
let mut ret = blst_fr::default();
unsafe { blst_fr_inverse(&mut ret, &self.0) };
Some(Self(ret))
}
pub fn sub(&mut self, rhs: &Self) {
unsafe { blst_fr_sub(&mut self.0, &self.0, &rhs.0) }
}
fn as_slice(&self) -> [u8; Self::SIZE] {
let mut slice = [0u8; Self::SIZE];
unsafe {
let mut scalar = blst_scalar::default();
blst_scalar_from_fr(&mut scalar, &self.0);
blst_bendian_from_scalar(slice.as_mut_ptr(), &scalar);
}
slice
}
pub(crate) fn as_blst_scalar(&self) -> blst_scalar {
let mut scalar = blst_scalar::default();
unsafe { blst_scalar_from_fr(&mut scalar, &self.0) };
scalar
}
}
impl From<u32> for Scalar {
fn from(i: u32) -> Self {
Self::from(i as u64)
}
}
impl From<u64> for Scalar {
fn from(i: u64) -> Self {
Self::from_u64(i)
}
}
impl Element for Scalar {
fn zero() -> Self {
Self(blst_fr::default())
}
fn one() -> Self {
BLST_FR_ONE
}
fn add(&mut self, rhs: &Self) {
unsafe {
blst_fr_add(&mut self.0, &self.0, &rhs.0);
}
}
fn mul(&mut self, rhs: &Self) {
unsafe {
blst_fr_mul(&mut self.0, &self.0, &rhs.0);
}
}
}
impl Write for Scalar {
fn write(&self, buf: &mut impl BufMut) {
let slice = self.as_slice();
buf.put_slice(&slice);
}
}
impl Read for Scalar {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let bytes = <[u8; Self::SIZE]>::read(buf)?;
let mut ret = blst_fr::default();
unsafe {
let mut scalar = blst_scalar::default();
blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
if !blst_sk_check(&scalar) {
return Err(Invalid("Scalar", "Invalid"));
}
blst_fr_from_scalar(&mut ret, &scalar);
}
Ok(Self(ret))
}
}
impl FixedSize for Scalar {
const SIZE: usize = SCALAR_LENGTH;
}
impl Hash for Scalar {
fn hash<H: Hasher>(&self, state: &mut H) {
let slice = self.as_slice();
state.write(&slice);
}
}
impl PartialOrd for Scalar {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Scalar {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.as_slice().cmp(&other.as_slice())
}
}
impl Debug for Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl Display for Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl Zeroize for Scalar {
fn zeroize(&mut self) {
self.0.l.zeroize();
}
}
impl Drop for Scalar {
fn drop(&mut self) {
self.zeroize();
}
}
impl ZeroizeOnDrop for Scalar {}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Share {
pub index: u32,
pub private: Private,
}
impl AsRef<Private> for Share {
fn as_ref(&self) -> &Private {
&self.private
}
}
impl Share {
pub fn public<V: Variant>(&self) -> V::Public {
let mut public = V::Public::one();
public.mul(&self.private);
public
}
}
impl Write for Share {
fn write(&self, buf: &mut impl BufMut) {
UInt(self.index).write(buf);
self.private.write(buf);
}
}
impl Read for Share {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let index = UInt::read(buf)?.into();
let private = Private::read(buf)?;
Ok(Self { index, private })
}
}
impl EncodeSize for Share {
fn encode_size(&self) -> usize {
UInt(self.index).encode_size() + self.private.encode_size()
}
}
impl Display for Share {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "Share(index={}, private={})", self.index, self.private)
}
}
impl Debug for Share {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "Share(index={}, private={})", self.index, self.private)
}
}
impl G1 {
fn as_slice(&self) -> [u8; Self::SIZE] {
let mut slice = [0u8; Self::SIZE];
unsafe {
blst_p1_compress(slice.as_mut_ptr(), &self.0);
}
slice
}
pub(crate) fn as_blst_p1_affine(&self) -> blst_p1_affine {
let mut affine = blst_p1_affine::default();
unsafe { blst_p1_to_affine(&mut affine, &self.0) };
affine
}
pub(crate) fn from_blst_p1(p: blst_p1) -> Self {
Self(p)
}
}
impl Element for G1 {
fn zero() -> Self {
Self(blst_p1::default())
}
fn one() -> Self {
let mut ret = blst_p1::default();
unsafe {
blst_p1_from_affine(&mut ret, &BLS12_381_G1);
}
Self(ret)
}
fn add(&mut self, rhs: &Self) {
unsafe {
blst_p1_add_or_double(&mut self.0, &self.0, &rhs.0);
}
}
fn mul(&mut self, rhs: &Scalar) {
let mut scalar: blst_scalar = blst_scalar::default();
unsafe {
blst_scalar_from_fr(&mut scalar, &rhs.0);
blst_p1_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
}
}
}
impl Write for G1 {
fn write(&self, buf: &mut impl BufMut) {
let slice = self.as_slice();
buf.put_slice(&slice);
}
}
impl Read for G1 {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let bytes = <[u8; Self::SIZE]>::read(buf)?;
let mut ret = blst_p1::default();
unsafe {
let mut affine = blst_p1_affine::default();
match blst_p1_uncompress(&mut affine, bytes.as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G1", "Bad encoding")),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G1", "Not on curve")),
BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G1", "Not in group")),
BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G1", "Type mismatch")),
BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G1", "Verify fail")),
BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G1", "PK is Infinity")),
BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G1", "Bad scalar")),
}
blst_p1_from_affine(&mut ret, &affine);
if blst_p1_is_inf(&ret) {
return Err(Invalid("G1", "Infinity"));
}
if !blst_p1_in_g1(&ret) {
return Err(Invalid("G1", "Outside G1"));
}
}
Ok(Self(ret))
}
}
impl FixedSize for G1 {
const SIZE: usize = G1_ELEMENT_BYTE_LENGTH;
}
impl Hash for G1 {
fn hash<H: Hasher>(&self, state: &mut H) {
let slice = self.as_slice();
state.write(&slice);
}
}
impl PartialOrd for G1 {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for G1 {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.as_slice().cmp(&other.as_slice())
}
}
impl Point for G1 {
fn map(&mut self, dst: DST, data: &[u8]) {
unsafe {
blst_hash_to_g1(
&mut self.0,
data.as_ptr(),
data.len(),
dst.as_ptr(),
dst.len(),
ptr::null(),
0,
);
}
}
fn msm(points: &[Self], scalars: &[Scalar]) -> Self {
assert_eq!(points.len(), scalars.len(), "mismatched lengths");
let mut points_filtered = Vec::with_capacity(points.len());
let mut scalars_filtered = Vec::with_capacity(scalars.len());
for (point, scalar) in points.iter().zip(scalars.iter()) {
if *point == G1::zero() || scalar == &Scalar::zero() {
continue;
}
points_filtered.push(point.as_blst_p1_affine());
scalars_filtered.push(scalar.as_blst_scalar());
}
if points_filtered.is_empty() {
return G1::zero();
}
let points: Vec<*const blst_p1_affine> =
points_filtered.iter().map(|p| p as *const _).collect();
let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
let scratch_size = unsafe { blst_p1s_mult_pippenger_scratch_sizeof(points.len()) };
let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
let mut msm_result = blst_p1::default();
unsafe {
blst_p1s_mult_pippenger(
&mut msm_result,
points.as_ptr(),
points.len(),
scalars.as_ptr(),
SCALAR_BITS, scratch.as_mut_ptr() as *mut _,
);
}
G1::from_blst_p1(msm_result)
}
}
impl Debug for G1 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl Display for G1 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl AsRef<G1> for G1 {
fn as_ref(&self) -> &Self {
self
}
}
impl G2 {
fn as_slice(&self) -> [u8; Self::SIZE] {
let mut slice = [0u8; Self::SIZE];
unsafe {
blst_p2_compress(slice.as_mut_ptr(), &self.0);
}
slice
}
pub(crate) fn as_blst_p2_affine(&self) -> blst_p2_affine {
let mut affine = blst_p2_affine::default();
unsafe { blst_p2_to_affine(&mut affine, &self.0) };
affine
}
pub(crate) fn from_blst_p2(p: blst_p2) -> Self {
Self(p)
}
}
impl Element for G2 {
fn zero() -> Self {
Self(blst_p2::default())
}
fn one() -> Self {
let mut ret = blst_p2::default();
unsafe {
blst_p2_from_affine(&mut ret, &BLS12_381_G2);
}
Self(ret)
}
fn add(&mut self, rhs: &Self) {
unsafe {
blst_p2_add_or_double(&mut self.0, &self.0, &rhs.0);
}
}
fn mul(&mut self, rhs: &Scalar) {
let mut scalar = blst_scalar::default();
unsafe {
blst_scalar_from_fr(&mut scalar, &rhs.0);
blst_p2_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
}
}
}
impl Write for G2 {
fn write(&self, buf: &mut impl BufMut) {
let slice = self.as_slice();
buf.put_slice(&slice);
}
}
impl Read for G2 {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let bytes = <[u8; Self::SIZE]>::read(buf)?;
let mut ret = blst_p2::default();
unsafe {
let mut affine = blst_p2_affine::default();
match blst_p2_uncompress(&mut affine, bytes.as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G2", "Bad encoding")),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G2", "Not on curve")),
BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G2", "Not in group")),
BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G2", "Type mismatch")),
BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G2", "Verify fail")),
BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G2", "PK is Infinity")),
BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G2", "Bad scalar")),
}
blst_p2_from_affine(&mut ret, &affine);
if blst_p2_is_inf(&ret) {
return Err(Invalid("G2", "Infinity"));
}
if !blst_p2_in_g2(&ret) {
return Err(Invalid("G2", "Outside G2"));
}
}
Ok(Self(ret))
}
}
impl FixedSize for G2 {
const SIZE: usize = G2_ELEMENT_BYTE_LENGTH;
}
impl Hash for G2 {
fn hash<H: Hasher>(&self, state: &mut H) {
let slice = self.as_slice();
state.write(&slice);
}
}
impl PartialOrd for G2 {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for G2 {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.as_slice().cmp(&other.as_slice())
}
}
impl Point for G2 {
fn map(&mut self, dst: DST, data: &[u8]) {
unsafe {
blst_hash_to_g2(
&mut self.0,
data.as_ptr(),
data.len(),
dst.as_ptr(),
dst.len(),
ptr::null(),
0,
);
}
}
fn msm(points: &[Self], scalars: &[Scalar]) -> Self {
assert_eq!(points.len(), scalars.len(), "mismatched lengths");
let mut points_filtered = Vec::with_capacity(points.len());
let mut scalars_filtered = Vec::with_capacity(scalars.len());
for (point, scalar) in points.iter().zip(scalars.iter()) {
if *point == G2::zero() || scalar == &Scalar::zero() {
continue;
}
points_filtered.push(point.as_blst_p2_affine());
scalars_filtered.push(scalar.as_blst_scalar());
}
if points_filtered.is_empty() {
return G2::zero();
}
let points: Vec<*const blst_p2_affine> =
points_filtered.iter().map(|p| p as *const _).collect();
let scalars: Vec<*const u8> = scalars_filtered.iter().map(|s| s.b.as_ptr()).collect();
let scratch_size = unsafe { blst_p2s_mult_pippenger_scratch_sizeof(points.len()) };
let mut scratch = vec![MaybeUninit::<u64>::uninit(); scratch_size / 8];
let mut msm_result = blst_p2::default();
unsafe {
blst_p2s_mult_pippenger(
&mut msm_result,
points.as_ptr(),
points.len(),
scalars.as_ptr(),
SCALAR_BITS, scratch.as_mut_ptr() as *mut _,
);
}
G2::from_blst_p2(msm_result)
}
}
impl Debug for G2 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl Display for G2 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl AsRef<G2> for G2 {
fn as_ref(&self) -> &Self {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_codec::{DecodeExt, Encode};
use rand::prelude::*;
use std::collections::{BTreeSet, HashMap};
#[test]
fn basic_group() {
let s = Scalar::from_rand(&mut thread_rng());
let mut e1 = s.clone();
let e2 = s.clone();
let mut s2 = s.clone();
s2.add(&s);
s2.mul(&s);
e1.add(&e2);
e1.mul(&e2);
let mut p1 = G1::zero();
p1.mul(&s2);
let mut p2 = G1::zero();
p2.mul(&s);
p2.add(&p2.clone());
assert_eq!(p1, p2);
}
#[test]
fn test_scalar_codec() {
let original = Scalar::from_rand(&mut thread_rng());
let mut encoded = original.encode();
assert_eq!(encoded.len(), Scalar::SIZE);
let decoded = Scalar::decode(&mut encoded).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_g1_codec() {
let mut original = G1::one();
original.mul(&Scalar::from_rand(&mut thread_rng()));
let mut encoded = original.encode();
assert_eq!(encoded.len(), G1::SIZE);
let decoded = G1::decode(&mut encoded).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_g2_codec() {
let mut original = G2::one();
original.mul(&Scalar::from_rand(&mut thread_rng()));
let mut encoded = original.encode();
assert_eq!(encoded.len(), G2::SIZE);
let decoded = G2::decode(&mut encoded).unwrap();
assert_eq!(original, decoded);
}
fn naive_msm<P: Point>(points: &[P], scalars: &[Scalar]) -> P {
assert_eq!(points.len(), scalars.len());
let mut total = P::zero();
for (point, scalar) in points.iter().zip(scalars.iter()) {
if *point == P::zero() || *scalar == Scalar::zero() {
continue;
}
let mut term = point.clone();
term.mul(scalar);
total.add(&term);
}
total
}
#[test]
fn test_g1_msm() {
let mut rng = thread_rng();
let n = 10;
let points_g1: Vec<G1> = (0..n)
.map(|_| {
let mut point = G1::one();
point.mul(&Scalar::from_rand(&mut rng));
point
})
.collect();
let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::from_rand(&mut rng)).collect();
let expected_g1 = naive_msm(&points_g1, &scalars);
let result_g1 = G1::msm(&points_g1, &scalars);
assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
let mut points_with_zero_g1 = points_g1.clone();
points_with_zero_g1[n / 2] = G1::zero();
let expected_zero_pt_g1 = naive_msm(&points_with_zero_g1, &scalars);
let result_zero_pt_g1 = G1::msm(&points_with_zero_g1, &scalars);
assert_eq!(
expected_zero_pt_g1, result_zero_pt_g1,
"G1 MSM with identity point failed"
);
let mut scalars_with_zero = scalars.clone();
scalars_with_zero[n / 2] = Scalar::zero();
let expected_zero_sc_g1 = naive_msm(&points_g1, &scalars_with_zero);
let result_zero_sc_g1 = G1::msm(&points_g1, &scalars_with_zero);
assert_eq!(
expected_zero_sc_g1, result_zero_sc_g1,
"G1 MSM with zero scalar failed"
);
let zero_points_g1 = vec![G1::zero(); n];
let expected_all_zero_pt_g1 = naive_msm(&zero_points_g1, &scalars);
let result_all_zero_pt_g1 = G1::msm(&zero_points_g1, &scalars);
assert_eq!(
expected_all_zero_pt_g1,
G1::zero(),
"G1 MSM all identity points (naive) failed"
);
assert_eq!(
result_all_zero_pt_g1,
G1::zero(),
"G1 MSM all identity points failed"
);
let zero_scalars = vec![Scalar::zero(); n];
let expected_all_zero_sc_g1 = naive_msm(&points_g1, &zero_scalars);
let result_all_zero_sc_g1 = G1::msm(&points_g1, &zero_scalars);
assert_eq!(
expected_all_zero_sc_g1,
G1::zero(),
"G1 MSM all zero scalars (naive) failed"
);
assert_eq!(
result_all_zero_sc_g1,
G1::zero(),
"G1 MSM all zero scalars failed"
);
let single_point_g1 = [points_g1[0]];
let single_scalar = [scalars[0].clone()];
let expected_single_g1 = naive_msm(&single_point_g1, &single_scalar);
let result_single_g1 = G1::msm(&single_point_g1, &single_scalar);
assert_eq!(
expected_single_g1, result_single_g1,
"G1 MSM single element failed"
);
let empty_points_g1: [G1; 0] = [];
let empty_scalars: [Scalar; 0] = [];
let expected_empty_g1 = naive_msm(&empty_points_g1, &empty_scalars);
let result_empty_g1 = G1::msm(&empty_points_g1, &empty_scalars);
assert_eq!(expected_empty_g1, G1::zero(), "G1 MSM empty (naive) failed");
assert_eq!(result_empty_g1, G1::zero(), "G1 MSM empty failed");
let points_g1: Vec<G1> = (0..50_000)
.map(|_| {
let mut point = G1::one();
point.mul(&Scalar::from_rand(&mut rng));
point
})
.collect();
let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::from_rand(&mut rng)).collect();
let expected_g1 = naive_msm(&points_g1, &scalars);
let result_g1 = G1::msm(&points_g1, &scalars);
assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
}
#[test]
fn test_g2_msm() {
let mut rng = thread_rng();
let n = 10;
let points_g2: Vec<G2> = (0..n)
.map(|_| {
let mut point = G2::one();
point.mul(&Scalar::from_rand(&mut rng));
point
})
.collect();
let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::from_rand(&mut rng)).collect();
let expected_g2 = naive_msm(&points_g2, &scalars);
let result_g2 = G2::msm(&points_g2, &scalars);
assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
let mut points_with_zero_g2 = points_g2.clone();
points_with_zero_g2[n / 2] = G2::zero();
let expected_zero_pt_g2 = naive_msm(&points_with_zero_g2, &scalars);
let result_zero_pt_g2 = G2::msm(&points_with_zero_g2, &scalars);
assert_eq!(
expected_zero_pt_g2, result_zero_pt_g2,
"G2 MSM with identity point failed"
);
let mut scalars_with_zero = scalars.clone();
scalars_with_zero[n / 2] = Scalar::zero();
let expected_zero_sc_g2 = naive_msm(&points_g2, &scalars_with_zero);
let result_zero_sc_g2 = G2::msm(&points_g2, &scalars_with_zero);
assert_eq!(
expected_zero_sc_g2, result_zero_sc_g2,
"G2 MSM with zero scalar failed"
);
let zero_points_g2 = vec![G2::zero(); n];
let expected_all_zero_pt_g2 = naive_msm(&zero_points_g2, &scalars);
let result_all_zero_pt_g2 = G2::msm(&zero_points_g2, &scalars);
assert_eq!(
expected_all_zero_pt_g2,
G2::zero(),
"G2 MSM all identity points (naive) failed"
);
assert_eq!(
result_all_zero_pt_g2,
G2::zero(),
"G2 MSM all identity points failed"
);
let zero_scalars = vec![Scalar::zero(); n];
let expected_all_zero_sc_g2 = naive_msm(&points_g2, &zero_scalars);
let result_all_zero_sc_g2 = G2::msm(&points_g2, &zero_scalars);
assert_eq!(
expected_all_zero_sc_g2,
G2::zero(),
"G2 MSM all zero scalars (naive) failed"
);
assert_eq!(
result_all_zero_sc_g2,
G2::zero(),
"G2 MSM all zero scalars failed"
);
let single_point_g2 = [points_g2[0]];
let single_scalar = [scalars[0].clone()];
let expected_single_g2 = naive_msm(&single_point_g2, &single_scalar);
let result_single_g2 = G2::msm(&single_point_g2, &single_scalar);
assert_eq!(
expected_single_g2, result_single_g2,
"G2 MSM single element failed"
);
let empty_points_g2: [G2; 0] = [];
let empty_scalars: [Scalar; 0] = [];
let expected_empty_g2 = naive_msm(&empty_points_g2, &empty_scalars);
let result_empty_g2 = G2::msm(&empty_points_g2, &empty_scalars);
assert_eq!(expected_empty_g2, G2::zero(), "G2 MSM empty (naive) failed");
assert_eq!(result_empty_g2, G2::zero(), "G2 MSM empty failed");
let points_g2: Vec<G2> = (0..50_000)
.map(|_| {
let mut point = G2::one();
point.mul(&Scalar::from_rand(&mut rng));
point
})
.collect();
let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::from_rand(&mut rng)).collect();
let expected_g2 = naive_msm(&points_g2, &scalars);
let result_g2 = G2::msm(&points_g2, &scalars);
assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
}
#[test]
fn test_trait_implementations() {
let mut rng = thread_rng();
const NUM_ITEMS: usize = 10;
let mut scalar_set = BTreeSet::new();
let mut g1_set = BTreeSet::new();
let mut g2_set = BTreeSet::new();
let mut share_set = BTreeSet::new();
while scalar_set.len() < NUM_ITEMS {
let scalar = Scalar::from_rand(&mut rng);
let mut g1 = G1::one();
g1.mul(&scalar);
let mut g2 = G2::one();
g2.mul(&scalar);
let share = Share {
index: scalar_set.len() as u32,
private: scalar.clone(),
};
scalar_set.insert(scalar);
g1_set.insert(g1);
g2_set.insert(g2);
share_set.insert(share);
}
assert_eq!(scalar_set.len(), NUM_ITEMS);
assert_eq!(g1_set.len(), NUM_ITEMS);
assert_eq!(g2_set.len(), NUM_ITEMS);
assert_eq!(share_set.len(), NUM_ITEMS);
let scalars: Vec<_> = scalar_set.iter().collect();
assert!(scalars.windows(2).all(|w| w[0] <= w[1]));
let g1s: Vec<_> = g1_set.iter().collect();
assert!(g1s.windows(2).all(|w| w[0] <= w[1]));
let g2s: Vec<_> = g2_set.iter().collect();
assert!(g2s.windows(2).all(|w| w[0] <= w[1]));
let shares: Vec<_> = share_set.iter().collect();
assert!(shares.windows(2).all(|w| w[0] <= w[1]));
let scalar_map: HashMap<_, _> = scalar_set.iter().cloned().zip(0..).collect();
let g1_map: HashMap<_, _> = g1_set.iter().cloned().zip(0..).collect();
let g2_map: HashMap<_, _> = g2_set.iter().cloned().zip(0..).collect();
let share_map: HashMap<_, _> = share_set.iter().cloned().zip(0..).collect();
assert_eq!(scalar_map.len(), NUM_ITEMS);
assert_eq!(g1_map.len(), NUM_ITEMS);
assert_eq!(g2_map.len(), NUM_ITEMS);
assert_eq!(share_map.len(), NUM_ITEMS);
}
#[test]
fn test_scalar_map() {
let msg = b"test message";
let dst = b"TEST_DST";
let scalar1 = Scalar::map(dst, msg);
let scalar2 = Scalar::map(dst, msg);
assert_eq!(scalar1, scalar2, "Same input should produce same output");
let msg2 = b"different message";
let scalar3 = Scalar::map(dst, msg2);
assert_ne!(
scalar1, scalar3,
"Different messages should produce different scalars"
);
let dst2 = b"DIFFERENT_DST";
let scalar4 = Scalar::map(dst2, msg);
assert_ne!(
scalar1, scalar4,
"Different DSTs should produce different scalars"
);
let empty_msg = b"";
let scalar_empty = Scalar::map(dst, empty_msg);
assert_ne!(
scalar_empty,
Scalar::zero(),
"Empty message should not produce zero"
);
let large_msg = vec![0x42u8; 1000];
let scalar_large = Scalar::map(dst, &large_msg);
assert_ne!(
scalar_large,
Scalar::zero(),
"Large message should not produce zero"
);
assert_ne!(
scalar1,
Scalar::zero(),
"Hash should not produce zero scalar"
);
}
}