light_batched_merkle_tree/
batch.rs

1use light_bloom_filter::BloomFilter;
2use light_hasher::{Hasher, Poseidon};
3use light_zero_copy::vec::ZeroCopyVecU64;
4use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
5
6use crate::{errors::BatchedMerkleTreeError, BorshDeserialize, BorshSerialize};
7
8#[derive(Clone, Debug, PartialEq, Eq, Copy)]
9#[repr(u64)]
10pub enum BatchState {
11    /// Batch can be filled with values.
12    Fill,
13    /// Batch has been inserted into the tree.
14    Inserted,
15    /// Batch is full.
16    Full,
17}
18
19impl From<u64> for BatchState {
20    fn from(value: u64) -> Self {
21        match value {
22            0 => BatchState::Fill,
23            1 => BatchState::Inserted,
24            2 => BatchState::Full,
25            _ => panic!("Invalid BatchState value"),
26        }
27    }
28}
29
30impl From<BatchState> for u64 {
31    fn from(val: BatchState) -> Self {
32        val as u64
33    }
34}
35
36/// Batch structure that holds
37/// the metadata and state of a batch.
38///
39/// A batch:
40/// - has a size and a number of zkp batches.
41/// - size must be divisible by zkp batch size.
42/// - is part of a queue, each queue has two batches.
43/// - is inserted into the tree by zkp batch.
44#[repr(C)]
45#[derive(
46    Clone,
47    Copy,
48    Debug,
49    PartialEq,
50    Eq,
51    KnownLayout,
52    Immutable,
53    IntoBytes,
54    FromBytes,
55    Default,
56    BorshSerialize,
57    BorshDeserialize,
58)]
59pub struct Batch {
60    /// Number of inserted elements in the zkp batch.
61    num_inserted: u64,
62    state: u64,
63    /// Number of full zkp batches in the batch,
64    /// that are ready to be inserted into the tree.
65    pub(crate) num_full_zkp_batches: u64,
66    /// Number zkp batches that are inserted into the tree.
67    num_inserted_zkp_batches: u64,
68    /// Number of iterations for the bloom_filter.
69    pub num_iters: u64,
70    /// Theoretical capacity of the bloom_filter in bits.
71    /// We want to make it much larger
72    /// than batch_size to avoid false positives.
73    pub bloom_filter_capacity: u64,
74    /// Number of elements in a batch.
75    pub batch_size: u64,
76    /// Number of elements in a zkp batch.
77    /// A batch consists out of one or more zkp batches.
78    pub zkp_batch_size: u64,
79    /// Sequence number when it is save to clear the batch without advancing to
80    /// the saved root index.
81    pub sequence_number: u64,
82    /// Start leaf index of the first
83    pub start_index: u64,
84    /// Slot of the first insertion into the batch.
85    /// Indexers can use this slot to reindex inserted elements.
86    /// Is not used for the batch itself.
87    pub start_slot: u64,
88    pub root_index: u32,
89    start_slot_is_set: u8,
90    bloom_filter_is_zeroed: u8,
91    _padding: [u8; 2],
92}
93
94impl Batch {
95    pub fn new(
96        num_iters: u64,
97        bloom_filter_capacity: u64,
98        batch_size: u64,
99        zkp_batch_size: u64,
100        start_index: u64,
101    ) -> Self {
102        Batch {
103            num_iters,
104            bloom_filter_capacity,
105            batch_size,
106            num_inserted: 0,
107            state: BatchState::Fill.into(),
108            zkp_batch_size,
109            num_full_zkp_batches: 0,
110            num_inserted_zkp_batches: 0,
111            sequence_number: 0,
112            root_index: 0,
113            start_index,
114            start_slot: 0,
115            start_slot_is_set: 0,
116            bloom_filter_is_zeroed: 0,
117            _padding: [0u8; 2],
118        }
119    }
120
121    /// Returns the state of the batch.
122    pub fn get_state(&self) -> BatchState {
123        self.state.into()
124    }
125
126    pub fn bloom_filter_is_zeroed(&self) -> bool {
127        self.bloom_filter_is_zeroed == 1
128    }
129
130    pub fn set_bloom_filter_to_zeroed(&mut self) {
131        // 1 if bloom filter is zeroed
132        // 0 if bloom filter is not zeroed
133        self.bloom_filter_is_zeroed = 1;
134    }
135
136    pub fn set_bloom_filter_to_not_zeroed(&mut self) {
137        // 1 if bloom filter is zeroed
138        // 0 if bloom filter is not zeroed
139        self.bloom_filter_is_zeroed = 0;
140    }
141
142    pub fn start_slot_is_set(&self) -> bool {
143        self.start_slot_is_set == 1
144    }
145
146    pub fn set_start_slot(&mut self, start_slot: &u64) {
147        if !self.start_slot_is_set() {
148            self.start_slot = *start_slot;
149            self.start_slot_is_set = 1;
150        }
151    }
152
153    /// fill -> full -> inserted -> fill
154    /// (from tree insertion perspective is pending if fill or full)
155    pub fn advance_state_to_fill(
156        &mut self,
157        start_index: Option<u64>,
158    ) -> Result<(), BatchedMerkleTreeError> {
159        if self.get_state() == BatchState::Inserted {
160            self.state = BatchState::Fill.into();
161            self.set_bloom_filter_to_not_zeroed();
162            self.sequence_number = 0;
163            self.root_index = 0;
164            self.num_inserted_zkp_batches = 0;
165            self.start_slot_is_set = 0;
166            self.start_slot = 0;
167            if let Some(start_index) = start_index {
168                self.start_index = start_index;
169            }
170            self.num_full_zkp_batches = 0;
171        } else {
172            #[cfg(feature = "solana")]
173            solana_msg::msg!(
174                "Batch is in incorrect state {} expected BatchState::Inserted 1",
175                self.state
176            );
177            return Err(BatchedMerkleTreeError::BatchNotReady);
178        }
179        Ok(())
180    }
181
182    /// fill -> full -> inserted -> fill
183    /// (from tree insertion perspective is pending if fill or full)
184    pub fn advance_state_to_inserted(&mut self) -> Result<(), BatchedMerkleTreeError> {
185        if self.get_state() == BatchState::Full {
186            self.state = BatchState::Inserted.into();
187        } else {
188            #[cfg(feature = "solana")]
189            solana_msg::msg!(
190                "Batch is in incorrect state {} expected BatchState::Full 2",
191                self.state
192            );
193            return Err(BatchedMerkleTreeError::BatchNotReady);
194        }
195        Ok(())
196    }
197
198    /// fill -> full -> inserted -> fill
199    /// (from tree insertion perspective is pending if fill or full)
200    pub fn advance_state_to_full(&mut self) -> Result<(), BatchedMerkleTreeError> {
201        if self.get_state() == BatchState::Fill {
202            self.state = BatchState::Full.into();
203        } else {
204            #[cfg(feature = "solana")]
205            solana_msg::msg!(
206                "Batch is in incorrect state {} expected BatchState::Fill 0",
207                self.state
208            );
209            return Err(BatchedMerkleTreeError::BatchNotReady);
210        }
211        Ok(())
212    }
213
214    pub fn get_first_ready_zkp_batch(&self) -> Result<u64, BatchedMerkleTreeError> {
215        if self.get_state() == BatchState::Inserted {
216            Err(BatchedMerkleTreeError::BatchAlreadyInserted)
217        } else if self.batch_is_ready_to_insert() {
218            Ok(self.num_inserted_zkp_batches)
219        } else {
220            Err(BatchedMerkleTreeError::BatchNotReady)
221        }
222    }
223
224    pub fn batch_is_ready_to_insert(&self) -> bool {
225        self.num_full_zkp_batches > self.num_inserted_zkp_batches
226    }
227
228    /// Returns the number of zkp batch updates
229    /// that are ready to be inserted into the tree.
230    pub fn get_num_ready_zkp_updates(&self) -> u64 {
231        self.num_full_zkp_batches
232            .saturating_sub(self.num_inserted_zkp_batches)
233    }
234
235    /// Returns the number of inserted elements
236    /// in the current zkp batch.
237    pub fn get_num_inserted_zkp_batch(&self) -> u64 {
238        self.num_inserted
239    }
240
241    /// Returns the current zkp batch index.
242    /// New values are inserted into the current zkp batch.
243    pub fn get_current_zkp_batch_index(&self) -> u64 {
244        self.num_full_zkp_batches
245    }
246
247    /// Returns the number of inserted zkps.
248    pub fn get_num_inserted_zkps(&self) -> u64 {
249        self.num_inserted_zkp_batches
250    }
251
252    /// Returns the number of elements inserted into the tree.
253    pub fn get_num_elements_inserted_into_tree(&self) -> u64 {
254        self.num_inserted_zkp_batches * self.zkp_batch_size
255    }
256
257    /// Returns the number of inserted elements in the batch.
258    pub fn get_num_inserted_elements(&self) -> u64 {
259        self.num_full_zkp_batches * self.zkp_batch_size + self.num_inserted
260    }
261
262    /// Returns the number of zkp batches in the batch.
263    pub fn get_num_zkp_batches(&self) -> u64 {
264        self.batch_size / self.zkp_batch_size
265    }
266
267    /// Returns the number of the hash_chain stores.
268    pub fn get_num_hash_chain_store(&self) -> usize {
269        self.get_num_zkp_batches() as usize
270    }
271
272    /// Returns the index of a value by leaf index in the value store,
273    /// provided it does exist in the batch.
274    pub fn get_value_index_in_batch(&self, leaf_index: u64) -> Result<u64, BatchedMerkleTreeError> {
275        self.check_leaf_index_exists(leaf_index)?;
276        let index = leaf_index
277            .checked_sub(self.start_index)
278            .ok_or(BatchedMerkleTreeError::LeafIndexNotInBatch)?;
279        Ok(index)
280    }
281
282    /// Stores the value in a value store,
283    /// and adds the value to the current hash chain.
284    pub fn store_and_hash_value(
285        &mut self,
286        value: &[u8; 32],
287        value_store: &mut ZeroCopyVecU64<[u8; 32]>,
288        hash_chain_store: &mut ZeroCopyVecU64<[u8; 32]>,
289        start_slot: &u64,
290    ) -> Result<(), BatchedMerkleTreeError> {
291        self.set_start_slot(start_slot);
292        self.add_to_hash_chain(value, hash_chain_store)?;
293        value_store.push(*value)?;
294        Ok(())
295    }
296
297    /// Insert into the bloom filter and
298    /// add value to current hash chain.
299    /// (used by nullifier & address queues)
300    /// 1. set start slot
301    /// 2. Add value to hash chain.
302    /// 3. Insert value into the bloom filter at bloom_filter_index.
303    /// 4. Check that value is not in any other bloom filter.
304    pub fn insert(
305        &mut self,
306        bloom_filter_value: &[u8; 32],
307        hash_chain_value: &[u8; 32],
308        bloom_filter_stores: &mut [&mut [u8]],
309        hash_chain_store: &mut ZeroCopyVecU64<[u8; 32]>,
310        bloom_filter_index: usize,
311        start_slot: &u64,
312    ) -> Result<(), BatchedMerkleTreeError> {
313        // 1. set start slot if not set.
314        self.set_start_slot(start_slot);
315        // 2. add value to hash chain
316        self.add_to_hash_chain(hash_chain_value, hash_chain_store)?;
317        // insert into bloom filter & check non inclusion
318        {
319            let other_bloom_filter_index = if bloom_filter_index == 0 { 1 } else { 0 };
320
321            // 3. Insert value into the bloom filter at bloom_filter_index.
322            BloomFilter::new(
323                self.num_iters as usize,
324                self.bloom_filter_capacity,
325                bloom_filter_stores[bloom_filter_index],
326            )?
327            .insert(bloom_filter_value)?;
328            // 4. Check that value is not in any other bloom filter.
329            Self::check_non_inclusion(
330                self.num_iters as usize,
331                self.bloom_filter_capacity,
332                bloom_filter_value,
333                bloom_filter_stores[other_bloom_filter_index],
334            )?;
335        }
336        Ok(())
337    }
338
339    /// Add a value to the current hash chain, and advance batch state.
340    /// 1. Check that the batch is ready.
341    /// 2. If the zkp batch is empty, start a new hash chain.
342    /// 3. If the zkp batch is not empty, add value to last hash chain.
343    /// 4. If the zkp batch is full, increment the zkp batch index.
344    /// 5. If all zkp batches are full, set batch state to full.
345    pub fn add_to_hash_chain(
346        &mut self,
347        value: &[u8; 32],
348        hash_chain_store: &mut ZeroCopyVecU64<[u8; 32]>,
349    ) -> Result<(), BatchedMerkleTreeError> {
350        // 1. Check that the batch is ready.
351        if self.get_state() != BatchState::Fill {
352            return Err(BatchedMerkleTreeError::BatchNotReady);
353        }
354        let start_new_hash_chain = self.num_inserted == 0;
355        if start_new_hash_chain {
356            // 2. Start a new hash chain.
357            hash_chain_store.push(*value)?;
358        } else if let Some(last_hash_chain) = hash_chain_store.last_mut() {
359            // 3. Add value to last hash chain.
360            let hash_chain = Poseidon::hashv(&[last_hash_chain, value.as_slice()])?;
361            *last_hash_chain = hash_chain;
362        } else {
363            unreachable!();
364        }
365        self.num_inserted += 1;
366
367        // 4. If the zkp batch is full, increment the zkp batch index.
368        let zkp_batch_is_full = self.num_inserted == self.zkp_batch_size;
369        if zkp_batch_is_full {
370            self.num_full_zkp_batches += 1;
371            // To start a new hash chain in the next insertion
372            // set num inserted to zero.
373            self.num_inserted = 0;
374
375            // 5. If all zkp batches are full, set batch state to full.
376            let batch_is_full = self.num_full_zkp_batches == self.get_num_zkp_batches();
377            if batch_is_full {
378                self.advance_state_to_full()?;
379            }
380        }
381
382        Ok(())
383    }
384
385    /// Checks that value is not in the bloom filter.
386    pub fn check_non_inclusion(
387        num_iters: usize,
388        bloom_filter_capacity: u64,
389        value: &[u8; 32],
390        store: &mut [u8],
391    ) -> Result<(), BatchedMerkleTreeError> {
392        let mut bloom_filter = BloomFilter::new(num_iters, bloom_filter_capacity, store)?;
393        if bloom_filter.contains(value) {
394            return Err(BatchedMerkleTreeError::NonInclusionCheckFailed);
395        }
396        Ok(())
397    }
398
399    /// Marks the batch as inserted in the merkle tree.
400    /// 1. Checks that the batch is ready.
401    /// 2. increments the number of inserted zkps.
402    /// 3. If all zkps are inserted, sets the state to inserted.
403    /// 4. Returns the updated state of the batch.
404    pub fn mark_as_inserted_in_merkle_tree(
405        &mut self,
406        sequence_number: u64,
407        root_index: u32,
408        root_history_length: u32,
409    ) -> Result<BatchState, BatchedMerkleTreeError> {
410        // 1. Check that batch is ready.
411        self.get_first_ready_zkp_batch()?;
412
413        let num_zkp_batches = self.get_num_zkp_batches();
414
415        // 2. increments the number of inserted zkps.
416        self.num_inserted_zkp_batches += 1;
417        // 3. If all zkp batches are inserted, sets the state to inserted.
418        let batch_is_completely_inserted = self.num_inserted_zkp_batches == num_zkp_batches;
419        if batch_is_completely_inserted {
420            self.advance_state_to_inserted()?;
421            // Saving sequence number and root index for the batch.
422            // When the batch is cleared check that sequence number is greater or equal than self.sequence_number
423            // if not advance current root index to root index
424            self.sequence_number = sequence_number + root_history_length as u64;
425            self.root_index = root_index;
426        }
427
428        Ok(self.get_state())
429    }
430
431    pub fn check_leaf_index_exists(&self, leaf_index: u64) -> Result<(), BatchedMerkleTreeError> {
432        if !self.leaf_index_exists(leaf_index) {
433            return Err(BatchedMerkleTreeError::LeafIndexNotInBatch);
434        }
435        Ok(())
436    }
437
438    /// Returns true if value of leaf index could exist in batch.
439    /// `True` doesn't mean that the value exists in the batch,
440    /// just that it is possible. The value might already be spent
441    /// or never have been inserted in case an invalid index was provided.
442    pub fn leaf_index_exists(&self, leaf_index: u64) -> bool {
443        let next_batch_leaf_index = self.get_num_inserted_elements() + self.start_index;
444        let min_batch_leaf_index = self.start_index;
445        leaf_index < next_batch_leaf_index && leaf_index >= min_batch_leaf_index
446    }
447}
448
449#[cfg(test)]
450mod tests {
451
452    use light_compressed_account::{pubkey::Pubkey, QueueType};
453    use light_merkle_tree_metadata::queue::QueueMetadata;
454
455    use super::*;
456    use crate::queue::BatchedQueueAccount;
457
458    fn get_test_batch() -> Batch {
459        Batch::new(3, 160_000, 500, 100, 0)
460    }
461
462    /// simulate zkp batch insertion
463    fn test_mark_as_inserted(mut batch: Batch) {
464        let mut sequence_number = 10;
465        let mut root_index = 20;
466        let root_history_length = 23;
467        let current_slot = 1;
468        for i in 0..batch.get_num_zkp_batches() {
469            sequence_number += i;
470            root_index += i as u32;
471            batch
472                .mark_as_inserted_in_merkle_tree(sequence_number, root_index, root_history_length)
473                .unwrap();
474            if i != batch.get_num_zkp_batches() - 1 {
475                assert_eq!(batch.get_state(), BatchState::Full);
476                assert_eq!(batch.get_num_inserted_zkp_batch(), 0);
477                assert_eq!(batch.get_current_zkp_batch_index(), 5);
478                assert_eq!(batch.get_num_inserted_zkps(), i + 1);
479            } else {
480                assert_eq!(batch.get_state(), BatchState::Inserted);
481                assert_eq!(batch.get_num_inserted_zkp_batch(), 0);
482                assert_eq!(batch.get_current_zkp_batch_index(), 5);
483                assert_eq!(batch.get_num_inserted_zkps(), i + 1);
484            }
485        }
486        assert_eq!(batch.get_state(), BatchState::Inserted);
487        assert_eq!(batch.get_num_inserted_zkp_batch(), 0);
488        let mut ref_batch = get_test_batch();
489        ref_batch.state = BatchState::Inserted.into();
490        ref_batch.root_index = root_index;
491        ref_batch.sequence_number = sequence_number + root_history_length as u64;
492        ref_batch.num_inserted_zkp_batches = 5;
493        ref_batch.start_slot = current_slot;
494        ref_batch.start_slot_is_set = 1;
495        ref_batch.num_full_zkp_batches = 5;
496        assert_eq!(batch, ref_batch);
497        batch.advance_state_to_fill(Some(1)).unwrap();
498        let mut ref_batch = get_test_batch();
499        ref_batch.start_index = 1;
500        assert_eq!(batch, ref_batch);
501    }
502
503    #[test]
504    fn test_store_value() {
505        let mut batch = get_test_batch();
506        let current_slot = 1;
507
508        let mut value_store_bytes =
509            vec![0u8; ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(batch.batch_size)];
510        let mut value_store =
511            ZeroCopyVecU64::new(batch.batch_size, &mut value_store_bytes).unwrap();
512        let mut hash_chain_store_bytes = vec![
513            0u8;
514            ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(
515                batch.get_num_hash_chain_store() as u64
516            )
517        ];
518        let mut hash_chain_store = ZeroCopyVecU64::new(
519            batch.get_num_hash_chain_store() as u64,
520            hash_chain_store_bytes.as_mut_slice(),
521        )
522        .unwrap();
523
524        let mut ref_batch = get_test_batch();
525        for i in 0..batch.batch_size {
526            if i == 0 {
527                ref_batch.start_slot = current_slot;
528                ref_batch.start_slot_is_set = 1;
529            }
530            ref_batch.num_inserted %= ref_batch.zkp_batch_size;
531
532            let mut value = [0u8; 32];
533            value[24..].copy_from_slice(&i.to_be_bytes());
534            assert!(batch
535                .store_and_hash_value(
536                    &value,
537                    &mut value_store,
538                    &mut hash_chain_store,
539                    &current_slot
540                )
541                .is_ok());
542            ref_batch.num_inserted += 1;
543            if ref_batch.num_inserted == ref_batch.zkp_batch_size {
544                ref_batch.num_full_zkp_batches += 1;
545                ref_batch.num_inserted = 0;
546            }
547            if ref_batch.num_full_zkp_batches == ref_batch.get_num_zkp_batches() {
548                ref_batch.state = BatchState::Full.into();
549                ref_batch.num_inserted = 0;
550            }
551            assert_eq!(batch, ref_batch);
552            assert_eq!(*value_store.get(i as usize).unwrap(), value);
553        }
554        let result = batch.store_and_hash_value(
555            &[1u8; 32],
556            &mut value_store,
557            &mut hash_chain_store,
558            &current_slot,
559        );
560        assert_eq!(result.unwrap_err(), BatchedMerkleTreeError::BatchNotReady);
561        assert_eq!(batch.get_state(), BatchState::Full);
562        assert_eq!(batch.get_num_inserted_zkp_batch(), 0);
563        assert_eq!(batch.get_current_zkp_batch_index(), 5);
564        assert_eq!(batch.get_num_zkp_batches(), 5);
565        assert_eq!(batch.get_num_inserted_zkps(), 0);
566
567        test_mark_as_inserted(batch);
568    }
569
570    #[test]
571    fn test_insert() {
572        // Behavior Input queue
573        let mut batch = get_test_batch();
574        let mut current_slot = 1;
575        let mut stores = vec![vec![0u8; 20_000]; 2];
576        let mut bloom_filter_stores = stores
577            .iter_mut()
578            .map(|store| &mut store[..])
579            .collect::<Vec<_>>();
580        let mut hash_chain_store_bytes = vec![
581            0u8;
582            ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(
583                batch.get_num_hash_chain_store() as u64
584            )
585        ];
586        ZeroCopyVecU64::<[u8; 32]>::new(
587            batch.get_num_hash_chain_store() as u64,
588            hash_chain_store_bytes.as_mut_slice(),
589        )
590        .unwrap();
591
592        let mut ref_batch = get_test_batch();
593        for processing_index in 0..=1 {
594            for i in 0..(batch.batch_size / 2) {
595                let i = i + (batch.batch_size / 2) * (processing_index as u64);
596                if i == 0 && processing_index == 0 {
597                    assert_eq!(batch.start_slot, 0);
598                    assert_eq!(batch.start_slot_is_set, 0);
599                    ref_batch.start_slot = current_slot;
600                    ref_batch.start_slot_is_set = 1;
601                } else {
602                    assert_eq!(batch.start_slot, 1);
603                    assert_eq!(batch.start_slot_is_set, 1);
604                }
605
606                ref_batch.num_inserted %= ref_batch.zkp_batch_size;
607                let mut hash_chain_store =
608                    ZeroCopyVecU64::<[u8; 32]>::from_bytes(hash_chain_store_bytes.as_mut_slice())
609                        .unwrap();
610
611                let mut value = [0u8; 32];
612                value[24..].copy_from_slice(&i.to_be_bytes());
613                let ref_hash_chain = if i % batch.zkp_batch_size == 0 {
614                    value
615                } else {
616                    Poseidon::hashv(&[hash_chain_store.last().unwrap(), &value]).unwrap()
617                };
618                let result = batch.insert(
619                    &value,
620                    &value,
621                    bloom_filter_stores.as_mut_slice(),
622                    &mut hash_chain_store,
623                    processing_index,
624                    &current_slot,
625                );
626                // First insert should succeed
627                assert!(result.is_ok(), "Failed result: {:?}", result);
628                assert_eq!(*hash_chain_store.last().unwrap(), ref_hash_chain);
629
630                {
631                    let mut cloned_hash_chain_store = hash_chain_store_bytes.clone();
632                    let mut hash_chain_store = ZeroCopyVecU64::<[u8; 32]>::from_bytes(
633                        cloned_hash_chain_store.as_mut_slice(),
634                    )
635                    .unwrap();
636                    let mut batch = batch;
637                    // Reinsert should fail
638                    assert!(batch
639                        .insert(
640                            &value,
641                            &value,
642                            bloom_filter_stores.as_mut_slice(),
643                            &mut hash_chain_store,
644                            processing_index,
645                            &current_slot
646                        )
647                        .is_err());
648                }
649                let mut bloom_filter = BloomFilter {
650                    num_iters: batch.num_iters as usize,
651                    capacity: batch.bloom_filter_capacity,
652                    store: bloom_filter_stores[processing_index],
653                };
654                assert!(bloom_filter.contains(&value));
655                let other_index = if processing_index == 0 { 1 } else { 0 };
656                Batch::check_non_inclusion(
657                    batch.num_iters as usize,
658                    batch.bloom_filter_capacity,
659                    &value,
660                    bloom_filter_stores[other_index],
661                )
662                .unwrap();
663                Batch::check_non_inclusion(
664                    batch.num_iters as usize,
665                    batch.bloom_filter_capacity,
666                    &value,
667                    bloom_filter_stores[processing_index],
668                )
669                .unwrap_err();
670
671                ref_batch.num_inserted += 1;
672                if ref_batch.num_inserted == ref_batch.zkp_batch_size {
673                    ref_batch.num_full_zkp_batches += 1;
674                    ref_batch.num_inserted = 0;
675                }
676                if i == batch.batch_size - 1 {
677                    ref_batch.state = BatchState::Full.into();
678                    ref_batch.num_inserted = 0;
679                }
680                assert_eq!(batch, ref_batch);
681                current_slot += 1;
682            }
683        }
684        test_mark_as_inserted(batch);
685    }
686
687    #[test]
688    fn test_add_to_hash_chain() {
689        let mut batch = get_test_batch();
690        let mut hash_chain_store_bytes = vec![
691            0u8;
692            ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(
693                batch.get_num_hash_chain_store() as u64
694            )
695        ];
696        let mut hash_chain_store = ZeroCopyVecU64::<[u8; 32]>::new(
697            batch.get_num_hash_chain_store() as u64,
698            hash_chain_store_bytes.as_mut_slice(),
699        )
700        .unwrap();
701        let value = [1u8; 32];
702
703        assert!(batch
704            .add_to_hash_chain(&value, &mut hash_chain_store)
705            .is_ok());
706        let mut ref_batch = get_test_batch();
707        let user_hash_chain = value;
708        ref_batch.num_inserted = 1;
709        assert_eq!(batch, ref_batch);
710        assert_eq!(hash_chain_store[0], user_hash_chain);
711        let value = [2u8; 32];
712        let ref_hash_chain = Poseidon::hashv(&[&user_hash_chain, &value]).unwrap();
713        assert!(batch
714            .add_to_hash_chain(&value, &mut hash_chain_store)
715            .is_ok());
716
717        ref_batch.num_inserted = 2;
718        assert_eq!(batch, ref_batch);
719        assert_eq!(hash_chain_store[0], ref_hash_chain);
720    }
721
722    #[test]
723    fn test_check_non_inclusion() {
724        let mut current_slot = 1;
725        for processing_index in 0..=1 {
726            let mut batch = get_test_batch();
727
728            let value = [1u8; 32];
729            let mut stores = vec![vec![0u8; 20_000]; 2];
730            let mut bloom_filter_stores = stores
731                .iter_mut()
732                .map(|store| &mut store[..])
733                .collect::<Vec<_>>();
734            let mut hash_chain_store_bytes = vec![
735            0u8;
736            ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(
737                batch.get_num_hash_chain_store() as u64
738            )
739        ];
740            let mut hash_chain_store = ZeroCopyVecU64::<[u8; 32]>::new(
741                batch.get_num_hash_chain_store() as u64,
742                hash_chain_store_bytes.as_mut_slice(),
743            )
744            .unwrap();
745
746            assert_eq!(
747                Batch::check_non_inclusion(
748                    batch.num_iters as usize,
749                    batch.bloom_filter_capacity,
750                    &value,
751                    bloom_filter_stores[processing_index]
752                ),
753                Ok(())
754            );
755            let ref_batch = get_test_batch();
756            assert_eq!(batch, ref_batch);
757            batch
758                .insert(
759                    &value,
760                    &value,
761                    bloom_filter_stores.as_mut_slice(),
762                    &mut hash_chain_store,
763                    processing_index,
764                    &current_slot,
765                )
766                .unwrap();
767            current_slot += 1;
768            assert!(Batch::check_non_inclusion(
769                batch.num_iters as usize,
770                batch.bloom_filter_capacity,
771                &value,
772                bloom_filter_stores[processing_index]
773            )
774            .is_err());
775
776            let other_index = if processing_index == 0 { 1 } else { 0 };
777            assert!(Batch::check_non_inclusion(
778                batch.num_iters as usize,
779                batch.bloom_filter_capacity,
780                &value,
781                bloom_filter_stores[other_index]
782            )
783            .is_ok());
784        }
785    }
786
787    #[test]
788    fn test_getters() {
789        let mut batch = get_test_batch();
790        assert_eq!(batch.get_num_zkp_batches(), 5);
791        assert_eq!(batch.get_num_hash_chain_store(), 5);
792        assert_eq!(batch.get_state(), BatchState::Fill);
793        assert_eq!(batch.get_num_inserted_zkp_batch(), 0);
794        assert_eq!(batch.get_current_zkp_batch_index(), 0);
795        assert_eq!(batch.get_num_inserted_zkps(), 0);
796        batch.advance_state_to_full().unwrap();
797        assert_eq!(batch.get_state(), BatchState::Full);
798        batch.advance_state_to_inserted().unwrap();
799        assert_eq!(batch.get_state(), BatchState::Inserted);
800    }
801
802    /// Tests:
803    /// 1. Failing test lowest value in eligble range - 1
804    /// 2. Functional test lowest value in eligble range
805    /// 3. Functional test highest value in eligble range
806    /// 4. Failing test eligble range + 1
807    #[test]
808    fn test_value_is_inserted_in_batch() {
809        let mut batch = get_test_batch();
810        batch.advance_state_to_full().unwrap();
811        batch.advance_state_to_inserted().unwrap();
812        batch.start_index = 1;
813        batch.num_inserted = 5;
814        let lowest_eligible_value = batch.start_index;
815        let highest_eligible_value = batch.start_index + batch.get_num_inserted_elements() - 1;
816        // 1. Failing test lowest value in eligible range - 1
817        assert!(!batch.leaf_index_exists(lowest_eligible_value - 1));
818        // 2. Functional test lowest value in eligible range
819        assert!(batch.leaf_index_exists(lowest_eligible_value));
820        // 3. Functional test highest value in eligible range
821        assert!(batch.leaf_index_exists(highest_eligible_value));
822        // 4. Failing test eligible range + 1
823        assert!(!batch.leaf_index_exists(highest_eligible_value + 1));
824    }
825
826    /// 1. Failing: empty batch
827    /// 2. Functional: if zkp batch size is full else failing
828    /// 3. Failing: batch is completely inserted
829    #[test]
830    fn test_can_insert_batch() {
831        let mut batch = get_test_batch();
832        let mut current_slot = 1;
833        assert_eq!(
834            batch.get_first_ready_zkp_batch(),
835            Err(BatchedMerkleTreeError::BatchNotReady)
836        );
837        let mut value_store_bytes =
838            vec![0u8; ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(batch.batch_size)];
839        let mut value_store =
840            ZeroCopyVecU64::<[u8; 32]>::new(batch.batch_size, &mut value_store_bytes).unwrap();
841        let mut hash_chain_store_bytes = vec![
842            0u8;
843            ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity(
844                batch.get_num_hash_chain_store() as u64
845            )
846        ];
847        let mut hash_chain_store = ZeroCopyVecU64::<[u8; 32]>::new(
848            batch.get_num_hash_chain_store() as u64,
849            hash_chain_store_bytes.as_mut_slice(),
850        )
851        .unwrap();
852
853        for i in 0..batch.batch_size + 10 {
854            let mut value = [0u8; 32];
855            value[24..].copy_from_slice(&i.to_be_bytes());
856            if i < batch.batch_size {
857                batch
858                    .store_and_hash_value(
859                        &value,
860                        &mut value_store,
861                        &mut hash_chain_store,
862                        &current_slot,
863                    )
864                    .unwrap();
865            }
866            if (i + 1) % batch.zkp_batch_size == 0 && i != 0 {
867                assert_eq!(
868                    batch.get_first_ready_zkp_batch().unwrap(),
869                    i / batch.zkp_batch_size
870                );
871                batch.mark_as_inserted_in_merkle_tree(0, 0, 0).unwrap();
872            } else if i >= batch.batch_size {
873                assert_eq!(
874                    batch.get_first_ready_zkp_batch(),
875                    Err(BatchedMerkleTreeError::BatchAlreadyInserted)
876                );
877            } else {
878                assert_eq!(
879                    batch.get_first_ready_zkp_batch(),
880                    Err(BatchedMerkleTreeError::BatchNotReady)
881                );
882            }
883            current_slot += 1;
884        }
885    }
886
887    #[test]
888    fn test_get_state() {
889        let mut batch = get_test_batch();
890        assert_eq!(batch.get_state(), BatchState::Fill);
891        {
892            let result = batch.advance_state_to_inserted();
893            assert_eq!(result, Err(BatchedMerkleTreeError::BatchNotReady));
894            let result = batch.advance_state_to_fill(None);
895            assert_eq!(result, Err(BatchedMerkleTreeError::BatchNotReady));
896        }
897        batch.advance_state_to_full().unwrap();
898        assert_eq!(batch.get_state(), BatchState::Full);
899        {
900            let result = batch.advance_state_to_full();
901            assert_eq!(result, Err(BatchedMerkleTreeError::BatchNotReady));
902            let result = batch.advance_state_to_fill(None);
903            assert_eq!(result, Err(BatchedMerkleTreeError::BatchNotReady));
904        }
905        batch.advance_state_to_inserted().unwrap();
906        assert_eq!(batch.get_state(), BatchState::Inserted);
907    }
908
909    #[test]
910    fn test_bloom_filter_is_zeroed() {
911        let mut batch = get_test_batch();
912        assert!(!batch.bloom_filter_is_zeroed());
913        batch.set_bloom_filter_to_zeroed();
914        assert!(batch.bloom_filter_is_zeroed());
915        batch.set_bloom_filter_to_not_zeroed();
916        assert!(!batch.bloom_filter_is_zeroed());
917    }
918
919    #[test]
920    fn test_num_ready_zkp_updates() {
921        let mut batch = get_test_batch();
922        assert_eq!(batch.get_num_ready_zkp_updates(), 0);
923        batch.num_full_zkp_batches = 1;
924        assert_eq!(batch.get_num_ready_zkp_updates(), 1);
925        batch.num_inserted_zkp_batches = 1;
926        assert_eq!(batch.get_num_ready_zkp_updates(), 0);
927        batch.num_full_zkp_batches = 2;
928        assert_eq!(batch.get_num_ready_zkp_updates(), 1);
929    }
930
931    #[test]
932    fn test_get_num_inserted_elements() {
933        let mut batch = get_test_batch();
934        assert_eq!(batch.get_num_inserted_elements(), 0);
935        let mut hash_chain_bytes = vec![0u8; 32 * batch.batch_size as usize];
936        let mut hash_chain_store = ZeroCopyVecU64::<[u8; 32]>::new(
937            batch.get_num_zkp_batches(),
938            hash_chain_bytes.as_mut_slice(),
939        )
940        .unwrap();
941
942        for i in 0..batch.batch_size {
943            let mut value = [0u8; 32];
944            value[24..].copy_from_slice(&i.to_be_bytes());
945            batch
946                .add_to_hash_chain(&value, &mut hash_chain_store)
947                .unwrap();
948            assert_eq!(batch.get_num_inserted_elements(), i + 1);
949        }
950    }
951
952    #[test]
953    fn test_get_num_elements_inserted_into_tree() {
954        let mut batch = get_test_batch();
955        assert_eq!(batch.get_num_elements_inserted_into_tree(), 0);
956        for i in 0..batch.get_num_zkp_batches() {
957            if i % batch.zkp_batch_size == 0 {
958                batch.num_full_zkp_batches += 1;
959                batch
960                    .mark_as_inserted_in_merkle_tree(i, i as u32, 0)
961                    .unwrap();
962                assert_eq!(
963                    batch.get_num_elements_inserted_into_tree(),
964                    (i + 1) * batch.zkp_batch_size
965                );
966            }
967        }
968    }
969
970    // Moved BatchedQueueAccount test to this file
971    // to modify private Batch variables for assertions.
972    #[test]
973    fn test_get_num_inserted() {
974        let mut account_data = vec![0u8; 1000];
975        let mut queue_metadata = QueueMetadata::default();
976        let associated_merkle_tree = Pubkey::new_unique();
977        queue_metadata.associated_merkle_tree = associated_merkle_tree;
978        queue_metadata.queue_type = QueueType::OutputStateV2 as u64;
979        let batch_size = 4;
980        let zkp_batch_size = 2;
981        let bloom_filter_capacity = 0;
982        let num_iters = 0;
983        let mut current_slot = 1;
984        let mut account = BatchedQueueAccount::init(
985            &mut account_data,
986            queue_metadata,
987            batch_size,
988            zkp_batch_size,
989            num_iters,
990            bloom_filter_capacity,
991            Pubkey::new_unique(),
992        )
993        .unwrap();
994        // Tree height 4 -> capacity 16
995        account.tree_capacity = 16;
996        assert_eq!(account.get_num_inserted_in_current_batch(), 0);
997        // Fill first batch
998        for i in 1..=batch_size {
999            account
1000                .insert_into_current_batch(&[1u8; 32], &current_slot)
1001                .unwrap();
1002            if i == batch_size {
1003                // Current batch is batch[1] now since batch[0] is full
1004                assert_eq!(account.get_num_inserted_in_current_batch(), 0);
1005                assert_eq!(
1006                    account.batch_metadata.batches[0].get_num_inserted_elements(),
1007                    i
1008                );
1009            } else {
1010                assert_eq!(account.get_num_inserted_in_current_batch(), i);
1011            }
1012            current_slot += 1;
1013        }
1014        println!("full batch 0 {:?}", account.batch_metadata.batches[0]);
1015
1016        // Fill second batch
1017        for i in 1..=batch_size {
1018            account
1019                .insert_into_current_batch(&[2u8; 32], &current_slot)
1020                .unwrap();
1021            if i == batch_size {
1022                // Current batch is batch[0] and it is still full
1023                assert_eq!(account.get_num_inserted_in_current_batch(), 4);
1024                assert_eq!(
1025                    account.batch_metadata.batches[1].get_num_inserted_elements(),
1026                    i
1027                );
1028            } else {
1029                assert_eq!(account.get_num_inserted_in_current_batch(), i);
1030            }
1031            current_slot += 1;
1032        }
1033        println!("account {:?}", account.batch_metadata);
1034        println!("account {:?}", account.batch_metadata.batches[0]);
1035        println!("account {:?}", account.batch_metadata.batches[1]);
1036        assert_eq!(account.get_num_inserted_in_current_batch(), batch_size);
1037        assert_eq!(
1038            account.insert_into_current_batch(&[1u8; 32], &current_slot),
1039            Err(BatchedMerkleTreeError::BatchNotReady)
1040        );
1041        let ref_value_array = vec![[1u8; 32]; 4];
1042        assert_eq!(account.value_vecs[0].as_slice(), ref_value_array.as_slice());
1043        let ref_value_array = vec![[2u8; 32]; 4];
1044        assert_eq!(account.value_vecs[1].as_slice(), ref_value_array.as_slice());
1045        assert_eq!(account.batch_metadata.get_current_batch().start_index, 0);
1046        {
1047            let batch_0 = account.batch_metadata.batches[0];
1048            let mut expected_batch = Batch::new(
1049                num_iters,
1050                bloom_filter_capacity,
1051                batch_size,
1052                zkp_batch_size,
1053                0,
1054            );
1055            expected_batch.num_full_zkp_batches = 2;
1056            expected_batch.start_slot = 1;
1057            expected_batch.start_slot_is_set = 1;
1058            expected_batch.advance_state_to_full().unwrap();
1059            assert_eq!(batch_0, expected_batch);
1060        }
1061        {
1062            let batch_1 = account.batch_metadata.batches[1];
1063            let mut expected_batch = Batch::new(
1064                num_iters,
1065                bloom_filter_capacity,
1066                batch_size,
1067                zkp_batch_size,
1068                batch_size,
1069            );
1070            expected_batch.num_full_zkp_batches = 2;
1071            expected_batch.start_slot = 1 + batch_size;
1072            expected_batch.start_slot_is_set = 1;
1073            expected_batch.advance_state_to_full().unwrap();
1074            assert_eq!(batch_1, expected_batch);
1075        }
1076        // Mark first batch as inserted
1077        {
1078            account.batch_metadata.batches[0]
1079                .advance_state_to_inserted()
1080                .unwrap();
1081        }
1082        // Check that batch is cleared properly.
1083        {
1084            assert_eq!(
1085                account.batch_metadata.get_current_batch().get_state(),
1086                BatchState::Inserted
1087            );
1088            account
1089                .insert_into_current_batch(&[1u8; 32], &current_slot)
1090                .unwrap();
1091            assert_eq!(account.value_vecs[0].as_slice(), [[1u8; 32]].as_slice());
1092            assert_eq!(account.value_vecs[1].as_slice(), ref_value_array.as_slice());
1093            assert_eq!(
1094                account.hash_chain_stores[0].as_slice(),
1095                [[1u8; 32]].as_slice()
1096            );
1097            assert_eq!(
1098                account.batch_metadata.get_current_batch().get_state(),
1099                BatchState::Fill
1100            );
1101            let mut expected_batch = Batch::new(
1102                num_iters,
1103                bloom_filter_capacity,
1104                batch_size,
1105                zkp_batch_size,
1106                batch_size * 2,
1107            );
1108
1109            assert_ne!(*account.batch_metadata.get_current_batch(), expected_batch);
1110            expected_batch.num_inserted = 1;
1111            expected_batch.start_slot_is_set = 1;
1112            expected_batch.start_slot = current_slot;
1113            assert_eq!(*account.batch_metadata.get_current_batch(), expected_batch);
1114
1115            assert_eq!(account.batch_metadata.get_current_batch().start_index, 8);
1116        }
1117        // Fill cleared batch
1118        {
1119            let expected_start_slot = current_slot;
1120            for i in 1..batch_size {
1121                assert_eq!(account.get_num_inserted_in_current_batch(), i);
1122                account
1123                    .insert_into_current_batch(&[1u8; 32], &current_slot)
1124                    .unwrap();
1125                current_slot += 1;
1126            }
1127            assert_eq!(account.get_num_inserted_in_current_batch(), batch_size);
1128            let mut expected_batch = Batch::new(
1129                num_iters,
1130                bloom_filter_capacity,
1131                batch_size,
1132                zkp_batch_size,
1133                batch_size * 2,
1134            );
1135
1136            expected_batch.num_full_zkp_batches = 2;
1137            expected_batch.advance_state_to_full().unwrap();
1138            expected_batch.start_slot = expected_start_slot;
1139            expected_batch.start_slot_is_set = 1;
1140            assert_eq!(account.batch_metadata.batches[0], expected_batch);
1141            assert_ne!(*account.batch_metadata.get_current_batch(), expected_batch);
1142            assert_eq!(
1143                *account.batch_metadata.get_current_batch(),
1144                account.batch_metadata.batches[1]
1145            );
1146        }
1147        assert_eq!(account.batch_metadata.next_index, 12);
1148        // Mark second batch as inserted
1149        account
1150            .batch_metadata
1151            .get_current_batch_mut()
1152            .advance_state_to_inserted()
1153            .unwrap();
1154
1155        {
1156            let expected_start_slot = current_slot;
1157            for _ in 0..batch_size {
1158                assert!(!account.tree_is_full());
1159                assert!(account.check_tree_is_full().is_ok());
1160                account
1161                    .insert_into_current_batch(&[1u8; 32], &current_slot)
1162                    .unwrap();
1163                current_slot += 1;
1164            }
1165            assert_eq!(account.get_num_inserted_in_current_batch(), batch_size);
1166            let mut expected_batch = Batch::new(
1167                num_iters,
1168                bloom_filter_capacity,
1169                batch_size,
1170                zkp_batch_size,
1171                batch_size * 3,
1172            );
1173            expected_batch.num_full_zkp_batches = 2;
1174            expected_batch.start_slot = expected_start_slot;
1175            expected_batch.start_slot_is_set = 1;
1176            expected_batch.advance_state_to_full().unwrap();
1177            assert_eq!(account.batch_metadata.batches[1], expected_batch);
1178        }
1179        assert_eq!(account.batch_metadata.next_index, 16);
1180        assert!(account.tree_is_full());
1181        assert_eq!(
1182            account.check_tree_is_full(),
1183            Err(BatchedMerkleTreeError::TreeIsFull)
1184        );
1185    }
1186}