libzeropool_rs/
merkle.rs

1use std::collections::HashMap;
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
5use kvdb::{DBTransaction, KeyValueDB};
6use kvdb_memorydb::InMemory as MemoryDatabase;
7#[cfg(feature = "native")]
8use kvdb_persy::PersyDatabase as NativeDatabase;
9#[cfg(feature = "web")]
10use kvdb_web::Database as WebDatabase;
11use libzeropool::{
12    constants,
13    fawkes_crypto::{
14        core::sizedvec::SizedVec,
15        ff_uint::{Num, PrimeField},
16        native::poseidon::{poseidon, MerkleProof},
17    },
18    native::params::PoolParams,
19};
20use serde::{Deserialize, Serialize};
21
22use crate::utils::zero_note;
23
24pub type Hash<F> = Num<F>;
25
26const NUM_COLUMNS: u32 = 4;
27const NEXT_INDEX_KEY: &[u8] = br"next_index";
28enum DbCols {
29    Leaves = 0,
30    TempLeaves = 1,
31    NamedIndex = 2,
32    NextIndex = 3,
33}
34
35pub struct MerkleTree<D: KeyValueDB, P: PoolParams> {
36    db: D,
37    params: P,
38    default_hashes: Vec<Hash<P::Fr>>,
39    zero_note_hashes: Vec<Hash<P::Fr>>,
40    next_index: u64,
41}
42
43#[cfg(feature = "native")]
44pub type NativeMerkleTree<P> = MerkleTree<NativeDatabase, P>;
45
46#[cfg(feature = "web")]
47pub type WebMerkleTree<P> = MerkleTree<WebDatabase, P>;
48
49#[cfg(feature = "web")]
50impl<P: PoolParams> MerkleTree<WebDatabase, P> {
51    pub async fn new_web(name: &str, params: P) -> MerkleTree<WebDatabase, P> {
52        let db = WebDatabase::open(name.to_owned(), NUM_COLUMNS)
53            .await
54            .unwrap();
55
56        Self::new(db, params)
57    }
58}
59
60#[cfg(feature = "native")]
61impl<P: PoolParams> MerkleTree<NativeDatabase, P> {
62    pub fn new_native(path: &str, params: P) -> std::io::Result<MerkleTree<NativeDatabase, P>> {
63        let prefix = (0u32).to_be_bytes();
64        let db = NativeDatabase::open(path, 4, &[&prefix])?;
65
66        Ok(Self::new(db, params))
67    }
68}
69
70impl<P: PoolParams> MerkleTree<MemoryDatabase, P> {
71    pub fn new_test(params: P) -> MerkleTree<MemoryDatabase, P> {
72        Self::new(kvdb_memorydb::create(NUM_COLUMNS), params)
73    }
74}
75
76// TODO: Proper error handling.
77impl<D: KeyValueDB, P: PoolParams> MerkleTree<D, P> {
78    pub fn new(db: D, params: P) -> Self {
79        let db_next_index = db.get(DbCols::NextIndex as u32, NEXT_INDEX_KEY);
80        let next_index = match db_next_index {
81            Ok(Some(next_index)) => next_index.as_slice().read_u64::<BigEndian>().unwrap(),
82            _ => {
83                let mut cur_next_index = 0;
84                for (k, _v) in db.iter(0).map(|res| res.unwrap()) {
85                    let (height, index) = Self::parse_node_key(&k);
86
87                    if height == 0 && index >= cur_next_index {
88                        cur_next_index = Self::calc_next_index(index);
89                    }
90                }
91                cur_next_index
92            }
93        };
94
95        MerkleTree {
96            db,
97            default_hashes: Self::gen_default_hashes(&params),
98            zero_note_hashes: Self::gen_empty_note_hashes(&params),
99            params,
100            next_index,
101        }
102    }
103
104    /// Add hash for an element with a certain index at a certain height
105    /// Set `temporary` to true if you want this leaf and all unneeded connected nodes to be removed
106    /// during cleanup.
107    pub fn add_hash_at_height(
108        &mut self,
109        height: u32,
110        index: u64,
111        hash: Hash<P::Fr>,
112        temporary: bool,
113    ) {
114        // todo: revert index change if update fails?
115        let next_index_was_updated = self.update_next_index_from_node(height, index);
116
117        if hash == self.zero_note_hashes[height as usize] && !next_index_was_updated {
118            return;
119        }
120
121        let mut batch = self.db.transaction();
122
123        // add leaf
124        let temporary_leaves_count = if temporary { 1 } else { 0 };
125        self.set_batched(&mut batch, height, index, hash, temporary_leaves_count);
126
127        // update inner nodes
128        self.update_path_batched(&mut batch, height, index, hash, temporary_leaves_count);
129
130        self.db.write(batch).unwrap();
131    }
132
133    pub fn add_hash(&mut self, index: u64, hash: Hash<P::Fr>, temporary: bool) {
134        self.add_hash_at_height(0, index, hash, temporary)
135    }
136
137    pub fn append_hash(&mut self, hash: Hash<P::Fr>, temporary: bool) -> u64 {
138        let index = self.next_index;
139        self.add_hash(index, hash, temporary);
140        index
141    }
142
143    pub fn add_leafs_and_commitments(
144        &mut self,
145        leafs: Vec<(u64, Vec<Hash<P::Fr>>)>,
146        commitments: Vec<(u64, Hash<P::Fr>)>,
147    ) {
148        if leafs.is_empty() && commitments.is_empty() {
149            return;
150        }
151
152        let mut next_index: u64 = 0;
153        let mut start_index: u64 = u64::MAX;
154        let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = commitments
155            .into_iter()
156            .map(|(index, hash)| {
157                assert_eq!(index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
158                start_index = start_index.min(index);
159                next_index = next_index.max(index + 1);
160                (
161                    (
162                        constants::OUTPLUSONELOG as u32,
163                        index >> constants::OUTPLUSONELOG,
164                    ),
165                    hash,
166                )
167            })
168            .collect();
169
170        leafs.into_iter().for_each(|(index, leafs)| {
171            assert_eq!(index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
172            start_index = start_index.min(index);
173            next_index = next_index.max(index + leafs.len() as u64);
174            (0..constants::OUTPLUSONELOG).for_each(|height| {
175                virtual_nodes.insert(
176                    (
177                        height as u32,
178                        ((index + leafs.len() as u64 - 1) >> height) + 1,
179                    ),
180                    self.zero_note_hashes[height],
181                );
182            });
183            leafs.into_iter().enumerate().for_each(|(i, leaf)| {
184                virtual_nodes.insert((0_u32, index + i as u64), leaf);
185            });
186        });
187
188        let original_next_index = self.next_index;
189        self.update_next_index_from_node(0, next_index);
190
191        let update_boundaries = UpdateBoundaries {
192            updated_range_left_index: original_next_index,
193            updated_range_right_index: self.next_index,
194            new_hashes_left_index: start_index,
195            new_hashes_right_index: next_index,
196        };
197
198        // calculate new hashes
199        self.get_virtual_node_full(
200            constants::HEIGHT as u32,
201            0,
202            &mut virtual_nodes,
203            &update_boundaries,
204        );
205
206        // add new hashes to tree
207        self.put_hashes(virtual_nodes);
208    }
209
210    pub fn add_hashes<I>(&mut self, start_index: u64, hashes: I)
211    where
212        I: IntoIterator<Item = Hash<P::Fr>>,
213    {
214        // check that index is correct
215        assert_eq!(start_index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
216
217        let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = hashes
218            .into_iter()
219            // todo: check that there are no zero holes?
220            .filter(|hash| *hash != self.zero_note_hashes[0])
221            .enumerate()
222            .map(|(index, hash)| ((0, start_index + index as u64), hash))
223            .collect();
224        let new_hashes_count = virtual_nodes.len() as u64;
225
226        assert!(new_hashes_count <= (2u64 << constants::OUTPLUSONELOG));
227
228        let original_next_index = self.next_index;
229        self.update_next_index_from_node(0, start_index);
230
231        let update_boundaries = UpdateBoundaries {
232            updated_range_left_index: original_next_index,
233            updated_range_right_index: self.next_index,
234            new_hashes_left_index: start_index,
235            new_hashes_right_index: start_index + new_hashes_count,
236        };
237
238        // calculate new hashes
239        self.get_virtual_node_full(
240            constants::HEIGHT as u32,
241            0,
242            &mut virtual_nodes,
243            &update_boundaries,
244        );
245
246        // add new hashes to tree
247        self.put_hashes(virtual_nodes);
248    }
249
250    fn put_hashes(&mut self, virtual_nodes: HashMap<(u32, u64), Hash<<P as PoolParams>::Fr>>) {
251        let mut batch = self.db.transaction();
252
253        for ((height, index), value) in virtual_nodes {
254            self.set_batched(&mut batch, height, index, value, 0);
255        }
256
257        self.db.write(batch).unwrap();
258    }
259
260    // This method is used in tests.
261    #[cfg(test)]
262    fn add_subtree_root(&mut self, height: u32, index: u64, hash: Hash<P::Fr>) {
263        self.update_next_index_from_node(height, index);
264
265        let mut batch = self.db.transaction();
266
267        // add root
268        self.set_batched(&mut batch, height, index, hash, 1 << height);
269
270        // update path
271        self.update_path_batched(&mut batch, height, index, hash, 1 << height);
272
273        self.db.write(batch).unwrap();
274    }
275
276    pub fn get(&self, height: u32, index: u64) -> Hash<P::Fr> {
277        self.get_with_next_index(height, index, self.next_index)
278    }
279
280    fn get_with_next_index(&self, height: u32, index: u64, next_index: u64) -> Hash<P::Fr> {
281        match self.get_opt(height, index) {
282            Some(val) => val,
283            _ => {
284                let next_leave_index = u64::pow(2, height) * (index + 1);
285                if next_leave_index <= next_index {
286                    self.zero_note_hashes[height as usize]
287                } else {
288                    self.default_hashes[height as usize]
289                }
290            }
291        }
292    }
293
294    pub fn last_leaf(&self) -> Hash<P::Fr> {
295        // todo: can last leaf be an zero note?
296        match self.get_opt(0, self.next_index.saturating_sub(1)) {
297            Some(val) => val,
298            _ => self.default_hashes[0],
299        }
300    }
301
302    pub fn get_root(&self) -> Hash<P::Fr> {
303        self.get(constants::HEIGHT as u32, 0)
304    }
305
306    pub fn get_root_after_virtual<I>(&self, new_commitments: I) -> Hash<P::Fr>
307    where
308        I: IntoIterator<Item = Hash<P::Fr>>,
309    {
310        let next_leaf_index = self.next_index;
311        let next_commitment_index = next_leaf_index / 2u64.pow(constants::OUTPLUSONELOG as u32);
312        let index_step = constants::OUT as u64 + 1;
313
314        let mut virtual_commitment_nodes: HashMap<(u32, u64), Hash<P::Fr>> = new_commitments
315            .into_iter()
316            .enumerate()
317            .map(|(index, hash)| {
318                (
319                    (
320                        constants::OUTPLUSONELOG as u32,
321                        next_commitment_index + index as u64,
322                    ),
323                    hash,
324                )
325            })
326            .collect();
327        let new_commitments_count = virtual_commitment_nodes.len() as u64;
328
329        self.get_virtual_node(
330            constants::HEIGHT as u32,
331            0,
332            &mut virtual_commitment_nodes,
333            next_leaf_index,
334            next_leaf_index + new_commitments_count * index_step,
335        )
336    }
337
338    pub fn get_root_optimistic(
339        &self,
340        virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
341        update_boundaries: &UpdateBoundaries,
342    ) -> Hash<P::Fr> {
343        self.get_virtual_node_full(
344            constants::HEIGHT as u32,
345            0,
346            virtual_nodes,
347            &update_boundaries,
348        )
349    }
350
351    pub fn get_opt(&self, height: u32, index: u64) -> Option<Hash<P::Fr>> {
352        assert!(height <= constants::HEIGHT as u32);
353
354        let key = Self::node_key(height, index);
355        let res = self.db.get(0, &key);
356
357        match res {
358            Ok(Some(ref val)) => Some(Hash::<P::Fr>::try_from_slice(val).unwrap()),
359            _ => None,
360        }
361    }
362
363    pub fn get_proof_unchecked<const H: usize>(&self, index: u64) -> MerkleProof<P::Fr, { H }> {
364        let mut sibling: SizedVec<_, { H }> = (0..H).map(|_| Num::ZERO).collect();
365        let mut path: SizedVec<_, { H }> = (0..H).map(|_| false).collect();
366
367        let start_height = constants::HEIGHT - H;
368
369        sibling.iter_mut().zip(path.iter_mut()).enumerate().fold(
370            index,
371            |x, (h, (sibling, is_right))| {
372                let cur_height = (start_height + h) as u32;
373                *is_right = x % 2 == 1;
374                *sibling = self.get(cur_height, x ^ 1);
375
376                x / 2
377            },
378        );
379
380        MerkleProof { sibling, path }
381    }
382
383    pub fn get_leaf_proof(&self, index: u64) -> Option<MerkleProof<P::Fr, { constants::HEIGHT }>> {
384        let key = Self::node_key(0, index);
385        let node_present = self.db.get(0, &key).map_or(false, |value| value.is_some());
386        if !node_present {
387            return None;
388        }
389        Some(self.get_proof_unchecked(index))
390    }
391
392    // This method is used in tests.
393    #[cfg(test)]
394    fn get_proof_after<I>(
395        &mut self,
396        new_hashes: I,
397    ) -> Vec<MerkleProof<P::Fr, { constants::HEIGHT }>>
398    where
399        I: IntoIterator<Item = Hash<P::Fr>>,
400    {
401        let new_hashes: Vec<_> = new_hashes.into_iter().collect();
402        let size = new_hashes.len() as u64;
403
404        // TODO: Optimize, no need to mutate the database.
405        let index_offset = self.next_index;
406        self.add_hashes(index_offset, new_hashes);
407
408        let proofs = (index_offset..index_offset + size)
409            .map(|index| {
410                self.get_leaf_proof(index)
411                    .expect("Leaf was expected to be present (bug)")
412            })
413            .collect();
414
415        // Restore next_index.
416        self.next_index = index_offset;
417        // FIXME: Not all nodes are deleted here
418        for index in index_offset..index_offset + size {
419            self.remove_leaf(index);
420        }
421
422        proofs
423    }
424
425    pub fn get_proof_after_virtual<I>(
426        &self,
427        new_hashes: I,
428    ) -> Vec<MerkleProof<P::Fr, { constants::HEIGHT }>>
429    where
430        I: IntoIterator<Item = Hash<P::Fr>>,
431    {
432        let index_offset = self.next_index;
433
434        let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = new_hashes
435            .into_iter()
436            .enumerate()
437            .map(|(index, hash)| ((0, index_offset + index as u64), hash))
438            .collect();
439        let new_hashes_count = virtual_nodes.len() as u64;
440
441        let update_boundaries = UpdateBoundaries {
442            updated_range_left_index: index_offset,
443            updated_range_right_index: Self::calc_next_index(index_offset),
444            new_hashes_left_index: index_offset,
445            new_hashes_right_index: index_offset + new_hashes_count,
446        };
447
448        (index_offset..index_offset + new_hashes_count)
449            .map(|index| self.get_proof_virtual(index, &mut virtual_nodes, &update_boundaries))
450            .collect()
451    }
452
453    pub fn get_proof_virtual_index<I>(
454        &self,
455        index: u64,
456        new_hashes: I,
457    ) -> Option<MerkleProof<P::Fr, { constants::HEIGHT }>>
458    where
459        I: IntoIterator<Item = Hash<P::Fr>>,
460    {
461        let index_offset = self.next_index;
462
463        let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = new_hashes
464            .into_iter()
465            .enumerate()
466            .map(|(index, hash)| ((0, index_offset + index as u64), hash))
467            .collect();
468        let new_hashes_count = virtual_nodes.len() as u64;
469
470        let update_boundaries = UpdateBoundaries {
471            updated_range_left_index: index_offset,
472            updated_range_right_index: index_offset + new_hashes_count,
473            new_hashes_left_index: index_offset,
474            new_hashes_right_index: index_offset + new_hashes_count,
475        };
476
477        Some(self.get_proof_virtual(index, &mut virtual_nodes, &update_boundaries))
478    }
479
480    pub fn get_virtual_subtree<I1, I2>(
481        &self,
482        new_hashes: I1,
483        new_commitments: I2,
484    ) -> (HashMap<(u32, u64), Hash<P::Fr>>, UpdateBoundaries)
485    where
486        I1: IntoIterator<Item = (u64, Vec<Hash<P::Fr>>)>,
487        I2: IntoIterator<Item = (u64, Hash<P::Fr>)>,
488    {
489        let mut next_index: Option<u64> = None;
490        let mut start_index: Option<u64> = None;
491        let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = new_commitments
492            .into_iter()
493            .map(|(index, hash)| {
494                assert_eq!(index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
495                start_index = Some(start_index.unwrap_or(u64::MAX).min(index));
496                next_index = Some(next_index.unwrap_or(0).max(index + 1));
497                (
498                    (
499                        constants::OUTPLUSONELOG as u32,
500                        index >> constants::OUTPLUSONELOG,
501                    ),
502                    hash,
503                )
504            })
505            .collect();
506
507        new_hashes.into_iter().for_each(|(index, leafs)| {
508            assert_eq!(index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
509            start_index = Some(start_index.unwrap_or(u64::MAX).min(index));
510            next_index = Some(next_index.unwrap_or(0).max(index + leafs.len() as u64));
511            (0..constants::OUTPLUSONELOG).for_each(|height| {
512                virtual_nodes.insert(
513                    (
514                        height as u32,
515                        ((index + leafs.len() as u64 - 1) >> height) + 1,
516                    ),
517                    self.zero_note_hashes[height],
518                );
519            });
520            leafs.into_iter().enumerate().for_each(|(i, leaf)| {
521                virtual_nodes.insert((0_u32, index + i as u64), leaf);
522            });
523        });
524
525        let update_boundaries = {
526            if let (Some(start_index), Some(next_index)) = (start_index, next_index) {
527                UpdateBoundaries {
528                    updated_range_left_index: self.next_index,
529                    updated_range_right_index: Self::calc_next_index(next_index),
530                    new_hashes_left_index: start_index,
531                    new_hashes_right_index: next_index,
532                }
533            } else {
534                UpdateBoundaries {
535                    updated_range_left_index: self.next_index,
536                    updated_range_right_index: self.next_index,
537                    new_hashes_left_index: self.next_index,
538                    new_hashes_right_index: self.next_index,
539                }
540            }
541        };
542
543        // calculate new hashes
544        self.get_virtual_node_full(
545            constants::HEIGHT as u32,
546            0,
547            &mut virtual_nodes,
548            &update_boundaries,
549        );
550
551        (virtual_nodes, update_boundaries)
552    }
553
554    pub fn get_proof_optimistic_index(
555        &self,
556        index: u64,
557        virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
558        update_boundaries: &UpdateBoundaries,
559    ) -> Option<MerkleProof<P::Fr, { constants::HEIGHT }>> {
560        Some(self.get_proof_virtual(index, virtual_nodes, update_boundaries))
561    }
562
563    fn get_proof_virtual<const H: usize>(
564        &self,
565        index: u64,
566        virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
567        update_boundaries: &UpdateBoundaries,
568    ) -> MerkleProof<P::Fr, { H }> {
569        let mut sibling: SizedVec<_, { H }> = (0..H).map(|_| Num::ZERO).collect();
570        let mut path: SizedVec<_, { H }> = (0..H).map(|_| false).collect();
571
572        let start_height = constants::HEIGHT - H;
573
574        sibling.iter_mut().zip(path.iter_mut()).enumerate().fold(
575            index,
576            |x, (h, (sibling, is_right))| {
577                let cur_height = (start_height + h) as u32;
578                *is_right = x % 2 == 1;
579                *sibling =
580                    self.get_virtual_node_full(cur_height, x ^ 1, virtual_nodes, update_boundaries);
581
582                x / 2
583            },
584        );
585
586        MerkleProof { sibling, path }
587    }
588
589    pub fn get_virtual_node(
590        &self,
591        height: u32,
592        index: u64,
593        virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
594        new_hashes_left_index: u64,
595        new_hashes_right_index: u64,
596    ) -> Hash<P::Fr> {
597        let update_boundaries = UpdateBoundaries {
598            updated_range_left_index: new_hashes_left_index,
599            updated_range_right_index: new_hashes_right_index,
600            new_hashes_left_index,
601            new_hashes_right_index,
602        };
603
604        self.get_virtual_node_full(height, index, virtual_nodes, &update_boundaries)
605    }
606
607    fn get_virtual_node_full(
608        &self,
609        height: u32,
610        index: u64,
611        virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
612        update_boundaries: &UpdateBoundaries,
613    ) -> Hash<P::Fr> {
614        let node_left = index * (1 << height);
615        let node_right = (index + 1) * (1 << height);
616        if node_right <= update_boundaries.updated_range_left_index
617            || update_boundaries.updated_range_right_index <= node_left
618        {
619            return self.get(height, index);
620        }
621        if (node_right <= update_boundaries.new_hashes_left_index
622            || update_boundaries.new_hashes_right_index <= node_left)
623            && update_boundaries.updated_range_left_index <= node_left
624            && node_right <= update_boundaries.updated_range_right_index
625        {
626            return self.zero_note_hashes[height as usize];
627        }
628
629        let key = (height, index);
630        match virtual_nodes.get(&key) {
631            Some(hash) => *hash,
632            None => {
633                let left_child = self.get_virtual_node_full(
634                    height - 1,
635                    2 * index,
636                    virtual_nodes,
637                    update_boundaries,
638                );
639                let right_child = self.get_virtual_node_full(
640                    height - 1,
641                    2 * index + 1,
642                    virtual_nodes,
643                    update_boundaries,
644                );
645                let pair = [left_child, right_child];
646                let hash = poseidon(pair.as_ref(), self.params.compress());
647                virtual_nodes.insert(key, hash);
648
649                hash
650            }
651        }
652    }
653
654    pub fn clean(&mut self) -> u64 {
655        self.clean_before_index(u64::MAX)
656    }
657
658    pub fn clean_before_index(&mut self, clean_before_index: u64) -> u64 {
659        let mut batch = self.db.transaction();
660
661        // get all nodes
662        // todo: improve performance?
663        let keys: Vec<(u32, u64)> = self
664            .db
665            .iter(0)
666            .map(|res| Self::parse_node_key(&res.unwrap().0))
667            .collect();
668        // remove unnecessary nodes
669        for (height, index) in keys {
670            // leaves have no children
671            if height == 0 {
672                continue;
673            }
674
675            // remove only nodes before specified index
676            if (index + 1) * (1 << height) > clean_before_index {
677                continue;
678            }
679
680            if self.subtree_contains_only_temporary_leaves(height, index) {
681                // all leaves in subtree are temporary, we can keep only subtree root
682                self.remove_batched(&mut batch, height - 1, 2 * index);
683                self.remove_batched(&mut batch, height - 1, 2 * index + 1);
684            }
685        }
686
687        self.set_clean_index_batched(&mut batch, clean_before_index);
688
689        self.db.write(batch).unwrap();
690
691        self.next_index
692    }
693
694    pub fn rollback(&mut self, rollback_index: u64) -> Option<u64> {
695        let mut result: Option<u64> = None;
696
697        // check that nodes that are necessary for rollback were not removed by clean
698        let clean_index = self.get_clean_index();
699        if rollback_index < clean_index {
700            // find what nodes are missing
701            let mut nodes_request_index = self.next_index;
702            let mut index = rollback_index;
703            for height in 0..constants::HEIGHT as u32 {
704                let sibling_index = index ^ 1;
705                if sibling_index < index
706                    && !self.subtree_contains_only_temporary_leaves(height, sibling_index)
707                {
708                    let leaf_index = index * (1 << height);
709                    if leaf_index < nodes_request_index {
710                        nodes_request_index = leaf_index
711                    }
712                }
713                index /= 2;
714            }
715            if nodes_request_index < clean_index {
716                result = Some(nodes_request_index)
717            }
718        }
719
720        // Update next_index.
721        let original_next_index = self.next_index;
722        self.next_index = if rollback_index > 0 {
723            Self::calc_next_index(rollback_index - 1)
724        } else {
725            0
726        };
727        // remove leaves
728        for index in (rollback_index..original_next_index).rev() {
729            self.remove_leaf(index);
730        }
731
732        result
733    }
734
735    pub fn get_all_nodes(&self) -> Vec<Node<P::Fr>> {
736        self.db
737            .iter(0)
738            .map(|res| {
739                let (key, value) = res.unwrap();
740                Self::build_node(&key, &value)
741            })
742            .collect()
743    }
744
745    pub fn get_leaves(&self) -> Vec<Node<P::Fr>> {
746        self.get_leaves_after(0)
747    }
748
749    pub fn get_leaves_after(&self, index: u64) -> Vec<Node<P::Fr>> {
750        let prefix = (0u32).to_be_bytes();
751        self.db
752            .iter_with_prefix(0, &prefix)
753            .map(|res| {
754                let (key, value) = res.unwrap();
755                Self::build_node(&key, &value)
756            })
757            .filter(|node| node.index >= index)
758            .collect()
759    }
760
761    pub fn next_index(&self) -> u64 {
762        self.next_index
763    }
764
765    fn update_next_index(&mut self, next_index: u64) -> bool {
766        if next_index >= self.next_index {
767            let mut transaction = self.db.transaction();
768            let mut data = [0u8; 8];
769            {
770                let mut bytes = &mut data[..];
771                let _ = bytes.write_u64::<BigEndian>(next_index);
772            }
773            transaction.put(DbCols::NextIndex as u32, NEXT_INDEX_KEY, &data);
774            self.db.write(transaction).unwrap();
775
776            self.next_index = next_index;
777            true
778        } else {
779            false
780        }
781    }
782
783    fn update_next_index_from_node(&mut self, height: u32, index: u64) -> bool {
784        let leaf_index = u64::pow(2, height) * (index + 1) - 1;
785        self.update_next_index(Self::calc_next_index(leaf_index))
786    }
787
788    #[inline]
789    fn calc_next_index(leaf_index: u64) -> u64 {
790        ((leaf_index >> constants::OUTPLUSONELOG) + 1) << constants::OUTPLUSONELOG
791    }
792
793    fn update_path_batched(
794        &mut self,
795        batch: &mut DBTransaction,
796        height: u32,
797        index: u64,
798        hash: Hash<P::Fr>,
799        temporary_leaves_count: u64,
800    ) {
801        let mut child_index = index;
802        let mut child_hash = hash;
803        let mut child_temporary_leaves_count = temporary_leaves_count;
804        // todo: improve
805        for current_height in height + 1..=constants::HEIGHT as u32 {
806            let parent_index = child_index / 2;
807
808            // get pair of children
809            let second_child_index = child_index ^ 1;
810
811            // compute hash
812            let pair = if child_index % 2 == 0 {
813                [child_hash, self.get(current_height - 1, second_child_index)]
814            } else {
815                [self.get(current_height - 1, second_child_index), child_hash]
816            };
817            let hash = poseidon(pair.as_ref(), self.params.compress());
818
819            // compute temporary leaves count
820            let second_child_temporary_leaves_count =
821                self.get_temporary_count(current_height - 1, second_child_index);
822            let parent_temporary_leaves_count =
823                child_temporary_leaves_count + second_child_temporary_leaves_count;
824
825            self.set_batched(
826                batch,
827                current_height,
828                parent_index,
829                hash,
830                parent_temporary_leaves_count,
831            );
832
833            /*if parent_temporary_leaves_count == (1 << current_height) {
834                // all leaves in subtree are temporary, we can keep only subtree root
835                self.remove_batched(batch, current_height - 1, child_index);
836                self.remove_batched(batch, current_height - 1, second_child_index);
837            }*/
838
839            child_index = parent_index;
840            child_hash = hash;
841            child_temporary_leaves_count = parent_temporary_leaves_count;
842        }
843    }
844
845    fn set_batched(
846        &mut self,
847        batch: &mut DBTransaction,
848        height: u32,
849        index: u64,
850        hash: Hash<P::Fr>,
851        temporary_leaves_count: u64,
852    ) {
853        let key = Self::node_key(height, index);
854        if hash != self.zero_note_hashes[height as usize] {
855            batch.put(DbCols::Leaves as u32, &key, &hash.try_to_vec().unwrap());
856        } else {
857            batch.delete(DbCols::Leaves as u32, &key);
858        }
859        if temporary_leaves_count > 0 {
860            batch.put(
861                DbCols::TempLeaves as u32,
862                &key,
863                &temporary_leaves_count.to_be_bytes(),
864            );
865        } else if self
866            .db
867            .has_key(DbCols::TempLeaves as u32, &key)
868            .unwrap_or(false)
869        {
870            batch.delete(DbCols::TempLeaves as u32, &key);
871        }
872    }
873
874    fn remove_batched(&mut self, batch: &mut DBTransaction, height: u32, index: u64) {
875        let key = Self::node_key(height, index);
876        batch.delete(DbCols::Leaves as u32, &key);
877        batch.delete(DbCols::TempLeaves as u32, &key);
878    }
879
880    fn remove_leaf(&mut self, index: u64) {
881        let mut batch = self.db.transaction();
882
883        self.remove_batched(&mut batch, 0, index);
884        self.update_path_batched(&mut batch, 0, index, self.default_hashes[0], 0);
885
886        self.db.write(batch).unwrap();
887    }
888
889    fn get_clean_index(&self) -> u64 {
890        match self.get_named_index_opt("clean_index") {
891            Some(val) => val,
892            _ => 0,
893        }
894    }
895
896    fn set_clean_index_batched(&mut self, batch: &mut DBTransaction, value: u64) {
897        self.set_named_index_batched(batch, "clean_index", value);
898    }
899
900    fn get_named_index_opt(&self, key: &str) -> Option<u64> {
901        let res = self.db.get(2, key.as_bytes());
902        match res {
903            Ok(Some(ref val)) => Some((&val[..]).read_u64::<BigEndian>().unwrap()),
904            _ => None,
905        }
906    }
907
908    fn set_named_index_batched(&mut self, batch: &mut DBTransaction, key: &str, value: u64) {
909        batch.put(
910            DbCols::NamedIndex as u32,
911            key.as_bytes(),
912            &value.to_be_bytes(),
913        );
914    }
915
916    fn get_temporary_count(&self, height: u32, index: u64) -> u64 {
917        match self.get_temporary_count_opt(height, index) {
918            Some(val) => val,
919            _ => 0,
920        }
921    }
922
923    fn get_temporary_count_opt(&self, height: u32, index: u64) -> Option<u64> {
924        assert!(height <= constants::HEIGHT as u32);
925
926        let key = Self::node_key(height, index);
927        let res = self.db.get(1, &key);
928
929        match res {
930            Ok(Some(ref val)) => Some((&val[..]).read_u64::<BigEndian>().unwrap()),
931            _ => None,
932        }
933    }
934
935    fn subtree_contains_only_temporary_leaves(&self, height: u32, index: u64) -> bool {
936        self.get_temporary_count(height, index) == (1 << height)
937    }
938
939    #[inline]
940    fn node_key(height: u32, index: u64) -> [u8; 12] {
941        let mut data = [0u8; 12];
942        {
943            let mut bytes = &mut data[..];
944            let _ = bytes.write_u32::<BigEndian>(height);
945            let _ = bytes.write_u64::<BigEndian>(index);
946        }
947
948        data
949    }
950
951    fn parse_node_key(data: &[u8]) -> (u32, u64) {
952        let mut bytes = data;
953        let height = bytes.read_u32::<BigEndian>().unwrap();
954        let index = bytes.read_u64::<BigEndian>().unwrap();
955
956        (height, index)
957    }
958
959    fn build_node(key: &[u8], value: &[u8]) -> Node<P::Fr> {
960        let (height, index) = Self::parse_node_key(key);
961        let value = Hash::try_from_slice(value).unwrap();
962
963        Node {
964            index,
965            height,
966            value,
967        }
968    }
969
970    fn gen_default_hashes(params: &P) -> Vec<Hash<P::Fr>> {
971        let mut default_hashes = vec![Num::ZERO; constants::HEIGHT + 1];
972
973        Self::fill_default_hashes(&mut default_hashes, params);
974
975        default_hashes
976    }
977
978    fn gen_empty_note_hashes(params: &P) -> Vec<Hash<P::Fr>> {
979        let empty_note_hash = zero_note().hash(params);
980
981        let mut empty_note_hashes = vec![empty_note_hash; constants::HEIGHT + 1];
982
983        Self::fill_default_hashes(&mut empty_note_hashes, params);
984
985        empty_note_hashes
986    }
987
988    fn fill_default_hashes(default_hashes: &mut Vec<Hash<P::Fr>>, params: &P) {
989        for i in 1..default_hashes.len() {
990            let t = default_hashes[i - 1];
991            default_hashes[i] = poseidon([t, t].as_ref(), params.compress());
992        }
993    }
994}
995
996#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
997pub struct Node<F: PrimeField> {
998    pub index: u64,
999    pub height: u32,
1000    #[serde(bound(serialize = "", deserialize = ""))]
1001    pub value: Num<F>,
1002}
1003
1004pub struct UpdateBoundaries {
1005    updated_range_left_index: u64,
1006    updated_range_right_index: u64,
1007    new_hashes_left_index: u64,
1008    new_hashes_right_index: u64,
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013    use std::sync::atomic::AtomicUsize;
1014
1015    #[cfg(not(feature = "native"))]
1016    use kvdb_memorydb::InMemory as Database;
1017    #[cfg(feature = "native")]
1018    use kvdb_persy::PersyDatabase as Database;
1019    use libzeropool::{
1020        fawkes_crypto::ff_uint::rand::Rng,
1021        native::{params::PoolBN256, tx},
1022        POOL_PARAMS,
1023    };
1024    use rand::{seq::SliceRandom, thread_rng};
1025    use test_case::test_case;
1026
1027    use super::*;
1028    use crate::random::CustomRng;
1029
1030    struct TestContext {
1031        tree: MerkleTree<Database, PoolBN256>,
1032        #[cfg(feature = "native")]
1033        db_path: String,
1034    }
1035
1036    #[cfg(feature = "native")]
1037    impl Drop for TestContext {
1038        fn drop(&mut self) {
1039            std::fs::remove_file(&self.db_path).unwrap();
1040        }
1041    }
1042
1043    #[cfg(feature = "native")]
1044    fn init() -> TestContext {
1045        static FILE_COUNTER: AtomicUsize = AtomicUsize::new(0);
1046        let file_counter = FILE_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1047        let path = format!("merkle-test-{}.persy", file_counter);
1048        let tree = MerkleTree::new_native(&path, POOL_PARAMS.clone()).unwrap();
1049
1050        TestContext {
1051            tree,
1052            db_path: path,
1053        }
1054    }
1055
1056    #[cfg(not(feature = "native"))]
1057    fn init() -> TestContext {
1058        let db = kvdb_memorydb::create(3);
1059
1060        TestContext {
1061            tree: MerkleTree::new(db, POOL_PARAMS.clone()),
1062        }
1063    }
1064
1065    #[test]
1066    fn test_add_hashes_first_3() {
1067        let mut rng = CustomRng;
1068        let tree = &mut init().tree;
1069        let hashes: Vec<_> = (0..3).map(|_| rng.gen()).collect();
1070        tree.add_hashes(0, hashes.clone());
1071
1072        let nodes = tree.get_all_nodes();
1073        assert_eq!(nodes.len(), constants::HEIGHT + 4);
1074
1075        for h in 0..constants::HEIGHT as u32 {
1076            assert!(tree.get_opt(h, 0).is_some()); // TODO: Compare with expected hash
1077        }
1078
1079        for (i, hash) in hashes.into_iter().enumerate() {
1080            assert_eq!(tree.get(0, i as u64), hash);
1081        }
1082    }
1083
1084    #[test]
1085    fn test_add_hashes_last_3() {
1086        let mut rng = CustomRng;
1087        let mut tree = &mut init().tree;
1088
1089        let max_index = (1 << constants::HEIGHT) - 1;
1090        let hashes: Vec<_> = (0..3).map(|_| rng.gen()).collect();
1091        tree.add_hashes(max_index - 127, hashes.clone());
1092
1093        let nodes = tree.get_all_nodes();
1094        assert_eq!(nodes.len(), constants::HEIGHT + 4);
1095
1096        for h in constants::OUTPLUSONELOG as u32 + 1..constants::HEIGHT as u32 {
1097            let index = max_index / 2u64.pow(h);
1098            assert!(tree.get_opt(h, index).is_some()); // TODO: Compare with expected hash
1099        }
1100
1101        for (i, hash) in hashes.into_iter().enumerate() {
1102            assert_eq!(tree.get(0, max_index - 127 + i as u64), hash);
1103        }
1104    }
1105
1106    #[test]
1107    fn test_add_hashes() {
1108        let mut tree_expected = &mut init().tree;
1109        let mut tree_actual = &mut init().tree;
1110
1111        // add first subtree
1112        add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 0, 3);
1113        check_trees_are_equal(&tree_expected, &tree_actual);
1114
1115        // add second subtree
1116        add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 128, 8);
1117        check_trees_are_equal(&tree_expected, &tree_actual);
1118
1119        // add third subtree
1120        add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 256, 1);
1121        check_trees_are_equal(&tree_expected, &tree_actual);
1122    }
1123
1124    #[test]
1125    fn test_add_hashes_with_gap() {
1126        let mut tree_expected = &mut init().tree;
1127        let mut tree_actual = &mut init().tree;
1128
1129        // add first subtree
1130        add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 0, 3);
1131        check_trees_are_equal(&tree_expected, &tree_actual);
1132
1133        tree_expected.add_hash_at_height(
1134            constants::OUTPLUSONELOG as u32,
1135            1,
1136            tree_expected.zero_note_hashes[constants::OUTPLUSONELOG],
1137            false,
1138        );
1139
1140        // add third subtree, second subtree contains zero node hashes
1141        add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 256, 7);
1142        check_trees_are_equal(&tree_expected, &tree_actual);
1143    }
1144
1145    fn add_hashes_to_test_trees<D: KeyValueDB, P: PoolParams>(
1146        tree_expected: &mut MerkleTree<D, P>,
1147        tree_actual: &mut MerkleTree<D, P>,
1148        start_index: u64,
1149        count: u64,
1150    ) {
1151        let mut rng = CustomRng;
1152
1153        let hashes: Vec<_> = (0..count).map(|_| rng.gen()).collect();
1154
1155        for (i, hash) in hashes.clone().into_iter().enumerate() {
1156            tree_expected.add_hash(start_index + i as u64, hash, false);
1157        }
1158        tree_actual.add_hashes(start_index, hashes);
1159    }
1160
1161    fn check_trees_are_equal<D: KeyValueDB, P: PoolParams>(
1162        tree_first: &MerkleTree<D, P>,
1163        tree_second: &MerkleTree<D, P>,
1164    ) {
1165        assert_eq!(tree_first.next_index, tree_second.next_index);
1166        assert_eq!(tree_first.get_root(), tree_second.get_root());
1167
1168        let mut first_nodes = tree_first.get_all_nodes();
1169        let mut second_nodes = tree_second.get_all_nodes();
1170        assert_eq!(first_nodes.len(), second_nodes.len());
1171
1172        first_nodes.sort_by_key(|node| (node.height, node.index));
1173        second_nodes.sort_by_key(|node| (node.height, node.index));
1174
1175        assert_eq!(first_nodes, second_nodes);
1176    }
1177
1178    // #[test]
1179    // fn test_unnecessary_temporary_nodes_are_removed() {
1180    //     let mut rng = CustomRng;
1181    //     let mut tree = create_tree();
1182    //
1183    //     let mut hashes: Vec<_> = (0..6).map(|_| rng.gen()).collect();
1184    //
1185    //     // make some hashes temporary
1186    //     // these two must remain after cleanup
1187    //     hashes[1].2 = true;
1188    //     hashes[3].2 = true;
1189    //
1190    //     // these two must be removed
1191    //     hashes[4].2 = true;
1192    //     hashes[5].2 = true;
1193    //
1194    //     tree.add_hashes(0, hashes);
1195    //
1196    //     let next_index = tree.clean();
1197    //     assert_eq!(next_index, tree.next_index);
1198    //
1199    //     let nodes = tree.get_all_nodes();
1200    //     assert_eq!(nodes.len(), constants::HEIGHT + 7);
1201    //     assert_eq!(tree.get_opt(0, 4), None);
1202    //     assert_eq!(tree.get_opt(0, 5), None);
1203    // }
1204
1205    #[test]
1206    fn test_get_leaf_proof() {
1207        let mut rng = CustomRng;
1208        let mut tree = &mut init().tree;
1209        let proof = tree.get_leaf_proof(123);
1210
1211        assert!(proof.is_none());
1212
1213        tree.add_hash(123, rng.gen(), false);
1214        let proof = tree.get_leaf_proof(123).unwrap();
1215
1216        assert_eq!(proof.sibling.as_slice().len(), constants::HEIGHT);
1217        assert_eq!(proof.path.as_slice().len(), constants::HEIGHT);
1218    }
1219
1220    #[test]
1221    fn test_get_proof_unchecked() {
1222        let mut rng = CustomRng;
1223        let mut tree = &mut init().tree;
1224
1225        // Get proof for the right child of the root of the tree
1226        const SUBROOT_HEIGHT: usize = 1;
1227        let proof = tree.get_proof_unchecked::<SUBROOT_HEIGHT>(1);
1228        assert_eq!(
1229            proof.sibling[SUBROOT_HEIGHT - 1],
1230            tree.default_hashes[constants::HEIGHT - SUBROOT_HEIGHT]
1231        );
1232
1233        assert_eq!(proof.sibling.as_slice().len(), SUBROOT_HEIGHT);
1234        assert_eq!(proof.path.as_slice().len(), SUBROOT_HEIGHT);
1235
1236        // If we add leaf to the right branch,
1237        // then left child of the root should not be affected directly
1238        tree.add_hash(1 << 47, rng.gen(), false);
1239        let proof = tree.get_proof_unchecked::<SUBROOT_HEIGHT>(1);
1240        assert_eq!(
1241            proof.sibling[SUBROOT_HEIGHT - 1],
1242            tree.zero_note_hashes[constants::HEIGHT - SUBROOT_HEIGHT]
1243        );
1244
1245        // But if we add leaf to the left branch, then left child of the root should change
1246        tree.add_hash((1 << 47) - 1, rng.gen(), false);
1247        let proof = tree.get_proof_unchecked::<SUBROOT_HEIGHT>(1);
1248        assert_ne!(
1249            proof.sibling[SUBROOT_HEIGHT - 1],
1250            tree.zero_note_hashes[constants::HEIGHT - SUBROOT_HEIGHT]
1251        );
1252    }
1253
1254    #[test]
1255    fn test_temporary_nodes_are_used_to_calculate_hashes_first() {
1256        let mut rng = CustomRng;
1257        let mut tree = &mut init().tree;
1258
1259        let hash0: Hash<_> = rng.gen();
1260        let hash1: Hash<_> = rng.gen();
1261
1262        // add hash for index 0
1263        tree.add_hash(0, hash0, true);
1264
1265        // add hash for index 1
1266        tree.add_hash(1, hash1, false);
1267
1268        let parent_hash = tree.get(1, 0);
1269        let expected_parent_hash = poseidon([hash0, hash1].as_ref(), POOL_PARAMS.compress());
1270
1271        assert_eq!(parent_hash, expected_parent_hash);
1272    }
1273
1274    #[test_case(0, 5)]
1275    #[test_case(1, 5)]
1276    #[test_case(2, 5)]
1277    #[test_case(4, 5)]
1278    #[test_case(5, 5)]
1279    #[test_case(5, 8)]
1280    #[test_case(10, 15)]
1281    #[test_case(12, 15)]
1282    fn test_all_temporary_nodes_in_subtree_are_removed(subtree_height: u32, full_height: usize) {
1283        let mut rng = CustomRng;
1284
1285        let subtree_size = 1 << subtree_height;
1286        let subtrees_count = (1 << full_height) / subtree_size;
1287        let start_index = 1 << 12;
1288        let mut subtree_indexes: Vec<_> = (0..subtrees_count).map(|i| start_index + i).collect();
1289        subtree_indexes.shuffle(&mut thread_rng());
1290
1291        let mut tree = &mut init().tree;
1292        for subtree_index in subtree_indexes {
1293            tree.add_subtree_root(subtree_height, subtree_index, rng.gen());
1294        }
1295
1296        tree.clean();
1297
1298        let tree_nodes = tree.get_all_nodes();
1299        assert_eq!(
1300            tree_nodes.len(),
1301            constants::HEIGHT - full_height + 1,
1302            "Some temporary subtree nodes were not removed."
1303        );
1304    }
1305
1306    #[test]
1307    fn test_rollback_all_works_correctly() {
1308        let remove_size: u64 = 24;
1309
1310        let mut rng = CustomRng;
1311        let mut tree = &mut init().tree;
1312
1313        let original_root = tree.get_root();
1314
1315        for index in 0..remove_size {
1316            let leaf = rng.gen();
1317            tree.add_hash(index, leaf, false);
1318        }
1319
1320        let rollback_result = tree.rollback(0);
1321        assert!(rollback_result.is_none());
1322        let rollback_root = tree.get_root();
1323        assert_eq!(rollback_root, original_root);
1324        assert_eq!(tree.next_index, 0);
1325    }
1326
1327    #[test_case(32, 16)]
1328    #[test_case(16, 0)]
1329    #[test_case(11, 7)]
1330    fn test_rollback_removes_nodes_correctly(keep_size: u64, remove_size: u64) {
1331        let mut rng = CustomRng;
1332        let mut tree = &mut init().tree;
1333
1334        for index in 0..keep_size {
1335            let leaf = rng.gen();
1336            tree.add_hash(index, leaf, false);
1337        }
1338        let original_root = tree.get_root();
1339
1340        for index in 0..remove_size {
1341            let leaf = rng.gen();
1342            tree.add_hash(128 + index, leaf, false);
1343        }
1344
1345        let rollback_result = tree.rollback(128);
1346        assert!(rollback_result.is_none());
1347        let rollback_root = tree.get_root();
1348        assert_eq!(rollback_root, original_root);
1349        assert_eq!(tree.next_index, 128);
1350    }
1351
1352    // #[test]
1353    // fn test_rollback_works_correctly_after_clean() {
1354    //     let mut rng = CustomRng;
1355    //     let mut tree = create_tree();
1356    //
1357    //     for index in 0..4 {
1358    //         let leaf = rng.gen();
1359    //         tree.add_hash(index, leaf, true);
1360    //     }
1361    //     for index in 4..6 {
1362    //         let leaf = rng.gen();
1363    //         tree.add_hash(index, leaf, false);
1364    //     }
1365    //     for index in 6..12 {
1366    //         let leaf = rng.gen();
1367    //         tree.add_hash(index, leaf, true);
1368    //     }
1369    //     let original_root = tree.get_root();
1370    //     for index in 12..16 {
1371    //         let leaf = rng.gen();
1372    //         tree.add_hash(index, leaf, true);
1373    //     }
1374    //
1375    //     tree.clean_before_index(10);
1376    //
1377    //     let rollback_result = tree.rollback(12);
1378    //     assert!(rollback_result.is_none());
1379    //     let rollback_root = tree.get_root();
1380    //     assert_eq!(rollback_root, original_root);
1381    //     assert_eq!(tree.next_index, 12)
1382    // }
1383    //
1384    // #[test]
1385    // fn test_rollback_of_cleaned_nodes() {
1386    //     let mut rng = CustomRng;
1387    //     let mut tree = create_tree();
1388    //
1389    //     for index in 0..4 {
1390    //         let leaf = rng.gen();
1391    //         tree.add_hash(index, leaf, true);
1392    //     }
1393    //     for index in 4..6 {
1394    //         let leaf = rng.gen();
1395    //         tree.add_hash(index, leaf, false);
1396    //     }
1397    //     for index in 6..7 {
1398    //         let leaf = rng.gen();
1399    //         tree.add_hash(index, leaf, true);
1400    //     }
1401    //     let original_root = tree.get_root();
1402    //     for index in 7..16 {
1403    //         let leaf = rng.gen();
1404    //         tree.add_hash(index, leaf, true);
1405    //     }
1406    //
1407    //     tree.clean_before_index(10);
1408    //
1409    //     let rollback_result = tree.rollback(7);
1410    //     assert_eq!(rollback_result.unwrap(), 6);
1411    //     let rollback_root = tree.get_root();
1412    //     assert_ne!(rollback_root, original_root);
1413    //     assert_eq!(tree.next_index, 7)
1414    // }
1415
1416    #[test]
1417    fn test_get_leaves() {
1418        let mut rng = CustomRng;
1419        let mut tree = &mut init().tree;
1420
1421        let leaves_count = 6;
1422
1423        for index in 0..leaves_count {
1424            let leaf = rng.gen();
1425            tree.add_hash(index, leaf, true);
1426        }
1427
1428        let leaves = tree.get_leaves();
1429
1430        assert_eq!(leaves.len(), leaves_count as usize);
1431        for index in 0..leaves_count {
1432            assert!(leaves.iter().any(|node| node.index == index));
1433        }
1434    }
1435
1436    #[test]
1437    fn test_get_leaves_after() {
1438        let mut rng = CustomRng;
1439        let mut tree = &mut init().tree;
1440
1441        let leaves_count = 6;
1442        let skip_count = 2;
1443
1444        for index in 0..leaves_count {
1445            let leaf = rng.gen();
1446            tree.add_hash(index, leaf, true);
1447        }
1448
1449        let leaves = tree.get_leaves_after(skip_count);
1450
1451        assert_eq!(leaves.len(), (leaves_count - skip_count) as usize);
1452        for index in skip_count..leaves_count {
1453            assert!(leaves.iter().any(|node| node.index == index));
1454        }
1455    }
1456
1457    #[test]
1458    fn test_get_proof_after() {
1459        let mut rng = CustomRng;
1460        let mut tree = &mut init().tree;
1461
1462        let tree_size = 6;
1463        let new_hashes_size = 3;
1464
1465        for index in 0..tree_size {
1466            let leaf = rng.gen();
1467            tree.add_hash(index, leaf, false);
1468        }
1469
1470        let root_before_call = tree.get_root();
1471
1472        let new_hashes: Vec<_> = (0..new_hashes_size).map(|_| rng.gen()).collect();
1473        tree.get_proof_after(new_hashes);
1474
1475        let root_after_call = tree.get_root();
1476
1477        assert_eq!(root_before_call, root_after_call);
1478    }
1479
1480    #[test_case(12, 4)]
1481    #[test_case(13, 5)]
1482    #[test_case(0, 1)]
1483    #[test_case(0, 5)]
1484    #[test_case(0, 8)]
1485    #[test_case(4, 16)]
1486    fn test_get_proof_after_virtual(tree_size: u64, new_hashes_size: u64) {
1487        let mut rng = CustomRng;
1488        let mut tree = &mut init().tree;
1489
1490        for index in 0..tree_size {
1491            let leaf = rng.gen();
1492            tree.add_hash(index, leaf, false);
1493        }
1494
1495        let new_hashes: Vec<_> = (0..new_hashes_size).map(|_| rng.gen()).collect();
1496
1497        let root_before_call = tree.get_root();
1498
1499        let proofs_virtual = tree.get_proof_after_virtual(new_hashes.clone());
1500        let proofs_simple = tree.get_proof_after(new_hashes);
1501
1502        let root_after_call = tree.get_root();
1503
1504        assert_eq!(root_before_call, root_after_call);
1505        assert_eq!(proofs_simple.len(), proofs_virtual.len());
1506        for (simple_proof, virtual_proof) in proofs_simple.iter().zip(proofs_virtual) {
1507            for (simple_sibling, virtual_sibling) in simple_proof
1508                .sibling
1509                .iter()
1510                .zip(virtual_proof.sibling.iter())
1511            {
1512                assert_eq!(simple_sibling, virtual_sibling);
1513            }
1514            for (simple_path, virtual_path) in
1515                simple_proof.path.iter().zip(virtual_proof.path.iter())
1516            {
1517                assert_eq!(simple_path, virtual_path);
1518            }
1519        }
1520    }
1521
1522    #[test]
1523    fn test_default_hashes_are_added_correctly() {
1524        let mut rng = CustomRng;
1525        let mut tree = &mut init().tree;
1526
1527        // Empty tree contains default hashes.
1528        assert_eq!(tree.get(0, 0), tree.default_hashes[0]);
1529        assert_eq!(tree.get(0, 3), tree.default_hashes[0]);
1530        assert_eq!(tree.get(2, 0), tree.default_hashes[2]);
1531
1532        let hashes: Vec<_> = (0..3).map(|_| rng.gen()).collect();
1533        tree.add_hashes(0, hashes);
1534
1535        // Hashes were added.
1536        assert_ne!(tree.get(2, 0), tree.zero_note_hashes[2]);
1537        assert_ne!(tree.get(2, 0), tree.default_hashes[2]);
1538        // First subtree contains zero note hashes instead of default hashes.
1539        assert_eq!(tree.get(0, 4), tree.zero_note_hashes[0]);
1540        assert_eq!(tree.get(0, 127), tree.zero_note_hashes[0]);
1541        assert_eq!(tree.get(2, 1), tree.zero_note_hashes[2]);
1542        // Second subtree still contains default hashes.
1543        assert_eq!(tree.get(0, 128), tree.default_hashes[0]);
1544        assert_eq!(tree.get(7, 1), tree.default_hashes[7]);
1545
1546        let hashes: Vec<_> = (0..2).map(|_| rng.gen()).collect();
1547        tree.add_hashes(128, hashes);
1548        // Second subtree contains zero note hashes instead of default hashes.
1549        assert_eq!(tree.get(0, 128 + 4), tree.zero_note_hashes[0]);
1550        assert_eq!(tree.get(0, 128 + 127), tree.zero_note_hashes[0]);
1551        assert_eq!(tree.get(2, 32 + 1), tree.zero_note_hashes[2]);
1552        // Third subtree still contains default hashes.
1553        assert_eq!(tree.get(0, 128 + 128), tree.default_hashes[0]);
1554        assert_eq!(tree.get(7, 2), tree.default_hashes[7]);
1555    }
1556
1557    #[test_case(0, 0, 0.0)]
1558    #[test_case(1, 1, 0.0)]
1559    #[test_case(1, 1, 1.0)]
1560    #[test_case(4, 2, 0.0)]
1561    #[test_case(4, 2, 0.5)]
1562    #[test_case(4, 2, 1.0)]
1563    #[test_case(15, 7, 0.0)]
1564    #[test_case(15, 7, 0.5)]
1565    #[test_case(15, 7, 1.0)]
1566    fn test_add_leafs_and_commitments(
1567        tx_count: u64,
1568        max_leafs_count: u32,
1569        commitments_probability: f64,
1570    ) {
1571        let mut rng = CustomRng;
1572        let mut first_tree = &mut init().tree;
1573        let mut second_tree = &mut init().tree;
1574
1575        let leafs: Vec<(u64, Vec<_>)> = (0..tx_count)
1576            .map(|i| {
1577                let leafs_count: u32 = 1 + (rng.gen::<u32>() % max_leafs_count);
1578                (
1579                    i * (constants::OUT + 1) as u64,
1580                    (0..leafs_count).map(|_| rng.gen()).collect(),
1581                )
1582            })
1583            .collect();
1584
1585        let now = std::time::Instant::now();
1586        for (index, leafs) in leafs.clone().into_iter() {
1587            first_tree.add_hashes(index, leafs)
1588        }
1589        println!(
1590            "({}, {}, {}) add_hashes elapsed: {}",
1591            tx_count,
1592            max_leafs_count,
1593            commitments_probability,
1594            now.elapsed().as_millis()
1595        );
1596
1597        let commitments: Vec<(u64, _)> = leafs
1598            .clone()
1599            .into_iter()
1600            .map(|(index, leafs)| {
1601                let mut out_hashes = leafs.clone();
1602                out_hashes.resize(constants::OUT + 1, first_tree.zero_note_hashes[0]);
1603                let commitment =
1604                    tx::out_commitment_hash(out_hashes.as_slice(), &POOL_PARAMS.clone());
1605                (index, commitment)
1606            })
1607            .collect();
1608
1609        commitments.iter().for_each(|(index, commitment)| {
1610            assert_eq!(
1611                first_tree.get(
1612                    constants::OUTPLUSONELOG as u32,
1613                    *index >> constants::OUTPLUSONELOG
1614                ),
1615                *commitment
1616            );
1617        });
1618
1619        let mut sub_leafs: Vec<(u64, Vec<_>)> = Vec::new();
1620        let mut sub_commitments: Vec<(u64, _)> = Vec::new();
1621        (0..tx_count).for_each(|i| {
1622            if rng.gen_bool(commitments_probability) {
1623                sub_commitments.push(commitments[i as usize]);
1624            } else {
1625                sub_leafs.push((leafs[i as usize].0, leafs[i as usize].1.clone()));
1626            }
1627        });
1628
1629        let now = std::time::Instant::now();
1630        second_tree.add_leafs_and_commitments(sub_leafs, sub_commitments);
1631        println!(
1632            "({}, {}, {}) add_leafs_and_commitments elapsed: {}",
1633            tx_count,
1634            max_leafs_count,
1635            commitments_probability,
1636            now.elapsed().as_millis()
1637        );
1638
1639        assert_eq!(
1640            first_tree.get_root().to_string(),
1641            second_tree.get_root().to_string()
1642        );
1643        assert_eq!(first_tree.next_index(), second_tree.next_index());
1644    }
1645
1646    #[test_case(0, 0, 0.0)]
1647    #[test_case(1, 1, 0.0)]
1648    #[test_case(1, 1, 1.0)]
1649    #[test_case(4, 2, 0.0)]
1650    #[test_case(4, 2, 0.5)]
1651    #[test_case(4, 2, 1.0)]
1652    #[test_case(15, 7, 0.0)]
1653    #[test_case(15, 7, 0.5)]
1654    #[test_case(15, 7, 1.0)]
1655    fn test_get_root_optimistic(tx_count: u64, max_leafs_count: u32, commitments_probability: f64) {
1656        let mut rng = CustomRng;
1657        let mut tree = &mut init().tree;
1658
1659        let leafs: Vec<(u64, Vec<_>)> = (0..tx_count)
1660            .map(|i| {
1661                let leafs_count: u32 = 1 + (rng.gen::<u32>() % max_leafs_count);
1662                (
1663                    i * (constants::OUT + 1) as u64,
1664                    (0..leafs_count).map(|_| rng.gen()).collect(),
1665                )
1666            })
1667            .collect();
1668
1669        let commitments: Vec<(u64, _)> = leafs
1670            .clone()
1671            .into_iter()
1672            .map(|(index, leafs)| {
1673                let mut out_hashes = leafs.clone();
1674                out_hashes.resize(constants::OUT + 1, tree.zero_note_hashes[0]);
1675                let commitment =
1676                    tx::out_commitment_hash(out_hashes.as_slice(), &POOL_PARAMS.clone());
1677                (index, commitment)
1678            })
1679            .collect();
1680
1681        let mut sub_leafs: Vec<(u64, Vec<_>)> = Vec::new();
1682        let mut sub_commitments: Vec<(u64, _)> = Vec::new();
1683        (0..tx_count).for_each(|i| {
1684            if rng.gen_bool(commitments_probability) {
1685                sub_commitments.push(commitments[i as usize]);
1686            } else {
1687                sub_leafs.push((leafs[i as usize].0, leafs[i as usize].1.clone()));
1688            }
1689        });
1690
1691        let (mut virtual_nodes, update_boundaries) =
1692            tree.get_virtual_subtree(sub_leafs.clone(), sub_commitments.clone());
1693        let optimistic_root = tree.get_root_optimistic(&mut virtual_nodes, &update_boundaries);
1694
1695        tree.add_leafs_and_commitments(sub_leafs, sub_commitments);
1696        let root = tree.get_root();
1697
1698        assert_eq!(optimistic_root.to_string(), root.to_string());
1699    }
1700}