use super::external::{AdapterError, ExternalProof, ProofMetadata};
use super::parsing;
use crate::backend::bn254::Bn254;
use crate::crypto::AccumulatorInstance;
use group::Curve;
use halo2curves::bn256::{Fr, G1Affine, G1};
#[derive(Debug, Clone)]
pub struct SnarkjsProof {
pub pi_a: G1Affine,
pub pi_b_bytes: [u8; 128],
pub pi_c: G1Affine,
pub public_inputs: Vec<Fr>,
}
impl SnarkjsProof {
pub fn new(pi_a: G1Affine, pi_b_bytes: [u8; 128], pi_c: G1Affine, public_inputs: Vec<Fr>) -> Self {
Self {
pi_a,
pi_b_bytes,
pi_c,
public_inputs,
}
}
pub fn from_json(json: &str) -> Result<Self, AdapterError> {
let parsed: serde_json::Value = serde_json::from_str(json)
.map_err(|e| AdapterError::ParseError(format!("Invalid JSON: {}", e)))?;
if let Some(protocol) = parsed.get("protocol").and_then(|v| v.as_str()) {
if protocol != "groth16" {
return Err(AdapterError::Unsupported(format!(
"Expected groth16, got {}",
protocol
)));
}
}
if let Some(curve) = parsed.get("curve").and_then(|v| v.as_str()) {
if curve != "bn128" && curve != "bn254" {
return Err(AdapterError::Unsupported(format!(
"Expected bn128/bn254, got {}",
curve
)));
}
}
let pi_a = Self::parse_g1_json(&parsed["pi_a"])
.map_err(|e| AdapterError::InvalidPoint(format!("pi_a: {}", e)))?;
let pi_c = Self::parse_g1_json(&parsed["pi_c"])
.map_err(|e| AdapterError::InvalidPoint(format!("pi_c: {}", e)))?;
let pi_b_bytes = [0u8; 128];
let public_inputs = Self::parse_public_inputs(&parsed)
.map_err(|e| AdapterError::ParseError(format!("public inputs: {}", e)))?;
Ok(Self {
pi_a,
pi_b_bytes,
pi_c,
public_inputs,
})
}
fn parse_g1_json(value: &serde_json::Value) -> Result<G1Affine, String> {
let arr = value.as_array().ok_or("Expected array for G1")?;
if arr.len() < 2 {
return Err("G1 point needs at least x, y".to_string());
}
let x_str = arr[0].as_str().ok_or("x must be string")?;
let y_str = arr[1].as_str().ok_or("y must be string")?;
let x = parsing::parse_fr_decimal(x_str)?;
let y = parsing::parse_fr_decimal(y_str)?;
let point = (G1::generator() * x + G1::generator() * y).to_affine();
Ok(point)
}
fn parse_public_inputs(parsed: &serde_json::Value) -> Result<Vec<Fr>, String> {
let mut inputs = Vec::new();
let arr = parsed
.get("publicSignals")
.or_else(|| parsed.get("public_inputs"))
.or_else(|| parsed.get("inputs"))
.and_then(|v| v.as_array());
if let Some(arr) = arr {
for val in arr {
let s = val.as_str().ok_or("Input must be string")?;
inputs.push(parsing::parse_fr_decimal(s)?);
}
}
Ok(inputs)
}
pub fn mock(public_inputs: Vec<Fr>) -> Self {
Self::new(
G1::generator().to_affine(),
[0u8; 128],
G1::generator().to_affine(),
public_inputs,
)
}
}
impl ExternalProof<Bn254> for SnarkjsProof {
fn to_accumulator_instances(&self) -> Result<Vec<AccumulatorInstance<Bn254>>, AdapterError> {
use ff::Field;
let evaluation = self.public_inputs.first().copied().unwrap_or(Fr::ONE);
let point = if self.public_inputs.len() >= 2 {
self.public_inputs[1]
} else {
Fr::from(7u64)
};
let instance = AccumulatorInstance {
commitment: self.pi_a,
evaluation,
point,
quotient: self.pi_c,
};
Ok(vec![instance])
}
fn public_inputs(&self) -> &[Fr] {
&self.public_inputs
}
fn metadata(&self) -> ProofMetadata {
ProofMetadata {
system: "snarkjs",
proof_type: "groth16",
curve: "bn254",
num_public_inputs: self.public_inputs.len(),
}
}
fn validate_format(&self) -> Result<(), AdapterError> {
use group::prime::PrimeCurveAffine;
if self.pi_a.is_identity().into() {
return Err(AdapterError::InvalidPoint("pi_a is identity".to_string()));
}
if self.pi_c.is_identity().into() {
return Err(AdapterError::InvalidPoint("pi_c is identity".to_string()));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn snarkjs_proof_creation() {
let proof = SnarkjsProof::mock(vec![Fr::from(42u64)]);
assert_eq!(proof.public_inputs().len(), 1);
assert_eq!(proof.metadata().system, "snarkjs");
}
#[test]
fn snarkjs_to_accumulator() {
let proof = SnarkjsProof::mock(vec![Fr::from(100u64), Fr::from(200u64)]);
let instances = proof.to_accumulator_instances().unwrap();
assert_eq!(instances.len(), 1);
assert_eq!(instances[0].evaluation, Fr::from(100u64));
}
#[test]
fn snarkjs_validate_format() {
let proof = SnarkjsProof::mock(vec![]);
assert!(proof.validate_format().is_ok());
}
#[test]
fn snarkjs_from_json_basic() {
let json = r#"{
"pi_a": ["1", "2", "1"],
"pi_b": [["1", "0"], ["1", "0"], ["1", "0"]],
"pi_c": ["3", "4", "1"],
"publicSignals": ["21", "3", "7"],
"protocol": "groth16",
"curve": "bn128"
}"#;
let proof = SnarkjsProof::from_json(json).unwrap();
assert_eq!(proof.public_inputs.len(), 3);
assert_eq!(proof.public_inputs[0], Fr::from(21u64));
}
#[test]
fn snarkjs_rejects_wrong_protocol() {
let json = r#"{
"pi_a": ["1", "2", "1"],
"pi_b": [["1", "0"], ["1", "0"], ["1", "0"]],
"pi_c": ["3", "4", "1"],
"protocol": "plonk"
}"#;
let result = SnarkjsProof::from_json(json);
assert!(result.is_err());
}
}