milhouse/
tree.rs

1use crate::utils::{Length, opt_hash, opt_packing_depth, opt_packing_factor};
2use crate::{Arc, Error, Leaf, PackedLeaf, UpdateMap, Value};
3use educe::Educe;
4use ethereum_hashing::{ZERO_HASHES, hash32_concat};
5use parking_lot::RwLock;
6use std::collections::BTreeMap;
7use std::collections::HashMap;
8use std::ops::ControlFlow;
9use tree_hash::Hash256;
10
11#[derive(Debug, Educe)]
12#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
13#[educe(PartialEq(bound(T: Value)), Hash)]
14pub enum Tree<T: Value> {
15    Leaf(Leaf<T>),
16    PackedLeaf(PackedLeaf<T>),
17    Node {
18        #[educe(PartialEq(ignore), Hash(ignore))]
19        #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::utils::arb_rwlock))]
20        hash: RwLock<Hash256>,
21        #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::utils::arb_arc))]
22        left: Arc<Self>,
23        #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::utils::arb_arc))]
24        right: Arc<Self>,
25    },
26    Zero(usize),
27}
28
29impl<T: Value> Clone for Tree<T> {
30    fn clone(&self) -> Self {
31        match self {
32            Self::Node { hash, left, right } => Self::Node {
33                hash: RwLock::new(*hash.read()),
34                left: left.clone(),
35                right: right.clone(),
36            },
37            Self::Leaf(leaf) => Self::Leaf(leaf.clone()),
38            Self::PackedLeaf(leaf) => Self::PackedLeaf(leaf.clone()),
39            Self::Zero(depth) => Self::Zero(*depth),
40        }
41    }
42}
43
44impl<T: Value> Tree<T> {
45    pub fn empty(depth: usize) -> Arc<Self> {
46        Self::zero(depth)
47    }
48
49    pub fn node(left: Arc<Self>, right: Arc<Self>, hash: Hash256) -> Arc<Self> {
50        Arc::new(Self::Node {
51            hash: RwLock::new(hash),
52            left,
53            right,
54        })
55    }
56
57    pub fn zero(depth: usize) -> Arc<Self> {
58        Arc::new(Self::Zero(depth))
59    }
60
61    pub fn leaf(value: T) -> Arc<Self> {
62        Arc::new(Self::Leaf(Leaf::new(value)))
63    }
64
65    pub fn leaf_with_hash(value: T, hash: Hash256) -> Arc<Self> {
66        Arc::new(Self::Leaf(Leaf::with_hash(value, hash)))
67    }
68
69    pub fn node_unboxed(left: Arc<Self>, right: Arc<Self>) -> Self {
70        Self::Node {
71            hash: RwLock::new(Hash256::ZERO),
72            left,
73            right,
74        }
75    }
76
77    pub fn zero_unboxed(depth: usize) -> Self {
78        Self::Zero(depth)
79    }
80
81    pub fn leaf_unboxed(value: T) -> Self {
82        Self::Leaf(Leaf::new(value))
83    }
84
85    pub fn get_recursive(&self, index: usize, depth: usize, packing_depth: usize) -> Option<&T> {
86        match self {
87            Self::Leaf(Leaf { value, .. }) if depth == 0 => Some(value),
88            Self::PackedLeaf(PackedLeaf { values, .. }) if depth == 0 => {
89                values.get(index % T::tree_hash_packing_factor())
90            }
91            Self::Node { left, right, .. } if depth > 0 => {
92                let new_depth = depth - 1;
93                // Left
94                if (index >> (new_depth + packing_depth)) & 1 == 0 {
95                    left.get_recursive(index, new_depth, packing_depth)
96                }
97                // Right
98                else {
99                    right.get_recursive(index, new_depth, packing_depth)
100                }
101            }
102            _ => None,
103        }
104    }
105
106    /// Create a new tree where the `index`th leaf is set to `new_value`.
107    ///
108    /// NOTE: callers are responsible for bounds-checking `index` before calling this function.
109    pub fn with_updated_leaf(
110        &self,
111        index: usize,
112        new_value: T,
113        depth: usize,
114    ) -> Result<Arc<Self>, Error> {
115        match self {
116            Self::Leaf(_) if depth == 0 => Ok(Self::leaf(new_value)),
117            Self::PackedLeaf(leaf) if depth == 0 => Ok(Arc::new(Self::PackedLeaf(
118                leaf.insert_at_index(index, new_value)?,
119            ))),
120            Self::Node { left, right, .. } if depth > 0 => {
121                let packing_depth = opt_packing_depth::<T>().unwrap_or(0);
122                let new_depth = depth - 1;
123                if (index >> (new_depth + packing_depth)) & 1 == 0 {
124                    // Index lies on the left, recurse left
125                    Ok(Self::node(
126                        left.with_updated_leaf(index, new_value, new_depth)?,
127                        right.clone(),
128                        Hash256::ZERO,
129                    ))
130                } else {
131                    // Index lies on the right, recurse right
132                    Ok(Self::node(
133                        left.clone(),
134                        right.with_updated_leaf(index, new_value, new_depth)?,
135                        Hash256::ZERO,
136                    ))
137                }
138            }
139            Self::Zero(zero_depth) if *zero_depth == depth => {
140                if depth == 0 {
141                    if opt_packing_factor::<T>().is_some() {
142                        Ok(Arc::new(Self::PackedLeaf(PackedLeaf::single(new_value))))
143                    } else {
144                        Ok(Self::leaf(new_value))
145                    }
146                } else {
147                    // Split zero node into a node with left and right, and recurse into
148                    // the appropriate subtree
149                    let new_zero = Self::zero(depth - 1);
150                    Self::node(new_zero.clone(), new_zero, Hash256::ZERO)
151                        .with_updated_leaf(index, new_value, depth)
152                }
153            }
154            _ => Err(Error::UpdateLeafError),
155        }
156    }
157
158    pub fn with_updated_leaves<U: UpdateMap<T>>(
159        &self,
160        updates: &U,
161        prefix: usize,
162        depth: usize,
163        hashes: Option<&BTreeMap<(usize, usize), Hash256>>,
164    ) -> Result<Arc<Self>, Error> {
165        let hash = opt_hash(hashes, depth, prefix).unwrap_or_default();
166
167        match self {
168            Self::Leaf(_) if depth == 0 => {
169                let index = prefix;
170                let value = updates
171                    .get(index)
172                    .cloned()
173                    .ok_or(Error::LeafUpdateMissing { index })?;
174                Ok(Self::leaf_with_hash(value, hash))
175            }
176            Self::PackedLeaf(packed_leaf) if depth == 0 => Ok(Arc::new(Self::PackedLeaf(
177                packed_leaf.update(prefix, hash, updates)?,
178            ))),
179            Self::Node { left, right, .. } if depth > 0 => {
180                let packing_depth = opt_packing_depth::<T>().unwrap_or(0);
181                let new_depth = depth - 1;
182                let left_prefix = prefix;
183                let right_prefix = prefix | (1 << (new_depth + packing_depth));
184                let right_subtree_end = prefix + (1 << (depth + packing_depth));
185
186                let mut has_left_updates = false;
187                updates.for_each_range(left_prefix, right_prefix, |_, _| {
188                    has_left_updates = true;
189                    ControlFlow::Break(())
190                })?;
191                let mut has_right_updates = false;
192                updates.for_each_range(right_prefix, right_subtree_end, |_, _| {
193                    has_right_updates = true;
194                    ControlFlow::Break(())
195                })?;
196
197                // Must have some updates else this recursive branch is a complete waste of time.
198                if !has_left_updates && !has_right_updates {
199                    return Err(Error::NodeUpdatesMissing { prefix });
200                }
201
202                let new_left = if has_left_updates {
203                    left.with_updated_leaves(updates, left_prefix, new_depth, hashes)?
204                } else {
205                    left.clone()
206                };
207                let new_right = if has_right_updates {
208                    right.with_updated_leaves(updates, right_prefix, new_depth, hashes)?
209                } else {
210                    right.clone()
211                };
212
213                Ok(Self::node(new_left, new_right, hash))
214            }
215            Self::Zero(zero_depth) if *zero_depth == depth => {
216                if depth == 0 {
217                    if opt_packing_factor::<T>().is_some() {
218                        let packed_leaf = PackedLeaf::empty().update(prefix, hash, updates)?;
219                        Ok(Arc::new(Self::PackedLeaf(packed_leaf)))
220                    } else {
221                        let index = prefix;
222                        let value = updates
223                            .get(index)
224                            .cloned()
225                            .ok_or(Error::LeafUpdateMissing { index })?;
226                        Ok(Self::leaf_with_hash(value, hash))
227                    }
228                } else {
229                    // Split zero node into a node with left and right and recurse.
230                    let new_zero = Self::zero(depth - 1);
231                    Self::node(new_zero.clone(), new_zero, hash)
232                        .with_updated_leaves(updates, prefix, depth, hashes)
233                }
234            }
235            _ => Err(Error::UpdateLeavesError),
236        }
237    }
238
239    /// Compute the number of elements stored in this subtree.
240    ///
241    /// This method should be avoided if possible. Prefer to read the length cached in a `List` or
242    /// similar.
243    pub fn compute_len(&self) -> usize {
244        match self {
245            Self::Leaf(_) => 1,
246            Self::PackedLeaf(leaf) => leaf.values.len(),
247            Self::Node { left, right, .. } => left.compute_len() + right.compute_len(),
248            Self::Zero(_) => 0,
249        }
250    }
251}
252
253pub enum RebaseAction<'a, T> {
254    // Not equal and no changes in parent nodes required.
255    NotEqualNoop,
256    // Not equal, but `new` should be replaced by the given node.
257    NotEqualReplace(Arc<T>),
258    // Nodes are already exactly equal and pointer equal.
259    EqualNoop,
260    // Nodes are exactly equal and `new` should be replaced by the given node.
261    EqualReplace(&'a Arc<T>),
262}
263
264pub enum IntraRebaseAction<T> {
265    Noop,
266    Replace(Arc<T>),
267}
268
269impl<T: Value> Tree<T> {
270    pub fn rebase_on<'a>(
271        orig: &'a Arc<Self>,
272        base: &'a Arc<Self>,
273        lengths: Option<(Length, Length)>,
274        full_depth: usize,
275    ) -> Result<RebaseAction<'a, Self>, Error> {
276        if Arc::ptr_eq(orig, base) {
277            return Ok(RebaseAction::EqualNoop);
278        }
279        match (&**orig, &**base) {
280            (Self::Leaf(l1), Self::Leaf(l2)) => {
281                if l1.value == l2.value {
282                    Ok(RebaseAction::EqualReplace(base))
283                } else {
284                    Ok(RebaseAction::NotEqualNoop)
285                }
286            }
287            (Self::PackedLeaf(l1), Self::PackedLeaf(l2)) => {
288                if l1.values == l2.values {
289                    Ok(RebaseAction::EqualReplace(base))
290                } else {
291                    Ok(RebaseAction::NotEqualNoop)
292                }
293            }
294            (Self::Zero(z1), Self::Zero(z2)) if z1 == z2 => Ok(RebaseAction::EqualReplace(base)),
295            (
296                Self::Node {
297                    hash: orig_hash_lock,
298                    left: l1,
299                    right: r1,
300                },
301                Self::Node {
302                    hash: base_hash_lock,
303                    left: l2,
304                    right: r2,
305                },
306            ) if full_depth > 0 => {
307                use RebaseAction::*;
308
309                let orig_hash = *orig_hash_lock.read();
310                let base_hash = *base_hash_lock.read();
311
312                // If hashes *and* lengths are equal then we can short-cut the recursion
313                // and immediately replace `orig` by the `base` node. If `lengths` are `None`
314                // then we know they are already equal (e.g. we're in a vector).
315                if !orig_hash.is_zero()
316                    && orig_hash == base_hash
317                    && lengths.is_none_or(|(orig_length, base_length)| orig_length == base_length)
318                {
319                    return Ok(EqualReplace(base));
320                }
321
322                let new_full_depth = full_depth - 1;
323                let (left_lengths, right_lengths) = lengths
324                    .map(|(orig_length, base_length)| {
325                        let max_left_length = Length(1 << new_full_depth);
326                        let orig_left_length = std::cmp::min(orig_length, max_left_length);
327                        let orig_right_length =
328                            Length(orig_length.as_usize() - orig_left_length.as_usize());
329
330                        let base_left_length = std::cmp::min(base_length, max_left_length);
331                        let base_right_length =
332                            Length(base_length.as_usize() - base_left_length.as_usize());
333                        (
334                            (orig_left_length, base_left_length),
335                            (orig_right_length, base_right_length),
336                        )
337                    })
338                    .unzip();
339
340                let left_action = Tree::rebase_on(l1, l2, left_lengths, new_full_depth)?;
341                let right_action = Tree::rebase_on(r1, r2, right_lengths, new_full_depth)?;
342
343                match (left_action, right_action) {
344                    (NotEqualNoop, NotEqualNoop | EqualNoop) | (EqualNoop, NotEqualNoop) => {
345                        Ok(NotEqualNoop)
346                    }
347                    (EqualNoop, EqualNoop) => Ok(EqualNoop),
348                    (NotEqualNoop | EqualNoop, NotEqualReplace(new_right)) => {
349                        Ok(NotEqualReplace(Arc::new(Self::Node {
350                            hash: RwLock::new(orig_hash),
351                            left: l1.clone(),
352                            right: new_right,
353                        })))
354                    }
355                    (NotEqualNoop | EqualNoop, EqualReplace(new_right)) => {
356                        Ok(NotEqualReplace(Arc::new(Self::Node {
357                            hash: RwLock::new(orig_hash),
358                            left: l1.clone(),
359                            right: new_right.clone(),
360                        })))
361                    }
362                    (NotEqualReplace(new_left), NotEqualNoop | EqualNoop) => {
363                        Ok(NotEqualReplace(Arc::new(Self::Node {
364                            hash: RwLock::new(orig_hash),
365                            left: new_left,
366                            right: r1.clone(),
367                        })))
368                    }
369                    (NotEqualReplace(new_left), NotEqualReplace(new_right)) => {
370                        Ok(NotEqualReplace(Arc::new(Self::Node {
371                            hash: RwLock::new(orig_hash),
372                            left: new_left,
373                            right: new_right,
374                        })))
375                    }
376                    (NotEqualReplace(new_left), EqualReplace(new_right)) => {
377                        Ok(NotEqualReplace(Arc::new(Self::Node {
378                            hash: RwLock::new(orig_hash),
379                            left: new_left,
380                            right: new_right.clone(),
381                        })))
382                    }
383                    (EqualReplace(new_left), NotEqualNoop) => {
384                        Ok(NotEqualReplace(Arc::new(Self::Node {
385                            hash: RwLock::new(orig_hash),
386                            left: new_left.clone(),
387                            right: r1.clone(),
388                        })))
389                    }
390                    (EqualReplace(new_left), NotEqualReplace(new_right)) => {
391                        Ok(NotEqualReplace(Arc::new(Self::Node {
392                            hash: RwLock::new(orig_hash),
393                            left: new_left.clone(),
394                            right: new_right,
395                        })))
396                    }
397                    (EqualReplace(_), EqualReplace(_)) | (EqualReplace(_), EqualNoop) => {
398                        Ok(EqualReplace(base))
399                    }
400                }
401            }
402            (Self::Zero(_), _) | (_, Self::Zero(_)) => Ok(RebaseAction::NotEqualNoop),
403            (Self::Node { .. }, Self::Node { .. }) => Err(Error::InvalidRebaseNode),
404            (Self::Leaf(_) | Self::PackedLeaf(_), _) | (_, Self::Leaf(_) | Self::PackedLeaf(_)) => {
405                Err(Error::InvalidRebaseLeaf)
406            }
407        }
408    }
409
410    /// Exploit structural sharing between identical parts of the tree.
411    ///
412    /// This method traverses a fully-hashed tree and replaces identical subtrees with clones of
413    /// the first equal subtree. The result is a tree that shares memory for common subtrees, and
414    /// thus uses less memory overall.
415    ///
416    /// You MUST pass a fully-hashed tree to this function, or an `Error::IntraRebaseZeroHash`
417    /// error will be returned.
418    ///
419    /// Arguments are:
420    ///
421    /// - `orig`: The tree to rebase.
422    /// - `known_subtrees`: map from `(depth, tree_hash_root)` to `Arc<Node>`. This should be empty
423    ///   for the top-level call. The recursive calls fill it in. It can be discarded after the
424    ///   method returns.
425    /// - `current_depth`: The depth of the tree `orig`. This will be decremented as we recurse
426    ///   down the tree towards the leaves.
427    ///
428    /// Presently leaves are left untouched by this procedure, so it will only produce savings in
429    /// trees with equal internal nodes (i.e. equal subtrees with at least two leaves/packed leaves
430    /// under them).
431    ///
432    /// The input tree must be fully-hashed, and the result will also remain fully-hashed.
433    pub fn intra_rebase(
434        orig: &Arc<Self>,
435        known_subtrees: &mut HashMap<(usize, Hash256), Arc<Self>>,
436        current_depth: usize,
437    ) -> Result<IntraRebaseAction<Self>, Error> {
438        match &**orig {
439            Self::Leaf(_) | Self::PackedLeaf(_) | Self::Zero(_) => Ok(IntraRebaseAction::Noop),
440            Self::Node { hash, left, right } if current_depth > 0 => {
441                let hash = *hash.read();
442
443                // Tree must be fully hashed prior to intra-rebase.
444                if hash.is_zero() {
445                    return Err(Error::IntraRebaseZeroHash);
446                }
447
448                if let Some(known_subtree) = known_subtrees.get(&(current_depth, hash)) {
449                    // Node is already known from elsewhere in the tree. We can replace it without
450                    // looking at further subtrees.
451                    return Ok(IntraRebaseAction::Replace(known_subtree.clone()));
452                }
453
454                let left_action = Self::intra_rebase(left, known_subtrees, current_depth - 1)?;
455                let right_action = Self::intra_rebase(right, known_subtrees, current_depth - 1)?;
456
457                let action = match (left_action, right_action) {
458                    (IntraRebaseAction::Noop, IntraRebaseAction::Noop) => IntraRebaseAction::Noop,
459                    (IntraRebaseAction::Noop, IntraRebaseAction::Replace(new_right)) => {
460                        IntraRebaseAction::Replace(Self::node(left.clone(), new_right, hash))
461                    }
462                    (IntraRebaseAction::Replace(new_left), IntraRebaseAction::Noop) => {
463                        IntraRebaseAction::Replace(Self::node(new_left, right.clone(), hash))
464                    }
465                    (
466                        IntraRebaseAction::Replace(new_left),
467                        IntraRebaseAction::Replace(new_right),
468                    ) => IntraRebaseAction::Replace(Self::node(new_left, new_right, hash)),
469                };
470
471                // Add the new version of this node to the known subtrees.
472                let new_subtree = match &action {
473                    // `orig` has not been seen in this traversal and will not change, so we add it
474                    // to the map.
475                    IntraRebaseAction::Noop => orig.clone(),
476                    IntraRebaseAction::Replace(new) => new.clone(),
477                };
478                let existing_entry = known_subtrees.insert((current_depth, hash), new_subtree);
479
480                // We should not add any identical node to the `known_subtrees` more than once.
481                // This indicates an error in this method's implementation or the map passed in not
482                // being empty.
483                if existing_entry.is_some() {
484                    return Err(Error::IntraRebaseRepeatVisit);
485                }
486
487                Ok(action)
488            }
489            Self::Node { .. } => Err(Error::IntraRebaseZeroDepth),
490        }
491    }
492}
493
494impl<T: Value + Send + Sync> Tree<T> {
495    pub fn tree_hash(&self) -> Hash256 {
496        match self {
497            Self::Leaf(Leaf { hash, value }) => {
498                // FIXME(sproul): upgradeable RwLock?
499                let read_lock = hash.read();
500                let existing_hash = *read_lock;
501                drop(read_lock);
502
503                // NOTE: We re-compute the hash whenever it is non-zero. Computed hashes may
504                // legitimately be zero, but this only occurs at the leaf level when the value is
505                // entirely zeroes (e.g. [0u64, 0, 0, 0]). In order to avoid storing an
506                // `Option<Hash256>` we choose to re-compute the hash in this case. In practice
507                // this is unlikely to provide any performance penalty except at very small list
508                // lengths (<= 32), because a node higher in the tree will cache a non-zero hash
509                // preventing its children from being visited more than once.
510                if !existing_hash.is_zero() {
511                    existing_hash
512                } else {
513                    let tree_hash = value.tree_hash_root();
514                    *hash.write() = tree_hash;
515                    tree_hash
516                }
517            }
518            Self::PackedLeaf(leaf) => leaf.tree_hash(),
519            Self::Zero(depth) => Hash256::from(ZERO_HASHES[*depth]),
520            Self::Node { hash, left, right } => {
521                let read_lock = hash.read();
522                let existing_hash = *read_lock;
523                drop(read_lock);
524
525                if !existing_hash.is_zero() {
526                    existing_hash
527                } else {
528                    // Parallelism goes brrrr.
529                    let (left_hash, right_hash) =
530                        rayon::join(|| left.tree_hash(), || right.tree_hash());
531                    let tree_hash =
532                        Hash256::from(hash32_concat(left_hash.as_slice(), right_hash.as_slice()));
533                    *hash.write() = tree_hash;
534                    tree_hash
535                }
536            }
537        }
538    }
539}