1use exo_core::types::Hash256;
24use serde::{Deserialize, Serialize};
25
26use crate::{
27 circuit::{Circuit, ConstraintSystem},
28 error::{ProofError, Result},
29};
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ProvingKey {
38 pub num_variables: usize,
40 pub num_constraints: usize,
42 pub num_public_inputs: usize,
44 pub circuit_hash: Hash256,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct VerifyingKey {
51 pub num_public_inputs: usize,
53 pub circuit_hash: Hash256,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub struct Proof {
65 pub a: [u8; 32],
67 pub b: [u8; 32],
69 pub c: [u8; 32],
71}
72
73pub fn setup(circuit: &dyn Circuit) -> Result<(ProvingKey, VerifyingKey)> {
82 crate::guard_unaudited("snark::setup")?;
83 let mut cs = ConstraintSystem::new();
84 circuit
85 .synthesize(&mut cs)
86 .map_err(|e| ProofError::SetupError(e.to_string()))?;
87
88 if cs.num_constraints() == 0 {
89 return Err(ProofError::SetupError(
90 "circuit has no constraints".to_string(),
91 ));
92 }
93 validate_public_input_indices(&cs).map_err(ProofError::SetupError)?;
94
95 let circuit_hash =
97 compute_circuit_hash(&cs).map_err(|e| ProofError::SetupError(e.to_string()))?;
98
99 let pk = ProvingKey {
100 num_variables: cs.num_variables(),
101 num_constraints: cs.num_constraints(),
102 num_public_inputs: cs.num_public_inputs,
103 circuit_hash,
104 };
105
106 let vk = VerifyingKey {
107 num_public_inputs: cs.num_public_inputs,
108 circuit_hash,
109 };
110
111 Ok((pk, vk))
112}
113
114pub fn prove(pk: &ProvingKey, circuit: &dyn Circuit, witness: &[u64]) -> Result<Proof> {
124 crate::guard_unaudited("snark::prove")?;
125 let mut cs = ConstraintSystem::new();
127 circuit
128 .synthesize(&mut cs)
129 .map_err(|e| ProofError::ProofGenerationFailed(e.to_string()))?;
130 validate_public_input_indices(&cs).map_err(ProofError::InvalidWitness)?;
131
132 if witness.len() != cs.num_variables() {
134 return Err(ProofError::InvalidWitness(format!(
135 "expected {} witness values, got {}",
136 cs.num_variables(),
137 witness.len()
138 )));
139 }
140
141 for (i, var) in cs.variables.iter_mut().enumerate() {
142 var.value = Some(witness[i]);
143 }
144
145 let circuit_hash =
147 compute_circuit_hash(&cs).map_err(|e| ProofError::ProofGenerationFailed(e.to_string()))?;
148 if circuit_hash != pk.circuit_hash {
149 return Err(ProofError::ProofGenerationFailed(
150 "circuit structure does not match proving key".to_string(),
151 ));
152 }
153
154 if !cs.is_satisfied() {
156 return Err(ProofError::ProofGenerationFailed(
157 "witness does not satisfy constraints".to_string(),
158 ));
159 }
160
161 let public_inputs = public_inputs_from_witness(&cs, witness)?;
162
163 let a = compute_proof_component(b"snark:a:statement:", &circuit_hash, &public_inputs);
168 let b = compute_proof_component(b"snark:b:statement:", &circuit_hash, &public_inputs);
169
170 let c = compute_c_component(&circuit_hash, &public_inputs, &a, &b);
173
174 Ok(Proof { a, b, c })
175}
176
177pub fn verify(vk: &VerifyingKey, proof: &Proof, public_inputs: &[u64]) -> Result<bool> {
187 crate::guard_unaudited("snark::verify")?;
188 if public_inputs.len() != vk.num_public_inputs {
189 return Ok(false);
190 }
191
192 let expected_c = compute_c_component(&vk.circuit_hash, public_inputs, &proof.a, &proof.b);
195
196 Ok(proof.c == expected_c)
197}
198
199fn usize_to_u64(n: usize) -> Result<u64> {
204 u64::try_from(n).map_err(|_| ProofError::SetupError(format!("value {n} overflows u64")))
205}
206
207fn compute_circuit_hash(cs: &ConstraintSystem) -> Result<Hash256> {
208 let mut hasher = blake3::Hasher::new();
209 hasher.update(b"snark:circuit:");
210 hasher.update(&usize_to_u64(cs.num_variables())?.to_le_bytes());
211 hasher.update(&usize_to_u64(cs.num_constraints())?.to_le_bytes());
212 hasher.update(&usize_to_u64(cs.num_public_inputs)?.to_le_bytes());
213
214 for constraint in &cs.constraints {
215 for &(coeff, idx) in &constraint.a_terms.terms {
216 hasher.update(&coeff.to_le_bytes());
217 hasher.update(&usize_to_u64(idx)?.to_le_bytes());
218 }
219 hasher.update(b"|");
220 for &(coeff, idx) in &constraint.b_terms.terms {
221 hasher.update(&coeff.to_le_bytes());
222 hasher.update(&usize_to_u64(idx)?.to_le_bytes());
223 }
224 hasher.update(b"|");
225 for &(coeff, idx) in &constraint.c_terms.terms {
226 hasher.update(&coeff.to_le_bytes());
227 hasher.update(&usize_to_u64(idx)?.to_le_bytes());
228 }
229 hasher.update(b"#");
230 }
231
232 Ok(Hash256::from_bytes(*hasher.finalize().as_bytes()))
233}
234
235fn validate_public_input_indices(cs: &ConstraintSystem) -> std::result::Result<(), String> {
236 if cs.public_input_indices.len() != cs.num_public_inputs {
237 return Err(format!(
238 "public input metadata mismatch: declared {} inputs but recorded {} indices",
239 cs.num_public_inputs,
240 cs.public_input_indices.len()
241 ));
242 }
243
244 for &idx in &cs.public_input_indices {
245 if idx >= cs.num_variables() {
246 return Err(format!(
247 "public input index {idx} is outside variable count {}",
248 cs.num_variables()
249 ));
250 }
251 }
252
253 Ok(())
254}
255
256fn public_inputs_from_witness(cs: &ConstraintSystem, witness: &[u64]) -> Result<Vec<u64>> {
257 let mut public_inputs = Vec::with_capacity(cs.public_input_indices.len());
258 for &idx in &cs.public_input_indices {
259 let Some(value) = witness.get(idx) else {
260 return Err(ProofError::InvalidWitness(format!(
261 "public input index {idx} is outside witness length {}",
262 witness.len()
263 )));
264 };
265 public_inputs.push(*value);
266 }
267 Ok(public_inputs)
268}
269
270fn compute_proof_component(prefix: &[u8], circuit_hash: &Hash256, values: &[u64]) -> [u8; 32] {
271 let mut hasher = blake3::Hasher::new();
272 hasher.update(prefix);
273 hasher.update(circuit_hash.as_bytes());
274 for &value in values {
275 hasher.update(&value.to_le_bytes());
276 }
277 *hasher.finalize().as_bytes()
278}
279
280fn compute_c_component(
281 circuit_hash: &Hash256,
282 public_inputs: &[u64],
283 a: &[u8; 32],
284 b: &[u8; 32],
285) -> [u8; 32] {
286 let mut hasher = blake3::Hasher::new();
287 hasher.update(b"snark:c:verify:");
288 hasher.update(circuit_hash.as_bytes());
289 for &inp in public_inputs {
290 hasher.update(&inp.to_le_bytes());
291 }
292 hasher.update(a);
293 hasher.update(b);
294 *hasher.finalize().as_bytes()
295}
296
297#[cfg(all(test, feature = "unaudited-pedagogical-proofs"))]
302mod tests {
303 use super::*;
304 use crate::circuit::{LinearCombination, allocate, allocate_public, enforce};
305
306 #[derive(Debug)]
308 struct MulCircuit {
309 x: Option<u64>,
310 y: Option<u64>,
311 z: Option<u64>,
312 }
313
314 impl Circuit for MulCircuit {
315 fn synthesize(&self, cs: &mut ConstraintSystem) -> crate::error::Result<()> {
316 let x = allocate_public(cs, self.x);
317 let y = allocate(cs, self.y);
318 let z = allocate_public(cs, self.z);
319 enforce(
320 cs,
321 &LinearCombination::from_variable(x),
322 &LinearCombination::from_variable(y),
323 &LinearCombination::from_variable(z),
324 );
325 Ok(())
326 }
327 }
328
329 fn make_mul_circuit(x: u64, y: u64) -> MulCircuit {
330 MulCircuit {
331 x: Some(x),
332 y: Some(y),
333 z: Some(x.checked_mul(y).expect("test witness product fits u64")),
334 }
335 }
336
337 #[test]
338 fn setup_produces_keys() {
339 let circuit = make_mul_circuit(3, 4);
340 let (pk, vk) = setup(&circuit).unwrap();
341 assert_eq!(pk.num_variables, 3);
342 assert_eq!(pk.num_constraints, 1);
343 assert_eq!(pk.num_public_inputs, 2);
344 assert_eq!(vk.num_public_inputs, 2);
345 assert_eq!(pk.circuit_hash, vk.circuit_hash);
346 }
347
348 #[test]
349 fn valid_proof_verifies() {
350 let circuit = make_mul_circuit(3, 4);
351 let (pk, vk) = setup(&circuit).unwrap();
352
353 let proof = prove(&pk, &circuit, &[3, 4, 12]).unwrap();
355 assert!(verify(&vk, &proof, &[3, 12]).unwrap());
357 }
358
359 #[test]
360 fn invalid_proof_rejected() {
361 let circuit = make_mul_circuit(3, 4);
362 let (pk, vk) = setup(&circuit).unwrap();
363
364 let proof = prove(&pk, &circuit, &[3, 4, 12]).unwrap();
365
366 assert!(!verify(&vk, &proof, &[3, 13]).unwrap());
368 assert!(!verify(&vk, &proof, &[4, 12]).unwrap());
369 }
370
371 #[test]
372 fn different_witnesses_produce_different_proofs() {
373 let c1 = make_mul_circuit(3, 4);
374 let c2 = make_mul_circuit(6, 2);
375 let (pk1, _) = setup(&c1).unwrap();
376 let (pk2, _) = setup(&c2).unwrap();
377
378 let proof1 = prove(&pk1, &c1, &[3, 4, 12]).unwrap();
379 let proof2 = prove(&pk2, &c2, &[6, 2, 12]).unwrap();
380
381 assert_ne!(proof1, proof2);
383 }
384
385 #[test]
386 fn wrong_witness_count_rejected() {
387 let circuit = make_mul_circuit(3, 4);
388 let (pk, _) = setup(&circuit).unwrap();
389 let err = prove(&pk, &circuit, &[3, 4]).unwrap_err();
390 assert!(matches!(err, ProofError::InvalidWitness(_)));
391 }
392
393 #[test]
394 fn proof_components_do_not_commit_private_witness_values() {
395 #[derive(Debug)]
396 struct UnderconstrainedCircuit {
397 public: Option<u64>,
398 private: Option<u64>,
399 }
400
401 impl Circuit for UnderconstrainedCircuit {
402 fn synthesize(&self, cs: &mut ConstraintSystem) -> crate::error::Result<()> {
403 let public = allocate_public(cs, self.public);
404 let _private = allocate(cs, self.private);
405 enforce(
406 cs,
407 &LinearCombination::from_variable(public),
408 &LinearCombination::constant(1),
409 &LinearCombination::from_variable(public),
410 );
411 Ok(())
412 }
413 }
414
415 let circuit = UnderconstrainedCircuit {
416 public: Some(7),
417 private: Some(11),
418 };
419 let (pk, _) = setup(&circuit).unwrap();
420
421 let proof_a = prove(&pk, &circuit, &[7, 11]).unwrap();
422 let proof_b = prove(&pk, &circuit, &[7, 99]).unwrap();
423
424 assert_eq!(
425 proof_a, proof_b,
426 "serialized proof components must not reveal deterministic commitments to private witness values"
427 );
428 }
429
430 #[test]
431 fn setup_rejects_out_of_range_public_input_indices() {
432 #[derive(Debug)]
433 struct InvalidPublicInputCircuit;
434
435 impl Circuit for InvalidPublicInputCircuit {
436 fn synthesize(&self, cs: &mut ConstraintSystem) -> crate::error::Result<()> {
437 let value = allocate(cs, Some(1));
438 enforce(
439 cs,
440 &LinearCombination::from_variable(value),
441 &LinearCombination::constant(1),
442 &LinearCombination::from_variable(value),
443 );
444 cs.public_input_indices.push(value.index + 1);
445 cs.num_public_inputs += 1;
446 Ok(())
447 }
448 }
449
450 let err = setup(&InvalidPublicInputCircuit).unwrap_err();
451
452 assert!(matches!(err, ProofError::SetupError(_)));
453 }
454
455 #[test]
456 fn unsatisfied_witness_rejected() {
457 let circuit = MulCircuit {
458 x: Some(3),
459 y: Some(4),
460 z: Some(12),
461 };
462 let (pk, _) = setup(&circuit).unwrap();
463
464 let err = prove(&pk, &circuit, &[3, 4, 13]).unwrap_err();
466 assert!(matches!(err, ProofError::ProofGenerationFailed(_)));
467 }
468
469 #[test]
470 fn wrong_public_input_count_rejected() {
471 let circuit = make_mul_circuit(3, 4);
472 let (pk, vk) = setup(&circuit).unwrap();
473
474 let proof = prove(&pk, &circuit, &[3, 4, 12]).unwrap();
475 assert!(!verify(&vk, &proof, &[3]).unwrap()); assert!(!verify(&vk, &proof, &[3, 12, 99]).unwrap()); }
478
479 #[test]
480 fn tampered_proof_rejected() {
481 let circuit = make_mul_circuit(3, 4);
482 let (pk, vk) = setup(&circuit).unwrap();
483 let mut proof = prove(&pk, &circuit, &[3, 4, 12]).unwrap();
484 proof.a[0] ^= 0xFF;
485 assert!(!verify(&vk, &proof, &[3, 12]).unwrap());
486 }
487
488 #[test]
489 fn setup_empty_circuit_rejected() {
490 struct EmptyCircuit;
491 impl Circuit for EmptyCircuit {
492 fn synthesize(&self, _cs: &mut ConstraintSystem) -> crate::error::Result<()> {
493 Ok(())
494 }
495 }
496 let err = setup(&EmptyCircuit).unwrap_err();
497 assert!(matches!(err, ProofError::SetupError(_)));
498 }
499
500 #[test]
501 fn proof_deterministic() {
502 let circuit = make_mul_circuit(5, 6);
503 let (pk, _) = setup(&circuit).unwrap();
504 let p1 = prove(&pk, &circuit, &[5, 6, 30]).unwrap();
505 let p2 = prove(&pk, &circuit, &[5, 6, 30]).unwrap();
506 assert_eq!(p1, p2);
507 }
508}