mmr_crypto_primitives/merkle_tree/
mod.rs

1#![allow(clippy::needless_range_loop)]
2
3/// Defines a trait to chain two types of CRHs.
4use crate::crh::TwoToOneCRHScheme;
5use crate::{CRHScheme, Error};
6use ark_ff::ToBytes;
7use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write};
8use ark_std::borrow::Borrow;
9use ark_std::hash::Hash;
10use ark_std::vec::Vec;
11
12#[cfg(test)]
13mod tests;
14
15#[cfg(feature = "r1cs")]
16pub mod constraints;
17
18/// Convert the hash digest in different layers by converting previous layer's output to
19/// `TargetType`, which is a `Borrow` to next layer's input.
20pub trait DigestConverter<From, To: ?Sized> {
21    type TargetType: Borrow<To>;
22    fn convert(item: From) -> Result<Self::TargetType, Error>;
23}
24
25/// A trivial converter where digest of previous layer's hash is the same as next layer's input.
26pub struct IdentityDigestConverter<T> {
27    _prev_layer_digest: T,
28}
29
30impl<T> DigestConverter<T, T> for IdentityDigestConverter<T> {
31    type TargetType = T;
32    fn convert(item: T) -> Result<T, Error> {
33        Ok(item)
34    }
35}
36
37/// Convert previous layer's digest to bytes and use bytes as input for next layer's digest.
38/// TODO: `ToBytes` trait will be deprecated in future versions.
39pub struct ByteDigestConverter<T: CanonicalSerialize + ToBytes> {
40    _prev_layer_digest: T,
41}
42
43impl<T: CanonicalSerialize + ToBytes> DigestConverter<T, [u8]> for ByteDigestConverter<T> {
44    type TargetType = Vec<u8>;
45
46    fn convert(item: T) -> Result<Self::TargetType, Error> {
47        // TODO: In some tests, `serialize` is not consistent with constraints. Try fix those.
48        Ok(crate::to_unchecked_bytes!(item)?)
49    }
50}
51
52/// Merkle tree have three types of hashes.
53/// * `LeafHash`: Convert leaf to leaf digest
54/// * `TwoLeavesToOneHash`: Convert two leaf digests to one inner digest. This one can be a wrapped
55/// version `TwoHashesToOneHash`, which first converts leaf digest to inner digest.
56/// * `TwoHashesToOneHash`: Compress two inner digests to one inner digest
57pub trait Config {
58    type Leaf: ?Sized; // merkle tree does not store the leaf
59                       // leaf layer
60    type LeafDigest: ToBytes
61        + Clone
62        + Eq
63        + core::fmt::Debug
64        + Hash
65        + Default
66        + CanonicalSerialize
67        + CanonicalDeserialize;
68    // transition between leaf layer to inner layer
69    type LeafInnerDigestConverter: DigestConverter<
70        Self::LeafDigest,
71        <Self::TwoToOneHash as TwoToOneCRHScheme>::Input,
72    >;
73    // inner layer
74    type InnerDigest: ToBytes
75        + Clone
76        + Eq
77        + core::fmt::Debug
78        + Hash
79        + Default
80        + CanonicalSerialize
81        + CanonicalDeserialize;
82
83    // Tom's Note: in the future, if we want different hash function, we can simply add more
84    // types of digest here and specify a digest converter. Same for constraints.
85
86    /// leaf -> leaf digest
87    /// If leaf hash digest and inner hash digest are different, we can create a new
88    /// leaf hash which wraps the original leaf hash and convert its output to `Digest`.
89    type LeafHash: CRHScheme<Input = Self::Leaf, Output = Self::LeafDigest>;
90    /// 2 inner digest -> inner digest
91    type TwoToOneHash: TwoToOneCRHScheme<Output = Self::InnerDigest>;
92}
93
94pub type TwoToOneParam<P> = <<P as Config>::TwoToOneHash as TwoToOneCRHScheme>::Parameters;
95pub type LeafParam<P> = <<P as Config>::LeafHash as CRHScheme>::Parameters;
96
97/// Stores the hashes of a particular path (in order) from root to leaf.
98/// For example:
99/// ```tree_diagram
100///         [A]
101///        /   \
102///      [B]    C
103///     / \   /  \
104///    D [E] F    H
105///   .. / \ ....
106///    [I] J
107/// ```
108///  Suppose we want to prove I, then `leaf_sibling_hash` is J, `auth_path` is `[C,D]`
109#[derive(Derivative, CanonicalSerialize, CanonicalDeserialize)]
110#[derivative(
111    Clone(bound = "P: Config"),
112    Debug(bound = "P: Config"),
113    Default(bound = "P: Config")
114)]
115pub struct Path<P: Config> {
116    pub leaf_sibling_hash: P::LeafDigest,
117    /// The sibling of path node ordered from higher layer to lower layer (does not include root node).
118    pub auth_path: Vec<P::InnerDigest>,
119    /// stores the leaf index of the node
120    pub leaf_index: usize,
121}
122
123impl<P: Config> Path<P> {
124    /// The position of on_path node in `leaf_and_sibling_hash` and `non_leaf_and_sibling_hash_path`.
125    /// `position[i]` is 0 (false) iff `i`th on-path node from top to bottom is on the left.
126    ///
127    /// This function simply converts `self.leaf_index` to boolean array in big endian form.
128    #[allow(unused)] // this function is actually used when r1cs feature is on
129    fn position_list(&'_ self) -> impl '_ + Iterator<Item = bool> {
130        (0..self.auth_path.len() + 1)
131            .map(move |i| ((self.leaf_index >> i) & 1) != 0)
132            .rev()
133    }
134}
135
136impl<P: Config> Path<P> {
137    /// Verify that a leaf is at `self.index` of the merkle tree.
138    /// * `leaf_size`: leaf size in number of bytes
139    ///
140    /// `verify` infers the tree height by setting `tree_height = self.auth_path.len() + 2`
141    pub fn verify<L: Borrow<P::Leaf>>(
142        &self,
143        leaf_hash_params: &LeafParam<P>,
144        two_to_one_params: &TwoToOneParam<P>,
145        root_hash: &P::InnerDigest,
146        leaf: L,
147    ) -> Result<bool, crate::Error> {
148        // calculate leaf hash
149        let claimed_leaf_hash = P::LeafHash::evaluate(&leaf_hash_params, leaf)?;
150        // check hash along the path from bottom to root
151        let (left_child, right_child) =
152            select_left_right_child(self.leaf_index, &claimed_leaf_hash, &self.leaf_sibling_hash)?;
153
154        // leaf layer to inner layer conversion
155        let left_child = P::LeafInnerDigestConverter::convert(left_child)?;
156        let right_child = P::LeafInnerDigestConverter::convert(right_child)?;
157
158        let mut curr_path_node =
159            P::TwoToOneHash::evaluate(&two_to_one_params, left_child, right_child)?;
160
161        // we will use `index` variable to track the position of path
162        let mut index = self.leaf_index;
163        index >>= 1;
164
165        // Check levels between leaf level and root
166        for level in (0..self.auth_path.len()).rev() {
167            // check if path node at this level is left or right
168            let (left, right) =
169                select_left_right_child(index, &curr_path_node, &self.auth_path[level])?;
170            // update curr_path_node
171            curr_path_node = P::TwoToOneHash::compress(&two_to_one_params, &left, &right)?;
172            index >>= 1;
173        }
174
175        // check if final hash is root
176        if &curr_path_node != root_hash {
177            return Ok(false);
178        }
179
180        Ok(true)
181    }
182}
183
184/// `index` is the first `path.len()` bits of
185/// the position of tree.
186///
187/// If the least significant bit of `index` is 0, then `sibling` will be left and `computed` will be right.
188/// Otherwise, `sibling` will be right and `computed` will be left.
189///
190/// Returns: (left, right)
191fn select_left_right_child<L: Clone>(
192    index: usize,
193    computed_hash: &L,
194    sibling_hash: &L,
195) -> Result<(L, L), crate::Error> {
196    let is_left = index & 1 == 0;
197    let mut left_child = computed_hash;
198    let mut right_child = sibling_hash;
199    if !is_left {
200        core::mem::swap(&mut left_child, &mut right_child);
201    }
202    Ok((left_child.clone(), right_child.clone()))
203}
204
205/// Defines a merkle tree data structure.
206/// This merkle tree has runtime fixed height, and assumes number of leaves is 2^height.
207///
208/// TODO: add RFC-6962 compatible merkle tree in the future.
209/// For this release, padding will not be supported because of security concerns: if the leaf hash and two to one hash uses same underlying
210/// CRH, a malicious prover can prove a leaf while the actual node is an inner node. In the future, we can prefix leaf hashes in different layers to
211/// solve the problem.
212#[derive(Derivative)]
213#[derivative(Clone(bound = "P: Config"))]
214pub struct MerkleTree<P: Config> {
215    /// stores the non-leaf nodes in level order. The first element is the root node.
216    /// The ith nodes (starting at 1st) children are at indices `2*i`, `2*i+1`
217    non_leaf_nodes: Vec<P::InnerDigest>,
218    /// store the hash of leaf nodes from left to right
219    leaf_nodes: Vec<P::LeafDigest>,
220    /// Store the inner hash parameters
221    two_to_one_hash_param: TwoToOneParam<P>,
222    /// Store the leaf hash parameters
223    leaf_hash_param: LeafParam<P>,
224    /// Stores the height of the MerkleTree
225    height: usize,
226}
227
228impl<P: Config> MerkleTree<P> {
229    /// Create an empty merkle tree such that all leaves are zero-filled.
230    /// Consider using a sparse merkle tree if you need the tree to be low memory
231    pub fn blank(
232        leaf_hash_param: &LeafParam<P>,
233        two_to_one_hash_param: &TwoToOneParam<P>,
234        height: usize,
235    ) -> Result<Self, crate::Error> {
236        // use empty leaf digest
237        let leaves_digest = vec![P::LeafDigest::default(); 1 << (height - 1)];
238        Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digest)
239    }
240
241    /// Returns a new merkle tree. `leaves.len()` should be power of two.
242    pub fn new<L: Borrow<P::Leaf>>(
243        leaf_hash_param: &LeafParam<P>,
244        two_to_one_hash_param: &TwoToOneParam<P>,
245        leaves: impl IntoIterator<Item = L>,
246    ) -> Result<Self, crate::Error> {
247        let mut leaves_digests = Vec::new();
248
249        // compute and store hash values for each leaf
250        for leaf in leaves.into_iter() {
251            leaves_digests.push(P::LeafHash::evaluate(leaf_hash_param, leaf)?)
252        }
253
254        Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digests)
255    }
256
257    pub fn new_with_leaf_digest(
258        leaf_hash_param: &LeafParam<P>,
259        two_to_one_hash_param: &TwoToOneParam<P>,
260        leaves_digest: Vec<P::LeafDigest>,
261    ) -> Result<Self, crate::Error> {
262        let leaf_nodes_size = leaves_digest.len();
263        assert!(
264            leaf_nodes_size.is_power_of_two() && leaf_nodes_size > 1,
265            "`leaves.len() should be power of two and greater than one"
266        );
267        let non_leaf_nodes_size = leaf_nodes_size - 1;
268
269        let tree_height = tree_height(leaf_nodes_size);
270
271        let hash_of_empty: P::InnerDigest = P::InnerDigest::default();
272
273        // initialize the merkle tree as array of nodes in level order
274        let mut non_leaf_nodes: Vec<P::InnerDigest> = (0..non_leaf_nodes_size)
275            .map(|_| hash_of_empty.clone())
276            .collect();
277
278        // Compute the starting indices for each non-leaf level of the tree
279        let mut index = 0;
280        let mut level_indices = Vec::with_capacity(tree_height - 1);
281        for _ in 0..(tree_height - 1) {
282            level_indices.push(index);
283            index = left_child(index);
284        }
285
286        // compute the hash values for the non-leaf bottom layer
287        {
288            let start_index = level_indices.pop().unwrap();
289            let upper_bound = left_child(start_index);
290            for current_index in start_index..upper_bound {
291                // `left_child(current_index)` and `right_child(current_index) returns the position of
292                // leaf in the whole tree (represented as a list in level order). We need to shift it
293                // by `-upper_bound` to get the index in `leaf_nodes` list.
294                let left_leaf_index = left_child(current_index) - upper_bound;
295                let right_leaf_index = right_child(current_index) - upper_bound;
296                // compute hash
297                non_leaf_nodes[current_index] = P::TwoToOneHash::evaluate(
298                    &two_to_one_hash_param,
299                    P::LeafInnerDigestConverter::convert(leaves_digest[left_leaf_index].clone())?,
300                    P::LeafInnerDigestConverter::convert(leaves_digest[right_leaf_index].clone())?,
301                )?
302            }
303        }
304
305        // compute the hash values for nodes in every other layer in the tree
306        level_indices.reverse();
307        for &start_index in &level_indices {
308            // The layer beginning `start_index` ends at `upper_bound` (exclusive).
309            let upper_bound = left_child(start_index);
310            for current_index in start_index..upper_bound {
311                let left_index = left_child(current_index);
312                let right_index = right_child(current_index);
313                non_leaf_nodes[current_index] = P::TwoToOneHash::compress(
314                    &two_to_one_hash_param,
315                    non_leaf_nodes[left_index].clone(),
316                    non_leaf_nodes[right_index].clone(),
317                )?
318            }
319        }
320
321        Ok(MerkleTree {
322            leaf_nodes: leaves_digest,
323            non_leaf_nodes,
324            height: tree_height,
325            leaf_hash_param: leaf_hash_param.clone(),
326            two_to_one_hash_param: two_to_one_hash_param.clone(),
327        })
328    }
329
330    /// Returns the root of the Merkle tree.
331    pub fn root(&self) -> P::InnerDigest {
332        self.non_leaf_nodes[0].clone()
333    }
334
335    /// Returns the height of the Merkle tree.
336    pub fn height(&self) -> usize {
337        self.height
338    }
339
340    /// Returns the authentication path from leaf at `index` to root.
341    pub fn generate_proof(&self, index: usize) -> Result<Path<P>, crate::Error> {
342        // gather basic tree information
343        let tree_height = tree_height(self.leaf_nodes.len());
344
345        // Get Leaf hash, and leaf sibling hash,
346        let leaf_index_in_tree = convert_index_to_last_level(index, tree_height);
347        let leaf_sibling_hash = if index & 1 == 0 {
348            // leaf is left child
349            self.leaf_nodes[index + 1].clone()
350        } else {
351            // leaf is right child
352            self.leaf_nodes[index - 1].clone()
353        };
354
355        // path.len() = `tree height - 2`, the two missing elements being the leaf sibling hash and the root
356        let mut path = Vec::with_capacity(tree_height - 2);
357        // Iterate from the bottom layer after the leaves, to the top, storing all sibling node's hash values.
358        let mut current_node = parent(leaf_index_in_tree).unwrap();
359        while !is_root(current_node) {
360            let sibling_node = sibling(current_node).unwrap();
361            path.push(self.non_leaf_nodes[sibling_node].clone());
362            current_node = parent(current_node).unwrap();
363        }
364
365        debug_assert_eq!(path.len(), tree_height - 2);
366
367        // we want to make path from root to bottom
368        path.reverse();
369
370        Ok(Path {
371            leaf_index: index,
372            auth_path: path,
373            leaf_sibling_hash,
374        })
375    }
376
377    /// Given the index and new leaf, return the hash of leaf and an updated path in order from root to bottom non-leaf level.
378    /// This does not mutate the underlying tree.
379    fn updated_path<T: Borrow<P::Leaf>>(
380        &self,
381        index: usize,
382        new_leaf: T,
383    ) -> Result<(P::LeafDigest, Vec<P::InnerDigest>), crate::Error> {
384        // calculate the hash of leaf
385        let new_leaf_hash: P::LeafDigest = P::LeafHash::evaluate(&self.leaf_hash_param, new_leaf)?;
386
387        // calculate leaf sibling hash and locate its position (left or right)
388        let (leaf_left, leaf_right) = if index & 1 == 0 {
389            // leaf on left
390            (&new_leaf_hash, &self.leaf_nodes[index + 1])
391        } else {
392            (&self.leaf_nodes[index - 1], &new_leaf_hash)
393        };
394
395        // calculate the updated hash at bottom non-leaf-level
396        let mut path_bottom_to_top = Vec::with_capacity(self.height - 1);
397        {
398            path_bottom_to_top.push(P::TwoToOneHash::evaluate(
399                &self.two_to_one_hash_param,
400                P::LeafInnerDigestConverter::convert(leaf_left.clone())?,
401                P::LeafInnerDigestConverter::convert(leaf_right.clone())?,
402            )?);
403        }
404
405        // then calculate the updated hash from bottom to root
406        let leaf_index_in_tree = convert_index_to_last_level(index, self.height);
407        let mut prev_index = parent(leaf_index_in_tree).unwrap();
408        while !is_root(prev_index) {
409            let (left_child, right_child) = if is_left_child(prev_index) {
410                (
411                    path_bottom_to_top.last().unwrap(),
412                    &self.non_leaf_nodes[sibling(prev_index).unwrap()],
413                )
414            } else {
415                (
416                    &self.non_leaf_nodes[sibling(prev_index).unwrap()],
417                    path_bottom_to_top.last().unwrap(),
418                )
419            };
420            let evaluated =
421                P::TwoToOneHash::compress(&self.two_to_one_hash_param, left_child, right_child)?;
422            path_bottom_to_top.push(evaluated);
423            prev_index = parent(prev_index).unwrap();
424        }
425
426        debug_assert_eq!(path_bottom_to_top.len(), self.height - 1);
427        let path_top_to_bottom: Vec<_> = path_bottom_to_top.into_iter().rev().collect();
428        Ok((new_leaf_hash, path_top_to_bottom))
429    }
430
431    /// Update the leaf at `index` to updated leaf.
432    /// ```tree_diagram
433    ///         [A]
434    ///        /   \
435    ///      [B]    C
436    ///     / \   /  \
437    ///    D [E] F    H
438    ///   .. / \ ....
439    ///    [I] J
440    /// ```
441    /// update(3, {new leaf}) would swap the leaf value at `[I]` and cause a recomputation of `[A]`, `[B]`, and `[E]`.
442    pub fn update(&mut self, index: usize, new_leaf: &P::Leaf) -> Result<(), crate::Error> {
443        assert!(index < self.leaf_nodes.len(), "index out of range");
444        let (updated_leaf_hash, mut updated_path) = self.updated_path(index, new_leaf)?;
445        self.leaf_nodes[index] = updated_leaf_hash;
446        let mut curr_index = convert_index_to_last_level(index, self.height);
447        for _ in 0..self.height - 1 {
448            curr_index = parent(curr_index).unwrap();
449            self.non_leaf_nodes[curr_index] = updated_path.pop().unwrap();
450        }
451        Ok(())
452    }
453
454    /// Update the leaf and check if the updated root is equal to `asserted_new_root`.
455    ///
456    /// Tree will not be modified if the check fails.
457    pub fn check_update<T: Borrow<P::Leaf>>(
458        &mut self,
459        index: usize,
460        new_leaf: &P::Leaf,
461        asserted_new_root: &P::InnerDigest,
462    ) -> Result<bool, crate::Error> {
463        let new_leaf = new_leaf.borrow();
464        assert!(index < self.leaf_nodes.len(), "index out of range");
465        let (updated_leaf_hash, mut updated_path) = self.updated_path(index, new_leaf)?;
466        if &updated_path[0] != asserted_new_root {
467            return Ok(false);
468        }
469        self.leaf_nodes[index] = updated_leaf_hash;
470        let mut curr_index = convert_index_to_last_level(index, self.height);
471        for _ in 0..self.height - 1 {
472            curr_index = parent(curr_index).unwrap();
473            self.non_leaf_nodes[curr_index] = updated_path.pop().unwrap();
474        }
475        Ok(true)
476    }
477}
478
479/// Returns the height of the tree, given the number of leaves.
480#[inline]
481fn tree_height(num_leaves: usize) -> usize {
482    if num_leaves == 1 {
483        return 1;
484    }
485
486    (ark_std::log2(num_leaves) as usize) + 1
487}
488/// Returns true iff the index represents the root.
489#[inline]
490fn is_root(index: usize) -> bool {
491    index == 0
492}
493
494/// Returns the index of the left child, given an index.
495#[inline]
496fn left_child(index: usize) -> usize {
497    2 * index + 1
498}
499
500/// Returns the index of the right child, given an index.
501#[inline]
502fn right_child(index: usize) -> usize {
503    2 * index + 2
504}
505
506/// Returns the index of the sibling, given an index.
507#[inline]
508fn sibling(index: usize) -> Option<usize> {
509    if index == 0 {
510        None
511    } else if is_left_child(index) {
512        Some(index + 1)
513    } else {
514        Some(index - 1)
515    }
516}
517
518/// Returns true iff the given index represents a left child.
519#[inline]
520fn is_left_child(index: usize) -> bool {
521    index % 2 == 1
522}
523
524/// Returns the index of the parent, given an index.
525#[inline]
526fn parent(index: usize) -> Option<usize> {
527    if index > 0 {
528        Some((index - 1) >> 1)
529    } else {
530        None
531    }
532}
533
534#[inline]
535fn convert_index_to_last_level(index: usize, tree_height: usize) -> usize {
536    index + (1 << (tree_height - 1)) - 1
537}