use super::external::{AdapterError, ExternalProof, ProofMetadata};
use super::parsing;
use crate::backend::bn254::Bn254;
use crate::crypto::AccumulatorInstance;
use halo2curves::bn256::{Fr, G1Affine, G2Affine};
#[cfg(test)]
use group::Curve;
#[cfg(test)]
use halo2curves::bn256::{G1, G2};
#[derive(Debug, Clone)]
pub struct SnarkjsProof {
pub pi_a: G1Affine,
pub pi_b: G2Affine,
pub pi_c: G1Affine,
pub public_inputs: Vec<Fr>,
}
impl SnarkjsProof {
pub fn new(pi_a: G1Affine, pi_b: G2Affine, pi_c: G1Affine, public_inputs: Vec<Fr>) -> Self {
Self {
pi_a,
pi_b,
pi_c,
public_inputs,
}
}
pub fn from_json(json: &str) -> Result<Self, AdapterError> {
let parsed: serde_json::Value = serde_json::from_str(json)
.map_err(|e| AdapterError::ParseError(format!("Invalid JSON: {}", e)))?;
if let Some(protocol) = parsed.get("protocol").and_then(|v| v.as_str()) {
if protocol != "groth16" {
return Err(AdapterError::Unsupported(format!(
"Expected groth16, got {}",
protocol
)));
}
}
if let Some(curve) = parsed.get("curve").and_then(|v| v.as_str()) {
if curve != "bn128" && curve != "bn254" {
return Err(AdapterError::Unsupported(format!(
"Expected bn128/bn254, got {}",
curve
)));
}
}
let pi_a = Self::parse_g1_json(&parsed["pi_a"])
.map_err(|e| AdapterError::InvalidPoint(format!("pi_a: {}", e)))?;
let pi_b = Self::parse_g2_json(&parsed["pi_b"])
.map_err(|e| AdapterError::InvalidPoint(format!("pi_b: {}", e)))?;
let pi_c = Self::parse_g1_json(&parsed["pi_c"])
.map_err(|e| AdapterError::InvalidPoint(format!("pi_c: {}", e)))?;
let public_inputs = Self::parse_public_inputs(&parsed)
.map_err(|e| AdapterError::ParseError(format!("public inputs: {}", e)))?;
Ok(Self {
pi_a,
pi_b,
pi_c,
public_inputs,
})
}
fn parse_g1_json(value: &serde_json::Value) -> Result<G1Affine, String> {
let arr = value.as_array().ok_or("Expected array for G1")?;
if arr.len() < 2 {
return Err("G1 point needs at least x, y".to_string());
}
let x_str = arr[0].as_str().ok_or("x must be string")?;
let y_str = arr[1].as_str().ok_or("y must be string")?;
let x = parsing::parse_fq_decimal(x_str)?;
let y = parsing::parse_fq_decimal(y_str)?;
parsing::g1_from_xy(x, y)
}
fn parse_g2_json(value: &serde_json::Value) -> Result<G2Affine, String> {
let arr = value.as_array().ok_or("Expected array for G2")?;
if arr.len() < 2 {
return Err("G2 point needs at least x, y components".to_string());
}
let x_arr = arr[0].as_array().ok_or("G2 x must be array")?;
if x_arr.len() < 2 {
return Err("G2 x needs two components".to_string());
}
let x_im = x_arr[0].as_str().ok_or("x_im must be string")?;
let x_re = x_arr[1].as_str().ok_or("x_re must be string")?;
let y_arr = arr[1].as_array().ok_or("G2 y must be array")?;
if y_arr.len() < 2 {
return Err("G2 y needs two components".to_string());
}
let y_im = y_arr[0].as_str().ok_or("y_im must be string")?;
let y_re = y_arr[1].as_str().ok_or("y_re must be string")?;
let x_c0 = parsing::parse_fq_decimal(x_re)?; let x_c1 = parsing::parse_fq_decimal(x_im)?; let y_c0 = parsing::parse_fq_decimal(y_re)?;
let y_c1 = parsing::parse_fq_decimal(y_im)?;
parsing::g2_from_fq2(x_c0, x_c1, y_c0, y_c1)
}
fn parse_public_inputs(parsed: &serde_json::Value) -> Result<Vec<Fr>, String> {
let mut inputs = Vec::new();
let arr = parsed
.get("publicSignals")
.or_else(|| parsed.get("public_inputs"))
.or_else(|| parsed.get("inputs"))
.and_then(|v| v.as_array());
if let Some(arr) = arr {
for val in arr {
let s = val.as_str().ok_or("Input must be string")?;
inputs.push(parsing::parse_fr_decimal(s)?);
}
}
Ok(inputs)
}
#[cfg(test)]
pub fn mock(public_inputs: Vec<Fr>) -> Self {
Self::new(
G1::generator().to_affine(),
G2::generator().to_affine(),
G1::generator().to_affine(),
public_inputs,
)
}
pub fn to_bytes(&self) -> Vec<u8> {
use group::GroupEncoding;
let mut bytes = Vec::with_capacity(32 + 64 + 32 + 4 + 32 * self.public_inputs.len());
bytes.extend_from_slice(self.pi_a.to_bytes().as_ref());
bytes.extend_from_slice(self.pi_b.to_bytes().as_ref());
bytes.extend_from_slice(self.pi_c.to_bytes().as_ref());
bytes.extend_from_slice(&(self.public_inputs.len() as u32).to_le_bytes());
for input in &self.public_inputs {
use ff::PrimeField;
bytes.extend_from_slice(input.to_repr().as_ref());
}
bytes
}
}
impl ExternalProof<Bn254> for SnarkjsProof {
fn to_accumulator_instances(&self) -> Result<Vec<AccumulatorInstance<Bn254>>, AdapterError> {
use ff::Field;
use group::Curve;
use sha2::{Sha256, Digest};
use group::GroupEncoding;
let mut hasher = Sha256::new();
hasher.update(self.pi_a.to_bytes().as_ref());
hasher.update(self.pi_b.to_bytes().as_ref());
hasher.update(self.pi_c.to_bytes().as_ref());
for input in &self.public_inputs {
use ff::PrimeField;
hasher.update(input.to_repr().as_ref());
}
let hash = hasher.finalize();
let mut scalar_bytes = [0u8; 32];
scalar_bytes.copy_from_slice(&hash);
let scalar = Fr::from_bytes(&scalar_bytes).unwrap_or(Fr::ONE);
let commitment = (halo2curves::bn256::G1::generator() * scalar).to_affine();
let instance = AccumulatorInstance {
commitment,
evaluation: Fr::ZERO,
point: Fr::ZERO,
quotient: commitment, };
Ok(vec![instance])
}
fn public_inputs(&self) -> &[Fr] {
&self.public_inputs
}
fn metadata(&self) -> ProofMetadata {
ProofMetadata {
system: "snarkjs",
proof_type: "groth16",
curve: "bn254",
num_public_inputs: self.public_inputs.len(),
}
}
fn validate_format(&self) -> Result<(), AdapterError> {
use group::prime::PrimeCurveAffine;
if self.pi_a.is_identity().into() {
return Err(AdapterError::InvalidPoint("pi_a is identity".to_string()));
}
if self.pi_c.is_identity().into() {
return Err(AdapterError::InvalidPoint("pi_c is identity".to_string()));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn snarkjs_proof_creation() {
let proof = SnarkjsProof::mock(vec![Fr::from(42u64)]);
assert_eq!(proof.public_inputs().len(), 1);
assert_eq!(proof.metadata().system, "snarkjs");
}
#[test]
fn snarkjs_to_accumulator() {
use ff::Field;
let proof = SnarkjsProof::mock(vec![Fr::from(100u64), Fr::from(200u64)]);
let instances = proof.to_accumulator_instances().unwrap();
assert_eq!(instances.len(), 1);
assert_eq!(instances[0].evaluation, Fr::ZERO);
assert_eq!(instances[0].point, Fr::ZERO);
assert_eq!(instances[0].commitment, instances[0].quotient);
}
#[test]
fn snarkjs_validate_format() {
let proof = SnarkjsProof::mock(vec![]);
assert!(proof.validate_format().is_ok());
}
#[test]
fn snarkjs_from_json_basic() {
let json = r#"{
"pi_a": ["1", "2", "1"],
"pi_b": [["1", "0"], ["1", "0"], ["1", "0"]],
"pi_c": ["1", "2", "1"],
"publicSignals": ["21", "3", "7"],
"protocol": "groth16",
"curve": "bn128"
}"#;
let proof = SnarkjsProof::from_json(json).unwrap();
assert_eq!(proof.public_inputs.len(), 3);
assert_eq!(proof.public_inputs[0], Fr::from(21u64));
}
#[test]
fn snarkjs_rejects_wrong_protocol() {
let json = r#"{
"pi_a": ["1", "2", "1"],
"pi_b": [["1", "0"], ["1", "0"], ["1", "0"]],
"pi_c": ["1", "2", "1"],
"protocol": "plonk"
}"#;
let result = SnarkjsProof::from_json(json);
assert!(result.is_err());
}
#[test]
fn snarkjs_to_bytes_roundtrip() {
let proof = SnarkjsProof::mock(vec![Fr::from(42u64), Fr::from(123u64)]);
let bytes = proof.to_bytes();
assert_eq!(bytes.len(), 32 + 64 + 32 + 4 + 32 * 2);
}
}