1use std::collections::HashMap;
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
5use kvdb::{DBTransaction, KeyValueDB};
6use kvdb_memorydb::InMemory as MemoryDatabase;
7#[cfg(feature = "native")]
8use kvdb_persy::PersyDatabase as NativeDatabase;
9#[cfg(feature = "web")]
10use kvdb_web::Database as WebDatabase;
11use libzeropool::{
12 constants,
13 fawkes_crypto::{
14 core::sizedvec::SizedVec,
15 ff_uint::{Num, PrimeField},
16 native::poseidon::{poseidon, MerkleProof},
17 },
18 native::params::PoolParams,
19};
20use serde::{Deserialize, Serialize};
21
22use crate::utils::zero_note;
23
24pub type Hash<F> = Num<F>;
25
26const NUM_COLUMNS: u32 = 4;
27const NEXT_INDEX_KEY: &[u8] = br"next_index";
28enum DbCols {
29 Leaves = 0,
30 TempLeaves = 1,
31 NamedIndex = 2,
32 NextIndex = 3,
33}
34
35pub struct MerkleTree<D: KeyValueDB, P: PoolParams> {
36 db: D,
37 params: P,
38 default_hashes: Vec<Hash<P::Fr>>,
39 zero_note_hashes: Vec<Hash<P::Fr>>,
40 next_index: u64,
41}
42
43#[cfg(feature = "native")]
44pub type NativeMerkleTree<P> = MerkleTree<NativeDatabase, P>;
45
46#[cfg(feature = "web")]
47pub type WebMerkleTree<P> = MerkleTree<WebDatabase, P>;
48
49#[cfg(feature = "web")]
50impl<P: PoolParams> MerkleTree<WebDatabase, P> {
51 pub async fn new_web(name: &str, params: P) -> MerkleTree<WebDatabase, P> {
52 let db = WebDatabase::open(name.to_owned(), NUM_COLUMNS)
53 .await
54 .unwrap();
55
56 Self::new(db, params)
57 }
58}
59
60#[cfg(feature = "native")]
61impl<P: PoolParams> MerkleTree<NativeDatabase, P> {
62 pub fn new_native(path: &str, params: P) -> std::io::Result<MerkleTree<NativeDatabase, P>> {
63 let prefix = (0u32).to_be_bytes();
64 let db = NativeDatabase::open(path, 4, &[&prefix])?;
65
66 Ok(Self::new(db, params))
67 }
68}
69
70impl<P: PoolParams> MerkleTree<MemoryDatabase, P> {
71 pub fn new_test(params: P) -> MerkleTree<MemoryDatabase, P> {
72 Self::new(kvdb_memorydb::create(NUM_COLUMNS), params)
73 }
74}
75
76impl<D: KeyValueDB, P: PoolParams> MerkleTree<D, P> {
78 pub fn new(db: D, params: P) -> Self {
79 let db_next_index = db.get(DbCols::NextIndex as u32, NEXT_INDEX_KEY);
80 let next_index = match db_next_index {
81 Ok(Some(next_index)) => next_index.as_slice().read_u64::<BigEndian>().unwrap(),
82 _ => {
83 let mut cur_next_index = 0;
84 for (k, _v) in db.iter(0).map(|res| res.unwrap()) {
85 let (height, index) = Self::parse_node_key(&k);
86
87 if height == 0 && index >= cur_next_index {
88 cur_next_index = Self::calc_next_index(index);
89 }
90 }
91 cur_next_index
92 }
93 };
94
95 MerkleTree {
96 db,
97 default_hashes: Self::gen_default_hashes(¶ms),
98 zero_note_hashes: Self::gen_empty_note_hashes(¶ms),
99 params,
100 next_index,
101 }
102 }
103
104 pub fn add_hash_at_height(
108 &mut self,
109 height: u32,
110 index: u64,
111 hash: Hash<P::Fr>,
112 temporary: bool,
113 ) {
114 let next_index_was_updated = self.update_next_index_from_node(height, index);
116
117 if hash == self.zero_note_hashes[height as usize] && !next_index_was_updated {
118 return;
119 }
120
121 let mut batch = self.db.transaction();
122
123 let temporary_leaves_count = if temporary { 1 } else { 0 };
125 self.set_batched(&mut batch, height, index, hash, temporary_leaves_count);
126
127 self.update_path_batched(&mut batch, height, index, hash, temporary_leaves_count);
129
130 self.db.write(batch).unwrap();
131 }
132
133 pub fn add_hash(&mut self, index: u64, hash: Hash<P::Fr>, temporary: bool) {
134 self.add_hash_at_height(0, index, hash, temporary)
135 }
136
137 pub fn append_hash(&mut self, hash: Hash<P::Fr>, temporary: bool) -> u64 {
138 let index = self.next_index;
139 self.add_hash(index, hash, temporary);
140 index
141 }
142
143 pub fn add_leafs_and_commitments(
144 &mut self,
145 leafs: Vec<(u64, Vec<Hash<P::Fr>>)>,
146 commitments: Vec<(u64, Hash<P::Fr>)>,
147 ) {
148 if leafs.is_empty() && commitments.is_empty() {
149 return;
150 }
151
152 let mut next_index: u64 = 0;
153 let mut start_index: u64 = u64::MAX;
154 let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = commitments
155 .into_iter()
156 .map(|(index, hash)| {
157 assert_eq!(index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
158 start_index = start_index.min(index);
159 next_index = next_index.max(index + 1);
160 (
161 (
162 constants::OUTPLUSONELOG as u32,
163 index >> constants::OUTPLUSONELOG,
164 ),
165 hash,
166 )
167 })
168 .collect();
169
170 leafs.into_iter().for_each(|(index, leafs)| {
171 assert_eq!(index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
172 start_index = start_index.min(index);
173 next_index = next_index.max(index + leafs.len() as u64);
174 (0..constants::OUTPLUSONELOG).for_each(|height| {
175 virtual_nodes.insert(
176 (
177 height as u32,
178 ((index + leafs.len() as u64 - 1) >> height) + 1,
179 ),
180 self.zero_note_hashes[height],
181 );
182 });
183 leafs.into_iter().enumerate().for_each(|(i, leaf)| {
184 virtual_nodes.insert((0_u32, index + i as u64), leaf);
185 });
186 });
187
188 let original_next_index = self.next_index;
189 self.update_next_index_from_node(0, next_index);
190
191 let update_boundaries = UpdateBoundaries {
192 updated_range_left_index: original_next_index,
193 updated_range_right_index: self.next_index,
194 new_hashes_left_index: start_index,
195 new_hashes_right_index: next_index,
196 };
197
198 self.get_virtual_node_full(
200 constants::HEIGHT as u32,
201 0,
202 &mut virtual_nodes,
203 &update_boundaries,
204 );
205
206 self.put_hashes(virtual_nodes);
208 }
209
210 pub fn add_hashes<I>(&mut self, start_index: u64, hashes: I)
211 where
212 I: IntoIterator<Item = Hash<P::Fr>>,
213 {
214 assert_eq!(start_index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
216
217 let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = hashes
218 .into_iter()
219 .filter(|hash| *hash != self.zero_note_hashes[0])
221 .enumerate()
222 .map(|(index, hash)| ((0, start_index + index as u64), hash))
223 .collect();
224 let new_hashes_count = virtual_nodes.len() as u64;
225
226 assert!(new_hashes_count <= (2u64 << constants::OUTPLUSONELOG));
227
228 let original_next_index = self.next_index;
229 self.update_next_index_from_node(0, start_index);
230
231 let update_boundaries = UpdateBoundaries {
232 updated_range_left_index: original_next_index,
233 updated_range_right_index: self.next_index,
234 new_hashes_left_index: start_index,
235 new_hashes_right_index: start_index + new_hashes_count,
236 };
237
238 self.get_virtual_node_full(
240 constants::HEIGHT as u32,
241 0,
242 &mut virtual_nodes,
243 &update_boundaries,
244 );
245
246 self.put_hashes(virtual_nodes);
248 }
249
250 fn put_hashes(&mut self, virtual_nodes: HashMap<(u32, u64), Hash<<P as PoolParams>::Fr>>) {
251 let mut batch = self.db.transaction();
252
253 for ((height, index), value) in virtual_nodes {
254 self.set_batched(&mut batch, height, index, value, 0);
255 }
256
257 self.db.write(batch).unwrap();
258 }
259
260 #[cfg(test)]
262 fn add_subtree_root(&mut self, height: u32, index: u64, hash: Hash<P::Fr>) {
263 self.update_next_index_from_node(height, index);
264
265 let mut batch = self.db.transaction();
266
267 self.set_batched(&mut batch, height, index, hash, 1 << height);
269
270 self.update_path_batched(&mut batch, height, index, hash, 1 << height);
272
273 self.db.write(batch).unwrap();
274 }
275
276 pub fn get(&self, height: u32, index: u64) -> Hash<P::Fr> {
277 self.get_with_next_index(height, index, self.next_index)
278 }
279
280 fn get_with_next_index(&self, height: u32, index: u64, next_index: u64) -> Hash<P::Fr> {
281 match self.get_opt(height, index) {
282 Some(val) => val,
283 _ => {
284 let next_leave_index = u64::pow(2, height) * (index + 1);
285 if next_leave_index <= next_index {
286 self.zero_note_hashes[height as usize]
287 } else {
288 self.default_hashes[height as usize]
289 }
290 }
291 }
292 }
293
294 pub fn last_leaf(&self) -> Hash<P::Fr> {
295 match self.get_opt(0, self.next_index.saturating_sub(1)) {
297 Some(val) => val,
298 _ => self.default_hashes[0],
299 }
300 }
301
302 pub fn get_root(&self) -> Hash<P::Fr> {
303 self.get(constants::HEIGHT as u32, 0)
304 }
305
306 pub fn get_root_after_virtual<I>(&self, new_commitments: I) -> Hash<P::Fr>
307 where
308 I: IntoIterator<Item = Hash<P::Fr>>,
309 {
310 let next_leaf_index = self.next_index;
311 let next_commitment_index = next_leaf_index / 2u64.pow(constants::OUTPLUSONELOG as u32);
312 let index_step = constants::OUT as u64 + 1;
313
314 let mut virtual_commitment_nodes: HashMap<(u32, u64), Hash<P::Fr>> = new_commitments
315 .into_iter()
316 .enumerate()
317 .map(|(index, hash)| {
318 (
319 (
320 constants::OUTPLUSONELOG as u32,
321 next_commitment_index + index as u64,
322 ),
323 hash,
324 )
325 })
326 .collect();
327 let new_commitments_count = virtual_commitment_nodes.len() as u64;
328
329 self.get_virtual_node(
330 constants::HEIGHT as u32,
331 0,
332 &mut virtual_commitment_nodes,
333 next_leaf_index,
334 next_leaf_index + new_commitments_count * index_step,
335 )
336 }
337
338 pub fn get_root_optimistic(
339 &self,
340 virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
341 update_boundaries: &UpdateBoundaries,
342 ) -> Hash<P::Fr> {
343 self.get_virtual_node_full(
344 constants::HEIGHT as u32,
345 0,
346 virtual_nodes,
347 &update_boundaries,
348 )
349 }
350
351 pub fn get_opt(&self, height: u32, index: u64) -> Option<Hash<P::Fr>> {
352 assert!(height <= constants::HEIGHT as u32);
353
354 let key = Self::node_key(height, index);
355 let res = self.db.get(0, &key);
356
357 match res {
358 Ok(Some(ref val)) => Some(Hash::<P::Fr>::try_from_slice(val).unwrap()),
359 _ => None,
360 }
361 }
362
363 pub fn get_proof_unchecked<const H: usize>(&self, index: u64) -> MerkleProof<P::Fr, { H }> {
364 let mut sibling: SizedVec<_, { H }> = (0..H).map(|_| Num::ZERO).collect();
365 let mut path: SizedVec<_, { H }> = (0..H).map(|_| false).collect();
366
367 let start_height = constants::HEIGHT - H;
368
369 sibling.iter_mut().zip(path.iter_mut()).enumerate().fold(
370 index,
371 |x, (h, (sibling, is_right))| {
372 let cur_height = (start_height + h) as u32;
373 *is_right = x % 2 == 1;
374 *sibling = self.get(cur_height, x ^ 1);
375
376 x / 2
377 },
378 );
379
380 MerkleProof { sibling, path }
381 }
382
383 pub fn get_leaf_proof(&self, index: u64) -> Option<MerkleProof<P::Fr, { constants::HEIGHT }>> {
384 let key = Self::node_key(0, index);
385 let node_present = self.db.get(0, &key).map_or(false, |value| value.is_some());
386 if !node_present {
387 return None;
388 }
389 Some(self.get_proof_unchecked(index))
390 }
391
392 #[cfg(test)]
394 fn get_proof_after<I>(
395 &mut self,
396 new_hashes: I,
397 ) -> Vec<MerkleProof<P::Fr, { constants::HEIGHT }>>
398 where
399 I: IntoIterator<Item = Hash<P::Fr>>,
400 {
401 let new_hashes: Vec<_> = new_hashes.into_iter().collect();
402 let size = new_hashes.len() as u64;
403
404 let index_offset = self.next_index;
406 self.add_hashes(index_offset, new_hashes);
407
408 let proofs = (index_offset..index_offset + size)
409 .map(|index| {
410 self.get_leaf_proof(index)
411 .expect("Leaf was expected to be present (bug)")
412 })
413 .collect();
414
415 self.next_index = index_offset;
417 for index in index_offset..index_offset + size {
419 self.remove_leaf(index);
420 }
421
422 proofs
423 }
424
425 pub fn get_proof_after_virtual<I>(
426 &self,
427 new_hashes: I,
428 ) -> Vec<MerkleProof<P::Fr, { constants::HEIGHT }>>
429 where
430 I: IntoIterator<Item = Hash<P::Fr>>,
431 {
432 let index_offset = self.next_index;
433
434 let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = new_hashes
435 .into_iter()
436 .enumerate()
437 .map(|(index, hash)| ((0, index_offset + index as u64), hash))
438 .collect();
439 let new_hashes_count = virtual_nodes.len() as u64;
440
441 let update_boundaries = UpdateBoundaries {
442 updated_range_left_index: index_offset,
443 updated_range_right_index: Self::calc_next_index(index_offset),
444 new_hashes_left_index: index_offset,
445 new_hashes_right_index: index_offset + new_hashes_count,
446 };
447
448 (index_offset..index_offset + new_hashes_count)
449 .map(|index| self.get_proof_virtual(index, &mut virtual_nodes, &update_boundaries))
450 .collect()
451 }
452
453 pub fn get_proof_virtual_index<I>(
454 &self,
455 index: u64,
456 new_hashes: I,
457 ) -> Option<MerkleProof<P::Fr, { constants::HEIGHT }>>
458 where
459 I: IntoIterator<Item = Hash<P::Fr>>,
460 {
461 let index_offset = self.next_index;
462
463 let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = new_hashes
464 .into_iter()
465 .enumerate()
466 .map(|(index, hash)| ((0, index_offset + index as u64), hash))
467 .collect();
468 let new_hashes_count = virtual_nodes.len() as u64;
469
470 let update_boundaries = UpdateBoundaries {
471 updated_range_left_index: index_offset,
472 updated_range_right_index: index_offset + new_hashes_count,
473 new_hashes_left_index: index_offset,
474 new_hashes_right_index: index_offset + new_hashes_count,
475 };
476
477 Some(self.get_proof_virtual(index, &mut virtual_nodes, &update_boundaries))
478 }
479
480 pub fn get_virtual_subtree<I1, I2>(
481 &self,
482 new_hashes: I1,
483 new_commitments: I2,
484 ) -> (HashMap<(u32, u64), Hash<P::Fr>>, UpdateBoundaries)
485 where
486 I1: IntoIterator<Item = (u64, Vec<Hash<P::Fr>>)>,
487 I2: IntoIterator<Item = (u64, Hash<P::Fr>)>,
488 {
489 let mut next_index: Option<u64> = None;
490 let mut start_index: Option<u64> = None;
491 let mut virtual_nodes: HashMap<(u32, u64), Hash<P::Fr>> = new_commitments
492 .into_iter()
493 .map(|(index, hash)| {
494 assert_eq!(index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
495 start_index = Some(start_index.unwrap_or(u64::MAX).min(index));
496 next_index = Some(next_index.unwrap_or(0).max(index + 1));
497 (
498 (
499 constants::OUTPLUSONELOG as u32,
500 index >> constants::OUTPLUSONELOG,
501 ),
502 hash,
503 )
504 })
505 .collect();
506
507 new_hashes.into_iter().for_each(|(index, leafs)| {
508 assert_eq!(index & ((1 << constants::OUTPLUSONELOG) - 1), 0);
509 start_index = Some(start_index.unwrap_or(u64::MAX).min(index));
510 next_index = Some(next_index.unwrap_or(0).max(index + leafs.len() as u64));
511 (0..constants::OUTPLUSONELOG).for_each(|height| {
512 virtual_nodes.insert(
513 (
514 height as u32,
515 ((index + leafs.len() as u64 - 1) >> height) + 1,
516 ),
517 self.zero_note_hashes[height],
518 );
519 });
520 leafs.into_iter().enumerate().for_each(|(i, leaf)| {
521 virtual_nodes.insert((0_u32, index + i as u64), leaf);
522 });
523 });
524
525 let update_boundaries = {
526 if let (Some(start_index), Some(next_index)) = (start_index, next_index) {
527 UpdateBoundaries {
528 updated_range_left_index: self.next_index,
529 updated_range_right_index: Self::calc_next_index(next_index),
530 new_hashes_left_index: start_index,
531 new_hashes_right_index: next_index,
532 }
533 } else {
534 UpdateBoundaries {
535 updated_range_left_index: self.next_index,
536 updated_range_right_index: self.next_index,
537 new_hashes_left_index: self.next_index,
538 new_hashes_right_index: self.next_index,
539 }
540 }
541 };
542
543 self.get_virtual_node_full(
545 constants::HEIGHT as u32,
546 0,
547 &mut virtual_nodes,
548 &update_boundaries,
549 );
550
551 (virtual_nodes, update_boundaries)
552 }
553
554 pub fn get_proof_optimistic_index(
555 &self,
556 index: u64,
557 virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
558 update_boundaries: &UpdateBoundaries,
559 ) -> Option<MerkleProof<P::Fr, { constants::HEIGHT }>> {
560 Some(self.get_proof_virtual(index, virtual_nodes, update_boundaries))
561 }
562
563 fn get_proof_virtual<const H: usize>(
564 &self,
565 index: u64,
566 virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
567 update_boundaries: &UpdateBoundaries,
568 ) -> MerkleProof<P::Fr, { H }> {
569 let mut sibling: SizedVec<_, { H }> = (0..H).map(|_| Num::ZERO).collect();
570 let mut path: SizedVec<_, { H }> = (0..H).map(|_| false).collect();
571
572 let start_height = constants::HEIGHT - H;
573
574 sibling.iter_mut().zip(path.iter_mut()).enumerate().fold(
575 index,
576 |x, (h, (sibling, is_right))| {
577 let cur_height = (start_height + h) as u32;
578 *is_right = x % 2 == 1;
579 *sibling =
580 self.get_virtual_node_full(cur_height, x ^ 1, virtual_nodes, update_boundaries);
581
582 x / 2
583 },
584 );
585
586 MerkleProof { sibling, path }
587 }
588
589 pub fn get_virtual_node(
590 &self,
591 height: u32,
592 index: u64,
593 virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
594 new_hashes_left_index: u64,
595 new_hashes_right_index: u64,
596 ) -> Hash<P::Fr> {
597 let update_boundaries = UpdateBoundaries {
598 updated_range_left_index: new_hashes_left_index,
599 updated_range_right_index: new_hashes_right_index,
600 new_hashes_left_index,
601 new_hashes_right_index,
602 };
603
604 self.get_virtual_node_full(height, index, virtual_nodes, &update_boundaries)
605 }
606
607 fn get_virtual_node_full(
608 &self,
609 height: u32,
610 index: u64,
611 virtual_nodes: &mut HashMap<(u32, u64), Hash<P::Fr>>,
612 update_boundaries: &UpdateBoundaries,
613 ) -> Hash<P::Fr> {
614 let node_left = index * (1 << height);
615 let node_right = (index + 1) * (1 << height);
616 if node_right <= update_boundaries.updated_range_left_index
617 || update_boundaries.updated_range_right_index <= node_left
618 {
619 return self.get(height, index);
620 }
621 if (node_right <= update_boundaries.new_hashes_left_index
622 || update_boundaries.new_hashes_right_index <= node_left)
623 && update_boundaries.updated_range_left_index <= node_left
624 && node_right <= update_boundaries.updated_range_right_index
625 {
626 return self.zero_note_hashes[height as usize];
627 }
628
629 let key = (height, index);
630 match virtual_nodes.get(&key) {
631 Some(hash) => *hash,
632 None => {
633 let left_child = self.get_virtual_node_full(
634 height - 1,
635 2 * index,
636 virtual_nodes,
637 update_boundaries,
638 );
639 let right_child = self.get_virtual_node_full(
640 height - 1,
641 2 * index + 1,
642 virtual_nodes,
643 update_boundaries,
644 );
645 let pair = [left_child, right_child];
646 let hash = poseidon(pair.as_ref(), self.params.compress());
647 virtual_nodes.insert(key, hash);
648
649 hash
650 }
651 }
652 }
653
654 pub fn clean(&mut self) -> u64 {
655 self.clean_before_index(u64::MAX)
656 }
657
658 pub fn clean_before_index(&mut self, clean_before_index: u64) -> u64 {
659 let mut batch = self.db.transaction();
660
661 let keys: Vec<(u32, u64)> = self
664 .db
665 .iter(0)
666 .map(|res| Self::parse_node_key(&res.unwrap().0))
667 .collect();
668 for (height, index) in keys {
670 if height == 0 {
672 continue;
673 }
674
675 if (index + 1) * (1 << height) > clean_before_index {
677 continue;
678 }
679
680 if self.subtree_contains_only_temporary_leaves(height, index) {
681 self.remove_batched(&mut batch, height - 1, 2 * index);
683 self.remove_batched(&mut batch, height - 1, 2 * index + 1);
684 }
685 }
686
687 self.set_clean_index_batched(&mut batch, clean_before_index);
688
689 self.db.write(batch).unwrap();
690
691 self.next_index
692 }
693
694 pub fn rollback(&mut self, rollback_index: u64) -> Option<u64> {
695 let mut result: Option<u64> = None;
696
697 let clean_index = self.get_clean_index();
699 if rollback_index < clean_index {
700 let mut nodes_request_index = self.next_index;
702 let mut index = rollback_index;
703 for height in 0..constants::HEIGHT as u32 {
704 let sibling_index = index ^ 1;
705 if sibling_index < index
706 && !self.subtree_contains_only_temporary_leaves(height, sibling_index)
707 {
708 let leaf_index = index * (1 << height);
709 if leaf_index < nodes_request_index {
710 nodes_request_index = leaf_index
711 }
712 }
713 index /= 2;
714 }
715 if nodes_request_index < clean_index {
716 result = Some(nodes_request_index)
717 }
718 }
719
720 let original_next_index = self.next_index;
722 self.next_index = if rollback_index > 0 {
723 Self::calc_next_index(rollback_index - 1)
724 } else {
725 0
726 };
727 for index in (rollback_index..original_next_index).rev() {
729 self.remove_leaf(index);
730 }
731
732 result
733 }
734
735 pub fn get_all_nodes(&self) -> Vec<Node<P::Fr>> {
736 self.db
737 .iter(0)
738 .map(|res| {
739 let (key, value) = res.unwrap();
740 Self::build_node(&key, &value)
741 })
742 .collect()
743 }
744
745 pub fn get_leaves(&self) -> Vec<Node<P::Fr>> {
746 self.get_leaves_after(0)
747 }
748
749 pub fn get_leaves_after(&self, index: u64) -> Vec<Node<P::Fr>> {
750 let prefix = (0u32).to_be_bytes();
751 self.db
752 .iter_with_prefix(0, &prefix)
753 .map(|res| {
754 let (key, value) = res.unwrap();
755 Self::build_node(&key, &value)
756 })
757 .filter(|node| node.index >= index)
758 .collect()
759 }
760
761 pub fn next_index(&self) -> u64 {
762 self.next_index
763 }
764
765 fn update_next_index(&mut self, next_index: u64) -> bool {
766 if next_index >= self.next_index {
767 let mut transaction = self.db.transaction();
768 let mut data = [0u8; 8];
769 {
770 let mut bytes = &mut data[..];
771 let _ = bytes.write_u64::<BigEndian>(next_index);
772 }
773 transaction.put(DbCols::NextIndex as u32, NEXT_INDEX_KEY, &data);
774 self.db.write(transaction).unwrap();
775
776 self.next_index = next_index;
777 true
778 } else {
779 false
780 }
781 }
782
783 fn update_next_index_from_node(&mut self, height: u32, index: u64) -> bool {
784 let leaf_index = u64::pow(2, height) * (index + 1) - 1;
785 self.update_next_index(Self::calc_next_index(leaf_index))
786 }
787
788 #[inline]
789 fn calc_next_index(leaf_index: u64) -> u64 {
790 ((leaf_index >> constants::OUTPLUSONELOG) + 1) << constants::OUTPLUSONELOG
791 }
792
793 fn update_path_batched(
794 &mut self,
795 batch: &mut DBTransaction,
796 height: u32,
797 index: u64,
798 hash: Hash<P::Fr>,
799 temporary_leaves_count: u64,
800 ) {
801 let mut child_index = index;
802 let mut child_hash = hash;
803 let mut child_temporary_leaves_count = temporary_leaves_count;
804 for current_height in height + 1..=constants::HEIGHT as u32 {
806 let parent_index = child_index / 2;
807
808 let second_child_index = child_index ^ 1;
810
811 let pair = if child_index % 2 == 0 {
813 [child_hash, self.get(current_height - 1, second_child_index)]
814 } else {
815 [self.get(current_height - 1, second_child_index), child_hash]
816 };
817 let hash = poseidon(pair.as_ref(), self.params.compress());
818
819 let second_child_temporary_leaves_count =
821 self.get_temporary_count(current_height - 1, second_child_index);
822 let parent_temporary_leaves_count =
823 child_temporary_leaves_count + second_child_temporary_leaves_count;
824
825 self.set_batched(
826 batch,
827 current_height,
828 parent_index,
829 hash,
830 parent_temporary_leaves_count,
831 );
832
833 child_index = parent_index;
840 child_hash = hash;
841 child_temporary_leaves_count = parent_temporary_leaves_count;
842 }
843 }
844
845 fn set_batched(
846 &mut self,
847 batch: &mut DBTransaction,
848 height: u32,
849 index: u64,
850 hash: Hash<P::Fr>,
851 temporary_leaves_count: u64,
852 ) {
853 let key = Self::node_key(height, index);
854 if hash != self.zero_note_hashes[height as usize] {
855 batch.put(DbCols::Leaves as u32, &key, &hash.try_to_vec().unwrap());
856 } else {
857 batch.delete(DbCols::Leaves as u32, &key);
858 }
859 if temporary_leaves_count > 0 {
860 batch.put(
861 DbCols::TempLeaves as u32,
862 &key,
863 &temporary_leaves_count.to_be_bytes(),
864 );
865 } else if self
866 .db
867 .has_key(DbCols::TempLeaves as u32, &key)
868 .unwrap_or(false)
869 {
870 batch.delete(DbCols::TempLeaves as u32, &key);
871 }
872 }
873
874 fn remove_batched(&mut self, batch: &mut DBTransaction, height: u32, index: u64) {
875 let key = Self::node_key(height, index);
876 batch.delete(DbCols::Leaves as u32, &key);
877 batch.delete(DbCols::TempLeaves as u32, &key);
878 }
879
880 fn remove_leaf(&mut self, index: u64) {
881 let mut batch = self.db.transaction();
882
883 self.remove_batched(&mut batch, 0, index);
884 self.update_path_batched(&mut batch, 0, index, self.default_hashes[0], 0);
885
886 self.db.write(batch).unwrap();
887 }
888
889 fn get_clean_index(&self) -> u64 {
890 match self.get_named_index_opt("clean_index") {
891 Some(val) => val,
892 _ => 0,
893 }
894 }
895
896 fn set_clean_index_batched(&mut self, batch: &mut DBTransaction, value: u64) {
897 self.set_named_index_batched(batch, "clean_index", value);
898 }
899
900 fn get_named_index_opt(&self, key: &str) -> Option<u64> {
901 let res = self.db.get(2, key.as_bytes());
902 match res {
903 Ok(Some(ref val)) => Some((&val[..]).read_u64::<BigEndian>().unwrap()),
904 _ => None,
905 }
906 }
907
908 fn set_named_index_batched(&mut self, batch: &mut DBTransaction, key: &str, value: u64) {
909 batch.put(
910 DbCols::NamedIndex as u32,
911 key.as_bytes(),
912 &value.to_be_bytes(),
913 );
914 }
915
916 fn get_temporary_count(&self, height: u32, index: u64) -> u64 {
917 match self.get_temporary_count_opt(height, index) {
918 Some(val) => val,
919 _ => 0,
920 }
921 }
922
923 fn get_temporary_count_opt(&self, height: u32, index: u64) -> Option<u64> {
924 assert!(height <= constants::HEIGHT as u32);
925
926 let key = Self::node_key(height, index);
927 let res = self.db.get(1, &key);
928
929 match res {
930 Ok(Some(ref val)) => Some((&val[..]).read_u64::<BigEndian>().unwrap()),
931 _ => None,
932 }
933 }
934
935 fn subtree_contains_only_temporary_leaves(&self, height: u32, index: u64) -> bool {
936 self.get_temporary_count(height, index) == (1 << height)
937 }
938
939 #[inline]
940 fn node_key(height: u32, index: u64) -> [u8; 12] {
941 let mut data = [0u8; 12];
942 {
943 let mut bytes = &mut data[..];
944 let _ = bytes.write_u32::<BigEndian>(height);
945 let _ = bytes.write_u64::<BigEndian>(index);
946 }
947
948 data
949 }
950
951 fn parse_node_key(data: &[u8]) -> (u32, u64) {
952 let mut bytes = data;
953 let height = bytes.read_u32::<BigEndian>().unwrap();
954 let index = bytes.read_u64::<BigEndian>().unwrap();
955
956 (height, index)
957 }
958
959 fn build_node(key: &[u8], value: &[u8]) -> Node<P::Fr> {
960 let (height, index) = Self::parse_node_key(key);
961 let value = Hash::try_from_slice(value).unwrap();
962
963 Node {
964 index,
965 height,
966 value,
967 }
968 }
969
970 fn gen_default_hashes(params: &P) -> Vec<Hash<P::Fr>> {
971 let mut default_hashes = vec![Num::ZERO; constants::HEIGHT + 1];
972
973 Self::fill_default_hashes(&mut default_hashes, params);
974
975 default_hashes
976 }
977
978 fn gen_empty_note_hashes(params: &P) -> Vec<Hash<P::Fr>> {
979 let empty_note_hash = zero_note().hash(params);
980
981 let mut empty_note_hashes = vec![empty_note_hash; constants::HEIGHT + 1];
982
983 Self::fill_default_hashes(&mut empty_note_hashes, params);
984
985 empty_note_hashes
986 }
987
988 fn fill_default_hashes(default_hashes: &mut Vec<Hash<P::Fr>>, params: &P) {
989 for i in 1..default_hashes.len() {
990 let t = default_hashes[i - 1];
991 default_hashes[i] = poseidon([t, t].as_ref(), params.compress());
992 }
993 }
994}
995
996#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
997pub struct Node<F: PrimeField> {
998 pub index: u64,
999 pub height: u32,
1000 #[serde(bound(serialize = "", deserialize = ""))]
1001 pub value: Num<F>,
1002}
1003
1004pub struct UpdateBoundaries {
1005 updated_range_left_index: u64,
1006 updated_range_right_index: u64,
1007 new_hashes_left_index: u64,
1008 new_hashes_right_index: u64,
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013 use std::sync::atomic::AtomicUsize;
1014
1015 #[cfg(not(feature = "native"))]
1016 use kvdb_memorydb::InMemory as Database;
1017 #[cfg(feature = "native")]
1018 use kvdb_persy::PersyDatabase as Database;
1019 use libzeropool::{
1020 fawkes_crypto::ff_uint::rand::Rng,
1021 native::{params::PoolBN256, tx},
1022 POOL_PARAMS,
1023 };
1024 use rand::{seq::SliceRandom, thread_rng};
1025 use test_case::test_case;
1026
1027 use super::*;
1028 use crate::random::CustomRng;
1029
1030 struct TestContext {
1031 tree: MerkleTree<Database, PoolBN256>,
1032 #[cfg(feature = "native")]
1033 db_path: String,
1034 }
1035
1036 #[cfg(feature = "native")]
1037 impl Drop for TestContext {
1038 fn drop(&mut self) {
1039 std::fs::remove_file(&self.db_path).unwrap();
1040 }
1041 }
1042
1043 #[cfg(feature = "native")]
1044 fn init() -> TestContext {
1045 static FILE_COUNTER: AtomicUsize = AtomicUsize::new(0);
1046 let file_counter = FILE_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1047 let path = format!("merkle-test-{}.persy", file_counter);
1048 let tree = MerkleTree::new_native(&path, POOL_PARAMS.clone()).unwrap();
1049
1050 TestContext {
1051 tree,
1052 db_path: path,
1053 }
1054 }
1055
1056 #[cfg(not(feature = "native"))]
1057 fn init() -> TestContext {
1058 let db = kvdb_memorydb::create(3);
1059
1060 TestContext {
1061 tree: MerkleTree::new(db, POOL_PARAMS.clone()),
1062 }
1063 }
1064
1065 #[test]
1066 fn test_add_hashes_first_3() {
1067 let mut rng = CustomRng;
1068 let tree = &mut init().tree;
1069 let hashes: Vec<_> = (0..3).map(|_| rng.gen()).collect();
1070 tree.add_hashes(0, hashes.clone());
1071
1072 let nodes = tree.get_all_nodes();
1073 assert_eq!(nodes.len(), constants::HEIGHT + 4);
1074
1075 for h in 0..constants::HEIGHT as u32 {
1076 assert!(tree.get_opt(h, 0).is_some()); }
1078
1079 for (i, hash) in hashes.into_iter().enumerate() {
1080 assert_eq!(tree.get(0, i as u64), hash);
1081 }
1082 }
1083
1084 #[test]
1085 fn test_add_hashes_last_3() {
1086 let mut rng = CustomRng;
1087 let mut tree = &mut init().tree;
1088
1089 let max_index = (1 << constants::HEIGHT) - 1;
1090 let hashes: Vec<_> = (0..3).map(|_| rng.gen()).collect();
1091 tree.add_hashes(max_index - 127, hashes.clone());
1092
1093 let nodes = tree.get_all_nodes();
1094 assert_eq!(nodes.len(), constants::HEIGHT + 4);
1095
1096 for h in constants::OUTPLUSONELOG as u32 + 1..constants::HEIGHT as u32 {
1097 let index = max_index / 2u64.pow(h);
1098 assert!(tree.get_opt(h, index).is_some()); }
1100
1101 for (i, hash) in hashes.into_iter().enumerate() {
1102 assert_eq!(tree.get(0, max_index - 127 + i as u64), hash);
1103 }
1104 }
1105
1106 #[test]
1107 fn test_add_hashes() {
1108 let mut tree_expected = &mut init().tree;
1109 let mut tree_actual = &mut init().tree;
1110
1111 add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 0, 3);
1113 check_trees_are_equal(&tree_expected, &tree_actual);
1114
1115 add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 128, 8);
1117 check_trees_are_equal(&tree_expected, &tree_actual);
1118
1119 add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 256, 1);
1121 check_trees_are_equal(&tree_expected, &tree_actual);
1122 }
1123
1124 #[test]
1125 fn test_add_hashes_with_gap() {
1126 let mut tree_expected = &mut init().tree;
1127 let mut tree_actual = &mut init().tree;
1128
1129 add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 0, 3);
1131 check_trees_are_equal(&tree_expected, &tree_actual);
1132
1133 tree_expected.add_hash_at_height(
1134 constants::OUTPLUSONELOG as u32,
1135 1,
1136 tree_expected.zero_note_hashes[constants::OUTPLUSONELOG],
1137 false,
1138 );
1139
1140 add_hashes_to_test_trees(&mut tree_expected, &mut tree_actual, 256, 7);
1142 check_trees_are_equal(&tree_expected, &tree_actual);
1143 }
1144
1145 fn add_hashes_to_test_trees<D: KeyValueDB, P: PoolParams>(
1146 tree_expected: &mut MerkleTree<D, P>,
1147 tree_actual: &mut MerkleTree<D, P>,
1148 start_index: u64,
1149 count: u64,
1150 ) {
1151 let mut rng = CustomRng;
1152
1153 let hashes: Vec<_> = (0..count).map(|_| rng.gen()).collect();
1154
1155 for (i, hash) in hashes.clone().into_iter().enumerate() {
1156 tree_expected.add_hash(start_index + i as u64, hash, false);
1157 }
1158 tree_actual.add_hashes(start_index, hashes);
1159 }
1160
1161 fn check_trees_are_equal<D: KeyValueDB, P: PoolParams>(
1162 tree_first: &MerkleTree<D, P>,
1163 tree_second: &MerkleTree<D, P>,
1164 ) {
1165 assert_eq!(tree_first.next_index, tree_second.next_index);
1166 assert_eq!(tree_first.get_root(), tree_second.get_root());
1167
1168 let mut first_nodes = tree_first.get_all_nodes();
1169 let mut second_nodes = tree_second.get_all_nodes();
1170 assert_eq!(first_nodes.len(), second_nodes.len());
1171
1172 first_nodes.sort_by_key(|node| (node.height, node.index));
1173 second_nodes.sort_by_key(|node| (node.height, node.index));
1174
1175 assert_eq!(first_nodes, second_nodes);
1176 }
1177
1178 #[test]
1206 fn test_get_leaf_proof() {
1207 let mut rng = CustomRng;
1208 let mut tree = &mut init().tree;
1209 let proof = tree.get_leaf_proof(123);
1210
1211 assert!(proof.is_none());
1212
1213 tree.add_hash(123, rng.gen(), false);
1214 let proof = tree.get_leaf_proof(123).unwrap();
1215
1216 assert_eq!(proof.sibling.as_slice().len(), constants::HEIGHT);
1217 assert_eq!(proof.path.as_slice().len(), constants::HEIGHT);
1218 }
1219
1220 #[test]
1221 fn test_get_proof_unchecked() {
1222 let mut rng = CustomRng;
1223 let mut tree = &mut init().tree;
1224
1225 const SUBROOT_HEIGHT: usize = 1;
1227 let proof = tree.get_proof_unchecked::<SUBROOT_HEIGHT>(1);
1228 assert_eq!(
1229 proof.sibling[SUBROOT_HEIGHT - 1],
1230 tree.default_hashes[constants::HEIGHT - SUBROOT_HEIGHT]
1231 );
1232
1233 assert_eq!(proof.sibling.as_slice().len(), SUBROOT_HEIGHT);
1234 assert_eq!(proof.path.as_slice().len(), SUBROOT_HEIGHT);
1235
1236 tree.add_hash(1 << 47, rng.gen(), false);
1239 let proof = tree.get_proof_unchecked::<SUBROOT_HEIGHT>(1);
1240 assert_eq!(
1241 proof.sibling[SUBROOT_HEIGHT - 1],
1242 tree.zero_note_hashes[constants::HEIGHT - SUBROOT_HEIGHT]
1243 );
1244
1245 tree.add_hash((1 << 47) - 1, rng.gen(), false);
1247 let proof = tree.get_proof_unchecked::<SUBROOT_HEIGHT>(1);
1248 assert_ne!(
1249 proof.sibling[SUBROOT_HEIGHT - 1],
1250 tree.zero_note_hashes[constants::HEIGHT - SUBROOT_HEIGHT]
1251 );
1252 }
1253
1254 #[test]
1255 fn test_temporary_nodes_are_used_to_calculate_hashes_first() {
1256 let mut rng = CustomRng;
1257 let mut tree = &mut init().tree;
1258
1259 let hash0: Hash<_> = rng.gen();
1260 let hash1: Hash<_> = rng.gen();
1261
1262 tree.add_hash(0, hash0, true);
1264
1265 tree.add_hash(1, hash1, false);
1267
1268 let parent_hash = tree.get(1, 0);
1269 let expected_parent_hash = poseidon([hash0, hash1].as_ref(), POOL_PARAMS.compress());
1270
1271 assert_eq!(parent_hash, expected_parent_hash);
1272 }
1273
1274 #[test_case(0, 5)]
1275 #[test_case(1, 5)]
1276 #[test_case(2, 5)]
1277 #[test_case(4, 5)]
1278 #[test_case(5, 5)]
1279 #[test_case(5, 8)]
1280 #[test_case(10, 15)]
1281 #[test_case(12, 15)]
1282 fn test_all_temporary_nodes_in_subtree_are_removed(subtree_height: u32, full_height: usize) {
1283 let mut rng = CustomRng;
1284
1285 let subtree_size = 1 << subtree_height;
1286 let subtrees_count = (1 << full_height) / subtree_size;
1287 let start_index = 1 << 12;
1288 let mut subtree_indexes: Vec<_> = (0..subtrees_count).map(|i| start_index + i).collect();
1289 subtree_indexes.shuffle(&mut thread_rng());
1290
1291 let mut tree = &mut init().tree;
1292 for subtree_index in subtree_indexes {
1293 tree.add_subtree_root(subtree_height, subtree_index, rng.gen());
1294 }
1295
1296 tree.clean();
1297
1298 let tree_nodes = tree.get_all_nodes();
1299 assert_eq!(
1300 tree_nodes.len(),
1301 constants::HEIGHT - full_height + 1,
1302 "Some temporary subtree nodes were not removed."
1303 );
1304 }
1305
1306 #[test]
1307 fn test_rollback_all_works_correctly() {
1308 let remove_size: u64 = 24;
1309
1310 let mut rng = CustomRng;
1311 let mut tree = &mut init().tree;
1312
1313 let original_root = tree.get_root();
1314
1315 for index in 0..remove_size {
1316 let leaf = rng.gen();
1317 tree.add_hash(index, leaf, false);
1318 }
1319
1320 let rollback_result = tree.rollback(0);
1321 assert!(rollback_result.is_none());
1322 let rollback_root = tree.get_root();
1323 assert_eq!(rollback_root, original_root);
1324 assert_eq!(tree.next_index, 0);
1325 }
1326
1327 #[test_case(32, 16)]
1328 #[test_case(16, 0)]
1329 #[test_case(11, 7)]
1330 fn test_rollback_removes_nodes_correctly(keep_size: u64, remove_size: u64) {
1331 let mut rng = CustomRng;
1332 let mut tree = &mut init().tree;
1333
1334 for index in 0..keep_size {
1335 let leaf = rng.gen();
1336 tree.add_hash(index, leaf, false);
1337 }
1338 let original_root = tree.get_root();
1339
1340 for index in 0..remove_size {
1341 let leaf = rng.gen();
1342 tree.add_hash(128 + index, leaf, false);
1343 }
1344
1345 let rollback_result = tree.rollback(128);
1346 assert!(rollback_result.is_none());
1347 let rollback_root = tree.get_root();
1348 assert_eq!(rollback_root, original_root);
1349 assert_eq!(tree.next_index, 128);
1350 }
1351
1352 #[test]
1417 fn test_get_leaves() {
1418 let mut rng = CustomRng;
1419 let mut tree = &mut init().tree;
1420
1421 let leaves_count = 6;
1422
1423 for index in 0..leaves_count {
1424 let leaf = rng.gen();
1425 tree.add_hash(index, leaf, true);
1426 }
1427
1428 let leaves = tree.get_leaves();
1429
1430 assert_eq!(leaves.len(), leaves_count as usize);
1431 for index in 0..leaves_count {
1432 assert!(leaves.iter().any(|node| node.index == index));
1433 }
1434 }
1435
1436 #[test]
1437 fn test_get_leaves_after() {
1438 let mut rng = CustomRng;
1439 let mut tree = &mut init().tree;
1440
1441 let leaves_count = 6;
1442 let skip_count = 2;
1443
1444 for index in 0..leaves_count {
1445 let leaf = rng.gen();
1446 tree.add_hash(index, leaf, true);
1447 }
1448
1449 let leaves = tree.get_leaves_after(skip_count);
1450
1451 assert_eq!(leaves.len(), (leaves_count - skip_count) as usize);
1452 for index in skip_count..leaves_count {
1453 assert!(leaves.iter().any(|node| node.index == index));
1454 }
1455 }
1456
1457 #[test]
1458 fn test_get_proof_after() {
1459 let mut rng = CustomRng;
1460 let mut tree = &mut init().tree;
1461
1462 let tree_size = 6;
1463 let new_hashes_size = 3;
1464
1465 for index in 0..tree_size {
1466 let leaf = rng.gen();
1467 tree.add_hash(index, leaf, false);
1468 }
1469
1470 let root_before_call = tree.get_root();
1471
1472 let new_hashes: Vec<_> = (0..new_hashes_size).map(|_| rng.gen()).collect();
1473 tree.get_proof_after(new_hashes);
1474
1475 let root_after_call = tree.get_root();
1476
1477 assert_eq!(root_before_call, root_after_call);
1478 }
1479
1480 #[test_case(12, 4)]
1481 #[test_case(13, 5)]
1482 #[test_case(0, 1)]
1483 #[test_case(0, 5)]
1484 #[test_case(0, 8)]
1485 #[test_case(4, 16)]
1486 fn test_get_proof_after_virtual(tree_size: u64, new_hashes_size: u64) {
1487 let mut rng = CustomRng;
1488 let mut tree = &mut init().tree;
1489
1490 for index in 0..tree_size {
1491 let leaf = rng.gen();
1492 tree.add_hash(index, leaf, false);
1493 }
1494
1495 let new_hashes: Vec<_> = (0..new_hashes_size).map(|_| rng.gen()).collect();
1496
1497 let root_before_call = tree.get_root();
1498
1499 let proofs_virtual = tree.get_proof_after_virtual(new_hashes.clone());
1500 let proofs_simple = tree.get_proof_after(new_hashes);
1501
1502 let root_after_call = tree.get_root();
1503
1504 assert_eq!(root_before_call, root_after_call);
1505 assert_eq!(proofs_simple.len(), proofs_virtual.len());
1506 for (simple_proof, virtual_proof) in proofs_simple.iter().zip(proofs_virtual) {
1507 for (simple_sibling, virtual_sibling) in simple_proof
1508 .sibling
1509 .iter()
1510 .zip(virtual_proof.sibling.iter())
1511 {
1512 assert_eq!(simple_sibling, virtual_sibling);
1513 }
1514 for (simple_path, virtual_path) in
1515 simple_proof.path.iter().zip(virtual_proof.path.iter())
1516 {
1517 assert_eq!(simple_path, virtual_path);
1518 }
1519 }
1520 }
1521
1522 #[test]
1523 fn test_default_hashes_are_added_correctly() {
1524 let mut rng = CustomRng;
1525 let mut tree = &mut init().tree;
1526
1527 assert_eq!(tree.get(0, 0), tree.default_hashes[0]);
1529 assert_eq!(tree.get(0, 3), tree.default_hashes[0]);
1530 assert_eq!(tree.get(2, 0), tree.default_hashes[2]);
1531
1532 let hashes: Vec<_> = (0..3).map(|_| rng.gen()).collect();
1533 tree.add_hashes(0, hashes);
1534
1535 assert_ne!(tree.get(2, 0), tree.zero_note_hashes[2]);
1537 assert_ne!(tree.get(2, 0), tree.default_hashes[2]);
1538 assert_eq!(tree.get(0, 4), tree.zero_note_hashes[0]);
1540 assert_eq!(tree.get(0, 127), tree.zero_note_hashes[0]);
1541 assert_eq!(tree.get(2, 1), tree.zero_note_hashes[2]);
1542 assert_eq!(tree.get(0, 128), tree.default_hashes[0]);
1544 assert_eq!(tree.get(7, 1), tree.default_hashes[7]);
1545
1546 let hashes: Vec<_> = (0..2).map(|_| rng.gen()).collect();
1547 tree.add_hashes(128, hashes);
1548 assert_eq!(tree.get(0, 128 + 4), tree.zero_note_hashes[0]);
1550 assert_eq!(tree.get(0, 128 + 127), tree.zero_note_hashes[0]);
1551 assert_eq!(tree.get(2, 32 + 1), tree.zero_note_hashes[2]);
1552 assert_eq!(tree.get(0, 128 + 128), tree.default_hashes[0]);
1554 assert_eq!(tree.get(7, 2), tree.default_hashes[7]);
1555 }
1556
1557 #[test_case(0, 0, 0.0)]
1558 #[test_case(1, 1, 0.0)]
1559 #[test_case(1, 1, 1.0)]
1560 #[test_case(4, 2, 0.0)]
1561 #[test_case(4, 2, 0.5)]
1562 #[test_case(4, 2, 1.0)]
1563 #[test_case(15, 7, 0.0)]
1564 #[test_case(15, 7, 0.5)]
1565 #[test_case(15, 7, 1.0)]
1566 fn test_add_leafs_and_commitments(
1567 tx_count: u64,
1568 max_leafs_count: u32,
1569 commitments_probability: f64,
1570 ) {
1571 let mut rng = CustomRng;
1572 let mut first_tree = &mut init().tree;
1573 let mut second_tree = &mut init().tree;
1574
1575 let leafs: Vec<(u64, Vec<_>)> = (0..tx_count)
1576 .map(|i| {
1577 let leafs_count: u32 = 1 + (rng.gen::<u32>() % max_leafs_count);
1578 (
1579 i * (constants::OUT + 1) as u64,
1580 (0..leafs_count).map(|_| rng.gen()).collect(),
1581 )
1582 })
1583 .collect();
1584
1585 let now = std::time::Instant::now();
1586 for (index, leafs) in leafs.clone().into_iter() {
1587 first_tree.add_hashes(index, leafs)
1588 }
1589 println!(
1590 "({}, {}, {}) add_hashes elapsed: {}",
1591 tx_count,
1592 max_leafs_count,
1593 commitments_probability,
1594 now.elapsed().as_millis()
1595 );
1596
1597 let commitments: Vec<(u64, _)> = leafs
1598 .clone()
1599 .into_iter()
1600 .map(|(index, leafs)| {
1601 let mut out_hashes = leafs.clone();
1602 out_hashes.resize(constants::OUT + 1, first_tree.zero_note_hashes[0]);
1603 let commitment =
1604 tx::out_commitment_hash(out_hashes.as_slice(), &POOL_PARAMS.clone());
1605 (index, commitment)
1606 })
1607 .collect();
1608
1609 commitments.iter().for_each(|(index, commitment)| {
1610 assert_eq!(
1611 first_tree.get(
1612 constants::OUTPLUSONELOG as u32,
1613 *index >> constants::OUTPLUSONELOG
1614 ),
1615 *commitment
1616 );
1617 });
1618
1619 let mut sub_leafs: Vec<(u64, Vec<_>)> = Vec::new();
1620 let mut sub_commitments: Vec<(u64, _)> = Vec::new();
1621 (0..tx_count).for_each(|i| {
1622 if rng.gen_bool(commitments_probability) {
1623 sub_commitments.push(commitments[i as usize]);
1624 } else {
1625 sub_leafs.push((leafs[i as usize].0, leafs[i as usize].1.clone()));
1626 }
1627 });
1628
1629 let now = std::time::Instant::now();
1630 second_tree.add_leafs_and_commitments(sub_leafs, sub_commitments);
1631 println!(
1632 "({}, {}, {}) add_leafs_and_commitments elapsed: {}",
1633 tx_count,
1634 max_leafs_count,
1635 commitments_probability,
1636 now.elapsed().as_millis()
1637 );
1638
1639 assert_eq!(
1640 first_tree.get_root().to_string(),
1641 second_tree.get_root().to_string()
1642 );
1643 assert_eq!(first_tree.next_index(), second_tree.next_index());
1644 }
1645
1646 #[test_case(0, 0, 0.0)]
1647 #[test_case(1, 1, 0.0)]
1648 #[test_case(1, 1, 1.0)]
1649 #[test_case(4, 2, 0.0)]
1650 #[test_case(4, 2, 0.5)]
1651 #[test_case(4, 2, 1.0)]
1652 #[test_case(15, 7, 0.0)]
1653 #[test_case(15, 7, 0.5)]
1654 #[test_case(15, 7, 1.0)]
1655 fn test_get_root_optimistic(tx_count: u64, max_leafs_count: u32, commitments_probability: f64) {
1656 let mut rng = CustomRng;
1657 let mut tree = &mut init().tree;
1658
1659 let leafs: Vec<(u64, Vec<_>)> = (0..tx_count)
1660 .map(|i| {
1661 let leafs_count: u32 = 1 + (rng.gen::<u32>() % max_leafs_count);
1662 (
1663 i * (constants::OUT + 1) as u64,
1664 (0..leafs_count).map(|_| rng.gen()).collect(),
1665 )
1666 })
1667 .collect();
1668
1669 let commitments: Vec<(u64, _)> = leafs
1670 .clone()
1671 .into_iter()
1672 .map(|(index, leafs)| {
1673 let mut out_hashes = leafs.clone();
1674 out_hashes.resize(constants::OUT + 1, tree.zero_note_hashes[0]);
1675 let commitment =
1676 tx::out_commitment_hash(out_hashes.as_slice(), &POOL_PARAMS.clone());
1677 (index, commitment)
1678 })
1679 .collect();
1680
1681 let mut sub_leafs: Vec<(u64, Vec<_>)> = Vec::new();
1682 let mut sub_commitments: Vec<(u64, _)> = Vec::new();
1683 (0..tx_count).for_each(|i| {
1684 if rng.gen_bool(commitments_probability) {
1685 sub_commitments.push(commitments[i as usize]);
1686 } else {
1687 sub_leafs.push((leafs[i as usize].0, leafs[i as usize].1.clone()));
1688 }
1689 });
1690
1691 let (mut virtual_nodes, update_boundaries) =
1692 tree.get_virtual_subtree(sub_leafs.clone(), sub_commitments.clone());
1693 let optimistic_root = tree.get_root_optimistic(&mut virtual_nodes, &update_boundaries);
1694
1695 tree.add_leafs_and_commitments(sub_leafs, sub_commitments);
1696 let root = tree.get_root();
1697
1698 assert_eq!(optimistic_root.to_string(), root.to_string());
1699 }
1700}