light_indexed_merkle_tree/
lib.rs

1use std::{
2    fmt,
3    marker::PhantomData,
4    mem,
5    ops::{Deref, DerefMut},
6};
7
8use array::{IndexedArray, IndexedElement};
9use changelog::IndexedChangelogEntry;
10use light_bounded_vec::{BoundedVec, CyclicBoundedVec, CyclicBoundedVecMetadata};
11use light_concurrent_merkle_tree::{
12    errors::ConcurrentMerkleTreeError,
13    event::{IndexedMerkleTreeUpdate, RawIndexedElement},
14    light_hasher::Hasher,
15    ConcurrentMerkleTree,
16};
17use light_hasher::bigint::bigint_to_be_bytes_array;
18use num_bigint::BigUint;
19use num_traits::{CheckedAdd, CheckedSub, ToBytes, Unsigned};
20
21pub mod array;
22pub mod changelog;
23pub mod copy;
24pub mod errors;
25pub mod reference;
26pub mod zero_copy;
27
28use crate::errors::IndexedMerkleTreeError;
29
30pub const HIGHEST_ADDRESS_PLUS_ONE: &str =
31    "452312848583266388373324160190187140051835877600158453279131187530910662655";
32
33#[derive(Debug)]
34#[repr(C)]
35pub struct IndexedMerkleTree<H, I, const HEIGHT: usize, const NET_HEIGHT: usize>
36where
37    H: Hasher,
38    I: CheckedAdd
39        + CheckedSub
40        + Copy
41        + Clone
42        + fmt::Debug
43        + PartialOrd
44        + ToBytes
45        + TryFrom<usize>
46        + Unsigned,
47    usize: From<I>,
48{
49    pub merkle_tree: ConcurrentMerkleTree<H, HEIGHT>,
50    pub indexed_changelog: CyclicBoundedVec<IndexedChangelogEntry<I, NET_HEIGHT>>,
51
52    _index: PhantomData<I>,
53}
54
55pub type IndexedMerkleTree26<H, I> = IndexedMerkleTree<H, I, 26, 16>;
56
57impl<H, I, const HEIGHT: usize, const NET_HEIGHT: usize> IndexedMerkleTree<H, I, HEIGHT, NET_HEIGHT>
58where
59    H: Hasher,
60    I: CheckedAdd
61        + CheckedSub
62        + Copy
63        + Clone
64        + fmt::Debug
65        + PartialOrd
66        + ToBytes
67        + TryFrom<usize>
68        + Unsigned,
69    usize: From<I>,
70{
71    /// Size of the struct **without** dynamically sized fields (`BoundedVec`,
72    /// `CyclicBoundedVec`).
73    pub fn non_dyn_fields_size() -> usize {
74        ConcurrentMerkleTree::<H, HEIGHT>::non_dyn_fields_size()
75            // indexed_changelog (metadata)
76            + mem::size_of::<CyclicBoundedVecMetadata>()
77    }
78
79    // TODO(vadorovsky): Make a macro for that.
80    pub fn size_in_account(
81        height: usize,
82        changelog_size: usize,
83        roots_size: usize,
84        canopy_depth: usize,
85        indexed_changelog_size: usize,
86    ) -> usize {
87        ConcurrentMerkleTree::<H, HEIGHT>::size_in_account(
88            height,
89            changelog_size,
90            roots_size,
91            canopy_depth,
92        )
93        // indexed_changelog (metadata)
94        + mem::size_of::<CyclicBoundedVecMetadata>()
95        // indexed_changelog
96        + mem::size_of::<IndexedChangelogEntry<I, NET_HEIGHT>>() * indexed_changelog_size
97    }
98
99    pub fn new(
100        height: usize,
101        changelog_size: usize,
102        roots_size: usize,
103        canopy_depth: usize,
104        indexed_changelog_size: usize,
105    ) -> Result<Self, ConcurrentMerkleTreeError> {
106        let merkle_tree = ConcurrentMerkleTree::<H, HEIGHT>::new(
107            height,
108            changelog_size,
109            roots_size,
110            canopy_depth,
111        )?;
112        Ok(Self {
113            merkle_tree,
114            indexed_changelog: CyclicBoundedVec::with_capacity(indexed_changelog_size),
115            _index: PhantomData,
116        })
117    }
118
119    pub fn init(&mut self) -> Result<(), IndexedMerkleTreeError> {
120        self.merkle_tree.init()?;
121
122        // Append the first low leaf, which has value 0 and does not point
123        // to any other leaf yet.
124        // This low leaf is going to be updated during the first `update`
125        // operation.
126        self.merkle_tree.append(&H::zero_indexed_leaf())?;
127
128        // Emit first changelog entries.
129        let element = RawIndexedElement {
130            value: [0_u8; 32],
131            next_index: I::zero(),
132            next_value: [0_u8; 32],
133            index: I::zero(),
134        };
135        let changelog_entry = IndexedChangelogEntry {
136            element,
137            proof: H::zero_bytes()[..NET_HEIGHT].try_into().unwrap(),
138            changelog_index: 0,
139        };
140        self.indexed_changelog.push(changelog_entry.clone());
141        self.indexed_changelog.push(changelog_entry);
142
143        Ok(())
144    }
145
146    /// Add the hightest element with a maximum value allowed by the prime
147    /// field.
148    ///
149    /// Initializing an indexed Merkle tree not only with the lowest element
150    /// (mandatory for the IMT algorithm to work), but also the highest element,
151    /// makes non-inclusion proofs easier - there is no special case needed for
152    /// the first insertion.
153    ///
154    /// However, it comes with a tradeoff - the space available in the tree
155    /// becomes lower by 1.
156    pub fn add_highest_element(&mut self) -> Result<(), IndexedMerkleTreeError> {
157        let mut indexed_array = IndexedArray::<H, I>::default();
158        let element_bundle = indexed_array.init()?;
159        let new_low_leaf = element_bundle
160            .new_low_element
161            .hash::<H>(&element_bundle.new_element.value)?;
162
163        let mut proof = BoundedVec::with_capacity(self.merkle_tree.height);
164        for i in 0..self.merkle_tree.height - self.merkle_tree.canopy_depth {
165            // PANICS: Calling `unwrap()` pushing into this bounded vec
166            // cannot panic since it has enough capacity.
167            proof.push(H::zero_bytes()[i]).unwrap();
168        }
169
170        let (changelog_index, _) = self.merkle_tree.update(
171            self.changelog_index(),
172            &H::zero_indexed_leaf(),
173            &new_low_leaf,
174            0,
175            &mut proof,
176        )?;
177
178        // Emit changelog for low element.
179        let low_element = RawIndexedElement {
180            value: bigint_to_be_bytes_array::<32>(&element_bundle.new_low_element.value)?,
181            next_index: element_bundle.new_low_element.next_index,
182            next_value: bigint_to_be_bytes_array::<32>(&element_bundle.new_element.value)?,
183            index: element_bundle.new_low_element.index,
184        };
185
186        let low_element_changelog_entry = IndexedChangelogEntry {
187            element: low_element,
188            proof: H::zero_bytes()[..NET_HEIGHT].try_into().unwrap(),
189            changelog_index,
190        };
191        self.indexed_changelog.push(low_element_changelog_entry);
192
193        let new_leaf = element_bundle
194            .new_element
195            .hash::<H>(&element_bundle.new_element_next_value)?;
196        let mut proof = BoundedVec::with_capacity(self.height);
197        let (changelog_index, _) = self.merkle_tree.append_with_proof(&new_leaf, &mut proof)?;
198
199        // Emit changelog for new element.
200        let new_element = RawIndexedElement {
201            value: bigint_to_be_bytes_array::<32>(&element_bundle.new_element.value)?,
202            next_index: element_bundle.new_element.next_index,
203            next_value: [0_u8; 32],
204            index: element_bundle.new_element.index,
205        };
206        let new_element_changelog_entry = IndexedChangelogEntry {
207            element: new_element,
208            proof: proof.as_slice()[..NET_HEIGHT].try_into().unwrap(),
209            changelog_index,
210        };
211
212        self.indexed_changelog.push(new_element_changelog_entry);
213
214        Ok(())
215    }
216
217    pub fn indexed_changelog_index(&self) -> usize {
218        self.indexed_changelog.last_index()
219    }
220
221    /// Checks whether the given Merkle `proof` for the given `node` (with index
222    /// `i`) is valid. The proof is valid when computing parent node hashes using
223    /// the whole path of the proof gives the same result as the given `root`.
224    pub fn validate_proof(
225        &self,
226        leaf: &[u8; 32],
227        leaf_index: usize,
228        proof: &BoundedVec<[u8; 32]>,
229    ) -> Result<(), IndexedMerkleTreeError> {
230        self.merkle_tree.validate_proof(leaf, leaf_index, proof)?;
231        Ok(())
232    }
233
234    /// Iterates over indexed changelog and every time an entry corresponding
235    /// to the provided `low_element` is found, it patches:
236    ///
237    /// * Changelog index - indexed changelog entries contain corresponding
238    ///   changelog indices.
239    /// * New element - changes might impact the `next_index` field, which in
240    ///   such case is updated.
241    /// * Low element - it might completely change if a change introduced an
242    ///   element in our range.
243    /// * Merkle proof.
244    #[allow(clippy::type_complexity)]
245    pub fn patch_elements_and_proof(
246        &mut self,
247        indexed_changelog_index: usize,
248        changelog_index: &mut usize,
249        new_element: &mut IndexedElement<I>,
250        low_element: &mut IndexedElement<I>,
251        low_element_next_value: &mut BigUint,
252        low_leaf_proof: &mut BoundedVec<[u8; 32]>,
253    ) -> Result<(), IndexedMerkleTreeError> {
254        let next_indexed_changelog_indices: Vec<usize> = self
255            .indexed_changelog
256            .iter_from(indexed_changelog_index)?
257            .skip(1)
258            .enumerate()
259            .filter_map(|(index, changelog_entry)| {
260                if changelog_entry.element.index == low_element.index {
261                    Some((indexed_changelog_index + 1 + index) % self.indexed_changelog.len())
262                } else {
263                    None
264                }
265            })
266            .collect();
267
268        let mut new_low_element = None;
269
270        for next_indexed_changelog_index in next_indexed_changelog_indices {
271            let changelog_entry = &mut self.indexed_changelog[next_indexed_changelog_index];
272
273            let next_element_value = BigUint::from_bytes_be(&changelog_entry.element.next_value);
274            if next_element_value < new_element.value {
275                // If the next element is lower than the current element, it means
276                // that it should become the low element.
277                //
278                // Save it and break the loop.
279                new_low_element = Some((
280                    (next_indexed_changelog_index + 1) % self.indexed_changelog.len(),
281                    next_element_value,
282                ));
283                break;
284            }
285
286            // Patch the changelog index.
287            *changelog_index = changelog_entry.changelog_index;
288
289            // Patch the `next_index` of `new_element`.
290            new_element.next_index = changelog_entry.element.next_index;
291
292            // Patch the element.
293            low_element.update_from_raw_element(&changelog_entry.element);
294            // Patch the next value.
295            *low_element_next_value = BigUint::from_bytes_be(&changelog_entry.element.next_value);
296            // Patch the proof.
297            for i in 0..low_leaf_proof.len() {
298                low_leaf_proof[i] = changelog_entry.proof[i];
299            }
300        }
301
302        // If we found a new low element.
303        if let Some((new_low_element_changelog_index, new_low_element)) = new_low_element {
304            let new_low_element_changelog_entry =
305                &self.indexed_changelog[new_low_element_changelog_index];
306            *changelog_index = new_low_element_changelog_entry.changelog_index;
307            *low_element = IndexedElement {
308                index: new_low_element_changelog_entry.element.index,
309                value: new_low_element.clone(),
310                next_index: new_low_element_changelog_entry.element.next_index,
311            };
312
313            for i in 0..low_leaf_proof.len() {
314                low_leaf_proof[i] = new_low_element_changelog_entry.proof[i];
315            }
316            new_element.next_index = low_element.next_index;
317
318            // Start the patching process from scratch for the new low element.
319            return self.patch_elements_and_proof(
320                new_low_element_changelog_index,
321                changelog_index,
322                new_element,
323                low_element,
324                low_element_next_value,
325                low_leaf_proof,
326            );
327        }
328
329        Ok(())
330    }
331
332    pub fn update(
333        &mut self,
334        mut changelog_index: usize,
335        indexed_changelog_index: usize,
336        new_element_value: BigUint,
337        mut low_element: IndexedElement<I>,
338        mut low_element_next_value: BigUint,
339        low_leaf_proof: &mut BoundedVec<[u8; 32]>,
340    ) -> Result<IndexedMerkleTreeUpdate<I>, IndexedMerkleTreeError> {
341        let mut new_element = IndexedElement {
342            index: I::try_from(self.merkle_tree.next_index())
343                .map_err(|_| IndexedMerkleTreeError::IntegerOverflow)?,
344            value: new_element_value,
345            next_index: low_element.next_index,
346        };
347        println!("low_element: {:?}", low_element);
348
349        self.patch_elements_and_proof(
350            indexed_changelog_index,
351            &mut changelog_index,
352            &mut new_element,
353            &mut low_element,
354            &mut low_element_next_value,
355            low_leaf_proof,
356        )?;
357        println!("patched low_element: {:?}", low_element);
358        // Check that the value of `new_element` belongs to the range
359        // of `old_low_element`.
360        if low_element.next_index == I::zero() {
361            // In this case, the `old_low_element` is the greatest element.
362            // The value of `new_element` needs to be greater than the value of
363            // `old_low_element` (and therefore, be the greatest).
364            if new_element.value <= low_element.value {
365                return Err(IndexedMerkleTreeError::LowElementGreaterOrEqualToNewElement);
366            }
367        } else {
368            // The value of `new_element` needs to be greater than the value of
369            // `old_low_element` (and therefore, be the greatest).
370            if new_element.value <= low_element.value {
371                return Err(IndexedMerkleTreeError::LowElementGreaterOrEqualToNewElement);
372            }
373            // The value of `new_element` needs to be lower than the value of
374            // next element pointed by `old_low_element`.
375            if new_element.value >= low_element_next_value {
376                return Err(IndexedMerkleTreeError::NewElementGreaterOrEqualToNextElement);
377            }
378        }
379        // Instantiate `new_low_element` - the low element with updated values.
380        let new_low_element = IndexedElement::<I> {
381            index: low_element.index,
382            value: low_element.value.clone(),
383            next_index: new_element.index,
384        };
385        // Update low element. If the `old_low_element` does not belong to the
386        // tree, validating the proof is going to fail.
387        let old_low_leaf = low_element.hash::<H>(&low_element_next_value)?;
388
389        let new_low_leaf = new_low_element.hash::<H>(&new_element.value)?;
390
391        let (new_changelog_index, _) = self.merkle_tree.update(
392            changelog_index,
393            &old_low_leaf,
394            &new_low_leaf,
395            low_element.index.into(),
396            low_leaf_proof,
397        )?;
398
399        // Emit changelog entry for low element.
400        let new_low_element = RawIndexedElement {
401            value: bigint_to_be_bytes_array::<32>(&new_low_element.value).unwrap(),
402            next_index: new_low_element.next_index,
403            next_value: bigint_to_be_bytes_array::<32>(&new_element.value)?,
404            index: new_low_element.index,
405        };
406        let low_element_changelog_entry = IndexedChangelogEntry {
407            element: new_low_element,
408            proof: low_leaf_proof.as_slice()[..NET_HEIGHT].try_into().unwrap(),
409            changelog_index: new_changelog_index,
410        };
411
412        self.indexed_changelog.push(low_element_changelog_entry);
413
414        // New element is always the newest one in the tree. Since we
415        // support concurrent updates, the index provided by the caller
416        // might be outdated. Let's just use the latest index indicated
417        // by the tree.
418        new_element.index =
419            I::try_from(self.next_index()).map_err(|_| IndexedMerkleTreeError::IntegerOverflow)?;
420
421        // Append new element.
422        let mut proof = BoundedVec::with_capacity(self.height);
423        let new_leaf = new_element.hash::<H>(&low_element_next_value)?;
424        let (new_changelog_index, _) = self.merkle_tree.append_with_proof(&new_leaf, &mut proof)?;
425
426        // Prepare raw new element to save in changelog.
427        let raw_new_element = RawIndexedElement {
428            value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(),
429            next_index: new_element.next_index,
430            next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value)?,
431            index: new_element.index,
432        };
433
434        // Emit changelog entry for new element.
435        let new_element_changelog_entry = IndexedChangelogEntry {
436            element: raw_new_element,
437            proof: proof.as_slice()[..NET_HEIGHT].try_into().unwrap(),
438            changelog_index: new_changelog_index,
439        };
440        self.indexed_changelog.push(new_element_changelog_entry);
441
442        let output = IndexedMerkleTreeUpdate {
443            new_low_element,
444            new_low_element_hash: new_low_leaf,
445            new_high_element: raw_new_element,
446            new_high_element_hash: new_leaf,
447        };
448
449        Ok(output)
450    }
451}
452
453impl<H, I, const HEIGHT: usize, const NET_HEIGHT: usize> Deref
454    for IndexedMerkleTree<H, I, HEIGHT, NET_HEIGHT>
455where
456    H: Hasher,
457    I: CheckedAdd
458        + CheckedSub
459        + Copy
460        + Clone
461        + fmt::Debug
462        + PartialOrd
463        + ToBytes
464        + TryFrom<usize>
465        + Unsigned,
466    usize: From<I>,
467{
468    type Target = ConcurrentMerkleTree<H, HEIGHT>;
469
470    fn deref(&self) -> &Self::Target {
471        &self.merkle_tree
472    }
473}
474
475impl<H, I, const HEIGHT: usize, const NET_HEIGHT: usize> DerefMut
476    for IndexedMerkleTree<H, I, HEIGHT, NET_HEIGHT>
477where
478    H: Hasher,
479    I: CheckedAdd
480        + CheckedSub
481        + Copy
482        + Clone
483        + fmt::Debug
484        + PartialOrd
485        + ToBytes
486        + TryFrom<usize>
487        + Unsigned,
488    usize: From<I>,
489{
490    fn deref_mut(&mut self) -> &mut Self::Target {
491        &mut self.merkle_tree
492    }
493}
494
495impl<H, I, const HEIGHT: usize, const NET_HEIGHT: usize> PartialEq
496    for IndexedMerkleTree<H, I, HEIGHT, NET_HEIGHT>
497where
498    H: Hasher,
499    I: CheckedAdd
500        + CheckedSub
501        + Copy
502        + Clone
503        + fmt::Debug
504        + PartialOrd
505        + ToBytes
506        + TryFrom<usize>
507        + Unsigned,
508    usize: From<I>,
509{
510    fn eq(&self, other: &Self) -> bool {
511        self.merkle_tree.eq(&other.merkle_tree)
512            && self
513                .indexed_changelog
514                .capacity()
515                .eq(&other.indexed_changelog.capacity())
516            && self
517                .indexed_changelog
518                .len()
519                .eq(&other.indexed_changelog.len())
520            && self
521                .indexed_changelog
522                .first_index()
523                .eq(&other.indexed_changelog.first_index())
524            && self
525                .indexed_changelog
526                .last_index()
527                .eq(&other.indexed_changelog.last_index())
528            && self.indexed_changelog.eq(&other.indexed_changelog)
529    }
530}