use super::{G1Point, G2Point, PairingPair};
use crate::{
bls12_381_const::{FP_LENGTH, G1_LENGTH, G2_LENGTH, SCALAR_LENGTH},
PrecompileHalt,
};
use ark_bls12_381::{Bls12_381, Fq, Fq2, Fr, G1Affine, G1Projective, G2Affine, G2Projective};
use ark_ec::{
hashing::{curve_maps::wb::WBMap, map_to_curve_hasher::MapToCurve},
pairing::Pairing,
AffineRepr, CurveGroup, VariableBaseMSM,
};
use ark_ff::{One, PrimeField, Zero};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use std::vec::Vec;
#[inline]
fn read_fp(input_be: &[u8]) -> Result<Fq, PrecompileHalt> {
assert_eq!(input_be.len(), FP_LENGTH, "input must be {FP_LENGTH} bytes");
let mut input_le = [0u8; FP_LENGTH];
input_le.copy_from_slice(input_be);
input_le.reverse();
Fq::deserialize_uncompressed(&input_le[..]).map_err(|_| PrecompileHalt::NonCanonicalFp)
}
fn encode_fp(fp: &Fq) -> [u8; FP_LENGTH] {
let mut bytes = [0u8; FP_LENGTH];
fp.serialize_uncompressed(&mut bytes[..])
.expect("Failed to serialize field element");
bytes.reverse();
bytes
}
#[inline]
fn read_fp2(input_1: &[u8; FP_LENGTH], input_2: &[u8; FP_LENGTH]) -> Result<Fq2, PrecompileHalt> {
let fp_1 = read_fp(input_1)?;
let fp_2 = read_fp(input_2)?;
Ok(Fq2::new(fp_1, fp_2))
}
#[inline]
fn new_g1_point_no_subgroup_check(px: Fq, py: Fq) -> Result<G1Affine, PrecompileHalt> {
if px.is_zero() && py.is_zero() {
Ok(G1Affine::zero())
} else {
let point = G1Affine::new_unchecked(px, py);
if !point.is_on_curve() {
return Err(PrecompileHalt::Bls12381G1NotOnCurve);
}
Ok(point)
}
}
#[inline]
fn new_g2_point_no_subgroup_check(x: Fq2, y: Fq2) -> Result<G2Affine, PrecompileHalt> {
let point = if x.is_zero() && y.is_zero() {
G2Affine::zero()
} else {
let point = G2Affine::new_unchecked(x, y);
if !point.is_on_curve() {
return Err(PrecompileHalt::Bls12381G2NotOnCurve);
}
point
};
Ok(point)
}
#[inline]
fn read_g1(x: &[u8; FP_LENGTH], y: &[u8; FP_LENGTH]) -> Result<G1Affine, PrecompileHalt> {
let point = read_g1_no_subgroup_check(x, y)?;
if !point.is_in_correct_subgroup_assuming_on_curve() {
return Err(PrecompileHalt::Bls12381G1NotInSubgroup);
}
Ok(point)
}
#[inline]
fn read_g1_no_subgroup_check(
x: &[u8; FP_LENGTH],
y: &[u8; FP_LENGTH],
) -> Result<G1Affine, PrecompileHalt> {
let px = read_fp(x)?;
let py = read_fp(y)?;
new_g1_point_no_subgroup_check(px, py)
}
#[inline]
fn encode_g1_point(input: &G1Affine) -> [u8; G1_LENGTH] {
let mut output = [0u8; G1_LENGTH];
let Some((x, y)) = input.xy() else {
return output; };
let x_encoded = encode_fp(&x);
let y_encoded = encode_fp(&y);
output[..FP_LENGTH].copy_from_slice(&x_encoded);
output[FP_LENGTH..].copy_from_slice(&y_encoded);
output
}
#[inline]
fn read_g2(
a_x_0: &[u8; FP_LENGTH],
a_x_1: &[u8; FP_LENGTH],
a_y_0: &[u8; FP_LENGTH],
a_y_1: &[u8; FP_LENGTH],
) -> Result<G2Affine, PrecompileHalt> {
let point = read_g2_no_subgroup_check(a_x_0, a_x_1, a_y_0, a_y_1)?;
if !point.is_in_correct_subgroup_assuming_on_curve() {
return Err(PrecompileHalt::Bls12381G2NotInSubgroup);
}
Ok(point)
}
#[inline]
fn read_g2_no_subgroup_check(
a_x_0: &[u8; FP_LENGTH],
a_x_1: &[u8; FP_LENGTH],
a_y_0: &[u8; FP_LENGTH],
a_y_1: &[u8; FP_LENGTH],
) -> Result<G2Affine, PrecompileHalt> {
let x = read_fp2(a_x_0, a_x_1)?;
let y = read_fp2(a_y_0, a_y_1)?;
new_g2_point_no_subgroup_check(x, y)
}
#[inline]
fn encode_g2_point(input: &G2Affine) -> [u8; G2_LENGTH] {
let mut output = [0u8; G2_LENGTH];
let Some((x, y)) = input.xy() else {
return output; };
let x_c0_encoded = encode_fp(&x.c0);
let x_c1_encoded = encode_fp(&x.c1);
let y_c0_encoded = encode_fp(&y.c0);
let y_c1_encoded = encode_fp(&y.c1);
output[..FP_LENGTH].copy_from_slice(&x_c0_encoded);
output[FP_LENGTH..2 * FP_LENGTH].copy_from_slice(&x_c1_encoded);
output[2 * FP_LENGTH..3 * FP_LENGTH].copy_from_slice(&y_c0_encoded);
output[3 * FP_LENGTH..4 * FP_LENGTH].copy_from_slice(&y_c1_encoded);
output
}
#[inline]
fn read_scalar(input: &[u8]) -> Result<Fr, PrecompileHalt> {
if input.len() != SCALAR_LENGTH {
return Err(PrecompileHalt::Bls12381ScalarInputLength);
}
Ok(Fr::from_be_bytes_mod_order(input))
}
#[inline]
fn p1_add_affine(p1: &G1Affine, p2: &G1Affine) -> G1Affine {
let p1_proj: G1Projective = (*p1).into();
let p3 = p1_proj + p2;
p3.into_affine()
}
#[inline]
fn p2_add_affine(p1: &G2Affine, p2: &G2Affine) -> G2Affine {
let p1_proj: G2Projective = (*p1).into();
let p3 = p1_proj + p2;
p3.into_affine()
}
#[inline]
fn p1_msm(g1_points: Vec<G1Affine>, scalars: Vec<Fr>) -> G1Affine {
assert_eq!(
g1_points.len(),
scalars.len(),
"number of scalars should equal the number of g1 points"
);
if g1_points.is_empty() {
return G1Affine::zero();
}
if g1_points.len() == 1 {
let big_int = scalars[0].into_bigint();
return g1_points[0].mul_bigint(big_int).into_affine();
}
G1Projective::msm(&g1_points, &scalars)
.expect("MSM should succeed")
.into_affine()
}
#[inline]
fn p2_msm(g2_points: Vec<G2Affine>, scalars: Vec<Fr>) -> G2Affine {
assert_eq!(
g2_points.len(),
scalars.len(),
"number of scalars should equal the number of g2 points"
);
if g2_points.is_empty() {
return G2Affine::zero();
}
if g2_points.len() == 1 {
let big_int = scalars[0].into_bigint();
return g2_points[0].mul_bigint(big_int).into_affine();
}
G2Projective::msm(&g2_points, &scalars)
.expect("MSM should succeed")
.into_affine()
}
#[inline]
fn map_fp_to_g1(fp: &Fq) -> G1Affine {
WBMap::map_to_curve(*fp)
.expect("map_to_curve is infallible")
.clear_cofactor()
}
#[inline]
fn map_fp2_to_g2(fp2: &Fq2) -> G2Affine {
WBMap::map_to_curve(*fp2)
.expect("map_to_curve is infallible")
.clear_cofactor()
}
#[inline]
pub(crate) fn pairing_check(pairs: &[(G1Affine, G2Affine)]) -> bool {
if pairs.is_empty() {
return true;
}
let (g1_points, g2_points): (Vec<G1Affine>, Vec<G2Affine>) = pairs.iter().copied().unzip();
let pairing_result = Bls12_381::multi_pairing(&g1_points, &g2_points);
pairing_result.0.is_one()
}
#[inline]
pub(crate) fn pairing_check_bytes(pairs: &[PairingPair]) -> Result<bool, PrecompileHalt> {
super::pairing_common::pairing_check_bytes_generic(pairs, read_g1, read_g2, pairing_check)
}
#[inline]
pub(crate) fn p1_add_affine_bytes(
a: G1Point,
b: G1Point,
) -> Result<[u8; G1_LENGTH], PrecompileHalt> {
let (a_x, a_y) = a;
let (b_x, b_y) = b;
let p1 = read_g1_no_subgroup_check(&a_x, &a_y)?;
let p2 = read_g1_no_subgroup_check(&b_x, &b_y)?;
let result = p1_add_affine(&p1, &p2);
Ok(encode_g1_point(&result))
}
#[inline]
pub(crate) fn p2_add_affine_bytes(
a: G2Point,
b: G2Point,
) -> Result<[u8; G2_LENGTH], PrecompileHalt> {
let (a_x_0, a_x_1, a_y_0, a_y_1) = a;
let (b_x_0, b_x_1, b_y_0, b_y_1) = b;
let p1 = read_g2_no_subgroup_check(&a_x_0, &a_x_1, &a_y_0, &a_y_1)?;
let p2 = read_g2_no_subgroup_check(&b_x_0, &b_x_1, &b_y_0, &b_y_1)?;
let result = p2_add_affine(&p1, &p2);
Ok(encode_g2_point(&result))
}
#[inline]
pub(crate) fn map_fp_to_g1_bytes(
fp_bytes: &[u8; FP_LENGTH],
) -> Result<[u8; G1_LENGTH], PrecompileHalt> {
let fp = read_fp(fp_bytes)?;
let result = map_fp_to_g1(&fp);
Ok(encode_g1_point(&result))
}
#[inline]
pub(crate) fn map_fp2_to_g2_bytes(
fp2_x: &[u8; FP_LENGTH],
fp2_y: &[u8; FP_LENGTH],
) -> Result<[u8; G2_LENGTH], PrecompileHalt> {
let fp2 = read_fp2(fp2_x, fp2_y)?;
let result = map_fp2_to_g2(&fp2);
Ok(encode_g2_point(&result))
}
#[inline]
pub(crate) fn p1_msm_bytes(
point_scalar_pairs: impl Iterator<Item = Result<(G1Point, [u8; SCALAR_LENGTH]), PrecompileHalt>>,
) -> Result<[u8; G1_LENGTH], PrecompileHalt> {
let (lower, _) = point_scalar_pairs.size_hint();
let mut g1_points = Vec::with_capacity(lower);
let mut scalars = Vec::with_capacity(lower);
for pair_result in point_scalar_pairs {
let ((x, y), scalar_bytes) = pair_result?;
let point = read_g1(&x, &y)?;
if scalar_bytes.iter().all(|&b| b == 0) {
continue;
}
let scalar = read_scalar(&scalar_bytes)?;
g1_points.push(point);
scalars.push(scalar);
}
if g1_points.is_empty() {
return Ok([0u8; G1_LENGTH]);
}
let result = p1_msm(g1_points, scalars);
Ok(encode_g1_point(&result))
}
#[inline]
pub(crate) fn p2_msm_bytes(
point_scalar_pairs: impl Iterator<Item = Result<(G2Point, [u8; SCALAR_LENGTH]), PrecompileHalt>>,
) -> Result<[u8; G2_LENGTH], PrecompileHalt> {
let (lower, _) = point_scalar_pairs.size_hint();
let mut g2_points = Vec::with_capacity(lower);
let mut scalars = Vec::with_capacity(lower);
for pair_result in point_scalar_pairs {
let ((x_0, x_1, y_0, y_1), scalar_bytes) = pair_result?;
let point = read_g2(&x_0, &x_1, &y_0, &y_1)?;
if scalar_bytes.iter().all(|&b| b == 0) {
continue;
}
let scalar = read_scalar(&scalar_bytes)?;
g2_points.push(point);
scalars.push(scalar);
}
if g2_points.is_empty() {
return Ok([0u8; G2_LENGTH]);
}
let result = p2_msm(g2_points, scalars);
Ok(encode_g2_point(&result))
}