ct_merkle/mem_backed_tree.rs
1use crate::{tree_util::*, RootHash};
2
3use alloc::vec::Vec;
4use core::fmt;
5
6use digest::Digest;
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
10
11/// An error representing what went wrong when running `MemoryBackedTree::self_check`.
12#[derive(Debug)]
13pub enum SelfCheckError {
14 /// The node at the given index is missing
15 MissingNode(u64),
16
17 /// The node at the given index has the wrong hash
18 IncorrectHash(u64),
19
20 /// The number of internal nodes in this struct exceeds the number of nodes that a tree with
21 /// this many leaves would hold.
22 TooManyInternalNodes,
23
24 /// There are so many leaves that the full tree could not possibly fit in memory
25 TooManyLeaves,
26}
27
28impl fmt::Display for SelfCheckError {
29 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
30 match self {
31 SelfCheckError::MissingNode(idx) => write!(f, "the node at index {} is missing", idx),
32 SelfCheckError::IncorrectHash(idx) => {
33 write!(f, "the node at index {} has the wrong hash", idx)
34 }
35 SelfCheckError::TooManyInternalNodes => {
36 write!(
37 f,
38 "the number of internal nodes in this struct exceedsc the number of nodes \
39 that a tree with this many leaves would hold"
40 )
41 }
42 SelfCheckError::TooManyLeaves => {
43 write!(
44 f,
45 "there are so many leaves that the full tree could not possibly fit in memory"
46 )
47 }
48 }
49 }
50}
51
52#[cfg(feature = "std")]
53impl std::error::Error for SelfCheckError {}
54
55/// An in-memory append-only Merkle tree implementation, supporting inclusion and consistency
56/// proofs. This stores leaf values, not just leaf hashes.
57#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
58#[derive(Clone, Debug)]
59pub struct MemoryBackedTree<H, T>
60where
61 H: Digest,
62 T: HashableLeaf,
63{
64 /// The leaves of this tree. This contains all the items
65 pub(crate) leaves: Vec<T>,
66
67 /// The internal nodes of the tree. This contains all the hashes of the leaves and parents, etc.
68 // The serde bounds are "" here because every digest::Output is Serializable and
69 // Deserializable, with no extra assumptions necessary
70 #[cfg_attr(feature = "serde", serde(bound(deserialize = "", serialize = "")))]
71 pub(crate) internal_nodes: Vec<digest::Output<H>>,
72}
73
74impl<H, T> Default for MemoryBackedTree<H, T>
75where
76 H: Digest,
77 T: HashableLeaf,
78{
79 fn default() -> Self {
80 MemoryBackedTree {
81 leaves: Vec::new(),
82 internal_nodes: Vec::new(),
83 }
84 }
85}
86
87impl<H, T> MemoryBackedTree<H, T>
88where
89 H: Digest,
90 T: HashableLeaf,
91{
92 pub fn new() -> Self {
93 Self::default()
94 }
95
96 /// Appends the given item to the end of the list.
97 ///
98 /// # Panics
99 /// Panics if `self.len() > ⌊usize::MAX / 2⌋ - 1`. Also panics if this tree is malformed,
100 /// e.g., deserialized from disk without passing a [`MemoryBackedTree::self_check`].
101 pub fn push(&mut self, new_val: T) {
102 // Make sure we can push two elements to internal_nodes (two because every append involves
103 // adding a parent node somewhere). usize::MAX is the max capacity of a vector, minus 1. So
104 // usize::MAX-1 is the correct bound to use here. Equivalently, if l is a leaf, then 2l
105 // is the internal index of it. To represent the next two leaves, we need 2(self.len() + 1)
106 // <= usize::MAX, or self.len() + 1 <= usize::MAX / 2
107 assert!(
108 self.internal_nodes.len() < usize::MAX, // equivly, <= usize::MAX - 1
109 "cannot push; tree is full"
110 );
111
112 // We push the new value, a node for its hash, and a node for its parent (assuming the tree
113 // isn't a singleton). The hash and parent nodes will get overwritten by recalculate_path()
114 self.leaves.push(new_val);
115 self.internal_nodes.push(digest::Output::<H>::default());
116
117 // If the tree is not a singleton, add a new parent node
118 if self.internal_nodes.len() > 1 {
119 self.internal_nodes.push(digest::Output::<H>::default());
120 }
121
122 // Recalculate the tree starting at the new leaf
123 let num_leaves = self.len();
124 let new_leaf_idx = LeafIdx::new(num_leaves - 1);
125 // recalculate_path() requires its leaf idx to be less than usize::MAX. This is guaranteed
126 // because it's self.len() - 1.
127 self.recalculate_path(new_leaf_idx)
128 }
129
130 /// Checks that this tree is well-formed. This can take a while if the tree is large. Run this
131 /// if you've deserialized this tree and don't trust the source. If a tree is malformed, other
132 /// methods will panic or behave oddly.
133 pub fn self_check(&self) -> Result<(), SelfCheckError> {
134 // If the number of leaves is more than an in-memory tree could support, return an error
135 let num_leaves = self.len();
136 if num_leaves > (usize::MAX / 2) as u64 + 1 {
137 return Err(SelfCheckError::TooManyLeaves);
138 }
139
140 // If the number of internal nodes is less than the necessary size of the tree, return an error
141 // This cannot panic because we checked that num_leaves isn't too big above
142 let num_nodes = num_internal_nodes(num_leaves);
143 if (self.internal_nodes.len() as u64) < num_nodes {
144 return Err(SelfCheckError::MissingNode(self.internal_nodes.len() as u64));
145 }
146 // If the number of internal nodes exceeds the necessary size of the tree, return an error
147 if (self.internal_nodes.len() as u64) > num_nodes {
148 return Err(SelfCheckError::TooManyInternalNodes);
149 }
150
151 // Start on level 0. We check the leaf hashes
152 for (leaf_idx, leaf) in self.leaves.iter().enumerate() {
153 // This cannot panic because we checked that num_leaves isn't too big above
154 let leaf_hash_idx: InternalIdx = LeafIdx::new(leaf_idx as u64).into();
155
156 // Compute the leaf hash and retrieve the stored leaf hash
157 let expected_hash = leaf_hash::<H, _>(leaf);
158 // We can unwrap() because we checked above that the number of nodes necessary for this
159 // tree fits in memory
160 let Some(stored_hash) = self.internal_nodes.get(leaf_hash_idx.as_usize().unwrap())
161 else {
162 return Err(SelfCheckError::MissingNode(leaf_hash_idx.as_u64()));
163 };
164
165 // If the hashes don't match, that's an error
166 if stored_hash != &expected_hash {
167 return Err(SelfCheckError::IncorrectHash(leaf_hash_idx.as_u64()));
168 }
169 }
170
171 // Now go through the rest of the levels, checking that the current node equals the hash of
172 // the children.
173 for level in 1..=root_idx(num_leaves).level() {
174 // First index on level i is 2^i - 1. Each subsequent index at level i is at an offset
175 // of 2^(i+1).
176 let start_idx = 2u64.pow(level) - 1;
177 let step_size = 2usize.pow(level + 1);
178 for parent_idx in (start_idx..num_nodes).step_by(step_size) {
179 // Get the left and right children, erroring if they don't exist
180 // new() doesn't panic because parent_idx is a valid node in num_leaves
181 let parent_idx = InternalIdx::new(parent_idx);
182 // *_child() don't panic because parent is a parent node, since level >= 1
183 let left_child_idx = parent_idx.left_child();
184 let right_child_idx = parent_idx.right_child(num_leaves);
185
186 // We may unwrap the .as_usize() computations because we already know from the check
187 // above that self.internal_nodes.len() == num_nodes, i.e., the total number of
188 // nodes in the tree fits in memory, and therefore all the indices are at most
189 // `usize::MAX`.
190
191 let left_child = self
192 .internal_nodes
193 .get(left_child_idx.as_usize().unwrap())
194 .ok_or(SelfCheckError::MissingNode(left_child_idx.as_u64()))?;
195 let right_child = self
196 .internal_nodes
197 .get(right_child_idx.as_usize().unwrap())
198 .ok_or(SelfCheckError::MissingNode(right_child_idx.as_u64()))?;
199
200 // Compute the expected hash and get the stored hash
201 let expected_hash = parent_hash::<H>(left_child, right_child);
202 let stored_hash = self
203 .internal_nodes
204 .get(parent_idx.as_usize().unwrap())
205 .ok_or(SelfCheckError::MissingNode(parent_idx.as_u64()))?;
206
207 // If the hashes don't match, that's an error
208 if stored_hash != &expected_hash {
209 return Err(SelfCheckError::IncorrectHash(parent_idx.as_u64()));
210 }
211 }
212 }
213
214 Ok(())
215 }
216
217 /// Recalculates the hashes on the path from `leaf_idx` to the root.
218 ///
219 /// # Panics
220 /// Panics if the path doesn't exist. In other words, this tree MUST NOT be missing internal
221 /// nodes or leaves. Also panics if the given leaf index exceeds `usize::MAX`.
222 fn recalculate_path(&mut self, leaf_idx: LeafIdx) {
223 // First update the leaf hash
224 let leaf = &self.leaves[leaf_idx.as_usize().unwrap()];
225 let mut cur_idx: InternalIdx = leaf_idx.into();
226 self.internal_nodes[cur_idx.as_usize().unwrap()] = leaf_hash::<H, _>(leaf);
227
228 // Get some data for the upcoming loop
229 let num_leaves = self.len();
230 let root_idx = root_idx(num_leaves);
231
232 // Now iteratively update the parent of cur_idx
233 while cur_idx != root_idx {
234 let parent_idx = cur_idx.parent(num_leaves);
235
236 // We can unwrap() the .as_usize() computations because we assumed the tree is not
237 // missing any internal nodes, i.e., it fits in memory, i.e., all the indices are at
238 // most usize::MAX
239
240 // Get the values of the current node and its sibling
241 let cur_node = &self.internal_nodes[cur_idx.as_usize().unwrap()];
242 let sibling = {
243 let sibling_idx = &cur_idx.sibling(num_leaves);
244 &self.internal_nodes[sibling_idx.as_usize().unwrap()]
245 };
246
247 // Compute the parent hash. If cur_node is to the left of the parent, the hash is
248 // H(0x01 || cur_node || sibling). Otherwise it's H(0x01 || sibling || cur_node).
249 if cur_idx.is_left(num_leaves) {
250 self.internal_nodes[parent_idx.as_usize().unwrap()] =
251 parent_hash::<H>(cur_node, sibling);
252 } else {
253 self.internal_nodes[parent_idx.as_usize().unwrap()] =
254 parent_hash::<H>(sibling, cur_node);
255 }
256
257 // Go up a level
258 cur_idx = parent_idx;
259 }
260 }
261
262 /// Returns the root hash of this tree. The value and type uniquely describe this tree.
263 ///
264 /// # Panics
265 /// Panics if this tree is malformed, e.g., deserialized from disk without passing a
266 /// [`MemoryBackedTree::self_check`].
267 pub fn root(&self) -> RootHash<H> {
268 let num_leaves = self.len();
269
270 // Root of an empty tree is H("")
271 let root_hash = if num_leaves == 0 {
272 H::digest(b"")
273 } else {
274 // Otherwise it's the internal node at the root index
275 // This cannot panic. In a valid tree, self.internal_nodes fits in memory, meaning that
276 // num_leaves is within range.
277 let root_idx = root_idx(num_leaves);
278 // We can unwrap() because we assume we're not missing any internal nodes. That is,
279 // self.internal_nodes.len() <= usize::MAX, which implies that root_idx <= usize::MAX
280 self.internal_nodes[root_idx.as_usize().unwrap()].clone()
281 };
282
283 RootHash::new(root_hash, num_leaves)
284 }
285
286 /// Tries to get the item at the given index
287 pub fn get(&self, idx: usize) -> Option<&T> {
288 self.leaves.get(idx)
289 }
290
291 /// Returns all the items
292 pub fn items(&self) -> &[T] {
293 &self.leaves
294 }
295
296 /// Returns the number of items
297 pub fn len(&self) -> u64 {
298 self.leaves.len() as u64
299 }
300
301 /// Returns true if this tree has no items
302 pub fn is_empty(&self) -> bool {
303 self.len() == 0
304 }
305}
306
307#[cfg(test)]
308pub(crate) mod test {
309 use super::*;
310 use crate::test_util::{Hash, Leaf};
311
312 use rand::{Rng, RngCore};
313
314 // Creates a random T
315 pub(crate) fn rand_val<R: RngCore>(mut rng: R) -> Leaf {
316 let mut val = Leaf::default();
317 rng.fill_bytes(&mut val);
318
319 val
320 }
321
322 // Creates a random CtMerkleTree with `size` items
323 pub(crate) fn rand_tree<R: RngCore>(mut rng: R, size: usize) -> MemoryBackedTree<Hash, Leaf> {
324 let mut t = MemoryBackedTree::<Hash, Leaf>::default();
325
326 for _ in 0..size {
327 let val = rand_val(&mut rng);
328 t.push(val);
329 }
330
331 t
332 }
333
334 // Adds a bunch of elements to the tree and then tests the tree's well-formedness
335 #[test]
336 fn self_check() {
337 let mut rng = rand::rng();
338 for _ in 0..1000 {
339 let num_items = rng.random_range(0..230);
340 let tree = rand_tree(&mut rng, num_items);
341 tree.self_check().expect("self check failed");
342 }
343 }
344
345 // Checks that a serialization round trip doesn't affect trees or roots
346 #[cfg(feature = "serde")]
347 #[test]
348 fn ser_deser() {
349 let mut rng = rand::rng();
350
351 for _ in 0..100 {
352 let num_items = rng.random_range(0..230);
353 let tree = rand_tree(&mut rng, num_items);
354
355 // Serialize and deserialize the tree
356 let roundtrip_tree = crate::test_util::serde_roundtrip(tree.clone());
357
358 // Run a self-check and ensure the root hasn't changed
359 roundtrip_tree.self_check().unwrap();
360 assert_eq!(tree.root(), roundtrip_tree.root());
361 }
362 }
363}