ct_merkle/
mem_backed_tree.rs

1use crate::{tree_util::*, RootHash};
2
3use alloc::vec::Vec;
4use core::fmt;
5
6use digest::Digest;
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
10
11/// An error representing what went wrong when running `MemoryBackedTree::self_check`.
12#[derive(Debug)]
13pub enum SelfCheckError {
14    /// The node at the given index is missing
15    MissingNode(u64),
16
17    /// The node at the given index has the wrong hash
18    IncorrectHash(u64),
19
20    /// The number of internal nodes in this struct exceeds the number of nodes that a tree with
21    /// this many leaves would hold.
22    TooManyInternalNodes,
23
24    /// There are so many leaves that the full tree could not possibly fit in memory
25    TooManyLeaves,
26}
27
28impl fmt::Display for SelfCheckError {
29    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
30        match self {
31            SelfCheckError::MissingNode(idx) => write!(f, "the node at index {} is missing", idx),
32            SelfCheckError::IncorrectHash(idx) => {
33                write!(f, "the node at index {} has the wrong hash", idx)
34            }
35            SelfCheckError::TooManyInternalNodes => {
36                write!(
37                    f,
38                    "the number of internal nodes in this struct exceedsc the number of nodes \
39                    that a tree with this many leaves would hold"
40                )
41            }
42            SelfCheckError::TooManyLeaves => {
43                write!(
44                    f,
45                    "there are so many leaves that the full tree could not possibly fit in memory"
46                )
47            }
48        }
49    }
50}
51
52#[cfg(feature = "std")]
53impl std::error::Error for SelfCheckError {}
54
55/// An in-memory append-only Merkle tree implementation, supporting inclusion and consistency
56/// proofs. This stores leaf values, not just leaf hashes.
57#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
58#[derive(Clone, Debug)]
59pub struct MemoryBackedTree<H, T>
60where
61    H: Digest,
62    T: HashableLeaf,
63{
64    /// The leaves of this tree. This contains all the items
65    pub(crate) leaves: Vec<T>,
66
67    /// The internal nodes of the tree. This contains all the hashes of the leaves and parents, etc.
68    // The serde bounds are "" here because every digest::Output is Serializable and
69    // Deserializable, with no extra assumptions necessary
70    #[cfg_attr(feature = "serde", serde(bound(deserialize = "", serialize = "")))]
71    pub(crate) internal_nodes: Vec<digest::Output<H>>,
72}
73
74impl<H, T> Default for MemoryBackedTree<H, T>
75where
76    H: Digest,
77    T: HashableLeaf,
78{
79    fn default() -> Self {
80        MemoryBackedTree {
81            leaves: Vec::new(),
82            internal_nodes: Vec::new(),
83        }
84    }
85}
86
87impl<H, T> MemoryBackedTree<H, T>
88where
89    H: Digest,
90    T: HashableLeaf,
91{
92    pub fn new() -> Self {
93        Self::default()
94    }
95
96    /// Appends the given item to the end of the list.
97    ///
98    /// # Panics
99    /// Panics if `self.len() > ⌊usize::MAX / 2⌋ - 1`. Also panics if this tree is malformed,
100    /// e.g., deserialized from disk without passing a [`MemoryBackedTree::self_check`].
101    pub fn push(&mut self, new_val: T) {
102        // Make sure we can push two elements to internal_nodes (two because every append involves
103        // adding a parent node somewhere). usize::MAX is the max capacity of a vector, minus 1. So
104        // usize::MAX-1 is the correct bound to use here. Equivalently, if l is a leaf, then 2l
105        // is the internal index of it. To represent the next two leaves, we need 2(self.len() + 1)
106        // <= usize::MAX, or self.len() + 1 <= usize::MAX / 2
107        assert!(
108            self.internal_nodes.len() < usize::MAX, // equivly, <= usize::MAX - 1
109            "cannot push; tree is full"
110        );
111
112        // We push the new value, a node for its hash, and a node for its parent (assuming the tree
113        // isn't a singleton). The hash and parent nodes will get overwritten by recalculate_path()
114        self.leaves.push(new_val);
115        self.internal_nodes.push(digest::Output::<H>::default());
116
117        // If the tree is not a singleton, add a new parent node
118        if self.internal_nodes.len() > 1 {
119            self.internal_nodes.push(digest::Output::<H>::default());
120        }
121
122        // Recalculate the tree starting at the new leaf
123        let num_leaves = self.len();
124        let new_leaf_idx = LeafIdx::new(num_leaves - 1);
125        // recalculate_path() requires its leaf idx to be less than usize::MAX. This is guaranteed
126        // because it's self.len() - 1.
127        self.recalculate_path(new_leaf_idx)
128    }
129
130    /// Checks that this tree is well-formed. This can take a while if the tree is large. Run this
131    /// if you've deserialized this tree and don't trust the source. If a tree is malformed, other
132    /// methods will panic or behave oddly.
133    pub fn self_check(&self) -> Result<(), SelfCheckError> {
134        // If the number of leaves is more than an in-memory tree could support, return an error
135        let num_leaves = self.len();
136        if num_leaves > (usize::MAX / 2) as u64 + 1 {
137            return Err(SelfCheckError::TooManyLeaves);
138        }
139
140        // If the number of internal nodes is less than the necessary size of the tree, return an error
141        // This cannot panic because we checked that num_leaves isn't too big above
142        let num_nodes = num_internal_nodes(num_leaves);
143        if (self.internal_nodes.len() as u64) < num_nodes {
144            return Err(SelfCheckError::MissingNode(self.internal_nodes.len() as u64));
145        }
146        // If the number of internal nodes exceeds the necessary size of the tree, return an error
147        if (self.internal_nodes.len() as u64) > num_nodes {
148            return Err(SelfCheckError::TooManyInternalNodes);
149        }
150
151        // Start on level 0. We check the leaf hashes
152        for (leaf_idx, leaf) in self.leaves.iter().enumerate() {
153            // This cannot panic because we checked that num_leaves isn't too big above
154            let leaf_hash_idx: InternalIdx = LeafIdx::new(leaf_idx as u64).into();
155
156            // Compute the leaf hash and retrieve the stored leaf hash
157            let expected_hash = leaf_hash::<H, _>(leaf);
158            // We can unwrap() because we checked above that the number of nodes necessary for this
159            // tree fits in memory
160            let Some(stored_hash) = self.internal_nodes.get(leaf_hash_idx.as_usize().unwrap())
161            else {
162                return Err(SelfCheckError::MissingNode(leaf_hash_idx.as_u64()));
163            };
164
165            // If the hashes don't match, that's an error
166            if stored_hash != &expected_hash {
167                return Err(SelfCheckError::IncorrectHash(leaf_hash_idx.as_u64()));
168            }
169        }
170
171        // Now go through the rest of the levels, checking that the current node equals the hash of
172        // the children.
173        for level in 1..=root_idx(num_leaves).level() {
174            // First index on level i is 2^i - 1. Each subsequent index at level i is at an offset
175            // of 2^(i+1).
176            let start_idx = 2u64.pow(level) - 1;
177            let step_size = 2usize.pow(level + 1);
178            for parent_idx in (start_idx..num_nodes).step_by(step_size) {
179                // Get the left and right children, erroring if they don't exist
180                // new() doesn't panic because parent_idx is a valid node in num_leaves
181                let parent_idx = InternalIdx::new(parent_idx);
182                // *_child() don't panic because parent is a parent node, since level >= 1
183                let left_child_idx = parent_idx.left_child();
184                let right_child_idx = parent_idx.right_child(num_leaves);
185
186                // We may unwrap the .as_usize() computations because we already know from the check
187                // above that self.internal_nodes.len() == num_nodes, i.e., the total number of
188                // nodes in the tree fits in memory, and therefore all the indices are at most
189                // `usize::MAX`.
190
191                let left_child = self
192                    .internal_nodes
193                    .get(left_child_idx.as_usize().unwrap())
194                    .ok_or(SelfCheckError::MissingNode(left_child_idx.as_u64()))?;
195                let right_child = self
196                    .internal_nodes
197                    .get(right_child_idx.as_usize().unwrap())
198                    .ok_or(SelfCheckError::MissingNode(right_child_idx.as_u64()))?;
199
200                // Compute the expected hash and get the stored hash
201                let expected_hash = parent_hash::<H>(left_child, right_child);
202                let stored_hash = self
203                    .internal_nodes
204                    .get(parent_idx.as_usize().unwrap())
205                    .ok_or(SelfCheckError::MissingNode(parent_idx.as_u64()))?;
206
207                // If the hashes don't match, that's an error
208                if stored_hash != &expected_hash {
209                    return Err(SelfCheckError::IncorrectHash(parent_idx.as_u64()));
210                }
211            }
212        }
213
214        Ok(())
215    }
216
217    /// Recalculates the hashes on the path from `leaf_idx` to the root.
218    ///
219    /// # Panics
220    /// Panics if the path doesn't exist. In other words, this tree MUST NOT be missing internal
221    /// nodes or leaves. Also panics if the given leaf index exceeds `usize::MAX`.
222    fn recalculate_path(&mut self, leaf_idx: LeafIdx) {
223        // First update the leaf hash
224        let leaf = &self.leaves[leaf_idx.as_usize().unwrap()];
225        let mut cur_idx: InternalIdx = leaf_idx.into();
226        self.internal_nodes[cur_idx.as_usize().unwrap()] = leaf_hash::<H, _>(leaf);
227
228        // Get some data for the upcoming loop
229        let num_leaves = self.len();
230        let root_idx = root_idx(num_leaves);
231
232        // Now iteratively update the parent of cur_idx
233        while cur_idx != root_idx {
234            let parent_idx = cur_idx.parent(num_leaves);
235
236            // We can unwrap() the .as_usize() computations because we assumed the tree is not
237            // missing any internal nodes, i.e., it fits in memory, i.e., all the indices are at
238            // most usize::MAX
239
240            // Get the values of the current node and its sibling
241            let cur_node = &self.internal_nodes[cur_idx.as_usize().unwrap()];
242            let sibling = {
243                let sibling_idx = &cur_idx.sibling(num_leaves);
244                &self.internal_nodes[sibling_idx.as_usize().unwrap()]
245            };
246
247            // Compute the parent hash. If cur_node is to the left of the parent, the hash is
248            // H(0x01 || cur_node || sibling). Otherwise it's H(0x01 || sibling || cur_node).
249            if cur_idx.is_left(num_leaves) {
250                self.internal_nodes[parent_idx.as_usize().unwrap()] =
251                    parent_hash::<H>(cur_node, sibling);
252            } else {
253                self.internal_nodes[parent_idx.as_usize().unwrap()] =
254                    parent_hash::<H>(sibling, cur_node);
255            }
256
257            // Go up a level
258            cur_idx = parent_idx;
259        }
260    }
261
262    /// Returns the root hash of this tree. The value and type uniquely describe this tree.
263    ///
264    /// # Panics
265    /// Panics if this tree is malformed, e.g., deserialized from disk without passing a
266    /// [`MemoryBackedTree::self_check`].
267    pub fn root(&self) -> RootHash<H> {
268        let num_leaves = self.len();
269
270        // Root of an empty tree is H("")
271        let root_hash = if num_leaves == 0 {
272            H::digest(b"")
273        } else {
274            //  Otherwise it's the internal node at the root index
275            // This cannot panic. In a valid tree, self.internal_nodes fits in memory, meaning that
276            // num_leaves is within range.
277            let root_idx = root_idx(num_leaves);
278            // We can unwrap() because we assume we're not missing any internal nodes. That is,
279            // self.internal_nodes.len() <= usize::MAX, which implies that root_idx <= usize::MAX
280            self.internal_nodes[root_idx.as_usize().unwrap()].clone()
281        };
282
283        RootHash::new(root_hash, num_leaves)
284    }
285
286    /// Tries to get the item at the given index
287    pub fn get(&self, idx: usize) -> Option<&T> {
288        self.leaves.get(idx)
289    }
290
291    /// Returns all the items
292    pub fn items(&self) -> &[T] {
293        &self.leaves
294    }
295
296    /// Returns the number of items
297    pub fn len(&self) -> u64 {
298        self.leaves.len() as u64
299    }
300
301    /// Returns true if this tree has no items
302    pub fn is_empty(&self) -> bool {
303        self.len() == 0
304    }
305}
306
307#[cfg(test)]
308pub(crate) mod test {
309    use super::*;
310    use crate::test_util::{Hash, Leaf};
311
312    use rand::{Rng, RngCore};
313
314    // Creates a random T
315    pub(crate) fn rand_val<R: RngCore>(mut rng: R) -> Leaf {
316        let mut val = Leaf::default();
317        rng.fill_bytes(&mut val);
318
319        val
320    }
321
322    // Creates a random CtMerkleTree with `size` items
323    pub(crate) fn rand_tree<R: RngCore>(mut rng: R, size: usize) -> MemoryBackedTree<Hash, Leaf> {
324        let mut t = MemoryBackedTree::<Hash, Leaf>::default();
325
326        for _ in 0..size {
327            let val = rand_val(&mut rng);
328            t.push(val);
329        }
330
331        t
332    }
333
334    // Adds a bunch of elements to the tree and then tests the tree's well-formedness
335    #[test]
336    fn self_check() {
337        let mut rng = rand::rng();
338        for _ in 0..1000 {
339            let num_items = rng.random_range(0..230);
340            let tree = rand_tree(&mut rng, num_items);
341            tree.self_check().expect("self check failed");
342        }
343    }
344
345    // Checks that a serialization round trip doesn't affect trees or roots
346    #[cfg(feature = "serde")]
347    #[test]
348    fn ser_deser() {
349        let mut rng = rand::rng();
350
351        for _ in 0..100 {
352            let num_items = rng.random_range(0..230);
353            let tree = rand_tree(&mut rng, num_items);
354
355            // Serialize and deserialize the tree
356            let roundtrip_tree = crate::test_util::serde_roundtrip(tree.clone());
357
358            // Run a self-check and ensure the root hasn't changed
359            roundtrip_tree.self_check().unwrap();
360            assert_eq!(tree.root(), roundtrip_tree.root());
361        }
362    }
363}