Skip to main content

export_aptos_verifier_core/
model.rs

1use crate::error::{Error, Result};
2use crate::snarkjs::{
3    validate_curve_match, validate_protocol, validate_public_counts,
4    validate_verification_key_geometry, Proof as LegacyProof, SnarkJsG1, SnarkJsG2,
5    VerificationKey as LegacyVerificationKey,
6};
7
8pub type DecimalValue = String;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum CurveKind {
12    Bn254,
13    Bls12_381,
14}
15
16impl CurveKind {
17    pub fn from_name(value: &str) -> Result<Self> {
18        match normalize_curve_name(value).as_str() {
19            "bn128" | "bn254" | "altbn128" => Ok(Self::Bn254),
20            "bls12381" => Ok(Self::Bls12_381),
21            _ => Err(Error::UnsupportedCurve(format!(
22                "unsupported curve: {value}"
23            ))),
24        }
25    }
26
27    pub fn canonical_name(self) -> &'static str {
28        match self {
29            Self::Bn254 => "bn254",
30            Self::Bls12_381 => "bls12381",
31        }
32    }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum SourceFormat {
37    SnarkjsJson,
38    Arkworks,
39    ArkworksCompact,
40    GnarkJson,
41    GnarkBinary,
42    Sp1Groth16,
43}
44
45#[derive(Debug, Clone)]
46pub struct Groth16G1Point {
47    pub x: DecimalValue,
48    pub y: DecimalValue,
49    pub z: DecimalValue,
50}
51
52#[derive(Debug, Clone)]
53pub struct Groth16G2Point {
54    pub x0: DecimalValue,
55    pub x1: DecimalValue,
56    pub y0: DecimalValue,
57    pub y1: DecimalValue,
58    pub z0: DecimalValue,
59    pub z1: DecimalValue,
60}
61
62#[derive(Debug, Clone)]
63pub struct Groth16VerificationKey {
64    pub n_public: usize,
65    pub vk_alpha_1: Groth16G1Point,
66    pub vk_beta_2: Groth16G2Point,
67    pub vk_gamma_2: Groth16G2Point,
68    pub vk_delta_2: Groth16G2Point,
69    pub ic: Vec<Groth16G1Point>,
70}
71
72#[derive(Debug, Clone)]
73pub struct Groth16Proof {
74    pub pi_a: Groth16G1Point,
75    pub pi_b: Groth16G2Point,
76    pub pi_c: Groth16G1Point,
77}
78
79#[derive(Debug, Clone)]
80pub struct Groth16VerifierInputs {
81    pub curve: CurveKind,
82    pub protocol: String,
83    pub verifying_key: Groth16VerificationKey,
84    pub proof: Option<Groth16Proof>,
85    pub public_inputs: Vec<DecimalValue>,
86    pub source_format: SourceFormat,
87}
88
89impl Groth16VerifierInputs {
90    pub fn from_legacy(
91        vk: LegacyVerificationKey,
92        proof: LegacyProof,
93        public_inputs: Vec<DecimalValue>,
94        source_format: SourceFormat,
95    ) -> Result<Self> {
96        validate_protocol(vk.protocol.as_ref(), proof.protocol.as_ref())?;
97        validate_verification_key_geometry(&vk)?;
98        validate_public_counts(&vk, &public_inputs)?;
99
100        let curve_name = validate_curve_match(vk.curve.as_ref(), proof.curve.as_ref())?;
101        let curve = CurveKind::from_name(&curve_name)?;
102        let protocol = vk
103            .protocol
104            .clone()
105            .or_else(|| proof.protocol.clone())
106            .unwrap_or_else(|| "groth16".to_string());
107
108        Ok(Self {
109            curve,
110            protocol,
111            verifying_key: Groth16VerificationKey {
112                n_public: vk.n_public,
113                vk_alpha_1: vk.vk_alpha_1.into(),
114                vk_beta_2: vk.vk_beta_2.into(),
115                vk_gamma_2: vk.vk_gamma_2.into(),
116                vk_delta_2: vk.vk_delta_2.into(),
117                ic: vk.ic.into_iter().map(Into::into).collect(),
118            },
119            proof: Some(Groth16Proof {
120                pi_a: proof.pi_a.into(),
121                pi_b: proof.pi_b.into(),
122                pi_c: proof.pi_c.into(),
123            }),
124            public_inputs,
125            source_format,
126        })
127    }
128
129    pub fn from_legacy_vk_only(
130        vk: LegacyVerificationKey,
131        public_inputs: Vec<DecimalValue>,
132        source_format: SourceFormat,
133    ) -> Result<Self> {
134        validate_protocol(vk.protocol.as_ref(), None)?;
135        validate_verification_key_geometry(&vk)?;
136
137        if vk.ic.len() != vk.n_public + 1 {
138            return Err(Error::IcLengthMismatch(format!(
139                "expected IC length = nPublic + 1, got {}",
140                vk.ic.len()
141            )));
142        }
143        if !public_inputs.is_empty() && vk.n_public != public_inputs.len() {
144            return Err(Error::PublicInputCountMismatch(format!(
145                "expected nPublic={}, got {}",
146                vk.n_public,
147                public_inputs.len()
148            )));
149        }
150
151        let curve_name = validate_curve_match(vk.curve.as_ref(), None)?;
152        let curve = CurveKind::from_name(&curve_name)?;
153        let protocol = vk.protocol.clone().unwrap_or_else(|| "groth16".to_string());
154
155        Ok(Self {
156            curve,
157            protocol,
158            verifying_key: Groth16VerificationKey {
159                n_public: vk.n_public,
160                vk_alpha_1: vk.vk_alpha_1.into(),
161                vk_beta_2: vk.vk_beta_2.into(),
162                vk_gamma_2: vk.vk_gamma_2.into(),
163                vk_delta_2: vk.vk_delta_2.into(),
164                ic: vk.ic.into_iter().map(Into::into).collect(),
165            },
166            proof: None,
167            public_inputs,
168            source_format,
169        })
170    }
171
172    pub fn from_parts(
173        curve: CurveKind,
174        verifying_key: Groth16VerificationKey,
175        proof: Option<Groth16Proof>,
176        public_inputs: Vec<DecimalValue>,
177        source_format: SourceFormat,
178    ) -> Result<Self> {
179        if verifying_key.ic.len() != verifying_key.n_public + 1 {
180            return Err(Error::IcLengthMismatch(format!(
181                "expected {} IC points, got {}",
182                verifying_key.n_public + 1,
183                verifying_key.ic.len()
184            )));
185        }
186        if proof.is_some() && public_inputs.len() != verifying_key.n_public {
187            return Err(Error::PublicInputCountMismatch(format!(
188                "verification key expects {} public inputs, got {}",
189                verifying_key.n_public,
190                public_inputs.len()
191            )));
192        }
193        if proof.is_none()
194            && !public_inputs.is_empty()
195            && public_inputs.len() != verifying_key.n_public
196        {
197            return Err(Error::PublicInputCountMismatch(format!(
198                "verification key expects {} public inputs, got {}",
199                verifying_key.n_public,
200                public_inputs.len()
201            )));
202        }
203
204        Ok(Self {
205            curve,
206            protocol: "groth16".to_string(),
207            verifying_key,
208            proof,
209            public_inputs,
210            source_format,
211        })
212    }
213
214    pub fn has_test_vectors(&self) -> bool {
215        self.proof.is_some()
216    }
217}
218
219impl From<SnarkJsG1> for Groth16G1Point {
220    fn from(value: SnarkJsG1) -> Self {
221        Self {
222            x: value.x,
223            y: value.y,
224            z: value.z,
225        }
226    }
227}
228
229impl From<SnarkJsG2> for Groth16G2Point {
230    fn from(value: SnarkJsG2) -> Self {
231        Self {
232            x0: value.x0,
233            x1: value.x1,
234            y0: value.y0,
235            y1: value.y1,
236            z0: value.z0,
237            z1: value.z1,
238        }
239    }
240}
241
242fn normalize_curve_name(value: &str) -> String {
243    value.to_lowercase().replace(['-', '_'], "")
244}