use std::marker::PhantomData;
use ff::Field;
use crate::crypto::transcript::Transcript;
use crate::crypto::{KzgScheme, KzgSrs, PlonkProof, ProofEvaluations, VerificationKey};
use crate::traits::PairingEngine;
#[derive(Debug, Clone)]
pub struct SimpleCircuit<E: PairingEngine> {
pub a_values: Vec<E::Fr>,
pub b_values: Vec<E::Fr>,
pub c_values: Vec<E::Fr>,
}
impl<E: PairingEngine> SimpleCircuit<E> {
pub fn multiplication(a: E::Fr, b: E::Fr) -> Self {
let c = a * b;
Self {
a_values: vec![a],
b_values: vec![b],
c_values: vec![c],
}
}
pub fn from_multiplications(pairs: Vec<(E::Fr, E::Fr)>) -> Self {
let mut a_values = Vec::with_capacity(pairs.len());
let mut b_values = Vec::with_capacity(pairs.len());
let mut c_values = Vec::with_capacity(pairs.len());
for (a, b) in pairs {
a_values.push(a);
b_values.push(b);
c_values.push(a * b);
}
Self {
a_values,
b_values,
c_values,
}
}
pub fn is_satisfied(&self) -> bool {
self.a_values
.iter()
.zip(self.b_values.iter())
.zip(self.c_values.iter())
.all(|((a, b), c)| *a * *b == *c)
}
pub fn public_inputs(&self) -> &[E::Fr] {
&self.c_values
}
}
pub struct PlonkProver<E: PairingEngine> {
kzg: KzgScheme<E>,
vk: VerificationKey<E>,
_engine: PhantomData<E>,
}
impl<E: PairingEngine> PlonkProver<E> {
pub fn new(srs: KzgSrs<E>, vk: VerificationKey<E>) -> Self {
Self {
kzg: KzgScheme::new(srs),
vk,
_engine: PhantomData,
}
}
pub fn prove(&self, circuit: &SimpleCircuit<E>) -> Result<PlonkProof<E>, String> {
if !circuit.is_satisfied() {
return Err("Circuit constraints not satisfied".to_string());
}
let n = self.vk.domain_size;
let a_poly = self.pad_and_ifft(&circuit.a_values, n);
let b_poly = self.pad_and_ifft(&circuit.b_values, n);
let c_poly = self.pad_and_ifft(&circuit.c_values, n);
let a_comm =
self.kzg.commit(&a_poly).map_err(|e| format!("Failed to commit to a: {}", e))?;
let b_comm =
self.kzg.commit(&b_poly).map_err(|e| format!("Failed to commit to b: {}", e))?;
let c_comm =
self.kzg.commit(&c_poly).map_err(|e| format!("Failed to commit to c: {}", e))?;
let mut transcript = Transcript::new("PLONK-Prover");
transcript.append_g1::<E>("a", &a_comm.point);
transcript.append_g1::<E>("b", &b_comm.point);
transcript.append_g1::<E>("c", &c_comm.point);
let beta: E::Fr = transcript.challenge_scalar::<E>("beta");
let gamma: E::Fr = transcript.challenge_scalar::<E>("gamma");
let z_poly = self.compute_permutation_polynomial(&a_poly, &b_poly, &c_poly, &beta, &gamma);
let z_comm =
self.kzg.commit(&z_poly).map_err(|e| format!("Failed to commit to z: {}", e))?;
transcript.append_g1::<E>("z", &z_comm.point);
let alpha: E::Fr = transcript.challenge_scalar::<E>("alpha");
let t_poly = self
.compute_quotient_polynomial(&a_poly, &b_poly, &c_poly, &z_poly, &alpha, &beta, &gamma);
let t_parts = self.split_quotient(&t_poly, n);
let t_comms: Vec<_> = t_parts
.iter()
.map(|p| self.kzg.commit(p).map(|c| c.point))
.collect::<Result<_, _>>()
.map_err(|e| format!("Failed to commit to t: {}", e))?;
for tc in &t_comms {
transcript.append_g1::<E>("t", tc);
}
let zeta: E::Fr = transcript.challenge_scalar::<E>("zeta");
let evaluations = self.compute_evaluations(&a_poly, &b_poly, &c_poly, &z_poly, &zeta, n);
let (opening_proof, shifted_opening_proof) = self.compute_opening_proofs(
&a_poly, &b_poly, &c_poly, &z_poly, &t_parts, &zeta, n, &alpha,
)?;
Ok(PlonkProof {
wire_commitments: [a_comm.point, b_comm.point, c_comm.point],
z_commitment: z_comm.point,
t_commitments: t_comms,
opening_proof,
shifted_opening_proof,
evaluations,
})
}
fn pad_and_ifft(&self, values: &[E::Fr], n: usize) -> Vec<E::Fr> {
let mut padded = values.to_vec();
padded.resize(n, E::Fr::ZERO);
padded
}
fn compute_permutation_polynomial(
&self,
_a: &[E::Fr],
_b: &[E::Fr],
_c: &[E::Fr],
_beta: &E::Fr,
_gamma: &E::Fr,
) -> Vec<E::Fr> {
let n = self.vk.domain_size;
let mut z = vec![E::Fr::ZERO; n];
z[0] = E::Fr::ONE;
z
}
#[allow(clippy::too_many_arguments)]
fn compute_quotient_polynomial(
&self,
a: &[E::Fr],
b: &[E::Fr],
c: &[E::Fr],
_z: &[E::Fr],
_alpha: &E::Fr,
_beta: &E::Fr,
_gamma: &E::Fr,
) -> Vec<E::Fr> {
let n = self.vk.domain_size;
let mut t = vec![E::Fr::ZERO; n * 3];
for i in 0..a.len().min(b.len()).min(c.len()) {
t[i] = a[i] * b[i] - c[i];
}
t
}
fn split_quotient(&self, t: &[E::Fr], n: usize) -> Vec<Vec<E::Fr>> {
let mut parts = Vec::new();
for chunk in t.chunks(n) {
parts.push(chunk.to_vec());
}
while parts.len() < 3 {
parts.push(vec![E::Fr::ZERO; n]);
}
parts
}
fn compute_evaluations(
&self,
a: &[E::Fr],
b: &[E::Fr],
c: &[E::Fr],
z: &[E::Fr],
zeta: &E::Fr,
n: usize,
) -> ProofEvaluations<E> {
let a_eval = self.evaluate_poly(a, zeta);
let b_eval = self.evaluate_poly(b, zeta);
let c_eval = self.evaluate_poly(c, zeta);
let omega = self.get_omega(n);
let zeta_omega = *zeta * omega;
let z_omega_eval = self.evaluate_poly(z, &zeta_omega);
ProofEvaluations {
a_eval,
b_eval,
c_eval,
s1_eval: E::Fr::ZERO,
s2_eval: E::Fr::ZERO,
z_shifted_eval: z_omega_eval,
}
}
fn evaluate_poly(&self, coeffs: &[E::Fr], point: &E::Fr) -> E::Fr {
let mut result = E::Fr::ZERO;
for coeff in coeffs.iter().rev() {
result = result * point + coeff;
}
result
}
fn get_omega(&self, n: usize) -> E::Fr {
let gen = E::Fr::from(5u64);
let mut omega = gen;
let log_n = (n as f64).log2() as usize;
for _ in 0..(256 - log_n) {
omega = omega.square();
}
omega
}
#[allow(clippy::too_many_arguments)]
fn compute_opening_proofs(
&self,
a: &[E::Fr],
b: &[E::Fr],
c: &[E::Fr],
z: &[E::Fr],
_t_parts: &[Vec<E::Fr>],
zeta: &E::Fr,
n: usize,
_alpha: &E::Fr,
) -> Result<(E::G1Affine, E::G1Affine), String> {
let mut lin = vec![E::Fr::ZERO; n];
for i in 0..n.min(a.len()) {
lin[i] = a[i] + b[i] + c[i];
}
let lin_eval = self.evaluate_poly(&lin, zeta);
let quotient = self.compute_quotient_for_opening(&lin, zeta, &lin_eval);
let opening =
self.kzg.commit("ient).map_err(|e| format!("Opening proof failed: {}", e))?;
let omega = self.get_omega(n);
let zeta_omega = *zeta * omega;
let z_eval = self.evaluate_poly(z, &zeta_omega);
let shifted_quotient = self.compute_quotient_for_opening(z, &zeta_omega, &z_eval);
let shifted_opening = self
.kzg
.commit(&shifted_quotient)
.map_err(|e| format!("Shifted opening proof failed: {}", e))?;
Ok((opening.point, shifted_opening.point))
}
fn compute_quotient_for_opening(
&self,
poly: &[E::Fr],
point: &E::Fr,
value: &E::Fr,
) -> Vec<E::Fr> {
let mut quotient = vec![E::Fr::ZERO; poly.len()];
if poly.is_empty() {
return quotient;
}
let mut remainder = poly[poly.len() - 1];
for i in (0..poly.len() - 1).rev() {
quotient[i + 1] = remainder;
remainder = poly[i] + remainder * point;
}
quotient[0] = remainder - value;
quotient.remove(0);
quotient
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::bn254::Bn254;
use group::{Curve, Group};
use halo2curves::bn256::{Fr, G1, G2};
use rand::rngs::OsRng;
fn mock_srs(size: usize) -> KzgSrs<Bn254> {
let tau = Fr::random(OsRng);
let g1_gen = G1::generator();
let g2_gen = G2::generator();
let mut powers_of_tau_g1 = Vec::with_capacity(size);
let mut current = Fr::ONE;
for _ in 0..size {
powers_of_tau_g1.push((g1_gen * current).to_affine());
current *= tau;
}
KzgSrs {
powers_of_tau_g1,
tau_g2: (g2_gen * tau).to_affine(),
g2_generator: g2_gen.to_affine(),
}
}
fn mock_vk(domain_size: usize) -> VerificationKey<Bn254> {
let g1_gen = G1::generator().to_affine();
let g2_gen = G2::generator().to_affine();
VerificationKey {
num_public_inputs: 1,
domain_size,
selector_commitments: vec![g1_gen; 5],
permutation_commitments: vec![g1_gen; 3],
x_g2: g2_gen,
g2_generator: g2_gen,
}
}
#[test]
fn simple_circuit_multiplication_is_satisfied() {
let a = Fr::from(3u64);
let b = Fr::from(7u64);
let circuit = SimpleCircuit::<Bn254>::multiplication(a, b);
assert!(circuit.is_satisfied());
assert_eq!(circuit.c_values[0], Fr::from(21u64));
}
#[test]
fn simple_circuit_multiple_gates() {
let pairs = vec![
(Fr::from(2u64), Fr::from(3u64)),
(Fr::from(4u64), Fr::from(5u64)),
(Fr::from(6u64), Fr::from(7u64)),
];
let circuit = SimpleCircuit::<Bn254>::from_multiplications(pairs);
assert!(circuit.is_satisfied());
assert_eq!(circuit.c_values[0], Fr::from(6u64));
assert_eq!(circuit.c_values[1], Fr::from(20u64));
assert_eq!(circuit.c_values[2], Fr::from(42u64));
}
#[test]
fn prover_generates_proof_for_simple_circuit() {
let domain_size = 8;
let srs = mock_srs(domain_size * 4);
let vk = mock_vk(domain_size);
let prover = PlonkProver::<Bn254>::new(srs, vk);
let circuit = SimpleCircuit::<Bn254>::multiplication(Fr::from(3u64), Fr::from(7u64));
let proof = prover.prove(&circuit);
assert!(proof.is_ok(), "Prover should generate a proof");
let proof = proof.unwrap();
assert_eq!(proof.wire_commitments.len(), 3);
assert_eq!(proof.t_commitments.len(), 3);
}
#[test]
fn prover_rejects_unsatisfied_circuit() {
let domain_size = 8;
let srs = mock_srs(domain_size * 4);
let vk = mock_vk(domain_size);
let prover = PlonkProver::<Bn254>::new(srs, vk);
let circuit = SimpleCircuit::<Bn254> {
a_values: vec![Fr::from(3u64)],
b_values: vec![Fr::from(7u64)],
c_values: vec![Fr::from(100u64)], };
let proof = prover.prove(&circuit);
assert!(proof.is_err(), "Prover should reject unsatisfied circuit");
}
#[test]
fn proof_has_valid_structure() {
let domain_size = 16;
let srs = mock_srs(domain_size * 4);
let vk = mock_vk(domain_size);
let prover = PlonkProver::<Bn254>::new(srs, vk);
let circuit = SimpleCircuit::<Bn254>::from_multiplications(vec![
(Fr::from(2u64), Fr::from(5u64)),
(Fr::from(3u64), Fr::from(4u64)),
]);
let proof = prover.prove(&circuit).unwrap();
assert!(!proof.wire_commitments.is_empty());
}
}