snarkvm-console-collections 0.9.11

Collections for a decentralized virtual machine
Documentation
// Copyright (C) 2019-2022 Aleo Systems Inc.
// This file is part of the snarkVM library.

// The snarkVM library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The snarkVM library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the snarkVM library. If not, see <https://www.gnu.org/licenses/>.

mod helpers;
pub use helpers::*;

mod path;
pub use path::*;

#[cfg(test)]
mod tests;

use snarkvm_console_types::prelude::*;

use aleo_std::prelude::*;

#[cfg(feature = "parallel")]
use rayon::prelude::*;

#[derive(Clone)]
pub struct MerkleTree<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>>, const DEPTH: u8> {
    /// The leaf hasher for the Merkle tree.
    leaf_hasher: LH,
    /// The path hasher for the Merkle tree.
    path_hasher: PH,
    /// The computed root of the full Merkle tree.
    root: PH::Hash,
    /// The internal hashes, from root to hashed leaves, of the full Merkle tree.
    tree: Vec<PH::Hash>,
    /// The canonical empty hash.
    empty_hash: Field<E>,
    /// The number of hashed leaves in the tree.
    number_of_leaves: usize,
}

impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>>, const DEPTH: u8>
    MerkleTree<E, LH, PH, DEPTH>
{
    #[inline]
    /// Initializes a new Merkle tree with the given leaves.
    pub fn new(leaf_hasher: &LH, path_hasher: &PH, leaves: &[LH::Leaf]) -> Result<Self> {
        let timer = timer!("MerkleTree::new");

        // Ensure the Merkle tree depth is greater than 0.
        ensure!(DEPTH > 0, "Merkle tree depth must be greater than 0");
        // Ensure the Merkle tree depth is less than or equal to 64.
        ensure!(DEPTH <= 64u8, "Merkle tree depth must be less than or equal to 64");

        // Compute the maximum number of leaves.
        let max_leaves = match leaves.len().checked_next_power_of_two() {
            Some(num_leaves) => num_leaves,
            None => bail!("Integer overflow when computing the maximum number of leaves in the Merkle tree"),
        };

        // Compute the number of nodes.
        let num_nodes = max_leaves - 1;
        // Compute the tree size as the maximum number of leaves plus the number of nodes.
        let tree_size = max_leaves + num_nodes;
        // Compute the number of levels in the Merkle tree (i.e. log2(tree_size)).
        let tree_depth = tree_depth::<DEPTH>(tree_size)?;
        // Compute the number of padded levels.
        let padding_depth = DEPTH - tree_depth;

        // Compute the empty hash.
        let empty_hash = path_hasher.hash_empty()?;

        // Initialize the Merkle tree.
        let mut tree = vec![empty_hash; tree_size];

        // Compute and store each leaf hash.
        tree[num_nodes..num_nodes + leaves.len()].copy_from_slice(&leaf_hasher.hash_leaves(leaves)?);
        lap!(timer, "Hashed {} leaves", leaves.len());

        // Compute and store the hashes for each level, iterating from the penultimate level to the root level.
        let mut start_index = num_nodes;
        // Compute the start index of the current level.
        while let Some(start) = parent(start_index) {
            // Compute the end index of the current level.
            let end = left_child(start);
            // Construct the children for each node in the current level.
            let tuples = (start..end).map(|i| (tree[left_child(i)], tree[right_child(i)])).collect::<Vec<_>>();
            // Compute and store the hashes for each node in the current level.
            tree[start..end].copy_from_slice(&path_hasher.hash_all_children(&tuples)?);
            // Update the start index for the next level.
            start_index = start;
        }
        lap!(timer, "Hashed {} levels", tree_depth);

        // Compute the root hash, by iterating from the root level up to `DEPTH`.
        let mut root_hash = tree[0];
        for _ in 0..padding_depth {
            // Update the root hash, by hashing the current root hash with the empty hash.
            root_hash = path_hasher.hash_children(&root_hash, &empty_hash)?;
        }
        lap!(timer, "Hashed {} padding levels", padding_depth);

        finish!(timer);

        Ok(Self {
            leaf_hasher: leaf_hasher.clone(),
            path_hasher: path_hasher.clone(),
            root: root_hash,
            tree,
            empty_hash,
            number_of_leaves: leaves.len(),
        })
    }

    #[inline]
    /// Returns a new Merkle tree with the given new leaves appended to it.
    pub fn prepare_append(&self, new_leaves: &[LH::Leaf]) -> Result<Self> {
        let timer = timer!("MerkleTree::prepare_append");

        // Compute the maximum number of leaves.
        let max_leaves = match (self.number_of_leaves + new_leaves.len()).checked_next_power_of_two() {
            Some(num_leaves) => num_leaves,
            None => bail!("Integer overflow when computing the maximum number of leaves in the Merkle tree"),
        };
        // Compute the number of nodes.
        let num_nodes = max_leaves - 1;
        // Compute the tree size as the maximum number of leaves plus the number of nodes.
        let tree_size = num_nodes + max_leaves;
        // Compute the number of levels in the Merkle tree (i.e. log2(tree_size)).
        let tree_depth = tree_depth::<DEPTH>(tree_size)?;
        // Compute the number of padded levels.
        let padding_depth = DEPTH - tree_depth;

        // Initialize the Merkle tree.
        let mut tree = vec![self.empty_hash; num_nodes];
        // Extend the new Merkle tree with the existing leaf hashes.
        tree.extend(self.leaf_hashes()?);
        // Extend the new Merkle tree with the new leaf hashes.
        tree.extend(&self.leaf_hasher.hash_leaves(new_leaves)?);
        // Resize the new Merkle tree with empty hashes to pad up to `tree_size`.
        tree.resize(tree_size, self.empty_hash);
        lap!(timer, "Hashed {} new leaves", new_leaves.len());

        // Initialize a start index to track the starting index of the current level.
        let start_index = num_nodes;
        // Initialize a middle index to separate the precomputed indices from the new indices that need to be computed.
        let middle_index = num_nodes + self.number_of_leaves;
        // Initialize a precompute index to track the starting index of each precomputed level.
        let start_precompute_index = match self.number_of_leaves.checked_next_power_of_two() {
            Some(num_leaves) => num_leaves - 1,
            None => bail!("Integer overflow when computing the Merkle tree precompute index"),
        };
        // Initialize a precompute index to track the middle index of each precomputed level.
        let middle_precompute_index = match num_nodes == start_precompute_index {
            // If the old tree and new tree are of the same size, then we can copy over the right half of the old tree.
            true => Some(start_precompute_index + self.number_of_leaves + new_leaves.len() + 1),
            // Otherwise, we need to compute the right half of the new tree.
            false => None,
        };

        // Compute and store the hashes for each level, iterating from the penultimate level to the root level.
        self.compute_updated_tree(
            &mut tree,
            start_index,
            middle_index,
            start_precompute_index,
            middle_precompute_index,
        )?;

        // Compute the root hash, by iterating from the root level up to `DEPTH`.
        let mut root_hash = tree[0];
        for _ in 0..padding_depth {
            // Update the root hash, by hashing the current root hash with the empty hash.
            root_hash = self.path_hasher.hash_children(&root_hash, &self.empty_hash)?;
        }
        lap!(timer, "Hashed {} padding levels", padding_depth);

        finish!(timer);

        Ok(Self {
            leaf_hasher: self.leaf_hasher.clone(),
            path_hasher: self.path_hasher.clone(),
            root: root_hash,
            tree,
            empty_hash: self.empty_hash,
            number_of_leaves: self.number_of_leaves + new_leaves.len(),
        })
    }

    #[inline]
    /// Updates the Merkle tree with the given new leaves appended to it.
    pub fn append(&mut self, new_leaves: &[LH::Leaf]) -> Result<()> {
        let timer = timer!("MerkleTree::append");

        // Compute the updated Merkle tree with the new leaves.
        let updated_tree = self.prepare_append(new_leaves)?;
        // Update the tree at the very end, so the original tree is not altered in case of failure.
        *self = updated_tree;

        finish!(timer);
        Ok(())
    }

    #[inline]
    /// Returns a new Merkle tree with the last 'n' leaves removed from it.
    pub fn prepare_remove_last_n(&self, n: usize) -> Result<Self> {
        let timer = timer!("MerkleTree::prepare_remove_last_n");

        ensure!(n > 0, "Cannot remove zero leaves from the Merkle tree");

        // Determine the updated number of leaves, after removing the last 'n' leaves.
        let updated_number_of_leaves = self.number_of_leaves.checked_sub(n).ok_or_else(|| {
            anyhow!("Failed to remove '{n}' leaves from the Merkle tree, as it only contains {}", self.number_of_leaves)
        })?;

        // Compute the maximum number of leaves.
        let max_leaves = match (updated_number_of_leaves).checked_next_power_of_two() {
            Some(num_leaves) => num_leaves,
            None => bail!("Integer overflow when computing the maximum number of leaves in the Merkle tree"),
        };
        // Compute the number of nodes.
        let num_nodes = max_leaves - 1;
        // Compute the tree size as the maximum number of leaves plus the number of nodes.
        let tree_size = num_nodes + max_leaves;
        // Compute the number of levels in the Merkle tree (i.e. log2(tree_size)).
        let tree_depth = tree_depth::<DEPTH>(tree_size)?;
        // Compute the number of padded levels.
        let padding_depth = DEPTH - tree_depth;

        // Initialize the Merkle tree.
        let mut tree = vec![self.empty_hash; num_nodes];
        // Extend the new Merkle tree with the existing leaf hashes, excluding the last 'n' leaves.
        tree.extend(&self.leaf_hashes()?[..updated_number_of_leaves]);
        // Resize the new Merkle tree with empty hashes to pad up to `tree_size`.
        tree.resize(tree_size, self.empty_hash);
        lap!(timer, "Resizing to {} leaves", updated_number_of_leaves);

        // Initialize a start index to track the starting index of the current level.
        let start_index = num_nodes;
        // Initialize a middle index to separate the precomputed indices from the new indices that need to be computed.
        let middle_index = num_nodes + updated_number_of_leaves;
        // Initialize a precompute index to track the starting index of each precomputed level.
        let start_precompute_index = match self.number_of_leaves.checked_next_power_of_two() {
            Some(num_leaves) => num_leaves - 1,
            None => bail!("Integer overflow when computing the Merkle tree precompute index"),
        };
        // Initialize a precompute index to track the middle index of each precomputed level.
        let middle_precompute_index = match num_nodes == start_precompute_index {
            // If the old tree and new tree are of the same size, then we can copy over the right half of the old tree.
            true => Some(start_precompute_index + self.number_of_leaves + 1),
            // true => None,
            // Otherwise, do nothing, since shrinking the tree is already free.
            false => None,
        };

        // Compute and store the hashes for each level, iterating from the penultimate level to the root level.
        self.compute_updated_tree(
            &mut tree,
            start_index,
            middle_index,
            start_precompute_index,
            middle_precompute_index,
        )?;

        // Compute the root hash, by iterating from the root level up to `DEPTH`.
        let mut root_hash = tree[0];
        for _ in 0..padding_depth {
            // Update the root hash, by hashing the current root hash with the empty hash.
            root_hash = self.path_hasher.hash_children(&root_hash, &self.empty_hash)?;
        }
        lap!(timer, "Hashed {} padding levels", padding_depth);

        finish!(timer);

        Ok(Self {
            leaf_hasher: self.leaf_hasher.clone(),
            path_hasher: self.path_hasher.clone(),
            root: root_hash,
            tree,
            empty_hash: self.empty_hash,
            number_of_leaves: updated_number_of_leaves,
        })
    }

    #[inline]
    /// Updates the Merkle tree with the last 'n' leaves removed from it.
    pub fn remove_last_n(&mut self, n: usize) -> Result<()> {
        let timer = timer!("MerkleTree::remove_last_n");

        // Compute the updated Merkle tree with the last 'n' leaves removed.
        let updated_tree = self.prepare_remove_last_n(n)?;
        // Update the tree at the very end, so the original tree is not altered in case of failure.
        *self = updated_tree;

        finish!(timer);
        Ok(())
    }

    #[inline]
    /// Returns the Merkle path for the given leaf index and leaf.
    pub fn prove(&self, leaf_index: usize, leaf: &LH::Leaf) -> Result<MerklePath<E, DEPTH>> {
        // Ensure the leaf index is valid.
        ensure!(leaf_index < self.number_of_leaves, "The given Merkle leaf index is out of bounds");

        // Compute the leaf hash.
        let leaf_hash = self.leaf_hasher.hash_leaf(leaf)?;

        // Compute the start index (on the left) for the leaf hashes level in the Merkle tree.
        let start = match self.number_of_leaves.checked_next_power_of_two() {
            Some(num_leaves) => num_leaves - 1,
            None => bail!("Integer overflow when computing the Merkle tree start index"),
        };
        // Compute the absolute index of the leaf in the Merkle tree.
        let mut index = start + leaf_index;
        // Ensure the leaf index is valid.
        ensure!(index < self.tree.len(), "The given Merkle leaf index is out of bounds");
        // Ensure the leaf hash matches the one in the tree.
        ensure!(self.tree[index] == leaf_hash, "The given Merkle leaf does not match the one in the Merkle tree");

        // Initialize a vector for the Merkle path.
        let mut path = Vec::with_capacity(DEPTH as usize);

        // Iterate from the leaf hash to the root level, storing the sibling hashes along the path.
        for _ in 0..DEPTH {
            // Compute the index of the sibling hash, if it exists.
            if let Some(sibling) = sibling(index) {
                // Append the sibling hash to the path.
                path.push(self.tree[sibling]);
                // Compute the index of the parent hash, if it exists.
                match parent(index) {
                    // Update the index to the parent index.
                    Some(parent) => index = parent,
                    // If the parent does not exist, the path is complete.
                    None => break,
                }
            }
        }

        // If the Merkle path length is not equal to `DEPTH`, pad the path with the empty hash.
        path.resize(DEPTH as usize, self.empty_hash);

        // Return the Merkle path.
        MerklePath::try_from((U64::new(leaf_index as u64), path))
    }

    /// Returns `true` if the given Merkle path is valid for the given root and leaf.
    pub fn verify(&self, path: &MerklePath<E, DEPTH>, root: &PH::Hash, leaf: &LH::Leaf) -> bool {
        path.verify(&self.leaf_hasher, &self.path_hasher, root, leaf)
    }

    /// Returns the Merkle root of the tree.
    pub const fn root(&self) -> &PH::Hash {
        &self.root
    }

    /// Returns the Merkle tree (excluding the hashes of the leaves).
    pub fn tree(&self) -> &[PH::Hash] {
        &self.tree
    }

    /// Returns the empty hash.
    pub const fn empty_hash(&self) -> &PH::Hash {
        &self.empty_hash
    }

    /// Returns the leaf hashes from the Merkle tree.
    pub fn leaf_hashes(&self) -> Result<&[LH::Hash]> {
        // Compute the start index (on the left) for the leaf hashes level in the Merkle tree.
        let start = match self.number_of_leaves.checked_next_power_of_two() {
            Some(num_leaves) => num_leaves - 1,
            None => bail!("Integer overflow when computing the Merkle tree start index"),
        };
        // Compute the end index (on the right) for the leaf hashes level in the Merkle tree.
        let end = start + self.number_of_leaves;
        // Return the leaf hashes.
        Ok(&self.tree[start..end])
    }

    /// Returns the number of leaves in the Merkle tree.
    pub const fn number_of_leaves(&self) -> usize {
        self.number_of_leaves
    }

    /// Compute and store the hashes for each level, iterating from the penultimate level to the root level.
    ///
    /// ```ignore
    ///  start_index      middle_index                              end_index
    ///  start_precompute_index         middle_precompute_index     end_index
    /// ```
    #[inline]
    fn compute_updated_tree(
        &self,
        tree: &mut [Field<E>],
        mut start_index: usize,
        mut middle_index: usize,
        mut start_precompute_index: usize,
        mut middle_precompute_index: Option<usize>,
    ) -> Result<()> {
        // Initialize a timer for the while loop.
        let timer = timer!("MerkleTree::compute_updated_tree");

        // Compute and store the hashes for each level, iterating from the penultimate level to the root level.
        while let (Some(start), Some(middle)) = (parent(start_index), parent(middle_index)) {
            // Compute the end index of the current level.
            let end = left_child(start);

            // If the current level has precomputed indices, copy them instead of recomputing them.
            if let Some(start_precompute) = parent(start_precompute_index) {
                // Compute the end index of the precomputed level.
                let end_precompute = start_precompute + (middle - start);
                // Copy the hashes for each node in the current level.
                tree[start..middle].copy_from_slice(&self.tree[start_precompute..end_precompute]);
                // Update the precompute index for the next level.
                start_precompute_index = start_precompute;
            } else {
                // Ensure the start index is equal to the middle index, as all precomputed indices have been processed.
                ensure!(start == middle, "Failed to process all left precomputed indices in the Merkle tree");
            }
            lap!(timer, "Precompute (Left): {start} -> {middle}");

            // If the current level has precomputed indices, copy them instead of recomputing them.
            // Note: This logic works because the old tree and new tree are the same power of two.
            if let Some(middle_precompute) = middle_precompute_index {
                if let Some(middle_precompute) = parent(middle_precompute) {
                    // Construct the children for the new indices in the current level.
                    let tuples = (middle..middle_precompute)
                        .map(|i| (tree[left_child(i)], tree[right_child(i)]))
                        .collect::<Vec<_>>();
                    // Process the indices that need to be computed for the current level.
                    // If any level requires computing more than 100 nodes, borrow the tree for performance.
                    match tuples.len() >= 100 {
                        // Option 1: Borrow the tree to compute and store the hashes for the new indices in the current level.
                        true => cfg_iter_mut!(tree[middle..middle_precompute]).zip_eq(cfg_iter!(tuples)).try_for_each(
                            |(node, (left, right))| {
                                *node = self.path_hasher.hash_children(left, right)?;
                                Ok::<_, Error>(())
                            },
                        )?,
                        // Option 2: Compute and store the hashes for the new indices in the current level.
                        false => tree[middle..middle_precompute].iter_mut().zip_eq(&tuples).try_for_each(
                            |(node, (left, right))| {
                                *node = self.path_hasher.hash_children(left, right)?;
                                Ok::<_, Error>(())
                            },
                        )?,
                    }
                    lap!(timer, "Compute: {middle} -> {middle_precompute}");

                    // Copy the hashes for each node in the current level.
                    tree[middle_precompute..end].copy_from_slice(&self.tree[middle_precompute..end]);
                    // Update the precompute index for the next level.
                    middle_precompute_index = Some(middle_precompute + 1);
                    lap!(timer, "Precompute (Right): {middle_precompute} -> {end}");
                } else {
                    // Ensure the middle precompute index is equal to the end index, as all precomputed indices have been processed.
                    ensure!(
                        middle_precompute == end,
                        "Failed to process all right precomputed indices in the Merkle tree"
                    );
                }
            } else {
                // Construct the children for the new indices in the current level.
                let tuples = (middle..end).map(|i| (tree[left_child(i)], tree[right_child(i)])).collect::<Vec<_>>();
                // Process the indices that need to be computed for the current level.
                // If any level requires computing more than 100 nodes, borrow the tree for performance.
                match tuples.len() >= 100 {
                    // Option 1: Borrow the tree to compute and store the hashes for the new indices in the current level.
                    true => cfg_iter_mut!(tree[middle..end]).zip_eq(cfg_iter!(tuples)).try_for_each(
                        |(node, (left, right))| {
                            *node = self.path_hasher.hash_children(left, right)?;
                            Ok::<_, Error>(())
                        },
                    )?,
                    // Option 2: Compute and store the hashes for the new indices in the current level.
                    false => tree[middle..end].iter_mut().zip_eq(&tuples).try_for_each(|(node, (left, right))| {
                        *node = self.path_hasher.hash_children(left, right)?;
                        Ok::<_, Error>(())
                    })?,
                }
                lap!(timer, "Compute: {middle} -> {end}");
            }

            // Update the start index for the next level.
            start_index = start;
            // Update the middle index for the next level.
            middle_index = middle;
        }

        // End the timer for the while loop.
        finish!(timer);

        Ok(())
    }
}

/// Returns the depth of the tree, given the size of the tree.
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn tree_depth<const DEPTH: u8>(tree_size: usize) -> Result<u8> {
    let tree_size = u64::try_from(tree_size)?;
    // Ensure the tree size is less than 2^52 (for casting to an f64).
    let tree_depth = match tree_size < 4503599627370496_u64 {
        // Compute the log2 of the tree size.
        true => (tree_size as f64).log2(),
        false => bail!("Tree size must be less than 2^52"),
    };
    // Ensure the tree depth is within a u8 range.
    match tree_depth <= u8::MAX as f64 {
        true => {
            // Convert the tree depth to a u8.
            let tree_depth = tree_depth as u8;
            // Ensure the tree depth is within the depth bound.
            match tree_depth <= DEPTH {
                // Return the tree depth.
                true => Ok(tree_depth),
                false => bail!("Merkle tree cannot exceed depth {DEPTH}: attempted to reach depth {tree_depth}"),
            }
        }
        false => bail!("Merkle tree depth ({tree_depth}) exceeds maximum size ({})", u8::MAX),
    }
}

/// Returns the index of the left child, given an index.
#[inline]
const fn left_child(index: usize) -> usize {
    2 * index + 1
}

/// Returns the index of the right child, given an index.
#[inline]
const fn right_child(index: usize) -> usize {
    2 * index + 2
}

/// Returns the index of the sibling, given an index.
#[inline]
const fn sibling(index: usize) -> Option<usize> {
    if is_root(index) {
        None
    } else if is_left_child(index) {
        Some(index + 1)
    } else {
        Some(index - 1)
    }
}

/// Returns true iff the index represents the root.
#[inline]
const fn is_root(index: usize) -> bool {
    index == 0
}

/// Returns true iff the given index represents a left child.
#[inline]
const fn is_left_child(index: usize) -> bool {
    index % 2 == 1
}

/// Returns the index of the parent, given the index of a child.
#[inline]
const fn parent(index: usize) -> Option<usize> {
    if index > 0 { Some((index - 1) >> 1) } else { None }
}