Skip to main content

exo_proofs/
verifier.rs

1// Copyright 2026 Exochain Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at:
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17//! Unified proof verifier -- dispatches to the appropriate proof system.
18
19use serde::{Deserialize, Serialize, de::DeserializeOwned};
20
21use crate::{
22    error::{ProofError, Result},
23    snark, stark, zkml,
24};
25
26const MAX_VERIFIER_CBOR_BYTES: usize = 1_048_576;
27
28// ---------------------------------------------------------------------------
29// ProofType
30// ---------------------------------------------------------------------------
31
32/// The type of zero-knowledge proof.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34pub enum ProofType {
35    /// SNARK proof (succinct, pairing-based structure).
36    Snark,
37    /// STARK proof (hash-based, post-quantum).
38    Stark,
39    /// ZKML inference proof.
40    Zkml,
41}
42
43// ---------------------------------------------------------------------------
44// Serialized proof formats
45// ---------------------------------------------------------------------------
46
47/// A serialized SNARK verification bundle.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SnarkBundle {
50    pub vk: snark::VerifyingKey,
51    pub proof: snark::Proof,
52}
53
54/// A serialized STARK verification bundle.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct StarkBundle {
57    pub proof: stark::StarkProof,
58}
59
60/// Public STARK statement supplied by the verifier.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StarkPublicInputs {
63    /// Public inputs, currently the first row of the committed trace.
64    pub inputs: Vec<u64>,
65    /// Public transition constraints the proof must satisfy.
66    pub constraints: Vec<stark::StarkConstraint>,
67}
68
69/// A serialized ZKML verification bundle.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ZkmlBundle {
72    pub proof: zkml::InferenceProof,
73}
74
75// ---------------------------------------------------------------------------
76// verify_any
77// ---------------------------------------------------------------------------
78
79/// Verify any proof type given its serialized form and public inputs.
80///
81/// The `proof_bytes` must be a canonical CBOR bundle appropriate for the proof
82/// type. The `public_inputs_bytes` contains canonical CBOR public inputs.
83pub fn verify_any(
84    proof_type: ProofType,
85    proof_bytes: &[u8],
86    public_inputs_bytes: &[u8],
87) -> Result<bool> {
88    match proof_type {
89        ProofType::Snark => verify_snark(proof_bytes, public_inputs_bytes),
90        ProofType::Stark => verify_stark(proof_bytes, public_inputs_bytes),
91        ProofType::Zkml => verify_zkml(proof_bytes),
92    }
93}
94
95fn decode_cbor<T: DeserializeOwned>(bytes: &[u8], label: &'static str) -> Result<T> {
96    if bytes.len() > MAX_VERIFIER_CBOR_BYTES {
97        return Err(ProofError::DeserializationError(format!(
98            "{label}: canonical CBOR input exceeds maximum size of {MAX_VERIFIER_CBOR_BYTES} bytes"
99        )));
100    }
101    ciborium::from_reader(bytes).map_err(|e| {
102        ProofError::DeserializationError(format!("{label}: canonical CBOR decode failed: {e}"))
103    })
104}
105
106fn verify_snark(proof_bytes: &[u8], public_inputs_bytes: &[u8]) -> Result<bool> {
107    let bundle: SnarkBundle = decode_cbor(proof_bytes, "snark proof bundle")?;
108    let public_inputs: Vec<u64> = decode_cbor(public_inputs_bytes, "snark public inputs")?;
109
110    snark::verify(&bundle.vk, &bundle.proof, &public_inputs)
111}
112
113fn verify_stark(proof_bytes: &[u8], public_inputs_bytes: &[u8]) -> Result<bool> {
114    let bundle: StarkBundle = decode_cbor(proof_bytes, "stark proof bundle")?;
115    let public_inputs: StarkPublicInputs = decode_cbor(public_inputs_bytes, "stark public inputs")?;
116
117    stark::verify_stark_with_constraints(
118        &bundle.proof,
119        &public_inputs.inputs,
120        &public_inputs.constraints,
121    )
122}
123
124fn verify_zkml(proof_bytes: &[u8]) -> Result<bool> {
125    let bundle: ZkmlBundle = decode_cbor(proof_bytes, "zkml proof bundle")?;
126
127    zkml::verify_inference(&bundle.proof)
128}
129
130// ---------------------------------------------------------------------------
131// Tests
132// ---------------------------------------------------------------------------
133
134#[cfg(test)]
135mod canonical_encoding_contract_tests {
136    #[test]
137    fn verify_any_uses_canonical_cbor_not_json() {
138        let source = include_str!("verifier.rs");
139        let production = source
140            .split("// ---------------------------------------------------------------------------\n// Tests")
141            .next()
142            .expect("production section exists");
143
144        assert!(
145            !production.contains("serde_json::from_slice"),
146            "proof verifier must not decode proof bundles or public inputs as JSON"
147        );
148        assert!(
149            production.contains("ciborium::from_reader"),
150            "proof verifier must decode proof bundles and public inputs as canonical CBOR"
151        );
152    }
153
154    #[test]
155    fn decode_cbor_rejects_oversized_inputs_before_deserialization() {
156        let oversized = vec![0u8; 1_048_577];
157        let err = super::decode_cbor::<Vec<u8>>(&oversized, "oversized proof").unwrap_err();
158        assert!(
159            err.to_string().contains("exceeds maximum"),
160            "oversized proof input must fail before CBOR decode: {err}"
161        );
162    }
163}
164
165#[cfg(all(test, feature = "unaudited-pedagogical-proofs"))]
166mod tests {
167    use super::*;
168    use crate::{
169        circuit::{
170            Circuit, ConstraintSystem, LinearCombination, allocate, allocate_public, enforce,
171        },
172        snark,
173        stark::{StarkConfig, StarkConstraint},
174        zkml::{self, ModelCommitment},
175    };
176
177    fn cbor_bytes<T: Serialize>(value: &T) -> Vec<u8> {
178        let mut encoded = Vec::new();
179        ciborium::into_writer(value, &mut encoded).expect("canonical CBOR encode");
180        encoded
181    }
182
183    /// x * y = z
184    #[derive(Debug)]
185    struct MulCircuit {
186        x: Option<u64>,
187        y: Option<u64>,
188        z: Option<u64>,
189    }
190
191    impl Circuit for MulCircuit {
192        fn synthesize(&self, cs: &mut ConstraintSystem) -> crate::error::Result<()> {
193            let x = allocate_public(cs, self.x);
194            let y = allocate(cs, self.y);
195            let z = allocate_public(cs, self.z);
196            enforce(
197                cs,
198                &LinearCombination::from_variable(x),
199                &LinearCombination::from_variable(y),
200                &LinearCombination::from_variable(z),
201            );
202            Ok(())
203        }
204    }
205
206    #[test]
207    fn verify_any_snark() {
208        let circuit = MulCircuit {
209            x: Some(3),
210            y: Some(4),
211            z: Some(12),
212        };
213        let (pk, vk) = snark::setup(&circuit).unwrap();
214        let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
215
216        let bundle = SnarkBundle { vk, proof };
217        let proof_bytes = cbor_bytes(&bundle);
218        let public_inputs_bytes = cbor_bytes(&vec![3u64, 12u64]);
219
220        let result = verify_any(ProofType::Snark, &proof_bytes, &public_inputs_bytes).unwrap();
221        assert!(result);
222    }
223
224    #[test]
225    fn verify_any_snark_accepts_canonical_cbor() {
226        let circuit = MulCircuit {
227            x: Some(3),
228            y: Some(4),
229            z: Some(12),
230        };
231        let (pk, vk) = snark::setup(&circuit).unwrap();
232        let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
233
234        let bundle = SnarkBundle { vk, proof };
235        let proof_bytes = cbor_bytes(&bundle);
236        let public_inputs_bytes = cbor_bytes(&vec![3u64, 12u64]);
237
238        let result = verify_any(ProofType::Snark, &proof_bytes, &public_inputs_bytes).unwrap();
239        assert!(result);
240    }
241
242    #[test]
243    fn verify_any_rejects_json_snark_bundle() {
244        let circuit = MulCircuit {
245            x: Some(3),
246            y: Some(4),
247            z: Some(12),
248        };
249        let (pk, vk) = snark::setup(&circuit).unwrap();
250        let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
251
252        let bundle = SnarkBundle { vk, proof };
253        let proof_bytes = serde_json::to_vec(&bundle).unwrap();
254        let public_inputs_bytes = serde_json::to_vec(&vec![3u64, 12u64]).unwrap();
255
256        let err = verify_any(ProofType::Snark, &proof_bytes, &public_inputs_bytes).unwrap_err();
257        assert!(matches!(err, ProofError::DeserializationError(_)));
258    }
259
260    #[test]
261    fn verify_any_rejects_json_stark_bundle() {
262        let config = StarkConfig {
263            num_queries: 2,
264            ..StarkConfig::default_config()
265        };
266        let trace: Vec<Vec<u64>> = vec![vec![1], vec![1], vec![1]];
267        let constraints = vec![StarkConstraint {
268            name: "constant".to_string(),
269            column_indices: vec![0],
270            coefficients: vec![(1, config.field_size - 1)],
271        }];
272        let proof = crate::stark::prove_stark(&trace, &constraints, &config).unwrap();
273
274        let bundle = StarkBundle { proof };
275        let proof_bytes = serde_json::to_vec(&bundle).unwrap();
276        let public_inputs_bytes = cbor_bytes(&StarkPublicInputs {
277            inputs: vec![1u64],
278            constraints,
279        });
280
281        let err = verify_any(ProofType::Stark, &proof_bytes, &public_inputs_bytes).unwrap_err();
282        assert!(matches!(err, ProofError::DeserializationError(_)));
283    }
284
285    #[test]
286    fn verify_any_rejects_json_zkml_bundle() {
287        let model = ModelCommitment::new(b"arch", b"weights", 1);
288        let proof = zkml::prove_inference(&model, b"input", b"output").unwrap();
289
290        let bundle = ZkmlBundle { proof };
291        let proof_bytes = serde_json::to_vec(&bundle).unwrap();
292
293        let err = verify_any(ProofType::Zkml, &proof_bytes, b"[]").unwrap_err();
294        assert!(matches!(err, ProofError::DeserializationError(_)));
295    }
296
297    #[test]
298    fn verify_any_snark_invalid() {
299        let circuit = MulCircuit {
300            x: Some(3),
301            y: Some(4),
302            z: Some(12),
303        };
304        let (pk, vk) = snark::setup(&circuit).unwrap();
305        let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
306
307        let bundle = SnarkBundle { vk, proof };
308        let proof_bytes = cbor_bytes(&bundle);
309        let wrong_inputs = cbor_bytes(&vec![3u64, 13u64]);
310
311        let result = verify_any(ProofType::Snark, &proof_bytes, &wrong_inputs).unwrap();
312        assert!(!result);
313    }
314
315    #[test]
316    fn verify_any_stark() {
317        let config = StarkConfig {
318            num_queries: 2,
319            ..StarkConfig::default_config()
320        };
321        let trace: Vec<Vec<u64>> = vec![vec![1], vec![1], vec![1]];
322        let constraints = vec![StarkConstraint {
323            name: "constant".to_string(),
324            column_indices: vec![0],
325            coefficients: vec![(1, config.field_size - 1)],
326        }];
327        let proof = crate::stark::prove_stark(&trace, &constraints, &config).unwrap();
328
329        let bundle = StarkBundle { proof };
330        let proof_bytes = cbor_bytes(&bundle);
331        let public_inputs_bytes = cbor_bytes(&StarkPublicInputs {
332            inputs: vec![1u64],
333            constraints,
334        });
335
336        let result = verify_any(ProofType::Stark, &proof_bytes, &public_inputs_bytes).unwrap();
337        assert!(result);
338    }
339
340    #[test]
341    fn verify_any_zkml() {
342        let model = ModelCommitment::new(b"arch", b"weights", 1);
343        let proof = zkml::prove_inference(&model, b"input", b"output").unwrap();
344
345        let bundle = ZkmlBundle { proof };
346        let proof_bytes = cbor_bytes(&bundle);
347
348        let result = verify_any(ProofType::Zkml, &proof_bytes, b"[]").unwrap();
349        assert!(result);
350    }
351
352    #[test]
353    fn verify_any_zkml_tampered() {
354        let model = ModelCommitment::new(b"arch", b"weights", 1);
355        let mut proof = zkml::prove_inference(&model, b"input", b"output").unwrap();
356        proof.output_hash = exo_core::types::Hash256::ZERO;
357
358        let bundle = ZkmlBundle { proof };
359        let proof_bytes = cbor_bytes(&bundle);
360
361        let result = verify_any(ProofType::Zkml, &proof_bytes, b"[]").unwrap();
362        assert!(!result);
363    }
364
365    #[test]
366    fn verify_any_bad_proof_bytes() {
367        let err = verify_any(ProofType::Snark, b"not cbor", b"[]").unwrap_err();
368        assert!(matches!(err, ProofError::DeserializationError(_)));
369    }
370
371    #[test]
372    fn verify_any_bad_public_inputs_bytes() {
373        let circuit = MulCircuit {
374            x: Some(3),
375            y: Some(4),
376            z: Some(12),
377        };
378        let (pk, vk) = snark::setup(&circuit).unwrap();
379        let proof = snark::prove(&pk, &circuit, &[3, 4, 12]).unwrap();
380
381        let bundle = SnarkBundle { vk, proof };
382        let proof_bytes = cbor_bytes(&bundle);
383        let legacy_json_inputs = serde_json::to_vec(&vec![3u64, 12u64]).unwrap();
384
385        let err = verify_any(ProofType::Snark, &proof_bytes, &legacy_json_inputs).unwrap_err();
386        assert!(matches!(err, ProofError::DeserializationError(_)));
387    }
388
389    #[test]
390    fn proof_type_serde() {
391        let types = vec![ProofType::Snark, ProofType::Stark, ProofType::Zkml];
392        for t in &types {
393            let json = serde_json::to_string(t).unwrap();
394            let t2: ProofType = serde_json::from_str(&json).unwrap();
395            assert_eq!(t, &t2);
396        }
397    }
398
399    #[test]
400    fn proof_type_eq() {
401        assert_eq!(ProofType::Snark, ProofType::Snark);
402        assert_ne!(ProofType::Snark, ProofType::Stark);
403        assert_ne!(ProofType::Stark, ProofType::Zkml);
404    }
405
406    #[test]
407    fn verify_any_stark_bad_proof_bytes() {
408        let err = verify_any(ProofType::Stark, b"not cbor", b"[]").unwrap_err();
409        assert!(matches!(err, ProofError::DeserializationError(_)));
410    }
411
412    #[test]
413    fn verify_any_stark_bad_public_inputs_bytes() {
414        let config = StarkConfig {
415            num_queries: 1,
416            ..StarkConfig::default_config()
417        };
418        let trace: Vec<Vec<u64>> = vec![vec![1], vec![1]];
419        let constraints = vec![StarkConstraint {
420            name: "constant".to_string(),
421            column_indices: vec![0],
422            coefficients: vec![(1, config.field_size - 1)],
423        }];
424        let proof = crate::stark::prove_stark(&trace, &constraints, &config).unwrap();
425        let bundle = StarkBundle { proof };
426        let proof_bytes = cbor_bytes(&bundle);
427        // Legacy bare public-input arrays are rejected because STARK
428        // verification now requires caller-supplied public constraints.
429        let legacy_json_inputs = serde_json::to_vec(&vec![1u64]).unwrap();
430        let err = verify_any(ProofType::Stark, &proof_bytes, &legacy_json_inputs).unwrap_err();
431        assert!(matches!(err, ProofError::DeserializationError(_)));
432    }
433
434    #[test]
435    fn verify_any_zkml_bad_proof_bytes() {
436        let err = verify_any(ProofType::Zkml, b"not cbor", b"[]").unwrap_err();
437        assert!(matches!(err, ProofError::DeserializationError(_)));
438    }
439}