use std::collections::HashMap;
use std::fs::create_dir_all;
use std::panic::{self, AssertUnwindSafe};
use std::path::PathBuf;
use std::sync::{Arc, LazyLock, RwLock};
use std::time::Instant;
use anyhow::{Context, anyhow};
use midnight_circuits::hash::poseidon::PoseidonState;
use midnight_curves::Bls12;
use midnight_proofs::plonk::Error as PlonkError;
use midnight_proofs::poly::kzg::params::ParamsKZG;
use midnight_zk_stdlib as zk;
use midnight_zk_stdlib::{MidnightCircuit, MidnightPK, MidnightVK};
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
use crate::circuits::halo2::circuit::StmCircuit;
use crate::circuits::halo2::errors::StmCircuitError;
use crate::circuits::halo2::types::CircuitBase;
use crate::circuits::halo2::witness::{
CircuitMerkleTreeLeaf, CircuitWitnessEntry, LotteryTargetValue as CircuitLotteryTargetValue,
MerklePath, MerkleRoot, SignedMessageWithoutPrefix,
};
use crate::circuits::test_utils::setup::{generate_params, load_params};
use crate::hash::poseidon::MidnightPoseidonDigest;
use crate::membership_commitment::{
MerkleTree as StmMerkleTree, MerkleTreeSnarkLeaf as StmMerkleTreeSnarkLeaf,
};
use crate::signature_scheme::{
BaseFieldElement, SchnorrSigningKey, SchnorrVerificationKey, UniqueSchnorrSignature,
};
use crate::{LotteryIndex, LotteryTargetValue, Parameters, StmError, StmResult};
const DEFAULT_NUM_SIGNERS: usize = 3000;
pub(crate) const LOTTERIES_PER_K: u32 = 10;
const DEFAULT_TEST_MSG: u64 = 42;
type CircuitVerificationAndProvingKeyPair = (MidnightVK, MidnightPK<StmCircuit>);
type CircuitKeysCache = HashMap<StmCircuitConfig, Arc<CircuitVerificationAndProvingKeyPair>>;
fn checked_len_u32(actual: usize) -> u32 {
u32::try_from(actual).unwrap_or(u32::MAX)
}
pub(crate) fn assert_proof_rejected_by_verifier(result: StmResult<()>) {
match result {
Err(error) => {
let has_verification_rejected = error.chain().any(|err| {
matches!(
err.downcast_ref::<StmCircuitError>(),
Some(StmCircuitError::VerificationRejected)
)
});
assert!(
has_verification_rejected,
"expected verification failure, got: {error}"
);
}
other => panic!("expected verification failure, got: {other:?}"),
}
}
pub(crate) fn assert_proving_circuit_error<T>(result: StmResult<T>) -> StmCircuitError {
match result {
Err(error) => error
.chain()
.find_map(|err| err.downcast_ref::<StmCircuitError>())
.cloned()
.unwrap_or_else(|| panic!("expected StmCircuitError in source chain, got: {error}")),
_ => panic!("expected circuit proving failure"),
}
}
pub(crate) fn assert_proving_backend_message_contains<T>(result: StmResult<T>, expected: &str) {
match result {
Err(error) => {
let message = error.chain().find_map(|err| match err.downcast_ref::<PlonkError>() {
Some(PlonkError::Synthesis(message)) => Some(message),
_ => None,
});
let message = message.unwrap_or_else(|| {
panic!("expected PlonkError::Synthesis in source chain, got: {error}")
});
assert!(
message.contains(expected),
"expected backend message to contain '{expected}', got '{message}'"
);
}
_ => panic!("expected circuit proving failure"),
}
}
fn validate_relation_for_setup(relation: &StmCircuit) -> StmResult<()> {
relation
.validate_parameters()
.context("Circuit parameter validation failed before setup")
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
struct StmCircuitConfig {
circuit_degree: u32,
k: u32,
m: u32,
merkle_tree_depth: u32,
}
pub(crate) struct StmCircuitEnv {
srs: ParamsKZG<Bls12>,
relation: StmCircuit,
vk: MidnightVK,
pk: MidnightPK<StmCircuit>,
num_signers: usize,
m: u32,
}
pub(crate) struct StmCircuitScenario {
merkle_tree_commitment: MerkleRoot,
message: SignedMessageWithoutPrefix,
witness: Vec<CircuitWitnessEntry>,
}
#[derive(Clone)]
pub(crate) struct SignerFixture {
sk: SchnorrSigningKey,
vk: SchnorrVerificationKey,
circuit_lottery_target_value: CircuitLotteryTargetValue,
stm_lottery_target_value: LotteryTargetValue,
}
impl From<&SignerFixture> for StmMerkleTreeSnarkLeaf {
fn from(value: &SignerFixture) -> Self {
StmMerkleTreeSnarkLeaf(value.vk, value.stm_lottery_target_value)
}
}
impl From<&SignerFixture> for CircuitMerkleTreeLeaf {
fn from(value: &SignerFixture) -> Self {
CircuitMerkleTreeLeaf(value.vk, value.circuit_lottery_target_value)
}
}
fn stm_lottery_target_value_from_circuit(
lottery_target_value: CircuitLotteryTargetValue,
) -> LotteryTargetValue {
lottery_target_value.into()
}
fn generate_signer_fixture(
rng: &mut ChaCha20Rng,
lottery_target_value: CircuitLotteryTargetValue,
) -> StmResult<SignerFixture> {
let stm_sk = SchnorrSigningKey::generate(rng);
let stm_vk = SchnorrVerificationKey::new_from_signing_key(stm_sk.clone());
let stm_lottery_target_value = stm_lottery_target_value_from_circuit(lottery_target_value);
Ok(SignerFixture {
sk: stm_sk,
vk: stm_vk,
circuit_lottery_target_value: lottery_target_value,
stm_lottery_target_value,
})
}
pub(crate) struct StmMerkleTreeWrapper {
stm_tree: StmMerkleTree<MidnightPoseidonDigest, StmMerkleTreeSnarkLeaf>,
merkle_tree_commitment: MerkleRoot,
signer_fixtures: Vec<SignerFixture>,
}
pub(crate) enum LeafSelector {
LeftMost,
RightMost,
Index(usize),
}
impl StmMerkleTreeWrapper {
pub(crate) fn merkle_tree_commitment(&self) -> MerkleRoot {
self.merkle_tree_commitment
}
pub(crate) fn get_path(&self, i: usize) -> StmResult<MerklePath> {
let stm_path = self.stm_tree.compute_merkle_tree_path(i);
let stm_leaf: StmMerkleTreeSnarkLeaf = self.signer_fixture(i)?.into();
self.stm_tree
.to_merkle_tree_commitment()
.verify_leaf_membership_from_path(&stm_leaf, &stm_path)
.map_err(|_| anyhow!(StmCircuitError::MerklePathVerificationFailed))?;
(&stm_path).try_into().map_err(Into::into)
}
fn signer_fixture(&self, i: usize) -> StmResult<&SignerFixture> {
self.signer_fixtures.get(i).ok_or_else(|| {
anyhow!(StmCircuitError::InvalidSignerFixtureIndex {
index: checked_len_u32(i),
num_signers: checked_len_u32(self.signer_fixtures.len()),
})
})
}
}
fn decode_merkle_root(root_bytes: &[u8]) -> StmResult<MerkleRoot> {
let actual = root_bytes.len();
let root_array: [u8; 32] = root_bytes.try_into().map_err(|_| {
anyhow!(StmCircuitError::InvalidMerkleRootDigestLength {
actual: checked_len_u32(actual),
})
})?;
BaseFieldElement::from_bytes(&root_array)
.ok()
.map(Into::into)
.ok_or_else(|| anyhow!(StmCircuitError::NonCanonicalMerkleRootDigest))
}
fn build_merkle_tree_wrapper(
n: usize,
selected_index: Option<usize>,
lottery_target_value: CircuitLotteryTargetValue,
) -> StmResult<StmMerkleTreeWrapper> {
if let Some(i) = selected_index
&& i >= n
{
return Err(anyhow!(StmCircuitError::InvalidSelectedLeafIndex {
index: checked_len_u32(i),
num_leaves: checked_len_u32(n),
}));
}
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
let mut signer_fixtures = Vec::with_capacity(n);
for i in 0..n {
let leaf_lottery_target_value = if selected_index == Some(i) {
lottery_target_value
} else {
-CircuitLotteryTargetValue::ONE
};
signer_fixtures.push(generate_signer_fixture(
&mut rng,
leaf_lottery_target_value,
)?);
}
let stm_leaves: Vec<StmMerkleTreeSnarkLeaf> = signer_fixtures.iter().map(Into::into).collect();
let stm_tree =
StmMerkleTree::<MidnightPoseidonDigest, StmMerkleTreeSnarkLeaf>::new(&stm_leaves);
let root_bytes = stm_tree.to_merkle_tree_commitment().root;
let merkle_tree_commitment = decode_merkle_root(root_bytes.as_slice())?;
Ok(StmMerkleTreeWrapper {
stm_tree,
merkle_tree_commitment,
signer_fixtures,
})
}
pub(crate) fn create_default_merkle_tree(n: usize) -> StmResult<StmMerkleTreeWrapper> {
build_merkle_tree_wrapper(n, None, -CircuitLotteryTargetValue::ONE)
}
pub(crate) fn create_merkle_tree_with_leaf_selector(
depth: u32,
selector: LeafSelector,
lottery_target_value: CircuitLotteryTargetValue,
) -> StmResult<(StmMerkleTreeWrapper, usize)> {
if depth >= usize::BITS {
return Err(anyhow!(StmCircuitError::InvalidMerkleTreeDepth { depth }));
}
let n = 1usize << depth;
let selected_index = match selector {
LeafSelector::LeftMost => 0usize,
LeafSelector::RightMost => n - 1,
LeafSelector::Index(i) => {
if i >= n {
return Err(anyhow!(StmCircuitError::InvalidSelectedLeafIndex {
index: checked_len_u32(i),
num_leaves: checked_len_u32(n),
}));
}
i
}
};
let tree = build_merkle_tree_wrapper(n, Some(selected_index), lottery_target_value)?;
Ok((tree, selected_index))
}
fn transcript_message(
merkle_tree_commitment: MerkleRoot,
message: SignedMessageWithoutPrefix,
) -> [BaseFieldElement; 2] {
[merkle_tree_commitment.into(), message.into()]
}
fn assert_challenge_endianness(sig: &UniqueSchnorrSignature) -> StmResult<()> {
let challenge_bytes = sig.challenge.to_bytes();
let challenge_base = CircuitBase::from_bytes_le(&challenge_bytes)
.into_option()
.ok_or_else(|| anyhow!(StmCircuitError::InvalidChallengeBytes))?;
if SignedMessageWithoutPrefix::from(challenge_base)
!= SignedMessageWithoutPrefix::from(sig.challenge)
{
return Err(anyhow!(StmCircuitError::ChallengeEndiannessMismatch));
}
Ok(())
}
fn sign_and_verify_lottery_message(
signer_fixture: &SignerFixture,
merkle_tree_commitment: MerkleRoot,
message: SignedMessageWithoutPrefix,
rng: &mut ChaCha20Rng,
) -> StmResult<UniqueSchnorrSignature> {
let transcript = transcript_message(merkle_tree_commitment, message);
let stm_sig = signer_fixture
.sk
.sign(&transcript, rng)
.map_err(|_| anyhow!(StmCircuitError::SignatureGenerationFailed))?;
assert_challenge_endianness(&stm_sig)?;
stm_sig
.verify(&transcript, &signer_fixture.vk)
.map_err(|_| anyhow!(StmCircuitError::SignatureVerificationFailed))?;
Ok(stm_sig)
}
pub(crate) fn build_witness(
merkle_tree: &StmMerkleTreeWrapper,
merkle_tree_commitment: MerkleRoot,
message: SignedMessageWithoutPrefix,
k: u32,
) -> StmResult<Vec<CircuitWitnessEntry>> {
let indices: Vec<LotteryIndex> = (0..k).map(u64::from).collect();
build_witness_with_indices(merkle_tree, merkle_tree_commitment, message, &indices)
}
pub(crate) fn build_witness_with_indices(
merkle_tree: &StmMerkleTreeWrapper,
merkle_tree_commitment: MerkleRoot,
message: SignedMessageWithoutPrefix,
indices: &[LotteryIndex],
) -> StmResult<Vec<CircuitWitnessEntry>> {
build_witness_internal(
merkle_tree,
merkle_tree_commitment,
message,
WitnessBuildMode::Indices(indices),
)
}
pub(crate) fn build_witness_with_fixed_signer(
merkle_tree: &StmMerkleTreeWrapper,
signer_index: usize,
merkle_tree_commitment: MerkleRoot,
message: SignedMessageWithoutPrefix,
indices: &[LotteryIndex],
) -> StmResult<Vec<CircuitWitnessEntry>> {
build_witness_internal(
merkle_tree,
merkle_tree_commitment,
message,
WitnessBuildMode::FixedSigner {
signer_index,
indices,
},
)
}
enum WitnessBuildMode<'a> {
Indices(&'a [LotteryIndex]),
FixedSigner {
signer_index: usize,
indices: &'a [LotteryIndex],
},
}
fn build_witness_internal(
merkle_tree: &StmMerkleTreeWrapper,
merkle_tree_commitment: MerkleRoot,
message: SignedMessageWithoutPrefix,
mode: WitnessBuildMode<'_>,
) -> StmResult<Vec<CircuitWitnessEntry>> {
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
let mut witness = Vec::new();
match mode {
WitnessBuildMode::Indices(indices) => {
if indices.is_empty() {
return Err(anyhow!(StmCircuitError::EmptyIndices));
}
let num_signers = merkle_tree.signer_fixtures.len();
if num_signers == 0 {
return Err(anyhow!(StmCircuitError::EmptySignerLeaves));
}
for (i, index) in indices.iter().enumerate() {
let signer_index = i % num_signers;
let signer_fixture = merkle_tree.signer_fixture(signer_index)?;
let stm_sig = sign_and_verify_lottery_message(
signer_fixture,
merkle_tree_commitment,
message,
&mut rng,
)?;
let merkle_path = merkle_tree.get_path(signer_index)?;
witness.push(CircuitWitnessEntry {
leaf: signer_fixture.into(),
merkle_path,
unique_schnorr_signature: stm_sig,
lottery_index: *index,
});
}
}
WitnessBuildMode::FixedSigner {
signer_index,
indices,
} => {
if indices.is_empty() {
return Err(anyhow!(StmCircuitError::EmptyIndices));
}
let signer_fixture = merkle_tree.signer_fixture(signer_index)?;
let merkle_path = merkle_tree.get_path(signer_index)?;
let stm_sig = sign_and_verify_lottery_message(
signer_fixture,
merkle_tree_commitment,
message,
&mut rng,
)?;
for index in indices {
witness.push(CircuitWitnessEntry {
leaf: signer_fixture.into(),
merkle_path: merkle_path.clone(),
unique_schnorr_signature: stm_sig,
lottery_index: *index,
});
}
}
}
Ok(witness)
}
pub(crate) fn find_two_distinct_witness_entries(
witness: &[CircuitWitnessEntry],
) -> StmResult<(usize, usize)> {
if witness.len() < 2 {
return Err(anyhow!(StmCircuitError::WitnessTooShort {
actual: checked_len_u32(witness.len()),
}));
}
let leaf0 = witness[0].leaf;
let leaf1 = witness[1].leaf;
if leaf0 != leaf1 {
return Ok((0, 1));
}
Err(anyhow!(StmCircuitError::NoDistinctWitnessEntries))
}
impl StmCircuitEnv {
pub(crate) fn num_signers(&self) -> usize {
self.num_signers
}
pub(crate) fn m(&self) -> u32 {
self.m
}
}
impl StmCircuitScenario {
pub(crate) fn new(
merkle_tree_commitment: MerkleRoot,
message: SignedMessageWithoutPrefix,
witness: Vec<CircuitWitnessEntry>,
) -> Self {
Self {
merkle_tree_commitment,
message,
witness,
}
}
}
pub(crate) fn setup_stm_circuit_env(
case_name: &str,
circuit_degree: u32,
k: u32,
m: u32,
) -> StmResult<StmCircuitEnv> {
let srs = load_or_generate_params(circuit_degree)?;
let num_signers: usize = DEFAULT_NUM_SIGNERS;
let depth = num_signers.next_power_of_two().trailing_zeros();
let stm_params = Parameters {
k: k as u64,
m: m as u64,
phi_f: 0.2,
};
let relation = StmCircuit::try_new(&stm_params, depth).unwrap();
validate_relation_for_setup(&relation)?;
{
let stm_circuit = MidnightCircuit::from_relation(&relation);
println!("\n=== STM circuit case: {case_name} ===");
println!("circuit degree (selected) {circuit_degree}");
println!("k {k}");
println!("min_k {:?}", stm_circuit.min_k());
println!("{:?}", zk::cost_model(&relation));
}
let config = StmCircuitConfig {
circuit_degree,
k,
m,
merkle_tree_depth: depth,
};
let key_pair = get_or_build_circuit_keys(config, &relation, &srs)?;
let (vk, pk) = (&key_pair.0, &key_pair.1);
Ok(StmCircuitEnv {
srs,
relation,
vk: vk.clone(),
pk: pk.clone(),
num_signers,
m,
})
}
pub(crate) fn prove_and_verify_result(
env: &StmCircuitEnv,
scenario: StmCircuitScenario,
) -> StmResult<()> {
let instance = (scenario.merkle_tree_commitment, scenario.message);
let start = Instant::now();
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
let proof = zk::prove::<StmCircuit, PoseidonState<CircuitBase>>(
&env.srs,
&env.pk,
&env.relation,
&instance,
scenario.witness,
&mut rng,
)
.map_err(map_proving_backend_error)?;
let duration = start.elapsed();
println!("\nProof generation took: {:?}", duration);
println!("Proof size: {:?}", proof.len());
let start = Instant::now();
let verify_result = zk::verify::<StmCircuit, PoseidonState<CircuitBase>>(
&env.srs.verifier_params(),
&env.vk,
&instance,
None,
&proof,
);
let duration = start.elapsed();
println!("Proof verification took: {:?}", duration);
if verify_result.is_ok() {
Ok(())
} else {
Err(anyhow!(StmCircuitError::VerificationRejected))
.context(format!("Proof verification step failed: {verify_result:?}"))
}
}
fn map_proving_backend_error(error: PlonkError) -> StmError {
anyhow::Error::new(error).context("Proving step failed")
}
pub(crate) fn run_stm_circuit_case_default(
case_name: &str,
circuit_degree: u32,
k: u32,
) -> StmResult<()> {
run_stm_circuit_case(
case_name,
circuit_degree,
k,
SignedMessageWithoutPrefix::from(DEFAULT_TEST_MSG),
)
}
pub(crate) fn run_stm_circuit_case(
case_name: &str,
circuit_degree: u32,
k: u32,
message: SignedMessageWithoutPrefix,
) -> StmResult<()> {
let m = k * LOTTERIES_PER_K;
let env = setup_stm_circuit_env(case_name, circuit_degree, k, m)?;
let merkle_tree = create_default_merkle_tree(env.num_signers())?;
let merkle_tree_commitment = merkle_tree.merkle_tree_commitment();
let witness = build_witness(&merkle_tree, merkle_tree_commitment, message, k)?;
let scenario = StmCircuitScenario::new(merkle_tree_commitment, message, witness);
prove_and_verify_result(&env, scenario)
}
fn load_or_generate_params(circuit_degree: u32) -> StmResult<ParamsKZG<Bls12>> {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let assets_dir = manifest_dir.join("src").join("circuits").join("halo2").join("assets");
let path = assets_dir.join(format!("params_kzg_unsafe_{}", circuit_degree));
if path.exists() {
return Ok(load_params(path.to_string_lossy().as_ref()));
}
create_dir_all(&assets_dir)
.map_err(|_| anyhow!(StmCircuitError::ParamsAssetsDirCreate))
.with_context(|| {
format!(
"Failed to create params assets directory at '{}'",
assets_dir.to_string_lossy()
)
})?;
Ok(generate_params(
circuit_degree,
path.to_string_lossy().as_ref(),
))
}
fn get_or_build_circuit_keys(
config: StmCircuitConfig,
relation: &StmCircuit,
srs: &ParamsKZG<Bls12>,
) -> StmResult<Arc<CircuitVerificationAndProvingKeyPair>> {
static STM_CIRCUIT_KEYS_CACHE: LazyLock<RwLock<CircuitKeysCache>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
if let Some(key_pair) = STM_CIRCUIT_KEYS_CACHE
.read()
.map_err(|_| anyhow!(StmCircuitError::CircuitKeysCacheLockPoisoned { operation: "read" }))?
.get(&config)
.cloned()
{
return Ok(key_pair);
}
let start = Instant::now();
let (vk, pk) = panic::catch_unwind(AssertUnwindSafe(|| {
let vk = zk::setup_vk(srs, relation);
let pk = zk::setup_pk(relation, &vk);
(vk, pk)
}))
.map_err(|panic_payload| {
let details = panic_payload
.downcast_ref::<&str>()
.map(|message| (*message).to_owned())
.or_else(|| panic_payload.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "non-string panic payload".to_string());
anyhow!("Midnight setup panicked before proving: {details}")
})?;
let duration = start.elapsed();
println!("\nvk pk generation took: {:?}", duration);
let key_pair = Arc::new((vk, pk));
STM_CIRCUIT_KEYS_CACHE
.write()
.map_err(|_| anyhow!(StmCircuitError::CircuitKeysCacheLockPoisoned { operation: "write" }))?
.insert(config, key_pair.clone());
Ok(key_pair)
}