export_aptos_verifier_core/
model.rs1use crate::error::{Error, Result};
2use crate::snarkjs::{
3 validate_curve_match, validate_protocol, validate_public_counts,
4 validate_verification_key_geometry, Proof as LegacyProof, SnarkJsG1, SnarkJsG2,
5 VerificationKey as LegacyVerificationKey,
6};
7
8pub type DecimalValue = String;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum CurveKind {
12 Bn254,
13 Bls12_381,
14}
15
16impl CurveKind {
17 pub fn from_name(value: &str) -> Result<Self> {
18 match normalize_curve_name(value).as_str() {
19 "bn128" | "bn254" | "altbn128" => Ok(Self::Bn254),
20 "bls12381" => Ok(Self::Bls12_381),
21 _ => Err(Error::UnsupportedCurve(format!(
22 "unsupported curve: {value}"
23 ))),
24 }
25 }
26
27 pub fn canonical_name(self) -> &'static str {
28 match self {
29 Self::Bn254 => "bn254",
30 Self::Bls12_381 => "bls12381",
31 }
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum SourceFormat {
37 SnarkjsJson,
38 Arkworks,
39 ArkworksCompact,
40}
41
42#[derive(Debug, Clone)]
43pub struct Groth16G1Point {
44 pub x: DecimalValue,
45 pub y: DecimalValue,
46 pub z: DecimalValue,
47}
48
49#[derive(Debug, Clone)]
50pub struct Groth16G2Point {
51 pub x0: DecimalValue,
52 pub x1: DecimalValue,
53 pub y0: DecimalValue,
54 pub y1: DecimalValue,
55 pub z0: DecimalValue,
56 pub z1: DecimalValue,
57}
58
59#[derive(Debug, Clone)]
60pub struct Groth16VerificationKey {
61 pub n_public: usize,
62 pub vk_alpha_1: Groth16G1Point,
63 pub vk_beta_2: Groth16G2Point,
64 pub vk_gamma_2: Groth16G2Point,
65 pub vk_delta_2: Groth16G2Point,
66 pub ic: Vec<Groth16G1Point>,
67}
68
69#[derive(Debug, Clone)]
70pub struct Groth16Proof {
71 pub pi_a: Groth16G1Point,
72 pub pi_b: Groth16G2Point,
73 pub pi_c: Groth16G1Point,
74}
75
76#[derive(Debug, Clone)]
77pub struct Groth16VerifierInputs {
78 pub curve: CurveKind,
79 pub protocol: String,
80 pub verifying_key: Groth16VerificationKey,
81 pub proof: Option<Groth16Proof>,
82 pub public_inputs: Vec<DecimalValue>,
83 pub source_format: SourceFormat,
84}
85
86impl Groth16VerifierInputs {
87 pub fn from_legacy(
88 vk: LegacyVerificationKey,
89 proof: LegacyProof,
90 public_inputs: Vec<DecimalValue>,
91 source_format: SourceFormat,
92 ) -> Result<Self> {
93 validate_protocol(vk.protocol.as_ref(), proof.protocol.as_ref())?;
94 validate_verification_key_geometry(&vk)?;
95 validate_public_counts(&vk, &public_inputs)?;
96
97 let curve_name = validate_curve_match(vk.curve.as_ref(), proof.curve.as_ref())?;
98 let curve = CurveKind::from_name(&curve_name)?;
99 let protocol = vk
100 .protocol
101 .clone()
102 .or_else(|| proof.protocol.clone())
103 .unwrap_or_else(|| "groth16".to_string());
104
105 Ok(Self {
106 curve,
107 protocol,
108 verifying_key: Groth16VerificationKey {
109 n_public: vk.n_public,
110 vk_alpha_1: vk.vk_alpha_1.into(),
111 vk_beta_2: vk.vk_beta_2.into(),
112 vk_gamma_2: vk.vk_gamma_2.into(),
113 vk_delta_2: vk.vk_delta_2.into(),
114 ic: vk.ic.into_iter().map(Into::into).collect(),
115 },
116 proof: Some(Groth16Proof {
117 pi_a: proof.pi_a.into(),
118 pi_b: proof.pi_b.into(),
119 pi_c: proof.pi_c.into(),
120 }),
121 public_inputs,
122 source_format,
123 })
124 }
125
126 pub fn from_legacy_vk_only(
127 vk: LegacyVerificationKey,
128 public_inputs: Vec<DecimalValue>,
129 source_format: SourceFormat,
130 ) -> Result<Self> {
131 validate_protocol(vk.protocol.as_ref(), None)?;
132 validate_verification_key_geometry(&vk)?;
133
134 if vk.ic.len() != vk.n_public + 1 {
135 return Err(Error::IcLengthMismatch(format!(
136 "expected IC length = nPublic + 1, got {}",
137 vk.ic.len()
138 )));
139 }
140 if !public_inputs.is_empty() && vk.n_public != public_inputs.len() {
141 return Err(Error::PublicInputCountMismatch(format!(
142 "expected nPublic={}, got {}",
143 vk.n_public,
144 public_inputs.len()
145 )));
146 }
147
148 let curve_name = validate_curve_match(vk.curve.as_ref(), None)?;
149 let curve = CurveKind::from_name(&curve_name)?;
150 let protocol = vk.protocol.clone().unwrap_or_else(|| "groth16".to_string());
151
152 Ok(Self {
153 curve,
154 protocol,
155 verifying_key: Groth16VerificationKey {
156 n_public: vk.n_public,
157 vk_alpha_1: vk.vk_alpha_1.into(),
158 vk_beta_2: vk.vk_beta_2.into(),
159 vk_gamma_2: vk.vk_gamma_2.into(),
160 vk_delta_2: vk.vk_delta_2.into(),
161 ic: vk.ic.into_iter().map(Into::into).collect(),
162 },
163 proof: None,
164 public_inputs,
165 source_format,
166 })
167 }
168
169 pub fn from_parts(
170 curve: CurveKind,
171 verifying_key: Groth16VerificationKey,
172 proof: Option<Groth16Proof>,
173 public_inputs: Vec<DecimalValue>,
174 source_format: SourceFormat,
175 ) -> Result<Self> {
176 if verifying_key.ic.len() != verifying_key.n_public + 1 {
177 return Err(Error::IcLengthMismatch(format!(
178 "expected {} IC points, got {}",
179 verifying_key.n_public + 1,
180 verifying_key.ic.len()
181 )));
182 }
183 if proof.is_some() && public_inputs.len() != verifying_key.n_public {
184 return Err(Error::PublicInputCountMismatch(format!(
185 "verification key expects {} public inputs, got {}",
186 verifying_key.n_public,
187 public_inputs.len()
188 )));
189 }
190 if proof.is_none()
191 && !public_inputs.is_empty()
192 && public_inputs.len() != verifying_key.n_public
193 {
194 return Err(Error::PublicInputCountMismatch(format!(
195 "verification key expects {} public inputs, got {}",
196 verifying_key.n_public,
197 public_inputs.len()
198 )));
199 }
200
201 Ok(Self {
202 curve,
203 protocol: "groth16".to_string(),
204 verifying_key,
205 proof,
206 public_inputs,
207 source_format,
208 })
209 }
210
211 pub fn has_test_vectors(&self) -> bool {
212 self.proof.is_some()
213 }
214}
215
216impl From<SnarkJsG1> for Groth16G1Point {
217 fn from(value: SnarkJsG1) -> Self {
218 Self {
219 x: value.x,
220 y: value.y,
221 z: value.z,
222 }
223 }
224}
225
226impl From<SnarkJsG2> for Groth16G2Point {
227 fn from(value: SnarkJsG2) -> Self {
228 Self {
229 x0: value.x0,
230 x1: value.x1,
231 y0: value.y0,
232 y1: value.y1,
233 z0: value.z0,
234 z1: value.z1,
235 }
236 }
237}
238
239fn normalize_curve_name(value: &str) -> String {
240 value.to_lowercase().replace(['-', '_'], "")
241}