forester_utils/instructions/
state_batch_nullify.rs

1use account_compression::processor::initialize_address_merkle_tree::Pubkey;
2use light_batched_merkle_tree::{
3    constants::DEFAULT_BATCH_STATE_TREE_HEIGHT,
4    merkle_tree::{BatchedMerkleTreeAccount, InstructionDataBatchNullifyInputs},
5};
6use light_client::{indexer::Indexer, rpc::Rpc};
7use light_compressed_account::instruction_data::compressed_proof::CompressedProof;
8use light_hasher::{bigint::bigint_to_be_bytes_array, Hasher, Poseidon};
9use light_merkle_tree_metadata::QueueType;
10use light_prover_client::{
11    proof_client::ProofClient,
12    proof_types::batch_update::{get_batch_update_inputs, BatchUpdateCircuitInputs},
13};
14use tracing::{error, trace};
15
16use crate::{error::ForesterUtilsError, utils::wait_for_indexer};
17
18pub async fn create_nullify_batch_ix_data<R: Rpc, I: Indexer>(
19    rpc: &mut R,
20    indexer: &mut I,
21    merkle_tree_pubkey: Pubkey,
22) -> Result<Vec<InstructionDataBatchNullifyInputs>, ForesterUtilsError> {
23    trace!("create_multiple_nullify_batch_ix_data");
24    // Get the tree information and find out how many ZKP batches need processing
25    let (
26        batch_idx,
27        zkp_batch_size,
28        num_inserted_zkps,
29        num_ready_zkps,
30        old_root,
31        root_history,
32        leaves_hash_chains,
33    ) = {
34        let mut account = rpc.get_account(merkle_tree_pubkey).await.unwrap().unwrap();
35        let merkle_tree = BatchedMerkleTreeAccount::state_from_bytes(
36            account.data.as_mut_slice(),
37            &merkle_tree_pubkey.into(),
38        )
39        .unwrap();
40
41        trace!("queue_batches: {:?}", merkle_tree.queue_batches);
42
43        let batch_idx = merkle_tree.queue_batches.pending_batch_index as usize;
44        let zkp_size = merkle_tree.queue_batches.zkp_batch_size;
45        let batch = &merkle_tree.queue_batches.batches[batch_idx];
46        let num_inserted_zkps = batch.get_num_inserted_zkps();
47        let num_current_zkp = batch.get_current_zkp_batch_index();
48        let num_ready_zkps = num_current_zkp.saturating_sub(num_inserted_zkps);
49
50        let mut leaves_hash_chains = Vec::new();
51        for i in num_inserted_zkps..num_current_zkp {
52            leaves_hash_chains.push(merkle_tree.hash_chain_stores[batch_idx][i as usize]);
53        }
54
55        let root = *merkle_tree.root_history.last().unwrap();
56        let root_history = merkle_tree.root_history.to_vec();
57
58        (
59            batch_idx,
60            zkp_size as u16,
61            num_inserted_zkps,
62            num_ready_zkps,
63            root,
64            root_history,
65            leaves_hash_chains,
66        )
67    };
68
69    trace!(
70        "batch_idx: {}, zkp_batch_size: {}, num_inserted_zkps: {}, num_ready_zkps: {}, leaves_hash_chains: {:?}",
71        batch_idx, zkp_batch_size, num_inserted_zkps, num_ready_zkps, leaves_hash_chains.len()
72    );
73
74    if leaves_hash_chains.is_empty() {
75        return Ok(Vec::new());
76    }
77
78    wait_for_indexer(rpc, indexer).await?;
79
80    let current_slot = rpc.get_slot().await.unwrap();
81    trace!("current_slot: {}", current_slot);
82
83    let total_elements = zkp_batch_size as usize * leaves_hash_chains.len();
84    let offset = num_inserted_zkps * zkp_batch_size as u64;
85
86    trace!(
87        "Requesting {} total elements with offset {}",
88        total_elements,
89        offset
90    );
91
92    let all_queue_elements = indexer
93        .get_queue_elements(
94            merkle_tree_pubkey.to_bytes(),
95            QueueType::InputStateV2,
96            total_elements as u16,
97            Some(offset),
98            None,
99        )
100        .await
101        .map_err(|e| {
102            error!(
103                "create_multiple_nullify_batch_ix_data: failed to get queue elements from indexer: {:?}",
104                e
105            );
106            ForesterUtilsError::Indexer("Failed to get queue elements".into())
107        })?.value.items;
108
109    trace!("Got {} queue elements in total", all_queue_elements.len());
110    if all_queue_elements.len() != total_elements {
111        return Err(ForesterUtilsError::Indexer(format!(
112            "Expected {} elements, got {}",
113            total_elements,
114            all_queue_elements.len()
115        )));
116    }
117
118    let indexer_root = all_queue_elements.first().unwrap().root;
119    debug_assert_eq!(
120        indexer_root, old_root,
121        "Root mismatch. Expected: {:?}, Got: {:?}. Root history: {:?}",
122        old_root, indexer_root, root_history
123    );
124
125    let mut all_changelogs = Vec::new();
126    let mut proof_futures = Vec::new();
127
128    let mut current_root = old_root;
129
130    for (batch_offset, leaves_hash_chain) in leaves_hash_chains.iter().enumerate() {
131        let start_idx = batch_offset * zkp_batch_size as usize;
132        let end_idx = start_idx + zkp_batch_size as usize;
133        let batch_elements = &all_queue_elements[start_idx..end_idx];
134
135        trace!(
136            "Processing batch {} with offset {}-{}",
137            batch_offset,
138            start_idx,
139            end_idx
140        );
141
142        // Process this batch's data
143        let mut leaves = Vec::new();
144        let mut tx_hashes = Vec::new();
145        let mut old_leaves = Vec::new();
146        let mut path_indices = Vec::new();
147        let mut merkle_proofs = Vec::new();
148        let mut nullifiers = Vec::new();
149
150        for (i, leaf_info) in batch_elements.iter().enumerate() {
151            let global_leaf_index = start_idx + i;
152            trace!(
153                "Element {}: local index={}, global index={}, reported index={}",
154                i,
155                i,
156                global_leaf_index,
157                leaf_info.leaf_index
158            );
159
160            path_indices.push(leaf_info.leaf_index as u32);
161            leaves.push(leaf_info.account_hash);
162            old_leaves.push(leaf_info.leaf);
163            merkle_proofs.push(leaf_info.proof.clone());
164
165            // Make sure tx_hash exists
166            let tx_hash = match leaf_info.tx_hash {
167                Some(hash) => hash,
168                None => {
169                    return Err(ForesterUtilsError::Indexer(format!(
170                        "Missing tx_hash for leaf index {}",
171                        leaf_info.leaf_index
172                    )))
173                }
174            };
175
176            tx_hashes.push(tx_hash);
177
178            let index_bytes = leaf_info.leaf_index.to_be_bytes();
179            let nullifier =
180                Poseidon::hashv(&[&leaf_info.account_hash, &index_bytes, &tx_hash]).unwrap();
181            nullifiers.push(nullifier);
182        }
183
184        let (circuit_inputs, batch_changelog) =
185            get_batch_update_inputs::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>(
186                current_root,
187                tx_hashes.clone(),
188                leaves.clone(),
189                *leaves_hash_chain,
190                old_leaves.clone(),
191                merkle_proofs.clone(),
192                path_indices.clone(),
193                zkp_batch_size as u32,
194                &all_changelogs,
195            )
196            .map_err(|e| {
197                error!("Failed to get batch update inputs: {:?}", e);
198                ForesterUtilsError::Prover("Failed to get batch update inputs".into())
199            })?;
200
201        all_changelogs.extend(batch_changelog);
202        current_root =
203            bigint_to_be_bytes_array::<32>(&circuit_inputs.new_root.to_biguint().unwrap())
204                .map_err(|_| {
205                    ForesterUtilsError::Prover("Failed to convert new root to bytes".into())
206                })?;
207
208        let proof_future = tokio::spawn(generate_nullify_zkp_proof(circuit_inputs));
209        proof_futures.push(proof_future);
210    }
211
212    // Wait for all proof generation to complete
213    let mut results = Vec::new();
214
215    for (i, future) in futures::future::join_all(proof_futures)
216        .await
217        .into_iter()
218        .enumerate()
219    {
220        match future {
221            Ok(result) => match result {
222                Ok((proof, new_root)) => {
223                    results.push(InstructionDataBatchNullifyInputs {
224                        new_root,
225                        compressed_proof: proof,
226                    });
227                    trace!("Successfully generated proof for batch {}", i);
228                }
229                Err(e) => {
230                    error!("Error generating proof for batch {}: {:?}", i, e);
231                    return Err(e);
232                }
233            },
234            Err(e) => {
235                error!("Task error for batch {}: {:?}", i, e);
236                return Err(ForesterUtilsError::Prover(format!(
237                    "Task error for batch {}: {:?}",
238                    i, e
239                )));
240            }
241        }
242    }
243
244    Ok(results)
245}
246async fn generate_nullify_zkp_proof(
247    inputs: BatchUpdateCircuitInputs,
248) -> Result<(CompressedProof, [u8; 32]), ForesterUtilsError> {
249    let proof_client = ProofClient::local();
250    let (proof, new_root) = proof_client
251        .generate_batch_update_proof(inputs)
252        .await
253        .map_err(|e| ForesterUtilsError::Prover(e.to_string()))?;
254    Ok((
255        CompressedProof {
256            a: proof.a,
257            b: proof.b,
258            c: proof.c,
259        },
260        new_root,
261    ))
262}