use std::collections::{BTreeMap, btree_map::Entry};
use crate::{
AggregationError, BaseFieldElement, LotteryIndex, MembershipDigest, SignerIndex,
SingleSignature, StmResult,
circuits::MerklePath as Halo2MerklePath,
circuits::halo2::witness::CircuitWitnessEntry,
circuits::{CircuitInstance, CircuitMerkleTreeLeaf, CircuitWitness},
membership_commitment::{MerkleTree, MerkleTreeSnarkLeaf},
proof_system::{
AggregateVerificationKeyForSnark, SnarkClerk,
halo2_snark::{build_snark_message, compute_winning_lottery_indices},
},
};
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct SnarkProverInput {
instance: CircuitInstance,
witness: CircuitWitness,
}
#[allow(dead_code)]
impl SnarkProverInput {
pub fn prepare_prover_input<D: MembershipDigest>(
clerk: &SnarkClerk,
signatures: &[SingleSignature],
message: &[u8],
) -> StmResult<SnarkProverInput> {
let avk: AggregateVerificationKeyForSnark<D> =
clerk.compute_aggregate_verification_key_for_snark();
let message_to_sign = build_snark_message(&avk.get_merkle_tree_commitment().root, message)?;
let valid_signatures_with_indices =
Self::collect_valid_signatures_with_indices(clerk, signatures, &message_to_sign);
let unique_index_signature_map = SnarkClerk::select_valid_signatures_for_k_indices(
&clerk.parameters,
&valid_signatures_with_indices,
&message_to_sign,
)?;
let witness = Self::create_witness::<D>(unique_index_signature_map, clerk)?;
let instance = (message_to_sign[0].into(), message_to_sign[1].into());
Ok(SnarkProverInput { witness, instance })
}
fn collect_valid_signatures_with_indices(
clerk: &SnarkClerk,
signatures: &[SingleSignature],
message_to_sign: &[BaseFieldElement; 2],
) -> Vec<SingleSignature> {
signatures
.iter()
.filter_map(|sig| {
let snark_sig = sig.snark_signature.as_ref()?;
let reg_entry = match clerk.get_snark_registration_entry(sig.signer_index) {
Ok(Some(entry)) => entry,
Ok(None) => return None,
Err(_) => return None,
};
snark_sig
.verify_with_prepared_message(®_entry.0, message_to_sign)
.ok()?;
let indices = compute_winning_lottery_indices(
clerk.parameters.m,
message_to_sign,
&snark_sig.get_schnorr_signature(),
reg_entry.1,
)
.ok()?;
let mut new_sig = sig.clone();
new_sig.set_snark_signature_indices(&indices);
Some(new_sig)
})
.collect()
}
fn create_witness<D: MembershipDigest>(
unique_index_signature_map: BTreeMap<LotteryIndex, SingleSignature>,
clerk: &SnarkClerk,
) -> StmResult<CircuitWitness> {
let merkle_tree: MerkleTree<D::SnarkHash, MerkleTreeSnarkLeaf> =
clerk.closed_key_registration.to_merkle_tree();
let mut unique_signers: BTreeMap<SignerIndex, (MerkleTreeSnarkLeaf, Halo2MerklePath)> =
BTreeMap::new();
for sig in unique_index_signature_map.values() {
if let Entry::Vacant(vacant) = unique_signers.entry(sig.signer_index) {
let leaf = match clerk.get_snark_registration_entry(sig.signer_index) {
Ok(Some(entry)) => entry,
Ok(None) => continue,
Err(_) => continue,
};
let merkle_path = merkle_tree.compute_merkle_tree_path(sig.signer_index as usize);
let merkle_path_circuit: Halo2MerklePath = Halo2MerklePath::try_from(&merkle_path)?;
vacant.insert((leaf, merkle_path_circuit));
}
}
unique_index_signature_map
.into_iter()
.map(|(lottery_index, sig)| {
let (leaf, merkle_path) = unique_signers
.get(&sig.signer_index)
.ok_or(AggregationError::MissingSnarkSignerData(sig.signer_index))?;
let schnorr_sig = sig
.snark_signature
.as_ref()
.ok_or(AggregationError::MissingSnarkSignature(lottery_index))?
.get_schnorr_signature();
let circuit_leaf = CircuitMerkleTreeLeaf(leaf.0, leaf.1.into());
Ok(CircuitWitnessEntry {
leaf: circuit_leaf,
merkle_path: merkle_path.clone(),
unique_schnorr_signature: schnorr_sig,
lottery_index,
})
})
.collect()
}
pub fn get_instance(&self) -> CircuitInstance {
self.instance
}
pub fn into_witness(self) -> CircuitWitness {
self.witness
}
}
#[cfg(test)]
mod tests {
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
use crate::{
AggregationError, Initializer, KeyRegistration, MithrilMembershipDigest, Parameters,
Signer, SingleSignature,
proof_system::{SnarkClerk, halo2_snark::build_snark_message},
};
use super::SnarkProverInput;
type D = MithrilMembershipDigest;
fn setup_signers_and_clerk(
params: Parameters,
nparties: usize,
rng: &mut ChaCha20Rng,
) -> (Vec<Signer<D>>, SnarkClerk) {
let mut key_reg = KeyRegistration::initialize();
let mut initializers = Vec::with_capacity(nparties);
for _ in 0..nparties {
let init = Initializer::new(params, 1, rng);
key_reg.register_by_entry(&init.clone().try_into().unwrap()).unwrap();
initializers.push(init);
}
let closed_reg = key_reg.close_registration(¶ms).unwrap();
let signers: Vec<Signer<D>> = initializers
.into_iter()
.map(|init| init.try_create_signer::<D>(&closed_reg).unwrap())
.collect();
let clerk = SnarkClerk::new_clerk_from_signer(&signers[0]);
(signers, clerk)
}
fn collect_signatures(signers: &[Signer<D>], message: &[u8]) -> Vec<SingleSignature> {
signers
.iter()
.filter_map(|signer| signer.create_single_signature(message).ok())
.collect()
}
#[test]
fn succeeds_with_enough_valid_signatures() {
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
let params = Parameters {
m: 200,
k: 5,
phi_f: 0.8,
};
let nparties = 10;
let message = [1u8; 32];
let (signers, clerk) = setup_signers_and_clerk(params, nparties, &mut rng);
let signatures = collect_signatures(&signers, &message);
let result = SnarkProverInput::prepare_prover_input::<D>(&clerk, &signatures, &message);
assert!(result.is_ok(), "Expected success, got: {result:?}");
let prover_input = result.unwrap();
assert_eq!(
prover_input.into_witness().len() as u64,
params.k,
"Witness should contain exactly k entries"
);
}
#[test]
fn fails_when_not_enough_signatures() {
let mut rng = ChaCha20Rng::from_seed([1u8; 32]);
let params = Parameters {
m: 5,
k: 100,
phi_f: 0.2,
};
let nparties = 2;
let message = [2u8; 32];
let (signers, clerk) = setup_signers_and_clerk(params, nparties, &mut rng);
let signatures = collect_signatures(&signers, &message);
let result = SnarkProverInput::prepare_prover_input::<D>(&clerk, &signatures, &message);
assert!(
result.is_err(),
"Expected failure due to insufficient signatures"
);
let err = result.unwrap_err();
assert!(
err.downcast_ref::<AggregationError>().is_some(),
"Expected AggregationError, got: {err:?}"
);
}
#[test]
fn skips_signatures_without_snark_component() {
let mut rng = ChaCha20Rng::from_seed([2u8; 32]);
let params = Parameters {
m: 200,
k: 5,
phi_f: 0.8,
};
let nparties = 10;
let message = [3u8; 32];
let (signers, clerk) = setup_signers_and_clerk(params, nparties, &mut rng);
let mut signatures = collect_signatures(&signers, &message);
for sig in &mut signatures {
sig.snark_signature = None;
}
let result = SnarkProverInput::prepare_prover_input::<D>(&clerk, &signatures, &message);
assert!(
result.is_err(),
"Expected failure when all SNARK signatures are stripped"
);
}
#[test]
fn empty_signatures_fails() {
let mut rng = ChaCha20Rng::from_seed([3u8; 32]);
let params = Parameters {
m: 200,
k: 5,
phi_f: 0.8,
};
let nparties = 5;
let (_, clerk) = setup_signers_and_clerk(params, nparties, &mut rng);
let result = SnarkProverInput::prepare_prover_input::<D>(&clerk, &[], &[4u8; 32]);
assert!(result.is_err(), "Expected failure with empty signatures");
}
#[test]
fn instance_contains_expected_message_components() {
let mut rng = ChaCha20Rng::from_seed([4u8; 32]);
let params = Parameters {
m: 200,
k: 5,
phi_f: 0.8,
};
let nparties = 10;
let message = [5u8; 32];
let (signers, clerk) = setup_signers_and_clerk(params, nparties, &mut rng);
let signatures = collect_signatures(&signers, &message);
let prover_input =
SnarkProverInput::prepare_prover_input::<D>(&clerk, &signatures, &message).unwrap();
let avk = clerk.compute_aggregate_verification_key_for_snark::<D>();
let expected_message =
build_snark_message(&avk.get_merkle_tree_commitment().root, &message).unwrap();
let (instance_root, instance_msg) = prover_input.get_instance();
assert_eq!(
instance_root,
expected_message[0].into(),
"Instance root mismatch"
);
assert_eq!(
instance_msg,
expected_message[1].into(),
"Instance message mismatch"
);
}
}