use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use halo2_proofs::{
pasta::EqAffine,
plonk::{self, create_proof, keygen_pk, keygen_vk, verify_proof, SingleVerifier},
poly::commitment::Params,
transcript::{Blake2bRead, Blake2bWrite, Challenge255},
};
use pasta_curves::{pallas, vesta};
use rand::rngs::OsRng;
use super::circuit::{Circuit, Instance, K};
const NUM_PUBLIC_INPUTS: usize = 9;
#[cfg(feature = "std")]
static VOTE_PROOF_PK_CACHE: std::sync::OnceLock<(
Params<EqAffine>,
plonk::ProvingKey<EqAffine>,
plonk::VerifyingKey<EqAffine>,
)> = std::sync::OnceLock::new();
#[cfg(feature = "std")]
fn get_vote_proof_keys() -> &'static (Params<EqAffine>, plonk::ProvingKey<EqAffine>, plonk::VerifyingKey<EqAffine>) {
VOTE_PROOF_PK_CACHE.get_or_init(|| {
let params = Params::new(K);
let empty_circuit = Circuit::default();
let vk = keygen_vk(¶ms, &empty_circuit)
.expect("vote_proof keygen_vk should not fail");
let pk = keygen_pk(¶ms, vk.clone(), &empty_circuit)
.expect("vote_proof keygen_pk should not fail");
(params, pk, vk)
})
}
pub fn vote_proof_params() -> Params<EqAffine> {
Params::new(K)
}
pub fn vote_proof_proving_key(
params: &Params<EqAffine>,
) -> (
plonk::ProvingKey<EqAffine>,
plonk::VerifyingKey<EqAffine>,
) {
let empty_circuit = Circuit::default();
let vk = keygen_vk(params, &empty_circuit).expect("vote_proof keygen_vk should not fail");
let pk = keygen_pk(params, vk.clone(), &empty_circuit)
.expect("vote_proof keygen_pk should not fail");
(pk, vk)
}
pub fn create_vote_proof(circuit: Circuit, instance: &Instance) -> Vec<u8> {
#[cfg(feature = "std")]
let (params, pk, _vk) = get_vote_proof_keys();
#[cfg(not(feature = "std"))]
let (params_owned, pk, _vk) = {
let p = vote_proof_params();
let (pk, vk) = vote_proof_proving_key(&p);
(p, pk, vk)
};
#[cfg(not(feature = "std"))]
let params = ¶ms_owned;
let public_inputs = instance.to_halo2_instance();
let mut transcript = Blake2bWrite::<_, EqAffine, Challenge255<_>>::init(vec![]);
create_proof(
params,
pk,
&[circuit],
&[&[&public_inputs]],
OsRng,
&mut transcript,
)
.expect("vote proof generation should not fail");
transcript.finalize()
}
pub fn verify_vote_proof(
proof: &[u8],
instance: &Instance,
) -> Result<(), String> {
#[cfg(feature = "std")]
let (params, _pk, vk) = get_vote_proof_keys();
#[cfg(not(feature = "std"))]
let (params_owned, _pk, vk) = {
let p = vote_proof_params();
let (pk, vk) = vote_proof_proving_key(&p);
(p, pk, vk)
};
#[cfg(not(feature = "std"))]
let params = ¶ms_owned;
let public_inputs = instance.to_halo2_instance();
let strategy = SingleVerifier::new(params);
let mut transcript = Blake2bRead::<_, EqAffine, Challenge255<_>>::init(proof);
verify_proof(params, vk, strategy, &[&[&public_inputs]], &mut transcript)
.map_err(|e| format!("vote proof verification failed: {:?}", e))
}
pub fn verify_vote_proof_raw(
proof: &[u8],
public_inputs_bytes: &[u8],
) -> Result<(), String> {
use pasta_curves::group::ff::PrimeField;
let expected_len = NUM_PUBLIC_INPUTS * 32;
if public_inputs_bytes.len() != expected_len {
return Err(format!(
"expected {} bytes ({} × 32) for public inputs, got {}",
expected_len, NUM_PUBLIC_INPUTS, public_inputs_bytes.len()
));
}
let mut public_inputs: Vec<vesta::Scalar> = Vec::with_capacity(NUM_PUBLIC_INPUTS);
for i in 0..NUM_PUBLIC_INPUTS {
let start = i * 32;
let mut repr = [0u8; 32];
repr.copy_from_slice(&public_inputs_bytes[start..start + 32]);
let fp_opt: Option<pallas::Base> = pallas::Base::from_repr(repr).into();
match fp_opt {
Some(f) => public_inputs.push(f),
None => {
return Err(format!(
"public input {} is not a canonical Pallas Fp encoding",
i
))
}
}
}
#[cfg(feature = "std")]
let (params, _pk, vk) = get_vote_proof_keys();
#[cfg(not(feature = "std"))]
let (params_owned, _pk, vk) = {
let p = vote_proof_params();
let (pk, vk) = vote_proof_proving_key(&p);
(p, pk, vk)
};
#[cfg(not(feature = "std"))]
let params = ¶ms_owned;
let strategy = SingleVerifier::new(params);
let mut transcript = Blake2bRead::<_, EqAffine, Challenge255<_>>::init(proof);
verify_proof(
params,
vk,
strategy,
&[&[&public_inputs]],
&mut transcript,
)
.map_err(|e| format!("vote proof verification failed: {:?}", e))
}