use exo_core::types::Hash256;
use serde::{Deserialize, Serialize};
use crate::{
circuit::{Circuit, ConstraintSystem},
error::{ProofError, Result},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProvingKey {
pub num_variables: usize,
pub num_constraints: usize,
pub num_public_inputs: usize,
pub circuit_hash: Hash256,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerifyingKey {
pub num_public_inputs: usize,
pub circuit_hash: Hash256,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Proof {
pub a: [u8; 32],
pub b: [u8; 32],
pub c: [u8; 32],
}
pub fn setup(circuit: &dyn Circuit) -> Result<(ProvingKey, VerifyingKey)> {
crate::guard_unaudited("snark::setup")?;
let mut cs = ConstraintSystem::new();
circuit
.synthesize(&mut cs)
.map_err(|e| ProofError::SetupError(e.to_string()))?;
if cs.num_constraints() == 0 {
return Err(ProofError::SetupError(
"circuit has no constraints".to_string(),
));
}
validate_public_input_indices(&cs).map_err(ProofError::SetupError)?;
let circuit_hash =
compute_circuit_hash(&cs).map_err(|e| ProofError::SetupError(e.to_string()))?;
let pk = ProvingKey {
num_variables: cs.num_variables(),
num_constraints: cs.num_constraints(),
num_public_inputs: cs.num_public_inputs,
circuit_hash,
};
let vk = VerifyingKey {
num_public_inputs: cs.num_public_inputs,
circuit_hash,
};
Ok((pk, vk))
}
pub fn prove(pk: &ProvingKey, circuit: &dyn Circuit, witness: &[u64]) -> Result<Proof> {
crate::guard_unaudited("snark::prove")?;
let mut cs = ConstraintSystem::new();
circuit
.synthesize(&mut cs)
.map_err(|e| ProofError::ProofGenerationFailed(e.to_string()))?;
validate_public_input_indices(&cs).map_err(ProofError::InvalidWitness)?;
if witness.len() != cs.num_variables() {
return Err(ProofError::InvalidWitness(format!(
"expected {} witness values, got {}",
cs.num_variables(),
witness.len()
)));
}
for (i, var) in cs.variables.iter_mut().enumerate() {
var.value = Some(witness[i]);
}
let circuit_hash =
compute_circuit_hash(&cs).map_err(|e| ProofError::ProofGenerationFailed(e.to_string()))?;
if circuit_hash != pk.circuit_hash {
return Err(ProofError::ProofGenerationFailed(
"circuit structure does not match proving key".to_string(),
));
}
if !cs.is_satisfied() {
return Err(ProofError::ProofGenerationFailed(
"witness does not satisfy constraints".to_string(),
));
}
let public_inputs = public_inputs_from_witness(&cs, witness)?;
let a = compute_proof_component(b"snark:a:statement:", &circuit_hash, &public_inputs);
let b = compute_proof_component(b"snark:b:statement:", &circuit_hash, &public_inputs);
let c = compute_c_component(&circuit_hash, &public_inputs, &a, &b);
Ok(Proof { a, b, c })
}
pub fn verify(vk: &VerifyingKey, proof: &Proof, public_inputs: &[u64]) -> Result<bool> {
crate::guard_unaudited("snark::verify")?;
if public_inputs.len() != vk.num_public_inputs {
return Ok(false);
}
let expected_c = compute_c_component(&vk.circuit_hash, public_inputs, &proof.a, &proof.b);
Ok(proof.c == expected_c)
}
fn usize_to_u64(n: usize) -> Result<u64> {
u64::try_from(n).map_err(|_| ProofError::SetupError(format!("value {n} overflows u64")))
}
fn compute_circuit_hash(cs: &ConstraintSystem) -> Result<Hash256> {
let mut hasher = blake3::Hasher::new();
hasher.update(b"snark:circuit:");
hasher.update(&usize_to_u64(cs.num_variables())?.to_le_bytes());
hasher.update(&usize_to_u64(cs.num_constraints())?.to_le_bytes());
hasher.update(&usize_to_u64(cs.num_public_inputs)?.to_le_bytes());
for constraint in &cs.constraints {
for &(coeff, idx) in &constraint.a_terms.terms {
hasher.update(&coeff.to_le_bytes());
hasher.update(&usize_to_u64(idx)?.to_le_bytes());
}
hasher.update(b"|");
for &(coeff, idx) in &constraint.b_terms.terms {
hasher.update(&coeff.to_le_bytes());
hasher.update(&usize_to_u64(idx)?.to_le_bytes());
}
hasher.update(b"|");
for &(coeff, idx) in &constraint.c_terms.terms {
hasher.update(&coeff.to_le_bytes());
hasher.update(&usize_to_u64(idx)?.to_le_bytes());
}
hasher.update(b"#");
}
Ok(Hash256::from_bytes(*hasher.finalize().as_bytes()))
}
fn validate_public_input_indices(cs: &ConstraintSystem) -> std::result::Result<(), String> {
if cs.public_input_indices.len() != cs.num_public_inputs {
return Err(format!(
"public input metadata mismatch: declared {} inputs but recorded {} indices",
cs.num_public_inputs,
cs.public_input_indices.len()
));
}
for &idx in &cs.public_input_indices {
if idx >= cs.num_variables() {
return Err(format!(
"public input index {idx} is outside variable count {}",
cs.num_variables()
));
}
}
Ok(())
}
fn public_inputs_from_witness(cs: &ConstraintSystem, witness: &[u64]) -> Result<Vec<u64>> {
let mut public_inputs = Vec::with_capacity(cs.public_input_indices.len());
for &idx in &cs.public_input_indices {
let Some(value) = witness.get(idx) else {
return Err(ProofError::InvalidWitness(format!(
"public input index {idx} is outside witness length {}",
witness.len()
)));
};
public_inputs.push(*value);
}
Ok(public_inputs)
}
fn compute_proof_component(prefix: &[u8], circuit_hash: &Hash256, values: &[u64]) -> [u8; 32] {
let mut hasher = blake3::Hasher::new();
hasher.update(prefix);
hasher.update(circuit_hash.as_bytes());
for &value in values {
hasher.update(&value.to_le_bytes());
}
*hasher.finalize().as_bytes()
}
fn compute_c_component(
circuit_hash: &Hash256,
public_inputs: &[u64],
a: &[u8; 32],
b: &[u8; 32],
) -> [u8; 32] {
let mut hasher = blake3::Hasher::new();
hasher.update(b"snark:c:verify:");
hasher.update(circuit_hash.as_bytes());
for &inp in public_inputs {
hasher.update(&inp.to_le_bytes());
}
hasher.update(a);
hasher.update(b);
*hasher.finalize().as_bytes()
}
#[cfg(all(test, feature = "unaudited-pedagogical-proofs"))]
mod tests {
use super::*;
use crate::circuit::{LinearCombination, allocate, allocate_public, enforce};
#[derive(Debug)]
struct MulCircuit {
x: Option<u64>,
y: Option<u64>,
z: Option<u64>,
}
impl Circuit for MulCircuit {
fn synthesize(&self, cs: &mut ConstraintSystem) -> crate::error::Result<()> {
let x = allocate_public(cs, self.x);
let y = allocate(cs, self.y);
let z = allocate_public(cs, self.z);
enforce(
cs,
&LinearCombination::from_variable(x),
&LinearCombination::from_variable(y),
&LinearCombination::from_variable(z),
);
Ok(())
}
}
fn make_mul_circuit(x: u64, y: u64) -> MulCircuit {
MulCircuit {
x: Some(x),
y: Some(y),
z: Some(x.checked_mul(y).expect("test witness product fits u64")),
}
}
#[test]
fn setup_produces_keys() {
let circuit = make_mul_circuit(3, 4);
let (pk, vk) = setup(&circuit).unwrap();
assert_eq!(pk.num_variables, 3);
assert_eq!(pk.num_constraints, 1);
assert_eq!(pk.num_public_inputs, 2);
assert_eq!(vk.num_public_inputs, 2);
assert_eq!(pk.circuit_hash, vk.circuit_hash);
}
#[test]
fn valid_proof_verifies() {
let circuit = make_mul_circuit(3, 4);
let (pk, vk) = setup(&circuit).unwrap();
let proof = prove(&pk, &circuit, &[3, 4, 12]).unwrap();
assert!(verify(&vk, &proof, &[3, 12]).unwrap());
}
#[test]
fn invalid_proof_rejected() {
let circuit = make_mul_circuit(3, 4);
let (pk, vk) = setup(&circuit).unwrap();
let proof = prove(&pk, &circuit, &[3, 4, 12]).unwrap();
assert!(!verify(&vk, &proof, &[3, 13]).unwrap());
assert!(!verify(&vk, &proof, &[4, 12]).unwrap());
}
#[test]
fn different_witnesses_produce_different_proofs() {
let c1 = make_mul_circuit(3, 4);
let c2 = make_mul_circuit(6, 2);
let (pk1, _) = setup(&c1).unwrap();
let (pk2, _) = setup(&c2).unwrap();
let proof1 = prove(&pk1, &c1, &[3, 4, 12]).unwrap();
let proof2 = prove(&pk2, &c2, &[6, 2, 12]).unwrap();
assert_ne!(proof1, proof2);
}
#[test]
fn wrong_witness_count_rejected() {
let circuit = make_mul_circuit(3, 4);
let (pk, _) = setup(&circuit).unwrap();
let err = prove(&pk, &circuit, &[3, 4]).unwrap_err();
assert!(matches!(err, ProofError::InvalidWitness(_)));
}
#[test]
fn proof_components_do_not_commit_private_witness_values() {
#[derive(Debug)]
struct UnderconstrainedCircuit {
public: Option<u64>,
private: Option<u64>,
}
impl Circuit for UnderconstrainedCircuit {
fn synthesize(&self, cs: &mut ConstraintSystem) -> crate::error::Result<()> {
let public = allocate_public(cs, self.public);
let _private = allocate(cs, self.private);
enforce(
cs,
&LinearCombination::from_variable(public),
&LinearCombination::constant(1),
&LinearCombination::from_variable(public),
);
Ok(())
}
}
let circuit = UnderconstrainedCircuit {
public: Some(7),
private: Some(11),
};
let (pk, _) = setup(&circuit).unwrap();
let proof_a = prove(&pk, &circuit, &[7, 11]).unwrap();
let proof_b = prove(&pk, &circuit, &[7, 99]).unwrap();
assert_eq!(
proof_a, proof_b,
"serialized proof components must not reveal deterministic commitments to private witness values"
);
}
#[test]
fn setup_rejects_out_of_range_public_input_indices() {
#[derive(Debug)]
struct InvalidPublicInputCircuit;
impl Circuit for InvalidPublicInputCircuit {
fn synthesize(&self, cs: &mut ConstraintSystem) -> crate::error::Result<()> {
let value = allocate(cs, Some(1));
enforce(
cs,
&LinearCombination::from_variable(value),
&LinearCombination::constant(1),
&LinearCombination::from_variable(value),
);
cs.public_input_indices.push(value.index + 1);
cs.num_public_inputs += 1;
Ok(())
}
}
let err = setup(&InvalidPublicInputCircuit).unwrap_err();
assert!(matches!(err, ProofError::SetupError(_)));
}
#[test]
fn unsatisfied_witness_rejected() {
let circuit = MulCircuit {
x: Some(3),
y: Some(4),
z: Some(12),
};
let (pk, _) = setup(&circuit).unwrap();
let err = prove(&pk, &circuit, &[3, 4, 13]).unwrap_err();
assert!(matches!(err, ProofError::ProofGenerationFailed(_)));
}
#[test]
fn wrong_public_input_count_rejected() {
let circuit = make_mul_circuit(3, 4);
let (pk, vk) = setup(&circuit).unwrap();
let proof = prove(&pk, &circuit, &[3, 4, 12]).unwrap();
assert!(!verify(&vk, &proof, &[3]).unwrap()); assert!(!verify(&vk, &proof, &[3, 12, 99]).unwrap()); }
#[test]
fn tampered_proof_rejected() {
let circuit = make_mul_circuit(3, 4);
let (pk, vk) = setup(&circuit).unwrap();
let mut proof = prove(&pk, &circuit, &[3, 4, 12]).unwrap();
proof.a[0] ^= 0xFF;
assert!(!verify(&vk, &proof, &[3, 12]).unwrap());
}
#[test]
fn setup_empty_circuit_rejected() {
struct EmptyCircuit;
impl Circuit for EmptyCircuit {
fn synthesize(&self, _cs: &mut ConstraintSystem) -> crate::error::Result<()> {
Ok(())
}
}
let err = setup(&EmptyCircuit).unwrap_err();
assert!(matches!(err, ProofError::SetupError(_)));
}
#[test]
fn proof_deterministic() {
let circuit = make_mul_circuit(5, 6);
let (pk, _) = setup(&circuit).unwrap();
let p1 = prove(&pk, &circuit, &[5, 6, 30]).unwrap();
let p2 = prove(&pk, &circuit, &[5, 6, 30]).unwrap();
assert_eq!(p1, p2);
}
}