use crate::{
bls12_381::arkworks::pairing_check, bls12_381_const::TRUSTED_SETUP_TAU_G2_BYTES, PrecompileHalt,
};
use ark_bls12_381::{Fr, G1Affine, G2Affine};
use ark_ec::{AffineRepr, CurveGroup};
use ark_ff::{BigInteger, PrimeField};
use ark_serialize::CanonicalDeserialize;
use core::ops::Neg;
use primitives::OnceLock;
#[inline]
pub fn verify_kzg_proof(
commitment: &[u8; 48],
z: &[u8; 32],
y: &[u8; 32],
proof: &[u8; 48],
) -> bool {
let Ok(commitment_point) = parse_g1_compressed(commitment) else {
return false;
};
let Ok(proof_point) = parse_g1_compressed(proof) else {
return false;
};
let Ok(z_fr) = read_scalar_canonical(z) else {
return false;
};
let Ok(y_fr) = read_scalar_canonical(y) else {
return false;
};
let tau_g2 = get_trusted_setup_g2();
let g1 = get_g1_generator();
let g2 = get_g2_generator();
let y_g1 = p1_scalar_mul(&g1, &y_fr);
let p_minus_y = p1_sub_affine(&commitment_point, &y_g1);
let z_g2 = p2_scalar_mul(&g2, &z_fr);
let x_minus_z = p2_sub_affine(tau_g2, &z_g2);
let neg_g2 = p2_neg(&g2);
pairing_check(&[(p_minus_y, neg_g2), (proof_point, x_minus_z)])
}
fn get_trusted_setup_g2() -> &'static G2Affine {
static TAU_G2: OnceLock<G2Affine> = OnceLock::new();
TAU_G2.get_or_init(|| {
G2Affine::deserialize_compressed_unchecked(&TRUSTED_SETUP_TAU_G2_BYTES[..])
.expect("Failed to parse trusted setup G2 point")
})
}
fn parse_g1_compressed(bytes: &[u8; 48]) -> Result<G1Affine, PrecompileHalt> {
G1Affine::deserialize_compressed(&bytes[..]).map_err(|_| PrecompileHalt::KzgInvalidG1Point)
}
fn read_scalar_canonical(bytes: &[u8; 32]) -> Result<Fr, PrecompileHalt> {
let fr = Fr::from_be_bytes_mod_order(bytes);
let bytes_roundtrip = fr.into_bigint().to_bytes_be();
if bytes_roundtrip.as_slice() != bytes {
return Err(PrecompileHalt::NonCanonicalFp);
}
Ok(fr)
}
#[inline]
fn get_g1_generator() -> G1Affine {
G1Affine::generator()
}
#[inline]
fn get_g2_generator() -> G2Affine {
G2Affine::generator()
}
#[inline]
fn p1_scalar_mul(point: &G1Affine, scalar: &Fr) -> G1Affine {
point.mul_bigint(scalar.into_bigint()).into_affine()
}
#[inline]
fn p2_scalar_mul(point: &G2Affine, scalar: &Fr) -> G2Affine {
point.mul_bigint(scalar.into_bigint()).into_affine()
}
#[inline]
fn p1_sub_affine(a: &G1Affine, b: &G1Affine) -> G1Affine {
(a.into_group() - b.into_group()).into_affine()
}
#[inline]
fn p2_sub_affine(a: &G2Affine, b: &G2Affine) -> G2Affine {
(a.into_group() - b.into_group()).into_affine()
}
#[inline]
fn p2_neg(p: &G2Affine) -> G2Affine {
p.neg()
}