use field_cat::{Field, FieldBytes};
use plonkish_cat::{Constraint, ConstraintSet, Expression, Wire};
use crate::commit::merkle::{MerkleProof, MerkleRoot, MerkleTree};
use crate::error::Error;
use crate::poly::MultilinearPoly;
use crate::sumcheck::{SumcheckClaim, SumcheckProof, sumcheck_prove, sumcheck_verify};
use crate::transcript::Transcript;
const TRANSCRIPT_LABEL: &[u8] = b"proof-cat-v0.1";
#[derive(Debug, Clone)]
pub struct Witness<F: Field> {
values: Vec<F>,
}
impl<F: Field> Witness<F> {
#[must_use]
pub fn new(values: Vec<F>) -> Self {
Self { values }
}
#[must_use]
pub fn values(&self) -> &[F] {
&self.values
}
fn assignment(&self) -> impl Fn(Wire) -> Result<F, plonkish_cat::Error> + '_ {
|wire| {
self.values
.get(wire.index())
.cloned()
.ok_or(plonkish_cat::Error::WireOutOfBounds {
wire_index: wire.index(),
allocated: self.values.len(),
})
}
}
}
#[derive(Debug, Clone)]
pub struct WireOpening<F: Field> {
wire_index: usize,
value: F,
merkle_proof: MerkleProof,
}
impl<F: Field> WireOpening<F> {
#[must_use]
pub fn wire_index(&self) -> usize {
self.wire_index
}
#[must_use]
pub fn value(&self) -> &F {
&self.value
}
#[must_use]
pub fn merkle_proof(&self) -> &MerkleProof {
&self.merkle_proof
}
}
#[derive(Debug, Clone)]
pub struct Proof<F: Field> {
witness_commitment: MerkleRoot,
sumcheck: SumcheckProof<F>,
wire_openings: Vec<WireOpening<F>>,
}
impl<F: Field> Proof<F> {
#[must_use]
pub fn witness_commitment(&self) -> &MerkleRoot {
&self.witness_commitment
}
#[must_use]
pub fn sumcheck_proof(&self) -> &SumcheckProof<F> {
&self.sumcheck
}
#[must_use]
pub fn wire_openings(&self) -> &[WireOpening<F>] {
&self.wire_openings
}
}
pub fn prove<F: FieldBytes>(
constraints: &ConstraintSet<F>,
witness: &Witness<F>,
) -> Result<Proof<F>, Error> {
let all_constraints = flatten_constraints(constraints);
if all_constraints.is_empty() {
Err(Error::EmptyConstraintSet)
} else {
validate_witness(&all_constraints, witness)?;
let tree = MerkleTree::from_field_values(witness.values());
let evals = evaluate_constraints(&all_constraints, witness)?;
let padded = pad_to_power_of_two(evals);
let poly = MultilinearPoly::from_evals(padded)?;
let transcript = Transcript::new(TRANSCRIPT_LABEL)
.absorb_bytes(tree.root().as_bytes())
.absorb_bytes(&all_constraints.len().to_le_bytes());
let (sumcheck, _, _) = sumcheck_prove(&SumcheckClaim::new(poly, F::zero()), transcript)?;
let wire_openings: Result<Vec<WireOpening<F>>, Error> = (0..witness.values().len())
.map(|i| {
let merkle_proof = tree.open(i)?;
Ok(WireOpening {
wire_index: i,
value: witness.values()[i].clone(),
merkle_proof,
})
})
.collect();
Ok(Proof {
witness_commitment: tree.root(),
sumcheck,
wire_openings: wire_openings?,
})
}
}
pub fn verify<F: FieldBytes>(
constraints: &ConstraintSet<F>,
proof: &Proof<F>,
) -> Result<bool, Error> {
let all_constraints = flatten_constraints(constraints);
if all_constraints.is_empty() {
Err(Error::EmptyConstraintSet)
} else {
let padded_len = pad_to_power_of_two_len(all_constraints.len());
let num_vars = usize::try_from(padded_len.trailing_zeros())
.map_err(|_| Error::NotPowerOfTwo { value: padded_len })?;
let transcript = Transcript::new(TRANSCRIPT_LABEL)
.absorb_bytes(proof.witness_commitment.as_bytes())
.absorb_bytes(&all_constraints.len().to_le_bytes());
let (final_eval, challenges, _) = sumcheck_verify(
&proof.sumcheck,
&F::zero(),
crate::poly::NumVars::new(num_vars),
transcript,
)?;
let all_openings_valid = proof.wire_openings.iter().all(|opening| {
MerkleTree::verify_opening(
&proof.witness_commitment,
opening.wire_index,
&opening.value,
&opening.merkle_proof,
)
});
if all_openings_valid {
let assignment = build_assignment_from_openings(&proof.wire_openings);
let evals = evaluate_constraints_with(&all_constraints, &assignment)?;
let padded = pad_to_power_of_two(evals);
let poly = MultilinearPoly::from_evals(padded)?;
let expected_eval = poly.evaluate(&challenges)?;
Ok(expected_eval == final_eval)
} else {
Err(Error::MerkleVerificationFailed)
}
}
}
fn flatten_constraints<F: Field>(cs: &ConstraintSet<F>) -> Vec<Constraint<F>> {
let polynomial: Vec<Constraint<F>> = cs.constraints().to_vec();
let from_copies: Vec<Constraint<F>> = cs
.copy_constraints()
.iter()
.map(|cc| {
let expr = Expression::Wire(cc.left()) - Expression::Wire(cc.right());
Constraint::new(expr)
})
.collect();
polynomial.into_iter().chain(from_copies).collect()
}
fn validate_witness<F: Field>(
constraints: &[Constraint<F>],
witness: &Witness<F>,
) -> Result<(), Error> {
let assign = witness.assignment();
constraints.iter().enumerate().try_for_each(|(i, c)| {
c.is_satisfied(&assign).map_err(Error::from).and_then(|ok| {
if ok {
Ok(())
} else {
Err(Error::UnsatisfiedConstraint { index: i })
}
})
})
}
fn evaluate_constraints<F: Field>(
constraints: &[Constraint<F>],
witness: &Witness<F>,
) -> Result<Vec<F>, Error> {
let assign = witness.assignment();
constraints
.iter()
.map(|c| c.expression().evaluate(&assign).map_err(Error::from))
.collect()
}
fn evaluate_constraints_with<F: Field>(
constraints: &[Constraint<F>],
assignment: &impl Fn(Wire) -> Result<F, plonkish_cat::Error>,
) -> Result<Vec<F>, Error> {
constraints
.iter()
.map(|c| c.expression().evaluate(assignment).map_err(Error::from))
.collect()
}
fn build_assignment_from_openings<F: Field>(
openings: &[WireOpening<F>],
) -> impl Fn(Wire) -> Result<F, plonkish_cat::Error> + '_ {
move |wire| {
openings
.iter()
.find(|o| o.wire_index == wire.index())
.map(|o| o.value.clone())
.ok_or(plonkish_cat::Error::WireOutOfBounds {
wire_index: wire.index(),
allocated: openings.len(),
})
}
}
fn pad_to_power_of_two<F: Field>(v: Vec<F>) -> Vec<F> {
let target = pad_to_power_of_two_len(v.len());
let padding_count = target - v.len();
v.into_iter()
.chain((0..padding_count).map(|_| F::zero()))
.collect()
}
fn pad_to_power_of_two_len(n: usize) -> usize {
if n <= 1 { 1 } else { n.next_power_of_two() }
}
#[cfg(test)]
mod tests {
use super::*;
use field_cat::F101;
use plonkish_cat::{CopyConstraint, Expression, Wire};
#[test]
fn add_gate_prove_verify() -> Result<(), Error> {
let expr = Expression::Wire(Wire::new(2))
- Expression::Wire(Wire::new(0))
- Expression::Wire(Wire::new(1));
let cs = ConstraintSet::empty().with_constraint(Constraint::new(expr));
let witness = Witness::new(vec![F101::new(3), F101::new(4), F101::new(7)]);
let proof = prove(&cs, &witness)?;
let valid = verify(&cs, &proof)?;
assert!(valid);
Ok(())
}
#[test]
fn mul_gate_prove_verify() -> Result<(), Error> {
let expr = Expression::Wire(Wire::new(2))
- Expression::Wire(Wire::new(0)) * Expression::Wire(Wire::new(1));
let cs = ConstraintSet::empty().with_constraint(Constraint::new(expr));
let witness = Witness::new(vec![F101::new(5), F101::new(6), F101::new(30)]);
let proof = prove(&cs, &witness)?;
assert!(verify(&cs, &proof)?);
Ok(())
}
#[test]
fn invalid_witness_rejected() {
let expr = Expression::Wire(Wire::new(2))
- Expression::Wire(Wire::new(0))
- Expression::Wire(Wire::new(1));
let cs = ConstraintSet::empty().with_constraint(Constraint::new(expr));
let witness = Witness::new(vec![F101::new(3), F101::new(4), F101::new(8)]);
let result = prove(&cs, &witness);
assert!(result.is_err());
}
#[test]
fn copy_constraint_prove_verify() -> Result<(), Error> {
let cs = ConstraintSet::empty().with_copy(CopyConstraint::new(Wire::new(0), Wire::new(1)));
let witness = Witness::new(vec![F101::new(42), F101::new(42)]);
let proof = prove(&cs, &witness)?;
assert!(verify(&cs, &proof)?);
Ok(())
}
#[test]
fn copy_constraint_invalid_rejected() {
let cs = ConstraintSet::empty().with_copy(CopyConstraint::new(Wire::new(0), Wire::new(1)));
let witness = Witness::new(vec![F101::new(1), F101::new(2)]);
let result = prove(&cs, &witness);
assert!(result.is_err());
}
#[test]
fn const_gate_prove_verify() -> Result<(), Error> {
let expr = Expression::Wire(Wire::new(0)) - Expression::Constant(F101::new(42));
let cs = ConstraintSet::empty().with_constraint(Constraint::new(expr));
let witness = Witness::new(vec![F101::new(42)]);
let proof = prove(&cs, &witness)?;
assert!(verify(&cs, &proof)?);
Ok(())
}
}