export_aptos_verifier_core/
model.rs1use crate::error::{Error, Result};
2use crate::snarkjs::{
3 validate_curve_match, validate_protocol, validate_public_counts,
4 validate_verification_key_geometry, Proof as LegacyProof, SnarkJsG1, SnarkJsG2,
5 VerificationKey as LegacyVerificationKey,
6};
7
8pub type DecimalValue = String;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum CurveKind {
12 Bn254,
13 Bls12_381,
14}
15
16impl CurveKind {
17 pub fn from_name(value: &str) -> Result<Self> {
18 match normalize_curve_name(value).as_str() {
19 "bn128" | "bn254" | "altbn128" => Ok(Self::Bn254),
20 "bls12381" => Ok(Self::Bls12_381),
21 _ => Err(Error::UnsupportedCurve(format!(
22 "unsupported curve: {value}"
23 ))),
24 }
25 }
26
27 pub fn canonical_name(self) -> &'static str {
28 match self {
29 Self::Bn254 => "bn254",
30 Self::Bls12_381 => "bls12381",
31 }
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum SourceFormat {
37 SnarkjsJson,
38 Arkworks,
39 ArkworksCompact,
40 GnarkJson,
41 GnarkBinary,
42 Sp1Groth16,
43}
44
45#[derive(Debug, Clone)]
46pub struct Groth16G1Point {
47 pub x: DecimalValue,
48 pub y: DecimalValue,
49 pub z: DecimalValue,
50}
51
52#[derive(Debug, Clone)]
53pub struct Groth16G2Point {
54 pub x0: DecimalValue,
55 pub x1: DecimalValue,
56 pub y0: DecimalValue,
57 pub y1: DecimalValue,
58 pub z0: DecimalValue,
59 pub z1: DecimalValue,
60}
61
62#[derive(Debug, Clone)]
63pub struct Groth16VerificationKey {
64 pub n_public: usize,
65 pub vk_alpha_1: Groth16G1Point,
66 pub vk_beta_2: Groth16G2Point,
67 pub vk_gamma_2: Groth16G2Point,
68 pub vk_delta_2: Groth16G2Point,
69 pub ic: Vec<Groth16G1Point>,
70}
71
72#[derive(Debug, Clone)]
73pub struct Groth16Proof {
74 pub pi_a: Groth16G1Point,
75 pub pi_b: Groth16G2Point,
76 pub pi_c: Groth16G1Point,
77}
78
79#[derive(Debug, Clone)]
80pub struct Groth16VerifierInputs {
81 pub curve: CurveKind,
82 pub protocol: String,
83 pub verifying_key: Groth16VerificationKey,
84 pub proof: Option<Groth16Proof>,
85 pub public_inputs: Vec<DecimalValue>,
86 pub source_format: SourceFormat,
87}
88
89impl Groth16VerifierInputs {
90 pub fn from_legacy(
91 vk: LegacyVerificationKey,
92 proof: LegacyProof,
93 public_inputs: Vec<DecimalValue>,
94 source_format: SourceFormat,
95 ) -> Result<Self> {
96 validate_protocol(vk.protocol.as_ref(), proof.protocol.as_ref())?;
97 validate_verification_key_geometry(&vk)?;
98 validate_public_counts(&vk, &public_inputs)?;
99
100 let curve_name = validate_curve_match(vk.curve.as_ref(), proof.curve.as_ref())?;
101 let curve = CurveKind::from_name(&curve_name)?;
102 let protocol = vk
103 .protocol
104 .clone()
105 .or_else(|| proof.protocol.clone())
106 .unwrap_or_else(|| "groth16".to_string());
107
108 Ok(Self {
109 curve,
110 protocol,
111 verifying_key: Groth16VerificationKey {
112 n_public: vk.n_public,
113 vk_alpha_1: vk.vk_alpha_1.into(),
114 vk_beta_2: vk.vk_beta_2.into(),
115 vk_gamma_2: vk.vk_gamma_2.into(),
116 vk_delta_2: vk.vk_delta_2.into(),
117 ic: vk.ic.into_iter().map(Into::into).collect(),
118 },
119 proof: Some(Groth16Proof {
120 pi_a: proof.pi_a.into(),
121 pi_b: proof.pi_b.into(),
122 pi_c: proof.pi_c.into(),
123 }),
124 public_inputs,
125 source_format,
126 })
127 }
128
129 pub fn from_legacy_vk_only(
130 vk: LegacyVerificationKey,
131 public_inputs: Vec<DecimalValue>,
132 source_format: SourceFormat,
133 ) -> Result<Self> {
134 validate_protocol(vk.protocol.as_ref(), None)?;
135 validate_verification_key_geometry(&vk)?;
136
137 if vk.ic.len() != vk.n_public + 1 {
138 return Err(Error::IcLengthMismatch(format!(
139 "expected IC length = nPublic + 1, got {}",
140 vk.ic.len()
141 )));
142 }
143 if !public_inputs.is_empty() && vk.n_public != public_inputs.len() {
144 return Err(Error::PublicInputCountMismatch(format!(
145 "expected nPublic={}, got {}",
146 vk.n_public,
147 public_inputs.len()
148 )));
149 }
150
151 let curve_name = validate_curve_match(vk.curve.as_ref(), None)?;
152 let curve = CurveKind::from_name(&curve_name)?;
153 let protocol = vk.protocol.clone().unwrap_or_else(|| "groth16".to_string());
154
155 Ok(Self {
156 curve,
157 protocol,
158 verifying_key: Groth16VerificationKey {
159 n_public: vk.n_public,
160 vk_alpha_1: vk.vk_alpha_1.into(),
161 vk_beta_2: vk.vk_beta_2.into(),
162 vk_gamma_2: vk.vk_gamma_2.into(),
163 vk_delta_2: vk.vk_delta_2.into(),
164 ic: vk.ic.into_iter().map(Into::into).collect(),
165 },
166 proof: None,
167 public_inputs,
168 source_format,
169 })
170 }
171
172 pub fn from_parts(
173 curve: CurveKind,
174 verifying_key: Groth16VerificationKey,
175 proof: Option<Groth16Proof>,
176 public_inputs: Vec<DecimalValue>,
177 source_format: SourceFormat,
178 ) -> Result<Self> {
179 if verifying_key.ic.len() != verifying_key.n_public + 1 {
180 return Err(Error::IcLengthMismatch(format!(
181 "expected {} IC points, got {}",
182 verifying_key.n_public + 1,
183 verifying_key.ic.len()
184 )));
185 }
186 if proof.is_some() && public_inputs.len() != verifying_key.n_public {
187 return Err(Error::PublicInputCountMismatch(format!(
188 "verification key expects {} public inputs, got {}",
189 verifying_key.n_public,
190 public_inputs.len()
191 )));
192 }
193 if proof.is_none()
194 && !public_inputs.is_empty()
195 && public_inputs.len() != verifying_key.n_public
196 {
197 return Err(Error::PublicInputCountMismatch(format!(
198 "verification key expects {} public inputs, got {}",
199 verifying_key.n_public,
200 public_inputs.len()
201 )));
202 }
203
204 Ok(Self {
205 curve,
206 protocol: "groth16".to_string(),
207 verifying_key,
208 proof,
209 public_inputs,
210 source_format,
211 })
212 }
213
214 pub fn has_test_vectors(&self) -> bool {
215 self.proof.is_some()
216 }
217}
218
219impl From<SnarkJsG1> for Groth16G1Point {
220 fn from(value: SnarkJsG1) -> Self {
221 Self {
222 x: value.x,
223 y: value.y,
224 z: value.z,
225 }
226 }
227}
228
229impl From<SnarkJsG2> for Groth16G2Point {
230 fn from(value: SnarkJsG2) -> Self {
231 Self {
232 x0: value.x0,
233 x1: value.x1,
234 y0: value.y0,
235 y1: value.y1,
236 z0: value.z0,
237 z1: value.z1,
238 }
239 }
240}
241
242fn normalize_curve_name(value: &str) -> String {
243 value.to_lowercase().replace(['-', '_'], "")
244}