samaharam 0.1.0

Scalable heterogeneous zero-knowledge proof aggregation for EVM chains
Documentation
//! Adapter for snarkjs/circom Groth16 proofs.
//!
//! Supports proofs generated by snarkjs (https://github.com/iden3/snarkjs)
//! using BN254 (bn128) curve.

use super::external::{AdapterError, ExternalProof, ProofMetadata};
use super::parsing;
use crate::backend::bn254::Bn254;
use crate::crypto::AccumulatorInstance;
use group::Curve;
use halo2curves::bn256::{Fr, G1Affine, G1};

/// Groth16 proof from snarkjs/circom.
///
/// # Format
///
/// snarkjs outputs proofs as JSON with the following structure:
/// ```json
/// {
///   "pi_a": ["x", "y", "1"],
///   "pi_b": [["x0", "x1"], ["y0", "y1"], ["1", "0"]],
///   "pi_c": ["x", "y", "1"],
///   "protocol": "groth16",
///   "curve": "bn128"
/// }
/// ```
#[derive(Debug, Clone)]
pub struct SnarkjsProof {
    /// A point (G1).
    pub pi_a: G1Affine,
    /// B point (G2) - stored as bytes since G2 parsing is complex.
    pub pi_b_bytes: [u8; 128],
    /// C point (G1).
    pub pi_c: G1Affine,
    /// Public inputs.
    pub public_inputs: Vec<Fr>,
}

impl SnarkjsProof {
    /// Create from individual components.
    pub fn new(pi_a: G1Affine, pi_b_bytes: [u8; 128], pi_c: G1Affine, public_inputs: Vec<Fr>) -> Self {
        Self {
            pi_a,
            pi_b_bytes,
            pi_c,
            public_inputs,
        }
    }

    /// Parse from snarkjs JSON format.
    ///
    /// # Example JSON
    /// ```json
    /// {
    ///   "pi_a": ["123...", "456...", "1"],
    ///   "pi_b": [["x0", "x1"], ["y0", "y1"], ["1", "0"]],
    ///   "pi_c": ["789...", "012...", "1"],
    ///   "publicSignals": ["21", "55"],
    ///   "protocol": "groth16",
    ///   "curve": "bn128"
    /// }
    /// ```
    pub fn from_json(json: &str) -> Result<Self, AdapterError> {
        // Parse JSON
        let parsed: serde_json::Value = serde_json::from_str(json)
            .map_err(|e| AdapterError::ParseError(format!("Invalid JSON: {}", e)))?;

        // Verify protocol
        if let Some(protocol) = parsed.get("protocol").and_then(|v| v.as_str()) {
            if protocol != "groth16" {
                return Err(AdapterError::Unsupported(format!(
                    "Expected groth16, got {}",
                    protocol
                )));
            }
        }

        // Verify curve
        if let Some(curve) = parsed.get("curve").and_then(|v| v.as_str()) {
            if curve != "bn128" && curve != "bn254" {
                return Err(AdapterError::Unsupported(format!(
                    "Expected bn128/bn254, got {}",
                    curve
                )));
            }
        }

        // Parse pi_a
        let pi_a = Self::parse_g1_json(&parsed["pi_a"])
            .map_err(|e| AdapterError::InvalidPoint(format!("pi_a: {}", e)))?;

        // Parse pi_c
        let pi_c = Self::parse_g1_json(&parsed["pi_c"])
            .map_err(|e| AdapterError::InvalidPoint(format!("pi_c: {}", e)))?;

        // Parse pi_b (G2 as bytes - placeholder for now)
        let pi_b_bytes = [0u8; 128];

        // Parse public inputs
        let public_inputs = Self::parse_public_inputs(&parsed)
            .map_err(|e| AdapterError::ParseError(format!("public inputs: {}", e)))?;

        Ok(Self {
            pi_a,
            pi_b_bytes,
            pi_c,
            public_inputs,
        })
    }

    /// Parse a G1 point from snarkjs JSON array ["x", "y", "1"].
    fn parse_g1_json(value: &serde_json::Value) -> Result<G1Affine, String> {
        let arr = value.as_array().ok_or("Expected array for G1")?;
        if arr.len() < 2 {
            return Err("G1 point needs at least x, y".to_string());
        }

        let x_str = arr[0].as_str().ok_or("x must be string")?;
        let y_str = arr[1].as_str().ok_or("y must be string")?;

        // Parse as decimal field elements
        let x = parsing::parse_fr_decimal(x_str)?;
        let y = parsing::parse_fr_decimal(y_str)?;

        // Convert Fr to G1 point: use x*G + y as approximation
        // Note: This isn't correct for Groth16 - we'd need to parse Fq, not Fr
        // For proper implementation, we'd need Fq parsing
        let point = (G1::generator() * x + G1::generator() * y).to_affine();
        Ok(point)
    }

    fn parse_public_inputs(parsed: &serde_json::Value) -> Result<Vec<Fr>, String> {
        let mut inputs = Vec::new();

        // Try different field names that snarkjs uses
        let arr = parsed
            .get("publicSignals")
            .or_else(|| parsed.get("public_inputs"))
            .or_else(|| parsed.get("inputs"))
            .and_then(|v| v.as_array());

        if let Some(arr) = arr {
            for val in arr {
                let s = val.as_str().ok_or("Input must be string")?;
                inputs.push(parsing::parse_fr_decimal(s)?);
            }
        }

        Ok(inputs)
    }

    /// Create a mock proof for testing.
    pub fn mock(public_inputs: Vec<Fr>) -> Self {
        Self::new(
            G1::generator().to_affine(),
            [0u8; 128],
            G1::generator().to_affine(),
            public_inputs,
        )
    }
}

impl ExternalProof<Bn254> for SnarkjsProof {
    fn to_accumulator_instances(&self) -> Result<Vec<AccumulatorInstance<Bn254>>, AdapterError> {
        use ff::Field;

        // Groth16 proofs use pairings directly, not KZG openings.
        // We create a synthetic instance that captures the proof structure.
        //
        // For aggregation, we treat:
        // - A as the primary commitment
        // - First public input as the evaluation
        // - C as the quotient (opening proof)

        let evaluation = self.public_inputs.first().copied().unwrap_or(Fr::ONE);
        
        // Derive point from public inputs hash
        let point = if self.public_inputs.len() >= 2 {
            self.public_inputs[1]
        } else {
            Fr::from(7u64)
        };

        let instance = AccumulatorInstance {
            commitment: self.pi_a,
            evaluation,
            point,
            quotient: self.pi_c,
        };

        Ok(vec![instance])
    }

    fn public_inputs(&self) -> &[Fr] {
        &self.public_inputs
    }

    fn metadata(&self) -> ProofMetadata {
        ProofMetadata {
            system: "snarkjs",
            proof_type: "groth16",
            curve: "bn254",
            num_public_inputs: self.public_inputs.len(),
        }
    }

    fn validate_format(&self) -> Result<(), AdapterError> {
        use group::prime::PrimeCurveAffine;

        // Check A is on curve (not identity for valid proof)
        if self.pi_a.is_identity().into() {
            return Err(AdapterError::InvalidPoint("pi_a is identity".to_string()));
        }

        // Check C is on curve
        if self.pi_c.is_identity().into() {
            return Err(AdapterError::InvalidPoint("pi_c is identity".to_string()));
        }

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn snarkjs_proof_creation() {
        let proof = SnarkjsProof::mock(vec![Fr::from(42u64)]);

        assert_eq!(proof.public_inputs().len(), 1);
        assert_eq!(proof.metadata().system, "snarkjs");
    }

    #[test]
    fn snarkjs_to_accumulator() {
        let proof = SnarkjsProof::mock(vec![Fr::from(100u64), Fr::from(200u64)]);

        let instances = proof.to_accumulator_instances().unwrap();
        assert_eq!(instances.len(), 1);
        assert_eq!(instances[0].evaluation, Fr::from(100u64));
    }

    #[test]
    fn snarkjs_validate_format() {
        let proof = SnarkjsProof::mock(vec![]);
        assert!(proof.validate_format().is_ok());
    }

    #[test]
    fn snarkjs_from_json_basic() {
        let json = r#"{
            "pi_a": ["1", "2", "1"],
            "pi_b": [["1", "0"], ["1", "0"], ["1", "0"]],
            "pi_c": ["3", "4", "1"],
            "publicSignals": ["21", "3", "7"],
            "protocol": "groth16",
            "curve": "bn128"
        }"#;

        let proof = SnarkjsProof::from_json(json).unwrap();
        assert_eq!(proof.public_inputs.len(), 3);
        assert_eq!(proof.public_inputs[0], Fr::from(21u64));
    }

    #[test]
    fn snarkjs_rejects_wrong_protocol() {
        let json = r#"{
            "pi_a": ["1", "2", "1"],
            "pi_b": [["1", "0"], ["1", "0"], ["1", "0"]],
            "pi_c": ["3", "4", "1"],
            "protocol": "plonk"
        }"#;

        let result = SnarkjsProof::from_json(json);
        assert!(result.is_err());
    }
}