Skip to main content

export_aptos_verifier_core/
model.rs

1use crate::error::{Error, Result};
2use crate::snarkjs::{
3    validate_curve_match, validate_protocol, validate_public_counts,
4    validate_verification_key_geometry, Proof as LegacyProof, SnarkJsG1, SnarkJsG2,
5    VerificationKey as LegacyVerificationKey,
6};
7
8pub type DecimalValue = String;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum CurveKind {
12    Bn254,
13    Bls12_381,
14}
15
16impl CurveKind {
17    pub fn from_name(value: &str) -> Result<Self> {
18        match normalize_curve_name(value).as_str() {
19            "bn128" | "bn254" | "altbn128" => Ok(Self::Bn254),
20            "bls12381" => Ok(Self::Bls12_381),
21            _ => Err(Error::UnsupportedCurve(format!(
22                "unsupported curve: {value}"
23            ))),
24        }
25    }
26
27    pub fn canonical_name(self) -> &'static str {
28        match self {
29            Self::Bn254 => "bn254",
30            Self::Bls12_381 => "bls12381",
31        }
32    }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum SourceFormat {
37    SnarkjsJson,
38    Arkworks,
39    ArkworksCompact,
40}
41
42#[derive(Debug, Clone)]
43pub struct Groth16G1Point {
44    pub x: DecimalValue,
45    pub y: DecimalValue,
46    pub z: DecimalValue,
47}
48
49#[derive(Debug, Clone)]
50pub struct Groth16G2Point {
51    pub x0: DecimalValue,
52    pub x1: DecimalValue,
53    pub y0: DecimalValue,
54    pub y1: DecimalValue,
55    pub z0: DecimalValue,
56    pub z1: DecimalValue,
57}
58
59#[derive(Debug, Clone)]
60pub struct Groth16VerificationKey {
61    pub n_public: usize,
62    pub vk_alpha_1: Groth16G1Point,
63    pub vk_beta_2: Groth16G2Point,
64    pub vk_gamma_2: Groth16G2Point,
65    pub vk_delta_2: Groth16G2Point,
66    pub ic: Vec<Groth16G1Point>,
67}
68
69#[derive(Debug, Clone)]
70pub struct Groth16Proof {
71    pub pi_a: Groth16G1Point,
72    pub pi_b: Groth16G2Point,
73    pub pi_c: Groth16G1Point,
74}
75
76#[derive(Debug, Clone)]
77pub struct Groth16VerifierInputs {
78    pub curve: CurveKind,
79    pub protocol: String,
80    pub verifying_key: Groth16VerificationKey,
81    pub proof: Option<Groth16Proof>,
82    pub public_inputs: Vec<DecimalValue>,
83    pub source_format: SourceFormat,
84}
85
86impl Groth16VerifierInputs {
87    pub fn from_legacy(
88        vk: LegacyVerificationKey,
89        proof: LegacyProof,
90        public_inputs: Vec<DecimalValue>,
91        source_format: SourceFormat,
92    ) -> Result<Self> {
93        validate_protocol(vk.protocol.as_ref(), proof.protocol.as_ref())?;
94        validate_verification_key_geometry(&vk)?;
95        validate_public_counts(&vk, &public_inputs)?;
96
97        let curve_name = validate_curve_match(vk.curve.as_ref(), proof.curve.as_ref())?;
98        let curve = CurveKind::from_name(&curve_name)?;
99        let protocol = vk
100            .protocol
101            .clone()
102            .or_else(|| proof.protocol.clone())
103            .unwrap_or_else(|| "groth16".to_string());
104
105        Ok(Self {
106            curve,
107            protocol,
108            verifying_key: Groth16VerificationKey {
109                n_public: vk.n_public,
110                vk_alpha_1: vk.vk_alpha_1.into(),
111                vk_beta_2: vk.vk_beta_2.into(),
112                vk_gamma_2: vk.vk_gamma_2.into(),
113                vk_delta_2: vk.vk_delta_2.into(),
114                ic: vk.ic.into_iter().map(Into::into).collect(),
115            },
116            proof: Some(Groth16Proof {
117                pi_a: proof.pi_a.into(),
118                pi_b: proof.pi_b.into(),
119                pi_c: proof.pi_c.into(),
120            }),
121            public_inputs,
122            source_format,
123        })
124    }
125
126    pub fn from_legacy_vk_only(
127        vk: LegacyVerificationKey,
128        public_inputs: Vec<DecimalValue>,
129        source_format: SourceFormat,
130    ) -> Result<Self> {
131        validate_protocol(vk.protocol.as_ref(), None)?;
132        validate_verification_key_geometry(&vk)?;
133
134        if vk.ic.len() != vk.n_public + 1 {
135            return Err(Error::IcLengthMismatch(format!(
136                "expected IC length = nPublic + 1, got {}",
137                vk.ic.len()
138            )));
139        }
140        if !public_inputs.is_empty() && vk.n_public != public_inputs.len() {
141            return Err(Error::PublicInputCountMismatch(format!(
142                "expected nPublic={}, got {}",
143                vk.n_public,
144                public_inputs.len()
145            )));
146        }
147
148        let curve_name = validate_curve_match(vk.curve.as_ref(), None)?;
149        let curve = CurveKind::from_name(&curve_name)?;
150        let protocol = vk.protocol.clone().unwrap_or_else(|| "groth16".to_string());
151
152        Ok(Self {
153            curve,
154            protocol,
155            verifying_key: Groth16VerificationKey {
156                n_public: vk.n_public,
157                vk_alpha_1: vk.vk_alpha_1.into(),
158                vk_beta_2: vk.vk_beta_2.into(),
159                vk_gamma_2: vk.vk_gamma_2.into(),
160                vk_delta_2: vk.vk_delta_2.into(),
161                ic: vk.ic.into_iter().map(Into::into).collect(),
162            },
163            proof: None,
164            public_inputs,
165            source_format,
166        })
167    }
168
169    pub fn from_parts(
170        curve: CurveKind,
171        verifying_key: Groth16VerificationKey,
172        proof: Option<Groth16Proof>,
173        public_inputs: Vec<DecimalValue>,
174        source_format: SourceFormat,
175    ) -> Result<Self> {
176        if verifying_key.ic.len() != verifying_key.n_public + 1 {
177            return Err(Error::IcLengthMismatch(format!(
178                "expected {} IC points, got {}",
179                verifying_key.n_public + 1,
180                verifying_key.ic.len()
181            )));
182        }
183        if proof.is_some() && public_inputs.len() != verifying_key.n_public {
184            return Err(Error::PublicInputCountMismatch(format!(
185                "verification key expects {} public inputs, got {}",
186                verifying_key.n_public,
187                public_inputs.len()
188            )));
189        }
190        if proof.is_none()
191            && !public_inputs.is_empty()
192            && public_inputs.len() != verifying_key.n_public
193        {
194            return Err(Error::PublicInputCountMismatch(format!(
195                "verification key expects {} public inputs, got {}",
196                verifying_key.n_public,
197                public_inputs.len()
198            )));
199        }
200
201        Ok(Self {
202            curve,
203            protocol: "groth16".to_string(),
204            verifying_key,
205            proof,
206            public_inputs,
207            source_format,
208        })
209    }
210
211    pub fn has_test_vectors(&self) -> bool {
212        self.proof.is_some()
213    }
214}
215
216impl From<SnarkJsG1> for Groth16G1Point {
217    fn from(value: SnarkJsG1) -> Self {
218        Self {
219            x: value.x,
220            y: value.y,
221            z: value.z,
222        }
223    }
224}
225
226impl From<SnarkJsG2> for Groth16G2Point {
227    fn from(value: SnarkJsG2) -> Self {
228        Self {
229            x0: value.x0,
230            x1: value.x1,
231            y0: value.y0,
232            y1: value.y1,
233            z0: value.z0,
234            z1: value.z1,
235        }
236    }
237}
238
239fn normalize_curve_name(value: &str) -> String {
240    value.to_lowercase().replace(['-', '_'], "")
241}