use super::variant::Variant;
use crate::Secret;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use blst::{
blst_bendian_from_fp12, blst_bendian_from_scalar, blst_expand_message_xmd, blst_fp12, blst_fr,
blst_fr_add, blst_fr_cneg, blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_inverse,
blst_fr_mul, blst_fr_rshift, blst_fr_sub, blst_hash_to_g1, blst_hash_to_g2, blst_keygen,
blst_p1, blst_p1_add_or_double, blst_p1_affine, blst_p1_cneg, blst_p1_compress, blst_p1_double,
blst_p1_from_affine, blst_p1_in_g1, blst_p1_is_inf, blst_p1_mult, blst_p1_to_affine,
blst_p1_uncompress, blst_p1s_mult_pippenger, blst_p1s_mult_pippenger_scratch_sizeof,
blst_p1s_tile_pippenger, blst_p1s_to_affine, blst_p2, blst_p2_add_or_double, blst_p2_affine,
blst_p2_cneg, blst_p2_compress, blst_p2_double, blst_p2_from_affine, blst_p2_in_g2,
blst_p2_is_inf, blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress, blst_p2s_mult_pippenger,
blst_p2s_mult_pippenger_scratch_sizeof, blst_p2s_tile_pippenger, blst_p2s_to_affine,
blst_scalar, blst_scalar_from_be_bytes, blst_scalar_from_bendian, blst_scalar_from_fr,
blst_sk_check, Pairing, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
};
use bytes::{Buf, BufMut};
use commonware_codec::{
EncodeSize,
Error::{self, Invalid},
FixedSize, Read, ReadExt, Write,
};
use commonware_math::algebra::{
Additive, CryptoGroup, Field, FieldNTT, HashToGroup, Multiplicative, Object, Random, Ring,
Space,
};
use commonware_parallel::Strategy;
use commonware_utils::{hex, Participant};
use core::{
fmt::{Debug, Display, Formatter},
hash::{Hash, Hasher},
iter,
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
ptr,
};
use ctutils::{Choice, CtEq};
use rand_core::CryptoRngCore;
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
fn all_zero(bytes: &[u8]) -> Choice {
bytes
.iter()
.fold(Choice::TRUE, |acc, b| acc & b.ct_eq(&0u8))
}
const fn pippenger_window_size(npoints: usize) -> usize {
let wbits = (usize::BITS - npoints.leading_zeros()) as usize;
if wbits > 13 {
wbits - 4
} else if wbits > 5 {
wbits - 3
} else {
2
}
}
fn msm_breakdown(nbits: usize, window: usize, ncpus: usize) -> (usize, usize, usize) {
let num_bits = |l: usize| (usize::BITS - l.leading_zeros()) as usize;
let (nx, wnd) = if nbits > window * ncpus {
let mut wnd = num_bits(ncpus / 4);
if (window + wnd) > 18 {
wnd = window.saturating_sub(wnd).max(1);
} else {
wnd = (nbits / window).div_ceil(ncpus);
if (nbits / (window + 1)).div_ceil(ncpus) < wnd {
wnd = window + 1;
} else {
wnd = window;
}
}
(1, wnd)
} else {
let mut nx = 2usize;
let mut wnd = window.saturating_sub(2).max(1);
while (nbits / wnd + 1) * nx < ncpus {
nx += 1;
let new_wnd = window.saturating_sub(num_bits(3 * nx / 2));
if new_wnd == 0 {
break;
}
wnd = new_wnd;
}
nx -= 1;
wnd = window.saturating_sub(num_bits(3 * nx / 2)).max(1);
(nx, wnd)
};
let ny = nbits / wnd + 1;
let final_wnd = nbits / ny + 1;
(nx, ny, final_wnd)
}
struct Tile {
x: usize,
dx: usize,
y: usize,
}
fn build_tiles(npoints: usize, nx: usize, ny: usize, window: usize) -> Vec<Tile> {
let mut tiles = Vec::with_capacity(nx * ny);
let dx = npoints / nx;
let mut y = window * (ny - 1);
for i in 0..nx {
let x = i * dx;
let tile_dx = if i == nx - 1 { npoints - x } else { dx };
tiles.push(Tile { x, dx: tile_dx, y });
}
while y != 0 {
y -= window;
for i in 0..nx {
let x = i * dx;
let tile_dx = if i == nx - 1 { npoints - x } else { dx };
tiles.push(Tile { x, dx: tile_dx, y });
}
}
tiles
}
#[allow(clippy::too_many_arguments)]
fn msm_parallel_generic<A, P, R>(
affine_points: &[A],
scalars: &[u8],
nbits: usize,
ncpus: usize,
strategy: &impl Strategy,
compute_tile: impl Fn(&[A], &[u8], &Tile, usize, usize, usize) -> P + Sync,
add: impl Fn(&P, &P) -> P,
double: impl Fn(&P) -> P,
from_projective: impl Fn(P) -> R,
) -> R
where
A: Sync,
P: Default + Send,
{
let npoints = affine_points.len();
let nbytes = nbits.div_ceil(8);
let (nx, ny, window) = msm_breakdown(nbits, pippenger_window_size(npoints), ncpus);
let tiles = build_tiles(npoints, nx, ny, window);
let tile_results: Vec<(usize, usize, P)> =
strategy.map_collect_vec(tiles.iter().enumerate(), |(idx, tile)| {
let result = compute_tile(affine_points, scalars, tile, nbytes, nbits, window);
(idx / nx, idx % nx, result)
});
let mut row_sums: Vec<Option<P>> = (0..ny).map(|_| None).collect();
for (row, _col, point) in tile_results {
row_sums[row] = Some(match row_sums[row].take() {
Some(sum) => add(&sum, &point),
None => point,
});
}
let mut result = P::default();
for (i, row_sum) in row_sums.into_iter().enumerate() {
if let Some(sum) = row_sum {
result = add(&result, &sum);
}
if i < ny - 1 {
for _ in 0..window {
result = double(&result);
}
}
}
from_projective(result)
}
pub type DST = &'static [u8];
#[derive(Clone, Eq, PartialEq)]
#[repr(transparent)]
pub struct Scalar(blst_fr);
#[cfg(any(test, feature = "arbitrary"))]
impl arbitrary::Arbitrary<'_> for Scalar {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let ikm = u.arbitrary::<[u8; IKM_LENGTH]>()?;
Ok(Self::from_ikm(&ikm))
}
}
pub const SCALAR_LENGTH: usize = 32;
const SCALAR_BITS: usize = 255;
const SMALL_SCALAR_BITS: usize = 128;
const SMALL_SCALAR_LENGTH: usize = 16;
const IKM_LENGTH: usize = 64;
const MIN_PARALLEL_POINTS: usize = 32;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SmallScalar {
inner: blst_scalar,
}
impl SmallScalar {
pub fn random(mut rng: impl CryptoRngCore) -> Self {
let mut bytes = [0u8; 32];
rng.fill_bytes(&mut bytes[SMALL_SCALAR_LENGTH..]);
let mut scalar = blst_scalar::default();
unsafe {
blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
}
Self { inner: scalar }
}
pub const fn as_bytes(&self) -> &[u8] {
self.inner.b.as_slice()
}
pub fn zero() -> Self {
Self {
inner: blst_scalar::default(),
}
}
}
impl From<SmallScalar> for Scalar {
fn from(small: SmallScalar) -> Self {
let mut fr = blst_fr::default();
unsafe {
blst_fr_from_scalar(&mut fr, &small.inner);
}
Self(fr)
}
}
const BLST_FR_ONE: Scalar = Scalar(blst_fr {
l: [
0x0000_0001_ffff_fffe,
0x5884_b7fa_0003_4802,
0x998c_4fef_ecbc_4ff5,
0x1824_b159_acc5_056f,
],
});
const ROOT_OF_UNITY: Scalar = Scalar(blst_fr {
l: [
0xb9b5_8d8c_5f0e_466a,
0x5b1b_4c80_1819_d7ec,
0x0af5_3ae3_52a3_1e64,
0x5bf3_adda_19e9_b27b,
],
});
const COSET_SHIFT: Scalar = Scalar(blst_fr {
l: [
0x0000_000e_ffff_fff1,
0x17e3_63d3_0018_9c0f,
0xff9c_5787_6f84_57b0,
0x3513_3220_8fc5_a8c4,
],
});
const COSET_SHIFT_INV: Scalar = Scalar(blst_fr {
l: [
0xdb6d_b6da_db6d_b6dc,
0xe6b5_824a_db6c_c6da,
0xf8b3_56e0_0581_0db9,
0x66d0_f1e6_60ec_4796,
],
});
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(transparent)]
pub struct G1(blst_p1);
pub const G1_ELEMENT_BYTE_LENGTH: usize = 48;
pub const G1_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
pub const G1_MESSAGE: DST = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
#[cfg(any(test, feature = "arbitrary"))]
impl arbitrary::Arbitrary<'_> for G1 {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
Ok(Self::generator() * &u.arbitrary::<Scalar>()?)
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(transparent)]
pub struct G2(blst_p2);
pub const G2_ELEMENT_BYTE_LENGTH: usize = 96;
pub const G2_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
pub const G2_MESSAGE: DST = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
#[cfg(any(test, feature = "arbitrary"))]
impl arbitrary::Arbitrary<'_> for G2 {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
Ok(Self::generator() * &u.arbitrary::<Scalar>()?)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Copy)]
#[repr(transparent)]
pub struct GT(blst_fp12);
pub const GT_ELEMENT_BYTE_LENGTH: usize = 576;
impl GT {
pub(crate) const fn from_blst_fp12(fp12: blst_fp12) -> Self {
Self(fp12)
}
pub fn as_slice(&self) -> [u8; GT_ELEMENT_BYTE_LENGTH] {
let mut slice = [0u8; GT_ELEMENT_BYTE_LENGTH];
unsafe {
blst_bendian_from_fp12(slice.as_mut_ptr(), &self.0);
}
slice
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Private {
scalar: Secret<Scalar>,
}
impl Private {
pub const fn new(private: Scalar) -> Self {
Self {
scalar: Secret::new(private),
}
}
pub fn expose<R>(&self, f: impl for<'a> FnOnce(&'a Scalar) -> R) -> R {
self.scalar.expose(f)
}
pub fn expose_unwrap(self) -> Scalar {
self.scalar.expose_unwrap()
}
}
impl Write for Private {
fn write(&self, buf: &mut impl BufMut) {
self.expose(|scalar| scalar.write(buf));
}
}
impl Read for Private {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let scalar = Scalar::read(buf)?;
Ok(Self::new(scalar))
}
}
impl FixedSize for Private {
const SIZE: usize = PRIVATE_KEY_LENGTH;
}
impl Random for Private {
fn random(rng: impl CryptoRngCore) -> Self {
Self::new(Scalar::random(rng))
}
}
#[cfg(feature = "arbitrary")]
impl arbitrary::Arbitrary<'_> for Private {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
Ok(Self::new(u.arbitrary::<Scalar>()?))
}
}
pub const PRIVATE_KEY_LENGTH: usize = SCALAR_LENGTH;
impl Scalar {
fn from_ikm(ikm: &[u8; IKM_LENGTH]) -> Self {
let mut sc = blst_scalar::default();
let mut ret = blst_fr::default();
unsafe {
blst_keygen(&mut sc, ikm.as_ptr(), ikm.len(), ptr::null(), 0);
blst_fr_from_scalar(&mut ret, &sc);
}
Self(ret)
}
pub fn map(dst: DST, msg: &[u8]) -> Self {
const L: usize = 48;
let mut uniform_bytes = Zeroizing::new([0u8; L]);
unsafe {
blst_expand_message_xmd(
uniform_bytes.as_mut_ptr(),
L,
msg.as_ptr(),
msg.len(),
dst.as_ptr(),
dst.len(),
);
}
let mut fr = blst_fr::default();
unsafe {
let mut scalar = blst_scalar::default();
blst_scalar_from_be_bytes(&mut scalar, uniform_bytes.as_ptr(), L);
blst_fr_from_scalar(&mut fr, &scalar);
}
Self(fr)
}
pub fn from_limbs(limbs: [u64; 4]) -> Self {
let mut ret = blst_fr::default();
unsafe { blst_fr_from_uint64(&mut ret, limbs.as_ptr()) };
Self(ret)
}
pub fn from_u64(i: u64) -> Self {
Self::from_limbs([i, 0, 0, 0])
}
fn as_slice(&self) -> Zeroizing<[u8; Self::SIZE]> {
let mut slice = Zeroizing::new([0u8; Self::SIZE]);
unsafe {
let mut scalar = blst_scalar::default();
blst_scalar_from_fr(&mut scalar, &self.0);
blst_bendian_from_scalar(slice.as_mut_ptr(), &scalar);
}
slice
}
pub(crate) fn as_blst_scalar(&self) -> blst_scalar {
let mut scalar = blst_scalar::default();
unsafe { blst_scalar_from_fr(&mut scalar, &self.0) };
scalar
}
}
impl Write for Scalar {
fn write(&self, buf: &mut impl BufMut) {
let slice = self.as_slice();
buf.put_slice(slice.as_ref());
}
}
impl Read for Scalar {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let bytes = Zeroizing::new(<[u8; Self::SIZE]>::read(buf)?);
let mut ret = blst_fr::default();
unsafe {
let mut scalar = blst_scalar::default();
blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
if !blst_sk_check(&scalar) {
return Err(Invalid("Scalar", "Invalid"));
}
blst_fr_from_scalar(&mut ret, &scalar);
}
Ok(Self(ret))
}
}
impl FixedSize for Scalar {
const SIZE: usize = SCALAR_LENGTH;
}
impl Hash for Scalar {
fn hash<H: Hasher>(&self, state: &mut H) {
let slice = self.as_slice();
state.write(slice.as_ref());
}
}
impl CtEq for Scalar {
fn ct_eq(&self, other: &Self) -> ctutils::Choice {
self.0.l.ct_eq(&other.0.l)
}
}
impl PartialOrd for Scalar {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Scalar {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.as_slice().cmp(&other.as_slice())
}
}
impl Debug for Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "Scalar([REDACTED])")
}
}
impl Display for Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "[REDACTED]")
}
}
impl Zeroize for Scalar {
fn zeroize(&mut self) {
self.0.l.zeroize();
}
}
impl Drop for Scalar {
fn drop(&mut self) {
self.zeroize();
}
}
impl ZeroizeOnDrop for Scalar {}
impl Object for Scalar {}
impl From<u64> for Scalar {
fn from(value: u64) -> Self {
Self::from_u64(value)
}
}
impl<'a> AddAssign<&'a Self> for Scalar {
fn add_assign(&mut self, rhs: &'a Self) {
let ptr = &raw mut self.0;
unsafe {
blst_fr_add(ptr, ptr, &rhs.0);
}
}
}
impl<'a> Add<&'a Self> for Scalar {
type Output = Self;
fn add(mut self, rhs: &'a Self) -> Self::Output {
self += rhs;
self
}
}
impl<'a> SubAssign<&'a Self> for Scalar {
fn sub_assign(&mut self, rhs: &'a Self) {
let ptr = &raw mut self.0;
unsafe { blst_fr_sub(ptr, ptr, &rhs.0) }
}
}
impl<'a> Sub<&'a Self> for Scalar {
type Output = Self;
fn sub(mut self, rhs: &'a Self) -> Self::Output {
self -= rhs;
self
}
}
impl Neg for Scalar {
type Output = Self;
fn neg(mut self) -> Self::Output {
let ptr = &raw mut self.0;
unsafe {
blst_fr_cneg(ptr, ptr, true);
}
self
}
}
impl Additive for Scalar {
fn zero() -> Self {
Self(blst_fr::default())
}
}
impl<'a> MulAssign<&'a Self> for Scalar {
fn mul_assign(&mut self, rhs: &'a Self) {
let ptr = &raw mut self.0;
unsafe {
blst_fr_mul(ptr, ptr, &rhs.0);
}
}
}
impl<'a> Mul<&'a Self> for Scalar {
type Output = Self;
fn mul(mut self, rhs: &'a Self) -> Self::Output {
self *= rhs;
self
}
}
impl Multiplicative for Scalar {}
impl Ring for Scalar {
fn one() -> Self {
BLST_FR_ONE
}
}
impl Field for Scalar {
fn inv(&self) -> Self {
if *self == Self::zero() {
return Self::zero();
}
let mut ret = blst_fr::default();
unsafe { blst_fr_inverse(&mut ret, &self.0) };
Self(ret)
}
}
impl Random for Scalar {
fn random(mut rng: impl CryptoRngCore) -> Self {
let mut ikm = Zeroizing::new([0u8; IKM_LENGTH]);
rng.fill_bytes(ikm.as_mut());
Self::from_ikm(&ikm)
}
}
impl FieldNTT for Scalar {
const MAX_LG_ROOT_ORDER: u8 = 32;
fn root_of_unity(lg: u8) -> Option<Self> {
if lg > Self::MAX_LG_ROOT_ORDER {
return None;
}
let mut out = ROOT_OF_UNITY;
for _ in 0..(Self::MAX_LG_ROOT_ORDER - lg) {
out = out.clone() * &out;
}
Some(out)
}
fn coset_shift() -> Self {
COSET_SHIFT
}
fn coset_shift_inv() -> Self {
COSET_SHIFT_INV
}
fn div_2(&self) -> Self {
let mut ret = blst_fr::default();
unsafe { blst_fr_rshift(&mut ret, &self.0, 1) };
Self(ret)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Share {
pub index: Participant,
pub private: Private,
}
impl Share {
pub const fn new(index: Participant, private: Private) -> Self {
Self { index, private }
}
pub fn public<V: Variant>(&self) -> V::Public {
self.private
.expose(|private| V::Public::generator() * private)
}
}
impl Write for Share {
fn write(&self, buf: &mut impl BufMut) {
self.index.write(buf);
self.private.expose(|private| private.write(buf));
}
}
impl Read for Share {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let index = Participant::read(buf)?;
let private = Private::read(buf)?;
Ok(Self { index, private })
}
}
impl EncodeSize for Share {
fn encode_size(&self) -> usize {
self.index.encode_size() + self.private.expose(|private| private.encode_size())
}
}
impl Display for Share {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{:?}", self)
}
}
#[cfg(feature = "arbitrary")]
impl arbitrary::Arbitrary<'_> for Share {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let index = u.arbitrary()?;
let private = u.arbitrary::<Private>()?;
Ok(Self { index, private })
}
}
impl G1 {
fn as_slice(&self) -> [u8; Self::SIZE] {
let mut slice = [0u8; Self::SIZE];
unsafe {
blst_p1_compress(slice.as_mut_ptr(), &self.0);
}
slice
}
fn neg_in_place(&mut self) {
let ptr = &raw mut self.0;
unsafe {
blst_p1_cneg(ptr, true);
}
}
pub(crate) fn as_blst_p1_affine(&self) -> blst_p1_affine {
let mut affine = blst_p1_affine::default();
unsafe { blst_p1_to_affine(&mut affine, &self.0) };
affine
}
pub(crate) const fn from_blst_p1(p: blst_p1) -> Self {
Self(p)
}
pub fn batch_to_affine(points: &[Self]) -> Vec<blst_p1_affine> {
if points.is_empty() {
return Vec::new();
}
let n = points.len();
let mut out = vec![blst_p1_affine::default(); n];
unsafe {
let points_ptr: Vec<*const blst_p1> = points.iter().map(|p| &p.0 as *const _).collect();
blst_p1s_to_affine(out.as_mut_ptr(), points_ptr.as_ptr(), n);
}
out
}
#[must_use]
pub(crate) fn multi_pairing_check(p1: &[Self], p2: &[G2], t1: &Self, t2: &G2) -> bool {
assert_eq!(p1.len(), p2.len());
let mut pairing = Pairing::new(false, &[]);
let p1_affine = Self::batch_to_affine(p1);
let p2_affine = G2::batch_to_affine(p2);
for (p1, p2) in iter::once((&t1.as_blst_p1_affine(), &t2.as_blst_p2_affine()))
.chain(p1_affine.iter().zip(p2_affine.iter()))
{
pairing.raw_aggregate(p2, p1);
}
pairing.commit();
pairing.finalverify(None)
}
fn msm_inner<'a>(
iter: impl Iterator<Item = (&'a Self, &'a [u8])>,
nbits: usize,
strategy: &impl Strategy,
) -> Self {
let nbytes = nbits.div_ceil(8);
let (points_filtered, scalars_filtered): (Vec<_>, Vec<_>) = iter
.filter_map(|(point, scalar)| {
if *point == Self::zero() || all_zero(scalar).into() {
return None;
}
Some((point, scalar))
})
.unzip();
if points_filtered.is_empty() {
return Self::zero();
}
let npoints = points_filtered.len();
let ncpus = strategy.parallelism_hint();
let affine_points = Self::batch_to_affine(&points_filtered);
let scalar_bytes: Vec<u8> = scalars_filtered
.iter()
.flat_map(|s| s[..nbytes].iter().copied())
.collect();
if ncpus < 2 || npoints < MIN_PARALLEL_POINTS {
return Self::msm_sequential(&affine_points, &scalar_bytes, nbits);
}
Self::msm_parallel(&affine_points, &scalar_bytes, nbits, ncpus, strategy)
}
fn msm_sequential(affine_points: &[blst_p1_affine], scalars: &[u8], nbits: usize) -> Self {
let npoints = affine_points.len();
let scratch_size = unsafe { blst_p1s_mult_pippenger_scratch_sizeof(npoints) };
assert_eq!(scratch_size % 8, 0, "scratch_size must be multiple of 8");
let mut scratch = vec![0u64; scratch_size / 8];
let p: [*const blst_p1_affine; 2] = [affine_points.as_ptr(), ptr::null()];
let s: [*const u8; 2] = [scalars.as_ptr(), ptr::null()];
let mut result = blst_p1::default();
unsafe {
blst_p1s_mult_pippenger(
&mut result,
p.as_ptr(),
npoints,
s.as_ptr(),
nbits,
scratch.as_mut_ptr(),
);
}
Self::from_blst_p1(result)
}
fn msm_parallel(
affine_points: &[blst_p1_affine],
scalars: &[u8],
nbits: usize,
ncpus: usize,
strategy: &impl Strategy,
) -> Self {
let scratch_size = unsafe { blst_p1s_mult_pippenger_scratch_sizeof(0) } / 8;
msm_parallel_generic(
affine_points,
scalars,
nbits,
ncpus,
strategy,
|points, scalars, tile, nbytes, nbits, window| {
let mut scratch = vec![0u64; scratch_size << (window - 1)];
let mut result = blst_p1::default();
let p: [*const blst_p1_affine; 2] = [points[tile.x..].as_ptr(), ptr::null()];
let s: [*const u8; 2] = [scalars[tile.x * nbytes..].as_ptr(), ptr::null()];
unsafe {
blst_p1s_tile_pippenger(
&mut result,
p.as_ptr(),
tile.dx,
s.as_ptr(),
nbits,
scratch.as_mut_ptr(),
tile.y,
window,
);
}
result
},
|a, b| {
let mut result = blst_p1::default();
unsafe { blst_p1_add_or_double(&mut result, a, b) };
result
},
|a| {
let mut result = blst_p1::default();
unsafe { blst_p1_double(&mut result, a) };
result
},
Self::from_blst_p1,
)
}
}
impl Write for G1 {
fn write(&self, buf: &mut impl BufMut) {
let slice = self.as_slice();
buf.put_slice(&slice);
}
}
impl Read for G1 {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let bytes = <[u8; Self::SIZE]>::read(buf)?;
let mut ret = blst_p1::default();
unsafe {
let mut affine = blst_p1_affine::default();
match blst_p1_uncompress(&mut affine, bytes.as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G1", "Bad encoding")),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G1", "Not on curve")),
BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G1", "Not in group")),
BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G1", "Type mismatch")),
BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G1", "Verify fail")),
BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G1", "PK is Infinity")),
BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G1", "Bad scalar")),
}
blst_p1_from_affine(&mut ret, &affine);
if blst_p1_is_inf(&ret) {
return Err(Invalid("G1", "Infinity"));
}
if !blst_p1_in_g1(&ret) {
return Err(Invalid("G1", "Outside G1"));
}
}
Ok(Self(ret))
}
}
impl FixedSize for G1 {
const SIZE: usize = G1_ELEMENT_BYTE_LENGTH;
}
impl Hash for G1 {
fn hash<H: Hasher>(&self, state: &mut H) {
let slice = self.as_slice();
state.write(&slice);
}
}
impl PartialOrd for G1 {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for G1 {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.as_slice().cmp(&other.as_slice())
}
}
impl Debug for G1 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl Display for G1 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl Object for G1 {}
impl<'a> AddAssign<&'a Self> for G1 {
fn add_assign(&mut self, rhs: &'a Self) {
let ptr = &raw mut self.0;
unsafe {
blst_p1_add_or_double(ptr, ptr, &rhs.0);
}
}
}
impl<'a> Add<&'a Self> for G1 {
type Output = Self;
fn add(mut self, rhs: &'a Self) -> Self::Output {
self += rhs;
self
}
}
impl Neg for G1 {
type Output = Self;
fn neg(mut self) -> Self::Output {
self.neg_in_place();
self
}
}
impl<'a> SubAssign<&'a Self> for G1 {
fn sub_assign(&mut self, rhs: &'a Self) {
let mut rhs_cp = *rhs;
rhs_cp.neg_in_place();
*self += &rhs_cp;
}
}
impl<'a> Sub<&'a Self> for G1 {
type Output = Self;
fn sub(mut self, rhs: &'a Self) -> Self::Output {
self -= rhs;
self
}
}
impl Additive for G1 {
fn zero() -> Self {
Self(blst_p1::default())
}
}
impl<'a> MulAssign<&'a Scalar> for G1 {
fn mul_assign(&mut self, rhs: &'a Scalar) {
let ptr = &raw mut self.0;
let mut scalar: blst_scalar = blst_scalar::default();
unsafe {
blst_scalar_from_fr(&mut scalar, &rhs.0);
blst_p1_mult(ptr, ptr, scalar.b.as_ptr(), SCALAR_BITS);
}
}
}
impl<'a> Mul<&'a Scalar> for G1 {
type Output = Self;
fn mul(mut self, rhs: &'a Scalar) -> Self::Output {
self *= rhs;
self
}
}
impl<'a> MulAssign<&'a SmallScalar> for G1 {
fn mul_assign(&mut self, rhs: &'a SmallScalar) {
let ptr = &raw mut self.0;
unsafe {
blst_p1_mult(ptr, ptr, rhs.inner.b.as_ptr(), SMALL_SCALAR_BITS);
}
}
}
impl<'a> Mul<&'a SmallScalar> for G1 {
type Output = Self;
fn mul(mut self, rhs: &'a SmallScalar) -> Self::Output {
self *= rhs;
self
}
}
impl Space<Scalar> for G1 {
fn msm(points: &[Self], scalars: &[Scalar], strategy: &impl Strategy) -> Self {
assert_eq!(points.len(), scalars.len(), "mismatched lengths");
let scalar_bytes: Vec<_> = scalars.iter().map(|s| s.as_blst_scalar()).collect();
Self::msm_inner(
points
.iter()
.zip(scalar_bytes.iter().map(|s| s.b.as_slice())),
SCALAR_BITS,
strategy,
)
}
}
impl Space<SmallScalar> for G1 {
fn msm(points: &[Self], scalars: &[SmallScalar], strategy: &impl Strategy) -> Self {
assert_eq!(points.len(), scalars.len(), "mismatched lengths");
Self::msm_inner(
points.iter().zip(scalars.iter().map(|s| s.as_bytes())),
SMALL_SCALAR_BITS,
strategy,
)
}
}
impl CryptoGroup for G1 {
type Scalar = Scalar;
fn generator() -> Self {
let mut ret = blst_p1::default();
unsafe {
blst_p1_from_affine(&mut ret, &BLS12_381_G1);
}
Self(ret)
}
}
impl HashToGroup for G1 {
fn hash_to_group(domain_separator: &[u8], message: &[u8]) -> Self {
let mut out = blst_p1::default();
unsafe {
blst_hash_to_g1(
&mut out,
message.as_ptr(),
message.len(),
domain_separator.as_ptr(),
domain_separator.len(),
ptr::null(),
0,
);
}
Self(out)
}
}
impl G2 {
fn as_slice(&self) -> [u8; Self::SIZE] {
let mut slice = [0u8; Self::SIZE];
unsafe {
blst_p2_compress(slice.as_mut_ptr(), &self.0);
}
slice
}
fn neg_in_place(&mut self) {
let ptr = &raw mut self.0;
unsafe {
blst_p2_cneg(ptr, true);
}
}
pub(crate) fn as_blst_p2_affine(&self) -> blst_p2_affine {
let mut affine = blst_p2_affine::default();
unsafe { blst_p2_to_affine(&mut affine, &self.0) };
affine
}
pub(crate) const fn from_blst_p2(p: blst_p2) -> Self {
Self(p)
}
pub fn batch_to_affine(points: &[Self]) -> Vec<blst_p2_affine> {
if points.is_empty() {
return Vec::new();
}
let n = points.len();
let mut out = vec![blst_p2_affine::default(); n];
unsafe {
let points_ptr: Vec<*const blst_p2> = points.iter().map(|p| &p.0 as *const _).collect();
blst_p2s_to_affine(out.as_mut_ptr(), points_ptr.as_ptr(), n);
}
out
}
#[must_use]
pub(crate) fn multi_pairing_check(p1: &[Self], p2: &[G1], t1: &Self, t2: &G1) -> bool {
G1::multi_pairing_check(p2, p1, t2, t1)
}
fn msm_inner<'a>(
iter: impl Iterator<Item = (&'a Self, &'a [u8])>,
nbits: usize,
strategy: &impl Strategy,
) -> Self {
let nbytes = nbits.div_ceil(8);
let (points_filtered, scalars_filtered): (Vec<_>, Vec<_>) = iter
.filter_map(|(point, scalar)| {
if *point == Self::zero() || all_zero(scalar).into() {
return None;
}
Some((point, scalar))
})
.unzip();
if points_filtered.is_empty() {
return Self::zero();
}
let npoints = points_filtered.len();
let ncpus = strategy.parallelism_hint();
let affine_points = Self::batch_to_affine(&points_filtered);
let scalar_bytes: Vec<u8> = scalars_filtered
.iter()
.flat_map(|s| s[..nbytes].iter().copied())
.collect();
if ncpus < 2 || npoints < MIN_PARALLEL_POINTS {
return Self::msm_sequential(&affine_points, &scalar_bytes, nbits);
}
Self::msm_parallel(&affine_points, &scalar_bytes, nbits, ncpus, strategy)
}
fn msm_sequential(affine_points: &[blst_p2_affine], scalars: &[u8], nbits: usize) -> Self {
let npoints = affine_points.len();
let scratch_size = unsafe { blst_p2s_mult_pippenger_scratch_sizeof(npoints) };
assert_eq!(scratch_size % 8, 0, "scratch_size must be multiple of 8");
let mut scratch = vec![0u64; scratch_size / 8];
let p: [*const blst_p2_affine; 2] = [affine_points.as_ptr(), ptr::null()];
let s: [*const u8; 2] = [scalars.as_ptr(), ptr::null()];
let mut result = blst_p2::default();
unsafe {
blst_p2s_mult_pippenger(
&mut result,
p.as_ptr(),
npoints,
s.as_ptr(),
nbits,
scratch.as_mut_ptr(),
);
}
Self::from_blst_p2(result)
}
fn msm_parallel(
affine_points: &[blst_p2_affine],
scalars: &[u8],
nbits: usize,
ncpus: usize,
strategy: &impl Strategy,
) -> Self {
let scratch_size = unsafe { blst_p2s_mult_pippenger_scratch_sizeof(0) } / 8;
msm_parallel_generic(
affine_points,
scalars,
nbits,
ncpus,
strategy,
|points, scalars, tile, nbytes, nbits, window| {
let mut scratch = vec![0u64; scratch_size << (window - 1)];
let mut result = blst_p2::default();
let p: [*const blst_p2_affine; 2] = [points[tile.x..].as_ptr(), ptr::null()];
let s: [*const u8; 2] = [scalars[tile.x * nbytes..].as_ptr(), ptr::null()];
unsafe {
blst_p2s_tile_pippenger(
&mut result,
p.as_ptr(),
tile.dx,
s.as_ptr(),
nbits,
scratch.as_mut_ptr(),
tile.y,
window,
);
}
result
},
|a, b| {
let mut result = blst_p2::default();
unsafe { blst_p2_add_or_double(&mut result, a, b) };
result
},
|a| {
let mut result = blst_p2::default();
unsafe { blst_p2_double(&mut result, a) };
result
},
Self::from_blst_p2,
)
}
}
impl Write for G2 {
fn write(&self, buf: &mut impl BufMut) {
let slice = self.as_slice();
buf.put_slice(&slice);
}
}
impl Read for G2 {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let bytes = <[u8; Self::SIZE]>::read(buf)?;
let mut ret = blst_p2::default();
unsafe {
let mut affine = blst_p2_affine::default();
match blst_p2_uncompress(&mut affine, bytes.as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G2", "Bad encoding")),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G2", "Not on curve")),
BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G2", "Not in group")),
BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G2", "Type mismatch")),
BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G2", "Verify fail")),
BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G2", "PK is Infinity")),
BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G2", "Bad scalar")),
}
blst_p2_from_affine(&mut ret, &affine);
if blst_p2_is_inf(&ret) {
return Err(Invalid("G2", "Infinity"));
}
if !blst_p2_in_g2(&ret) {
return Err(Invalid("G2", "Outside G2"));
}
}
Ok(Self(ret))
}
}
impl FixedSize for G2 {
const SIZE: usize = G2_ELEMENT_BYTE_LENGTH;
}
impl Hash for G2 {
fn hash<H: Hasher>(&self, state: &mut H) {
let slice = self.as_slice();
state.write(&slice);
}
}
impl PartialOrd for G2 {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for G2 {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.as_slice().cmp(&other.as_slice())
}
}
impl Debug for G2 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl Display for G2 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", hex(&self.as_slice()))
}
}
impl Object for G2 {}
impl<'a> AddAssign<&'a Self> for G2 {
fn add_assign(&mut self, rhs: &'a Self) {
let ptr = &raw mut self.0;
unsafe {
blst_p2_add_or_double(ptr, ptr, &rhs.0);
}
}
}
impl<'a> Add<&'a Self> for G2 {
type Output = Self;
fn add(mut self, rhs: &'a Self) -> Self::Output {
self += rhs;
self
}
}
impl Neg for G2 {
type Output = Self;
fn neg(mut self) -> Self::Output {
self.neg_in_place();
self
}
}
impl<'a> SubAssign<&'a Self> for G2 {
fn sub_assign(&mut self, rhs: &'a Self) {
let mut rhs_cp = *rhs;
rhs_cp.neg_in_place();
*self += &rhs_cp;
}
}
impl<'a> Sub<&'a Self> for G2 {
type Output = Self;
fn sub(mut self, rhs: &'a Self) -> Self::Output {
self -= rhs;
self
}
}
impl Additive for G2 {
fn zero() -> Self {
Self(blst_p2::default())
}
}
impl<'a> MulAssign<&'a Scalar> for G2 {
fn mul_assign(&mut self, rhs: &'a Scalar) {
let mut scalar = blst_scalar::default();
let ptr = &raw mut self.0;
unsafe {
blst_scalar_from_fr(&mut scalar, &rhs.0);
blst_p2_mult(ptr, ptr, scalar.b.as_ptr(), SCALAR_BITS);
}
}
}
impl<'a> Mul<&'a Scalar> for G2 {
type Output = Self;
fn mul(mut self, rhs: &'a Scalar) -> Self::Output {
self *= rhs;
self
}
}
impl<'a> MulAssign<&'a SmallScalar> for G2 {
fn mul_assign(&mut self, rhs: &'a SmallScalar) {
let ptr = &raw mut self.0;
unsafe {
blst_p2_mult(ptr, ptr, rhs.inner.b.as_ptr(), SMALL_SCALAR_BITS);
}
}
}
impl<'a> Mul<&'a SmallScalar> for G2 {
type Output = Self;
fn mul(mut self, rhs: &'a SmallScalar) -> Self::Output {
self *= rhs;
self
}
}
impl Space<Scalar> for G2 {
fn msm(points: &[Self], scalars: &[Scalar], strategy: &impl Strategy) -> Self {
assert_eq!(points.len(), scalars.len(), "mismatched lengths");
let scalar_bytes: Vec<_> = scalars.iter().map(|s| s.as_blst_scalar()).collect();
Self::msm_inner(
points
.iter()
.zip(scalar_bytes.iter().map(|s| s.b.as_slice())),
SCALAR_BITS,
strategy,
)
}
}
impl Space<SmallScalar> for G2 {
fn msm(points: &[Self], scalars: &[SmallScalar], strategy: &impl Strategy) -> Self {
assert_eq!(points.len(), scalars.len(), "mismatched lengths");
Self::msm_inner(
points.iter().zip(scalars.iter().map(|s| s.as_bytes())),
SMALL_SCALAR_BITS,
strategy,
)
}
}
impl CryptoGroup for G2 {
type Scalar = Scalar;
fn generator() -> Self {
let mut ret = blst_p2::default();
unsafe {
blst_p2_from_affine(&mut ret, &BLS12_381_G2);
}
Self(ret)
}
}
impl HashToGroup for G2 {
fn hash_to_group(domain_separator: &[u8], message: &[u8]) -> Self {
let mut out = blst_p2::default();
unsafe {
blst_hash_to_g2(
&mut out,
message.as_ptr(),
message.len(),
domain_separator.as_ptr(),
domain_separator.len(),
ptr::null(),
0,
);
}
Self(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bls12381::primitives::group::Scalar;
use commonware_codec::{DecodeExt, Encode};
use commonware_invariants::minifuzz;
use commonware_macros::test_group;
use commonware_math::algebra::{test_suites, Random};
use commonware_parallel::{Rayon, Sequential};
use commonware_utils::test_rng;
use std::{
collections::{BTreeSet, HashMap},
num::NonZeroUsize,
};
#[test]
fn test_scalar_as_field() {
minifuzz::test(test_suites::fuzz_field::<Scalar>);
}
#[test]
fn test_scalar_as_field_ntt() {
minifuzz::test(test_suites::fuzz_field_ntt::<Scalar>);
}
#[test]
fn test_g1_as_space() {
minifuzz::test(test_suites::fuzz_space_ring::<Scalar, G1>);
}
#[test]
fn test_g2_as_space() {
minifuzz::test(test_suites::fuzz_space_ring::<Scalar, G2>);
}
#[test]
fn test_hash_to_g1() {
minifuzz::test(test_suites::fuzz_hash_to_group::<G1>);
}
#[test]
fn test_hash_to_g2() {
minifuzz::test(test_suites::fuzz_hash_to_group::<G2>);
}
#[test]
fn basic_group() {
let s = Scalar::random(&mut test_rng());
let mut s2 = s.clone();
s2.double();
let p1 = G1::generator() * &s2;
let mut p2 = G1::generator() * &s;
p2.double();
assert_eq!(p1, p2);
}
#[test]
fn test_scalar_codec() {
let original = Scalar::random(&mut test_rng());
let mut encoded = original.encode();
assert_eq!(encoded.len(), Scalar::SIZE);
let decoded = Scalar::decode(&mut encoded).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_g1_codec() {
let original = G1::generator() * &Scalar::random(&mut test_rng());
let mut encoded = original.encode();
assert_eq!(encoded.len(), G1::SIZE);
let decoded = G1::decode(&mut encoded).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_g2_codec() {
let original = G2::generator() * &Scalar::random(&mut test_rng());
let mut encoded = original.encode();
assert_eq!(encoded.len(), G2::SIZE);
let decoded = G2::decode(&mut encoded).unwrap();
assert_eq!(original, decoded);
}
fn naive_msm<P: Space<Scalar>>(points: &[P], scalars: &[Scalar]) -> P {
assert_eq!(points.len(), scalars.len());
let mut total = P::zero();
for (point, scalar) in points.iter().zip(scalars.iter()) {
if *point == P::zero() || *scalar == Scalar::zero() {
continue;
}
let term = point.clone() * scalar;
total += &term;
}
total
}
#[test]
fn test_g1_msm() {
let mut rng = test_rng();
let n = 10;
let points_g1: Vec<G1> = (0..n)
.map(|_| G1::generator() * &Scalar::random(&mut rng))
.collect();
let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::random(&mut rng)).collect();
let expected_g1 = naive_msm(&points_g1, &scalars);
let result_g1 = G1::msm(&points_g1, &scalars, &Sequential);
assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
let mut points_with_zero_g1 = points_g1.clone();
points_with_zero_g1[n / 2] = G1::zero();
let expected_zero_pt_g1 = naive_msm(&points_with_zero_g1, &scalars);
let result_zero_pt_g1 = G1::msm(&points_with_zero_g1, &scalars, &Sequential);
assert_eq!(
expected_zero_pt_g1, result_zero_pt_g1,
"G1 MSM with identity point failed"
);
let mut scalars_with_zero = scalars.clone();
scalars_with_zero[n / 2] = Scalar::zero();
let expected_zero_sc_g1 = naive_msm(&points_g1, &scalars_with_zero);
let result_zero_sc_g1 = G1::msm(&points_g1, &scalars_with_zero, &Sequential);
assert_eq!(
expected_zero_sc_g1, result_zero_sc_g1,
"G1 MSM with zero scalar failed"
);
let zero_points_g1 = vec![G1::zero(); n];
let expected_all_zero_pt_g1 = naive_msm(&zero_points_g1, &scalars);
let result_all_zero_pt_g1 = G1::msm(&zero_points_g1, &scalars, &Sequential);
assert_eq!(
expected_all_zero_pt_g1,
G1::zero(),
"G1 MSM all identity points (naive) failed"
);
assert_eq!(
result_all_zero_pt_g1,
G1::zero(),
"G1 MSM all identity points failed"
);
let zero_scalars = vec![Scalar::zero(); n];
let expected_all_zero_sc_g1 = naive_msm(&points_g1, &zero_scalars);
let result_all_zero_sc_g1 = G1::msm(&points_g1, &zero_scalars, &Sequential);
assert_eq!(
expected_all_zero_sc_g1,
G1::zero(),
"G1 MSM all zero scalars (naive) failed"
);
assert_eq!(
result_all_zero_sc_g1,
G1::zero(),
"G1 MSM all zero scalars failed"
);
let single_point_g1 = [points_g1[0]];
let single_scalar = [scalars[0].clone()];
let expected_single_g1 = naive_msm(&single_point_g1, &single_scalar);
let result_single_g1 = G1::msm(&single_point_g1, &single_scalar, &Sequential);
assert_eq!(
expected_single_g1, result_single_g1,
"G1 MSM single element failed"
);
let empty_points_g1: [G1; 0] = [];
let empty_scalars: [Scalar; 0] = [];
let expected_empty_g1 = naive_msm(&empty_points_g1, &empty_scalars);
let result_empty_g1 = G1::msm(&empty_points_g1, &empty_scalars, &Sequential);
assert_eq!(expected_empty_g1, G1::zero(), "G1 MSM empty (naive) failed");
assert_eq!(result_empty_g1, G1::zero(), "G1 MSM empty failed");
let points_g1: Vec<G1> = (0..50_000)
.map(|_| G1::generator() * &Scalar::random(&mut rng))
.collect();
let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::random(&mut rng)).collect();
let expected_g1 = naive_msm(&points_g1, &scalars);
let result_g1 = G1::msm(&points_g1, &scalars, &Sequential);
assert_eq!(expected_g1, result_g1, "G1 MSM basic case failed");
}
#[test_group("slow")]
#[test]
fn test_g2_msm() {
let mut rng = test_rng();
let n = 10;
let points_g2: Vec<G2> = (0..n)
.map(|_| G2::generator() * &Scalar::random(&mut rng))
.collect();
let scalars: Vec<Scalar> = (0..n).map(|_| Scalar::random(&mut rng)).collect();
let expected_g2 = naive_msm(&points_g2, &scalars);
let result_g2 = G2::msm(&points_g2, &scalars, &Sequential);
assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
let mut points_with_zero_g2 = points_g2.clone();
points_with_zero_g2[n / 2] = G2::zero();
let expected_zero_pt_g2 = naive_msm(&points_with_zero_g2, &scalars);
let result_zero_pt_g2 = G2::msm(&points_with_zero_g2, &scalars, &Sequential);
assert_eq!(
expected_zero_pt_g2, result_zero_pt_g2,
"G2 MSM with identity point failed"
);
let mut scalars_with_zero = scalars.clone();
scalars_with_zero[n / 2] = Scalar::zero();
let expected_zero_sc_g2 = naive_msm(&points_g2, &scalars_with_zero);
let result_zero_sc_g2 = G2::msm(&points_g2, &scalars_with_zero, &Sequential);
assert_eq!(
expected_zero_sc_g2, result_zero_sc_g2,
"G2 MSM with zero scalar failed"
);
let zero_points_g2 = vec![G2::zero(); n];
let expected_all_zero_pt_g2 = naive_msm(&zero_points_g2, &scalars);
let result_all_zero_pt_g2 = G2::msm(&zero_points_g2, &scalars, &Sequential);
assert_eq!(
expected_all_zero_pt_g2,
G2::zero(),
"G2 MSM all identity points (naive) failed"
);
assert_eq!(
result_all_zero_pt_g2,
G2::zero(),
"G2 MSM all identity points failed"
);
let zero_scalars = vec![Scalar::zero(); n];
let expected_all_zero_sc_g2 = naive_msm(&points_g2, &zero_scalars);
let result_all_zero_sc_g2 = G2::msm(&points_g2, &zero_scalars, &Sequential);
assert_eq!(
expected_all_zero_sc_g2,
G2::zero(),
"G2 MSM all zero scalars (naive) failed"
);
assert_eq!(
result_all_zero_sc_g2,
G2::zero(),
"G2 MSM all zero scalars failed"
);
let single_point_g2 = [points_g2[0]];
let single_scalar = [scalars[0].clone()];
let expected_single_g2 = naive_msm(&single_point_g2, &single_scalar);
let result_single_g2 = G2::msm(&single_point_g2, &single_scalar, &Sequential);
assert_eq!(
expected_single_g2, result_single_g2,
"G2 MSM single element failed"
);
let empty_points_g2: [G2; 0] = [];
let empty_scalars: [Scalar; 0] = [];
let expected_empty_g2 = naive_msm(&empty_points_g2, &empty_scalars);
let result_empty_g2 = G2::msm(&empty_points_g2, &empty_scalars, &Sequential);
assert_eq!(expected_empty_g2, G2::zero(), "G2 MSM empty (naive) failed");
assert_eq!(result_empty_g2, G2::zero(), "G2 MSM empty failed");
let points_g2: Vec<G2> = (0..50_000)
.map(|_| G2::generator() * &Scalar::random(&mut rng))
.collect();
let scalars: Vec<Scalar> = (0..50_000).map(|_| Scalar::random(&mut rng)).collect();
let expected_g2 = naive_msm(&points_g2, &scalars);
let result_g2 = G2::msm(&points_g2, &scalars, &Sequential);
assert_eq!(expected_g2, result_g2, "G2 MSM basic case failed");
}
#[test]
fn test_trait_implementations() {
let mut rng = test_rng();
const NUM_ITEMS: usize = 10;
let mut scalar_set = BTreeSet::new();
let mut g1_set = BTreeSet::new();
let mut g2_set = BTreeSet::new();
while scalar_set.len() < NUM_ITEMS {
let scalar = Scalar::random(&mut rng);
let g1 = G1::generator() * &scalar;
let g2 = G2::generator() * &scalar;
scalar_set.insert(scalar);
g1_set.insert(g1);
g2_set.insert(g2);
}
assert_eq!(scalar_set.len(), NUM_ITEMS);
assert_eq!(g1_set.len(), NUM_ITEMS);
assert_eq!(g2_set.len(), NUM_ITEMS);
let scalars: Vec<_> = scalar_set.iter().collect();
assert!(scalars.windows(2).all(|w| w[0] <= w[1]));
let g1s: Vec<_> = g1_set.iter().collect();
assert!(g1s.windows(2).all(|w| w[0] <= w[1]));
let g2s: Vec<_> = g2_set.iter().collect();
assert!(g2s.windows(2).all(|w| w[0] <= w[1]));
let scalar_map: HashMap<_, _> = scalar_set.iter().cloned().zip(0..).collect();
let g1_map: HashMap<_, _> = g1_set.iter().cloned().zip(0..).collect();
let g2_map: HashMap<_, _> = g2_set.iter().cloned().zip(0..).collect();
assert_eq!(scalar_map.len(), NUM_ITEMS);
assert_eq!(g1_map.len(), NUM_ITEMS);
assert_eq!(g2_map.len(), NUM_ITEMS);
}
#[test]
fn test_scalar_map() {
let msg = b"test message";
let dst = b"TEST_DST";
let scalar1 = Scalar::map(dst, msg);
let scalar2 = Scalar::map(dst, msg);
assert_eq!(scalar1, scalar2, "Same input should produce same output");
let msg2 = b"different message";
let scalar3 = Scalar::map(dst, msg2);
assert_ne!(
scalar1, scalar3,
"Different messages should produce different scalars"
);
let dst2 = b"DIFFERENT_DST";
let scalar4 = Scalar::map(dst2, msg);
assert_ne!(
scalar1, scalar4,
"Different DSTs should produce different scalars"
);
let empty_msg = b"";
let scalar_empty = Scalar::map(dst, empty_msg);
assert_ne!(
scalar_empty,
Scalar::zero(),
"Empty message should not produce zero"
);
let large_msg = vec![0x42u8; 1000];
let scalar_large = Scalar::map(dst, &large_msg);
assert_ne!(
scalar_large,
Scalar::zero(),
"Large message should not produce zero"
);
assert_ne!(
scalar1,
Scalar::zero(),
"Hash should not produce zero scalar"
);
}
#[test]
fn test_secret_scalar_equality() {
let mut rng = test_rng();
let scalar1 = Scalar::random(&mut rng);
let scalar2 = scalar1.clone();
let scalar3 = Scalar::random(&mut rng);
let s1 = Secret::new(scalar1);
let s2 = Secret::new(scalar2);
let s3 = Secret::new(scalar3);
assert_eq!(s1, s2);
assert_ne!(s1, s3);
}
#[test]
fn test_share_redacted() {
let mut rng = test_rng();
let share = Share::new(Participant::new(1), Private::random(&mut rng));
let debug = format!("{:?}", share);
let display = format!("{}", share);
assert!(debug.contains("REDACTED"));
assert!(display.contains("REDACTED"));
}
fn test_msm_parallel_impl<G>(points: Vec<G>, scalars: Vec<Scalar>)
where
G: Space<Scalar> + PartialEq + Debug + Copy,
{
let par = Rayon::new(NonZeroUsize::new(8).unwrap()).unwrap();
let seq = G::msm(&points, &scalars, &Sequential);
assert_eq!(seq, G::msm(&points, &scalars, &par));
}
fn test_msm_parallel_edge_cases_impl<G>(
points: Vec<G>,
scalars: Vec<Scalar>,
single_point: G,
single_scalar: Scalar,
idx: usize,
) where
G: Space<Scalar> + Additive + PartialEq + Debug + Copy,
for<'a> G: Mul<&'a Scalar, Output = G>,
{
let par = Rayon::new(NonZeroUsize::new(8).unwrap()).unwrap();
let n = points.len();
assert_eq!(G::msm(&points, &vec![Scalar::zero(); n], &par), G::zero());
assert_eq!(G::msm(&vec![G::zero(); n], &scalars, &par), G::zero());
let mut pts = vec![G::zero(); n];
let mut scalars = vec![Scalar::zero(); n];
pts[idx] = single_point;
scalars[idx] = single_scalar.clone();
assert_eq!(G::msm(&pts, &scalars, &par), single_point * &single_scalar);
}
#[test]
fn test_msm_parallel_g1() {
minifuzz::test(|u| {
let n = u.int_in_range(MIN_PARALLEL_POINTS..=100)?;
let points: Vec<G1> = (0..n)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
let scalars: Vec<Scalar> = (0..n)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
test_msm_parallel_impl(points, scalars);
Ok(())
});
}
#[test]
fn test_msm_parallel_g2() {
minifuzz::test(|u| {
let n = u.int_in_range(MIN_PARALLEL_POINTS..=100)?;
let points: Vec<G2> = (0..n)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
let scalars: Vec<Scalar> = (0..n)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
test_msm_parallel_impl(points, scalars);
Ok(())
});
}
#[test]
fn test_msm_parallel_edge_cases_g1() {
minifuzz::test(|u| {
let points: Vec<G1> = (0..50)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
let scalars: Vec<Scalar> = (0..50)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
let single_point: G1 = u.arbitrary()?;
let single_scalar: Scalar = u.arbitrary()?;
let idx: usize = u.int_in_range(0..=49)?;
test_msm_parallel_edge_cases_impl(points, scalars, single_point, single_scalar, idx);
Ok(())
});
}
#[test]
fn test_msm_parallel_edge_cases_g2() {
minifuzz::test(|u| {
let points: Vec<G2> = (0..50)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
let scalars: Vec<Scalar> = (0..50)
.map(|_| u.arbitrary())
.collect::<arbitrary::Result<_>>()?;
let single_point: G2 = u.arbitrary()?;
let single_scalar: Scalar = u.arbitrary()?;
let idx: usize = u.int_in_range(0..=49)?;
test_msm_parallel_edge_cases_impl(points, scalars, single_point, single_scalar, idx);
Ok(())
});
}
#[test]
fn test_msm_breakdown_high_parallelism() {
for npoints in [32, 50, 100, 200] {
let window = pippenger_window_size(npoints);
for ncpus in [64, 128, 256, 512, 1024, 2048] {
let (nx, ny, final_wnd) = msm_breakdown(SCALAR_BITS, window, ncpus);
assert!(nx >= 1 && ny >= 1 && final_wnd >= 1);
}
}
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<G1>,
CodecConformance<G2>,
CodecConformance<Private>,
CodecConformance<Scalar>,
CodecConformance<Share>
}
}
#[test]
fn test_small_scalar_to_scalar_preserves_bytes() {
let small = SmallScalar::random(test_rng());
let scalar = Scalar::from(small.clone());
let round_tripped = scalar.as_blst_scalar();
assert_eq!(small.as_bytes(), round_tripped.b.as_slice());
}
#[test]
fn test_ntt_constants() {
let root = Scalar::root_of_unity(32).unwrap();
let root_pow_2_32 = root.exp(&[1u64 << 32]);
assert_eq!(root_pow_2_32, Scalar::one(), "root^(2^32) should be 1");
let coset = Scalar::coset_shift();
let coset_inv = Scalar::coset_shift_inv();
let product = coset * &coset_inv;
assert_eq!(
product,
Scalar::one(),
"coset_shift * coset_shift_inv should be 1"
);
let two = Scalar::from_u64(2);
let half = Scalar::one().div_2();
assert_eq!(two * &half, Scalar::one(), "2 * (1/2) should be 1");
}
}