use blst::{
blst_bendian_from_fp, blst_final_exp, blst_fp, blst_fp_from_bendian, blst_fp2, blst_fp12,
blst_fp12_is_one, blst_fp12_mul, blst_map_to_g1, blst_map_to_g2, blst_miller_loop, blst_p1,
blst_p1_add_or_double_affine, blst_p1_affine, blst_p1_affine_in_g1, blst_p1_affine_is_inf,
blst_p1_affine_on_curve, blst_p1_from_affine, blst_p1_mult, blst_p1_to_affine,
blst_p1s_mult_pippenger, blst_p1s_mult_pippenger_scratch_sizeof, blst_p2,
blst_p2_add_or_double_affine, blst_p2_affine, blst_p2_affine_in_g2, blst_p2_affine_is_inf,
blst_p2_affine_on_curve, blst_p2_from_affine, blst_p2_mult, blst_p2_to_affine,
blst_p2s_mult_pippenger, blst_p2s_mult_pippenger_scratch_sizeof, blst_scalar,
blst_scalar_from_be_bytes,
};
use crate::provider::CryptoError;
const FP_LENGTH: usize = 48;
const SCALAR_BITS: usize = 256;
const MODULUS_REPR: [u8; FP_LENGTH] = [
0x1a, 0x01, 0x11, 0xea, 0x39, 0x7f, 0xe6, 0x9a, 0x4b, 0x1b, 0xa7, 0xb6, 0x43, 0x4b, 0xac, 0xd7,
0x64, 0x77, 0x4b, 0x84, 0xf3, 0x85, 0x12, 0xbf, 0x67, 0x30, 0xd2, 0xa0, 0xf6, 0xb0, 0xf6, 0x24,
0x1e, 0xab, 0xff, 0xfe, 0xb1, 0x53, 0xff, 0xff, 0xb9, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xaa, 0xab,
];
#[inline]
fn p1_to_affine(p: &blst_p1) -> blst_p1_affine {
let mut out = blst_p1_affine::default();
unsafe { blst_p1_to_affine(&mut out, p) };
out
}
#[inline]
fn p1_from_affine(p: &blst_p1_affine) -> blst_p1 {
let mut out = blst_p1::default();
unsafe { blst_p1_from_affine(&mut out, p) };
out
}
#[inline]
fn p2_to_affine(p: &blst_p2) -> blst_p2_affine {
let mut out = blst_p2_affine::default();
unsafe { blst_p2_to_affine(&mut out, p) };
out
}
#[inline]
fn p2_from_affine(p: &blst_p2_affine) -> blst_p2 {
let mut out = blst_p2::default();
unsafe { blst_p2_from_affine(&mut out, p) };
out
}
#[inline]
fn p1_add_affine(a: &blst_p1_affine, b: &blst_p1_affine) -> blst_p1_affine {
let a_jac = p1_from_affine(a);
let mut sum = blst_p1::default();
unsafe { blst_p1_add_or_double_affine(&mut sum, &a_jac, b) };
p1_to_affine(&sum)
}
#[inline]
fn p2_add_affine(a: &blst_p2_affine, b: &blst_p2_affine) -> blst_p2_affine {
let a_jac = p2_from_affine(a);
let mut sum = blst_p2::default();
unsafe { blst_p2_add_or_double_affine(&mut sum, &a_jac, b) };
p2_to_affine(&sum)
}
#[inline]
fn p1_scalar_mul(p: &blst_p1_affine, scalar: &blst_scalar) -> blst_p1_affine {
let p_jac = p1_from_affine(p);
let mut out = blst_p1::default();
unsafe { blst_p1_mult(&mut out, &p_jac, scalar.b.as_ptr(), scalar.b.len() * 8) };
p1_to_affine(&out)
}
#[inline]
fn p2_scalar_mul(p: &blst_p2_affine, scalar: &blst_scalar) -> blst_p2_affine {
let p_jac = p2_from_affine(p);
let mut out = blst_p2::default();
unsafe { blst_p2_mult(&mut out, &p_jac, scalar.b.as_ptr(), scalar.b.len() * 8) };
p2_to_affine(&out)
}
#[inline]
fn p1_msm(points: &[blst_p1_affine], scalars: &[u8]) -> blst_p1 {
let npoints = points.len();
let p: [*const blst_p1_affine; 2] = [points.as_ptr(), core::ptr::null()];
let s: [*const u8; 2] = [scalars.as_ptr(), core::ptr::null()];
let mut out = blst_p1::default();
unsafe {
let mut scratch = vec![0u64; blst_p1s_mult_pippenger_scratch_sizeof(npoints).div_ceil(8)];
blst_p1s_mult_pippenger(
&mut out,
p.as_ptr(),
npoints,
s.as_ptr(),
SCALAR_BITS,
scratch.as_mut_ptr(),
);
}
out
}
#[inline]
fn p2_msm(points: &[blst_p2_affine], scalars: &[u8]) -> blst_p2 {
let npoints = points.len();
let p: [*const blst_p2_affine; 2] = [points.as_ptr(), core::ptr::null()];
let s: [*const u8; 2] = [scalars.as_ptr(), core::ptr::null()];
let mut out = blst_p2::default();
unsafe {
let mut scratch = vec![0u64; blst_p2s_mult_pippenger_scratch_sizeof(npoints).div_ceil(8)];
blst_p2s_mult_pippenger(
&mut out,
p.as_ptr(),
npoints,
s.as_ptr(),
SCALAR_BITS,
scratch.as_mut_ptr(),
);
}
out
}
#[inline]
fn read_fp(input: &[u8; FP_LENGTH]) -> Result<blst_fp, CryptoError> {
if *input >= MODULUS_REPR {
return Err(CryptoError::InvalidInput("fp coordinate >= field modulus"));
}
let mut fp = blst_fp::default();
unsafe { blst_fp_from_bendian(&mut fp, input.as_ptr()) };
Ok(fp)
}
#[inline]
fn read_fp2(c0: &[u8; FP_LENGTH], c1: &[u8; FP_LENGTH]) -> Result<blst_fp2, CryptoError> {
Ok(blst_fp2 {
fp: [read_fp(c0)?, read_fp(c1)?],
})
}
fn decode_g1_on_curve(x: &[u8; 48], y: &[u8; 48]) -> Result<blst_p1_affine, CryptoError> {
let point = blst_p1_affine {
x: read_fp(x)?,
y: read_fp(y)?,
};
if unsafe { !blst_p1_affine_on_curve(&point) } {
return Err(CryptoError::InvalidPoint("G1 point not on curve"));
}
Ok(point)
}
fn read_g1_subgroup(x: &[u8; 48], y: &[u8; 48]) -> Result<blst_p1_affine, CryptoError> {
let point = decode_g1_on_curve(x, y)?;
if unsafe { !blst_p1_affine_in_g1(&point) } {
return Err(CryptoError::InvalidPoint("G1 point not in subgroup"));
}
Ok(point)
}
fn decode_g2_on_curve(
x0: &[u8; 48],
x1: &[u8; 48],
y0: &[u8; 48],
y1: &[u8; 48],
) -> Result<blst_p2_affine, CryptoError> {
let point = blst_p2_affine {
x: read_fp2(x0, x1)?,
y: read_fp2(y0, y1)?,
};
if unsafe { !blst_p2_affine_on_curve(&point) } {
return Err(CryptoError::InvalidPoint("G2 point not on curve"));
}
Ok(point)
}
fn read_g2_subgroup(
x0: &[u8; 48],
x1: &[u8; 48],
y0: &[u8; 48],
y1: &[u8; 48],
) -> Result<blst_p2_affine, CryptoError> {
let point = decode_g2_on_curve(x0, x1, y0, y1)?;
if unsafe { !blst_p2_affine_in_g2(&point) } {
return Err(CryptoError::InvalidPoint("G2 point not in subgroup"));
}
Ok(point)
}
#[cfg(test)]
use blst::blst_scalar_from_bendian;
#[cfg(test)]
#[inline]
fn read_scalar(bytes: &[u8; 32]) -> blst_scalar {
let mut out = blst_scalar::default();
unsafe { blst_scalar_from_bendian(&mut out, bytes.as_ptr()) };
out
}
#[inline]
fn read_scalar_mod_r(bytes: &[u8; 32]) -> Option<blst_scalar> {
let mut out = blst_scalar::default();
let nonzero = unsafe { blst_scalar_from_be_bytes(&mut out, bytes.as_ptr(), 32) };
nonzero.then_some(out)
}
#[inline]
fn fp_to_bytes(out: &mut [u8; FP_LENGTH], fp: &blst_fp) {
unsafe { blst_bendian_from_fp(out.as_mut_ptr(), fp) };
}
fn encode_g1(point: &blst_p1_affine) -> [u8; 96] {
let mut out = [0u8; 96];
let (x, y) = out.split_at_mut(FP_LENGTH);
fp_to_bytes(x.try_into().expect("48 bytes"), &point.x);
fp_to_bytes(y.try_into().expect("48 bytes"), &point.y);
out
}
fn encode_g2(point: &blst_p2_affine) -> [u8; 192] {
let mut out = [0u8; 192];
fp_to_bytes(
(&mut out[0..48]).try_into().expect("48 bytes"),
&point.x.fp[0],
);
fp_to_bytes(
(&mut out[48..96]).try_into().expect("48 bytes"),
&point.x.fp[1],
);
fp_to_bytes(
(&mut out[96..144]).try_into().expect("48 bytes"),
&point.y.fp[0],
);
fp_to_bytes(
(&mut out[144..192]).try_into().expect("48 bytes"),
&point.y.fp[1],
);
out
}
#[inline]
fn is_zero(bytes: &[u8]) -> bool {
bytes.iter().all(|&b| b == 0)
}
pub fn g1_add(a: ([u8; 48], [u8; 48]), b: ([u8; 48], [u8; 48])) -> Result<[u8; 96], CryptoError> {
let pa = decode_g1_on_curve(&a.0, &a.1)?;
let pb = decode_g1_on_curve(&b.0, &b.1)?;
Ok(encode_g1(&p1_add_affine(&pa, &pb)))
}
pub fn g2_add(
a: ([u8; 48], [u8; 48], [u8; 48], [u8; 48]),
b: ([u8; 48], [u8; 48], [u8; 48], [u8; 48]),
) -> Result<[u8; 192], CryptoError> {
let pa = decode_g2_on_curve(&a.0, &a.1, &a.2, &a.3)?;
let pb = decode_g2_on_curve(&b.0, &b.1, &b.2, &b.3)?;
Ok(encode_g2(&p2_add_affine(&pa, &pb)))
}
#[allow(clippy::type_complexity)]
pub fn g1_msm(pairs: &[(([u8; 48], [u8; 48]), [u8; 32])]) -> Result<[u8; 96], CryptoError> {
let mut points = Vec::with_capacity(pairs.len());
let mut scalars = Vec::with_capacity(pairs.len());
for ((x, y), scalar_bytes) in pairs {
let point = decode_g1_on_curve(x, y)?;
if unsafe { blst_p1_affine_is_inf(&point) } {
continue;
}
if unsafe { !blst_p1_affine_in_g1(&point) } {
return Err(CryptoError::InvalidPoint("G1 point not in subgroup"));
}
let Some(scalar) = read_scalar_mod_r(scalar_bytes) else {
continue;
};
points.push(point);
scalars.push(scalar);
}
if points.is_empty() {
return Ok([0u8; 96]);
}
if points.len() == 1 {
return Ok(encode_g1(&p1_scalar_mul(&points[0], &scalars[0])));
}
let scalar_bytes: Vec<u8> = scalars.iter().flat_map(|s| s.b).collect();
let result = p1_to_affine(&p1_msm(&points, &scalar_bytes));
Ok(encode_g1(&result))
}
#[allow(clippy::type_complexity)]
pub fn g2_msm(
pairs: &[(([u8; 48], [u8; 48], [u8; 48], [u8; 48]), [u8; 32])],
) -> Result<[u8; 192], CryptoError> {
let mut points = Vec::with_capacity(pairs.len());
let mut scalars = Vec::with_capacity(pairs.len());
for ((x0, x1, y0, y1), scalar_bytes) in pairs {
let point = decode_g2_on_curve(x0, x1, y0, y1)?;
if unsafe { blst_p2_affine_is_inf(&point) } {
continue;
}
if unsafe { !blst_p2_affine_in_g2(&point) } {
return Err(CryptoError::InvalidPoint("G2 point not in subgroup"));
}
let Some(scalar) = read_scalar_mod_r(scalar_bytes) else {
continue;
};
points.push(point);
scalars.push(scalar);
}
if points.is_empty() {
return Ok([0u8; 192]);
}
if points.len() == 1 {
return Ok(encode_g2(&p2_scalar_mul(&points[0], &scalars[0])));
}
let scalar_bytes: Vec<u8> = scalars.iter().flat_map(|s| s.b).collect();
let result = p2_to_affine(&p2_msm(&points, &scalar_bytes));
Ok(encode_g2(&result))
}
#[allow(clippy::type_complexity)]
pub fn pairing_check(
pairs: &[(
([u8; 48], [u8; 48]),
([u8; 48], [u8; 48], [u8; 48], [u8; 48]),
)],
) -> Result<bool, CryptoError> {
let mut parsed = Vec::with_capacity(pairs.len());
for ((g1x, g1y), (g2x0, g2x1, g2y0, g2y1)) in pairs {
let g1_inf = is_zero(g1x) && is_zero(g1y);
let g2_inf = is_zero(g2x0) && is_zero(g2x1) && is_zero(g2y0) && is_zero(g2y1);
if g1_inf || g2_inf {
if !g1_inf {
read_g1_subgroup(g1x, g1y)?;
}
if !g2_inf {
read_g2_subgroup(g2x0, g2x1, g2y0, g2y1)?;
}
continue;
}
let g1 = read_g1_subgroup(g1x, g1y)?;
let g2 = read_g2_subgroup(g2x0, g2x1, g2y0, g2y1)?;
parsed.push((g1, g2));
}
if parsed.is_empty() {
return Ok(true);
}
let mut acc = blst_fp12::default();
let (first_g1, first_g2) = &parsed[0];
unsafe { blst_miller_loop(&mut acc, first_g2, first_g1) };
for (g1, g2) in parsed.iter().skip(1) {
let mut ml = blst_fp12::default();
let mut next = blst_fp12::default();
unsafe {
blst_miller_loop(&mut ml, g2, g1);
blst_fp12_mul(&mut next, &acc, &ml);
}
acc = next;
}
let mut result = blst_fp12::default();
unsafe { blst_final_exp(&mut result, &acc) };
Ok(unsafe { blst_fp12_is_one(&result) })
}
pub fn fp_to_g1(fp: &[u8; 48]) -> Result<[u8; 96], CryptoError> {
let fp = read_fp(fp)?;
let mut p = blst_p1::default();
unsafe { blst_map_to_g1(&mut p, &fp, core::ptr::null()) };
Ok(encode_g1(&p1_to_affine(&p)))
}
pub fn fp2_to_g2(fp2: ([u8; 48], [u8; 48])) -> Result<[u8; 192], CryptoError> {
let fp2 = read_fp2(&fp2.0, &fp2.1)?;
let mut p = blst_p2::default();
unsafe { blst_map_to_g2(&mut p, &fp2, core::ptr::null()) };
Ok(encode_g2(&p2_to_affine(&p)))
}
#[cfg(test)]
mod tests {
use super::*;
use blst::{blst_p1_generator, blst_p2_generator};
fn scalar_bytes(seed: u64) -> [u8; 32] {
let mut s = [0u8; 32];
s[24..].copy_from_slice(&(seed | 1).to_be_bytes());
s
}
fn g1_xy(seed: u64) -> ([u8; 48], [u8; 48]) {
let g = p1_to_affine(unsafe { &*blst_p1_generator() });
let p = p1_scalar_mul(&g, &read_scalar(&scalar_bytes(seed)));
let enc = encode_g1(&p);
(
enc[0..48].try_into().unwrap(),
enc[48..96].try_into().unwrap(),
)
}
fn g2_xy(seed: u64) -> ([u8; 48], [u8; 48], [u8; 48], [u8; 48]) {
let g = p2_to_affine(unsafe { &*blst_p2_generator() });
let p = p2_scalar_mul(&g, &read_scalar(&scalar_bytes(seed)));
let e = encode_g2(&p);
(
e[0..48].try_into().unwrap(),
e[48..96].try_into().unwrap(),
e[96..144].try_into().unwrap(),
e[144..192].try_into().unwrap(),
)
}
#[test]
fn g1_msm_skips_infinity() {
let p = g1_xy(3);
let inf = ([0u8; 48], [0u8; 48]);
let with_inf = g1_msm(&[(p, scalar_bytes(5)), (inf, scalar_bytes(9))]).unwrap();
let without = g1_msm(&[(p, scalar_bytes(5))]).unwrap();
assert_eq!(with_inf, without, "infinity term must contribute nothing");
assert_eq!(g1_msm(&[(inf, scalar_bytes(9))]).unwrap(), [0u8; 96]);
}
#[test]
fn g2_msm_skips_infinity() {
let p = g2_xy(3);
let inf = ([0u8; 48], [0u8; 48], [0u8; 48], [0u8; 48]);
let with_inf = g2_msm(&[(p, scalar_bytes(5)), (inf, scalar_bytes(9))]).unwrap();
let without = g2_msm(&[(p, scalar_bytes(5))]).unwrap();
assert_eq!(with_inf, without, "infinity term must contribute nothing");
assert_eq!(g2_msm(&[(inf, scalar_bytes(9))]).unwrap(), [0u8; 192]);
}
#[test]
fn g1_msm_point_times_order_is_identity_then_infinity() {
const R: [u8; 32] = [
0x73, 0xed, 0xa7, 0x53, 0x29, 0x9d, 0x7d, 0x48, 0x33, 0x39, 0xd8, 0x08, 0x09, 0xa1,
0xd8, 0x05, 0x53, 0xbd, 0xa4, 0x02, 0xff, 0xfe, 0x5b, 0xfe, 0xff, 0xff, 0xff, 0xff,
0x00, 0x00, 0x00, 0x01,
];
let out = g1_msm(&[(g1_xy(3), R)]).unwrap();
assert_eq!(out, [0u8; 96], "P * r == identity");
let inf = ([0u8; 48], [0u8; 48]);
assert_eq!(g1_msm(&[(inf, R)]).unwrap(), [0u8; 96]);
}
}