export-aptos-verifier-core 0.3.0

Load Groth16 artifacts from snarkjs, Gnark, SP1, and Arkworks inputs and generate Aptos Move verifier packages.
Documentation
use ark_bls12_381::{Bls12_381, G1Affine as BlsG1Affine, G2Affine as BlsG2Affine};
use ark_bn254::{Bn254, G1Affine as Bn254G1Affine, G2Affine as Bn254G2Affine};
use ark_groth16::{Proof as ArkProof, VerifyingKey as ArkVerifyingKey};
use ark_serialize::CanonicalDeserialize;
use num_bigint::BigUint;
use serde_json::Value;
use std::fs;
use std::io::Cursor;
use std::path::Path;

use crate::error::{Error, Result};
use crate::model::{
    CurveKind, Groth16G1Point, Groth16G2Point, Groth16Proof, Groth16VerificationKey,
    Groth16VerifierInputs, SourceFormat,
};
use crate::snarkjs::parse_decimal;

pub fn load_arkworks_bundle(
    path: &Path,
    curve_hint: Option<&str>,
) -> Result<Groth16VerifierInputs> {
    let value = read_json_or_string(path)?;
    let curve = parse_curve(value.get("curve").and_then(Value::as_str), curve_hint)?;
    let vk_hex = string_from_keys(&value, &["vk", "verification_key", "verifying_key"], "vk")?;
    let vk = decode_vk(&curve, &vk_hex)?;

    let proof = optional_string_from_keys(&value, &["proof", "proof_bytes"])
        .map(|hex| decode_proof(&curve, &hex))
        .transpose()?;

    let public_inputs = match optional_value_from_keys(&value, &["public_input", "public_inputs"]) {
        Some(public) => parse_public_inputs(public)?,
        None => Vec::new(),
    };

    Groth16VerifierInputs::from_parts(curve, vk, proof, public_inputs, SourceFormat::Arkworks)
}

pub fn load_arkworks_inputs(
    vk_path: &Path,
    proof_path: Option<&Path>,
    public_path: Option<&Path>,
    curve_hint: Option<&str>,
) -> Result<Groth16VerifierInputs> {
    let vk_value = read_json_or_string(vk_path)?;
    let curve = parse_curve(vk_value.get("curve").and_then(Value::as_str), curve_hint)?;
    let vk_hex = if vk_value.is_string() {
        vk_value.as_str().unwrap_or_default().to_string()
    } else {
        string_from_keys(
            &vk_value,
            &["vk", "verification_key", "verifying_key"],
            "vk",
        )?
    };
    let vk = decode_vk(&curve, &vk_hex)?;

    let proof = match proof_path {
        Some(path) => {
            let proof_value = read_json_or_string(path)?;
            let proof_hex = if proof_value.is_string() {
                proof_value.as_str().unwrap_or_default().to_string()
            } else {
                string_from_keys(&proof_value, &["proof", "proof_bytes"], "proof")?
            };
            Some(decode_proof(&curve, &proof_hex)?)
        }
        None => optional_string_from_keys(&vk_value, &["proof", "proof_bytes"])
            .map(|hex| decode_proof(&curve, &hex))
            .transpose()?,
    };

    let public_inputs = match public_path {
        Some(path) => parse_public_inputs(&read_json_or_string(path)?)?,
        None => match optional_value_from_keys(&vk_value, &["public_input", "public_inputs"]) {
            Some(public) => parse_public_inputs(public)?,
            None => Vec::new(),
        },
    };

    Groth16VerifierInputs::from_parts(curve, vk, proof, public_inputs, SourceFormat::Arkworks)
}

fn read_json_or_string(path: &Path) -> Result<Value> {
    let content = fs::read_to_string(path).map_err(|e| Error::Io {
        source: e,
        context: format!("failed to read file {}", path.display()),
    })?;
    let trimmed = content.trim();
    match serde_json::from_str::<Value>(trimmed) {
        Ok(value) => Ok(value),
        Err(_) => Ok(Value::String(trimmed.to_string())),
    }
}

fn parse_curve(value: Option<&str>, curve_hint: Option<&str>) -> Result<CurveKind> {
    let raw = value
        .or(curve_hint)
        .ok_or_else(|| Error::MissingInput("arkworks input requires curve metadata".to_string()))?;
    CurveKind::from_name(raw)
}

fn string_from_keys(value: &Value, keys: &[&str], field: &str) -> Result<String> {
    optional_string_from_keys(value, keys)
        .ok_or_else(|| Error::MissingInput(format!("arkworks input requires {field} hex field")))
}

fn optional_string_from_keys(value: &Value, keys: &[&str]) -> Option<String> {
    if let Some(s) = value.as_str() {
        return Some(s.to_string());
    }
    keys.iter().find_map(|key| {
        value
            .get(*key)
            .and_then(Value::as_str)
            .map(ToString::to_string)
    })
}

fn optional_value_from_keys<'a>(value: &'a Value, keys: &[&str]) -> Option<&'a Value> {
    keys.iter().find_map(|key| value.get(*key))
}

fn decode_vk(curve: &CurveKind, raw: &str) -> Result<Groth16VerificationKey> {
    match curve {
        CurveKind::Bn254 => decode_vk_bn254(raw),
        CurveKind::Bls12_381 => decode_vk_bls12381(raw),
    }
}

fn decode_proof(curve: &CurveKind, raw: &str) -> Result<Groth16Proof> {
    match curve {
        CurveKind::Bn254 => decode_proof_bn254(raw),
        CurveKind::Bls12_381 => decode_proof_bls12381(raw),
    }
}

fn decode_vk_bn254(raw: &str) -> Result<Groth16VerificationKey> {
    let bytes = decode_hex(raw, "vk")?;
    let mut cursor = Cursor::new(bytes);
    let vk = ArkVerifyingKey::<Bn254>::deserialize_compressed(&mut cursor).map_err(|e| {
        Error::Serialization(format!("failed to deserialize BN254 verifying key: {e:?}"))
    })?;
    ensure_no_trailing_bytes(&cursor, "BN254 verifying key")?;
    Ok(Groth16VerificationKey {
        n_public: vk.gamma_abc_g1.len().saturating_sub(1),
        vk_alpha_1: point_from_bn254_g1(&vk.alpha_g1),
        vk_beta_2: point_from_bn254_g2(&vk.beta_g2),
        vk_gamma_2: point_from_bn254_g2(&vk.gamma_g2),
        vk_delta_2: point_from_bn254_g2(&vk.delta_g2),
        ic: vk.gamma_abc_g1.iter().map(point_from_bn254_g1).collect(),
    })
}

fn decode_proof_bn254(raw: &str) -> Result<Groth16Proof> {
    let bytes = decode_hex(raw, "proof")?;
    let mut cursor = Cursor::new(bytes);
    let proof = ArkProof::<Bn254>::deserialize_compressed(&mut cursor)
        .map_err(|e| Error::Serialization(format!("failed to deserialize BN254 proof: {e:?}")))?;
    ensure_no_trailing_bytes(&cursor, "BN254 proof")?;
    Ok(Groth16Proof {
        pi_a: point_from_bn254_g1(&proof.a),
        pi_b: point_from_bn254_g2(&proof.b),
        pi_c: point_from_bn254_g1(&proof.c),
    })
}

fn decode_vk_bls12381(raw: &str) -> Result<Groth16VerificationKey> {
    let bytes = decode_hex(raw, "vk")?;
    let mut cursor = Cursor::new(bytes);
    let vk = ArkVerifyingKey::<Bls12_381>::deserialize_compressed(&mut cursor).map_err(|e| {
        Error::Serialization(format!(
            "failed to deserialize BLS12-381 verifying key: {e:?}"
        ))
    })?;
    ensure_no_trailing_bytes(&cursor, "BLS12-381 verifying key")?;
    Ok(Groth16VerificationKey {
        n_public: vk.gamma_abc_g1.len().saturating_sub(1),
        vk_alpha_1: point_from_bls_g1(&vk.alpha_g1),
        vk_beta_2: point_from_bls_g2(&vk.beta_g2),
        vk_gamma_2: point_from_bls_g2(&vk.gamma_g2),
        vk_delta_2: point_from_bls_g2(&vk.delta_g2),
        ic: vk.gamma_abc_g1.iter().map(point_from_bls_g1).collect(),
    })
}

fn decode_proof_bls12381(raw: &str) -> Result<Groth16Proof> {
    let bytes = decode_hex(raw, "proof")?;
    let mut cursor = Cursor::new(bytes);
    let proof = ArkProof::<Bls12_381>::deserialize_compressed(&mut cursor).map_err(|e| {
        Error::Serialization(format!("failed to deserialize BLS12-381 proof: {e:?}"))
    })?;
    ensure_no_trailing_bytes(&cursor, "BLS12-381 proof")?;
    Ok(Groth16Proof {
        pi_a: point_from_bls_g1(&proof.a),
        pi_b: point_from_bls_g2(&proof.b),
        pi_c: point_from_bls_g1(&proof.c),
    })
}

fn ensure_no_trailing_bytes(cursor: &Cursor<Vec<u8>>, field: &str) -> Result<()> {
    let position: usize = cursor
        .position()
        .try_into()
        .map_err(|_| Error::Serialization(format!("{field} cursor position overflow")))?;
    let len = cursor.get_ref().len();
    if position == len {
        return Ok(());
    }

    Err(Error::Serialization(format!(
        "{field} has {} trailing bytes after Arkworks compressed artifact",
        len.saturating_sub(position)
    )))
}

fn point_from_bn254_g1(point: &Bn254G1Affine) -> Groth16G1Point {
    Groth16G1Point {
        x: point.x.to_string(),
        y: point.y.to_string(),
        z: "1".to_string(),
    }
}

fn point_from_bn254_g2(point: &Bn254G2Affine) -> Groth16G2Point {
    Groth16G2Point {
        x0: point.x.c0.to_string(),
        x1: point.x.c1.to_string(),
        y0: point.y.c0.to_string(),
        y1: point.y.c1.to_string(),
        z0: "1".to_string(),
        z1: "0".to_string(),
    }
}

fn point_from_bls_g1(point: &BlsG1Affine) -> Groth16G1Point {
    Groth16G1Point {
        x: point.x.to_string(),
        y: point.y.to_string(),
        z: "1".to_string(),
    }
}

fn point_from_bls_g2(point: &BlsG2Affine) -> Groth16G2Point {
    Groth16G2Point {
        x0: point.x.c0.to_string(),
        x1: point.x.c1.to_string(),
        y0: point.y.c0.to_string(),
        y1: point.y.c1.to_string(),
        z0: "1".to_string(),
        z1: "0".to_string(),
    }
}

fn parse_public_inputs(value: &Value) -> Result<Vec<String>> {
    match value {
        Value::Array(values) => values
            .iter()
            .enumerate()
            .map(|(idx, value)| parse_public_input_value(value, &format!("public_inputs[{idx}]")))
            .collect(),
        Value::Object(_) => {
            match optional_value_from_keys(value, &["public_input", "public_inputs"]) {
                Some(inner) => parse_public_inputs(inner),
                None => Err(Error::MissingInput(
                    "public input JSON object requires public_input or public_inputs".to_string(),
                )),
            }
        }
        _ => parse_public_input_value(value, "public_input").map(|value| vec![value]),
    }
}

fn parse_public_input_value(value: &Value, field_name: &str) -> Result<String> {
    match value {
        Value::String(raw) => parse_public_input_string(raw, field_name),
        Value::Number(num) => {
            let decimal = num.to_string();
            parse_decimal(&decimal, field_name)?;
            Ok(decimal)
        }
        _ => Err(Error::DecimalParse(format!(
            "{field_name} must be string or number"
        ))),
    }
}

fn parse_public_input_string(raw: &str, field_name: &str) -> Result<String> {
    let trimmed = raw.trim();
    let without_prefix = trimmed
        .strip_prefix("0x")
        .or_else(|| trimmed.strip_prefix("0X"))
        .unwrap_or(trimmed);

    if without_prefix.len() >= 64
        && without_prefix.len().is_multiple_of(64)
        && without_prefix.chars().all(|c| c.is_ascii_hexdigit())
    {
        let bytes = decode_hex(without_prefix, field_name)?;
        if bytes.len() == 32 {
            return Ok(BigUint::from_bytes_le(&bytes).to_string());
        }
        return Err(Error::MissingInput(format!(
            "{field_name} contains multiple public inputs; use an array or public_inputs field"
        )));
    }

    parse_scalar(trimmed, field_name)
}

fn parse_scalar(raw: &str, field_name: &str) -> Result<String> {
    let has_hex_prefix = raw.starts_with("0x") || raw.starts_with("0X");
    let value = if has_hex_prefix { &raw[2..] } else { raw };
    if value.chars().all(|c| c.is_ascii_digit()) {
        parse_decimal(value, field_name)?;
        return Ok(value.to_string());
    }
    if value.chars().all(|c| c.is_ascii_hexdigit()) {
        let parsed = BigUint::parse_bytes(value.as_bytes(), 16).ok_or_else(|| {
            Error::DecimalParse(format!("{field_name} could not parse hex scalar {raw}"))
        })?;
        return Ok(parsed.to_string());
    }
    Err(Error::DecimalParse(format!(
        "{field_name} must be decimal or hex string"
    )))
}

fn decode_hex(raw: &str, field: &str) -> Result<Vec<u8>> {
    let hex = raw.trim().trim_start_matches("0x").trim_start_matches("0X");
    if hex.is_empty() {
        return Err(Error::HexParse(format!("{field} must not be empty")));
    }
    if !hex.chars().all(|c| c.is_ascii_hexdigit()) {
        return Err(Error::HexParse(format!(
            "{field} must be a hex string, got {raw}"
        )));
    }
    if !hex.len().is_multiple_of(2) {
        return Err(Error::HexParse(format!(
            "{field} has odd hex length, got {hex}"
        )));
    }
    hex::decode(hex).map_err(|e| Error::HexParse(format!("{field}: {e}")))
}