Skip to main content

export_aptos_verifier_core/
model.rs

1use crate::error::{Error, Result};
2use crate::snarkjs::{
3    validate_curve_match, validate_protocol, validate_public_counts,
4    validate_verification_key_geometry, Proof as LegacyProof, SnarkJsG1, SnarkJsG2,
5    VerificationKey as LegacyVerificationKey,
6};
7
8pub type DecimalValue = String;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum CurveKind {
12    Bn254,
13    Bls12_381,
14}
15
16impl CurveKind {
17    pub fn from_name(value: &str) -> Result<Self> {
18        match normalize_curve_name(value).as_str() {
19            "bn128" | "bn254" | "altbn128" => Ok(Self::Bn254),
20            "bls12381" => Ok(Self::Bls12_381),
21            _ => Err(Error::UnsupportedCurve(format!(
22                "unsupported curve: {value}"
23            ))),
24        }
25    }
26
27    pub fn canonical_name(self) -> &'static str {
28        match self {
29            Self::Bn254 => "bn254",
30            Self::Bls12_381 => "bls12381",
31        }
32    }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum SourceFormat {
37    SnarkjsJson,
38    ArkworksCompact,
39}
40
41#[derive(Debug, Clone)]
42pub struct Groth16G1Point {
43    pub x: DecimalValue,
44    pub y: DecimalValue,
45    pub z: DecimalValue,
46}
47
48#[derive(Debug, Clone)]
49pub struct Groth16G2Point {
50    pub x0: DecimalValue,
51    pub x1: DecimalValue,
52    pub y0: DecimalValue,
53    pub y1: DecimalValue,
54    pub z0: DecimalValue,
55    pub z1: DecimalValue,
56}
57
58#[derive(Debug, Clone)]
59pub struct Groth16VerificationKey {
60    pub n_public: usize,
61    pub vk_alpha_1: Groth16G1Point,
62    pub vk_beta_2: Groth16G2Point,
63    pub vk_gamma_2: Groth16G2Point,
64    pub vk_delta_2: Groth16G2Point,
65    pub ic: Vec<Groth16G1Point>,
66}
67
68#[derive(Debug, Clone)]
69pub struct Groth16Proof {
70    pub pi_a: Groth16G1Point,
71    pub pi_b: Groth16G2Point,
72    pub pi_c: Groth16G1Point,
73}
74
75#[derive(Debug, Clone)]
76pub struct Groth16VerifierInputs {
77    pub curve: CurveKind,
78    pub protocol: String,
79    pub verifying_key: Groth16VerificationKey,
80    pub proof: Groth16Proof,
81    pub public_inputs: Vec<DecimalValue>,
82    pub source_format: SourceFormat,
83}
84
85impl Groth16VerifierInputs {
86    pub fn from_legacy(
87        vk: LegacyVerificationKey,
88        proof: LegacyProof,
89        public_inputs: Vec<DecimalValue>,
90        source_format: SourceFormat,
91    ) -> Result<Self> {
92        validate_protocol(vk.protocol.as_ref(), proof.protocol.as_ref())?;
93        validate_verification_key_geometry(&vk)?;
94        validate_public_counts(&vk, &public_inputs)?;
95
96        let curve_name = validate_curve_match(vk.curve.as_ref(), proof.curve.as_ref())?;
97        let curve = CurveKind::from_name(&curve_name)?;
98        let protocol = vk
99            .protocol
100            .clone()
101            .or_else(|| proof.protocol.clone())
102            .unwrap_or_else(|| "groth16".to_string());
103
104        Ok(Self {
105            curve,
106            protocol,
107            verifying_key: Groth16VerificationKey {
108                n_public: vk.n_public,
109                vk_alpha_1: vk.vk_alpha_1.into(),
110                vk_beta_2: vk.vk_beta_2.into(),
111                vk_gamma_2: vk.vk_gamma_2.into(),
112                vk_delta_2: vk.vk_delta_2.into(),
113                ic: vk.ic.into_iter().map(Into::into).collect(),
114            },
115            proof: Groth16Proof {
116                pi_a: proof.pi_a.into(),
117                pi_b: proof.pi_b.into(),
118                pi_c: proof.pi_c.into(),
119            },
120            public_inputs,
121            source_format,
122        })
123    }
124}
125
126impl From<SnarkJsG1> for Groth16G1Point {
127    fn from(value: SnarkJsG1) -> Self {
128        Self {
129            x: value.x,
130            y: value.y,
131            z: value.z,
132        }
133    }
134}
135
136impl From<SnarkJsG2> for Groth16G2Point {
137    fn from(value: SnarkJsG2) -> Self {
138        Self {
139            x0: value.x0,
140            x1: value.x1,
141            y0: value.y0,
142            y1: value.y1,
143            z0: value.z0,
144            z1: value.z1,
145        }
146    }
147}
148
149fn normalize_curve_name(value: &str) -> String {
150    value.to_lowercase().replace(['-', '_'], "")
151}