use account_compression::processor::initialize_address_merkle_tree::Pubkey;
use light_batched_merkle_tree::{
constants::DEFAULT_BATCH_STATE_TREE_HEIGHT,
merkle_tree::{BatchedMerkleTreeAccount, InstructionDataBatchNullifyInputs},
};
use light_client::{indexer::Indexer, rpc::Rpc};
use light_compressed_account::instruction_data::compressed_proof::CompressedProof;
use light_hasher::{bigint::bigint_to_be_bytes_array, Hasher, Poseidon};
use light_merkle_tree_metadata::QueueType;
use light_prover_client::{
proof_client::ProofClient,
proof_types::batch_update::{get_batch_update_inputs, BatchUpdateCircuitInputs},
};
use tracing::{error, trace};
use crate::{error::ForesterUtilsError, utils::wait_for_indexer};
pub async fn create_nullify_batch_ix_data<R: Rpc, I: Indexer>(
rpc: &mut R,
indexer: &mut I,
merkle_tree_pubkey: Pubkey,
) -> Result<Vec<InstructionDataBatchNullifyInputs>, ForesterUtilsError> {
trace!("create_multiple_nullify_batch_ix_data");
let (
batch_idx,
zkp_batch_size,
num_inserted_zkps,
num_ready_zkps,
old_root,
root_history,
leaves_hash_chains,
) = {
let mut account = rpc.get_account(merkle_tree_pubkey).await.unwrap().unwrap();
let merkle_tree = BatchedMerkleTreeAccount::state_from_bytes(
account.data.as_mut_slice(),
&merkle_tree_pubkey.into(),
)
.unwrap();
trace!("queue_batches: {:?}", merkle_tree.queue_batches);
let batch_idx = merkle_tree.queue_batches.pending_batch_index as usize;
let zkp_size = merkle_tree.queue_batches.zkp_batch_size;
let batch = &merkle_tree.queue_batches.batches[batch_idx];
let num_inserted_zkps = batch.get_num_inserted_zkps();
let num_current_zkp = batch.get_current_zkp_batch_index();
let num_ready_zkps = num_current_zkp.saturating_sub(num_inserted_zkps);
let mut leaves_hash_chains = Vec::new();
for i in num_inserted_zkps..num_current_zkp {
leaves_hash_chains.push(merkle_tree.hash_chain_stores[batch_idx][i as usize]);
}
let root = *merkle_tree.root_history.last().unwrap();
let root_history = merkle_tree.root_history.to_vec();
(
batch_idx,
zkp_size as u16,
num_inserted_zkps,
num_ready_zkps,
root,
root_history,
leaves_hash_chains,
)
};
trace!(
"batch_idx: {}, zkp_batch_size: {}, num_inserted_zkps: {}, num_ready_zkps: {}, leaves_hash_chains: {:?}",
batch_idx, zkp_batch_size, num_inserted_zkps, num_ready_zkps, leaves_hash_chains.len()
);
if leaves_hash_chains.is_empty() {
return Ok(Vec::new());
}
wait_for_indexer(rpc, indexer).await?;
let current_slot = rpc.get_slot().await.unwrap();
trace!("current_slot: {}", current_slot);
let total_elements = zkp_batch_size as usize * leaves_hash_chains.len();
let offset = num_inserted_zkps * zkp_batch_size as u64;
trace!(
"Requesting {} total elements with offset {}",
total_elements,
offset
);
let all_queue_elements = indexer
.get_queue_elements(
merkle_tree_pubkey.to_bytes(),
QueueType::InputStateV2,
total_elements as u16,
Some(offset),
None,
)
.await
.map_err(|e| {
error!(
"create_multiple_nullify_batch_ix_data: failed to get queue elements from indexer: {:?}",
e
);
ForesterUtilsError::Indexer("Failed to get queue elements".into())
})?.value.items;
trace!("Got {} queue elements in total", all_queue_elements.len());
if all_queue_elements.len() != total_elements {
return Err(ForesterUtilsError::Indexer(format!(
"Expected {} elements, got {}",
total_elements,
all_queue_elements.len()
)));
}
let indexer_root = all_queue_elements.first().unwrap().root;
debug_assert_eq!(
indexer_root, old_root,
"Root mismatch. Expected: {:?}, Got: {:?}. Root history: {:?}",
old_root, indexer_root, root_history
);
let mut all_changelogs = Vec::new();
let mut proof_futures = Vec::new();
let mut current_root = old_root;
for (batch_offset, leaves_hash_chain) in leaves_hash_chains.iter().enumerate() {
let start_idx = batch_offset * zkp_batch_size as usize;
let end_idx = start_idx + zkp_batch_size as usize;
let batch_elements = &all_queue_elements[start_idx..end_idx];
trace!(
"Processing batch {} with offset {}-{}",
batch_offset,
start_idx,
end_idx
);
let mut leaves = Vec::new();
let mut tx_hashes = Vec::new();
let mut old_leaves = Vec::new();
let mut path_indices = Vec::new();
let mut merkle_proofs = Vec::new();
let mut nullifiers = Vec::new();
for (i, leaf_info) in batch_elements.iter().enumerate() {
let global_leaf_index = start_idx + i;
trace!(
"Element {}: local index={}, global index={}, reported index={}",
i,
i,
global_leaf_index,
leaf_info.leaf_index
);
path_indices.push(leaf_info.leaf_index as u32);
leaves.push(leaf_info.account_hash);
old_leaves.push(leaf_info.leaf);
merkle_proofs.push(leaf_info.proof.clone());
let tx_hash = match leaf_info.tx_hash {
Some(hash) => hash,
None => {
return Err(ForesterUtilsError::Indexer(format!(
"Missing tx_hash for leaf index {}",
leaf_info.leaf_index
)))
}
};
tx_hashes.push(tx_hash);
let index_bytes = leaf_info.leaf_index.to_be_bytes();
let nullifier =
Poseidon::hashv(&[&leaf_info.account_hash, &index_bytes, &tx_hash]).unwrap();
nullifiers.push(nullifier);
}
let (circuit_inputs, batch_changelog) =
get_batch_update_inputs::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>(
current_root,
tx_hashes.clone(),
leaves.clone(),
*leaves_hash_chain,
old_leaves.clone(),
merkle_proofs.clone(),
path_indices.clone(),
zkp_batch_size as u32,
&all_changelogs,
)
.map_err(|e| {
error!("Failed to get batch update inputs: {:?}", e);
ForesterUtilsError::Prover("Failed to get batch update inputs".into())
})?;
all_changelogs.extend(batch_changelog);
current_root =
bigint_to_be_bytes_array::<32>(&circuit_inputs.new_root.to_biguint().unwrap())
.map_err(|_| {
ForesterUtilsError::Prover("Failed to convert new root to bytes".into())
})?;
let proof_future = tokio::spawn(generate_nullify_zkp_proof(circuit_inputs));
proof_futures.push(proof_future);
}
let mut results = Vec::new();
for (i, future) in futures::future::join_all(proof_futures)
.await
.into_iter()
.enumerate()
{
match future {
Ok(result) => match result {
Ok((proof, new_root)) => {
results.push(InstructionDataBatchNullifyInputs {
new_root,
compressed_proof: proof,
});
trace!("Successfully generated proof for batch {}", i);
}
Err(e) => {
error!("Error generating proof for batch {}: {:?}", i, e);
return Err(e);
}
},
Err(e) => {
error!("Task error for batch {}: {:?}", i, e);
return Err(ForesterUtilsError::Prover(format!(
"Task error for batch {}: {:?}",
i, e
)));
}
}
}
Ok(results)
}
async fn generate_nullify_zkp_proof(
inputs: BatchUpdateCircuitInputs,
) -> Result<(CompressedProof, [u8; 32]), ForesterUtilsError> {
let proof_client = ProofClient::local();
let (proof, new_root) = proof_client
.generate_batch_update_proof(inputs)
.await
.map_err(|e| ForesterUtilsError::Prover(e.to_string()))?;
Ok((
CompressedProof {
a: proof.a,
b: proof.b,
c: proof.c,
},
new_root,
))
}