1use serde::{Deserialize, Serialize, de::DeserializeOwned};
20
21use crate::{
22 error::{ProofError, Result},
23 snark, stark, zkml,
24};
25
26const MAX_VERIFIER_CBOR_BYTES: usize = 1_048_576;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34pub enum ProofType {
35 Snark,
37 Stark,
39 Zkml,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SnarkBundle {
50 pub vk: snark::VerifyingKey,
51 pub proof: snark::Proof,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct StarkBundle {
57 pub proof: stark::StarkProof,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StarkPublicInputs {
63 pub inputs: Vec<u64>,
65 pub constraints: Vec<stark::StarkConstraint>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ZkmlBundle {
72 pub proof: zkml::InferenceProof,
73}
74
75pub fn verify_any(
84 proof_type: ProofType,
85 proof_bytes: &[u8],
86 public_inputs_bytes: &[u8],
87) -> Result<bool> {
88 match proof_type {
89 ProofType::Snark => verify_snark(proof_bytes, public_inputs_bytes),
90 ProofType::Stark => verify_stark(proof_bytes, public_inputs_bytes),
91 ProofType::Zkml => verify_zkml(proof_bytes),
92 }
93}
94
95fn decode_cbor<T: DeserializeOwned>(bytes: &[u8], label: &'static str) -> Result<T> {
96 if bytes.len() > MAX_VERIFIER_CBOR_BYTES {
97 return Err(ProofError::DeserializationError(format!(
98 "{label}: canonical CBOR input exceeds maximum size of {MAX_VERIFIER_CBOR_BYTES} bytes"
99 )));
100 }
101 ciborium::from_reader(bytes).map_err(|e| {
102 ProofError::DeserializationError(format!("{label}: canonical CBOR decode failed: {e}"))
103 })
104}
105
106fn verify_snark(proof_bytes: &[u8], public_inputs_bytes: &[u8]) -> Result<bool> {
107 let bundle: SnarkBundle = decode_cbor(proof_bytes, "snark proof bundle")?;
108 let public_inputs: Vec<u64> = decode_cbor(public_inputs_bytes, "snark public inputs")?;
109
110 snark::verify(&bundle.vk, &bundle.proof, &public_inputs)
111}
112
113fn verify_stark(proof_bytes: &[u8], public_inputs_bytes: &[u8]) -> Result<bool> {
114 let bundle: StarkBundle = decode_cbor(proof_bytes, "stark proof bundle")?;
115 let public_inputs: StarkPublicInputs = decode_cbor(public_inputs_bytes, "stark public inputs")?;
116
117 stark::verify_stark_with_constraints(
118 &bundle.proof,
119 &public_inputs.inputs,
120 &public_inputs.constraints,
121 )
122}
123
124fn verify_zkml(proof_bytes: &[u8]) -> Result<bool> {
125 let bundle: ZkmlBundle = decode_cbor(proof_bytes, "zkml proof bundle")?;
126
127 zkml::verify_inference(&bundle.proof)
128}
129
130#[cfg(test)]
135mod canonical_encoding_contract_tests {
136 #[test]
137 fn verify_any_uses_canonical_cbor_not_json() {
138 let source = include_str!("verifier.rs");
139 let production = source
140 .split("// ---------------------------------------------------------------------------\n// Tests")
141 .next()
142 .expect("production section exists");
143
144 assert!(
145 !production.contains("serde_json::from_slice"),
146 "proof verifier must not decode proof bundles or public inputs as JSON"
147 );
148 assert!(
149 production.contains("ciborium::from_reader"),
150 "proof verifier must decode proof bundles and public inputs as canonical CBOR"
151 );
152 }
153
154 #[test]
155 fn decode_cbor_rejects_oversized_inputs_before_deserialization() {
156 let oversized = vec![0u8; 1_048_577];
157 let err = super::decode_cbor::<Vec<u8>>(&oversized, "oversized proof").unwrap_err();
158 assert!(
159 err.to_string().contains("exceeds maximum"),
160 "oversized proof input must fail before CBOR decode: {err}"
161 );
162 }
163}
164
165#[cfg(all(test, feature = "unaudited-pedagogical-proofs"))]
166mod tests {
167 use super::*;
168 use crate::{
169 circuit::{
170 Circuit, ConstraintSystem, LinearCombination, allocate, allocate_public, enforce,
171 },
172 snark,
173 stark::{StarkConfig, StarkConstraint},
174 zkml::{self, ModelCommitment},
175 };
176
177 fn cbor_bytes<T: Serialize>(value: &T) -> Vec<u8> {
178 let mut encoded = Vec::new();
179 ciborium::into_writer(value, &mut encoded).expect("canonical CBOR encode");
180 encoded
181 }
182
183 #[derive(Debug)]
185 struct MulCircuit {
186 x: Option<u64>,
187 y: Option<u64>,
188 z: Option<u64>,
189 }
190
191 impl Circuit for MulCircuit {
192 fn synthesize(&self, cs: &mut ConstraintSystem) -> crate::error::Result<()> {
193 let x = allocate_public(cs, self.x);
194 let y = allocate(cs, self.y);
195 let z = allocate_public(cs, self.z);
196 enforce(
197 cs,
198 &LinearCombination::from_variable(x),
199 &LinearCombination::from_variable(y),
200 &LinearCombination::from_variable(z),
201 );
202 Ok(())
203 }
204 }
205
206 #[test]
207 fn verify_any_snark() {
208 let circuit = MulCircuit {
209 x: Some(3),
210 y: Some(4),
211 z: Some(12),
212 };
213 let (pk, vk) = snark::setup(&circuit).unwrap();
214 let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
215
216 let bundle = SnarkBundle { vk, proof };
217 let proof_bytes = cbor_bytes(&bundle);
218 let public_inputs_bytes = cbor_bytes(&vec![3u64, 12u64]);
219
220 let result = verify_any(ProofType::Snark, &proof_bytes, &public_inputs_bytes).unwrap();
221 assert!(result);
222 }
223
224 #[test]
225 fn verify_any_snark_accepts_canonical_cbor() {
226 let circuit = MulCircuit {
227 x: Some(3),
228 y: Some(4),
229 z: Some(12),
230 };
231 let (pk, vk) = snark::setup(&circuit).unwrap();
232 let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
233
234 let bundle = SnarkBundle { vk, proof };
235 let proof_bytes = cbor_bytes(&bundle);
236 let public_inputs_bytes = cbor_bytes(&vec![3u64, 12u64]);
237
238 let result = verify_any(ProofType::Snark, &proof_bytes, &public_inputs_bytes).unwrap();
239 assert!(result);
240 }
241
242 #[test]
243 fn verify_any_rejects_json_snark_bundle() {
244 let circuit = MulCircuit {
245 x: Some(3),
246 y: Some(4),
247 z: Some(12),
248 };
249 let (pk, vk) = snark::setup(&circuit).unwrap();
250 let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
251
252 let bundle = SnarkBundle { vk, proof };
253 let proof_bytes = serde_json::to_vec(&bundle).unwrap();
254 let public_inputs_bytes = serde_json::to_vec(&vec![3u64, 12u64]).unwrap();
255
256 let err = verify_any(ProofType::Snark, &proof_bytes, &public_inputs_bytes).unwrap_err();
257 assert!(matches!(err, ProofError::DeserializationError(_)));
258 }
259
260 #[test]
261 fn verify_any_rejects_json_stark_bundle() {
262 let config = StarkConfig {
263 num_queries: 2,
264 ..StarkConfig::default_config()
265 };
266 let trace: Vec<Vec<u64>> = vec![vec![1], vec![1], vec![1]];
267 let constraints = vec![StarkConstraint {
268 name: "constant".to_string(),
269 column_indices: vec![0],
270 coefficients: vec![(1, config.field_size - 1)],
271 }];
272 let proof = crate::stark::prove_stark(&trace, &constraints, &config).unwrap();
273
274 let bundle = StarkBundle { proof };
275 let proof_bytes = serde_json::to_vec(&bundle).unwrap();
276 let public_inputs_bytes = cbor_bytes(&StarkPublicInputs {
277 inputs: vec![1u64],
278 constraints,
279 });
280
281 let err = verify_any(ProofType::Stark, &proof_bytes, &public_inputs_bytes).unwrap_err();
282 assert!(matches!(err, ProofError::DeserializationError(_)));
283 }
284
285 #[test]
286 fn verify_any_rejects_json_zkml_bundle() {
287 let model = ModelCommitment::new(b"arch", b"weights", 1);
288 let proof = zkml::prove_inference(&model, b"input", b"output").unwrap();
289
290 let bundle = ZkmlBundle { proof };
291 let proof_bytes = serde_json::to_vec(&bundle).unwrap();
292
293 let err = verify_any(ProofType::Zkml, &proof_bytes, b"[]").unwrap_err();
294 assert!(matches!(err, ProofError::DeserializationError(_)));
295 }
296
297 #[test]
298 fn verify_any_snark_invalid() {
299 let circuit = MulCircuit {
300 x: Some(3),
301 y: Some(4),
302 z: Some(12),
303 };
304 let (pk, vk) = snark::setup(&circuit).unwrap();
305 let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
306
307 let bundle = SnarkBundle { vk, proof };
308 let proof_bytes = cbor_bytes(&bundle);
309 let wrong_inputs = cbor_bytes(&vec![3u64, 13u64]);
310
311 let result = verify_any(ProofType::Snark, &proof_bytes, &wrong_inputs).unwrap();
312 assert!(!result);
313 }
314
315 #[test]
316 fn verify_any_stark() {
317 let config = StarkConfig {
318 num_queries: 2,
319 ..StarkConfig::default_config()
320 };
321 let trace: Vec<Vec<u64>> = vec![vec![1], vec![1], vec![1]];
322 let constraints = vec![StarkConstraint {
323 name: "constant".to_string(),
324 column_indices: vec![0],
325 coefficients: vec![(1, config.field_size - 1)],
326 }];
327 let proof = crate::stark::prove_stark(&trace, &constraints, &config).unwrap();
328
329 let bundle = StarkBundle { proof };
330 let proof_bytes = cbor_bytes(&bundle);
331 let public_inputs_bytes = cbor_bytes(&StarkPublicInputs {
332 inputs: vec![1u64],
333 constraints,
334 });
335
336 let result = verify_any(ProofType::Stark, &proof_bytes, &public_inputs_bytes).unwrap();
337 assert!(result);
338 }
339
340 #[test]
341 fn verify_any_zkml() {
342 let model = ModelCommitment::new(b"arch", b"weights", 1);
343 let proof = zkml::prove_inference(&model, b"input", b"output").unwrap();
344
345 let bundle = ZkmlBundle { proof };
346 let proof_bytes = cbor_bytes(&bundle);
347
348 let result = verify_any(ProofType::Zkml, &proof_bytes, b"[]").unwrap();
349 assert!(result);
350 }
351
352 #[test]
353 fn verify_any_zkml_tampered() {
354 let model = ModelCommitment::new(b"arch", b"weights", 1);
355 let mut proof = zkml::prove_inference(&model, b"input", b"output").unwrap();
356 proof.output_hash = exo_core::types::Hash256::ZERO;
357
358 let bundle = ZkmlBundle { proof };
359 let proof_bytes = cbor_bytes(&bundle);
360
361 let result = verify_any(ProofType::Zkml, &proof_bytes, b"[]").unwrap();
362 assert!(!result);
363 }
364
365 #[test]
366 fn verify_any_bad_proof_bytes() {
367 let err = verify_any(ProofType::Snark, b"not cbor", b"[]").unwrap_err();
368 assert!(matches!(err, ProofError::DeserializationError(_)));
369 }
370
371 #[test]
372 fn verify_any_bad_public_inputs_bytes() {
373 let circuit = MulCircuit {
374 x: Some(3),
375 y: Some(4),
376 z: Some(12),
377 };
378 let (pk, vk) = snark::setup(&circuit).unwrap();
379 let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
380
381 let bundle = SnarkBundle { vk, proof };
382 let proof_bytes = cbor_bytes(&bundle);
383 let legacy_json_inputs = serde_json::to_vec(&vec![3u64, 12u64]).unwrap();
384
385 let err = verify_any(ProofType::Snark, &proof_bytes, &legacy_json_inputs).unwrap_err();
386 assert!(matches!(err, ProofError::DeserializationError(_)));
387 }
388
389 #[test]
390 fn proof_type_serde() {
391 let types = vec![ProofType::Snark, ProofType::Stark, ProofType::Zkml];
392 for t in &types {
393 let json = serde_json::to_string(t).unwrap();
394 let t2: ProofType = serde_json::from_str(&json).unwrap();
395 assert_eq!(t, &t2);
396 }
397 }
398
399 #[test]
400 fn proof_type_eq() {
401 assert_eq!(ProofType::Snark, ProofType::Snark);
402 assert_ne!(ProofType::Snark, ProofType::Stark);
403 assert_ne!(ProofType::Stark, ProofType::Zkml);
404 }
405
406 #[test]
407 fn verify_any_stark_bad_proof_bytes() {
408 let err = verify_any(ProofType::Stark, b"not cbor", b"[]").unwrap_err();
409 assert!(matches!(err, ProofError::DeserializationError(_)));
410 }
411
412 #[test]
413 fn verify_any_stark_bad_public_inputs_bytes() {
414 let config = StarkConfig {
415 num_queries: 1,
416 ..StarkConfig::default_config()
417 };
418 let trace: Vec<Vec<u64>> = vec![vec![1], vec![1]];
419 let constraints = vec![StarkConstraint {
420 name: "constant".to_string(),
421 column_indices: vec![0],
422 coefficients: vec![(1, config.field_size - 1)],
423 }];
424 let proof = crate::stark::prove_stark(&trace, &constraints, &config).unwrap();
425 let bundle = StarkBundle { proof };
426 let proof_bytes = cbor_bytes(&bundle);
427 let legacy_json_inputs = serde_json::to_vec(&vec![1u64]).unwrap();
430 let err = verify_any(ProofType::Stark, &proof_bytes, &legacy_json_inputs).unwrap_err();
431 assert!(matches!(err, ProofError::DeserializationError(_)));
432 }
433
434 #[test]
435 fn verify_any_zkml_bad_proof_bytes() {
436 let err = verify_any(ProofType::Zkml, b"not cbor", b"[]").unwrap_err();
437 assert!(matches!(err, ProofError::DeserializationError(_)));
438 }
439}