eth_trie_utils/
trie_subsets.rs

1//! Logic for calculating a subset of a [`PartialTrie`] from an existing
2//! [`PartialTrie`].
3//!
4//! Given a `PartialTrie`, you can pass in keys of leaf nodes that should be
5//! included in the produced subset. Any nodes that are not needed in the subset
6//! are replaced with [`Hash`] nodes are far up the trie as possible.
7
8use std::sync::Arc;
9
10use ethereum_types::H256;
11use thiserror::Error;
12
13use crate::{
14    nibbles::Nibbles,
15    partial_trie::{Node, PartialTrie, WrappedNode},
16    utils::TrieNodeType,
17};
18
19pub type SubsetTrieResult<T> = Result<T, SubsetTrieError>;
20
21/// Errors that may occur when creating a subset [`PartialTrie`].
22#[derive(Debug, Error)]
23pub enum SubsetTrieError {
24    #[error("Tried to mark nodes in a tracked trie for a key that does not exist! (Key: {0}, trie: {1})")]
25    UnexpectedKey(Nibbles, String),
26}
27
28#[derive(Debug)]
29enum TrackedNodeIntern<N: PartialTrie> {
30    Empty,
31    Hash,
32    Branch(Box<[TrackedNode<N>; 16]>),
33    Extension(Box<TrackedNode<N>>),
34    Leaf,
35}
36
37#[derive(Debug)]
38struct TrackedNode<N: PartialTrie> {
39    node: TrackedNodeIntern<N>,
40    info: TrackedNodeInfo<N>,
41}
42
43impl<N: Clone + PartialTrie> TrackedNode<N> {
44    fn new(underlying_node: &N) -> Self {
45        Self {
46            node: match &**underlying_node {
47                Node::Empty => TrackedNodeIntern::Empty,
48                Node::Hash(_) => TrackedNodeIntern::Hash,
49                Node::Branch { ref children, .. } => {
50                    TrackedNodeIntern::Branch(Box::new(tracked_branch(children)))
51                }
52                Node::Extension { child, .. } => {
53                    TrackedNodeIntern::Extension(Box::new(TrackedNode::new(child)))
54                }
55                Node::Leaf { .. } => TrackedNodeIntern::Leaf,
56            },
57            info: TrackedNodeInfo::new(underlying_node.clone()),
58        }
59    }
60}
61
62fn tracked_branch<N: PartialTrie>(
63    underlying_children: &[WrappedNode<N>; 16],
64) -> [TrackedNode<N>; 16] {
65    [
66        TrackedNode::new(&underlying_children[0]),
67        TrackedNode::new(&underlying_children[1]),
68        TrackedNode::new(&underlying_children[2]),
69        TrackedNode::new(&underlying_children[3]),
70        TrackedNode::new(&underlying_children[4]),
71        TrackedNode::new(&underlying_children[5]),
72        TrackedNode::new(&underlying_children[6]),
73        TrackedNode::new(&underlying_children[7]),
74        TrackedNode::new(&underlying_children[8]),
75        TrackedNode::new(&underlying_children[9]),
76        TrackedNode::new(&underlying_children[10]),
77        TrackedNode::new(&underlying_children[11]),
78        TrackedNode::new(&underlying_children[12]),
79        TrackedNode::new(&underlying_children[13]),
80        TrackedNode::new(&underlying_children[14]),
81        TrackedNode::new(&underlying_children[15]),
82    ]
83}
84
85fn partial_trie_extension<N: PartialTrie>(nibbles: Nibbles, child: &TrackedNode<N>) -> N {
86    N::new(Node::Extension {
87        nibbles,
88        child: Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
89            child,
90        ))),
91    })
92}
93
94fn partial_trie_branch<N: PartialTrie>(
95    underlying_children: &[TrackedNode<N>; 16],
96    value: &[u8],
97) -> N {
98    let children = [
99        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
100            &underlying_children[0],
101        ))),
102        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
103            &underlying_children[1],
104        ))),
105        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
106            &underlying_children[2],
107        ))),
108        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
109            &underlying_children[3],
110        ))),
111        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
112            &underlying_children[4],
113        ))),
114        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
115            &underlying_children[5],
116        ))),
117        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
118            &underlying_children[6],
119        ))),
120        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
121            &underlying_children[7],
122        ))),
123        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
124            &underlying_children[8],
125        ))),
126        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
127            &underlying_children[9],
128        ))),
129        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
130            &underlying_children[10],
131        ))),
132        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
133            &underlying_children[11],
134        ))),
135        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
136            &underlying_children[12],
137        ))),
138        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
139            &underlying_children[13],
140        ))),
141        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
142            &underlying_children[14],
143        ))),
144        Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
145            &underlying_children[15],
146        ))),
147    ];
148
149    N::new(Node::Branch {
150        children,
151        value: value.to_owned(),
152    })
153}
154
155#[derive(Debug)]
156struct TrackedNodeInfo<N: PartialTrie> {
157    underlying_node: N,
158    touched: bool,
159}
160
161impl<N: PartialTrie> TrackedNodeInfo<N> {
162    fn new(underlying_node: N) -> Self {
163        Self {
164            underlying_node,
165            touched: false,
166        }
167    }
168
169    fn reset(&mut self) {
170        self.touched = false;
171    }
172
173    fn get_nibbles_expected(&self) -> &Nibbles {
174        match &*self.underlying_node {
175            Node::Extension { nibbles, .. } => nibbles,
176            Node::Leaf { nibbles, .. } => nibbles,
177            _ => unreachable!(
178                "Tried getting the nibbles field from a {} node!",
179                TrieNodeType::from(&*self.underlying_node)
180            ),
181        }
182    }
183
184    fn get_hash_node_hash_expected(&self) -> H256 {
185        match *self.underlying_node {
186            Node::Hash(h) => h,
187            _ => unreachable!("Expected an underlying hash node!"),
188        }
189    }
190
191    fn get_branch_value_expected(&self) -> &Vec<u8> {
192        match &*self.underlying_node {
193            Node::Branch { value, .. } => value,
194            _ => unreachable!("Expected an underlying branch node!"),
195        }
196    }
197
198    fn get_leaf_nibbles_and_value_expected(&self) -> (&Nibbles, &Vec<u8>) {
199        match &*self.underlying_node {
200            Node::Leaf { nibbles, value } => (nibbles, value),
201            _ => unreachable!("Expected an underlying leaf node!"),
202        }
203    }
204}
205
206/// Create a [`PartialTrie`] subset from a base trie given an iterator of keys
207/// of nodes that may or may not exist in the trie. All nodes traversed by the
208/// keys will not be hashed out in the trie subset. If the key does not exist in
209/// the trie at all, this is not considered an error and will still record which
210/// nodes were visited.
211pub fn create_trie_subset<N, K, I>(trie: &N, keys_involved: I) -> SubsetTrieResult<N>
212where
213    N: PartialTrie,
214    K: Into<Nibbles>,
215    I: IntoIterator<Item = K>,
216{
217    let mut tracked_trie = TrackedNode::new(trie);
218    create_trie_subset_intern(&mut tracked_trie, keys_involved.into_iter())
219}
220
221/// Create [`PartialTrie`] subsets from a given base `PartialTrie` given a
222/// iterator of keys per subset needed. See [`create_trie_subset`] for more
223/// info.
224pub fn create_trie_subsets<N, K, I, O>(base_trie: &N, keys_involved: O) -> SubsetTrieResult<Vec<N>>
225where
226    N: PartialTrie,
227    K: Into<Nibbles>,
228    I: IntoIterator<Item = K>,
229    O: IntoIterator<Item = I>,
230{
231    let mut tracked_trie = TrackedNode::new(base_trie);
232
233    keys_involved
234        .into_iter()
235        .map(|ks| {
236            let res = create_trie_subset_intern(&mut tracked_trie, ks.into_iter())?;
237            reset_tracked_trie_state(&mut tracked_trie);
238
239            Ok(res)
240        })
241        .collect::<SubsetTrieResult<_>>()
242}
243
244fn create_trie_subset_intern<N, K>(
245    tracked_trie: &mut TrackedNode<N>,
246    keys_involved: impl Iterator<Item = K>,
247) -> SubsetTrieResult<N>
248where
249    N: PartialTrie,
250    K: Into<Nibbles>,
251{
252    for k in keys_involved {
253        mark_nodes_that_are_needed(tracked_trie, &mut k.into())?;
254    }
255
256    Ok(create_partial_trie_subset_from_tracked_trie(tracked_trie))
257}
258
259fn mark_nodes_that_are_needed<N: PartialTrie>(
260    trie: &mut TrackedNode<N>,
261    curr_nibbles: &mut Nibbles,
262) -> SubsetTrieResult<()> {
263    trie.info.touched = true;
264
265    match &mut trie.node {
266        TrackedNodeIntern::Empty => Ok(()),
267        TrackedNodeIntern::Hash => match curr_nibbles.is_empty() {
268            false => Err(SubsetTrieError::UnexpectedKey(
269                *curr_nibbles,
270                format!("{:?}", trie),
271            )),
272            true => Ok(()),
273        },
274        // Note: If we end up supporting non-fixed sized keys, then we need to also check value.
275        TrackedNodeIntern::Branch(children) => {
276            // Check against branch value.
277            if curr_nibbles.is_empty() {
278                return Ok(());
279            }
280
281            let nib = curr_nibbles.pop_next_nibble_front();
282            mark_nodes_that_are_needed(&mut children[nib as usize], curr_nibbles)
283        }
284        TrackedNodeIntern::Extension(child) => {
285            let nibbles = trie.info.get_nibbles_expected();
286            let r = curr_nibbles.pop_nibbles_front(nibbles.count);
287
288            match r.nibbles_are_identical_up_to_smallest_count(nibbles) {
289                false => Ok(()),
290                true => mark_nodes_that_are_needed(child, curr_nibbles),
291            }
292        }
293        TrackedNodeIntern::Leaf => Ok(()),
294    }
295}
296
297fn create_partial_trie_subset_from_tracked_trie<N: PartialTrie>(
298    tracked_node: &TrackedNode<N>,
299) -> N {
300    match tracked_node.info.touched {
301        false => N::new(Node::Hash(tracked_node.info.underlying_node.hash())),
302        true => match &tracked_node.node {
303            TrackedNodeIntern::Empty => N::new(Node::Empty),
304            TrackedNodeIntern::Hash => {
305                N::new(Node::Hash(tracked_node.info.get_hash_node_hash_expected()))
306            }
307            TrackedNodeIntern::Branch(children) => {
308                partial_trie_branch(children, tracked_node.info.get_branch_value_expected())
309            }
310            TrackedNodeIntern::Extension(child) => {
311                partial_trie_extension(*tracked_node.info.get_nibbles_expected(), child)
312            }
313            TrackedNodeIntern::Leaf => {
314                let (nibbles, value) = tracked_node.info.get_leaf_nibbles_and_value_expected();
315                N::new(Node::Leaf {
316                    nibbles: *nibbles,
317                    value: value.clone(),
318                })
319            }
320        },
321    }
322}
323
324fn reset_tracked_trie_state<N: PartialTrie>(tracked_node: &mut TrackedNode<N>) {
325    match tracked_node.node {
326        TrackedNodeIntern::Branch(ref mut children) => {
327            children.iter_mut().for_each(|c| c.info.reset())
328        }
329        TrackedNodeIntern::Extension(ref mut child) => child.info.reset(),
330        TrackedNodeIntern::Empty | TrackedNodeIntern::Hash | TrackedNodeIntern::Leaf => {
331            tracked_node.info.reset()
332        }
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use std::{collections::HashSet, iter::once};
339
340    use ethereum_types::H256;
341
342    use super::{create_trie_subset, create_trie_subsets};
343    use crate::{
344        nibbles::Nibbles,
345        partial_trie::{HashedPartialTrie, Node, PartialTrie},
346        testing_utils::generate_n_random_fixed_trie_entries,
347        trie_ops::ValOrHash,
348        utils::TrieNodeType,
349    };
350
351    type TrieType = HashedPartialTrie;
352
353    const MASSIVE_TEST_NUM_SUB_TRIES: usize = 10;
354    const MASSIVE_TEST_NUM_SUB_TRIE_SIZE: usize = 5000;
355
356    #[derive(Debug, Eq, PartialEq)]
357    struct NodeFullNibbles {
358        n_type: TrieNodeType,
359        nibbles: Nibbles,
360    }
361
362    impl NodeFullNibbles {
363        fn new_from_node<N: PartialTrie>(node: &Node<N>, nibbles: Nibbles) -> Self {
364            Self {
365                n_type: node.into(),
366                nibbles,
367            }
368        }
369
370        fn new_from_node_type<K: Into<Nibbles>>(n_type: TrieNodeType, nibbles: K) -> Self {
371            Self {
372                n_type,
373                nibbles: nibbles.into(),
374            }
375        }
376    }
377
378    fn get_all_non_empty_and_hash_nodes_in_trie(trie: &TrieType) -> Vec<NodeFullNibbles> {
379        let mut nodes = Vec::new();
380        get_all_non_empty_and_hash_nodes_in_trie_intern(trie, Nibbles::default(), &mut nodes);
381
382        nodes
383    }
384
385    fn get_all_non_empty_and_hash_nodes_in_trie_intern(
386        trie: &TrieType,
387        mut curr_nibbles: Nibbles,
388        nodes: &mut Vec<NodeFullNibbles>,
389    ) {
390        match &trie.node {
391            Node::Empty | Node::Hash(_) => return,
392            Node::Branch { children, .. } => {
393                for (i, c) in children.iter().enumerate() {
394                    get_all_non_empty_and_hash_nodes_in_trie_intern(
395                        c,
396                        curr_nibbles.merge_nibble(i as u8),
397                        nodes,
398                    )
399                }
400            }
401            Node::Extension { nibbles, child } => get_all_non_empty_and_hash_nodes_in_trie_intern(
402                child,
403                curr_nibbles.merge_nibbles(nibbles),
404                nodes,
405            ),
406            Node::Leaf { nibbles, .. } => curr_nibbles = curr_nibbles.merge_nibbles(nibbles),
407        };
408
409        nodes.push(NodeFullNibbles::new_from_node(trie, curr_nibbles.reverse()));
410    }
411
412    fn get_all_nibbles_of_leaf_nodes_in_trie(trie: &TrieType) -> HashSet<Nibbles> {
413        trie.items()
414            .filter_map(|(n, v_or_h)| matches!(v_or_h, ValOrHash::Val(_)).then(|| n))
415            .collect()
416    }
417
418    #[test]
419    fn empty_trie_does_not_return_err_on_query() {
420        let trie = TrieType::default();
421        let nibbles: Nibbles = 0x1234.into();
422        let res = create_trie_subset(&trie, once(nibbles));
423
424        assert!(res.is_ok());
425    }
426
427    #[test]
428    fn non_existent_key_does_not_return_err() {
429        let mut trie = TrieType::default();
430        trie.insert(0x1234, vec![0, 1, 2]);
431        let res = create_trie_subset(&trie, once(0x5678));
432
433        assert!(res.is_ok());
434    }
435
436    #[test]
437    fn encountering_a_hash_node_returns_err() {
438        let trie = HashedPartialTrie::new(Node::Hash(H256::zero()));
439        let res = create_trie_subset(&trie, once(0x1234));
440
441        assert!(res.is_err())
442    }
443
444    #[test]
445    fn single_node_trie_is_queryable() {
446        let mut trie = TrieType::default();
447        trie.insert(0x1234, vec![0, 1, 2]);
448        let trie_subset = create_trie_subset(&trie, once(0x1234)).unwrap();
449
450        assert_eq!(trie, trie_subset);
451    }
452
453    #[test]
454    fn multi_node_trie_returns_proper_subset() {
455        let mut trie = TrieType::default();
456        trie.insert(0x1234, vec![0]);
457        trie.insert(0x56, vec![1]);
458        trie.insert(0x12345, vec![2]);
459
460        let trie_subset = create_trie_subset(&trie, vec![0x1234, 0x56].into_iter()).unwrap();
461        let leaf_keys = get_all_nibbles_of_leaf_nodes_in_trie(&trie_subset);
462
463        assert!(leaf_keys.contains(&(Nibbles::from(0x1234))));
464        assert!(leaf_keys.contains(&(Nibbles::from(0x56))));
465        assert!(!leaf_keys.contains(&Nibbles::from(0x12345)));
466    }
467
468    #[test]
469    fn intermediate_nodes_are_included_in_subset() {
470        let mut trie = TrieType::default();
471        let inserts = vec![
472            (0x1234_u64.into(), vec![0]),
473            (0x1324_u64.into(), vec![1]),
474            (0x132400005_u64.into(), vec![2]),
475            (0x2001_u64.into(), vec![3]),
476            (0x2002_u64.into(), vec![4]),
477        ];
478
479        // Branch (0x)  --> 1, 2
480        // Branch (0x1) --> 2, 3
481        // Leaf (0x1234) --> (n: 0x34, v: [0])
482
483        // Branch (0x1324, v: [1]) --> 0
484        // Leaf (0x132400005) --> (0x0005, v: [2])
485
486        // Extension (0x2) --> n: 0x00
487        // Branch (0x200) --> 1, 2
488        // Leaf  (0x2001) --> (n: 0x1, v: [3])
489        // Leaf  (0x2002) --> (n: 0x2, v: [4])
490
491        for (k, v) in inserts.iter() {
492            trie.insert(*k, v.clone());
493        }
494
495        let ks: Vec<_> = inserts.iter().map(|(k, _)| k).cloned().collect();
496        let trie_subset_all = create_trie_subset(&trie, ks.iter().cloned()).unwrap();
497
498        let subset_keys = get_all_nibbles_of_leaf_nodes_in_trie(&trie_subset_all);
499        assert!(subset_keys.iter().all(|k| ks.contains(k)));
500        assert!(ks.iter().all(|k| subset_keys.contains(k)));
501
502        let all_non_empty_and_hash_nodes =
503            get_all_non_empty_and_hash_nodes_in_trie(&trie_subset_all);
504
505        assert_node_exists(
506            &all_non_empty_and_hash_nodes,
507            TrieNodeType::Branch,
508            Nibbles::default(),
509        );
510        assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Branch, 0x1);
511        assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Leaf, 0x1234);
512
513        assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Extension, 0x13);
514        assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Branch, 0x1324);
515        assert_node_exists(
516            &all_non_empty_and_hash_nodes,
517            TrieNodeType::Leaf,
518            0x132400005_u64,
519        );
520
521        assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Extension, 0x2);
522        assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Branch, 0x200);
523        assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Leaf, 0x2001);
524        assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Leaf, 0x2002);
525
526        assert_eq!(all_non_empty_and_hash_nodes.len(), 10);
527
528        // Now actual subset tests.
529        let all_non_empty_and_hash_nodes_partial = get_all_non_empty_and_hash_nodes_in_trie(
530            &create_trie_subset(&trie, once(0x2001)).unwrap(),
531        );
532        assert_node_exists(
533            &all_non_empty_and_hash_nodes_partial,
534            TrieNodeType::Branch,
535            Nibbles::default(),
536        );
537        assert_node_exists(
538            &all_non_empty_and_hash_nodes_partial,
539            TrieNodeType::Extension,
540            0x2,
541        );
542        assert_node_exists(
543            &all_non_empty_and_hash_nodes_partial,
544            TrieNodeType::Branch,
545            0x200,
546        );
547        assert_node_exists(
548            &all_non_empty_and_hash_nodes_partial,
549            TrieNodeType::Leaf,
550            0x2001,
551        );
552        assert_eq!(all_non_empty_and_hash_nodes_partial.len(), 4);
553
554        let all_non_empty_and_hash_nodes_partial = get_all_non_empty_and_hash_nodes_in_trie(
555            &create_trie_subset(&trie, once(0x1324)).unwrap(),
556        );
557        assert_node_exists(
558            &all_non_empty_and_hash_nodes_partial,
559            TrieNodeType::Branch,
560            Nibbles::default(),
561        );
562        assert_node_exists(
563            &all_non_empty_and_hash_nodes_partial,
564            TrieNodeType::Branch,
565            0x1,
566        );
567        assert_node_exists(
568            &all_non_empty_and_hash_nodes_partial,
569            TrieNodeType::Extension,
570            0x13,
571        );
572        assert_node_exists(
573            &all_non_empty_and_hash_nodes_partial,
574            TrieNodeType::Branch,
575            0x1324,
576        );
577        assert_eq!(all_non_empty_and_hash_nodes_partial.len(), 4);
578    }
579
580    fn assert_node_exists<K: Into<Nibbles>>(
581        nodes: &[NodeFullNibbles],
582        n_type: TrieNodeType,
583        nibbles: K,
584    ) {
585        assert!(nodes.contains(&NodeFullNibbles::new_from_node_type(
586            n_type,
587            nibbles.into().reverse()
588        )));
589    }
590
591    #[test]
592    fn all_leafs_of_keys_to_create_subset_are_included_in_subset_for_giant_trie() {
593        let trie_size = MASSIVE_TEST_NUM_SUB_TRIES * MASSIVE_TEST_NUM_SUB_TRIE_SIZE;
594
595        let random_entries: Vec<_> =
596            generate_n_random_fixed_trie_entries(trie_size, 9009).collect();
597        let entry_keys: Vec<_> = random_entries.iter().map(|(k, _)| k).cloned().collect();
598        let trie = TrieType::from_iter(random_entries);
599
600        let keys_of_subsets: Vec<Vec<_>> = (0..MASSIVE_TEST_NUM_SUB_TRIES)
601            .map(|i| {
602                let entry_range_start = i * MASSIVE_TEST_NUM_SUB_TRIE_SIZE;
603                let entry_range_end = entry_range_start + MASSIVE_TEST_NUM_SUB_TRIE_SIZE;
604                entry_keys[entry_range_start..entry_range_end].to_vec()
605            })
606            .collect();
607
608        let trie_subsets =
609            create_trie_subsets(&trie, keys_of_subsets.iter().map(|v| v.iter().cloned())).unwrap();
610
611        for (sub_trie, ks_used) in trie_subsets.into_iter().zip(keys_of_subsets.into_iter()) {
612            let leaf_nibbles = get_all_nibbles_of_leaf_nodes_in_trie(&sub_trie);
613            assert!(ks_used.into_iter().all(|k| leaf_nibbles.contains(&k)));
614        }
615    }
616}