forester_utils/instructions/
state_batch_append.rs

1use account_compression::processor::initialize_address_merkle_tree::Pubkey;
2use light_batched_merkle_tree::{
3    constants::DEFAULT_BATCH_STATE_TREE_HEIGHT,
4    merkle_tree::{BatchedMerkleTreeAccount, InstructionDataBatchAppendInputs},
5    queue::BatchedQueueAccount,
6};
7use light_client::{indexer::Indexer, rpc::Rpc};
8use light_compressed_account::instruction_data::compressed_proof::CompressedProof;
9use light_hasher::bigint::bigint_to_be_bytes_array;
10use light_merkle_tree_metadata::QueueType;
11use light_prover_client::{
12    proof_client::ProofClient,
13    proof_types::batch_append::{
14        get_batch_append_with_proofs_inputs, BatchAppendWithProofsCircuitInputs,
15    },
16};
17use light_sparse_merkle_tree::changelog::ChangelogEntry;
18use tracing::{error, trace};
19
20use crate::{error::ForesterUtilsError, utils::wait_for_indexer};
21
22pub async fn create_append_batch_ix_data<R: Rpc, I: Indexer>(
23    rpc: &mut R,
24    indexer: &mut I,
25    merkle_tree_pubkey: Pubkey,
26    output_queue_pubkey: Pubkey,
27) -> Result<Vec<InstructionDataBatchAppendInputs>, ForesterUtilsError> {
28    trace!("Creating append batch instruction data");
29
30    let (merkle_tree_next_index, current_root, root_history) =
31        get_merkle_tree_metadata(rpc, merkle_tree_pubkey).await?;
32
33    trace!(
34        "merkle_tree_next_index: {:?} current_root: {:?}",
35        merkle_tree_next_index,
36        current_root
37    );
38
39    // Get output queue metadata and hash chains
40    let (zkp_batch_size, leaves_hash_chains) =
41        get_output_queue_metadata(rpc, output_queue_pubkey).await?;
42
43    if leaves_hash_chains.is_empty() {
44        trace!("No hash chains to process");
45        return Ok(Vec::new());
46    }
47
48    wait_for_indexer(rpc, indexer).await?;
49
50    let total_elements = zkp_batch_size as usize * leaves_hash_chains.len();
51    let offset = merkle_tree_next_index;
52
53    let queue_elements = indexer
54        .get_queue_elements(
55            merkle_tree_pubkey.to_bytes(),
56            QueueType::OutputStateV2,
57            total_elements as u16,
58            Some(offset),
59            None,
60        )
61        .await
62        .map_err(|e| {
63            error!("Failed to get queue elements from indexer: {:?}", e);
64            ForesterUtilsError::Indexer("Failed to get queue elements".into())
65        })?
66        .value
67        .items;
68
69    trace!("Got {} queue elements in total", queue_elements.len());
70
71    if queue_elements.len() != total_elements {
72        return Err(ForesterUtilsError::Indexer(format!(
73            "Expected {} elements, got {}",
74            total_elements,
75            queue_elements.len()
76        )));
77    }
78    let indexer_root = queue_elements.first().unwrap().root;
79    debug_assert_eq!(
80        indexer_root, current_root,
81        "root_history: {:?}",
82        root_history
83    );
84
85    let mut current_root = current_root;
86    let mut all_changelogs: Vec<ChangelogEntry<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>> =
87        Vec::new();
88    let mut proof_futures = Vec::new();
89
90    for (batch_idx, leaves_hash_chain) in leaves_hash_chains.iter().enumerate() {
91        let start_idx = batch_idx * zkp_batch_size as usize;
92        let end_idx = start_idx + zkp_batch_size as usize;
93        let batch_elements = &queue_elements[start_idx..end_idx];
94
95        trace!(
96            "Processing batch {}: index range {}-{}",
97            batch_idx,
98            start_idx,
99            end_idx
100        );
101
102        let old_leaves = batch_elements
103            .iter()
104            .map(|x| x.leaf)
105            .collect::<Vec<[u8; 32]>>();
106
107        let leaves = batch_elements
108            .iter()
109            .map(|x| x.account_hash)
110            .collect::<Vec<[u8; 32]>>();
111
112        let merkle_proofs = batch_elements
113            .iter()
114            .map(|x| x.proof.clone())
115            .collect::<Vec<Vec<[u8; 32]>>>();
116
117        let adjusted_start_index =
118            merkle_tree_next_index as u32 + (batch_idx * zkp_batch_size as usize) as u32;
119
120        let (circuit_inputs, batch_changelogs) = get_batch_append_with_proofs_inputs::<32>(
121            current_root,
122            adjusted_start_index,
123            leaves,
124            *leaves_hash_chain,
125            old_leaves,
126            merkle_proofs,
127            zkp_batch_size as u32,
128            all_changelogs.as_slice(),
129        )
130        .map_err(|e| {
131            error!("Failed to get circuit inputs: {:?}", e);
132            ForesterUtilsError::Prover("Failed to get circuit inputs".into())
133        })?;
134
135        current_root =
136            bigint_to_be_bytes_array::<32>(&circuit_inputs.new_root.to_biguint().unwrap()).unwrap();
137        all_changelogs.extend(batch_changelogs);
138
139        let proof_future = generate_zkp_proof(circuit_inputs);
140
141        proof_futures.push(proof_future);
142    }
143
144    let proof_results = futures::future::join_all(proof_futures).await;
145    let mut instruction_data_vec = Vec::new();
146
147    for (i, proof_result) in proof_results.into_iter().enumerate() {
148        match proof_result {
149            Ok((proof, new_root)) => {
150                trace!("Successfully generated proof for batch {}", i);
151                instruction_data_vec.push(InstructionDataBatchAppendInputs {
152                    new_root,
153                    compressed_proof: proof,
154                });
155            }
156            Err(e) => {
157                error!("Failed to generate proof for batch {}: {:?}", i, e);
158                return Err(e);
159            }
160        }
161    }
162
163    Ok(instruction_data_vec)
164}
165async fn generate_zkp_proof(
166    circuit_inputs: BatchAppendWithProofsCircuitInputs,
167) -> Result<(CompressedProof, [u8; 32]), ForesterUtilsError> {
168    let proof_client = ProofClient::local();
169    let (proof, new_root) = proof_client
170        .generate_batch_append_proof(circuit_inputs)
171        .await
172        .map_err(|e| ForesterUtilsError::Prover(e.to_string()))?;
173    Ok((
174        CompressedProof {
175            a: proof.a,
176            b: proof.b,
177            c: proof.c,
178        },
179        new_root,
180    ))
181}
182
183/// Get metadata from the Merkle tree account
184async fn get_merkle_tree_metadata(
185    rpc: &mut impl Rpc,
186    merkle_tree_pubkey: Pubkey,
187) -> Result<(u64, [u8; 32], Vec<[u8; 32]>), ForesterUtilsError> {
188    let mut merkle_tree_account = rpc
189        .get_account(merkle_tree_pubkey)
190        .await
191        .map_err(|e| ForesterUtilsError::Rpc(format!("Failed to get merkle tree account: {}", e)))?
192        .ok_or_else(|| ForesterUtilsError::Rpc("Merkle tree account not found".into()))?;
193
194    let merkle_tree = BatchedMerkleTreeAccount::state_from_bytes(
195        merkle_tree_account.data.as_mut_slice(),
196        &merkle_tree_pubkey.into(),
197    )
198    .map_err(|e| ForesterUtilsError::Rpc(format!("Failed to parse merkle tree: {}", e)))?;
199
200    Ok((
201        merkle_tree.next_index,
202        *merkle_tree.root_history.last().unwrap(),
203        merkle_tree.root_history.to_vec(),
204    ))
205}
206
207/// Get metadata and hash chains from the output queue
208async fn get_output_queue_metadata(
209    rpc: &mut impl Rpc,
210    output_queue_pubkey: Pubkey,
211) -> Result<(u16, Vec<[u8; 32]>), ForesterUtilsError> {
212    let mut output_queue_account = rpc
213        .get_account(output_queue_pubkey)
214        .await
215        .map_err(|e| ForesterUtilsError::Rpc(format!("Failed to get output queue account: {}", e)))?
216        .ok_or_else(|| ForesterUtilsError::Rpc("Output queue account not found".into()))?;
217
218    let output_queue =
219        BatchedQueueAccount::output_from_bytes(output_queue_account.data.as_mut_slice())
220            .map_err(|e| ForesterUtilsError::Rpc(format!("Failed to parse output queue: {}", e)))?;
221
222    let full_batch_index = output_queue.batch_metadata.pending_batch_index;
223    let zkp_batch_size = output_queue.batch_metadata.zkp_batch_size;
224    let batch = &output_queue.batch_metadata.batches[full_batch_index as usize];
225    let num_inserted_zkps = batch.get_num_inserted_zkps();
226
227    // Get all remaining hash chains for the batch
228    let mut leaves_hash_chains = Vec::new();
229    for i in num_inserted_zkps..batch.get_current_zkp_batch_index() {
230        leaves_hash_chains
231            .push(output_queue.hash_chain_stores[full_batch_index as usize][i as usize]);
232    }
233
234    trace!(
235        "ZKP batch size: {}, inserted ZKPs: {}, current ZKP index: {}, ready for insertion: {}",
236        zkp_batch_size,
237        num_inserted_zkps,
238        batch.get_current_zkp_batch_index(),
239        leaves_hash_chains.len()
240    );
241
242    Ok((zkp_batch_size as u16, leaves_hash_chains))
243}